diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-05-19 01:47:49 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-05-23 01:14:11 +0300 |
| commit | 1c4c61c90272fe251245da5f30b6134ba5a410f1 (patch) | |
| tree | d2809dd944de691e66422aec1becb4acc17b9a19 /pkg/server/socket/sys.go | |
| parent | 2c5259a594f5d8ddf12691deb6a79a0b566f024e (diff) | |
add redirect listen option
Diffstat (limited to 'pkg/server/socket/sys.go')
| -rw-r--r-- | pkg/server/socket/sys.go | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/pkg/server/socket/sys.go b/pkg/server/socket/sys.go new file mode 100644 index 0000000..d09df12 --- /dev/null +++ b/pkg/server/socket/sys.go @@ -0,0 +1,131 @@ +package socket + +import ( + "encoding/binary" + "errors" + "fmt" + "golang.org/x/sys/unix" + "net" + "strconv" + "syscall" + "unsafe" +) + +func be16toh(n uint16) uint16 { + b := (*[2]byte)(unsafe.Pointer(&n)) + return binary.BigEndian.Uint16(b[:]) +} + +func getsockopt(s int, level, name int, ptr unsafe.Pointer, size int) error { + tmpsize := uint32(size) + + _, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, + uintptr(s), + uintptr(level), + uintptr(name), + uintptr(ptr), + uintptr(unsafe.Pointer(&tmpsize)), + 0) + + if errno != 0 { + return fmt.Errorf("getsockopt: %w", errno) + } + + return nil +} + +func ioctl(fd int, req int, ptr unsafe.Pointer) error { + _, _, errno := unix.Syscall(unix.SYS_IOCTL, + uintptr(fd), + uintptr(req), + uintptr(ptr)) + + if errno != 0 { + return fmt.Errorf("ioctl: %w", errno) + } + + return nil +} + +func getRawConn(conn net.Conn) (syscall.RawConn, error) { + switch c := conn.(type) { + case *net.TCPConn: + return c.SyscallConn() + default: + return nil, errors.New("unknown connection type") + } +} + +func withConnControl(conn net.Conn, f func(fd int) error) (err error) { + var c syscall.RawConn + var ferr error + + c, err = getRawConn(conn) + if err != nil { + return + } + + err = c.Control(func(fd uintptr) { + ferr = f(int(fd)) + }) + if ferr != nil { + err = ferr + } + + return +} + +func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error { + const SO_ORIGINAL_DST = 80 + + p, n := unsafe.Pointer(sa), int(unsafe.Sizeof(*sa)) + + family, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN) + if err != nil { + return err + } + + switch family { + case unix.AF_INET6: + err := getsockopt(fd, unix.SOL_IPV6, SO_ORIGINAL_DST, p, n) + if !errors.Is(err, unix.ENOENT) { + return err + } + // skipped check for ipv4 encoded as ipv6 address + fallthrough + case unix.AF_INET: + return getsockopt(fd, unix.SOL_IP, SO_ORIGINAL_DST, p, n) + default: + return errors.New("unknown address family") + } +} + +func getConnOriginalAddr(conn net.Conn, addr *string) error { + var sa unix.RawSockaddrAny + + f := func(fd int) error { + return getSocketOriginalDst(fd, &sa) + } + + if err := withConnControl(conn, f); err != nil { + return fmt.Errorf("get-original-addr: %w", err) + } + + var host net.IP + var port uint16 + + switch sa.Addr.Family { + case unix.AF_INET: + sin := (*unix.RawSockaddrInet4)(unsafe.Pointer(&sa)) + host, port = sin.Addr[:], sin.Port + case unix.AF_INET6: + sin := (*unix.RawSockaddrInet6)(unsafe.Pointer(&sa)) + host, port = sin.Addr[:], sin.Port + default: + return errors.New("get-original-addr: unknown address family") + } + + *addr = net.JoinHostPort(host.String(), strconv.Itoa(int(be16toh(port)))) + + return nil +} |
