package server import ( "tunnel/pkg/server/socket" "tunnel/pkg/server/module" "tunnel/pkg/server/queue" "tunnel/pkg/server/env" "tunnel/pkg/config" "strings" "time" "sort" "sync" "fmt" "log" ) type stream struct { id int t *tunnel since time.Time wg sync.WaitGroup in, out socket.Channel } type tunnel struct { id string args string streams map[int]*stream mu sync.Mutex wg sync.WaitGroup nextSid int quit chan struct{} done chan struct{} in, out socket.S m []module.M env env.Env } func (s *stream) String() string { return fmt.Sprintf("stream:%d", s.id) } func (t *tunnel) String() string { return fmt.Sprintf("tunnel:%s", t.id) } func (t *tunnel) stopServe() { close(t.quit) t.in.Close() t.out.Close() <-t.done } func (t *tunnel) stopStreams() { t.mu.Lock() for _, s := range t.streams { s.stop() } t.mu.Unlock() t.wg.Wait() } func (t *tunnel) Close() { t.stopServe() t.stopStreams() log.Println(t, "delete") } func (t *tunnel) isQuit() bool { select { case <-t.quit: return true default: return false } } func (t *tunnel) serve() { var wg sync.WaitGroup for { if in, err := t.in.Open(t.env); err != nil { if t.isQuit() { break } log.Println(t, err) time.Sleep(5 * time.Second) } else { log.Println(t, "open", in) wg.Add(1) go func () { t.handle(in) wg.Done() }() } } wg.Wait() close(t.done) } func (t *tunnel) handle(in socket.Channel) { out, err := t.out.Open(t.env) if err != nil { log.Println(t, err) in.Close() return } log.Println(t, "open", out) s := t.newStream(in, out) log.Println(t, s, "create", in, out) } func (t *tunnel) newStream(in, out socket.Channel) *stream { s := &stream{ t: t, in: in, out: out, id: t.nextSid, since: time.Now(), } s.run() t.mu.Lock() t.nextSid++ t.streams[s.id] = s t.mu.Unlock() go func () { s.wg.Wait() s.t.mu.Lock() delete(s.t.streams, s.id) s.t.mu.Unlock() s.t.wg.Done() log.Println(s.t, s, "close") }() return s } func (s *stream) channel(c socket.Channel, rq, wq queue.Q) { watch := func (q queue.Q, f func (q queue.Q) error) { defer s.wg.Done() if err := f(q); err != nil { log.Println(s.t, s, err) } } s.wg.Add(2) go func () { watch(wq, c.Send) close(wq) }() go func () { watch(rq, c.Recv) rq.Dry() }() } func (s *stream) pipe(m module.M, p module.Pipe, rq, wq queue.Q) { s.wg.Add(1) go func () { defer s.wg.Done() if err := p(rq, wq); err != nil { log.Println(s.t, s, m, err) } close(wq) rq.Dry() }() } func (s *stream) run() { env := s.t.env.Copy() env.Set("tunnel", s.t.id) s.t.wg.Add(1) rq, wq := queue.New(), queue.New() s.channel(s.in, rq, wq) for _, m := range s.t.m { send, recv := m.Open(env) if send != nil { q := queue.New() s.pipe(m, send, wq, q) wq = q } if recv != nil { q := queue.New() s.pipe(m, recv, q, rq) rq = q } } s.channel(s.out, wq, rq) } func (s *stream) stop() { s.in.Close() s.out.Close() } func parseModules(args []string, env env.Env) ([]module.M, error) { var mm []module.M reverse := false for _, arg := range args { var m module.M var err error if arg == "-" { reverse = true continue } if arg == "+" { reverse = false continue } if m, err = module.New(arg, env); err != nil { return nil, err } if reverse { m = module.Reverse(m) reverse = false } mm = append(mm, m) } if reverse { return nil, fmt.Errorf("bad '-' usage") } return mm, nil } func newTunnel(args []string, env env.Env) (*tunnel, error) { var in, out socket.S var mm []module.M var err error n := len(args) - 1 if in, err = socket.New(args[0], env); err != nil { return nil, err } if out, err = socket.New(args[n], env); err != nil { in.Close() return nil, err } if mm, err = parseModules(args[1:n], env); err != nil { in.Close() out.Close() return nil, err } t := &tunnel{ args: strings.Join(args, " "), quit: make(chan struct{}), done: make(chan struct{}), m: mm, in: in, out: out, env: env, streams: make(map[int]*stream), } go t.serve() return t, nil } func isOkTunnelName(s string) bool { return s != "" } func tunnelAdd(r *request) { args := r.args name := "" if len(args) >= 2 { if args[0] == "name" { name = args[1] if !isOkTunnelName(name) { r.Fatal("bad name") } if _, ok := r.c.s.tunnels[name]; ok { r.Fatal("already exists") } args = args[2:] } } if len(args) < 2 { r.Fatal("not enough args") } t, err := newTunnel(args, r.c.s.env) if err != nil { r.Fatal(err) } if name == "" { t.id = r.c.s.tunnels.add(t) } else { t.id = name r.c.s.tunnels[t.id] = t } log.Println(r.c, r, t, "create") } func tunnelDel(r *request) { r.expect(1) id := r.args[0] if t, ok := r.c.s.tunnels[id]; !ok { r.Fatal("no such entry") } else { t.(*tunnel).Close() delete(r.c.s.tunnels, id) } } func tunnelRename(r *request) { r.expect(2) old, new := r.args[0], r.args[1] if !isOkTunnelName(new) { r.Fatal("bad name") } if t, err := r.c.s.tunnels.rename(old, new); err != nil { r.Fatal(err) } else { t.(*tunnel).id = new } } func foreachTunnel(m automap, f func (t *tunnel)) { var keys []string for k := range m { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { f(m[k].(*tunnel)) } } func foreachStream(m map[int]*stream, f func (s *stream)) { var keys []int for k := range m { keys = append(keys, k) } sort.Ints(keys) for _, k := range keys { f(m[k]) } } func tunnelShow(r *request) { foreachTunnel(r.c.s.tunnels, func (t *tunnel) { r.Println(t.id, t.args) }) } func streamShow(r *request) { foreachTunnel(r.c.s.tunnels, func (t *tunnel) { t.mu.Lock() defer t.mu.Unlock() if len(t.streams) > 0 { r.Println(t.id, t.args) foreachStream(t.streams, func (s *stream) { tm := s.since.Format(config.TimeFormat) r.Println("\t", s.id, tm, s.in, s.out) }) } }) } func init() { newCmd(tunnelAdd, "add") newCmd(tunnelDel, "del") newCmd(tunnelRename, "rename") newCmd(tunnelShow, "show") newCmd(streamShow, "streams") }