summaryrefslogtreecommitdiff
path: root/pkg/server/tunnel.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/tunnel.go')
-rw-r--r--pkg/server/tunnel.go294
1 files changed, 294 insertions, 0 deletions
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
new file mode 100644
index 0000000..b1c3fe8
--- /dev/null
+++ b/pkg/server/tunnel.go
@@ -0,0 +1,294 @@
+package server
+
+import (
+ "tunnel/pkg/server/module"
+ "tunnel/pkg/server/queue"
+ "tunnel/pkg/server/socket"
+ "tunnel/pkg/config"
+ "sync/atomic"
+ "strings"
+ "time"
+ "sort"
+ "sync"
+ "fmt"
+ "log"
+)
+
+type stream struct {
+ id int
+ n int32
+ t *tunnel
+ since time.Time
+ in, out socket.Channel
+}
+
+type tunnel struct {
+ id int
+ args string
+
+ streams map[int]*stream
+
+ mu sync.Mutex
+
+ nextSid int
+
+ in, out socket.S
+ m []module.M
+}
+
+func (s *stream) String() string {
+ return fmt.Sprintf("stream(%d)", s.id)
+}
+
+func (t *tunnel) String() string {
+ return fmt.Sprintf("tunnel(%d)", t.id)
+}
+
+func (t *tunnel) Close() {
+ t.in.Close()
+ t.out.Close()
+}
+
+func (t *tunnel) run() {
+ for {
+ if in, err := t.in.Open(); err != nil {
+ log.Println(t, err)
+ time.Sleep(5 * time.Second)
+ } else {
+ log.Println(t, "open", in)
+ go t.run2(in)
+ }
+ }
+}
+
+func (t *tunnel) newStream(in, out socket.Channel) *stream {
+ s := &stream{
+ t: t,
+ in: in,
+ out: out,
+ id: t.nextSid,
+ since: time.Now(),
+ }
+
+ t.mu.Lock()
+ t.nextSid++
+ t.streams[s.id] = s
+ t.mu.Unlock()
+
+ return s
+}
+
+func (s *stream) ref() {
+ atomic.AddInt32(&s.n, 1)
+}
+
+func (s *stream) unref() {
+ if atomic.AddInt32(&s.n, -1) == 0 {
+ log.Println(s.t, s, "close")
+
+ s.t.mu.Lock()
+ delete(s.t.streams, s.id)
+ s.t.mu.Unlock()
+ }
+}
+
+func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) {
+ watch := func (q queue.Q, f func (q queue.Q) error) {
+ s.ref()
+ defer s.unref()
+
+ if err := f(q); err != nil {
+ log.Println(s.t, s, err)
+ }
+ }
+
+ go func () {
+ watch(wq, c.Send)
+ close(wq)
+ }()
+
+ go watch(rq, c.Recv)
+}
+
+func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) {
+ go func () {
+ s.ref()
+ defer s.unref()
+
+ if err := f(rq, wq); err != nil {
+ log.Println(s.t, s, err)
+ }
+
+ close(wq)
+ }()
+}
+
+func (t *tunnel) run2(in socket.Channel) {
+ out, err := t.out.Open()
+ if err != nil {
+ log.Println(t, err)
+ in.Close()
+ return
+ }
+
+ log.Println(t, "open", out)
+
+ s := t.newStream(in, out)
+
+ s.ref()
+ defer s.unref()
+
+ rq, wq := queue.New(), queue.New()
+
+ s.watchChannel(rq, wq, in)
+
+ for _, m := range t.m {
+ send, recv := m.Open()
+ if send != nil {
+ q := queue.New()
+ s.watchPipe(wq, q, send)
+ wq = q
+ }
+ if recv != nil {
+ q := queue.New()
+ s.watchPipe(q, rq, recv)
+ rq = q
+ }
+ }
+
+ s.watchChannel(wq, rq, out)
+
+ log.Println(t, s, "create", in, out)
+}
+
+func newTunnel(args []string) (*tunnel, error) {
+ var in, out socket.S
+ var err error
+
+ n := len(args) - 1
+
+ if in, err = socket.New(args[0]); err != nil {
+ return nil, err
+ }
+
+ if out, err = socket.New(args[n]); err != nil {
+ in.Close()
+ return nil, err
+ }
+
+ t := &tunnel{
+ args: strings.Join(args, " "),
+ in: in,
+ out: out,
+ streams: make(map[int]*stream),
+ }
+
+ reverse := false
+
+ for _, arg := range args[1:n] {
+ var m module.M
+
+ if arg == "-" {
+ reverse = true
+ continue
+ }
+
+ if arg == "+" {
+ reverse = false
+ continue
+ }
+
+ if m, err = module.New(arg); err != nil {
+ t.Close()
+ return nil, err
+ }
+
+ if reverse {
+ m = module.Reverse(m)
+ reverse = false
+ }
+
+ t.m = append(t.m, m)
+ }
+
+ if reverse {
+ t.Close()
+ return nil, fmt.Errorf("bad '-' usage")
+ }
+
+ go t.run()
+
+ return t, nil
+}
+
+func tunnelAdd(r *request) {
+ if r.argc < 2 {
+ r.Fatal("not enough args")
+ }
+
+ t, err := newTunnel(r.args)
+ if err != nil {
+ r.Fatal(err)
+ }
+
+ log.Println(r.c, r, t, "create")
+
+ t.id = r.c.s.tunnels.add(t)
+}
+
+func foreachTunnel(m automap, f func (t *tunnel)) {
+ var keys []int
+
+ for k := range m {
+ keys = append(keys, k)
+ }
+
+ sort.Ints(keys)
+
+ for _, k := range keys {
+ f(m[k].(*tunnel))
+ }
+}
+
+func foreachStream(m map[int]*stream, f func (s *stream)) {
+ var keys []int
+
+ for k := range m {
+ keys = append(keys, k)
+ }
+
+ sort.Ints(keys)
+
+ for _, k := range keys {
+ f(m[k])
+ }
+}
+
+func tunnelShow(r *request) {
+ foreachTunnel(r.c.s.tunnels, func (t *tunnel) {
+ r.Println(t.id, t.args)
+ })
+}
+
+func streamShow(r *request) {
+ foreachTunnel(r.c.s.tunnels, func (t *tunnel) {
+ r.Println(t.id, t.args)
+
+ t.mu.Lock()
+ if len(t.streams) == 0 {
+ r.Println("\t", "nothing")
+ } else {
+ foreachStream(t.streams, func (s *stream) {
+ tm := s.since.Format(config.TimeFormat)
+ r.Println("\t", s.id, tm, s.in, s.out)
+ })
+ }
+ t.mu.Unlock()
+ })
+}
+
+func init() {
+ newCmd(tunnelAdd, "add")
+ newCmd(tunnelShow, "show")
+ newCmd(streamShow, "stream show")
+}