summaryrefslogtreecommitdiff
path: root/pkg/server/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/socket')
-rw-r--r--pkg/server/socket/defer.go23
-rw-r--r--pkg/server/socket/dial.go29
-rw-r--r--pkg/server/socket/listen.go66
-rw-r--r--pkg/server/socket/loop.go6
-rw-r--r--pkg/server/socket/proxy.go21
-rw-r--r--pkg/server/socket/socket.go76
-rw-r--r--pkg/server/socket/tun.go24
7 files changed, 125 insertions, 120 deletions
diff --git a/pkg/server/socket/defer.go b/pkg/server/socket/defer.go
index 7ed303d..7c1436e 100644
--- a/pkg/server/socket/defer.go
+++ b/pkg/server/socket/defer.go
@@ -6,28 +6,19 @@ import (
)
type deferSocket struct {
- S
+ dialSocket `opts:"inline"`
}
type deferConn struct {
- sock *deferSocket
+ sock S
wait chan bool
env env.Env
conn Conn
}
-func newDeferSocket(proto, addr string) (S, error) {
- s, err := newDialSocket(proto, addr)
- if err != nil {
- return s, err
- }
-
- return &deferSocket{s}, nil
-}
-
-func (s *deferSocket) Open(env env.Env) (Conn, error) {
+func (s *deferSocket) New(env env.Env) (Conn, error) {
c := &deferConn{
- sock: s,
+ sock: &s.dialSocket,
wait: make(chan bool),
env: env,
}
@@ -54,7 +45,7 @@ func (c *deferConn) Recv(rq queue.Q) error {
return nil
}
- conn, err := c.sock.S.Open(c.env)
+ conn, err := c.sock.New(c.env)
if err != nil {
c.wait <- false
return err
@@ -83,3 +74,7 @@ func (c *deferConn) Close() (err error) {
return
}
+
+func init() {
+ register("defer", deferSocket{})
+}
diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go
index 7623084..b2df3b7 100644
--- a/pkg/server/socket/dial.go
+++ b/pkg/server/socket/dial.go
@@ -4,36 +4,25 @@ import (
"fmt"
"log"
"net"
- "strings"
"time"
+
"tunnel/pkg/server/env"
)
const defaultTimeout = 5 * time.Second
type dialSocket struct {
- proto, addr string
-}
-
-func newDialSocket(proto, addr string) (S, error) {
- return &dialSocket{proto: proto, addr: addr}, nil
+ Proto string `opts:"default:tcp"`
+ Addr string `opts:"required"`
}
func (s *dialSocket) String() string {
- return fmt.Sprintf("%s/%s", s.proto, s.addr)
+ return fmt.Sprintf("%s/%s", s.Proto, s.Addr)
}
-func (s *dialSocket) Open(e env.Env) (Conn, error) {
- addr := e.Expand(s.addr)
-
- switch s.proto {
- case "tcp", "udp":
- if !strings.Contains(addr, ":") {
- addr = "localhost:" + addr
- }
- }
-
- conn, err := net.DialTimeout(s.proto, addr, defaultTimeout)
+func (s *dialSocket) New(e env.Env) (Conn, error) {
+ proto, addr := parseProtoAddr(s.Proto, e.Expand(s.Addr))
+ conn, err := net.DialTimeout(proto, addr, defaultTimeout)
if err != nil {
return nil, err
}
@@ -49,3 +38,7 @@ func (s *dialSocket) Open(e env.Env) (Conn, error) {
func (s *dialSocket) Close() {
}
+
+func init() {
+ register("dial", dialSocket{})
+}
diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go
index 910e5de..2c2f184 100644
--- a/pkg/server/socket/listen.go
+++ b/pkg/server/socket/listen.go
@@ -5,64 +5,54 @@ import (
"fmt"
"log"
"net"
- "strings"
+
"tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
)
type listenSocket struct {
- proto, addr string
- listen net.Listener
- redirect bool
- tproxy bool
-}
+ Proto string `opts:"default:tcp"`
+ Addr string `opts:"required"`
-func newListenSocket(proto, addr string, opts opts.Opts) (S, error) {
- redirect := opts.Bool("redirect")
- tproxy := opts.Bool("tproxy")
+ Redirect bool
+ Tproxy bool
- if proto == "tcp" {
- if !strings.Contains(addr, ":") {
- addr = ":" + addr
- }
- }
+ listen net.Listener
+}
- if redirect && proto != "tcp" {
- return nil, errors.New("redirect not supported")
+func (s *listenSocket) Prepare(e env.Env) error {
+ if s.Redirect && s.Proto != "tcp" {
+ return errors.New("redirect not supported")
}
- if tproxy && proto != "tcp" {
- return nil, errors.New("tproxy not supported")
+ if s.Tproxy && s.Proto != "tcp" {
+ return errors.New("tproxy not supported")
}
- if redirect && tproxy {
- return nil, errors.New("redirect and tproxy cannot be used together")
+ if s.Redirect && s.Tproxy {
+ return errors.New("redirect and tproxy cannot be used together")
}
+ proto, addr := parseProtoAddr(s.Proto, s.Addr)
listen, err := net.Listen(proto, addr)
if err != nil {
- return nil, err
+ return err
}
- if tproxy {
+ e.Set("listen", listen.Addr().String())
+
+ if s.Tproxy {
if err := setConnTransparent(listen); err != nil {
listen.Close()
- return nil, err
+ return err
}
}
- s := &listenSocket{
- proto: proto,
- addr: addr,
- listen: listen,
- redirect: redirect,
- tproxy: tproxy,
- }
+ s.listen = listen
- return s, nil
+ return nil
}
-func (s *listenSocket) Open(env env.Env) (Conn, error) {
+func (s *listenSocket) New(env env.Env) (Conn, error) {
var original string
conn, err := s.listen.Accept()
@@ -74,7 +64,7 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) {
desc := fmt.Sprintf("%s/%s->%s", la.Network(), ra, la)
info := fmt.Sprintf("<%s/%s", ra.Network(), ra)
- if s.redirect {
+ if s.Redirect {
if err := getConnOriginalAddr(conn, &original); err != nil {
log.Println("accept", desc, "failed")
conn.Close()
@@ -84,7 +74,7 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) {
}
}
- if s.tproxy {
+ if s.Tproxy {
env.Set("original", la.String())
}
@@ -98,9 +88,13 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) {
}
func (s *listenSocket) String() string {
- return fmt.Sprintf("%s/%s,listen", s.proto, s.addr)
+ return fmt.Sprintf("%s/%s,listen", s.Proto, s.Addr)
}
func (s *listenSocket) Close() {
s.listen.Close()
}
+
+func init() {
+ register("listen", listenSocket{})
+}
diff --git a/pkg/server/socket/loop.go b/pkg/server/socket/loop.go
index a06448a..c442140 100644
--- a/pkg/server/socket/loop.go
+++ b/pkg/server/socket/loop.go
@@ -30,7 +30,7 @@ func (c *loopConn) Close() error {
return nil
}
-func (s *loopSocket) Open(env.Env) (Conn, error) {
+func (s *loopSocket) New(env.Env) (Conn, error) {
return &loopConn{make(chan queue.Q), make(chan error)}, nil
}
@@ -41,6 +41,6 @@ func (s *loopSocket) String() string {
func (s *loopSocket) Close() {
}
-func newLoopSocket() (S, error) {
- return &loopSocket{}, nil
+func init() {
+ register("loop", loopSocket{})
}
diff --git a/pkg/server/socket/proxy.go b/pkg/server/socket/proxy.go
index e4baec2..1be4bba 100644
--- a/pkg/server/socket/proxy.go
+++ b/pkg/server/socket/proxy.go
@@ -5,6 +5,7 @@ import (
"bytes"
"errors"
"fmt"
+
"tunnel/pkg/http"
"tunnel/pkg/server/env"
"tunnel/pkg/server/queue"
@@ -16,7 +17,7 @@ type status struct {
}
type proxySocket struct {
- proto string
+ Proto string `opts:"default:tcp"`
}
type proxyServer struct {
@@ -28,11 +29,7 @@ type proxyServer struct {
conn Conn
}
-func newProxySocket(proto string) (S, error) {
- return &proxySocket{proto}, nil
-}
-
-func (sock *proxySocket) Open(env env.Env) (Conn, error) {
+func (sock *proxySocket) New(env env.Env) (Conn, error) {
s := &proxyServer{
sock: sock,
auth: env.Value("proxy.auth"),
@@ -78,12 +75,12 @@ func (s *proxyServer) Send(wq queue.Q) error {
}
func (s *proxyServer) initConn(addr string) error {
- dial, err := newDialSocket(s.sock.proto, addr)
- if err != nil {
- return err
+ dial := dialSocket{
+ Proto: s.sock.Proto,
+ Addr: addr,
}
- conn, err := dial.Open(s.env)
+ conn, err := dial.New(s.env)
if err != nil {
dial.Close()
return err
@@ -138,3 +135,7 @@ func (s *proxyServer) Close() (err error) {
return
}
+
+func init() {
+ register("proxy", proxySocket{})
+}
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 62ce5cf..03b73d9 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -5,8 +5,11 @@ import (
"fmt"
"log"
"net"
+ "reflect"
+ "sort"
"strings"
"sync"
+
"tunnel/pkg/server/env"
"tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
@@ -21,7 +24,7 @@ type Conn interface {
}
type S interface {
- Open(env env.Env) (Conn, error)
+ New(env env.Env) (Conn, error)
Close()
}
@@ -66,43 +69,62 @@ func (c *conn) Close() error {
return err
}
-func New(desc string) (S, error) {
- base, opts := opts.Parse(desc)
- args := strings.SplitN(base, "/", 2)
-
- var proto string
- var addr string
+func New(desc string, e env.Env) (S, error) {
+ name, opts := opts.Parse(desc)
- if len(args) < 2 {
- addr = args[0]
- } else {
- proto, addr = args[0], args[1]
+ t, ok := sockets[name]
+ if !ok {
+ return nil, fmt.Errorf("%s: unknown type", name)
}
- if proto == "" {
- proto = "tcp"
+ s := reflect.New(t).Interface()
+ if err := opts.Configure(s); err != nil {
+ return nil, fmt.Errorf("%s: %w", name, err)
}
- switch addr {
- case "loop":
- return newLoopSocket()
- case "proxy":
- return newProxySocket(proto)
- case "":
- return nil, fmt.Errorf("bad socket '%s'", desc)
+ if i, ok := s.(interface{ Prepare(env.Env) error }); ok {
+ if err := i.Prepare(e); err != nil {
+ return nil, fmt.Errorf("%s: %w", name, err)
+ }
}
- if proto == "tun" {
- return newTunSocket(addr)
+ return s.(S), nil
+}
+
+func parseProtoAddr(proto, addr string) (string, string) {
+ if proto == "tcp" || proto == "udp" {
+ if strings.HasPrefix(addr, "-:") {
+ addr = "localhost" + addr[1:]
+ }
}
- if opts.Bool("listen") {
- return newListenSocket(proto, addr, opts)
+ return proto, addr
+}
+
+var sockets = map[string]reflect.Type{}
+
+func register(name string, i interface{}) {
+ t := reflect.TypeOf(i)
+ if t.Kind() != reflect.Struct {
+ log.Panicf("non-struct type '%s'", t.String())
+ }
+ if _, ok := reflect.New(t).Interface().(S); !ok {
+ log.Panicf("uncompatible socket type '%s'", t.String())
+ }
+ if _, ok := sockets[name]; ok {
+ log.Panicf("duplicate socket name '%s'", name)
}
+ sockets[name] = t
+}
+
+func GetList() []string {
+ var list []string
- if opts.Bool("defer") {
- return newDeferSocket(proto, addr)
+ for k := range sockets {
+ list = append(list, k)
}
- return newDialSocket(proto, addr)
+ sort.Strings(list)
+
+ return list
}
diff --git a/pkg/server/socket/tun.go b/pkg/server/socket/tun.go
index d48c30c..3e673eb 100644
--- a/pkg/server/socket/tun.go
+++ b/pkg/server/socket/tun.go
@@ -8,10 +8,11 @@ import (
"os"
"strings"
"sync"
+ "unsafe"
+
"tunnel/pkg/pack"
"tunnel/pkg/server/env"
"tunnel/pkg/server/queue"
- "unsafe"
)
const maxTunBufSize = 65535
@@ -24,7 +25,7 @@ type ifReq struct {
}
type tunSocket struct {
- name string
+ Name string `opts:"required"`
}
type tunConn struct {
@@ -34,35 +35,30 @@ type tunConn struct {
once sync.Once
}
-func newTunSocket(name string) (S, error) {
- return &tunSocket{name: name}, nil
-
-}
-
func (s *tunSocket) String() string {
- return fmt.Sprintf("tun/%s", s.name)
+ return fmt.Sprintf("tun/%s", s.Name)
}
func (s *tunSocket) Single() {}
-func (s *tunSocket) Open(env.Env) (Conn, error) {
+func (s *tunSocket) New(env.Env) (Conn, error) {
fd, err := unix.Open("/dev/net/tun", unix.O_RDWR, 0)
if err != nil {
return nil, err
}
ifr := &ifReq{}
- copy(ifr.name[:], s.name)
+ copy(ifr.name[:], s.Name)
ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI
if err := ioctl(fd, unix.TUNSETIFF, unsafe.Pointer(ifr)); err != nil {
unix.Close(fd)
- return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.name, err)
+ return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.Name, err)
}
if err := unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
- return nil, fmt.Errorf("set nonblock %s: %w", s.name, err)
+ return nil, fmt.Errorf("set nonblock %s: %w", s.Name, err)
}
c := &tunConn{
@@ -124,3 +120,7 @@ func (c *tunConn) Close() error {
return err
}
+
+func init() {
+ register("tun", tunSocket{})
+}