diff options
Diffstat (limited to 'pkg/server/socket')
| -rw-r--r-- | pkg/server/socket/dial.go | 28 | ||||
| -rw-r--r-- | pkg/server/socket/listen.go | 51 | ||||
| -rw-r--r-- | pkg/server/socket/socket.go | 36 | ||||
| -rw-r--r-- | pkg/server/socket/sys.go | 131 | ||||
| -rw-r--r-- | pkg/server/socket/tun.go | 9 |
5 files changed, 203 insertions, 52 deletions
diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go index d7b232c..728269e 100644 --- a/pkg/server/socket/dial.go +++ b/pkg/server/socket/dial.go @@ -2,6 +2,7 @@ package socket import ( "fmt" + "log" "net" "strings" "tunnel/pkg/server/env" @@ -12,13 +13,6 @@ type dialSocket struct { } func newDialSocket(proto, addr string) (S, error) { - switch proto { - case "tcp", "udp": - if !strings.Contains(addr, ":") { - addr = "localhost:" + addr - } - } - return &dialSocket{proto: proto, addr: addr}, nil } @@ -27,15 +21,27 @@ func (s *dialSocket) String() string { } func (s *dialSocket) Open(e env.Env) (Conn, error) { - conn, err := net.Dial(s.proto, e.Eval(s.addr)) + addr := e.Eval(s.addr) + + switch s.proto { + case "tcp", "udp": + if !strings.Contains(addr, ":") { + addr = "localhost:" + addr + } + } + + conn, err := net.Dial(s.proto, addr) if err != nil { return nil, err } - addr := conn.RemoteAddr() - info := fmt.Sprintf(">%s/%s", addr.Network(), addr) + la, ra := conn.LocalAddr(), conn.RemoteAddr() + desc := fmt.Sprintf("%s/%s->%s", la.Network(), la, ra) + info := fmt.Sprintf(">%s/%s", ra.Network(), ra) + + log.Println("dial", desc) - return exported{info, newConn(conn)}, nil + return newConn(conn, desc, info), nil } func (s *dialSocket) Close() { diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go index caf5fcf..e640e16 100644 --- a/pkg/server/socket/listen.go +++ b/pkg/server/socket/listen.go @@ -1,43 +1,78 @@ package socket import ( + "errors" "fmt" + "log" "net" "strings" "tunnel/pkg/server/env" + "tunnel/pkg/server/opts" ) -func newListenSocket(proto, addr string) (S, error) { +type listenSocket struct { + proto, addr string + listen net.Listener + redirect bool +} + +func newListenSocket(proto, addr string, opts opts.Opts) (S, error) { + redirect := opts.Bool("redirect") + if proto == "tcp" { if !strings.Contains(addr, ":") { addr = ":" + addr } } + if redirect && proto != "tcp" { + return nil, errors.New("redirect not supported") + } + listen, err := net.Listen(proto, addr) if err != nil { return nil, err } s := &listenSocket{ - proto: proto, - addr: addr, - listen: listen, + proto: proto, + addr: addr, + listen: listen, + redirect: redirect, } return s, nil } -func (s *listenSocket) Open(env.Env) (Conn, error) { +func (s *listenSocket) Open(env env.Env) (Conn, error) { + var original string + conn, err := s.listen.Accept() if err != nil { return nil, err } - addr := conn.RemoteAddr() - info := fmt.Sprintf("<%s/%s", addr.Network(), addr) + la, ra := conn.LocalAddr(), conn.RemoteAddr() + desc := fmt.Sprintf("%s/%s->%s", la.Network(), ra, la) + info := fmt.Sprintf("<%s/%s", ra.Network(), ra) + + if s.redirect { + if err := getConnOriginalAddr(conn, &original); err != nil { + log.Println("accept", desc, "failed") + conn.Close() + return nil, err + } else { + env.Set("original.addr", original) + } + } + + if original == "" { + log.Println("accept", desc) + } else { + log.Println("accept", desc, "original", original) + } - return exported{info, newConn(conn)}, nil + return newConn(conn, desc, info), nil } func (s *listenSocket) String() string { diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index 48c06b4..b419468 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -14,11 +14,6 @@ import ( var ErrAlreadyClosed = errors.New("already closed") -type exported struct { - info string - Conn -} - type Conn interface { Send(wq queue.Q) error Recv(rq queue.Q) error @@ -30,23 +25,17 @@ type S interface { Close() } -type listenSocket struct { - proto, addr string - listen net.Listener -} - type conn struct { net.Conn - once sync.Once -} -func (c exported) String() string { - return c.info + desc string + info string + + once sync.Once } -func newConn(cn net.Conn) Conn { - c := &conn{Conn: cn} - log.Println("open", c) +func newConn(cn net.Conn, desc, info string) *conn { + c := &conn{Conn: cn, desc: desc, info: info} return c } @@ -59,22 +48,21 @@ func (c *conn) Recv(rq queue.Q) error { } func (c *conn) String() string { - local, remote := c.LocalAddr(), c.RemoteAddr() - return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote) + return c.info } func (c *conn) Close() error { err := ErrAlreadyClosed c.once.Do(func() { - log.Println("close", c) + log.Println("close", c.desc) err = c.Conn.Close() }) return err } -func New(desc string, env env.Env) (S, error) { +func New(desc string) (S, error) { base, opts := opts.Parse(desc) args := strings.SplitN(base, "/", 2) @@ -104,11 +92,11 @@ func New(desc string, env env.Env) (S, error) { return newTunSocket(addr) } - if _, ok := opts["listen"]; ok { - return newListenSocket(proto, addr) + if opts.Bool("listen") { + return newListenSocket(proto, addr, opts) } - if _, ok := opts["defer"]; ok { + if opts.Bool("defer") { return newDeferSocket(proto, addr) } diff --git a/pkg/server/socket/sys.go b/pkg/server/socket/sys.go new file mode 100644 index 0000000..d09df12 --- /dev/null +++ b/pkg/server/socket/sys.go @@ -0,0 +1,131 @@ +package socket + +import ( + "encoding/binary" + "errors" + "fmt" + "golang.org/x/sys/unix" + "net" + "strconv" + "syscall" + "unsafe" +) + +func be16toh(n uint16) uint16 { + b := (*[2]byte)(unsafe.Pointer(&n)) + return binary.BigEndian.Uint16(b[:]) +} + +func getsockopt(s int, level, name int, ptr unsafe.Pointer, size int) error { + tmpsize := uint32(size) + + _, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, + uintptr(s), + uintptr(level), + uintptr(name), + uintptr(ptr), + uintptr(unsafe.Pointer(&tmpsize)), + 0) + + if errno != 0 { + return fmt.Errorf("getsockopt: %w", errno) + } + + return nil +} + +func ioctl(fd int, req int, ptr unsafe.Pointer) error { + _, _, errno := unix.Syscall(unix.SYS_IOCTL, + uintptr(fd), + uintptr(req), + uintptr(ptr)) + + if errno != 0 { + return fmt.Errorf("ioctl: %w", errno) + } + + return nil +} + +func getRawConn(conn net.Conn) (syscall.RawConn, error) { + switch c := conn.(type) { + case *net.TCPConn: + return c.SyscallConn() + default: + return nil, errors.New("unknown connection type") + } +} + +func withConnControl(conn net.Conn, f func(fd int) error) (err error) { + var c syscall.RawConn + var ferr error + + c, err = getRawConn(conn) + if err != nil { + return + } + + err = c.Control(func(fd uintptr) { + ferr = f(int(fd)) + }) + if ferr != nil { + err = ferr + } + + return +} + +func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error { + const SO_ORIGINAL_DST = 80 + + p, n := unsafe.Pointer(sa), int(unsafe.Sizeof(*sa)) + + family, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN) + if err != nil { + return err + } + + switch family { + case unix.AF_INET6: + err := getsockopt(fd, unix.SOL_IPV6, SO_ORIGINAL_DST, p, n) + if !errors.Is(err, unix.ENOENT) { + return err + } + // skipped check for ipv4 encoded as ipv6 address + fallthrough + case unix.AF_INET: + return getsockopt(fd, unix.SOL_IP, SO_ORIGINAL_DST, p, n) + default: + return errors.New("unknown address family") + } +} + +func getConnOriginalAddr(conn net.Conn, addr *string) error { + var sa unix.RawSockaddrAny + + f := func(fd int) error { + return getSocketOriginalDst(fd, &sa) + } + + if err := withConnControl(conn, f); err != nil { + return fmt.Errorf("get-original-addr: %w", err) + } + + var host net.IP + var port uint16 + + switch sa.Addr.Family { + case unix.AF_INET: + sin := (*unix.RawSockaddrInet4)(unsafe.Pointer(&sa)) + host, port = sin.Addr[:], sin.Port + case unix.AF_INET6: + sin := (*unix.RawSockaddrInet6)(unsafe.Pointer(&sa)) + host, port = sin.Addr[:], sin.Port + default: + return errors.New("get-original-addr: unknown address family") + } + + *addr = net.JoinHostPort(host.String(), strconv.Itoa(int(be16toh(port)))) + + return nil +} diff --git a/pkg/server/socket/tun.go b/pkg/server/socket/tun.go index c14125b..7336c04 100644 --- a/pkg/server/socket/tun.go +++ b/pkg/server/socket/tun.go @@ -34,15 +34,6 @@ type tunConn struct { once sync.Once } -func ioctl(fd int, req uintptr, ptr unsafe.Pointer) error { - _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), req, uintptr(ptr)) - if errno != 0 { - return errno - } - - return nil -} - func newTunSocket(name string) (S, error) { return &tunSocket{name: name}, nil |
