diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-05-26 03:22:32 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-05-26 03:22:32 +0300 |
| commit | e43c60b56401be7515d7fbfdfe3e4e56d1886a23 (patch) | |
| tree | 58711f3bbbe2b34e91836087c274773564e29029 /pkg/server/socket | |
| parent | 97802849e3a21952e6c5896622d52153b2d96ee7 (diff) | |
add tproxy listen option
Diffstat (limited to 'pkg/server/socket')
| -rw-r--r-- | pkg/server/socket/listen.go | 22 | ||||
| -rw-r--r-- | pkg/server/socket/sys.go | 44 |
2 files changed, 63 insertions, 3 deletions
diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go index e640e16..94fb85c 100644 --- a/pkg/server/socket/listen.go +++ b/pkg/server/socket/listen.go @@ -14,10 +14,12 @@ type listenSocket struct { proto, addr string listen net.Listener redirect bool + tproxy bool } func newListenSocket(proto, addr string, opts opts.Opts) (S, error) { redirect := opts.Bool("redirect") + tproxy := opts.Bool("tproxy") if proto == "tcp" { if !strings.Contains(addr, ":") { @@ -29,16 +31,32 @@ func newListenSocket(proto, addr string, opts opts.Opts) (S, error) { return nil, errors.New("redirect not supported") } + if tproxy && proto != "tcp" { + return nil, errors.New("tproxy not supported") + } + + if redirect && tproxy { + return nil, errors.New("redirect and tproxy cannot be used together") + } + listen, err := net.Listen(proto, addr) if err != nil { return nil, err } + if tproxy { + if err := setConnTransparent(listen); err != nil { + listen.Close() + return nil, err + } + } + s := &listenSocket{ proto: proto, addr: addr, listen: listen, redirect: redirect, + tproxy: tproxy, } return s, nil @@ -66,6 +84,10 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) { } } + if s.tproxy { + env.Set("original.addr", la.String()) + } + if original == "" { log.Println("accept", desc) } else { diff --git a/pkg/server/socket/sys.go b/pkg/server/socket/sys.go index d09df12..70b59a6 100644 --- a/pkg/server/socket/sys.go +++ b/pkg/server/socket/sys.go @@ -47,16 +47,18 @@ func ioctl(fd int, req int, ptr unsafe.Pointer) error { return nil } -func getRawConn(conn net.Conn) (syscall.RawConn, error) { +func getRawConn(conn interface{}) (syscall.RawConn, error) { switch c := conn.(type) { case *net.TCPConn: return c.SyscallConn() + case *net.TCPListener: + return c.SyscallConn() default: return nil, errors.New("unknown connection type") } } -func withConnControl(conn net.Conn, f func(fd int) error) (err error) { +func withConnControl(conn interface{}, f func(fd int) error) (err error) { var c syscall.RawConn var ferr error @@ -75,12 +77,16 @@ func withConnControl(conn net.Conn, f func(fd int) error) (err error) { return } +func getSocketDomain(fd int) (int, error) { + return unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN) +} + func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error { const SO_ORIGINAL_DST = 80 p, n := unsafe.Pointer(sa), int(unsafe.Sizeof(*sa)) - family, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN) + family, err := getSocketDomain(fd) if err != nil { return err } @@ -129,3 +135,35 @@ func getConnOriginalAddr(conn net.Conn, addr *string) error { return nil } + +func setSocketTransparent(fd int) error { + family, err := getSocketDomain(fd) + if err != nil { + return err + } + + var level, opt int + + switch family { + case unix.AF_INET6: + level, opt = unix.SOL_IPV6, unix.IPV6_TRANSPARENT + case unix.AF_INET: + level, opt = unix.SOL_IP, unix.IP_TRANSPARENT + default: + return errors.New("unknown address family") + } + + return unix.SetsockoptInt(fd, level, opt, 1) +} + +func setConnTransparent(conn interface{}) error { + f := func(fd int) error { + return setSocketTransparent(fd) + } + + if err := withConnControl(conn, f); err != nil { + return fmt.Errorf("set-transparent: %w", err) + } + + return nil +} |
