package socket import ( "tunnel/pkg/server/queue" "strings" "sync" "fmt" "log" "net" ) type Channel interface { Send(wq queue.Q) error Recv(rq queue.Q) error Close() } type S interface { Open() (Channel, error) Close() } type listenSocket struct { listen net.Listener } type dialSocket struct { proto, addr string } type connChannel struct { conn net.Conn once sync.Once cancel chan struct{} } func newConnChannel(conn net.Conn) Channel { return &connChannel{conn: conn, cancel: make(chan struct{})} } func (cc *connChannel) Send(wq queue.Q) (err error) { defer cc.shutdown(&err) return queue.IoCopy(cc.conn, wq.Writer()) } func (cc *connChannel) Recv(rq queue.Q) (err error) { defer cc.shutdown(&err) for b := range rq { for len(b) > 0 { n, err := cc.conn.Write(b) if err != nil { return err } b = b[n:] } } return nil } func (cc *connChannel) String() string { addr := cc.conn.RemoteAddr() return fmt.Sprintf("%s/%s", addr.Network(), addr.String()) } func (cc *connChannel) isCanceled() bool { select { case <- cc.cancel: return true default: return false } } func (cc *connChannel) shutdown(err *error) { select { case <- cc.cancel: *err = nil default: cc.once.Do(func () { close(cc.cancel) log.Println("close", cc) if e := cc.conn.Close(); e != nil && *err != nil { *err = e } }) } } func (cc *connChannel) Close() { var err error cc.shutdown(&err) } func newListenSocket(proto, addr string) (S, error) { if !strings.Contains(addr, ":") { addr = ":" + addr } listen, err := net.Listen(proto, addr) if err != nil { return nil, err } return &listenSocket{listen: listen}, nil } func (s *listenSocket) Open() (Channel, error) { conn, err := s.listen.Accept() if err != nil { return nil, err } return newConnChannel(conn), nil } func (s *listenSocket) Close() { s.listen.Close() } func newDialSocket(proto, addr string) (S, error) { return &dialSocket{proto: proto, addr: addr}, nil } func (s *dialSocket) Open() (Channel, error) { conn, err := net.Dial(s.proto, s.addr) if err != nil { return nil, err } return newConnChannel(conn), nil } func (s *dialSocket) Close() { } func New(desc string) (S, error) { args := strings.Split(desc, "/") if len(args) != 2 { return nil, fmt.Errorf("bad socket '%s'", desc) } proto, addr := args[0], args[1] switch proto { case "tcp-listen": return newListenSocket("tcp", addr) case "tcp": return newDialSocket("tcp", addr) } return nil, fmt.Errorf("bad socket '%s': unknown type", desc) }