diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/server/socket/listen.go | 22 | ||||
| -rw-r--r-- | pkg/server/socket/sys.go | 44 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 3 |
3 files changed, 64 insertions, 5 deletions
diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go index e640e16..94fb85c 100644 --- a/pkg/server/socket/listen.go +++ b/pkg/server/socket/listen.go @@ -14,10 +14,12 @@ type listenSocket struct { proto, addr string listen net.Listener redirect bool + tproxy bool } func newListenSocket(proto, addr string, opts opts.Opts) (S, error) { redirect := opts.Bool("redirect") + tproxy := opts.Bool("tproxy") if proto == "tcp" { if !strings.Contains(addr, ":") { @@ -29,16 +31,32 @@ func newListenSocket(proto, addr string, opts opts.Opts) (S, error) { return nil, errors.New("redirect not supported") } + if tproxy && proto != "tcp" { + return nil, errors.New("tproxy not supported") + } + + if redirect && tproxy { + return nil, errors.New("redirect and tproxy cannot be used together") + } + listen, err := net.Listen(proto, addr) if err != nil { return nil, err } + if tproxy { + if err := setConnTransparent(listen); err != nil { + listen.Close() + return nil, err + } + } + s := &listenSocket{ proto: proto, addr: addr, listen: listen, redirect: redirect, + tproxy: tproxy, } return s, nil @@ -66,6 +84,10 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) { } } + if s.tproxy { + env.Set("original.addr", la.String()) + } + if original == "" { log.Println("accept", desc) } else { diff --git a/pkg/server/socket/sys.go b/pkg/server/socket/sys.go index d09df12..70b59a6 100644 --- a/pkg/server/socket/sys.go +++ b/pkg/server/socket/sys.go @@ -47,16 +47,18 @@ func ioctl(fd int, req int, ptr unsafe.Pointer) error { return nil } -func getRawConn(conn net.Conn) (syscall.RawConn, error) { +func getRawConn(conn interface{}) (syscall.RawConn, error) { switch c := conn.(type) { case *net.TCPConn: return c.SyscallConn() + case *net.TCPListener: + return c.SyscallConn() default: return nil, errors.New("unknown connection type") } } -func withConnControl(conn net.Conn, f func(fd int) error) (err error) { +func withConnControl(conn interface{}, f func(fd int) error) (err error) { var c syscall.RawConn var ferr error @@ -75,12 +77,16 @@ func withConnControl(conn net.Conn, f func(fd int) error) (err error) { return } +func getSocketDomain(fd int) (int, error) { + return unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN) +} + func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error { const SO_ORIGINAL_DST = 80 p, n := unsafe.Pointer(sa), int(unsafe.Sizeof(*sa)) - family, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN) + family, err := getSocketDomain(fd) if err != nil { return err } @@ -129,3 +135,35 @@ func getConnOriginalAddr(conn net.Conn, addr *string) error { return nil } + +func setSocketTransparent(fd int) error { + family, err := getSocketDomain(fd) + if err != nil { + return err + } + + var level, opt int + + switch family { + case unix.AF_INET6: + level, opt = unix.SOL_IPV6, unix.IPV6_TRANSPARENT + case unix.AF_INET: + level, opt = unix.SOL_IP, unix.IP_TRANSPARENT + default: + return errors.New("unknown address family") + } + + return unix.SetsockoptInt(fd, level, opt, 1) +} + +func setConnTransparent(conn interface{}) error { + f := func(fd int) error { + return setSocketTransparent(fd) + } + + if err := withConnControl(conn, f); err != nil { + return fmt.Errorf("set-transparent: %w", err) + } + + return nil +} diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index d54208b..ccd501e 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -157,6 +157,7 @@ func (t *tunnel) serve() { env := t.env.Fork() env.Set("tunnel", t.id) + env.Set("stream", strconv.Itoa(t.nextSid)) if in, err := t.in.Open(env); err != nil { if t.alive() { @@ -195,8 +196,6 @@ func (t *tunnel) newStream(env env.Env, in, out socket.Conn, pipes []*hook.Pipe) since: time.Now(), } - s.env.Set("stream", strconv.Itoa(s.id)) - s.run() t.mu.Lock() |
