summaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/server/socket/socket.go2
-rw-r--r--pkg/server/tunnel.go96
2 files changed, 59 insertions, 39 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index cad1ad3..bf754cf 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -61,7 +61,7 @@ func (cc *connChannel) Recv(rq queue.Q) (err error) {
func (cc *connChannel) String() string {
local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr()
- return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote)
+ return fmt.Sprintf("%s/%s->%s", local.Network(), remote, local)
}
func (cc *connChannel) shutdown(err *error) {
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index 04f838b..e4a324c 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -5,7 +5,6 @@ import (
"tunnel/pkg/server/queue"
"tunnel/pkg/server/socket"
"tunnel/pkg/config"
- "sync/atomic"
"strings"
"time"
"sort"
@@ -16,9 +15,9 @@ import (
type stream struct {
id int
- n int32
t *tunnel
since time.Time
+ wg sync.WaitGroup
in, out socket.Channel
}
@@ -29,6 +28,7 @@ type tunnel struct {
streams map[int]*stream
mu sync.Mutex
+ wg sync.WaitGroup
nextSid int
@@ -47,12 +47,26 @@ func (t *tunnel) String() string {
return fmt.Sprintf("tunnel(%s)", t.id)
}
-/* FIXME close streams also */
-func (t *tunnel) Close() {
+func (t *tunnel) stopServe() {
close(t.quit)
t.in.Close()
t.out.Close()
<-t.done
+}
+
+func (t *tunnel) stopStreams() {
+ t.mu.Lock()
+ for _, s := range t.streams {
+ s.stop()
+ }
+ t.mu.Unlock()
+
+ t.wg.Wait()
+}
+
+func (t *tunnel) Close() {
+ t.stopServe()
+ t.stopStreams()
log.Println(t, "delete")
}
@@ -94,6 +108,21 @@ func (t *tunnel) serve() {
close(t.done)
}
+func (t *tunnel) handle(in socket.Channel) {
+ out, err := t.out.Open()
+ if err != nil {
+ log.Println(t, err)
+ in.Close()
+ return
+ }
+
+ log.Println(t, "open", out)
+
+ s := t.newStream(in, out)
+
+ log.Println(t, s, "create", in, out)
+}
+
func (t *tunnel) newStream(in, out socket.Channel) *stream {
s := &stream{
t: t,
@@ -108,33 +137,34 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream {
t.streams[s.id] = s
t.mu.Unlock()
- return s
-}
+ s.run()
-func (s *stream) ref() {
- atomic.AddInt32(&s.n, 1)
-}
-
-func (s *stream) unref() {
- if atomic.AddInt32(&s.n, -1) == 0 {
- log.Println(s.t, s, "close")
+ go func () {
+ s.wg.Wait()
s.t.mu.Lock()
delete(s.t.streams, s.id)
s.t.mu.Unlock()
- }
+
+ s.t.wg.Done()
+
+ log.Println(s.t, s, "close")
+ }()
+
+ return s
}
func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) {
watch := func (q queue.Q, f func (q queue.Q) error) {
- s.ref()
- defer s.unref()
+ defer s.wg.Done()
if err := f(q); err != nil {
log.Println(s.t, s, err)
}
}
+ s.wg.Add(2)
+
go func () {
watch(wq, c.Send)
close(wq)
@@ -144,9 +174,10 @@ func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) {
}
func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) {
+ s.wg.Add(1)
+
go func () {
- s.ref()
- defer s.unref()
+ defer s.wg.Done()
if err := f(rq, wq); err != nil {
log.Println(s.t, s, err)
@@ -156,28 +187,14 @@ func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) {
}()
}
-func (t *tunnel) handle(in socket.Channel) {
- log.Println(t, "handle")
-
- out, err := t.out.Open()
- if err != nil {
- log.Println(t, err)
- in.Close()
- return
- }
-
- log.Println(t, "open", out)
-
- s := t.newStream(in, out)
-
- s.ref()
- defer s.unref()
+func (s *stream) run() {
+ s.t.wg.Add(1)
rq, wq := queue.New(), queue.New()
- s.watchChannel(rq, wq, in)
+ s.watchChannel(rq, wq, s.in)
- for _, m := range t.m {
+ for _, m := range s.t.m {
send, recv := m.Open()
if send != nil {
q := queue.New()
@@ -191,9 +208,12 @@ func (t *tunnel) handle(in socket.Channel) {
}
}
- s.watchChannel(wq, rq, out)
+ s.watchChannel(wq, rq, s.out)
+}
- log.Println(t, s, "create", in, out)
+func (s *stream) stop() {
+ s.in.Close()
+ s.out.Close()
}
func newTunnel(args []string) (*tunnel, error) {