package socket import ( "fmt" "log" "net" "strings" "sync" "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) type Channel interface { Send(wq queue.Q) error Recv(rq queue.Q) error Close() } type S interface { Open(env env.Env) (Channel, error) Close() } type listenSocket struct { proto, addr string listen net.Listener } type dialSocket struct { proto, addr string } type connChannel struct { conn net.Conn once sync.Once } func newConnChannel(conn net.Conn) Channel { return &connChannel{conn: conn} } func (cc *connChannel) Send(wq queue.Q) (err error) { defer cc.shutdown(&err) return queue.IoCopy(cc.conn, wq.Writer()) } func (cc *connChannel) Recv(rq queue.Q) (err error) { defer cc.shutdown(&err) return queue.IoCopy(rq.Reader(), cc.conn) } func (cc *connChannel) String() string { local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr() return fmt.Sprintf("%s/%s->%s", local.Network(), remote, local) } func (cc *connChannel) shutdown(err *error) { miss := true cc.once.Do(func() { miss = false log.Println("close", cc) if e := cc.conn.Close(); e != nil && *err != nil { *err = e } }) if miss { *err = nil } } func (cc *connChannel) Close() { var err error cc.shutdown(&err) } func newListenSocket(proto, addr string) (S, error) { if proto == "tcp" { if !strings.Contains(addr, ":") { addr = ":" + addr } } listen, err := net.Listen(proto, addr) if err != nil { return nil, err } s := &listenSocket{ proto: proto, addr: addr, listen: listen, } return s, nil } func (s *listenSocket) Open(env env.Env) (Channel, error) { conn, err := s.listen.Accept() if err != nil { return nil, err } return newConnChannel(conn), nil } func (s *listenSocket) String() string { return fmt.Sprintf("%s/%s,listen", s.proto, s.addr) } func (s *listenSocket) Close() { s.listen.Close() } func newDialSocket(proto, addr string) (S, error) { switch proto { case "tcp", "udp": if !strings.Contains(addr, ":") { addr = "localhost:" + addr } } return &dialSocket{proto: proto, addr: addr}, nil } func (s *dialSocket) String() string { return fmt.Sprintf("%s/%s", s.proto, s.addr) } func (s *dialSocket) Open(env env.Env) (Channel, error) { conn, err := net.Dial(s.proto, s.addr) if err != nil { return nil, err } return newConnChannel(conn), nil } func (s *dialSocket) Close() { } func New(desc string, env env.Env) (S, error) { base, opts := opts.Parse(desc) args := strings.Split(base, "/") var proto string var addr string if len(args) < 2 { addr = args[0] } else { proto, addr = args[0], args[1] } if proto == "" { proto = "tcp" } if addr == "" { return nil, fmt.Errorf("bad socket '%s'", desc) } if _, ok := opts["listen"]; ok { return newListenSocket(proto, addr) } return newDialSocket(proto, addr) }