diff options
Diffstat (limited to 'pkg/server/tunnel.go')
| -rw-r--r-- | pkg/server/tunnel.go | 69 |
1 files changed, 43 insertions, 26 deletions
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index 309c272..d54208b 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -129,6 +129,27 @@ func (t *tunnel) sleep(d time.Duration) { tmr.Stop() } +func (t *tunnel) openPipes(env env.Env) ([]*hook.Pipe, error) { + var pipes []*hook.Pipe + + cleanup := func() { + for _, p := range pipes { + p.Close() + } + } + + for _, h := range t.hooks { + p, err := h.Open(env) + if err != nil { + cleanup() + return nil, fmt.Errorf("%s: %w", h, err) + } + pipes = append(pipes, p) + } + + return pipes, nil +} + func (t *tunnel) serve() { for t.acquire() { var ok bool @@ -137,17 +158,21 @@ func (t *tunnel) serve() { env.Set("tunnel", t.id) - if in, err := t.in.Open(env); err == nil { - if out, err := t.out.Open(env); err == nil { - s := t.newStream(env, in, out) - log.Println(t, s, "create", in, out) - ok = true - } else { + if in, err := t.in.Open(env); err != nil { + if t.alive() { log.Println(t, err) - in.Close() } - } else if t.alive() { + } else if out, err := t.out.Open(env); err != nil { + log.Println(t, err) + in.Close() + } else if pipes, err := t.openPipes(env); err != nil { log.Println(t, err) + in.Close() + out.Close() + } else { + s := t.newStream(env, in, out, pipes) + log.Println(t, s, "create", in, out) + ok = true } if !ok { @@ -159,11 +184,12 @@ func (t *tunnel) serve() { close(t.done) } -func (t *tunnel) newStream(env env.Env, in, out socket.Conn) *stream { +func (t *tunnel) newStream(env env.Env, in, out socket.Conn, pipes []*hook.Pipe) *stream { s := &stream{ t: t, in: in, out: out, + pipes: pipes, env: env, id: t.nextSid, since: time.Now(), @@ -301,27 +327,18 @@ func (s *stream) run() { s.channel(s.in, &s.m.in, rq, wq) - for _, h := range s.t.hooks { - p, err := h.Open(s.env) - if err != nil { - // FIXME: abort stream on error - log.Println(s.t, s, h, err) - continue - } - + for _, p := range s.pipes { if p.Send != nil { q := queue.New() - s.pipe(h, p.Send, wq, q) + s.pipe(p.Hook, p.Send, wq, q) wq = q } if p.Recv != nil { q := queue.New() - s.pipe(h, p.Recv, q, rq) + s.pipe(p.Hook, p.Recv, q, rq) rq = q } - - s.pipes = append(s.pipes, p) } s.channel(s.out, &s.m.out, wq, rq) @@ -332,11 +349,11 @@ func (s *stream) stop() { s.out.Close() } -func parseHooks(args []string, env env.Env) ([]hook.H, error) { +func parseHooks(args []string) ([]hook.H, error) { var hooks []hook.H for _, arg := range args { - if h, err := hook.New(arg, env); err != nil { + if h, err := hook.New(arg); err != nil { return nil, err } else { hooks = append(hooks, h) @@ -353,16 +370,16 @@ func newTunnel(limit int, args []string, env env.Env) (*tunnel, error) { n := len(args) - 1 - if in, err = socket.New(args[0], env); err != nil { + if in, err = socket.New(args[0]); err != nil { return nil, err } - if out, err = socket.New(args[n], env); err != nil { + if out, err = socket.New(args[n]); err != nil { in.Close() return nil, err } - if hooks, err = parseHooks(args[1:n], env); err != nil { + if hooks, err = parseHooks(args[1:n]); err != nil { in.Close() out.Close() return nil, err |
