diff options
| -rw-r--r-- | pkg/server/env/env.go | 42 | ||||
| -rw-r--r-- | pkg/server/hook/auth.go | 26 | ||||
| -rw-r--r-- | pkg/server/socket/auto.go | 80 | ||||
| -rw-r--r-- | pkg/server/socket/dial.go | 42 | ||||
| -rw-r--r-- | pkg/server/socket/listen.go | 49 | ||||
| -rw-r--r-- | pkg/server/socket/loop.go | 46 | ||||
| -rw-r--r-- | pkg/server/socket/socket.go | 145 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 107 |
8 files changed, 334 insertions, 203 deletions
diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go index 118f853..16f5760 100644 --- a/pkg/server/env/env.go +++ b/pkg/server/env/env.go @@ -14,6 +14,7 @@ type env struct { type Env struct { *env + p *env } const namePattern = "[a-zA-Z][a-zA-Z0-9.]*" @@ -25,7 +26,7 @@ var errBadVariable = errors.New("bad variable name") var errEmptyVariable = errors.New("empty variable") func New() Env { - return Env{new(env)} + return Env{env: new(env)} } func (e *env) init() { @@ -34,18 +35,10 @@ func (e *env) init() { } } -func (e *env) Copy() Env { - c := New() - - if len(e.m) > 0 { - c.init() - - for k, v := range e.m { - c.m[k] = v - } - } - - return c +func (e *env) Fork() Env { + t := New() + t.p = e + return t } func (e *env) Find(key string) (string, bool) { @@ -57,7 +50,17 @@ func (e *env) Find(key string) (string, bool) { return v, ok } -func (e *env) Get(key string) string { +func (e Env) Find(key string) (string, bool) { + if e.p != nil { + if v, ok := e.p.Find(key); ok { + return v, ok + } + } + + return e.env.Find(key) +} + +func (e Env) Get(key string) string { v, _ := e.Find(key) return v } @@ -124,20 +127,13 @@ func (e *env) Clear() { e.m = nil } -func (e *env) Eval(s string) string { - e.Lock() - defer e.Unlock() - +func (e Env) Eval(s string) string { repl := func(v string) string { key := v[1:] if key[0] == '{' { key = key[1 : len(key)-1] } - - if v, ok := e.m[key]; ok { - return v - } - return "" + return e.Get(key) } for { diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go index 7f02816..f347b2a 100644 --- a/pkg/server/hook/auth.go +++ b/pkg/server/hook/auth.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "errors" "io" + "sync" "time" "tunnel/pkg/netstring" "tunnel/pkg/server/env" @@ -16,6 +17,8 @@ const authTimeout = 5 * time.Second const challengeLen = 16 type auth struct { + h *authHook + secret string challenge struct { @@ -34,11 +37,13 @@ type auth struct { ok chan struct{} } -var errDupChallenge = errors.New("peer duplicates challenge") +var errDupChallenge = errors.New("peer repeats challenge") var errAuthFail = errors.New("peer auth fail") var errTimeout = errors.New("timeout") -type authHook struct{} +type authHook struct { + m sync.Map +} func (a *auth) generateChallenge() error { b := make([]byte, challengeLen) @@ -48,14 +53,20 @@ func (a *auth) generateChallenge() error { a.challenge.self = string(b) + a.h.m.Store(a.challenge.self, struct{}{}) + return nil } +func (a *auth) deleteChallenge() { + a.h.m.Delete(a.challenge.self) +} + func (a *auth) getHash(c string) string { h := md5.New() - io.WriteString(h, a.secret) io.WriteString(h, c) + io.WriteString(h, a.secret) return string(h.Sum(nil)) } @@ -78,13 +89,15 @@ func (a *auth) Send(rq, wq queue.Q) error { return err } + defer a.deleteChallenge() + e.Encode(a.challenge.self) if err := a.wait(a.recvChallenge); err != nil { return err } - if a.challenge.self == a.challenge.peer { + if _, ok := a.h.m.Load(a.challenge.peer); ok { return errDupChallenge } @@ -131,8 +144,9 @@ func (a *auth) Recv(rq, wq queue.Q) (err error) { return queue.IoCopy(r, wq.Writer()) } -func (authHook) Open(env env.Env) (interface{}, error) { +func (h *authHook) Open(env env.Env) (interface{}, error) { a := &auth{ + h: h, secret: getHookVar(env, "secret"), recvChallenge: make(chan struct{}), recvHash: make(chan struct{}), @@ -144,7 +158,7 @@ func (authHook) Open(env env.Env) (interface{}, error) { } func newAuthHook(opts.Opts, env.Env) (hook, error) { - return authHook{}, nil + return &authHook{}, nil } func init() { diff --git a/pkg/server/socket/auto.go b/pkg/server/socket/auto.go new file mode 100644 index 0000000..97bc625 --- /dev/null +++ b/pkg/server/socket/auto.go @@ -0,0 +1,80 @@ +package socket + +import ( + "tunnel/pkg/server/env" + "tunnel/pkg/server/queue" +) + +type autoSocket struct { + S +} + +type autoChannel struct { + s *autoSocket + c chan Channel + e env.Env +} + +func newAutoSocket(proto, addr string) (S, error) { + s, err := newDialSocket(proto, addr) + if err != nil { + return s, err + } + + return &autoSocket{s}, nil +} + +func (s *autoSocket) Open(env env.Env) (Channel, error) { + c := &autoChannel{ + s: s, + c: make(chan Channel), + e: env, + } + + return c, nil +} + +func (c *autoChannel) String() string { + return "auto" +} + +func (c *autoChannel) Send(wq queue.Q) error { + if x := <-c.c; x == nil { + return nil + } else { + return x.Send(wq) + } +} + +func (c *autoChannel) Recv(rq queue.Q) error { + b := <-rq + if b == nil { + close(c.c) + return nil + } + + x, err := c.s.S.Open(c.e) + if err != nil { + close(c.c) + return err + } + + c.c <- x + + q := queue.New() + + go func() { + q <- b + queue.Copy(rq, q) + close(q) + }() + + defer q.Dry() + + return x.Recv(q) +} + +/* TODO */ +func (c *autoChannel) Close() error { + return nil +} diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go new file mode 100644 index 0000000..818fbc6 --- /dev/null +++ b/pkg/server/socket/dial.go @@ -0,0 +1,42 @@ +package socket + +import ( + "fmt" + "net" + "strings" + "tunnel/pkg/server/env" +) + +type dialSocket struct { + proto, addr string +} + +func newDialSocket(proto, addr string) (S, error) { + switch proto { + case "tcp", "udp": + if !strings.Contains(addr, ":") { + addr = "localhost:" + addr + } + } + + return &dialSocket{proto: proto, addr: addr}, nil +} + +func (s *dialSocket) String() string { + return fmt.Sprintf("%s/%s", s.proto, s.addr) +} + +func (s *dialSocket) Open(env.Env) (Channel, error) { + conn, err := net.Dial(s.proto, s.addr) + if err != nil { + return nil, err + } + + addr := conn.RemoteAddr() + info := fmt.Sprintf(">%s/%s", addr.Network(), addr) + + return exported{info, newConn(conn)}, nil +} + +func (s *dialSocket) Close() { +} diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go new file mode 100644 index 0000000..c328945 --- /dev/null +++ b/pkg/server/socket/listen.go @@ -0,0 +1,49 @@ +package socket + +import ( + "fmt" + "net" + "strings" + "tunnel/pkg/server/env" +) + +func newListenSocket(proto, addr string) (S, error) { + if proto == "tcp" { + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + } + + listen, err := net.Listen(proto, addr) + if err != nil { + return nil, err + } + + s := &listenSocket{ + proto: proto, + addr: addr, + listen: listen, + } + + return s, nil +} + +func (s *listenSocket) Open(env.Env) (Channel, error) { + conn, err := s.listen.Accept() + if err != nil { + return nil, err + } + + addr := conn.RemoteAddr() + info := fmt.Sprintf("<%s/%s", addr.Network(), addr) + + return exported{info, newConn(conn)}, nil +} + +func (s *listenSocket) String() string { + return fmt.Sprintf("%s/%s,listen", s.proto, s.addr) +} + +func (s *listenSocket) Close() { + s.listen.Close() +} diff --git a/pkg/server/socket/loop.go b/pkg/server/socket/loop.go new file mode 100644 index 0000000..88e9491 --- /dev/null +++ b/pkg/server/socket/loop.go @@ -0,0 +1,46 @@ +package socket + +import ( + "tunnel/pkg/server/env" + "tunnel/pkg/server/queue" +) + +type loopSocket struct{} + +type loopChannel struct { + c chan queue.Q + q chan error +} + +func (c *loopChannel) Send(wq queue.Q) error { + c.c <- wq + return <-c.q +} + +func (c *loopChannel) Recv(rq queue.Q) error { + defer close(c.q) + return queue.Copy(rq, <-c.c) +} + +func (c *loopChannel) String() string { + return "loop" +} + +func (c *loopChannel) Close() error { + return nil +} + +func (s *loopSocket) Open(env.Env) (Channel, error) { + return &loopChannel{make(chan queue.Q), make(chan error)}, nil +} + +func (s *loopSocket) String() string { + return "loop" +} + +func (s *loopSocket) Close() { +} + +func newLoopSocket() (S, error) { + return &loopSocket{}, nil +} diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index 5ccf1bd..a945ce0 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -14,7 +14,7 @@ import ( var errAlreadyClosed = errors.New("already closed") -type exportChannel struct { +type exported struct { info string Channel } @@ -35,30 +35,22 @@ type listenSocket struct { listen net.Listener } -type dialSocket struct { - proto, addr string -} - -type connChannel struct { - conn net.Conn +type conn struct { + net.Conn once sync.Once } -type loopSocket struct{} - -type loopChannel struct { - q queue.Q -} - -func (c exportChannel) String() string { +func (c exported) String() string { return c.info } -func newConnChannel(conn net.Conn) Channel { - return &connChannel{conn: conn} +func newConn(cn net.Conn) Channel { + c := &conn{Conn: cn} + log.Println("open", c) + return c } -func (c *connChannel) final(f func() error, err error) error { +func (c *conn) final(f func() error, err error) error { if e := f(); e != nil { if e == errAlreadyClosed { return nil @@ -70,131 +62,32 @@ func (c *connChannel) final(f func() error, err error) error { return err } -func (c *connChannel) Send(wq queue.Q) error { - err := queue.IoCopy(c.conn, wq.Writer()) +func (c *conn) Send(wq queue.Q) error { + err := queue.IoCopy(c, wq.Writer()) return c.final(c.Close, err) } -func (c *connChannel) Recv(rq queue.Q) error { - err := queue.IoCopy(rq.Reader(), c.conn) +func (c *conn) Recv(rq queue.Q) error { + err := queue.IoCopy(rq.Reader(), c) return c.final(c.Close, err) } -func (c *connChannel) String() string { - local, remote := c.conn.LocalAddr(), c.conn.RemoteAddr() +func (c *conn) String() string { + local, remote := c.LocalAddr(), c.RemoteAddr() return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote) } -func (c *connChannel) Close() error { +func (c *conn) Close() error { err := errAlreadyClosed c.once.Do(func() { log.Println("close", c) - err = c.conn.Close() + err = c.Conn.Close() }) return err } -func newListenSocket(proto, addr string) (S, error) { - if proto == "tcp" { - if !strings.Contains(addr, ":") { - addr = ":" + addr - } - } - - listen, err := net.Listen(proto, addr) - if err != nil { - return nil, err - } - - s := &listenSocket{ - proto: proto, - addr: addr, - listen: listen, - } - - return s, nil -} - -func (s *listenSocket) Open(env env.Env) (Channel, error) { - conn, err := s.listen.Accept() - if err != nil { - return nil, err - } - - addr := conn.RemoteAddr() - info := fmt.Sprintf("%s/%s", addr.Network(), addr) - - return exportChannel{info, newConnChannel(conn)}, nil -} - -func (s *listenSocket) String() string { - return fmt.Sprintf("%s/%s,listen", s.proto, s.addr) -} - -func (s *listenSocket) Close() { - s.listen.Close() -} - -func newDialSocket(proto, addr string) (S, error) { - switch proto { - case "tcp", "udp": - if !strings.Contains(addr, ":") { - addr = "localhost:" + addr - } - } - - return &dialSocket{proto: proto, addr: addr}, nil -} - -func (s *dialSocket) String() string { - return fmt.Sprintf("%s/%s", s.proto, s.addr) -} - -func (s *dialSocket) Open(env env.Env) (Channel, error) { - conn, err := net.Dial(s.proto, s.addr) - if err != nil { - return nil, err - } - return exportChannel{"-", newConnChannel(conn)}, nil -} - -func (s *dialSocket) Close() { -} - -func (c *loopChannel) Send(wq queue.Q) error { - return queue.Copy(c.q, wq) -} - -func (c *loopChannel) Recv(rq queue.Q) error { - defer close(c.q) - return queue.Copy(rq, c.q) -} - -func (c *loopChannel) Close() error { - return nil -} - -func (c *loopChannel) String() string { - return "loop" -} - -func (s *loopSocket) Open(env.Env) (Channel, error) { - return &loopChannel{queue.New()}, nil -} - -func (s *loopSocket) String() string { - return "loop" -} - -func (s *loopSocket) Close() { -} - -func newLoopSocket() (S, error) { - return &loopSocket{}, nil -} - func New(desc string, env env.Env) (S, error) { base, opts := opts.Parse(desc) args := strings.SplitN(base, "/", 2) @@ -224,5 +117,9 @@ func New(desc string, env env.Env) (S, error) { return newListenSocket(proto, addr) } + if _, ok := opts["auto"]; ok { + return newAutoSocket(proto, addr) + } + return newDialSocket(proto, addr) } diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index d59d9db..592f48c 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -19,6 +19,7 @@ import ( ) const maxRecentSize = 8 +const maxQueueLimit = 16384 type metric struct { tx uint64 @@ -56,7 +57,7 @@ type tunnel struct { quit chan struct{} done chan struct{} - mono bool + queue chan struct{} in, out socket.S hooks []hook.H @@ -82,6 +83,7 @@ func (t *tunnel) stopServe() { func (t *tunnel) stopStreams() { t.mu.Lock() for _, s := range t.streams { + log.Println(s, "stop") s.stop() } t.mu.Unlock() @@ -96,70 +98,72 @@ func (t *tunnel) Close() { log.Println(t, "delete") } -func (t *tunnel) isQuit() bool { +func (t *tunnel) alive() bool { select { case <-t.quit: - return true - default: return false + default: + return true } } -func (t *tunnel) serve() { - var wg sync.WaitGroup +func (t *tunnel) acquire() bool { + select { + case t.queue <- struct{}{}: + return true + case <-t.quit: + return false + } +} - for { - if in, err := t.in.Open(t.env); err != nil { - if t.isQuit() { - break - } +func (t *tunnel) release() { + <-t.queue +} - log.Println(t, err) - time.Sleep(5 * time.Second) - } else { - log.Println(t, "open", in) +func (t *tunnel) sleep(d time.Duration) { + tmr := time.NewTimer(d) + select { + case <-tmr.C: + case <-t.quit: + } + tmr.Stop() +} - wg.Add(1) +func (t *tunnel) serve() { + for t.acquire() { + var ok bool - go func() { - t.handle(in) - wg.Done() - }() + env := t.env.Fork() - if t.mono { - wg.Wait() - t.wg.Wait() + if in, err := t.in.Open(env); err == nil { + if out, err := t.out.Open(env); err == nil { + s := t.newStream(env, in, out) + log.Println(t, s, "create", in, out) + ok = true + } else { + log.Println(t, err) + in.Close() } + } else if t.alive() { + log.Println(t, err) + t.sleep(5 * time.Second) } - } - - wg.Wait() - close(t.done) -} - -func (t *tunnel) handle(in socket.Channel) { - out, err := t.out.Open(t.env) - if err != nil { - log.Println(t, err) - in.Close() - return + if !ok { + t.release() + } } - log.Println(t, "open", out) - - s := t.newStream(in, out) - - log.Println(t, s, "create", in, out) + close(t.done) } -func (t *tunnel) newStream(in, out socket.Channel) *stream { +func (t *tunnel) newStream(env env.Env, in, out socket.Channel) *stream { s := &stream{ t: t, in: in, out: out, + env: env, id: t.nextSid, - env: t.env.Copy(), since: time.Now(), } @@ -193,7 +197,7 @@ func (t *tunnel) delStream(s *stream) { func (s *stream) info() string { d := time.Since(s.since).Milliseconds() - return fmt.Sprintf("%.3fms %d/%d -> %d/%d", + return fmt.Sprintf("%.3fs %d/%d -> %d/%d", float64(d)/1000.0, s.m.in.tx, s.m.in.rx, @@ -207,9 +211,12 @@ func (s *stream) waitAndClose() { s.until = time.Now() s.t.delStream(s) - + s.t.release() s.t.wg.Done() + s.in.Close() + s.out.Close() + for _, p := range s.pipes { p.Close() } @@ -318,7 +325,7 @@ func parseHooks(args []string, env env.Env) ([]hook.H, error) { return hooks, nil } -func newTunnel(mono bool, args []string, env env.Env) (*tunnel, error) { +func newTunnel(limit int, args []string, env env.Env) (*tunnel, error) { var in, out socket.S var hooks []hook.H var err error @@ -344,11 +351,11 @@ func newTunnel(mono bool, args []string, env env.Env) (*tunnel, error) { args: strings.Join(args, " "), quit: make(chan struct{}), done: make(chan struct{}), - mono: mono, hooks: hooks, in: in, out: out, env: env, + queue: make(chan struct{}, limit), streams: make(map[int]*stream), } @@ -364,10 +371,10 @@ func isOkTunnelName(s string) bool { func tunnelAdd(r *request) { args := r.args name := "" - mono := false + limit := maxQueueLimit for len(args) > 1 { - if args[0] == "name" && len(args) > 1 { + if args[0] == "name" { name = args[1] if !isOkTunnelName(name) { r.Fatal("bad name") @@ -382,7 +389,7 @@ func tunnelAdd(r *request) { } if args[0] == "mono" { - mono = true + limit = 1 args = args[1:] continue } @@ -394,7 +401,7 @@ func tunnelAdd(r *request) { r.Fatal("not enough args") } - t, err := newTunnel(mono, args, r.c.s.env) + t, err := newTunnel(limit, args, r.c.s.env) if err != nil { r.Fatal(err) } |
