summaryrefslogtreecommitdiff
path: root/pkg/server/tunnel.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/tunnel.go')
-rw-r--r--pkg/server/tunnel.go93
1 files changed, 53 insertions, 40 deletions
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index b6c5ada..8b86ddc 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -11,6 +11,7 @@ import (
"sync"
"sync/atomic"
"time"
+
"tunnel/pkg/config"
"tunnel/pkg/server/env"
"tunnel/pkg/server/hook"
@@ -92,6 +93,8 @@ func (t *tunnel) stopStreams() {
}
func (t *tunnel) Close() {
+ t.env.Detach()
+
t.stopServe()
t.stopStreams()
@@ -139,7 +142,7 @@ func (t *tunnel) openPipes(env env.Env) ([]*hook.Pipe, error) {
}
for _, h := range t.hooks {
- p, err := h.Open(env)
+ p, err := h.New(env)
if err != nil {
cleanup()
return nil, fmt.Errorf("%s: %w", h, err)
@@ -159,11 +162,11 @@ func (t *tunnel) serve() {
env.Set("tunnel", t.id)
env.Set("stream", strconv.Itoa(t.nextSid))
- if in, err := t.in.Open(env); err != nil {
+ if in, err := t.in.New(env); err != nil {
if t.alive() {
log.Println(t, err)
}
- } else if out, err := t.out.Open(env); err != nil {
+ } else if out, err := t.out.New(env); err != nil {
log.Println(t, err)
in.Close()
} else if pipes, err := t.openPipes(env); err != nil {
@@ -377,47 +380,41 @@ func parseHooks(args []string) ([]hook.H, error) {
return hooks, nil
}
-func newTunnel(limit int, args []string, env env.Env) (*tunnel, error) {
- var in, out socket.S
- var hooks []hook.H
- var err error
+func (t *tunnel) init(limit int, args []string, env env.Env) (err error) {
+ t.env = env.Fork("tunnel", t.id)
+ defer func() {
+ if err != nil {
+ t.env.Detach()
+ }
+ }()
+
+ closeOnFail := func(s socket.S) {
+ if err != nil {
+ s.Close()
+ }
+ }
n := len(args) - 1
- if in, err = socket.New(args[0]); err != nil {
- return nil, err
+ if t.in, err = socket.New(args[0], t.env); err != nil {
+ return
}
+ defer closeOnFail(t.in)
- if _, ok := in.(socket.Single); ok {
+ if _, ok := t.in.(socket.Single); ok {
limit = 1
}
- if out, err = socket.New(args[n]); err != nil {
- in.Close()
- return nil, err
+ if t.out, err = socket.New(args[n], t.env); err != nil {
+ return
}
+ defer closeOnFail(t.out)
- if hooks, err = parseHooks(args[1:n]); err != nil {
- in.Close()
- out.Close()
- return nil, err
+ if t.hooks, err = parseHooks(args[1:n]); err != nil {
+ return
}
- t := &tunnel{
- args: strings.Join(args, " "),
- quit: make(chan struct{}),
- done: make(chan struct{}),
- hooks: hooks,
- in: in,
- out: out,
- env: env,
- queue: make(chan struct{}, limit),
- streams: make(map[int]*stream),
- }
-
- go t.serve()
-
- return t, nil
+ return
}
func isOkTunnelName(s string) bool {
@@ -471,18 +468,27 @@ func tunnelAdd(r *request) {
r.Fatal("not enough args")
}
- t, err := newTunnel(limit, args, r.c.s.env)
- if err != nil {
- r.Fatal(err)
+ if name == "" {
+ name = r.c.s.tunnels.next()
}
- if name == "" {
- t.id = r.c.s.tunnels.add(t)
- } else {
- t.id = name
- r.c.s.tunnels[t.id] = t
+ t := &tunnel{
+ id: name,
+ args: strings.Join(args, " "),
+ quit: make(chan struct{}),
+ done: make(chan struct{}),
+ queue: make(chan struct{}, limit),
+ streams: make(map[int]*stream),
}
+ if err := t.init(limit, args, r.c.s.env); err != nil {
+ r.Fatal(err)
+ }
+
+ r.c.s.tunnels[t.id] = t
+
+ go t.serve()
+
log.Println(r.c, r, t, "create")
}
@@ -584,6 +590,12 @@ func showHooks(r *request) {
}
}
+func showSockets(r *request) {
+ for _, s := range socket.GetList() {
+ r.Println(s)
+ }
+}
+
func init() {
newCmd(tunnelAdd, "add")
newCmd(tunnelDel, "del")
@@ -591,6 +603,7 @@ func init() {
newCmd(tunnelRename, "rename")
newCmd(showHooks, "hooks")
+ newCmd(showSockets, "sockets")
newCmd(showTunnels, "show")