summaryrefslogtreecommitdiff
path: root/pkg/server/socket/redial.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/socket/redial.go')
-rw-r--r--pkg/server/socket/redial.go183
1 files changed, 183 insertions, 0 deletions
diff --git a/pkg/server/socket/redial.go b/pkg/server/socket/redial.go
new file mode 100644
index 0000000..cdd1043
--- /dev/null
+++ b/pkg/server/socket/redial.go
@@ -0,0 +1,183 @@
+package socket
+
+import (
+ "context"
+ "log"
+ "sync"
+ "time"
+
+ "tunnel/pkg/server/env"
+ "tunnel/pkg/server/queue"
+)
+
+type redialSocket struct {
+ dialSocket
+
+ SendKeep bool
+}
+
+type redialConn struct {
+ s *redialSocket
+ e env.Env
+
+ mu sync.Mutex
+ skip chan struct{}
+
+ cc chan *conn
+ q queue.Q
+
+ ctx context.Context
+ cancel context.CancelFunc
+}
+
+var closedchan = make(chan struct{})
+
+func init() {
+ close(closedchan)
+}
+
+func (s *redialSocket) New(env env.Env) (Conn, error) {
+ ctx, cancel := context.WithCancel(context.TODO())
+
+ c := &redialConn{
+ s: s,
+ e: env,
+
+ cc: make(chan *conn),
+ q: queue.New(),
+
+ ctx: ctx,
+ cancel: cancel,
+ }
+
+ go c.loop(ctx)
+
+ return c, nil
+}
+
+func (c *redialConn) loop(ctx context.Context) {
+ defer close(c.cc)
+
+ conntry := 0
+
+ for {
+ conntry++
+ conn, err := dial(ctx, c.e, c.s.Proto, c.s.Addr)
+ if err != nil {
+ if ctx.Err() != nil {
+ return
+ }
+
+ log.Println(err)
+
+ if conntry == 1 {
+ c.SkipUp()
+ }
+
+ t := time.NewTimer(time.Second)
+ select {
+ case <-ctx.Done():
+ return
+ case <-t.C:
+ }
+ continue
+ }
+
+ if conntry > 1 {
+ c.SkipDown()
+ }
+
+ conntry = 0
+
+ c.cc <- conn
+
+ for loop := true; loop; {
+ select {
+ case b := <-c.q:
+ if _, err := conn.Write(b); err != nil {
+ conn.Close()
+ loop = false
+ }
+ case <-conn.closed:
+ loop = false
+ case <-ctx.Done():
+ conn.Close()
+ return
+ }
+ }
+ }
+}
+
+func (c *redialConn) Send(wq queue.Q) error {
+ for conn := range c.cc {
+ conn.Send(wq)
+ conn.Close()
+ }
+ return nil
+}
+
+func (c *redialConn) SkipUp() {
+ if c.s.SendKeep {
+ return
+ }
+
+ c.mu.Lock()
+ if c.skip == nil {
+ c.skip = closedchan
+ } else {
+ close(c.skip)
+ }
+ c.mu.Unlock()
+}
+
+func (c *redialConn) SkipDown() {
+ if c.s.SendKeep {
+ return
+ }
+
+ c.mu.Lock()
+ c.skip = nil
+ c.mu.Unlock()
+}
+
+func (c *redialConn) Skip() chan struct{} {
+ if c.s.SendKeep {
+ return nil
+ }
+
+ c.mu.Lock()
+ if c.skip == nil {
+ c.skip = make(chan struct{})
+ }
+ d := c.skip
+ c.mu.Unlock()
+ return d
+}
+
+func (c *redialConn) Recv(rq queue.Q) (err error) {
+ for b := range rq {
+ select {
+ case c.q <- b:
+ case <-c.Skip():
+ case <-c.ctx.Done():
+ return
+ }
+ }
+ return
+}
+
+func (c *redialConn) String() string {
+ return "redial"
+}
+
+func (c *redialConn) Close() error {
+ c.cancel()
+ return nil
+}
+
+func (c *redialSocket) Close() {
+}
+
+func init() {
+ register("redial", "dial even after close", redialSocket{})
+}