diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2021-04-14 15:02:11 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2021-04-14 16:04:20 +0300 |
| commit | a6a89cd5e025c6a6e12ad6060aae347385869cd3 (patch) | |
| tree | 69e19085bc91aaa798bf1b8e5d39d8648c6e08ba | |
| parent | 1517965447c59f1426405bf775e4c7c1f0611354 (diff) | |
add redial socket
| -rw-r--r-- | BUGS | 2 | ||||
| -rw-r--r-- | pkg/server/cmds.go | 2 | ||||
| -rw-r--r-- | pkg/server/queue/queue.go | 15 | ||||
| -rw-r--r-- | pkg/server/socket/defer.go | 66 | ||||
| -rw-r--r-- | pkg/server/socket/dial.go | 14 | ||||
| -rw-r--r-- | pkg/server/socket/proxy.go | 18 | ||||
| -rw-r--r-- | pkg/server/socket/redial.go | 183 | ||||
| -rw-r--r-- | pkg/server/socket/socket.go | 14 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 42 |
9 files changed, 297 insertions, 59 deletions
@@ -0,0 +1,2 @@ +- dial async ? +- force closed other socket in case of stream not closed during timeout diff --git a/pkg/server/cmds.go b/pkg/server/cmds.go index a525b86..febda6b 100644 --- a/pkg/server/cmds.go +++ b/pkg/server/cmds.go @@ -177,7 +177,7 @@ func (c *cmd) parseFuncSignature(i interface{}) error { } if f, ok := parseFuncMap[t.Kind()]; !ok { - return fmt.Errorf("%s: unsupported %d arg type %s", t, t) + return fmt.Errorf("%s: unsupported %d arg type %s", t, n, t) } else if slice { c.rest = newParseSlice(t, f) } else { diff --git a/pkg/server/queue/queue.go b/pkg/server/queue/queue.go index f0c1fc9..1478c07 100644 --- a/pkg/server/queue/queue.go +++ b/pkg/server/queue/queue.go @@ -55,13 +55,15 @@ func (r *reader) WriteTo(w io.Writer) (int64, error) { } if len(r.b) > 0 { - if _, err := w.Write(r.b); err != nil { + if n, err := w.Write(r.b); err != nil { + r.b = r.b[n:] return 0, err } } for b := range r.q { - if _, err := w.Write(b); err != nil { + if n, err := w.Write(b); err != nil { + r.b = b[n:] return 0, err } } @@ -91,13 +93,8 @@ func (w *writer) Write(p []byte) (int, error) { } func IoCopy(r io.Reader, w io.Writer) error { - if _, err := io.Copy(w, r); err != nil { - if err != io.EOF { - return err - } - } - - return nil + _, err := io.Copy(w, r) + return err } func Copy(rq, wq Q) error { diff --git a/pkg/server/socket/defer.go b/pkg/server/socket/defer.go index d63d51e..3eb2dfd 100644 --- a/pkg/server/socket/defer.go +++ b/pkg/server/socket/defer.go @@ -1,6 +1,8 @@ package socket import ( + "context" + "tunnel/pkg/server/env" "tunnel/pkg/server/queue" ) @@ -10,17 +12,26 @@ type deferSocket struct { } type deferConn struct { - sock S - wait chan bool - env env.Env - conn Conn + s *deferSocket + e env.Env + + ctx context.Context + cancel func() + + conn chan *conn } func (s *deferSocket) New(env env.Env) (Conn, error) { + ctx, cancel := context.WithCancel(context.TODO()) + c := &deferConn{ - sock: &s.dialSocket, - wait: make(chan bool), - env: env, + s: s, + e: env, + + conn: make(chan *conn), + + ctx: ctx, + cancel: cancel, } return c, nil @@ -31,48 +42,41 @@ func (c *deferConn) String() string { } func (c *deferConn) Send(wq queue.Q) error { - if !<-c.wait { + conn := <-c.conn + if conn == nil { return nil - } else { - return c.conn.Send(wq) } + return conn.Send(wq) } func (c *deferConn) Recv(rq queue.Q) error { - b := <-rq - if b == nil { - c.wait <- false + // TODO: and context check + r := rq.Reader() + + if _, err := r.Read(nil); err != nil { + c.conn <- nil return nil } - conn, err := c.sock.New(c.env) + conn, err := dial(c.ctx, c.e, c.s.Proto, c.s.Addr) if err != nil { - c.wait <- false + c.conn <- nil return err } - c.conn = conn - c.wait <- true - - q := queue.New() - go func() { - q <- b - queue.Copy(rq, q) - close(q) + <-c.ctx.Done() + conn.Close() }() - defer q.Dry() + c.conn <- conn - return c.conn.Recv(q) + return queue.IoCopy(r, conn) } -func (c *deferConn) Close() (err error) { - if c.conn != nil { - err = c.conn.Close() - } - - return +func (c *deferConn) Close() error { + c.cancel() + return nil } func init() { diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go index a31e192..aea74d2 100644 --- a/pkg/server/socket/dial.go +++ b/pkg/server/socket/dial.go @@ -1,6 +1,7 @@ package socket import ( + "context" "fmt" "log" "net" @@ -21,8 +22,17 @@ func (s *dialSocket) String() string { } func (s *dialSocket) New(e env.Env) (Conn, error) { - proto, addr := parseProtoAddr(s.Proto, e.Expand(s.Addr)) - conn, err := net.DialTimeout(proto, addr, defaultTimeout) + return dial(context.TODO(), e, s.Proto, s.Addr) +} + +func dial(ctx context.Context, e env.Env, proto, addr string) (*conn, error) { + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + + var d net.Dialer + + proto, addr = parseProtoAddr(proto, e.Expand(addr)) + conn, err := d.DialContext(ctx, proto, addr) if err != nil { return nil, err } diff --git a/pkg/server/socket/proxy.go b/pkg/server/socket/proxy.go index c46d1c1..f6c4f72 100644 --- a/pkg/server/socket/proxy.go +++ b/pkg/server/socket/proxy.go @@ -3,6 +3,7 @@ package socket import ( "bufio" "bytes" + "context" "errors" "fmt" @@ -26,7 +27,7 @@ type proxyServer struct { auth string wait chan status env env.Env - conn Conn + conn *conn } func (sock *proxySocket) New(env env.Env) (Conn, error) { @@ -74,15 +75,9 @@ func (s *proxyServer) Send(wq queue.Q) error { return s.conn.Send(wq) } -func (s *proxyServer) initConn(addr string) error { - dial := dialSocket{ - Proto: s.sock.Proto, - Addr: addr, - } - - conn, err := dial.New(s.env) +func (s *proxyServer) dial(addr string) error { + conn, err := dial(context.TODO(), s.env, s.sock.Proto, addr) if err != nil { - dial.Close() return err } @@ -118,17 +113,18 @@ func (s *proxyServer) Recv(rq queue.Q) error { } } - if err := s.initConn(req.URI); err != nil { + if err := s.dial(req.URI); err != nil { s.wait <- status{500, "Unable to connect"} return err } s.wait <- status{200, "Connection established"} - return queue.IoCopy(r, s.conn.(*conn)) + return queue.IoCopy(r, s.conn) } func (s *proxyServer) Close() (err error) { + // TODO safe close if s.conn != nil { err = s.conn.Close() } diff --git a/pkg/server/socket/redial.go b/pkg/server/socket/redial.go new file mode 100644 index 0000000..cdd1043 --- /dev/null +++ b/pkg/server/socket/redial.go @@ -0,0 +1,183 @@ +package socket + +import ( + "context" + "log" + "sync" + "time" + + "tunnel/pkg/server/env" + "tunnel/pkg/server/queue" +) + +type redialSocket struct { + dialSocket + + SendKeep bool +} + +type redialConn struct { + s *redialSocket + e env.Env + + mu sync.Mutex + skip chan struct{} + + cc chan *conn + q queue.Q + + ctx context.Context + cancel context.CancelFunc +} + +var closedchan = make(chan struct{}) + +func init() { + close(closedchan) +} + +func (s *redialSocket) New(env env.Env) (Conn, error) { + ctx, cancel := context.WithCancel(context.TODO()) + + c := &redialConn{ + s: s, + e: env, + + cc: make(chan *conn), + q: queue.New(), + + ctx: ctx, + cancel: cancel, + } + + go c.loop(ctx) + + return c, nil +} + +func (c *redialConn) loop(ctx context.Context) { + defer close(c.cc) + + conntry := 0 + + for { + conntry++ + conn, err := dial(ctx, c.e, c.s.Proto, c.s.Addr) + if err != nil { + if ctx.Err() != nil { + return + } + + log.Println(err) + + if conntry == 1 { + c.SkipUp() + } + + t := time.NewTimer(time.Second) + select { + case <-ctx.Done(): + return + case <-t.C: + } + continue + } + + if conntry > 1 { + c.SkipDown() + } + + conntry = 0 + + c.cc <- conn + + for loop := true; loop; { + select { + case b := <-c.q: + if _, err := conn.Write(b); err != nil { + conn.Close() + loop = false + } + case <-conn.closed: + loop = false + case <-ctx.Done(): + conn.Close() + return + } + } + } +} + +func (c *redialConn) Send(wq queue.Q) error { + for conn := range c.cc { + conn.Send(wq) + conn.Close() + } + return nil +} + +func (c *redialConn) SkipUp() { + if c.s.SendKeep { + return + } + + c.mu.Lock() + if c.skip == nil { + c.skip = closedchan + } else { + close(c.skip) + } + c.mu.Unlock() +} + +func (c *redialConn) SkipDown() { + if c.s.SendKeep { + return + } + + c.mu.Lock() + c.skip = nil + c.mu.Unlock() +} + +func (c *redialConn) Skip() chan struct{} { + if c.s.SendKeep { + return nil + } + + c.mu.Lock() + if c.skip == nil { + c.skip = make(chan struct{}) + } + d := c.skip + c.mu.Unlock() + return d +} + +func (c *redialConn) Recv(rq queue.Q) (err error) { + for b := range rq { + select { + case c.q <- b: + case <-c.Skip(): + case <-c.ctx.Done(): + return + } + } + return +} + +func (c *redialConn) String() string { + return "redial" +} + +func (c *redialConn) Close() error { + c.cancel() + return nil +} + +func (c *redialSocket) Close() { +} + +func init() { + register("redial", "dial even after close", redialSocket{}) +} diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index ea35be3..cb76cf7 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -38,11 +38,20 @@ type conn struct { desc string info string - once sync.Once + once sync.Once + closed chan struct{} } func newConn(cn net.Conn, desc, info string) *conn { - c := &conn{Conn: cn, desc: desc, info: info} + c := &conn{ + Conn: cn, + + desc: desc, + info: info, + + closed: make(chan struct{}), + } + return c } @@ -64,6 +73,7 @@ func (c *conn) Close() error { c.once.Do(func() { log.Println("close", c.desc) err = c.Conn.Close() + close(c.closed) }) return err diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index e2fce92..afa0f31 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -36,6 +36,8 @@ type stream struct { wg sync.WaitGroup in, out socket.Conn pipes []*hook.Pipe + mu sync.Mutex + zombie bool m struct { in metric @@ -295,6 +297,10 @@ func (s *stream) channel(c socket.Conn, m *metric, rq, wq queue.Q) { if err != nil { log.Println(s.t, s, c, err) } + + s.mu.Lock() + s.zombie = true + s.mu.Unlock() } counter := func(c *uint64, src, dst queue.Q) { @@ -499,6 +505,26 @@ func tunnelDel(r *request, id string) { } } +func streamKick(r *request, tid string, sid int) { + e, ok := r.c.s.tunnels[tid] + if !ok { + r.Fatal("no such tunnel") + } + + t := e.(*tunnel) + + t.mu.Lock() + defer t.mu.Unlock() + + s, ok := t.streams[sid] + if !ok { + r.Fatal("no such stream") + } + + log.Println(s, "kick") + s.stop() +} + func tunnelRename(r *request, old, new string) { if !isOkTunnelName(new) { r.Fatal("bad name") @@ -551,7 +577,15 @@ func showActive(r *request) { defer t.mu.Unlock() foreachStream(t.streams, func(s *stream) { - r.Println(t.id, s.id, s.in, s.out, s.info()) + var opts string + + s.mu.Lock() + if s.zombie { + opts = "zombie" + } + s.mu.Unlock() + + r.Println(t.id, s.id, s.in, s.out, s.info(), opts) }) }) } @@ -577,9 +611,11 @@ func showRecent(r *request) { func init() { newCmd("add", tunnelAdd, "[name id] [limit N] [single] socket [hook ...] socket") - newCmd("del", tunnelDel, "id") + newCmd("del", tunnelDel, "tunnel-id") + + newCmd("kick", streamKick, "tunnel-id stream-id") - newCmd("rename", tunnelRename, "old-id new-id") + newCmd("rename", tunnelRename, "tunnel-old-id tunnel-new-id") newCmd("show", showTunnels, "") |
