summaryrefslogtreecommitdiff
path: root/pkg/server/socket
diff options
context:
space:
mode:
authorMikhail Osipov <mike.osipov@gmail.com>2021-04-14 15:02:11 +0300
committerMikhail Osipov <mike.osipov@gmail.com>2021-04-14 16:04:20 +0300
commita6a89cd5e025c6a6e12ad6060aae347385869cd3 (patch)
tree69e19085bc91aaa798bf1b8e5d39d8648c6e08ba /pkg/server/socket
parent1517965447c59f1426405bf775e4c7c1f0611354 (diff)
add redial socket
Diffstat (limited to 'pkg/server/socket')
-rw-r--r--pkg/server/socket/defer.go66
-rw-r--r--pkg/server/socket/dial.go14
-rw-r--r--pkg/server/socket/proxy.go18
-rw-r--r--pkg/server/socket/redial.go183
-rw-r--r--pkg/server/socket/socket.go14
5 files changed, 249 insertions, 46 deletions
diff --git a/pkg/server/socket/defer.go b/pkg/server/socket/defer.go
index d63d51e..3eb2dfd 100644
--- a/pkg/server/socket/defer.go
+++ b/pkg/server/socket/defer.go
@@ -1,6 +1,8 @@
package socket
import (
+ "context"
+
"tunnel/pkg/server/env"
"tunnel/pkg/server/queue"
)
@@ -10,17 +12,26 @@ type deferSocket struct {
}
type deferConn struct {
- sock S
- wait chan bool
- env env.Env
- conn Conn
+ s *deferSocket
+ e env.Env
+
+ ctx context.Context
+ cancel func()
+
+ conn chan *conn
}
func (s *deferSocket) New(env env.Env) (Conn, error) {
+ ctx, cancel := context.WithCancel(context.TODO())
+
c := &deferConn{
- sock: &s.dialSocket,
- wait: make(chan bool),
- env: env,
+ s: s,
+ e: env,
+
+ conn: make(chan *conn),
+
+ ctx: ctx,
+ cancel: cancel,
}
return c, nil
@@ -31,48 +42,41 @@ func (c *deferConn) String() string {
}
func (c *deferConn) Send(wq queue.Q) error {
- if !<-c.wait {
+ conn := <-c.conn
+ if conn == nil {
return nil
- } else {
- return c.conn.Send(wq)
}
+ return conn.Send(wq)
}
func (c *deferConn) Recv(rq queue.Q) error {
- b := <-rq
- if b == nil {
- c.wait <- false
+ // TODO: and context check
+ r := rq.Reader()
+
+ if _, err := r.Read(nil); err != nil {
+ c.conn <- nil
return nil
}
- conn, err := c.sock.New(c.env)
+ conn, err := dial(c.ctx, c.e, c.s.Proto, c.s.Addr)
if err != nil {
- c.wait <- false
+ c.conn <- nil
return err
}
- c.conn = conn
- c.wait <- true
-
- q := queue.New()
-
go func() {
- q <- b
- queue.Copy(rq, q)
- close(q)
+ <-c.ctx.Done()
+ conn.Close()
}()
- defer q.Dry()
+ c.conn <- conn
- return c.conn.Recv(q)
+ return queue.IoCopy(r, conn)
}
-func (c *deferConn) Close() (err error) {
- if c.conn != nil {
- err = c.conn.Close()
- }
-
- return
+func (c *deferConn) Close() error {
+ c.cancel()
+ return nil
}
func init() {
diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go
index a31e192..aea74d2 100644
--- a/pkg/server/socket/dial.go
+++ b/pkg/server/socket/dial.go
@@ -1,6 +1,7 @@
package socket
import (
+ "context"
"fmt"
"log"
"net"
@@ -21,8 +22,17 @@ func (s *dialSocket) String() string {
}
func (s *dialSocket) New(e env.Env) (Conn, error) {
- proto, addr := parseProtoAddr(s.Proto, e.Expand(s.Addr))
- conn, err := net.DialTimeout(proto, addr, defaultTimeout)
+ return dial(context.TODO(), e, s.Proto, s.Addr)
+}
+
+func dial(ctx context.Context, e env.Env, proto, addr string) (*conn, error) {
+ ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
+ defer cancel()
+
+ var d net.Dialer
+
+ proto, addr = parseProtoAddr(proto, e.Expand(addr))
+ conn, err := d.DialContext(ctx, proto, addr)
if err != nil {
return nil, err
}
diff --git a/pkg/server/socket/proxy.go b/pkg/server/socket/proxy.go
index c46d1c1..f6c4f72 100644
--- a/pkg/server/socket/proxy.go
+++ b/pkg/server/socket/proxy.go
@@ -3,6 +3,7 @@ package socket
import (
"bufio"
"bytes"
+ "context"
"errors"
"fmt"
@@ -26,7 +27,7 @@ type proxyServer struct {
auth string
wait chan status
env env.Env
- conn Conn
+ conn *conn
}
func (sock *proxySocket) New(env env.Env) (Conn, error) {
@@ -74,15 +75,9 @@ func (s *proxyServer) Send(wq queue.Q) error {
return s.conn.Send(wq)
}
-func (s *proxyServer) initConn(addr string) error {
- dial := dialSocket{
- Proto: s.sock.Proto,
- Addr: addr,
- }
-
- conn, err := dial.New(s.env)
+func (s *proxyServer) dial(addr string) error {
+ conn, err := dial(context.TODO(), s.env, s.sock.Proto, addr)
if err != nil {
- dial.Close()
return err
}
@@ -118,17 +113,18 @@ func (s *proxyServer) Recv(rq queue.Q) error {
}
}
- if err := s.initConn(req.URI); err != nil {
+ if err := s.dial(req.URI); err != nil {
s.wait <- status{500, "Unable to connect"}
return err
}
s.wait <- status{200, "Connection established"}
- return queue.IoCopy(r, s.conn.(*conn))
+ return queue.IoCopy(r, s.conn)
}
func (s *proxyServer) Close() (err error) {
+ // TODO safe close
if s.conn != nil {
err = s.conn.Close()
}
diff --git a/pkg/server/socket/redial.go b/pkg/server/socket/redial.go
new file mode 100644
index 0000000..cdd1043
--- /dev/null
+++ b/pkg/server/socket/redial.go
@@ -0,0 +1,183 @@
+package socket
+
+import (
+ "context"
+ "log"
+ "sync"
+ "time"
+
+ "tunnel/pkg/server/env"
+ "tunnel/pkg/server/queue"
+)
+
+type redialSocket struct {
+ dialSocket
+
+ SendKeep bool
+}
+
+type redialConn struct {
+ s *redialSocket
+ e env.Env
+
+ mu sync.Mutex
+ skip chan struct{}
+
+ cc chan *conn
+ q queue.Q
+
+ ctx context.Context
+ cancel context.CancelFunc
+}
+
+var closedchan = make(chan struct{})
+
+func init() {
+ close(closedchan)
+}
+
+func (s *redialSocket) New(env env.Env) (Conn, error) {
+ ctx, cancel := context.WithCancel(context.TODO())
+
+ c := &redialConn{
+ s: s,
+ e: env,
+
+ cc: make(chan *conn),
+ q: queue.New(),
+
+ ctx: ctx,
+ cancel: cancel,
+ }
+
+ go c.loop(ctx)
+
+ return c, nil
+}
+
+func (c *redialConn) loop(ctx context.Context) {
+ defer close(c.cc)
+
+ conntry := 0
+
+ for {
+ conntry++
+ conn, err := dial(ctx, c.e, c.s.Proto, c.s.Addr)
+ if err != nil {
+ if ctx.Err() != nil {
+ return
+ }
+
+ log.Println(err)
+
+ if conntry == 1 {
+ c.SkipUp()
+ }
+
+ t := time.NewTimer(time.Second)
+ select {
+ case <-ctx.Done():
+ return
+ case <-t.C:
+ }
+ continue
+ }
+
+ if conntry > 1 {
+ c.SkipDown()
+ }
+
+ conntry = 0
+
+ c.cc <- conn
+
+ for loop := true; loop; {
+ select {
+ case b := <-c.q:
+ if _, err := conn.Write(b); err != nil {
+ conn.Close()
+ loop = false
+ }
+ case <-conn.closed:
+ loop = false
+ case <-ctx.Done():
+ conn.Close()
+ return
+ }
+ }
+ }
+}
+
+func (c *redialConn) Send(wq queue.Q) error {
+ for conn := range c.cc {
+ conn.Send(wq)
+ conn.Close()
+ }
+ return nil
+}
+
+func (c *redialConn) SkipUp() {
+ if c.s.SendKeep {
+ return
+ }
+
+ c.mu.Lock()
+ if c.skip == nil {
+ c.skip = closedchan
+ } else {
+ close(c.skip)
+ }
+ c.mu.Unlock()
+}
+
+func (c *redialConn) SkipDown() {
+ if c.s.SendKeep {
+ return
+ }
+
+ c.mu.Lock()
+ c.skip = nil
+ c.mu.Unlock()
+}
+
+func (c *redialConn) Skip() chan struct{} {
+ if c.s.SendKeep {
+ return nil
+ }
+
+ c.mu.Lock()
+ if c.skip == nil {
+ c.skip = make(chan struct{})
+ }
+ d := c.skip
+ c.mu.Unlock()
+ return d
+}
+
+func (c *redialConn) Recv(rq queue.Q) (err error) {
+ for b := range rq {
+ select {
+ case c.q <- b:
+ case <-c.Skip():
+ case <-c.ctx.Done():
+ return
+ }
+ }
+ return
+}
+
+func (c *redialConn) String() string {
+ return "redial"
+}
+
+func (c *redialConn) Close() error {
+ c.cancel()
+ return nil
+}
+
+func (c *redialSocket) Close() {
+}
+
+func init() {
+ register("redial", "dial even after close", redialSocket{})
+}
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index ea35be3..cb76cf7 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -38,11 +38,20 @@ type conn struct {
desc string
info string
- once sync.Once
+ once sync.Once
+ closed chan struct{}
}
func newConn(cn net.Conn, desc, info string) *conn {
- c := &conn{Conn: cn, desc: desc, info: info}
+ c := &conn{
+ Conn: cn,
+
+ desc: desc,
+ info: info,
+
+ closed: make(chan struct{}),
+ }
+
return c
}
@@ -64,6 +73,7 @@ func (c *conn) Close() error {
c.once.Do(func() {
log.Println("close", c.desc)
err = c.Conn.Close()
+ close(c.closed)
})
return err