summaryrefslogtreecommitdiff
path: root/pkg/server
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server')
-rw-r--r--pkg/server/env/env.go16
-rw-r--r--pkg/server/hook/aes.go2
-rw-r--r--pkg/server/hook/auth.go2
-rw-r--r--pkg/server/hook/b64.go2
-rw-r--r--pkg/server/hook/b85.go2
-rw-r--r--pkg/server/hook/hex.go2
-rw-r--r--pkg/server/hook/hook.go13
-rw-r--r--pkg/server/hook/info-http.go2
-rw-r--r--pkg/server/hook/proxy.go27
-rw-r--r--pkg/server/hook/split.go2
-rw-r--r--pkg/server/hook/tee.go2
-rw-r--r--pkg/server/hook/zip.go2
-rw-r--r--pkg/server/opts/opts.go5
-rw-r--r--pkg/server/server.go27
-rw-r--r--pkg/server/socket/dial.go28
-rw-r--r--pkg/server/socket/listen.go51
-rw-r--r--pkg/server/socket/socket.go36
-rw-r--r--pkg/server/socket/sys.go131
-rw-r--r--pkg/server/socket/tun.go9
-rw-r--r--pkg/server/tunnel.go69
20 files changed, 307 insertions, 123 deletions
diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go
index 45650ab..8594e10 100644
--- a/pkg/server/env/env.go
+++ b/pkg/server/env/env.go
@@ -19,8 +19,8 @@ type Env struct {
const namePattern = "[a-zA-Z][a-zA-Z0-9.]*"
-var isNamePattern = regexp.MustCompile("^" + namePattern + "$").MatchString
-var namePatternRe = regexp.MustCompile("@(" + namePattern + "|{" + namePattern + "})")
+var isGoodName = regexp.MustCompile("^" + namePattern + "$").MatchString
+var varRe = regexp.MustCompile("@(" + namePattern + "|{" + namePattern + "})")
var errBadVariable = errors.New("bad variable name")
var errEmptyVariable = errors.New("empty variable")
@@ -67,8 +67,8 @@ func (e Env) Get(key string) string {
return v
}
-func (e *env) Set(key string, value string) error {
- if !isNamePattern(key) {
+func (e Env) Set(key string, value string) error {
+ if !isGoodName(key) {
return errBadVariable
}
@@ -139,11 +139,13 @@ func (e Env) Eval(s string) string {
}
for {
- if t := namePatternRe.ReplaceAllStringFunc(s, repl); t == s {
+ t := varRe.ReplaceAllStringFunc(s, repl)
+
+ if t == s {
break
- } else {
- s = t
}
+
+ s = t
}
return s
diff --git a/pkg/server/hook/aes.go b/pkg/server/hook/aes.go
index 8ae47a3..ef0ef1a 100644
--- a/pkg/server/hook/aes.go
+++ b/pkg/server/hook/aes.go
@@ -78,7 +78,7 @@ func (aesHook) Open(env env.Env) (interface{}, error) {
return newAes(env), nil
}
-func newAesHook(opts.Opts, env.Env) (hook, error) {
+func newAesHook(opts.Opts) (hook, error) {
return aesHook{}, nil
}
diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go
index 14ad114..86acd91 100644
--- a/pkg/server/hook/auth.go
+++ b/pkg/server/hook/auth.go
@@ -161,7 +161,7 @@ func (h *authHook) Open(env env.Env) (interface{}, error) {
return a, nil
}
-func newAuthHook(opts.Opts, env.Env) (hook, error) {
+func newAuthHook(opts.Opts) (hook, error) {
return &authHook{}, nil
}
diff --git a/pkg/server/hook/b64.go b/pkg/server/hook/b64.go
index 3c21474..fce42a0 100644
--- a/pkg/server/hook/b64.go
+++ b/pkg/server/hook/b64.go
@@ -43,7 +43,7 @@ func (h b64Hook) Open(env.Env) (interface{}, error) {
return h, nil
}
-func newB64Hook(opts.Opts, env.Env) (hook, error) {
+func newB64Hook(opts.Opts) (hook, error) {
return b64Hook{}, nil
}
diff --git a/pkg/server/hook/b85.go b/pkg/server/hook/b85.go
index 9851ffc..bf36b56 100644
--- a/pkg/server/hook/b85.go
+++ b/pkg/server/hook/b85.go
@@ -48,7 +48,7 @@ func (h b85Hook) Open(env.Env) (interface{}, error) {
return h, nil
}
-func newB85Hook(opts.Opts, env.Env) (hook, error) {
+func newB85Hook(opts.Opts) (hook, error) {
return b85Hook{}, nil
}
diff --git a/pkg/server/hook/hex.go b/pkg/server/hook/hex.go
index e37bc6e..bc71bf2 100644
--- a/pkg/server/hook/hex.go
+++ b/pkg/server/hook/hex.go
@@ -28,7 +28,7 @@ func (h hexHook) Open(env.Env) (interface{}, error) {
return h, nil
}
-func newHexHook(opts.Opts, env.Env) (hook, error) {
+func newHexHook(opts.Opts) (hook, error) {
return hexHook{}, nil
}
diff --git a/pkg/server/hook/hook.go b/pkg/server/hook/hook.go
index 6ac51a1..3065cbe 100644
--- a/pkg/server/hook/hook.go
+++ b/pkg/server/hook/hook.go
@@ -10,12 +10,13 @@ import (
"tunnel/pkg/server/queue"
)
-type hookInitFunc func(opts.Opts, env.Env) (hook, error)
+type hookInitFunc func(opts.Opts) (hook, error)
var hooks = map[string]hookInitFunc{}
type Pipe struct {
priv interface{}
+ Hook H
Send Func
Recv Func
}
@@ -67,7 +68,7 @@ func (w *wrapper) Open(env env.Env) (*Pipe, error) {
return nil, err
}
- pipe := &Pipe{priv: it}
+ pipe := &Pipe{priv: it, Hook: w}
if s, ok := it.(Sender); ok {
pipe.Send = s.Send
@@ -90,7 +91,7 @@ func (p *Pipe) Close() {
}
}
-func New(desc string, env env.Env) (H, error) {
+func New(desc string) (H, error) {
name, opts := opts.Parse(desc)
reverse := false
@@ -101,8 +102,8 @@ func New(desc string, env env.Env) (H, error) {
if f, ok := hooks[name]; !ok {
return nil, fmt.Errorf("unknown hook '%s'", name)
- } else if h, err := f(opts, env); err != nil {
- return nil, err
+ } else if h, err := f(opts); err != nil {
+ return nil, fmt.Errorf("%s: %w", name, err)
} else {
w := &wrapper{
hook: h,
@@ -122,7 +123,7 @@ func register(name string, f hookInitFunc) {
}
func registerFunc(name string, p Func) {
- register(name, func(opts.Opts, env.Env) (hook, error) {
+ register(name, func(opts.Opts) (hook, error) {
return p, nil
})
}
diff --git a/pkg/server/hook/info-http.go b/pkg/server/hook/info-http.go
index 7941072..73480ff 100644
--- a/pkg/server/hook/info-http.go
+++ b/pkg/server/hook/info-http.go
@@ -46,7 +46,7 @@ func (infoHttpHook) Open(env env.Env) (interface{}, error) {
return &infoHttp{env: env}, nil
}
-func newInfoHttpHook(opts opts.Opts, env env.Env) (hook, error) {
+func newInfoHttpHook(opts.Opts) (hook, error) {
return infoHttpHook{}, nil
}
diff --git a/pkg/server/hook/proxy.go b/pkg/server/hook/proxy.go
index 26be2d0..172b01a 100644
--- a/pkg/server/hook/proxy.go
+++ b/pkg/server/hook/proxy.go
@@ -11,7 +11,8 @@ import (
"tunnel/pkg/server/queue"
)
-var addrRe = regexp.MustCompile("^[0-9a-zA-Z-.]+:[0-9]+$")
+var addrPattern = "^([0-9a-zA-Z-.]+|\\[[0-9a-fA-F:]*\\]):[0-9]+$"
+var isGoodAddr = regexp.MustCompile(addrPattern).MatchString
type proxyHook struct {
addr string
@@ -61,8 +62,13 @@ func (p *proxy) Recv(rq, wq queue.Q) error {
}
func (h *proxyHook) Open(env env.Env) (interface{}, error) {
+ addr := env.Eval(h.addr)
+ if !isGoodAddr(addr) {
+ return nil, fmt.Errorf("invalid addr '%s'", addr)
+ }
+
p := &proxy{
- addr: h.addr,
+ addr: addr,
auth: h.auth,
c: make(chan bool),
}
@@ -74,18 +80,15 @@ func (h *proxyHook) Open(env env.Env) (interface{}, error) {
return p, nil
}
-func newProxyHook(opts opts.Opts, env env.Env) (hook, error) {
- h := &proxyHook{}
-
- if addr, ok := opts["addr"]; !ok {
- return nil, errors.New("proxy: missing addr")
- } else if !addrRe.MatchString(addr) {
- return nil, errors.New("proxy: invalid addr")
- } else {
- h.addr = addr
+func newProxyHook(opts opts.Opts) (hook, error) {
+ h := &proxyHook{
+ addr: opts["addr"],
+ auth: opts["auth"],
}
- h.auth = opts["auth"]
+ if h.addr == "" {
+ return nil, errors.New("expected addr")
+ }
return h, nil
}
diff --git a/pkg/server/hook/split.go b/pkg/server/hook/split.go
index 75faf48..6a2c4ca 100644
--- a/pkg/server/hook/split.go
+++ b/pkg/server/hook/split.go
@@ -38,7 +38,7 @@ func (h *splitHook) Open(env.Env) (interface{}, error) {
return h, nil
}
-func newSplitHook(opts opts.Opts, env env.Env) (hook, error) {
+func newSplitHook(opts opts.Opts) (hook, error) {
size := splitDefaultSize
if s, ok := opts["size"]; ok {
diff --git a/pkg/server/hook/tee.go b/pkg/server/hook/tee.go
index 2d13fcb..3f61b50 100644
--- a/pkg/server/hook/tee.go
+++ b/pkg/server/hook/tee.go
@@ -93,7 +93,7 @@ func (h *teeHook) Open(env env.Env) (interface{}, error) {
return &t, nil
}
-func newTeeHook(opts opts.Opts, env env.Env) (hook, error) {
+func newTeeHook(opts opts.Opts) (hook, error) {
h := &teeHook{}
if file, ok := opts["file"]; ok {
h.file = file
diff --git a/pkg/server/hook/zip.go b/pkg/server/hook/zip.go
index 94160fe..bde4957 100644
--- a/pkg/server/hook/zip.go
+++ b/pkg/server/hook/zip.go
@@ -46,7 +46,7 @@ func (h zipHook) Open(env.Env) (interface{}, error) {
return h, nil
}
-func newZipHook(opts.Opts, env.Env) (hook, error) {
+func newZipHook(opts.Opts) (hook, error) {
return zipHook{}, nil
}
diff --git a/pkg/server/opts/opts.go b/pkg/server/opts/opts.go
index 25dd8e6..22383d8 100644
--- a/pkg/server/opts/opts.go
+++ b/pkg/server/opts/opts.go
@@ -19,3 +19,8 @@ func Parse(s string) (string, Opts) {
return v[0], m
}
+
+func (m Opts) Bool(key string) bool {
+ _, ok := m[key]
+ return ok
+}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index e794b56..43a0309 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -288,26 +288,31 @@ func (c *client) decode(b []byte) ([]string, error) {
}
func (r *request) eval(args []string) []string {
- for n, s := range args {
- if strings.HasPrefix(s, "^") {
- args[n] = s[1:]
+ var out []string
+
+ for _, s := range args {
+ var t string
+
+ if strings.HasPrefix(s, ":") {
+ t = s[1:]
} else {
- args[n] = r.c.s.env.Eval(s)
+ t = r.c.s.env.Eval(s)
}
+
+ out = append(out, t)
}
- return args
+ return out
}
func (r *request) parse(args []string) {
- c, args := getCmd(r.eval(args))
- if c == nil {
+ if c, args := getCmd(r.eval(args)); c == nil {
r.Fatal("command not found")
+ } else {
+ r.args = args
+ r.argc = len(args)
+ r.cmd = c
}
-
- r.args = args
- r.argc = len(args)
- r.cmd = c
}
func (r *request) run(args []string) {
diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go
index d7b232c..728269e 100644
--- a/pkg/server/socket/dial.go
+++ b/pkg/server/socket/dial.go
@@ -2,6 +2,7 @@ package socket
import (
"fmt"
+ "log"
"net"
"strings"
"tunnel/pkg/server/env"
@@ -12,13 +13,6 @@ type dialSocket struct {
}
func newDialSocket(proto, addr string) (S, error) {
- switch proto {
- case "tcp", "udp":
- if !strings.Contains(addr, ":") {
- addr = "localhost:" + addr
- }
- }
-
return &dialSocket{proto: proto, addr: addr}, nil
}
@@ -27,15 +21,27 @@ func (s *dialSocket) String() string {
}
func (s *dialSocket) Open(e env.Env) (Conn, error) {
- conn, err := net.Dial(s.proto, e.Eval(s.addr))
+ addr := e.Eval(s.addr)
+
+ switch s.proto {
+ case "tcp", "udp":
+ if !strings.Contains(addr, ":") {
+ addr = "localhost:" + addr
+ }
+ }
+
+ conn, err := net.Dial(s.proto, addr)
if err != nil {
return nil, err
}
- addr := conn.RemoteAddr()
- info := fmt.Sprintf(">%s/%s", addr.Network(), addr)
+ la, ra := conn.LocalAddr(), conn.RemoteAddr()
+ desc := fmt.Sprintf("%s/%s->%s", la.Network(), la, ra)
+ info := fmt.Sprintf(">%s/%s", ra.Network(), ra)
+
+ log.Println("dial", desc)
- return exported{info, newConn(conn)}, nil
+ return newConn(conn, desc, info), nil
}
func (s *dialSocket) Close() {
diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go
index caf5fcf..e640e16 100644
--- a/pkg/server/socket/listen.go
+++ b/pkg/server/socket/listen.go
@@ -1,43 +1,78 @@
package socket
import (
+ "errors"
"fmt"
+ "log"
"net"
"strings"
"tunnel/pkg/server/env"
+ "tunnel/pkg/server/opts"
)
-func newListenSocket(proto, addr string) (S, error) {
+type listenSocket struct {
+ proto, addr string
+ listen net.Listener
+ redirect bool
+}
+
+func newListenSocket(proto, addr string, opts opts.Opts) (S, error) {
+ redirect := opts.Bool("redirect")
+
if proto == "tcp" {
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
}
+ if redirect && proto != "tcp" {
+ return nil, errors.New("redirect not supported")
+ }
+
listen, err := net.Listen(proto, addr)
if err != nil {
return nil, err
}
s := &listenSocket{
- proto: proto,
- addr: addr,
- listen: listen,
+ proto: proto,
+ addr: addr,
+ listen: listen,
+ redirect: redirect,
}
return s, nil
}
-func (s *listenSocket) Open(env.Env) (Conn, error) {
+func (s *listenSocket) Open(env env.Env) (Conn, error) {
+ var original string
+
conn, err := s.listen.Accept()
if err != nil {
return nil, err
}
- addr := conn.RemoteAddr()
- info := fmt.Sprintf("<%s/%s", addr.Network(), addr)
+ la, ra := conn.LocalAddr(), conn.RemoteAddr()
+ desc := fmt.Sprintf("%s/%s->%s", la.Network(), ra, la)
+ info := fmt.Sprintf("<%s/%s", ra.Network(), ra)
+
+ if s.redirect {
+ if err := getConnOriginalAddr(conn, &original); err != nil {
+ log.Println("accept", desc, "failed")
+ conn.Close()
+ return nil, err
+ } else {
+ env.Set("original.addr", original)
+ }
+ }
+
+ if original == "" {
+ log.Println("accept", desc)
+ } else {
+ log.Println("accept", desc, "original", original)
+ }
- return exported{info, newConn(conn)}, nil
+ return newConn(conn, desc, info), nil
}
func (s *listenSocket) String() string {
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 48c06b4..b419468 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -14,11 +14,6 @@ import (
var ErrAlreadyClosed = errors.New("already closed")
-type exported struct {
- info string
- Conn
-}
-
type Conn interface {
Send(wq queue.Q) error
Recv(rq queue.Q) error
@@ -30,23 +25,17 @@ type S interface {
Close()
}
-type listenSocket struct {
- proto, addr string
- listen net.Listener
-}
-
type conn struct {
net.Conn
- once sync.Once
-}
-func (c exported) String() string {
- return c.info
+ desc string
+ info string
+
+ once sync.Once
}
-func newConn(cn net.Conn) Conn {
- c := &conn{Conn: cn}
- log.Println("open", c)
+func newConn(cn net.Conn, desc, info string) *conn {
+ c := &conn{Conn: cn, desc: desc, info: info}
return c
}
@@ -59,22 +48,21 @@ func (c *conn) Recv(rq queue.Q) error {
}
func (c *conn) String() string {
- local, remote := c.LocalAddr(), c.RemoteAddr()
- return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote)
+ return c.info
}
func (c *conn) Close() error {
err := ErrAlreadyClosed
c.once.Do(func() {
- log.Println("close", c)
+ log.Println("close", c.desc)
err = c.Conn.Close()
})
return err
}
-func New(desc string, env env.Env) (S, error) {
+func New(desc string) (S, error) {
base, opts := opts.Parse(desc)
args := strings.SplitN(base, "/", 2)
@@ -104,11 +92,11 @@ func New(desc string, env env.Env) (S, error) {
return newTunSocket(addr)
}
- if _, ok := opts["listen"]; ok {
- return newListenSocket(proto, addr)
+ if opts.Bool("listen") {
+ return newListenSocket(proto, addr, opts)
}
- if _, ok := opts["defer"]; ok {
+ if opts.Bool("defer") {
return newDeferSocket(proto, addr)
}
diff --git a/pkg/server/socket/sys.go b/pkg/server/socket/sys.go
new file mode 100644
index 0000000..d09df12
--- /dev/null
+++ b/pkg/server/socket/sys.go
@@ -0,0 +1,131 @@
+package socket
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "golang.org/x/sys/unix"
+ "net"
+ "strconv"
+ "syscall"
+ "unsafe"
+)
+
+func be16toh(n uint16) uint16 {
+ b := (*[2]byte)(unsafe.Pointer(&n))
+ return binary.BigEndian.Uint16(b[:])
+}
+
+func getsockopt(s int, level, name int, ptr unsafe.Pointer, size int) error {
+ tmpsize := uint32(size)
+
+ _, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT,
+ uintptr(s),
+ uintptr(level),
+ uintptr(name),
+ uintptr(ptr),
+ uintptr(unsafe.Pointer(&tmpsize)),
+ 0)
+
+ if errno != 0 {
+ return fmt.Errorf("getsockopt: %w", errno)
+ }
+
+ return nil
+}
+
+func ioctl(fd int, req int, ptr unsafe.Pointer) error {
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL,
+ uintptr(fd),
+ uintptr(req),
+ uintptr(ptr))
+
+ if errno != 0 {
+ return fmt.Errorf("ioctl: %w", errno)
+ }
+
+ return nil
+}
+
+func getRawConn(conn net.Conn) (syscall.RawConn, error) {
+ switch c := conn.(type) {
+ case *net.TCPConn:
+ return c.SyscallConn()
+ default:
+ return nil, errors.New("unknown connection type")
+ }
+}
+
+func withConnControl(conn net.Conn, f func(fd int) error) (err error) {
+ var c syscall.RawConn
+ var ferr error
+
+ c, err = getRawConn(conn)
+ if err != nil {
+ return
+ }
+
+ err = c.Control(func(fd uintptr) {
+ ferr = f(int(fd))
+ })
+ if ferr != nil {
+ err = ferr
+ }
+
+ return
+}
+
+func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error {
+ const SO_ORIGINAL_DST = 80
+
+ p, n := unsafe.Pointer(sa), int(unsafe.Sizeof(*sa))
+
+ family, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN)
+ if err != nil {
+ return err
+ }
+
+ switch family {
+ case unix.AF_INET6:
+ err := getsockopt(fd, unix.SOL_IPV6, SO_ORIGINAL_DST, p, n)
+ if !errors.Is(err, unix.ENOENT) {
+ return err
+ }
+ // skipped check for ipv4 encoded as ipv6 address
+ fallthrough
+ case unix.AF_INET:
+ return getsockopt(fd, unix.SOL_IP, SO_ORIGINAL_DST, p, n)
+ default:
+ return errors.New("unknown address family")
+ }
+}
+
+func getConnOriginalAddr(conn net.Conn, addr *string) error {
+ var sa unix.RawSockaddrAny
+
+ f := func(fd int) error {
+ return getSocketOriginalDst(fd, &sa)
+ }
+
+ if err := withConnControl(conn, f); err != nil {
+ return fmt.Errorf("get-original-addr: %w", err)
+ }
+
+ var host net.IP
+ var port uint16
+
+ switch sa.Addr.Family {
+ case unix.AF_INET:
+ sin := (*unix.RawSockaddrInet4)(unsafe.Pointer(&sa))
+ host, port = sin.Addr[:], sin.Port
+ case unix.AF_INET6:
+ sin := (*unix.RawSockaddrInet6)(unsafe.Pointer(&sa))
+ host, port = sin.Addr[:], sin.Port
+ default:
+ return errors.New("get-original-addr: unknown address family")
+ }
+
+ *addr = net.JoinHostPort(host.String(), strconv.Itoa(int(be16toh(port))))
+
+ return nil
+}
diff --git a/pkg/server/socket/tun.go b/pkg/server/socket/tun.go
index c14125b..7336c04 100644
--- a/pkg/server/socket/tun.go
+++ b/pkg/server/socket/tun.go
@@ -34,15 +34,6 @@ type tunConn struct {
once sync.Once
}
-func ioctl(fd int, req uintptr, ptr unsafe.Pointer) error {
- _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), req, uintptr(ptr))
- if errno != 0 {
- return errno
- }
-
- return nil
-}
-
func newTunSocket(name string) (S, error) {
return &tunSocket{name: name}, nil
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index 309c272..d54208b 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -129,6 +129,27 @@ func (t *tunnel) sleep(d time.Duration) {
tmr.Stop()
}
+func (t *tunnel) openPipes(env env.Env) ([]*hook.Pipe, error) {
+ var pipes []*hook.Pipe
+
+ cleanup := func() {
+ for _, p := range pipes {
+ p.Close()
+ }
+ }
+
+ for _, h := range t.hooks {
+ p, err := h.Open(env)
+ if err != nil {
+ cleanup()
+ return nil, fmt.Errorf("%s: %w", h, err)
+ }
+ pipes = append(pipes, p)
+ }
+
+ return pipes, nil
+}
+
func (t *tunnel) serve() {
for t.acquire() {
var ok bool
@@ -137,17 +158,21 @@ func (t *tunnel) serve() {
env.Set("tunnel", t.id)
- if in, err := t.in.Open(env); err == nil {
- if out, err := t.out.Open(env); err == nil {
- s := t.newStream(env, in, out)
- log.Println(t, s, "create", in, out)
- ok = true
- } else {
+ if in, err := t.in.Open(env); err != nil {
+ if t.alive() {
log.Println(t, err)
- in.Close()
}
- } else if t.alive() {
+ } else if out, err := t.out.Open(env); err != nil {
+ log.Println(t, err)
+ in.Close()
+ } else if pipes, err := t.openPipes(env); err != nil {
log.Println(t, err)
+ in.Close()
+ out.Close()
+ } else {
+ s := t.newStream(env, in, out, pipes)
+ log.Println(t, s, "create", in, out)
+ ok = true
}
if !ok {
@@ -159,11 +184,12 @@ func (t *tunnel) serve() {
close(t.done)
}
-func (t *tunnel) newStream(env env.Env, in, out socket.Conn) *stream {
+func (t *tunnel) newStream(env env.Env, in, out socket.Conn, pipes []*hook.Pipe) *stream {
s := &stream{
t: t,
in: in,
out: out,
+ pipes: pipes,
env: env,
id: t.nextSid,
since: time.Now(),
@@ -301,27 +327,18 @@ func (s *stream) run() {
s.channel(s.in, &s.m.in, rq, wq)
- for _, h := range s.t.hooks {
- p, err := h.Open(s.env)
- if err != nil {
- // FIXME: abort stream on error
- log.Println(s.t, s, h, err)
- continue
- }
-
+ for _, p := range s.pipes {
if p.Send != nil {
q := queue.New()
- s.pipe(h, p.Send, wq, q)
+ s.pipe(p.Hook, p.Send, wq, q)
wq = q
}
if p.Recv != nil {
q := queue.New()
- s.pipe(h, p.Recv, q, rq)
+ s.pipe(p.Hook, p.Recv, q, rq)
rq = q
}
-
- s.pipes = append(s.pipes, p)
}
s.channel(s.out, &s.m.out, wq, rq)
@@ -332,11 +349,11 @@ func (s *stream) stop() {
s.out.Close()
}
-func parseHooks(args []string, env env.Env) ([]hook.H, error) {
+func parseHooks(args []string) ([]hook.H, error) {
var hooks []hook.H
for _, arg := range args {
- if h, err := hook.New(arg, env); err != nil {
+ if h, err := hook.New(arg); err != nil {
return nil, err
} else {
hooks = append(hooks, h)
@@ -353,16 +370,16 @@ func newTunnel(limit int, args []string, env env.Env) (*tunnel, error) {
n := len(args) - 1
- if in, err = socket.New(args[0], env); err != nil {
+ if in, err = socket.New(args[0]); err != nil {
return nil, err
}
- if out, err = socket.New(args[n], env); err != nil {
+ if out, err = socket.New(args[n]); err != nil {
in.Close()
return nil, err
}
- if hooks, err = parseHooks(args[1:n], env); err != nil {
+ if hooks, err = parseHooks(args[1:n]); err != nil {
in.Close()
out.Close()
return nil, err