diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-02-21 22:55:35 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-02-21 22:55:35 +0300 |
| commit | 7c7fafefef94c5fb8bfe319e7745d80a1e88205d (patch) | |
| tree | 40e6d3b5c5933fb3fa83819f40f0c02458cc297d /pkg | |
| parent | 6a25466ac3a8b94b08a3114c9e5cc721ed620d49 (diff) | |
tunnel del with active streams
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/server/socket/socket.go | 2 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 96 |
2 files changed, 59 insertions, 39 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index cad1ad3..bf754cf 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -61,7 +61,7 @@ func (cc *connChannel) Recv(rq queue.Q) (err error) { func (cc *connChannel) String() string { local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr() - return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote) + return fmt.Sprintf("%s/%s->%s", local.Network(), remote, local) } func (cc *connChannel) shutdown(err *error) { diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index 04f838b..e4a324c 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -5,7 +5,6 @@ import ( "tunnel/pkg/server/queue" "tunnel/pkg/server/socket" "tunnel/pkg/config" - "sync/atomic" "strings" "time" "sort" @@ -16,9 +15,9 @@ import ( type stream struct { id int - n int32 t *tunnel since time.Time + wg sync.WaitGroup in, out socket.Channel } @@ -29,6 +28,7 @@ type tunnel struct { streams map[int]*stream mu sync.Mutex + wg sync.WaitGroup nextSid int @@ -47,12 +47,26 @@ func (t *tunnel) String() string { return fmt.Sprintf("tunnel(%s)", t.id) } -/* FIXME close streams also */ -func (t *tunnel) Close() { +func (t *tunnel) stopServe() { close(t.quit) t.in.Close() t.out.Close() <-t.done +} + +func (t *tunnel) stopStreams() { + t.mu.Lock() + for _, s := range t.streams { + s.stop() + } + t.mu.Unlock() + + t.wg.Wait() +} + +func (t *tunnel) Close() { + t.stopServe() + t.stopStreams() log.Println(t, "delete") } @@ -94,6 +108,21 @@ func (t *tunnel) serve() { close(t.done) } +func (t *tunnel) handle(in socket.Channel) { + out, err := t.out.Open() + if err != nil { + log.Println(t, err) + in.Close() + return + } + + log.Println(t, "open", out) + + s := t.newStream(in, out) + + log.Println(t, s, "create", in, out) +} + func (t *tunnel) newStream(in, out socket.Channel) *stream { s := &stream{ t: t, @@ -108,33 +137,34 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream { t.streams[s.id] = s t.mu.Unlock() - return s -} + s.run() -func (s *stream) ref() { - atomic.AddInt32(&s.n, 1) -} - -func (s *stream) unref() { - if atomic.AddInt32(&s.n, -1) == 0 { - log.Println(s.t, s, "close") + go func () { + s.wg.Wait() s.t.mu.Lock() delete(s.t.streams, s.id) s.t.mu.Unlock() - } + + s.t.wg.Done() + + log.Println(s.t, s, "close") + }() + + return s } func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { watch := func (q queue.Q, f func (q queue.Q) error) { - s.ref() - defer s.unref() + defer s.wg.Done() if err := f(q); err != nil { log.Println(s.t, s, err) } } + s.wg.Add(2) + go func () { watch(wq, c.Send) close(wq) @@ -144,9 +174,10 @@ func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { } func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) { + s.wg.Add(1) + go func () { - s.ref() - defer s.unref() + defer s.wg.Done() if err := f(rq, wq); err != nil { log.Println(s.t, s, err) @@ -156,28 +187,14 @@ func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) { }() } -func (t *tunnel) handle(in socket.Channel) { - log.Println(t, "handle") - - out, err := t.out.Open() - if err != nil { - log.Println(t, err) - in.Close() - return - } - - log.Println(t, "open", out) - - s := t.newStream(in, out) - - s.ref() - defer s.unref() +func (s *stream) run() { + s.t.wg.Add(1) rq, wq := queue.New(), queue.New() - s.watchChannel(rq, wq, in) + s.watchChannel(rq, wq, s.in) - for _, m := range t.m { + for _, m := range s.t.m { send, recv := m.Open() if send != nil { q := queue.New() @@ -191,9 +208,12 @@ func (t *tunnel) handle(in socket.Channel) { } } - s.watchChannel(wq, rq, out) + s.watchChannel(wq, rq, s.out) +} - log.Println(t, s, "create", in, out) +func (s *stream) stop() { + s.in.Close() + s.out.Close() } func newTunnel(args []string) (*tunnel, error) { |
