package server import ( "errors" "fmt" "io" "log" "sort" "strconv" "strings" "sync" "time" "tunnel/pkg/config" "tunnel/pkg/server/env" "tunnel/pkg/server/hook" "tunnel/pkg/server/queue" "tunnel/pkg/server/socket" ) type stream struct { id int t *tunnel env env.Env since time.Time wg sync.WaitGroup in, out socket.Channel } type tunnel struct { id string args string streams map[int]*stream mu sync.Mutex wg sync.WaitGroup nextSid int quit chan struct{} done chan struct{} in, out socket.S hooks []hook.H env env.Env } func (s *stream) String() string { return fmt.Sprintf("stream:%d", s.id) } func (t *tunnel) String() string { return fmt.Sprintf("tunnel:%s", t.id) } func (t *tunnel) stopServe() { close(t.quit) t.in.Close() t.out.Close() <-t.done } func (t *tunnel) stopStreams() { t.mu.Lock() for _, s := range t.streams { s.stop() } t.mu.Unlock() t.wg.Wait() } func (t *tunnel) Close() { t.stopServe() t.stopStreams() log.Println(t, "delete") } func (t *tunnel) isQuit() bool { select { case <-t.quit: return true default: return false } } func (t *tunnel) serve() { var wg sync.WaitGroup for { if in, err := t.in.Open(t.env); err != nil { if t.isQuit() { break } log.Println(t, err) time.Sleep(5 * time.Second) } else { log.Println(t, "open", in) wg.Add(1) go func() { t.handle(in) wg.Done() }() } } wg.Wait() close(t.done) } func (t *tunnel) handle(in socket.Channel) { out, err := t.out.Open(t.env) if err != nil { log.Println(t, err) in.Close() return } log.Println(t, "open", out) s := t.newStream(in, out) log.Println(t, s, "create", in, out) } func (t *tunnel) newStream(in, out socket.Channel) *stream { s := &stream{ t: t, in: in, out: out, id: t.nextSid, env: t.env.Copy(), since: time.Now(), } s.env.Set("tunnel", t.id) s.env.Set("stream", strconv.Itoa(s.id)) s.run() t.mu.Lock() t.nextSid++ t.streams[s.id] = s t.mu.Unlock() go func() { s.wg.Wait() s.t.mu.Lock() delete(s.t.streams, s.id) s.t.mu.Unlock() s.t.wg.Done() log.Println(s.t, s, "close") }() return s } func (s *stream) channel(c socket.Channel, rq, wq queue.Q) { watch := func(q queue.Q, f func(q queue.Q) error) { defer s.wg.Done() if err := f(q); err != nil && !errors.Is(err, io.EOF) { log.Println(s.t, s, err) } } s.wg.Add(2) go func() { watch(wq, c.Send) close(wq) }() go func() { watch(rq, c.Recv) rq.Dry() }() } func (s *stream) pipe(h hook.H, f hook.Func, rq, wq queue.Q) { s.wg.Add(1) go func() { defer s.wg.Done() if err := f(rq, wq); err != nil && !errors.Is(err, io.EOF) { log.Println(s.t, s, h, err) } close(wq) rq.Dry() }() } func (s *stream) run() { s.t.wg.Add(1) rq, wq := queue.New(), queue.New() s.channel(s.in, rq, wq) for _, h := range s.t.hooks { send, recv, err := hook.Open(h, s.env) if err != nil { // FIXME: abort stream on error log.Println(s.t, s, h, err) continue } if send != nil { q := queue.New() s.pipe(h, send, wq, q) wq = q } if recv != nil { q := queue.New() s.pipe(h, recv, q, rq) rq = q } } s.channel(s.out, wq, rq) } func (s *stream) stop() { s.in.Close() s.out.Close() } func parseHooks(args []string, env env.Env) ([]hook.H, error) { var hooks []hook.H for _, arg := range args { if h, err := hook.New(arg, env); err != nil { return nil, err } else { hooks = append(hooks, h) } } return hooks, nil } func newTunnel(args []string, env env.Env) (*tunnel, error) { var in, out socket.S var hooks []hook.H var err error n := len(args) - 1 if in, err = socket.New(args[0], env); err != nil { return nil, err } if out, err = socket.New(args[n], env); err != nil { in.Close() return nil, err } if hooks, err = parseHooks(args[1:n], env); err != nil { in.Close() out.Close() return nil, err } t := &tunnel{ args: strings.Join(args, " "), quit: make(chan struct{}), done: make(chan struct{}), hooks: hooks, in: in, out: out, env: env, streams: make(map[int]*stream), } go t.serve() return t, nil } func isOkTunnelName(s string) bool { return s != "" } func tunnelAdd(r *request) { args := r.args name := "" if len(args) >= 2 { if args[0] == "name" { name = args[1] if !isOkTunnelName(name) { r.Fatal("bad name") } if _, ok := r.c.s.tunnels[name]; ok { r.Fatal("already exists") } args = args[2:] } } if len(args) < 2 { r.Fatal("not enough args") } t, err := newTunnel(args, r.c.s.env) if err != nil { r.Fatal(err) } if name == "" { t.id = r.c.s.tunnels.add(t) } else { t.id = name r.c.s.tunnels[t.id] = t } log.Println(r.c, r, t, "create") } func tunnelDel(r *request) { r.expect(1) id := r.args[0] if t, ok := r.c.s.tunnels[id]; !ok { r.Fatal("no such entry") } else { t.(*tunnel).Close() delete(r.c.s.tunnels, id) } } func tunnelRename(r *request) { r.expect(2) old, new := r.args[0], r.args[1] if !isOkTunnelName(new) { r.Fatal("bad name") } if t, err := r.c.s.tunnels.rename(old, new); err != nil { r.Fatal(err) } else { t.(*tunnel).id = new } } func foreachTunnel(m automap, f func(t *tunnel)) { var keys []string for k := range m { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { f(m[k].(*tunnel)) } } func foreachStream(m map[int]*stream, f func(s *stream)) { var keys []int for k := range m { keys = append(keys, k) } sort.Ints(keys) for _, k := range keys { f(m[k]) } } func showTunnels(r *request) { foreachTunnel(r.c.s.tunnels, func(t *tunnel) { r.Println(t.id, t.args) }) } func showStreams(r *request) { foreachTunnel(r.c.s.tunnels, func(t *tunnel) { t.mu.Lock() defer t.mu.Unlock() if len(t.streams) > 0 { r.Println(t.id, t.args) foreachStream(t.streams, func(s *stream) { tm := s.since.Format(config.TimeFormat) r.Println("\t", s.id, tm, s.in, s.out) }) } }) } func showHooks(r *request) { for _, h := range hook.GetList() { r.Println(h) } } func init() { newCmd(tunnelAdd, "add") newCmd(tunnelDel, "del") newCmd(tunnelRename, "rename") newCmd(showHooks, "hooks") newCmd(showStreams, "streams") newCmd(showTunnels, "show") }