summaryrefslogtreecommitdiff
path: root/pkg/server/socket/socket.go
diff options
context:
space:
mode:
authorMikhail Osipov <mike.osipov@gmail.com>2020-12-16 15:27:48 +0300
committerMikhail Osipov <mike.osipov@gmail.com>2020-12-16 15:27:48 +0300
commit6fed9dd0dd62718f78eca11e30a71c2712636fbd (patch)
tree8d1f90b96efbe8ea8aea350c283325adc216ef9d /pkg/server/socket/socket.go
parent050ea053dd549f0dd01beddfcd74989858391fd7 (diff)
hook and socket args check fix, tests
Diffstat (limited to 'pkg/server/socket/socket.go')
-rw-r--r--pkg/server/socket/socket.go76
1 files changed, 49 insertions, 27 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 62ce5cf..03b73d9 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -5,8 +5,11 @@ import (
"fmt"
"log"
"net"
+ "reflect"
+ "sort"
"strings"
"sync"
+
"tunnel/pkg/server/env"
"tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
@@ -21,7 +24,7 @@ type Conn interface {
}
type S interface {
- Open(env env.Env) (Conn, error)
+ New(env env.Env) (Conn, error)
Close()
}
@@ -66,43 +69,62 @@ func (c *conn) Close() error {
return err
}
-func New(desc string) (S, error) {
- base, opts := opts.Parse(desc)
- args := strings.SplitN(base, "/", 2)
-
- var proto string
- var addr string
+func New(desc string, e env.Env) (S, error) {
+ name, opts := opts.Parse(desc)
- if len(args) < 2 {
- addr = args[0]
- } else {
- proto, addr = args[0], args[1]
+ t, ok := sockets[name]
+ if !ok {
+ return nil, fmt.Errorf("%s: unknown type", name)
}
- if proto == "" {
- proto = "tcp"
+ s := reflect.New(t).Interface()
+ if err := opts.Configure(s); err != nil {
+ return nil, fmt.Errorf("%s: %w", name, err)
}
- switch addr {
- case "loop":
- return newLoopSocket()
- case "proxy":
- return newProxySocket(proto)
- case "":
- return nil, fmt.Errorf("bad socket '%s'", desc)
+ if i, ok := s.(interface{ Prepare(env.Env) error }); ok {
+ if err := i.Prepare(e); err != nil {
+ return nil, fmt.Errorf("%s: %w", name, err)
+ }
}
- if proto == "tun" {
- return newTunSocket(addr)
+ return s.(S), nil
+}
+
+func parseProtoAddr(proto, addr string) (string, string) {
+ if proto == "tcp" || proto == "udp" {
+ if strings.HasPrefix(addr, "-:") {
+ addr = "localhost" + addr[1:]
+ }
}
- if opts.Bool("listen") {
- return newListenSocket(proto, addr, opts)
+ return proto, addr
+}
+
+var sockets = map[string]reflect.Type{}
+
+func register(name string, i interface{}) {
+ t := reflect.TypeOf(i)
+ if t.Kind() != reflect.Struct {
+ log.Panicf("non-struct type '%s'", t.String())
+ }
+ if _, ok := reflect.New(t).Interface().(S); !ok {
+ log.Panicf("uncompatible socket type '%s'", t.String())
+ }
+ if _, ok := sockets[name]; ok {
+ log.Panicf("duplicate socket name '%s'", name)
}
+ sockets[name] = t
+}
+
+func GetList() []string {
+ var list []string
- if opts.Bool("defer") {
- return newDeferSocket(proto, addr)
+ for k := range sockets {
+ list = append(list, k)
}
- return newDialSocket(proto, addr)
+ sort.Strings(list)
+
+ return list
}