summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--BUGS2
-rw-r--r--pkg/server/cmds.go2
-rw-r--r--pkg/server/queue/queue.go15
-rw-r--r--pkg/server/socket/defer.go66
-rw-r--r--pkg/server/socket/dial.go14
-rw-r--r--pkg/server/socket/proxy.go18
-rw-r--r--pkg/server/socket/redial.go183
-rw-r--r--pkg/server/socket/socket.go14
-rw-r--r--pkg/server/tunnel.go42
9 files changed, 297 insertions, 59 deletions
diff --git a/BUGS b/BUGS
new file mode 100644
index 0000000..1a65455
--- /dev/null
+++ b/BUGS
@@ -0,0 +1,2 @@
+- dial async ?
+- force closed other socket in case of stream not closed during timeout
diff --git a/pkg/server/cmds.go b/pkg/server/cmds.go
index a525b86..febda6b 100644
--- a/pkg/server/cmds.go
+++ b/pkg/server/cmds.go
@@ -177,7 +177,7 @@ func (c *cmd) parseFuncSignature(i interface{}) error {
}
if f, ok := parseFuncMap[t.Kind()]; !ok {
- return fmt.Errorf("%s: unsupported %d arg type %s", t, t)
+ return fmt.Errorf("%s: unsupported %d arg type %s", t, n, t)
} else if slice {
c.rest = newParseSlice(t, f)
} else {
diff --git a/pkg/server/queue/queue.go b/pkg/server/queue/queue.go
index f0c1fc9..1478c07 100644
--- a/pkg/server/queue/queue.go
+++ b/pkg/server/queue/queue.go
@@ -55,13 +55,15 @@ func (r *reader) WriteTo(w io.Writer) (int64, error) {
}
if len(r.b) > 0 {
- if _, err := w.Write(r.b); err != nil {
+ if n, err := w.Write(r.b); err != nil {
+ r.b = r.b[n:]
return 0, err
}
}
for b := range r.q {
- if _, err := w.Write(b); err != nil {
+ if n, err := w.Write(b); err != nil {
+ r.b = b[n:]
return 0, err
}
}
@@ -91,13 +93,8 @@ func (w *writer) Write(p []byte) (int, error) {
}
func IoCopy(r io.Reader, w io.Writer) error {
- if _, err := io.Copy(w, r); err != nil {
- if err != io.EOF {
- return err
- }
- }
-
- return nil
+ _, err := io.Copy(w, r)
+ return err
}
func Copy(rq, wq Q) error {
diff --git a/pkg/server/socket/defer.go b/pkg/server/socket/defer.go
index d63d51e..3eb2dfd 100644
--- a/pkg/server/socket/defer.go
+++ b/pkg/server/socket/defer.go
@@ -1,6 +1,8 @@
package socket
import (
+ "context"
+
"tunnel/pkg/server/env"
"tunnel/pkg/server/queue"
)
@@ -10,17 +12,26 @@ type deferSocket struct {
}
type deferConn struct {
- sock S
- wait chan bool
- env env.Env
- conn Conn
+ s *deferSocket
+ e env.Env
+
+ ctx context.Context
+ cancel func()
+
+ conn chan *conn
}
func (s *deferSocket) New(env env.Env) (Conn, error) {
+ ctx, cancel := context.WithCancel(context.TODO())
+
c := &deferConn{
- sock: &s.dialSocket,
- wait: make(chan bool),
- env: env,
+ s: s,
+ e: env,
+
+ conn: make(chan *conn),
+
+ ctx: ctx,
+ cancel: cancel,
}
return c, nil
@@ -31,48 +42,41 @@ func (c *deferConn) String() string {
}
func (c *deferConn) Send(wq queue.Q) error {
- if !<-c.wait {
+ conn := <-c.conn
+ if conn == nil {
return nil
- } else {
- return c.conn.Send(wq)
}
+ return conn.Send(wq)
}
func (c *deferConn) Recv(rq queue.Q) error {
- b := <-rq
- if b == nil {
- c.wait <- false
+ // TODO: and context check
+ r := rq.Reader()
+
+ if _, err := r.Read(nil); err != nil {
+ c.conn <- nil
return nil
}
- conn, err := c.sock.New(c.env)
+ conn, err := dial(c.ctx, c.e, c.s.Proto, c.s.Addr)
if err != nil {
- c.wait <- false
+ c.conn <- nil
return err
}
- c.conn = conn
- c.wait <- true
-
- q := queue.New()
-
go func() {
- q <- b
- queue.Copy(rq, q)
- close(q)
+ <-c.ctx.Done()
+ conn.Close()
}()
- defer q.Dry()
+ c.conn <- conn
- return c.conn.Recv(q)
+ return queue.IoCopy(r, conn)
}
-func (c *deferConn) Close() (err error) {
- if c.conn != nil {
- err = c.conn.Close()
- }
-
- return
+func (c *deferConn) Close() error {
+ c.cancel()
+ return nil
}
func init() {
diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go
index a31e192..aea74d2 100644
--- a/pkg/server/socket/dial.go
+++ b/pkg/server/socket/dial.go
@@ -1,6 +1,7 @@
package socket
import (
+ "context"
"fmt"
"log"
"net"
@@ -21,8 +22,17 @@ func (s *dialSocket) String() string {
}
func (s *dialSocket) New(e env.Env) (Conn, error) {
- proto, addr := parseProtoAddr(s.Proto, e.Expand(s.Addr))
- conn, err := net.DialTimeout(proto, addr, defaultTimeout)
+ return dial(context.TODO(), e, s.Proto, s.Addr)
+}
+
+func dial(ctx context.Context, e env.Env, proto, addr string) (*conn, error) {
+ ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
+ defer cancel()
+
+ var d net.Dialer
+
+ proto, addr = parseProtoAddr(proto, e.Expand(addr))
+ conn, err := d.DialContext(ctx, proto, addr)
if err != nil {
return nil, err
}
diff --git a/pkg/server/socket/proxy.go b/pkg/server/socket/proxy.go
index c46d1c1..f6c4f72 100644
--- a/pkg/server/socket/proxy.go
+++ b/pkg/server/socket/proxy.go
@@ -3,6 +3,7 @@ package socket
import (
"bufio"
"bytes"
+ "context"
"errors"
"fmt"
@@ -26,7 +27,7 @@ type proxyServer struct {
auth string
wait chan status
env env.Env
- conn Conn
+ conn *conn
}
func (sock *proxySocket) New(env env.Env) (Conn, error) {
@@ -74,15 +75,9 @@ func (s *proxyServer) Send(wq queue.Q) error {
return s.conn.Send(wq)
}
-func (s *proxyServer) initConn(addr string) error {
- dial := dialSocket{
- Proto: s.sock.Proto,
- Addr: addr,
- }
-
- conn, err := dial.New(s.env)
+func (s *proxyServer) dial(addr string) error {
+ conn, err := dial(context.TODO(), s.env, s.sock.Proto, addr)
if err != nil {
- dial.Close()
return err
}
@@ -118,17 +113,18 @@ func (s *proxyServer) Recv(rq queue.Q) error {
}
}
- if err := s.initConn(req.URI); err != nil {
+ if err := s.dial(req.URI); err != nil {
s.wait <- status{500, "Unable to connect"}
return err
}
s.wait <- status{200, "Connection established"}
- return queue.IoCopy(r, s.conn.(*conn))
+ return queue.IoCopy(r, s.conn)
}
func (s *proxyServer) Close() (err error) {
+ // TODO safe close
if s.conn != nil {
err = s.conn.Close()
}
diff --git a/pkg/server/socket/redial.go b/pkg/server/socket/redial.go
new file mode 100644
index 0000000..cdd1043
--- /dev/null
+++ b/pkg/server/socket/redial.go
@@ -0,0 +1,183 @@
+package socket
+
+import (
+ "context"
+ "log"
+ "sync"
+ "time"
+
+ "tunnel/pkg/server/env"
+ "tunnel/pkg/server/queue"
+)
+
+type redialSocket struct {
+ dialSocket
+
+ SendKeep bool
+}
+
+type redialConn struct {
+ s *redialSocket
+ e env.Env
+
+ mu sync.Mutex
+ skip chan struct{}
+
+ cc chan *conn
+ q queue.Q
+
+ ctx context.Context
+ cancel context.CancelFunc
+}
+
+var closedchan = make(chan struct{})
+
+func init() {
+ close(closedchan)
+}
+
+func (s *redialSocket) New(env env.Env) (Conn, error) {
+ ctx, cancel := context.WithCancel(context.TODO())
+
+ c := &redialConn{
+ s: s,
+ e: env,
+
+ cc: make(chan *conn),
+ q: queue.New(),
+
+ ctx: ctx,
+ cancel: cancel,
+ }
+
+ go c.loop(ctx)
+
+ return c, nil
+}
+
+func (c *redialConn) loop(ctx context.Context) {
+ defer close(c.cc)
+
+ conntry := 0
+
+ for {
+ conntry++
+ conn, err := dial(ctx, c.e, c.s.Proto, c.s.Addr)
+ if err != nil {
+ if ctx.Err() != nil {
+ return
+ }
+
+ log.Println(err)
+
+ if conntry == 1 {
+ c.SkipUp()
+ }
+
+ t := time.NewTimer(time.Second)
+ select {
+ case <-ctx.Done():
+ return
+ case <-t.C:
+ }
+ continue
+ }
+
+ if conntry > 1 {
+ c.SkipDown()
+ }
+
+ conntry = 0
+
+ c.cc <- conn
+
+ for loop := true; loop; {
+ select {
+ case b := <-c.q:
+ if _, err := conn.Write(b); err != nil {
+ conn.Close()
+ loop = false
+ }
+ case <-conn.closed:
+ loop = false
+ case <-ctx.Done():
+ conn.Close()
+ return
+ }
+ }
+ }
+}
+
+func (c *redialConn) Send(wq queue.Q) error {
+ for conn := range c.cc {
+ conn.Send(wq)
+ conn.Close()
+ }
+ return nil
+}
+
+func (c *redialConn) SkipUp() {
+ if c.s.SendKeep {
+ return
+ }
+
+ c.mu.Lock()
+ if c.skip == nil {
+ c.skip = closedchan
+ } else {
+ close(c.skip)
+ }
+ c.mu.Unlock()
+}
+
+func (c *redialConn) SkipDown() {
+ if c.s.SendKeep {
+ return
+ }
+
+ c.mu.Lock()
+ c.skip = nil
+ c.mu.Unlock()
+}
+
+func (c *redialConn) Skip() chan struct{} {
+ if c.s.SendKeep {
+ return nil
+ }
+
+ c.mu.Lock()
+ if c.skip == nil {
+ c.skip = make(chan struct{})
+ }
+ d := c.skip
+ c.mu.Unlock()
+ return d
+}
+
+func (c *redialConn) Recv(rq queue.Q) (err error) {
+ for b := range rq {
+ select {
+ case c.q <- b:
+ case <-c.Skip():
+ case <-c.ctx.Done():
+ return
+ }
+ }
+ return
+}
+
+func (c *redialConn) String() string {
+ return "redial"
+}
+
+func (c *redialConn) Close() error {
+ c.cancel()
+ return nil
+}
+
+func (c *redialSocket) Close() {
+}
+
+func init() {
+ register("redial", "dial even after close", redialSocket{})
+}
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index ea35be3..cb76cf7 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -38,11 +38,20 @@ type conn struct {
desc string
info string
- once sync.Once
+ once sync.Once
+ closed chan struct{}
}
func newConn(cn net.Conn, desc, info string) *conn {
- c := &conn{Conn: cn, desc: desc, info: info}
+ c := &conn{
+ Conn: cn,
+
+ desc: desc,
+ info: info,
+
+ closed: make(chan struct{}),
+ }
+
return c
}
@@ -64,6 +73,7 @@ func (c *conn) Close() error {
c.once.Do(func() {
log.Println("close", c.desc)
err = c.Conn.Close()
+ close(c.closed)
})
return err
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index e2fce92..afa0f31 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -36,6 +36,8 @@ type stream struct {
wg sync.WaitGroup
in, out socket.Conn
pipes []*hook.Pipe
+ mu sync.Mutex
+ zombie bool
m struct {
in metric
@@ -295,6 +297,10 @@ func (s *stream) channel(c socket.Conn, m *metric, rq, wq queue.Q) {
if err != nil {
log.Println(s.t, s, c, err)
}
+
+ s.mu.Lock()
+ s.zombie = true
+ s.mu.Unlock()
}
counter := func(c *uint64, src, dst queue.Q) {
@@ -499,6 +505,26 @@ func tunnelDel(r *request, id string) {
}
}
+func streamKick(r *request, tid string, sid int) {
+ e, ok := r.c.s.tunnels[tid]
+ if !ok {
+ r.Fatal("no such tunnel")
+ }
+
+ t := e.(*tunnel)
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ s, ok := t.streams[sid]
+ if !ok {
+ r.Fatal("no such stream")
+ }
+
+ log.Println(s, "kick")
+ s.stop()
+}
+
func tunnelRename(r *request, old, new string) {
if !isOkTunnelName(new) {
r.Fatal("bad name")
@@ -551,7 +577,15 @@ func showActive(r *request) {
defer t.mu.Unlock()
foreachStream(t.streams, func(s *stream) {
- r.Println(t.id, s.id, s.in, s.out, s.info())
+ var opts string
+
+ s.mu.Lock()
+ if s.zombie {
+ opts = "zombie"
+ }
+ s.mu.Unlock()
+
+ r.Println(t.id, s.id, s.in, s.out, s.info(), opts)
})
})
}
@@ -577,9 +611,11 @@ func showRecent(r *request) {
func init() {
newCmd("add", tunnelAdd, "[name id] [limit N] [single] socket [hook ...] socket")
- newCmd("del", tunnelDel, "id")
+ newCmd("del", tunnelDel, "tunnel-id")
+
+ newCmd("kick", streamKick, "tunnel-id stream-id")
- newCmd("rename", tunnelRename, "old-id new-id")
+ newCmd("rename", tunnelRename, "tunnel-old-id tunnel-new-id")
newCmd("show", showTunnels, "")