package socket import ( "encoding/binary" "errors" "fmt" "golang.org/x/sys/unix" "net" "strconv" "syscall" "unsafe" ) func be16toh(n uint16) uint16 { b := (*[2]byte)(unsafe.Pointer(&n)) return binary.BigEndian.Uint16(b[:]) } func getsockopt(s int, level, name int, ptr unsafe.Pointer, size int) error { tmpsize := uint32(size) _, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(s), uintptr(level), uintptr(name), uintptr(ptr), uintptr(unsafe.Pointer(&tmpsize)), 0) if errno != 0 { return fmt.Errorf("getsockopt: %w", errno) } return nil } func ioctl(fd int, req uint, ptr unsafe.Pointer) error { _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(ptr)) if errno != 0 { return fmt.Errorf("ioctl: %w", errno) } return nil } func getRawConn(conn interface{}) (syscall.RawConn, error) { switch c := conn.(type) { case *net.TCPConn: return c.SyscallConn() case *net.TCPListener: return c.SyscallConn() default: return nil, errors.New("unknown connection type") } } func withConnControl(conn interface{}, f func(fd int) error) (err error) { var c syscall.RawConn var ferr error c, err = getRawConn(conn) if err != nil { return } err = c.Control(func(fd uintptr) { ferr = f(int(fd)) }) if ferr != nil { err = ferr } return } func getSocketDomain(fd int) (int, error) { return unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN) } func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error { const SO_ORIGINAL_DST = 80 p, n := unsafe.Pointer(sa), int(unsafe.Sizeof(*sa)) family, err := getSocketDomain(fd) if err != nil { return err } switch family { case unix.AF_INET6: err := getsockopt(fd, unix.SOL_IPV6, SO_ORIGINAL_DST, p, n) if !errors.Is(err, unix.ENOENT) { return err } // skipped check for ipv4 encoded as ipv6 address fallthrough case unix.AF_INET: return getsockopt(fd, unix.SOL_IP, SO_ORIGINAL_DST, p, n) default: return errors.New("unknown address family") } } func getConnOriginalAddr(conn net.Conn, addr *string) error { var sa unix.RawSockaddrAny f := func(fd int) error { return getSocketOriginalDst(fd, &sa) } if err := withConnControl(conn, f); err != nil { return fmt.Errorf("get-original-addr: %w", err) } var host net.IP var port uint16 switch sa.Addr.Family { case unix.AF_INET: sin := (*unix.RawSockaddrInet4)(unsafe.Pointer(&sa)) host, port = sin.Addr[:], sin.Port case unix.AF_INET6: sin := (*unix.RawSockaddrInet6)(unsafe.Pointer(&sa)) host, port = sin.Addr[:], sin.Port default: return errors.New("get-original-addr: unknown address family") } *addr = net.JoinHostPort(host.String(), strconv.Itoa(int(be16toh(port)))) return nil } func setSocketTransparent(fd int) error { family, err := getSocketDomain(fd) if err != nil { return err } var level, opt int switch family { case unix.AF_INET6: level, opt = unix.SOL_IPV6, unix.IPV6_TRANSPARENT case unix.AF_INET: level, opt = unix.SOL_IP, unix.IP_TRANSPARENT default: return errors.New("unknown address family") } return unix.SetsockoptInt(fd, level, opt, 1) } func setConnTransparent(conn interface{}) error { f := func(fd int) error { return setSocketTransparent(fd) } if err := withConnControl(conn, f); err != nil { return fmt.Errorf("set-transparent: %w", err) } return nil }