summaryrefslogtreecommitdiff
path: root/pkg/server
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server')
-rw-r--r--pkg/server/env/env.go42
-rw-r--r--pkg/server/hook/auth.go26
-rw-r--r--pkg/server/socket/auto.go80
-rw-r--r--pkg/server/socket/dial.go42
-rw-r--r--pkg/server/socket/listen.go49
-rw-r--r--pkg/server/socket/loop.go46
-rw-r--r--pkg/server/socket/socket.go145
-rw-r--r--pkg/server/tunnel.go107
8 files changed, 334 insertions, 203 deletions
diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go
index 118f853..16f5760 100644
--- a/pkg/server/env/env.go
+++ b/pkg/server/env/env.go
@@ -14,6 +14,7 @@ type env struct {
type Env struct {
*env
+ p *env
}
const namePattern = "[a-zA-Z][a-zA-Z0-9.]*"
@@ -25,7 +26,7 @@ var errBadVariable = errors.New("bad variable name")
var errEmptyVariable = errors.New("empty variable")
func New() Env {
- return Env{new(env)}
+ return Env{env: new(env)}
}
func (e *env) init() {
@@ -34,18 +35,10 @@ func (e *env) init() {
}
}
-func (e *env) Copy() Env {
- c := New()
-
- if len(e.m) > 0 {
- c.init()
-
- for k, v := range e.m {
- c.m[k] = v
- }
- }
-
- return c
+func (e *env) Fork() Env {
+ t := New()
+ t.p = e
+ return t
}
func (e *env) Find(key string) (string, bool) {
@@ -57,7 +50,17 @@ func (e *env) Find(key string) (string, bool) {
return v, ok
}
-func (e *env) Get(key string) string {
+func (e Env) Find(key string) (string, bool) {
+ if e.p != nil {
+ if v, ok := e.p.Find(key); ok {
+ return v, ok
+ }
+ }
+
+ return e.env.Find(key)
+}
+
+func (e Env) Get(key string) string {
v, _ := e.Find(key)
return v
}
@@ -124,20 +127,13 @@ func (e *env) Clear() {
e.m = nil
}
-func (e *env) Eval(s string) string {
- e.Lock()
- defer e.Unlock()
-
+func (e Env) Eval(s string) string {
repl := func(v string) string {
key := v[1:]
if key[0] == '{' {
key = key[1 : len(key)-1]
}
-
- if v, ok := e.m[key]; ok {
- return v
- }
- return ""
+ return e.Get(key)
}
for {
diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go
index 7f02816..f347b2a 100644
--- a/pkg/server/hook/auth.go
+++ b/pkg/server/hook/auth.go
@@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"io"
+ "sync"
"time"
"tunnel/pkg/netstring"
"tunnel/pkg/server/env"
@@ -16,6 +17,8 @@ const authTimeout = 5 * time.Second
const challengeLen = 16
type auth struct {
+ h *authHook
+
secret string
challenge struct {
@@ -34,11 +37,13 @@ type auth struct {
ok chan struct{}
}
-var errDupChallenge = errors.New("peer duplicates challenge")
+var errDupChallenge = errors.New("peer repeats challenge")
var errAuthFail = errors.New("peer auth fail")
var errTimeout = errors.New("timeout")
-type authHook struct{}
+type authHook struct {
+ m sync.Map
+}
func (a *auth) generateChallenge() error {
b := make([]byte, challengeLen)
@@ -48,14 +53,20 @@ func (a *auth) generateChallenge() error {
a.challenge.self = string(b)
+ a.h.m.Store(a.challenge.self, struct{}{})
+
return nil
}
+func (a *auth) deleteChallenge() {
+ a.h.m.Delete(a.challenge.self)
+}
+
func (a *auth) getHash(c string) string {
h := md5.New()
- io.WriteString(h, a.secret)
io.WriteString(h, c)
+ io.WriteString(h, a.secret)
return string(h.Sum(nil))
}
@@ -78,13 +89,15 @@ func (a *auth) Send(rq, wq queue.Q) error {
return err
}
+ defer a.deleteChallenge()
+
e.Encode(a.challenge.self)
if err := a.wait(a.recvChallenge); err != nil {
return err
}
- if a.challenge.self == a.challenge.peer {
+ if _, ok := a.h.m.Load(a.challenge.peer); ok {
return errDupChallenge
}
@@ -131,8 +144,9 @@ func (a *auth) Recv(rq, wq queue.Q) (err error) {
return queue.IoCopy(r, wq.Writer())
}
-func (authHook) Open(env env.Env) (interface{}, error) {
+func (h *authHook) Open(env env.Env) (interface{}, error) {
a := &auth{
+ h: h,
secret: getHookVar(env, "secret"),
recvChallenge: make(chan struct{}),
recvHash: make(chan struct{}),
@@ -144,7 +158,7 @@ func (authHook) Open(env env.Env) (interface{}, error) {
}
func newAuthHook(opts.Opts, env.Env) (hook, error) {
- return authHook{}, nil
+ return &authHook{}, nil
}
func init() {
diff --git a/pkg/server/socket/auto.go b/pkg/server/socket/auto.go
new file mode 100644
index 0000000..97bc625
--- /dev/null
+++ b/pkg/server/socket/auto.go
@@ -0,0 +1,80 @@
+package socket
+
+import (
+ "tunnel/pkg/server/env"
+ "tunnel/pkg/server/queue"
+)
+
+type autoSocket struct {
+ S
+}
+
+type autoChannel struct {
+ s *autoSocket
+ c chan Channel
+ e env.Env
+}
+
+func newAutoSocket(proto, addr string) (S, error) {
+ s, err := newDialSocket(proto, addr)
+ if err != nil {
+ return s, err
+ }
+
+ return &autoSocket{s}, nil
+}
+
+func (s *autoSocket) Open(env env.Env) (Channel, error) {
+ c := &autoChannel{
+ s: s,
+ c: make(chan Channel),
+ e: env,
+ }
+
+ return c, nil
+}
+
+func (c *autoChannel) String() string {
+ return "auto"
+}
+
+func (c *autoChannel) Send(wq queue.Q) error {
+ if x := <-c.c; x == nil {
+ return nil
+ } else {
+ return x.Send(wq)
+ }
+}
+
+func (c *autoChannel) Recv(rq queue.Q) error {
+ b := <-rq
+ if b == nil {
+ close(c.c)
+ return nil
+ }
+
+ x, err := c.s.S.Open(c.e)
+ if err != nil {
+ close(c.c)
+ return err
+ }
+
+ c.c <- x
+
+ q := queue.New()
+
+ go func() {
+ q <- b
+ queue.Copy(rq, q)
+ close(q)
+ }()
+
+ defer q.Dry()
+
+ return x.Recv(q)
+}
+
+/* TODO */
+func (c *autoChannel) Close() error {
+ return nil
+}
diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go
new file mode 100644
index 0000000..818fbc6
--- /dev/null
+++ b/pkg/server/socket/dial.go
@@ -0,0 +1,42 @@
+package socket
+
+import (
+ "fmt"
+ "net"
+ "strings"
+ "tunnel/pkg/server/env"
+)
+
+type dialSocket struct {
+ proto, addr string
+}
+
+func newDialSocket(proto, addr string) (S, error) {
+ switch proto {
+ case "tcp", "udp":
+ if !strings.Contains(addr, ":") {
+ addr = "localhost:" + addr
+ }
+ }
+
+ return &dialSocket{proto: proto, addr: addr}, nil
+}
+
+func (s *dialSocket) String() string {
+ return fmt.Sprintf("%s/%s", s.proto, s.addr)
+}
+
+func (s *dialSocket) Open(env.Env) (Channel, error) {
+ conn, err := net.Dial(s.proto, s.addr)
+ if err != nil {
+ return nil, err
+ }
+
+ addr := conn.RemoteAddr()
+ info := fmt.Sprintf(">%s/%s", addr.Network(), addr)
+
+ return exported{info, newConn(conn)}, nil
+}
+
+func (s *dialSocket) Close() {
+}
diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go
new file mode 100644
index 0000000..c328945
--- /dev/null
+++ b/pkg/server/socket/listen.go
@@ -0,0 +1,49 @@
+package socket
+
+import (
+ "fmt"
+ "net"
+ "strings"
+ "tunnel/pkg/server/env"
+)
+
+func newListenSocket(proto, addr string) (S, error) {
+ if proto == "tcp" {
+ if !strings.Contains(addr, ":") {
+ addr = ":" + addr
+ }
+ }
+
+ listen, err := net.Listen(proto, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ s := &listenSocket{
+ proto: proto,
+ addr: addr,
+ listen: listen,
+ }
+
+ return s, nil
+}
+
+func (s *listenSocket) Open(env.Env) (Channel, error) {
+ conn, err := s.listen.Accept()
+ if err != nil {
+ return nil, err
+ }
+
+ addr := conn.RemoteAddr()
+ info := fmt.Sprintf("<%s/%s", addr.Network(), addr)
+
+ return exported{info, newConn(conn)}, nil
+}
+
+func (s *listenSocket) String() string {
+ return fmt.Sprintf("%s/%s,listen", s.proto, s.addr)
+}
+
+func (s *listenSocket) Close() {
+ s.listen.Close()
+}
diff --git a/pkg/server/socket/loop.go b/pkg/server/socket/loop.go
new file mode 100644
index 0000000..88e9491
--- /dev/null
+++ b/pkg/server/socket/loop.go
@@ -0,0 +1,46 @@
+package socket
+
+import (
+ "tunnel/pkg/server/env"
+ "tunnel/pkg/server/queue"
+)
+
+type loopSocket struct{}
+
+type loopChannel struct {
+ c chan queue.Q
+ q chan error
+}
+
+func (c *loopChannel) Send(wq queue.Q) error {
+ c.c <- wq
+ return <-c.q
+}
+
+func (c *loopChannel) Recv(rq queue.Q) error {
+ defer close(c.q)
+ return queue.Copy(rq, <-c.c)
+}
+
+func (c *loopChannel) String() string {
+ return "loop"
+}
+
+func (c *loopChannel) Close() error {
+ return nil
+}
+
+func (s *loopSocket) Open(env.Env) (Channel, error) {
+ return &loopChannel{make(chan queue.Q), make(chan error)}, nil
+}
+
+func (s *loopSocket) String() string {
+ return "loop"
+}
+
+func (s *loopSocket) Close() {
+}
+
+func newLoopSocket() (S, error) {
+ return &loopSocket{}, nil
+}
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 5ccf1bd..a945ce0 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -14,7 +14,7 @@ import (
var errAlreadyClosed = errors.New("already closed")
-type exportChannel struct {
+type exported struct {
info string
Channel
}
@@ -35,30 +35,22 @@ type listenSocket struct {
listen net.Listener
}
-type dialSocket struct {
- proto, addr string
-}
-
-type connChannel struct {
- conn net.Conn
+type conn struct {
+ net.Conn
once sync.Once
}
-type loopSocket struct{}
-
-type loopChannel struct {
- q queue.Q
-}
-
-func (c exportChannel) String() string {
+func (c exported) String() string {
return c.info
}
-func newConnChannel(conn net.Conn) Channel {
- return &connChannel{conn: conn}
+func newConn(cn net.Conn) Channel {
+ c := &conn{Conn: cn}
+ log.Println("open", c)
+ return c
}
-func (c *connChannel) final(f func() error, err error) error {
+func (c *conn) final(f func() error, err error) error {
if e := f(); e != nil {
if e == errAlreadyClosed {
return nil
@@ -70,131 +62,32 @@ func (c *connChannel) final(f func() error, err error) error {
return err
}
-func (c *connChannel) Send(wq queue.Q) error {
- err := queue.IoCopy(c.conn, wq.Writer())
+func (c *conn) Send(wq queue.Q) error {
+ err := queue.IoCopy(c, wq.Writer())
return c.final(c.Close, err)
}
-func (c *connChannel) Recv(rq queue.Q) error {
- err := queue.IoCopy(rq.Reader(), c.conn)
+func (c *conn) Recv(rq queue.Q) error {
+ err := queue.IoCopy(rq.Reader(), c)
return c.final(c.Close, err)
}
-func (c *connChannel) String() string {
- local, remote := c.conn.LocalAddr(), c.conn.RemoteAddr()
+func (c *conn) String() string {
+ local, remote := c.LocalAddr(), c.RemoteAddr()
return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote)
}
-func (c *connChannel) Close() error {
+func (c *conn) Close() error {
err := errAlreadyClosed
c.once.Do(func() {
log.Println("close", c)
- err = c.conn.Close()
+ err = c.Conn.Close()
})
return err
}
-func newListenSocket(proto, addr string) (S, error) {
- if proto == "tcp" {
- if !strings.Contains(addr, ":") {
- addr = ":" + addr
- }
- }
-
- listen, err := net.Listen(proto, addr)
- if err != nil {
- return nil, err
- }
-
- s := &listenSocket{
- proto: proto,
- addr: addr,
- listen: listen,
- }
-
- return s, nil
-}
-
-func (s *listenSocket) Open(env env.Env) (Channel, error) {
- conn, err := s.listen.Accept()
- if err != nil {
- return nil, err
- }
-
- addr := conn.RemoteAddr()
- info := fmt.Sprintf("%s/%s", addr.Network(), addr)
-
- return exportChannel{info, newConnChannel(conn)}, nil
-}
-
-func (s *listenSocket) String() string {
- return fmt.Sprintf("%s/%s,listen", s.proto, s.addr)
-}
-
-func (s *listenSocket) Close() {
- s.listen.Close()
-}
-
-func newDialSocket(proto, addr string) (S, error) {
- switch proto {
- case "tcp", "udp":
- if !strings.Contains(addr, ":") {
- addr = "localhost:" + addr
- }
- }
-
- return &dialSocket{proto: proto, addr: addr}, nil
-}
-
-func (s *dialSocket) String() string {
- return fmt.Sprintf("%s/%s", s.proto, s.addr)
-}
-
-func (s *dialSocket) Open(env env.Env) (Channel, error) {
- conn, err := net.Dial(s.proto, s.addr)
- if err != nil {
- return nil, err
- }
- return exportChannel{"-", newConnChannel(conn)}, nil
-}
-
-func (s *dialSocket) Close() {
-}
-
-func (c *loopChannel) Send(wq queue.Q) error {
- return queue.Copy(c.q, wq)
-}
-
-func (c *loopChannel) Recv(rq queue.Q) error {
- defer close(c.q)
- return queue.Copy(rq, c.q)
-}
-
-func (c *loopChannel) Close() error {
- return nil
-}
-
-func (c *loopChannel) String() string {
- return "loop"
-}
-
-func (s *loopSocket) Open(env.Env) (Channel, error) {
- return &loopChannel{queue.New()}, nil
-}
-
-func (s *loopSocket) String() string {
- return "loop"
-}
-
-func (s *loopSocket) Close() {
-}
-
-func newLoopSocket() (S, error) {
- return &loopSocket{}, nil
-}
-
func New(desc string, env env.Env) (S, error) {
base, opts := opts.Parse(desc)
args := strings.SplitN(base, "/", 2)
@@ -224,5 +117,9 @@ func New(desc string, env env.Env) (S, error) {
return newListenSocket(proto, addr)
}
+ if _, ok := opts["auto"]; ok {
+ return newAutoSocket(proto, addr)
+ }
+
return newDialSocket(proto, addr)
}
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index d59d9db..592f48c 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -19,6 +19,7 @@ import (
)
const maxRecentSize = 8
+const maxQueueLimit = 16384
type metric struct {
tx uint64
@@ -56,7 +57,7 @@ type tunnel struct {
quit chan struct{}
done chan struct{}
- mono bool
+ queue chan struct{}
in, out socket.S
hooks []hook.H
@@ -82,6 +83,7 @@ func (t *tunnel) stopServe() {
func (t *tunnel) stopStreams() {
t.mu.Lock()
for _, s := range t.streams {
+ log.Println(s, "stop")
s.stop()
}
t.mu.Unlock()
@@ -96,70 +98,72 @@ func (t *tunnel) Close() {
log.Println(t, "delete")
}
-func (t *tunnel) isQuit() bool {
+func (t *tunnel) alive() bool {
select {
case <-t.quit:
- return true
- default:
return false
+ default:
+ return true
}
}
-func (t *tunnel) serve() {
- var wg sync.WaitGroup
+func (t *tunnel) acquire() bool {
+ select {
+ case t.queue <- struct{}{}:
+ return true
+ case <-t.quit:
+ return false
+ }
+}
- for {
- if in, err := t.in.Open(t.env); err != nil {
- if t.isQuit() {
- break
- }
+func (t *tunnel) release() {
+ <-t.queue
+}
- log.Println(t, err)
- time.Sleep(5 * time.Second)
- } else {
- log.Println(t, "open", in)
+func (t *tunnel) sleep(d time.Duration) {
+ tmr := time.NewTimer(d)
+ select {
+ case <-tmr.C:
+ case <-t.quit:
+ }
+ tmr.Stop()
+}
- wg.Add(1)
+func (t *tunnel) serve() {
+ for t.acquire() {
+ var ok bool
- go func() {
- t.handle(in)
- wg.Done()
- }()
+ env := t.env.Fork()
- if t.mono {
- wg.Wait()
- t.wg.Wait()
+ if in, err := t.in.Open(env); err == nil {
+ if out, err := t.out.Open(env); err == nil {
+ s := t.newStream(env, in, out)
+ log.Println(t, s, "create", in, out)
+ ok = true
+ } else {
+ log.Println(t, err)
+ in.Close()
}
+ } else if t.alive() {
+ log.Println(t, err)
+ t.sleep(5 * time.Second)
}
- }
-
- wg.Wait()
- close(t.done)
-}
-
-func (t *tunnel) handle(in socket.Channel) {
- out, err := t.out.Open(t.env)
- if err != nil {
- log.Println(t, err)
- in.Close()
- return
+ if !ok {
+ t.release()
+ }
}
- log.Println(t, "open", out)
-
- s := t.newStream(in, out)
-
- log.Println(t, s, "create", in, out)
+ close(t.done)
}
-func (t *tunnel) newStream(in, out socket.Channel) *stream {
+func (t *tunnel) newStream(env env.Env, in, out socket.Channel) *stream {
s := &stream{
t: t,
in: in,
out: out,
+ env: env,
id: t.nextSid,
- env: t.env.Copy(),
since: time.Now(),
}
@@ -193,7 +197,7 @@ func (t *tunnel) delStream(s *stream) {
func (s *stream) info() string {
d := time.Since(s.since).Milliseconds()
- return fmt.Sprintf("%.3fms %d/%d -> %d/%d",
+ return fmt.Sprintf("%.3fs %d/%d -> %d/%d",
float64(d)/1000.0,
s.m.in.tx,
s.m.in.rx,
@@ -207,9 +211,12 @@ func (s *stream) waitAndClose() {
s.until = time.Now()
s.t.delStream(s)
-
+ s.t.release()
s.t.wg.Done()
+ s.in.Close()
+ s.out.Close()
+
for _, p := range s.pipes {
p.Close()
}
@@ -318,7 +325,7 @@ func parseHooks(args []string, env env.Env) ([]hook.H, error) {
return hooks, nil
}
-func newTunnel(mono bool, args []string, env env.Env) (*tunnel, error) {
+func newTunnel(limit int, args []string, env env.Env) (*tunnel, error) {
var in, out socket.S
var hooks []hook.H
var err error
@@ -344,11 +351,11 @@ func newTunnel(mono bool, args []string, env env.Env) (*tunnel, error) {
args: strings.Join(args, " "),
quit: make(chan struct{}),
done: make(chan struct{}),
- mono: mono,
hooks: hooks,
in: in,
out: out,
env: env,
+ queue: make(chan struct{}, limit),
streams: make(map[int]*stream),
}
@@ -364,10 +371,10 @@ func isOkTunnelName(s string) bool {
func tunnelAdd(r *request) {
args := r.args
name := ""
- mono := false
+ limit := maxQueueLimit
for len(args) > 1 {
- if args[0] == "name" && len(args) > 1 {
+ if args[0] == "name" {
name = args[1]
if !isOkTunnelName(name) {
r.Fatal("bad name")
@@ -382,7 +389,7 @@ func tunnelAdd(r *request) {
}
if args[0] == "mono" {
- mono = true
+ limit = 1
args = args[1:]
continue
}
@@ -394,7 +401,7 @@ func tunnelAdd(r *request) {
r.Fatal("not enough args")
}
- t, err := newTunnel(mono, args, r.c.s.env)
+ t, err := newTunnel(limit, args, r.c.s.env)
if err != nil {
r.Fatal(err)
}