package socket import ( "errors" "fmt" "log" "net" "reflect" "sort" "strings" "sync" "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) var ErrAlreadyClosed = errors.New("already closed") type Conn interface { Send(wq queue.Q) error Recv(rq queue.Q) error Close() error } type S interface { New(env env.Env) (Conn, error) Close() } type Single interface { Single() } type conn struct { net.Conn desc string info string once sync.Once } func newConn(cn net.Conn, desc, info string) *conn { c := &conn{Conn: cn, desc: desc, info: info} return c } func (c *conn) Send(wq queue.Q) error { return queue.IoCopy(c, wq.Writer()) } func (c *conn) Recv(rq queue.Q) error { return queue.IoCopy(rq.Reader(), c) } func (c *conn) String() string { return c.info } func (c *conn) Close() error { err := ErrAlreadyClosed c.once.Do(func() { log.Println("close", c.desc) err = c.Conn.Close() }) return err } func New(desc string, e env.Env) (S, error) { name, opts := opts.Parse(desc) sockType, ok := sockets[name] if !ok { return nil, fmt.Errorf("%s: unknown type", name) } s := reflect.New(sockType.t).Interface() if err := opts.Configure(s); err != nil { return nil, fmt.Errorf("%s: %w", name, err) } if i, ok := s.(interface{ Prepare(env.Env) error }); ok { if err := i.Prepare(e); err != nil { return nil, fmt.Errorf("%s: %w", name, err) } } return s.(S), nil } func parseProtoAddr(proto, addr string) (string, string) { if proto == "tcp" || proto == "udp" { if strings.HasPrefix(addr, "-:") { addr = "localhost" + addr[1:] } } return proto, addr } type Type struct { Desc string Name string Param []opts.Param t reflect.Type } var sockets = map[string]*Type{} func register(name string, desc string, i interface{}) { t := reflect.TypeOf(i) if t.Kind() != reflect.Struct { log.Panicf("non-struct type '%s'", t.String()) } if _, ok := reflect.New(t).Interface().(S); !ok { log.Panicf("uncompatible socket type '%s'", t.String()) } if _, ok := sockets[name]; ok { log.Panicf("duplicate socket name '%s'", name) } sockets[name] = &Type{ Desc: desc, Name: name, Param: opts.Parametrize(t), t: t, } } func GetList() []string { var list []string for k := range sockets { list = append(list, k) } sort.Strings(list) return list } func GetType(name string) *Type { return sockets[name] }