package server import ( "tunnel/pkg/server/module" "tunnel/pkg/server/queue" "tunnel/pkg/server/socket" "tunnel/pkg/config" "sync/atomic" "strings" "time" "sort" "sync" "fmt" "log" ) type stream struct { id int n int32 t *tunnel since time.Time in, out socket.Channel } type tunnel struct { id int args string streams map[int]*stream mu sync.Mutex nextSid int in, out socket.S m []module.M } func (s *stream) String() string { return fmt.Sprintf("stream(%d)", s.id) } func (t *tunnel) String() string { return fmt.Sprintf("tunnel(%d)", t.id) } func (t *tunnel) Close() { t.in.Close() t.out.Close() } func (t *tunnel) run() { for { if in, err := t.in.Open(); err != nil { log.Println(t, err) time.Sleep(5 * time.Second) } else { log.Println(t, "open", in) go t.run2(in) } } } func (t *tunnel) newStream(in, out socket.Channel) *stream { s := &stream{ t: t, in: in, out: out, id: t.nextSid, since: time.Now(), } t.mu.Lock() t.nextSid++ t.streams[s.id] = s t.mu.Unlock() return s } func (s *stream) ref() { atomic.AddInt32(&s.n, 1) } func (s *stream) unref() { if atomic.AddInt32(&s.n, -1) == 0 { log.Println(s.t, s, "close") s.t.mu.Lock() delete(s.t.streams, s.id) s.t.mu.Unlock() } } func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { watch := func (q queue.Q, f func (q queue.Q) error) { s.ref() defer s.unref() if err := f(q); err != nil { log.Println(s.t, s, err) } } go func () { watch(wq, c.Send) close(wq) }() go watch(rq, c.Recv) } func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) { go func () { s.ref() defer s.unref() if err := f(rq, wq); err != nil { log.Println(s.t, s, err) } close(wq) }() } func (t *tunnel) run2(in socket.Channel) { out, err := t.out.Open() if err != nil { log.Println(t, err) in.Close() return } log.Println(t, "open", out) s := t.newStream(in, out) s.ref() defer s.unref() rq, wq := queue.New(), queue.New() s.watchChannel(rq, wq, in) for _, m := range t.m { send, recv := m.Open() if send != nil { q := queue.New() s.watchPipe(wq, q, send) wq = q } if recv != nil { q := queue.New() s.watchPipe(q, rq, recv) rq = q } } s.watchChannel(wq, rq, out) log.Println(t, s, "create", in, out) } func newTunnel(args []string) (*tunnel, error) { var in, out socket.S var err error n := len(args) - 1 if in, err = socket.New(args[0]); err != nil { return nil, err } if out, err = socket.New(args[n]); err != nil { in.Close() return nil, err } t := &tunnel{ args: strings.Join(args, " "), in: in, out: out, streams: make(map[int]*stream), } reverse := false for _, arg := range args[1:n] { var m module.M if arg == "-" { reverse = true continue } if arg == "+" { reverse = false continue } if m, err = module.New(arg); err != nil { t.Close() return nil, err } if reverse { m = module.Reverse(m) reverse = false } t.m = append(t.m, m) } if reverse { t.Close() return nil, fmt.Errorf("bad '-' usage") } go t.run() return t, nil } func tunnelAdd(r *request) { if r.argc < 2 { r.Fatal("not enough args") } t, err := newTunnel(r.args) if err != nil { r.Fatal(err) } log.Println(r.c, r, t, "create") t.id = r.c.s.tunnels.add(t) } func foreachTunnel(m automap, f func (t *tunnel)) { var keys []int for k := range m { keys = append(keys, k) } sort.Ints(keys) for _, k := range keys { f(m[k].(*tunnel)) } } func foreachStream(m map[int]*stream, f func (s *stream)) { var keys []int for k := range m { keys = append(keys, k) } sort.Ints(keys) for _, k := range keys { f(m[k]) } } func tunnelShow(r *request) { foreachTunnel(r.c.s.tunnels, func (t *tunnel) { r.Println(t.id, t.args) }) } func streamShow(r *request) { foreachTunnel(r.c.s.tunnels, func (t *tunnel) { r.Println(t.id, t.args) t.mu.Lock() if len(t.streams) == 0 { r.Println("\t", "nothing") } else { foreachStream(t.streams, func (s *stream) { tm := s.since.Format(config.TimeFormat) r.Println("\t", s.id, tm, s.in, s.out) }) } t.mu.Unlock() }) } func init() { newCmd(tunnelAdd, "add") newCmd(tunnelShow, "show") newCmd(streamShow, "stream show") }