package server import ( "errors" "fmt" "io" "log" "sort" "strconv" "strings" "sync" "sync/atomic" "time" "tunnel/pkg/config" "tunnel/pkg/server/env" "tunnel/pkg/server/hook" "tunnel/pkg/server/queue" "tunnel/pkg/server/socket" ) const maxRecentSize = 8 const maxQueueLimit = 16384 type metric struct { tx uint64 rx uint64 } type stream struct { id int t *tunnel env env.Env since time.Time until time.Time wg sync.WaitGroup in, out socket.Conn pipes []*hook.Pipe m struct { in metric out metric } } type tunnel struct { id string args string streams map[int]*stream recent []*stream mu sync.Mutex wg sync.WaitGroup nextSid int quit chan struct{} done chan struct{} queue chan struct{} in, out socket.S hooks []hook.H env env.Env } func (s *stream) String() string { return fmt.Sprintf("stream:%d", s.id) } func (t *tunnel) String() string { return fmt.Sprintf("tunnel:%s", t.id) } func (t *tunnel) stopServe() { close(t.quit) t.in.Close() t.out.Close() <-t.done } func (t *tunnel) stopStreams() { t.mu.Lock() for _, s := range t.streams { log.Println(s, "stop") s.stop() } t.mu.Unlock() t.wg.Wait() } func (t *tunnel) Close() { t.stopServe() t.stopStreams() log.Println(t, "delete") } func (t *tunnel) alive() bool { select { case <-t.quit: return false default: return true } } func (t *tunnel) acquire() bool { select { case t.queue <- struct{}{}: return true case <-t.quit: return false } } func (t *tunnel) release() { <-t.queue } func (t *tunnel) sleep(d time.Duration) { tmr := time.NewTimer(d) select { case <-tmr.C: case <-t.quit: } tmr.Stop() } func (t *tunnel) openPipes(env env.Env) ([]*hook.Pipe, error) { var pipes []*hook.Pipe cleanup := func() { for _, p := range pipes { p.Close() } } for _, h := range t.hooks { p, err := h.Open(env) if err != nil { cleanup() return nil, fmt.Errorf("%s: %w", h, err) } pipes = append(pipes, p) } return pipes, nil } func (t *tunnel) serve() { for t.acquire() { var ok bool env := t.env.Fork() env.Set("tunnel", t.id) if in, err := t.in.Open(env); err != nil { if t.alive() { log.Println(t, err) } } else if out, err := t.out.Open(env); err != nil { log.Println(t, err) in.Close() } else if pipes, err := t.openPipes(env); err != nil { log.Println(t, err) in.Close() out.Close() } else { s := t.newStream(env, in, out, pipes) log.Println(t, s, "create", in, out) ok = true } if !ok { t.sleep(5 * time.Second) t.release() } } close(t.done) } func (t *tunnel) newStream(env env.Env, in, out socket.Conn, pipes []*hook.Pipe) *stream { s := &stream{ t: t, in: in, out: out, pipes: pipes, env: env, id: t.nextSid, since: time.Now(), } s.env.Set("stream", strconv.Itoa(s.id)) s.run() t.mu.Lock() t.nextSid++ t.streams[s.id] = s t.mu.Unlock() go s.waitAndClose() return s } func (t *tunnel) delStream(s *stream) { t.mu.Lock() defer t.mu.Unlock() delete(t.streams, s.id) t.recent = append(t.recent, s) if len(t.recent) > maxRecentSize { t.recent = t.recent[len(t.recent)-maxRecentSize:] } } func (s *stream) info() string { var d int64 if s.until.IsZero() { d = time.Since(s.since).Milliseconds() } else { d = s.until.Sub(s.since).Milliseconds() } return fmt.Sprintf("%.3fs [%s] %d/%d -> %d/%d", float64(d)/1000.0, s.env.Get("info"), s.m.in.tx, s.m.in.rx, s.m.out.rx, s.m.out.tx) } func (s *stream) waitAndClose() { s.wg.Wait() s.until = time.Now() s.t.delStream(s) s.t.release() s.t.wg.Done() s.in.Close() s.out.Close() for _, p := range s.pipes { p.Close() } log.Println(s.t, s, "done", s.info()) } func (s *stream) channel(c socket.Conn, m *metric, rq, wq queue.Q) { watch := func(q queue.Q, f func(q queue.Q) error) { defer s.wg.Done() err := f(q) if errors.Is(err, io.EOF) { err = nil } if e := c.Close(); e != nil { if e == socket.ErrAlreadyClosed { err = nil } else { err = e } } if err != nil { log.Println(s.t, s, c, err) } } counter := func(c *uint64, src, dst queue.Q) { for b := range src { dst <- b atomic.AddUint64(c, uint64(len(b))) } close(dst) } s.wg.Add(2) go func() { q := queue.New() go counter(&m.tx, q, wq) watch(q, c.Send) close(q) }() go func() { q := queue.New() go counter(&m.rx, rq, q) watch(q, c.Recv) q.Dry() }() } func (s *stream) pipe(h hook.H, f hook.Func, rq, wq queue.Q) { s.wg.Add(1) go func() { defer s.wg.Done() if err := f(rq, wq); err != nil && !errors.Is(err, io.EOF) { log.Println(s.t, s, h, err) } close(wq) rq.Dry() }() } func (s *stream) run() { s.t.wg.Add(1) rq, wq := queue.New(), queue.New() s.channel(s.in, &s.m.in, rq, wq) for _, p := range s.pipes { if p.Send != nil { q := queue.New() s.pipe(p.Hook, p.Send, wq, q) wq = q } if p.Recv != nil { q := queue.New() s.pipe(p.Hook, p.Recv, q, rq) rq = q } } s.channel(s.out, &s.m.out, wq, rq) } func (s *stream) stop() { s.in.Close() s.out.Close() } func parseHooks(args []string) ([]hook.H, error) { var hooks []hook.H for _, arg := range args { if h, err := hook.New(arg); err != nil { return nil, err } else { hooks = append(hooks, h) } } return hooks, nil } func newTunnel(limit int, args []string, env env.Env) (*tunnel, error) { var in, out socket.S var hooks []hook.H var err error n := len(args) - 1 if in, err = socket.New(args[0]); err != nil { return nil, err } if out, err = socket.New(args[n]); err != nil { in.Close() return nil, err } if hooks, err = parseHooks(args[1:n]); err != nil { in.Close() out.Close() return nil, err } t := &tunnel{ args: strings.Join(args, " "), quit: make(chan struct{}), done: make(chan struct{}), hooks: hooks, in: in, out: out, env: env, queue: make(chan struct{}, limit), streams: make(map[int]*stream), } go t.serve() return t, nil } func isOkTunnelName(s string) bool { return s != "" } func tunnelAdd(r *request) { args := r.args name := "" limit := 1 for len(args) > 1 { if args[0] == "name" { name = args[1] if !isOkTunnelName(name) { r.Fatal("bad name") } if _, ok := r.c.s.tunnels[name]; ok { r.Fatal("already exists") } args = args[2:] continue } if args[0] == "limit" { if n, _ := strconv.Atoi(args[1]); n > 0 && n < maxQueueLimit { limit = n } else { r.Fatal("bad limit") } args = args[2:] } if args[0] == "unlim" { limit = maxQueueLimit args = args[1:] continue } break } if len(args) < 2 { r.Fatal("not enough args") } t, err := newTunnel(limit, args, r.c.s.env) if err != nil { r.Fatal(err) } if name == "" { t.id = r.c.s.tunnels.add(t) } else { t.id = name r.c.s.tunnels[t.id] = t } log.Println(r.c, r, t, "create") } func tunnelDel(r *request) { r.expect(1) id := r.args[0] if t, ok := r.c.s.tunnels[id]; !ok { r.Fatal("no such entry") } else { t.(*tunnel).Close() delete(r.c.s.tunnels, id) } } func tunnelRename(r *request) { r.expect(2) old, new := r.args[0], r.args[1] if !isOkTunnelName(new) { r.Fatal("bad name") } if t, err := r.c.s.tunnels.rename(old, new); err != nil { r.Fatal(err) } else { t.(*tunnel).id = new } } func foreachTunnel(m automap, f func(t *tunnel)) { var keys []string for k := range m { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { f(m[k].(*tunnel)) } } func foreachStream(m map[int]*stream, f func(s *stream)) { var keys []int for k := range m { keys = append(keys, k) } sort.Ints(keys) for _, k := range keys { f(m[k]) } } func showTunnels(r *request) { foreachTunnel(r.c.s.tunnels, func(t *tunnel) { r.Println(t.id, t.args) }) } func showActive(r *request) { foreachTunnel(r.c.s.tunnels, func(t *tunnel) { t.mu.Lock() defer t.mu.Unlock() foreachStream(t.streams, func(s *stream) { r.Println(t.id, s.id, s.in, s.out, s.info()) }) }) } func showRecent(r *request) { var streams []*stream foreachTunnel(r.c.s.tunnels, func(t *tunnel) { t.mu.Lock() streams = append(streams, t.recent...) t.mu.Unlock() }) sort.SliceStable(streams, func(i, j int) bool { return streams[i].until.Before(streams[j].until) }) for _, s := range streams { when := s.until.Format(config.TimeFormat) r.Println(when, s.t.id, s.id, s.in, s.out, s.info()) } } func showHooks(r *request) { for _, h := range hook.GetList() { r.Println(h) } } func init() { newCmd(tunnelAdd, "add") newCmd(tunnelDel, "del") newCmd(tunnelRename, "rename") newCmd(showHooks, "hooks") newCmd(showTunnels, "show") newCmd(showActive, "active") newCmd(showRecent, "recent") }