summaryrefslogtreecommitdiff
path: root/pkg/server/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/socket')
-rw-r--r--pkg/server/socket/dial.go28
-rw-r--r--pkg/server/socket/listen.go51
-rw-r--r--pkg/server/socket/socket.go36
-rw-r--r--pkg/server/socket/sys.go131
-rw-r--r--pkg/server/socket/tun.go9
5 files changed, 203 insertions, 52 deletions
diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go
index d7b232c..728269e 100644
--- a/pkg/server/socket/dial.go
+++ b/pkg/server/socket/dial.go
@@ -2,6 +2,7 @@ package socket
import (
"fmt"
+ "log"
"net"
"strings"
"tunnel/pkg/server/env"
@@ -12,13 +13,6 @@ type dialSocket struct {
}
func newDialSocket(proto, addr string) (S, error) {
- switch proto {
- case "tcp", "udp":
- if !strings.Contains(addr, ":") {
- addr = "localhost:" + addr
- }
- }
-
return &dialSocket{proto: proto, addr: addr}, nil
}
@@ -27,15 +21,27 @@ func (s *dialSocket) String() string {
}
func (s *dialSocket) Open(e env.Env) (Conn, error) {
- conn, err := net.Dial(s.proto, e.Eval(s.addr))
+ addr := e.Eval(s.addr)
+
+ switch s.proto {
+ case "tcp", "udp":
+ if !strings.Contains(addr, ":") {
+ addr = "localhost:" + addr
+ }
+ }
+
+ conn, err := net.Dial(s.proto, addr)
if err != nil {
return nil, err
}
- addr := conn.RemoteAddr()
- info := fmt.Sprintf(">%s/%s", addr.Network(), addr)
+ la, ra := conn.LocalAddr(), conn.RemoteAddr()
+ desc := fmt.Sprintf("%s/%s->%s", la.Network(), la, ra)
+ info := fmt.Sprintf(">%s/%s", ra.Network(), ra)
+
+ log.Println("dial", desc)
- return exported{info, newConn(conn)}, nil
+ return newConn(conn, desc, info), nil
}
func (s *dialSocket) Close() {
diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go
index caf5fcf..e640e16 100644
--- a/pkg/server/socket/listen.go
+++ b/pkg/server/socket/listen.go
@@ -1,43 +1,78 @@
package socket
import (
+ "errors"
"fmt"
+ "log"
"net"
"strings"
"tunnel/pkg/server/env"
+ "tunnel/pkg/server/opts"
)
-func newListenSocket(proto, addr string) (S, error) {
+type listenSocket struct {
+ proto, addr string
+ listen net.Listener
+ redirect bool
+}
+
+func newListenSocket(proto, addr string, opts opts.Opts) (S, error) {
+ redirect := opts.Bool("redirect")
+
if proto == "tcp" {
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
}
+ if redirect && proto != "tcp" {
+ return nil, errors.New("redirect not supported")
+ }
+
listen, err := net.Listen(proto, addr)
if err != nil {
return nil, err
}
s := &listenSocket{
- proto: proto,
- addr: addr,
- listen: listen,
+ proto: proto,
+ addr: addr,
+ listen: listen,
+ redirect: redirect,
}
return s, nil
}
-func (s *listenSocket) Open(env.Env) (Conn, error) {
+func (s *listenSocket) Open(env env.Env) (Conn, error) {
+ var original string
+
conn, err := s.listen.Accept()
if err != nil {
return nil, err
}
- addr := conn.RemoteAddr()
- info := fmt.Sprintf("<%s/%s", addr.Network(), addr)
+ la, ra := conn.LocalAddr(), conn.RemoteAddr()
+ desc := fmt.Sprintf("%s/%s->%s", la.Network(), ra, la)
+ info := fmt.Sprintf("<%s/%s", ra.Network(), ra)
+
+ if s.redirect {
+ if err := getConnOriginalAddr(conn, &original); err != nil {
+ log.Println("accept", desc, "failed")
+ conn.Close()
+ return nil, err
+ } else {
+ env.Set("original.addr", original)
+ }
+ }
+
+ if original == "" {
+ log.Println("accept", desc)
+ } else {
+ log.Println("accept", desc, "original", original)
+ }
- return exported{info, newConn(conn)}, nil
+ return newConn(conn, desc, info), nil
}
func (s *listenSocket) String() string {
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 48c06b4..b419468 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -14,11 +14,6 @@ import (
var ErrAlreadyClosed = errors.New("already closed")
-type exported struct {
- info string
- Conn
-}
-
type Conn interface {
Send(wq queue.Q) error
Recv(rq queue.Q) error
@@ -30,23 +25,17 @@ type S interface {
Close()
}
-type listenSocket struct {
- proto, addr string
- listen net.Listener
-}
-
type conn struct {
net.Conn
- once sync.Once
-}
-func (c exported) String() string {
- return c.info
+ desc string
+ info string
+
+ once sync.Once
}
-func newConn(cn net.Conn) Conn {
- c := &conn{Conn: cn}
- log.Println("open", c)
+func newConn(cn net.Conn, desc, info string) *conn {
+ c := &conn{Conn: cn, desc: desc, info: info}
return c
}
@@ -59,22 +48,21 @@ func (c *conn) Recv(rq queue.Q) error {
}
func (c *conn) String() string {
- local, remote := c.LocalAddr(), c.RemoteAddr()
- return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote)
+ return c.info
}
func (c *conn) Close() error {
err := ErrAlreadyClosed
c.once.Do(func() {
- log.Println("close", c)
+ log.Println("close", c.desc)
err = c.Conn.Close()
})
return err
}
-func New(desc string, env env.Env) (S, error) {
+func New(desc string) (S, error) {
base, opts := opts.Parse(desc)
args := strings.SplitN(base, "/", 2)
@@ -104,11 +92,11 @@ func New(desc string, env env.Env) (S, error) {
return newTunSocket(addr)
}
- if _, ok := opts["listen"]; ok {
- return newListenSocket(proto, addr)
+ if opts.Bool("listen") {
+ return newListenSocket(proto, addr, opts)
}
- if _, ok := opts["defer"]; ok {
+ if opts.Bool("defer") {
return newDeferSocket(proto, addr)
}
diff --git a/pkg/server/socket/sys.go b/pkg/server/socket/sys.go
new file mode 100644
index 0000000..d09df12
--- /dev/null
+++ b/pkg/server/socket/sys.go
@@ -0,0 +1,131 @@
+package socket
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "golang.org/x/sys/unix"
+ "net"
+ "strconv"
+ "syscall"
+ "unsafe"
+)
+
+func be16toh(n uint16) uint16 {
+ b := (*[2]byte)(unsafe.Pointer(&n))
+ return binary.BigEndian.Uint16(b[:])
+}
+
+func getsockopt(s int, level, name int, ptr unsafe.Pointer, size int) error {
+ tmpsize := uint32(size)
+
+ _, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT,
+ uintptr(s),
+ uintptr(level),
+ uintptr(name),
+ uintptr(ptr),
+ uintptr(unsafe.Pointer(&tmpsize)),
+ 0)
+
+ if errno != 0 {
+ return fmt.Errorf("getsockopt: %w", errno)
+ }
+
+ return nil
+}
+
+func ioctl(fd int, req int, ptr unsafe.Pointer) error {
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL,
+ uintptr(fd),
+ uintptr(req),
+ uintptr(ptr))
+
+ if errno != 0 {
+ return fmt.Errorf("ioctl: %w", errno)
+ }
+
+ return nil
+}
+
+func getRawConn(conn net.Conn) (syscall.RawConn, error) {
+ switch c := conn.(type) {
+ case *net.TCPConn:
+ return c.SyscallConn()
+ default:
+ return nil, errors.New("unknown connection type")
+ }
+}
+
+func withConnControl(conn net.Conn, f func(fd int) error) (err error) {
+ var c syscall.RawConn
+ var ferr error
+
+ c, err = getRawConn(conn)
+ if err != nil {
+ return
+ }
+
+ err = c.Control(func(fd uintptr) {
+ ferr = f(int(fd))
+ })
+ if ferr != nil {
+ err = ferr
+ }
+
+ return
+}
+
+func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error {
+ const SO_ORIGINAL_DST = 80
+
+ p, n := unsafe.Pointer(sa), int(unsafe.Sizeof(*sa))
+
+ family, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_DOMAIN)
+ if err != nil {
+ return err
+ }
+
+ switch family {
+ case unix.AF_INET6:
+ err := getsockopt(fd, unix.SOL_IPV6, SO_ORIGINAL_DST, p, n)
+ if !errors.Is(err, unix.ENOENT) {
+ return err
+ }
+ // skipped check for ipv4 encoded as ipv6 address
+ fallthrough
+ case unix.AF_INET:
+ return getsockopt(fd, unix.SOL_IP, SO_ORIGINAL_DST, p, n)
+ default:
+ return errors.New("unknown address family")
+ }
+}
+
+func getConnOriginalAddr(conn net.Conn, addr *string) error {
+ var sa unix.RawSockaddrAny
+
+ f := func(fd int) error {
+ return getSocketOriginalDst(fd, &sa)
+ }
+
+ if err := withConnControl(conn, f); err != nil {
+ return fmt.Errorf("get-original-addr: %w", err)
+ }
+
+ var host net.IP
+ var port uint16
+
+ switch sa.Addr.Family {
+ case unix.AF_INET:
+ sin := (*unix.RawSockaddrInet4)(unsafe.Pointer(&sa))
+ host, port = sin.Addr[:], sin.Port
+ case unix.AF_INET6:
+ sin := (*unix.RawSockaddrInet6)(unsafe.Pointer(&sa))
+ host, port = sin.Addr[:], sin.Port
+ default:
+ return errors.New("get-original-addr: unknown address family")
+ }
+
+ *addr = net.JoinHostPort(host.String(), strconv.Itoa(int(be16toh(port))))
+
+ return nil
+}
diff --git a/pkg/server/socket/tun.go b/pkg/server/socket/tun.go
index c14125b..7336c04 100644
--- a/pkg/server/socket/tun.go
+++ b/pkg/server/socket/tun.go
@@ -34,15 +34,6 @@ type tunConn struct {
once sync.Once
}
-func ioctl(fd int, req uintptr, ptr unsafe.Pointer) error {
- _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), req, uintptr(ptr))
- if errno != 0 {
- return errno
- }
-
- return nil
-}
-
func newTunSocket(name string) (S, error) {
return &tunSocket{name: name}, nil