summaryrefslogtreecommitdiff
path: root/pkg/server/tunnel.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/tunnel.go')
-rw-r--r--pkg/server/tunnel.go69
1 files changed, 43 insertions, 26 deletions
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index 309c272..d54208b 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -129,6 +129,27 @@ func (t *tunnel) sleep(d time.Duration) {
tmr.Stop()
}
+func (t *tunnel) openPipes(env env.Env) ([]*hook.Pipe, error) {
+ var pipes []*hook.Pipe
+
+ cleanup := func() {
+ for _, p := range pipes {
+ p.Close()
+ }
+ }
+
+ for _, h := range t.hooks {
+ p, err := h.Open(env)
+ if err != nil {
+ cleanup()
+ return nil, fmt.Errorf("%s: %w", h, err)
+ }
+ pipes = append(pipes, p)
+ }
+
+ return pipes, nil
+}
+
func (t *tunnel) serve() {
for t.acquire() {
var ok bool
@@ -137,17 +158,21 @@ func (t *tunnel) serve() {
env.Set("tunnel", t.id)
- if in, err := t.in.Open(env); err == nil {
- if out, err := t.out.Open(env); err == nil {
- s := t.newStream(env, in, out)
- log.Println(t, s, "create", in, out)
- ok = true
- } else {
+ if in, err := t.in.Open(env); err != nil {
+ if t.alive() {
log.Println(t, err)
- in.Close()
}
- } else if t.alive() {
+ } else if out, err := t.out.Open(env); err != nil {
+ log.Println(t, err)
+ in.Close()
+ } else if pipes, err := t.openPipes(env); err != nil {
log.Println(t, err)
+ in.Close()
+ out.Close()
+ } else {
+ s := t.newStream(env, in, out, pipes)
+ log.Println(t, s, "create", in, out)
+ ok = true
}
if !ok {
@@ -159,11 +184,12 @@ func (t *tunnel) serve() {
close(t.done)
}
-func (t *tunnel) newStream(env env.Env, in, out socket.Conn) *stream {
+func (t *tunnel) newStream(env env.Env, in, out socket.Conn, pipes []*hook.Pipe) *stream {
s := &stream{
t: t,
in: in,
out: out,
+ pipes: pipes,
env: env,
id: t.nextSid,
since: time.Now(),
@@ -301,27 +327,18 @@ func (s *stream) run() {
s.channel(s.in, &s.m.in, rq, wq)
- for _, h := range s.t.hooks {
- p, err := h.Open(s.env)
- if err != nil {
- // FIXME: abort stream on error
- log.Println(s.t, s, h, err)
- continue
- }
-
+ for _, p := range s.pipes {
if p.Send != nil {
q := queue.New()
- s.pipe(h, p.Send, wq, q)
+ s.pipe(p.Hook, p.Send, wq, q)
wq = q
}
if p.Recv != nil {
q := queue.New()
- s.pipe(h, p.Recv, q, rq)
+ s.pipe(p.Hook, p.Recv, q, rq)
rq = q
}
-
- s.pipes = append(s.pipes, p)
}
s.channel(s.out, &s.m.out, wq, rq)
@@ -332,11 +349,11 @@ func (s *stream) stop() {
s.out.Close()
}
-func parseHooks(args []string, env env.Env) ([]hook.H, error) {
+func parseHooks(args []string) ([]hook.H, error) {
var hooks []hook.H
for _, arg := range args {
- if h, err := hook.New(arg, env); err != nil {
+ if h, err := hook.New(arg); err != nil {
return nil, err
} else {
hooks = append(hooks, h)
@@ -353,16 +370,16 @@ func newTunnel(limit int, args []string, env env.Env) (*tunnel, error) {
n := len(args) - 1
- if in, err = socket.New(args[0], env); err != nil {
+ if in, err = socket.New(args[0]); err != nil {
return nil, err
}
- if out, err = socket.New(args[n], env); err != nil {
+ if out, err = socket.New(args[n]); err != nil {
in.Close()
return nil, err
}
- if hooks, err = parseHooks(args[1:n], env); err != nil {
+ if hooks, err = parseHooks(args[1:n]); err != nil {
in.Close()
out.Close()
return nil, err