summaryrefslogtreecommitdiff
path: root/pkg/server/socket/sys.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/socket/sys.go')
-rw-r--r--pkg/server/socket/sys.go131
1 files changed, 131 insertions, 0 deletions
diff --git a/pkg/server/socket/sys.go b/pkg/server/socket/sys.go
new file mode 100644
index 0000000..d09df12
--- /dev/null
+++ b/pkg/server/socket/sys.go
@@ -0,0 +1,131 @@
+package socket
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "golang.org/x/sys/unix"
+ "net"
+ "strconv"
+ "syscall"
+ "unsafe"
+)
+
+func be16toh(n uint16) uint16 {
+ b := (*[2]byte)(unsafe.Pointer(&n))
+ return binary.BigEndian.Uint16(b[:])
+}
+
+func getsockopt(s int, level, name int, ptr unsafe.Pointer, size int) error {
+ tmpsize := uint32(size)
+
+ _, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT,
+ uintptr(s),
+ uintptr(level),
+ uintptr(name),
+ uintptr(ptr),
+ uintptr(unsafe.Pointer(&tmpsize)),
+ 0)
+
+ if errno != 0 {
+ return fmt.Errorf("getsockopt: %w", errno)
+ }
+
+ return nil
+}
+
+func ioctl(fd int, req int, ptr unsafe.Pointer) error {
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL,
+ uintptr(fd),
+ uintptr(req),
+ uintptr(ptr))
+
+ if errno != 0 {
+ return fmt.Errorf("ioctl: %w", errno)
+ }
+
+ return nil
+}
+
+func getRawConn(conn net.Conn) (syscall.RawConn, error) {
+ switch c := conn.(type) {
+ case *net.TCPConn:
+ return c.SyscallConn()
+ default:
+ return nil, errors.New("unknown connection type")
+ }
+}
+
+func withConnControl(conn net.Conn, f func(fd int) error) (err error) {
+ var c syscall.RawConn
+ var ferr error
+
+ c, err = getRawConn(conn)
+ if err != nil {
+ return
+ }
+
+ err = c.Control(func(fd uintptr) {
+ ferr = f(int(fd))
+ })
+ if ferr != nil {
+ err = ferr
+ }
+
+ return
+}
+
+func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error {
+ const SO_ORIGINAL_DST = 80
+
+ p, n := unsafe.Pointer(sa), int(unsafe.Sizeof(*sa))
+
+ family, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN)
+ if err != nil {
+ return err
+ }
+
+ switch family {
+ case unix.AF_INET6:
+ err := getsockopt(fd, unix.SOL_IPV6, SO_ORIGINAL_DST, p, n)
+ if !errors.Is(err, unix.ENOENT) {
+ return err
+ }
+ // skipped check for ipv4 encoded as ipv6 address
+ fallthrough
+ case unix.AF_INET:
+ return getsockopt(fd, unix.SOL_IP, SO_ORIGINAL_DST, p, n)
+ default:
+ return errors.New("unknown address family")
+ }
+}
+
+func getConnOriginalAddr(conn net.Conn, addr *string) error {
+ var sa unix.RawSockaddrAny
+
+ f := func(fd int) error {
+ return getSocketOriginalDst(fd, &sa)
+ }
+
+ if err := withConnControl(conn, f); err != nil {
+ return fmt.Errorf("get-original-addr: %w", err)
+ }
+
+ var host net.IP
+ var port uint16
+
+ switch sa.Addr.Family {
+ case unix.AF_INET:
+ sin := (*unix.RawSockaddrInet4)(unsafe.Pointer(&sa))
+ host, port = sin.Addr[:], sin.Port
+ case unix.AF_INET6:
+ sin := (*unix.RawSockaddrInet6)(unsafe.Pointer(&sa))
+ host, port = sin.Addr[:], sin.Port
+ default:
+ return errors.New("get-original-addr: unknown address family")
+ }
+
+ *addr = net.JoinHostPort(host.String(), strconv.Itoa(int(be16toh(port))))
+
+ return nil
+}