diff options
Diffstat (limited to 'pkg/server/tunnel.go')
| -rw-r--r-- | pkg/server/tunnel.go | 294 |
1 files changed, 294 insertions, 0 deletions
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go new file mode 100644 index 0000000..b1c3fe8 --- /dev/null +++ b/pkg/server/tunnel.go @@ -0,0 +1,294 @@ +package server + +import ( + "tunnel/pkg/server/module" + "tunnel/pkg/server/queue" + "tunnel/pkg/server/socket" + "tunnel/pkg/config" + "sync/atomic" + "strings" + "time" + "sort" + "sync" + "fmt" + "log" +) + +type stream struct { + id int + n int32 + t *tunnel + since time.Time + in, out socket.Channel +} + +type tunnel struct { + id int + args string + + streams map[int]*stream + + mu sync.Mutex + + nextSid int + + in, out socket.S + m []module.M +} + +func (s *stream) String() string { + return fmt.Sprintf("stream(%d)", s.id) +} + +func (t *tunnel) String() string { + return fmt.Sprintf("tunnel(%d)", t.id) +} + +func (t *tunnel) Close() { + t.in.Close() + t.out.Close() +} + +func (t *tunnel) run() { + for { + if in, err := t.in.Open(); err != nil { + log.Println(t, err) + time.Sleep(5 * time.Second) + } else { + log.Println(t, "open", in) + go t.run2(in) + } + } +} + +func (t *tunnel) newStream(in, out socket.Channel) *stream { + s := &stream{ + t: t, + in: in, + out: out, + id: t.nextSid, + since: time.Now(), + } + + t.mu.Lock() + t.nextSid++ + t.streams[s.id] = s + t.mu.Unlock() + + return s +} + +func (s *stream) ref() { + atomic.AddInt32(&s.n, 1) +} + +func (s *stream) unref() { + if atomic.AddInt32(&s.n, -1) == 0 { + log.Println(s.t, s, "close") + + s.t.mu.Lock() + delete(s.t.streams, s.id) + s.t.mu.Unlock() + } +} + +func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { + watch := func (q queue.Q, f func (q queue.Q) error) { + s.ref() + defer s.unref() + + if err := f(q); err != nil { + log.Println(s.t, s, err) + } + } + + go func () { + watch(wq, c.Send) + close(wq) + }() + + go watch(rq, c.Recv) +} + +func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) { + go func () { + s.ref() + defer s.unref() + + if err := f(rq, wq); err != nil { + log.Println(s.t, s, err) + } + + close(wq) + }() +} + +func (t *tunnel) run2(in socket.Channel) { + out, err := t.out.Open() + if err != nil { + log.Println(t, err) + in.Close() + return + } + + log.Println(t, "open", out) + + s := t.newStream(in, out) + + s.ref() + defer s.unref() + + rq, wq := queue.New(), queue.New() + + s.watchChannel(rq, wq, in) + + for _, m := range t.m { + send, recv := m.Open() + if send != nil { + q := queue.New() + s.watchPipe(wq, q, send) + wq = q + } + if recv != nil { + q := queue.New() + s.watchPipe(q, rq, recv) + rq = q + } + } + + s.watchChannel(wq, rq, out) + + log.Println(t, s, "create", in, out) +} + +func newTunnel(args []string) (*tunnel, error) { + var in, out socket.S + var err error + + n := len(args) - 1 + + if in, err = socket.New(args[0]); err != nil { + return nil, err + } + + if out, err = socket.New(args[n]); err != nil { + in.Close() + return nil, err + } + + t := &tunnel{ + args: strings.Join(args, " "), + in: in, + out: out, + streams: make(map[int]*stream), + } + + reverse := false + + for _, arg := range args[1:n] { + var m module.M + + if arg == "-" { + reverse = true + continue + } + + if arg == "+" { + reverse = false + continue + } + + if m, err = module.New(arg); err != nil { + t.Close() + return nil, err + } + + if reverse { + m = module.Reverse(m) + reverse = false + } + + t.m = append(t.m, m) + } + + if reverse { + t.Close() + return nil, fmt.Errorf("bad '-' usage") + } + + go t.run() + + return t, nil +} + +func tunnelAdd(r *request) { + if r.argc < 2 { + r.Fatal("not enough args") + } + + t, err := newTunnel(r.args) + if err != nil { + r.Fatal(err) + } + + log.Println(r.c, r, t, "create") + + t.id = r.c.s.tunnels.add(t) +} + +func foreachTunnel(m automap, f func (t *tunnel)) { + var keys []int + + for k := range m { + keys = append(keys, k) + } + + sort.Ints(keys) + + for _, k := range keys { + f(m[k].(*tunnel)) + } +} + +func foreachStream(m map[int]*stream, f func (s *stream)) { + var keys []int + + for k := range m { + keys = append(keys, k) + } + + sort.Ints(keys) + + for _, k := range keys { + f(m[k]) + } +} + +func tunnelShow(r *request) { + foreachTunnel(r.c.s.tunnels, func (t *tunnel) { + r.Println(t.id, t.args) + }) +} + +func streamShow(r *request) { + foreachTunnel(r.c.s.tunnels, func (t *tunnel) { + r.Println(t.id, t.args) + + t.mu.Lock() + if len(t.streams) == 0 { + r.Println("\t", "nothing") + } else { + foreachStream(t.streams, func (s *stream) { + tm := s.since.Format(config.TimeFormat) + r.Println("\t", s.id, tm, s.in, s.out) + }) + } + t.mu.Unlock() + }) +} + +func init() { + newCmd(tunnelAdd, "add") + newCmd(tunnelShow, "show") + newCmd(streamShow, "stream show") +} |
