summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikhail Osipov <mike.osipov@gmail.com>2020-02-29 21:50:58 +0300
committerMikhail Osipov <mike.osipov@gmail.com>2020-02-29 23:48:53 +0300
commit8b7283ad01a8dde92cf708f81f6c1105647bafd7 (patch)
tree6598c17258bacb4e84f6e486e591460437e93086
parent7ab641d239e502e09c6f05dfc7efd069fcf3c314 (diff)
close pipes at end of stream
-rw-r--r--pkg/server/hook/aes.go2
-rw-r--r--pkg/server/hook/auth.go2
-rw-r--r--pkg/server/hook/hex.go4
-rw-r--r--pkg/server/hook/hook.go42
-rw-r--r--pkg/server/hook/tee.go20
-rw-r--r--pkg/server/hook/zip.go8
-rw-r--r--pkg/server/tunnel.go37
7 files changed, 64 insertions, 51 deletions
diff --git a/pkg/server/hook/aes.go b/pkg/server/hook/aes.go
index 414f6d6..dc9605a 100644
--- a/pkg/server/hook/aes.go
+++ b/pkg/server/hook/aes.go
@@ -74,7 +74,7 @@ func newAes(env env.Env) *aesPipe {
return a
}
-func (h aesHook) Open(env env.Env) (interface{}, error) {
+func (aesHook) Open(env env.Env) (interface{}, error) {
return newAes(env), nil
}
diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go
index 8d7d23e..11068e5 100644
--- a/pkg/server/hook/auth.go
+++ b/pkg/server/hook/auth.go
@@ -122,7 +122,7 @@ func (a *auth) Recv(rq, wq queue.Q) (err error) {
return queue.IoCopy(r, wq.Writer())
}
-func (h authHook) Open(env env.Env) (interface{}, error) {
+func (authHook) Open(env env.Env) (interface{}, error) {
a := &auth{
secret: getHookVar(env, "secret"),
recvChallenge: make(chan struct{}),
diff --git a/pkg/server/hook/hex.go b/pkg/server/hook/hex.go
index beaadeb..e37bc6e 100644
--- a/pkg/server/hook/hex.go
+++ b/pkg/server/hook/hex.go
@@ -9,7 +9,7 @@ import (
type hexHook struct{}
-func (h hexHook) Send(rq, wq queue.Q) error {
+func (hexHook) Send(rq, wq queue.Q) error {
enc := hex.NewEncoder(wq.Writer())
for b := range rq {
@@ -19,7 +19,7 @@ func (h hexHook) Send(rq, wq queue.Q) error {
return nil
}
-func (h hexHook) Recv(rq, wq queue.Q) error {
+func (hexHook) Recv(rq, wq queue.Q) error {
r := hex.NewDecoder(rq.Reader())
return queue.IoCopy(r, wq.Writer())
}
diff --git a/pkg/server/hook/hook.go b/pkg/server/hook/hook.go
index e4497ff..1702afd 100644
--- a/pkg/server/hook/hook.go
+++ b/pkg/server/hook/hook.go
@@ -14,12 +14,18 @@ type hookInitFunc func(opts.Opts, env.Env) (hook, error)
var hooks = map[string]hookInitFunc{}
+type Pipe struct {
+ priv interface{}
+ Send Func
+ Recv Func
+}
+
type hook interface {
Open(env env.Env) (interface{}, error)
}
type H interface {
- hook
+ Open(env env.Env) (*Pipe, error)
String() string
}
@@ -31,6 +37,10 @@ type Recver interface {
Recv(rq, wq queue.Q) error
}
+type Closer interface {
+ Close()
+}
+
type Func func(rq, wq queue.Q) error
func (f Func) Send(rq, wq queue.Q) error {
@@ -51,29 +61,33 @@ func (w *wrapper) String() string {
return fmt.Sprintf("hook:%s", w.name)
}
-func Open(h H, env env.Env) (Func, Func, error) {
- var send, recv Func
-
- w := h.(*wrapper)
-
- it, err := h.Open(env)
+func (w *wrapper) Open(env env.Env) (*Pipe, error) {
+ it, err := w.hook.Open(env)
if err != nil {
- return nil, nil, err
+ return nil, err
}
- if sender, ok := it.(Sender); ok {
- send = sender.Send
+ pipe := &Pipe{priv: it}
+
+ if s, ok := it.(Sender); ok {
+ pipe.Send = s.Send
}
- if recver, ok := it.(Recver); ok {
- recv = recver.Recv
+ if r, ok := it.(Recver); ok {
+ pipe.Recv = r.Recv
}
if w.reverse {
- send, recv = recv, send
+ pipe.Send, pipe.Recv = pipe.Recv, pipe.Send
}
- return send, recv, nil
+ return pipe, nil
+}
+
+func (p *Pipe) Close() {
+ if c, ok := p.priv.(Closer); ok {
+ c.Close()
+ }
}
func New(desc string, env env.Env) (H, error) {
diff --git a/pkg/server/hook/tee.go b/pkg/server/hook/tee.go
index fd30a81..521164b 100644
--- a/pkg/server/hook/tee.go
+++ b/pkg/server/hook/tee.go
@@ -6,7 +6,6 @@ import (
"fmt"
"os"
"path"
- "sync"
"tunnel/pkg/server/env"
"tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
@@ -15,9 +14,7 @@ import (
const teeDefaultFile = "/tmp/tunnel/dump"
type tee struct {
- f *os.File
- mu sync.Mutex
- wg sync.WaitGroup
+ f *os.File
}
type teeHook struct {
@@ -41,8 +38,6 @@ func (t *tee) dump(s string, p []byte) error {
}
func (t *tee) Send(rq, wq queue.Q) error {
- defer t.wg.Done()
-
for b := range rq {
t.dump(">", b)
wq <- b
@@ -52,8 +47,6 @@ func (t *tee) Send(rq, wq queue.Q) error {
}
func (t *tee) Recv(rq, wq queue.Q) error {
- defer t.wg.Done()
-
for b := range rq {
t.dump("<", b)
wq <- b
@@ -62,6 +55,10 @@ func (t *tee) Recv(rq, wq queue.Q) error {
return nil
}
+func (t *tee) Close() {
+ t.f.Close()
+}
+
func (h *teeHook) where(env env.Env) string {
if h.file != "" {
return h.file
@@ -93,13 +90,6 @@ func (h *teeHook) Open(env env.Env) (interface{}, error) {
t.f = f
}
- t.wg.Add(2)
-
- go func() {
- t.wg.Wait()
- t.f.Close()
- }()
-
return &t, nil
}
diff --git a/pkg/server/hook/zip.go b/pkg/server/hook/zip.go
index 61264c9..94160fe 100644
--- a/pkg/server/hook/zip.go
+++ b/pkg/server/hook/zip.go
@@ -10,7 +10,7 @@ import (
type zipHook struct{}
-func (m zipHook) Send(rq, wq queue.Q) error {
+func (zipHook) Send(rq, wq queue.Q) error {
w, err := flate.NewWriter(wq.Writer(), flate.BestCompression)
if err != nil {
return err
@@ -28,7 +28,7 @@ func (m zipHook) Send(rq, wq queue.Q) error {
return w.Close()
}
-func (m zipHook) Recv(rq, wq queue.Q) error {
+func (zipHook) Recv(rq, wq queue.Q) error {
r := flate.NewReader(rq.Reader())
// FIXME: not received ending due to ultimate conn.Close
@@ -42,8 +42,8 @@ func (m zipHook) Recv(rq, wq queue.Q) error {
return r.Close()
}
-func (m zipHook) Open(env.Env) (interface{}, error) {
- return m, nil
+func (h zipHook) Open(env.Env) (interface{}, error) {
+ return h, nil
}
func newZipHook(opts.Opts, env.Env) (hook, error) {
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index e00db6c..8465fac 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -24,6 +24,7 @@ type stream struct {
since time.Time
wg sync.WaitGroup
in, out socket.Channel
+ pipes []*hook.Pipe
}
type tunnel struct {
@@ -150,19 +151,25 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream {
t.streams[s.id] = s
t.mu.Unlock()
- go func() {
- s.wg.Wait()
+ go s.waitAndClose()
- s.t.mu.Lock()
- delete(s.t.streams, s.id)
- s.t.mu.Unlock()
+ return s
+}
- s.t.wg.Done()
+func (s *stream) waitAndClose() {
+ s.wg.Wait()
- log.Println(s.t, s, "close")
- }()
+ s.t.mu.Lock()
+ delete(s.t.streams, s.id)
+ s.t.mu.Unlock()
- return s
+ s.t.wg.Done()
+
+ for _, p := range s.pipes {
+ p.Close()
+ }
+
+ log.Println(s.t, s, "close")
}
func (s *stream) channel(c socket.Channel, rq, wq queue.Q) {
@@ -210,24 +217,26 @@ func (s *stream) run() {
s.channel(s.in, rq, wq)
for _, h := range s.t.hooks {
- send, recv, err := hook.Open(h, s.env)
+ p, err := h.Open(s.env)
if err != nil {
// FIXME: abort stream on error
log.Println(s.t, s, h, err)
continue
}
- if send != nil {
+ if p.Send != nil {
q := queue.New()
- s.pipe(h, send, wq, q)
+ s.pipe(h, p.Send, wq, q)
wq = q
}
- if recv != nil {
+ if p.Recv != nil {
q := queue.New()
- s.pipe(h, recv, q, rq)
+ s.pipe(h, p.Recv, q, rq)
rq = q
}
+
+ s.pipes = append(s.pipes, p)
}
s.channel(s.out, wq, rq)