diff options
Diffstat (limited to 'pkg/server/socket')
| -rw-r--r-- | pkg/server/socket/defer.go | 23 | ||||
| -rw-r--r-- | pkg/server/socket/dial.go | 29 | ||||
| -rw-r--r-- | pkg/server/socket/listen.go | 66 | ||||
| -rw-r--r-- | pkg/server/socket/loop.go | 6 | ||||
| -rw-r--r-- | pkg/server/socket/proxy.go | 21 | ||||
| -rw-r--r-- | pkg/server/socket/socket.go | 76 | ||||
| -rw-r--r-- | pkg/server/socket/tun.go | 24 |
7 files changed, 125 insertions, 120 deletions
diff --git a/pkg/server/socket/defer.go b/pkg/server/socket/defer.go index 7ed303d..7c1436e 100644 --- a/pkg/server/socket/defer.go +++ b/pkg/server/socket/defer.go @@ -6,28 +6,19 @@ import ( ) type deferSocket struct { - S + dialSocket `opts:"inline"` } type deferConn struct { - sock *deferSocket + sock S wait chan bool env env.Env conn Conn } -func newDeferSocket(proto, addr string) (S, error) { - s, err := newDialSocket(proto, addr) - if err != nil { - return s, err - } - - return &deferSocket{s}, nil -} - -func (s *deferSocket) Open(env env.Env) (Conn, error) { +func (s *deferSocket) New(env env.Env) (Conn, error) { c := &deferConn{ - sock: s, + sock: &s.dialSocket, wait: make(chan bool), env: env, } @@ -54,7 +45,7 @@ func (c *deferConn) Recv(rq queue.Q) error { return nil } - conn, err := c.sock.S.Open(c.env) + conn, err := c.sock.New(c.env) if err != nil { c.wait <- false return err @@ -83,3 +74,7 @@ func (c *deferConn) Close() (err error) { return } + +func init() { + register("defer", deferSocket{}) +} diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go index 7623084..b2df3b7 100644 --- a/pkg/server/socket/dial.go +++ b/pkg/server/socket/dial.go @@ -4,36 +4,25 @@ import ( "fmt" "log" "net" - "strings" "time" + "tunnel/pkg/server/env" ) const defaultTimeout = 5 * time.Second type dialSocket struct { - proto, addr string -} - -func newDialSocket(proto, addr string) (S, error) { - return &dialSocket{proto: proto, addr: addr}, nil + Proto string `opts:"default:tcp"` + Addr string `opts:"required"` } func (s *dialSocket) String() string { - return fmt.Sprintf("%s/%s", s.proto, s.addr) + return fmt.Sprintf("%s/%s", s.Proto, s.Addr) } -func (s *dialSocket) Open(e env.Env) (Conn, error) { - addr := e.Expand(s.addr) - - switch s.proto { - case "tcp", "udp": - if !strings.Contains(addr, ":") { - addr = "localhost:" + addr - } - } - - conn, err := net.DialTimeout(s.proto, addr, defaultTimeout) +func (s *dialSocket) New(e env.Env) (Conn, error) { + proto, addr := parseProtoAddr(s.Proto, e.Expand(s.Addr)) + conn, err := net.DialTimeout(proto, addr, defaultTimeout) if err != nil { return nil, err } @@ -49,3 +38,7 @@ func (s *dialSocket) Open(e env.Env) (Conn, error) { func (s *dialSocket) Close() { } + +func init() { + register("dial", dialSocket{}) +} diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go index 910e5de..2c2f184 100644 --- a/pkg/server/socket/listen.go +++ b/pkg/server/socket/listen.go @@ -5,64 +5,54 @@ import ( "fmt" "log" "net" - "strings" + "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" ) type listenSocket struct { - proto, addr string - listen net.Listener - redirect bool - tproxy bool -} + Proto string `opts:"default:tcp"` + Addr string `opts:"required"` -func newListenSocket(proto, addr string, opts opts.Opts) (S, error) { - redirect := opts.Bool("redirect") - tproxy := opts.Bool("tproxy") + Redirect bool + Tproxy bool - if proto == "tcp" { - if !strings.Contains(addr, ":") { - addr = ":" + addr - } - } + listen net.Listener +} - if redirect && proto != "tcp" { - return nil, errors.New("redirect not supported") +func (s *listenSocket) Prepare(e env.Env) error { + if s.Redirect && s.Proto != "tcp" { + return errors.New("redirect not supported") } - if tproxy && proto != "tcp" { - return nil, errors.New("tproxy not supported") + if s.Tproxy && s.Proto != "tcp" { + return errors.New("tproxy not supported") } - if redirect && tproxy { - return nil, errors.New("redirect and tproxy cannot be used together") + if s.Redirect && s.Tproxy { + return errors.New("redirect and tproxy cannot be used together") } + proto, addr := parseProtoAddr(s.Proto, s.Addr) listen, err := net.Listen(proto, addr) if err != nil { - return nil, err + return err } - if tproxy { + e.Set("listen", listen.Addr().String()) + + if s.Tproxy { if err := setConnTransparent(listen); err != nil { listen.Close() - return nil, err + return err } } - s := &listenSocket{ - proto: proto, - addr: addr, - listen: listen, - redirect: redirect, - tproxy: tproxy, - } + s.listen = listen - return s, nil + return nil } -func (s *listenSocket) Open(env env.Env) (Conn, error) { +func (s *listenSocket) New(env env.Env) (Conn, error) { var original string conn, err := s.listen.Accept() @@ -74,7 +64,7 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) { desc := fmt.Sprintf("%s/%s->%s", la.Network(), ra, la) info := fmt.Sprintf("<%s/%s", ra.Network(), ra) - if s.redirect { + if s.Redirect { if err := getConnOriginalAddr(conn, &original); err != nil { log.Println("accept", desc, "failed") conn.Close() @@ -84,7 +74,7 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) { } } - if s.tproxy { + if s.Tproxy { env.Set("original", la.String()) } @@ -98,9 +88,13 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) { } func (s *listenSocket) String() string { - return fmt.Sprintf("%s/%s,listen", s.proto, s.addr) + return fmt.Sprintf("%s/%s,listen", s.Proto, s.Addr) } func (s *listenSocket) Close() { s.listen.Close() } + +func init() { + register("listen", listenSocket{}) +} diff --git a/pkg/server/socket/loop.go b/pkg/server/socket/loop.go index a06448a..c442140 100644 --- a/pkg/server/socket/loop.go +++ b/pkg/server/socket/loop.go @@ -30,7 +30,7 @@ func (c *loopConn) Close() error { return nil } -func (s *loopSocket) Open(env.Env) (Conn, error) { +func (s *loopSocket) New(env.Env) (Conn, error) { return &loopConn{make(chan queue.Q), make(chan error)}, nil } @@ -41,6 +41,6 @@ func (s *loopSocket) String() string { func (s *loopSocket) Close() { } -func newLoopSocket() (S, error) { - return &loopSocket{}, nil +func init() { + register("loop", loopSocket{}) } diff --git a/pkg/server/socket/proxy.go b/pkg/server/socket/proxy.go index e4baec2..1be4bba 100644 --- a/pkg/server/socket/proxy.go +++ b/pkg/server/socket/proxy.go @@ -5,6 +5,7 @@ import ( "bytes" "errors" "fmt" + "tunnel/pkg/http" "tunnel/pkg/server/env" "tunnel/pkg/server/queue" @@ -16,7 +17,7 @@ type status struct { } type proxySocket struct { - proto string + Proto string `opts:"default:tcp"` } type proxyServer struct { @@ -28,11 +29,7 @@ type proxyServer struct { conn Conn } -func newProxySocket(proto string) (S, error) { - return &proxySocket{proto}, nil -} - -func (sock *proxySocket) Open(env env.Env) (Conn, error) { +func (sock *proxySocket) New(env env.Env) (Conn, error) { s := &proxyServer{ sock: sock, auth: env.Value("proxy.auth"), @@ -78,12 +75,12 @@ func (s *proxyServer) Send(wq queue.Q) error { } func (s *proxyServer) initConn(addr string) error { - dial, err := newDialSocket(s.sock.proto, addr) - if err != nil { - return err + dial := dialSocket{ + Proto: s.sock.Proto, + Addr: addr, } - conn, err := dial.Open(s.env) + conn, err := dial.New(s.env) if err != nil { dial.Close() return err @@ -138,3 +135,7 @@ func (s *proxyServer) Close() (err error) { return } + +func init() { + register("proxy", proxySocket{}) +} diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index 62ce5cf..03b73d9 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -5,8 +5,11 @@ import ( "fmt" "log" "net" + "reflect" + "sort" "strings" "sync" + "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" @@ -21,7 +24,7 @@ type Conn interface { } type S interface { - Open(env env.Env) (Conn, error) + New(env env.Env) (Conn, error) Close() } @@ -66,43 +69,62 @@ func (c *conn) Close() error { return err } -func New(desc string) (S, error) { - base, opts := opts.Parse(desc) - args := strings.SplitN(base, "/", 2) - - var proto string - var addr string +func New(desc string, e env.Env) (S, error) { + name, opts := opts.Parse(desc) - if len(args) < 2 { - addr = args[0] - } else { - proto, addr = args[0], args[1] + t, ok := sockets[name] + if !ok { + return nil, fmt.Errorf("%s: unknown type", name) } - if proto == "" { - proto = "tcp" + s := reflect.New(t).Interface() + if err := opts.Configure(s); err != nil { + return nil, fmt.Errorf("%s: %w", name, err) } - switch addr { - case "loop": - return newLoopSocket() - case "proxy": - return newProxySocket(proto) - case "": - return nil, fmt.Errorf("bad socket '%s'", desc) + if i, ok := s.(interface{ Prepare(env.Env) error }); ok { + if err := i.Prepare(e); err != nil { + return nil, fmt.Errorf("%s: %w", name, err) + } } - if proto == "tun" { - return newTunSocket(addr) + return s.(S), nil +} + +func parseProtoAddr(proto, addr string) (string, string) { + if proto == "tcp" || proto == "udp" { + if strings.HasPrefix(addr, "-:") { + addr = "localhost" + addr[1:] + } } - if opts.Bool("listen") { - return newListenSocket(proto, addr, opts) + return proto, addr +} + +var sockets = map[string]reflect.Type{} + +func register(name string, i interface{}) { + t := reflect.TypeOf(i) + if t.Kind() != reflect.Struct { + log.Panicf("non-struct type '%s'", t.String()) + } + if _, ok := reflect.New(t).Interface().(S); !ok { + log.Panicf("uncompatible socket type '%s'", t.String()) + } + if _, ok := sockets[name]; ok { + log.Panicf("duplicate socket name '%s'", name) } + sockets[name] = t +} + +func GetList() []string { + var list []string - if opts.Bool("defer") { - return newDeferSocket(proto, addr) + for k := range sockets { + list = append(list, k) } - return newDialSocket(proto, addr) + sort.Strings(list) + + return list } diff --git a/pkg/server/socket/tun.go b/pkg/server/socket/tun.go index d48c30c..3e673eb 100644 --- a/pkg/server/socket/tun.go +++ b/pkg/server/socket/tun.go @@ -8,10 +8,11 @@ import ( "os" "strings" "sync" + "unsafe" + "tunnel/pkg/pack" "tunnel/pkg/server/env" "tunnel/pkg/server/queue" - "unsafe" ) const maxTunBufSize = 65535 @@ -24,7 +25,7 @@ type ifReq struct { } type tunSocket struct { - name string + Name string `opts:"required"` } type tunConn struct { @@ -34,35 +35,30 @@ type tunConn struct { once sync.Once } -func newTunSocket(name string) (S, error) { - return &tunSocket{name: name}, nil - -} - func (s *tunSocket) String() string { - return fmt.Sprintf("tun/%s", s.name) + return fmt.Sprintf("tun/%s", s.Name) } func (s *tunSocket) Single() {} -func (s *tunSocket) Open(env.Env) (Conn, error) { +func (s *tunSocket) New(env.Env) (Conn, error) { fd, err := unix.Open("/dev/net/tun", unix.O_RDWR, 0) if err != nil { return nil, err } ifr := &ifReq{} - copy(ifr.name[:], s.name) + copy(ifr.name[:], s.Name) ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI if err := ioctl(fd, unix.TUNSETIFF, unsafe.Pointer(ifr)); err != nil { unix.Close(fd) - return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.name, err) + return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.Name, err) } if err := unix.SetNonblock(fd, true); err != nil { unix.Close(fd) - return nil, fmt.Errorf("set nonblock %s: %w", s.name, err) + return nil, fmt.Errorf("set nonblock %s: %w", s.Name, err) } c := &tunConn{ @@ -124,3 +120,7 @@ func (c *tunConn) Close() error { return err } + +func init() { + register("tun", tunSocket{}) +} |
