package socket import ( "errors" "fmt" "log" "net" "strings" "sync" "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) var errAlreadyClosed = errors.New("already closed") type Channel interface { Origin() string Send(wq queue.Q) error Recv(rq queue.Q) error Close() error } type S interface { Open(env env.Env) (Channel, error) Close() } type listenSocket struct { proto, addr string listen net.Listener } type dialSocket struct { proto, addr string } type connChannel struct { origin string conn net.Conn once sync.Once } type loopSocket struct{} type loopChannel struct { q queue.Q } func newConnChannel(origin string, conn net.Conn) Channel { return &connChannel{origin: origin, conn: conn} } func (c *connChannel) final(f func() error, err error) error { if e := f(); e != nil { if e == errAlreadyClosed { return nil } else { return e } } return err } func (c *connChannel) Origin() string { return c.origin } func (c *connChannel) Send(wq queue.Q) error { err := queue.IoCopy(c.conn, wq.Writer()) return c.final(c.Close, err) } func (c *connChannel) Recv(rq queue.Q) error { err := queue.IoCopy(rq.Reader(), c.conn) return c.final(c.Close, err) } func (c *connChannel) String() string { local, remote := c.conn.LocalAddr(), c.conn.RemoteAddr() return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote) } func (c *connChannel) Close() error { err := errAlreadyClosed c.once.Do(func() { log.Println("close", c) err = c.conn.Close() }) return err } func newListenSocket(proto, addr string) (S, error) { if proto == "tcp" { if !strings.Contains(addr, ":") { addr = ":" + addr } } listen, err := net.Listen(proto, addr) if err != nil { return nil, err } s := &listenSocket{ proto: proto, addr: addr, listen: listen, } return s, nil } func (s *listenSocket) Open(env env.Env) (Channel, error) { conn, err := s.listen.Accept() if err != nil { return nil, err } addr := conn.RemoteAddr() origin := fmt.Sprintf("%s/%s", addr.Network(), addr) return newConnChannel(origin, conn), nil } func (s *listenSocket) String() string { return fmt.Sprintf("%s/%s,listen", s.proto, s.addr) } func (s *listenSocket) Close() { s.listen.Close() } func newDialSocket(proto, addr string) (S, error) { switch proto { case "tcp", "udp": if !strings.Contains(addr, ":") { addr = "localhost:" + addr } } return &dialSocket{proto: proto, addr: addr}, nil } func (s *dialSocket) String() string { return fmt.Sprintf("%s/%s", s.proto, s.addr) } func (s *dialSocket) Open(env env.Env) (Channel, error) { conn, err := net.Dial(s.proto, s.addr) if err != nil { return nil, err } return newConnChannel("-", conn), nil } func (s *dialSocket) Close() { } func (c *loopChannel) Origin() string { return "loop" } func (c *loopChannel) Send(wq queue.Q) error { return queue.Copy(c.q, wq) } func (c *loopChannel) Recv(rq queue.Q) error { defer close(c.q) return queue.Copy(rq, c.q) } func (c *loopChannel) Close() error { return nil } func (c *loopChannel) String() string { return "loop" } func (s *loopSocket) Open(env.Env) (Channel, error) { return &loopChannel{queue.New()}, nil } func (s *loopSocket) String() string { return "loop" } func (s *loopSocket) Close() { } func newLoopSocket() (S, error) { return &loopSocket{}, nil } func New(desc string, env env.Env) (S, error) { base, opts := opts.Parse(desc) args := strings.SplitN(base, "/", 2) var proto string var addr string if len(args) < 2 { addr = args[0] } else { proto, addr = args[0], args[1] } if addr == "loop" { return newLoopSocket() } if proto == "" { proto = "tcp" } if addr == "" { return nil, fmt.Errorf("bad socket '%s'", desc) } if _, ok := opts["listen"]; ok { return newListenSocket(proto, addr) } return newDialSocket(proto, addr) }