diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-12-16 15:27:48 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-12-16 15:27:48 +0300 |
| commit | 6fed9dd0dd62718f78eca11e30a71c2712636fbd (patch) | |
| tree | 8d1f90b96efbe8ea8aea350c283325adc216ef9d /pkg/server/socket/socket.go | |
| parent | 050ea053dd549f0dd01beddfcd74989858391fd7 (diff) | |
hook and socket args check fix, tests
Diffstat (limited to 'pkg/server/socket/socket.go')
| -rw-r--r-- | pkg/server/socket/socket.go | 76 |
1 files changed, 49 insertions, 27 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index 62ce5cf..03b73d9 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -5,8 +5,11 @@ import ( "fmt" "log" "net" + "reflect" + "sort" "strings" "sync" + "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" @@ -21,7 +24,7 @@ type Conn interface { } type S interface { - Open(env env.Env) (Conn, error) + New(env env.Env) (Conn, error) Close() } @@ -66,43 +69,62 @@ func (c *conn) Close() error { return err } -func New(desc string) (S, error) { - base, opts := opts.Parse(desc) - args := strings.SplitN(base, "/", 2) - - var proto string - var addr string +func New(desc string, e env.Env) (S, error) { + name, opts := opts.Parse(desc) - if len(args) < 2 { - addr = args[0] - } else { - proto, addr = args[0], args[1] + t, ok := sockets[name] + if !ok { + return nil, fmt.Errorf("%s: unknown type", name) } - if proto == "" { - proto = "tcp" + s := reflect.New(t).Interface() + if err := opts.Configure(s); err != nil { + return nil, fmt.Errorf("%s: %w", name, err) } - switch addr { - case "loop": - return newLoopSocket() - case "proxy": - return newProxySocket(proto) - case "": - return nil, fmt.Errorf("bad socket '%s'", desc) + if i, ok := s.(interface{ Prepare(env.Env) error }); ok { + if err := i.Prepare(e); err != nil { + return nil, fmt.Errorf("%s: %w", name, err) + } } - if proto == "tun" { - return newTunSocket(addr) + return s.(S), nil +} + +func parseProtoAddr(proto, addr string) (string, string) { + if proto == "tcp" || proto == "udp" { + if strings.HasPrefix(addr, "-:") { + addr = "localhost" + addr[1:] + } } - if opts.Bool("listen") { - return newListenSocket(proto, addr, opts) + return proto, addr +} + +var sockets = map[string]reflect.Type{} + +func register(name string, i interface{}) { + t := reflect.TypeOf(i) + if t.Kind() != reflect.Struct { + log.Panicf("non-struct type '%s'", t.String()) + } + if _, ok := reflect.New(t).Interface().(S); !ok { + log.Panicf("uncompatible socket type '%s'", t.String()) + } + if _, ok := sockets[name]; ok { + log.Panicf("duplicate socket name '%s'", name) } + sockets[name] = t +} + +func GetList() []string { + var list []string - if opts.Bool("defer") { - return newDeferSocket(proto, addr) + for k := range sockets { + list = append(list, k) } - return newDialSocket(proto, addr) + sort.Strings(list) + + return list } |
