From 8b7283ad01a8dde92cf708f81f6c1105647bafd7 Mon Sep 17 00:00:00 2001 From: Mikhail Osipov Date: Sat, 29 Feb 2020 21:50:58 +0300 Subject: close pipes at end of stream --- pkg/server/hook/aes.go | 2 +- pkg/server/hook/auth.go | 2 +- pkg/server/hook/hex.go | 4 ++-- pkg/server/hook/hook.go | 42 ++++++++++++++++++++++++++++-------------- pkg/server/hook/tee.go | 20 +++++--------------- pkg/server/hook/zip.go | 8 ++++---- pkg/server/tunnel.go | 37 +++++++++++++++++++++++-------------- 7 files changed, 64 insertions(+), 51 deletions(-) (limited to 'pkg/server') diff --git a/pkg/server/hook/aes.go b/pkg/server/hook/aes.go index 414f6d6..dc9605a 100644 --- a/pkg/server/hook/aes.go +++ b/pkg/server/hook/aes.go @@ -74,7 +74,7 @@ func newAes(env env.Env) *aesPipe { return a } -func (h aesHook) Open(env env.Env) (interface{}, error) { +func (aesHook) Open(env env.Env) (interface{}, error) { return newAes(env), nil } diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go index 8d7d23e..11068e5 100644 --- a/pkg/server/hook/auth.go +++ b/pkg/server/hook/auth.go @@ -122,7 +122,7 @@ func (a *auth) Recv(rq, wq queue.Q) (err error) { return queue.IoCopy(r, wq.Writer()) } -func (h authHook) Open(env env.Env) (interface{}, error) { +func (authHook) Open(env env.Env) (interface{}, error) { a := &auth{ secret: getHookVar(env, "secret"), recvChallenge: make(chan struct{}), diff --git a/pkg/server/hook/hex.go b/pkg/server/hook/hex.go index beaadeb..e37bc6e 100644 --- a/pkg/server/hook/hex.go +++ b/pkg/server/hook/hex.go @@ -9,7 +9,7 @@ import ( type hexHook struct{} -func (h hexHook) Send(rq, wq queue.Q) error { +func (hexHook) Send(rq, wq queue.Q) error { enc := hex.NewEncoder(wq.Writer()) for b := range rq { @@ -19,7 +19,7 @@ func (h hexHook) Send(rq, wq queue.Q) error { return nil } -func (h hexHook) Recv(rq, wq queue.Q) error { +func (hexHook) Recv(rq, wq queue.Q) error { r := hex.NewDecoder(rq.Reader()) return queue.IoCopy(r, wq.Writer()) } diff --git a/pkg/server/hook/hook.go b/pkg/server/hook/hook.go index e4497ff..1702afd 100644 --- a/pkg/server/hook/hook.go +++ b/pkg/server/hook/hook.go @@ -14,12 +14,18 @@ type hookInitFunc func(opts.Opts, env.Env) (hook, error) var hooks = map[string]hookInitFunc{} +type Pipe struct { + priv interface{} + Send Func + Recv Func +} + type hook interface { Open(env env.Env) (interface{}, error) } type H interface { - hook + Open(env env.Env) (*Pipe, error) String() string } @@ -31,6 +37,10 @@ type Recver interface { Recv(rq, wq queue.Q) error } +type Closer interface { + Close() +} + type Func func(rq, wq queue.Q) error func (f Func) Send(rq, wq queue.Q) error { @@ -51,29 +61,33 @@ func (w *wrapper) String() string { return fmt.Sprintf("hook:%s", w.name) } -func Open(h H, env env.Env) (Func, Func, error) { - var send, recv Func - - w := h.(*wrapper) - - it, err := h.Open(env) +func (w *wrapper) Open(env env.Env) (*Pipe, error) { + it, err := w.hook.Open(env) if err != nil { - return nil, nil, err + return nil, err } - if sender, ok := it.(Sender); ok { - send = sender.Send + pipe := &Pipe{priv: it} + + if s, ok := it.(Sender); ok { + pipe.Send = s.Send } - if recver, ok := it.(Recver); ok { - recv = recver.Recv + if r, ok := it.(Recver); ok { + pipe.Recv = r.Recv } if w.reverse { - send, recv = recv, send + pipe.Send, pipe.Recv = pipe.Recv, pipe.Send } - return send, recv, nil + return pipe, nil +} + +func (p *Pipe) Close() { + if c, ok := p.priv.(Closer); ok { + c.Close() + } } func New(desc string, env env.Env) (H, error) { diff --git a/pkg/server/hook/tee.go b/pkg/server/hook/tee.go index fd30a81..521164b 100644 --- a/pkg/server/hook/tee.go +++ b/pkg/server/hook/tee.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "path" - "sync" "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" @@ -15,9 +14,7 @@ import ( const teeDefaultFile = "/tmp/tunnel/dump" type tee struct { - f *os.File - mu sync.Mutex - wg sync.WaitGroup + f *os.File } type teeHook struct { @@ -41,8 +38,6 @@ func (t *tee) dump(s string, p []byte) error { } func (t *tee) Send(rq, wq queue.Q) error { - defer t.wg.Done() - for b := range rq { t.dump(">", b) wq <- b @@ -52,8 +47,6 @@ func (t *tee) Send(rq, wq queue.Q) error { } func (t *tee) Recv(rq, wq queue.Q) error { - defer t.wg.Done() - for b := range rq { t.dump("<", b) wq <- b @@ -62,6 +55,10 @@ func (t *tee) Recv(rq, wq queue.Q) error { return nil } +func (t *tee) Close() { + t.f.Close() +} + func (h *teeHook) where(env env.Env) string { if h.file != "" { return h.file @@ -93,13 +90,6 @@ func (h *teeHook) Open(env env.Env) (interface{}, error) { t.f = f } - t.wg.Add(2) - - go func() { - t.wg.Wait() - t.f.Close() - }() - return &t, nil } diff --git a/pkg/server/hook/zip.go b/pkg/server/hook/zip.go index 61264c9..94160fe 100644 --- a/pkg/server/hook/zip.go +++ b/pkg/server/hook/zip.go @@ -10,7 +10,7 @@ import ( type zipHook struct{} -func (m zipHook) Send(rq, wq queue.Q) error { +func (zipHook) Send(rq, wq queue.Q) error { w, err := flate.NewWriter(wq.Writer(), flate.BestCompression) if err != nil { return err @@ -28,7 +28,7 @@ func (m zipHook) Send(rq, wq queue.Q) error { return w.Close() } -func (m zipHook) Recv(rq, wq queue.Q) error { +func (zipHook) Recv(rq, wq queue.Q) error { r := flate.NewReader(rq.Reader()) // FIXME: not received ending due to ultimate conn.Close @@ -42,8 +42,8 @@ func (m zipHook) Recv(rq, wq queue.Q) error { return r.Close() } -func (m zipHook) Open(env.Env) (interface{}, error) { - return m, nil +func (h zipHook) Open(env.Env) (interface{}, error) { + return h, nil } func newZipHook(opts.Opts, env.Env) (hook, error) { diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index e00db6c..8465fac 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -24,6 +24,7 @@ type stream struct { since time.Time wg sync.WaitGroup in, out socket.Channel + pipes []*hook.Pipe } type tunnel struct { @@ -150,19 +151,25 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream { t.streams[s.id] = s t.mu.Unlock() - go func() { - s.wg.Wait() + go s.waitAndClose() - s.t.mu.Lock() - delete(s.t.streams, s.id) - s.t.mu.Unlock() + return s +} - s.t.wg.Done() +func (s *stream) waitAndClose() { + s.wg.Wait() - log.Println(s.t, s, "close") - }() + s.t.mu.Lock() + delete(s.t.streams, s.id) + s.t.mu.Unlock() - return s + s.t.wg.Done() + + for _, p := range s.pipes { + p.Close() + } + + log.Println(s.t, s, "close") } func (s *stream) channel(c socket.Channel, rq, wq queue.Q) { @@ -210,24 +217,26 @@ func (s *stream) run() { s.channel(s.in, rq, wq) for _, h := range s.t.hooks { - send, recv, err := hook.Open(h, s.env) + p, err := h.Open(s.env) if err != nil { // FIXME: abort stream on error log.Println(s.t, s, h, err) continue } - if send != nil { + if p.Send != nil { q := queue.New() - s.pipe(h, send, wq, q) + s.pipe(h, p.Send, wq, q) wq = q } - if recv != nil { + if p.Recv != nil { q := queue.New() - s.pipe(h, recv, q, rq) + s.pipe(h, p.Recv, q, rq) rq = q } + + s.pipes = append(s.pipes, p) } s.channel(s.out, wq, rq) -- cgit v1.2.3-70-g09d2