package socket import ( "tunnel/pkg/server/queue" "strings" "sync" "fmt" "log" "net" ) type Channel interface { Send(wq queue.Q) error Recv(rq queue.Q) error Close() } type S interface { Open() (Channel, error) Close() } type listenSocket struct { proto, addr string listen net.Listener } type dialSocket struct { proto, addr string } type connChannel struct { conn net.Conn once sync.Once } func newConnChannel(conn net.Conn) Channel { return &connChannel{conn: conn} } func (cc *connChannel) Send(wq queue.Q) (err error) { defer cc.shutdown(&err) return queue.IoCopy(cc.conn, wq.Writer()) } func (cc *connChannel) Recv(rq queue.Q) (err error) { defer cc.shutdown(&err) for b := range rq { for len(b) > 0 { n, err := cc.conn.Write(b) if err != nil { return err } b = b[n:] } } return nil } func (cc *connChannel) String() string { local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr() return fmt.Sprintf("%s/%s->%s", local.Network(), remote, local) } func (cc *connChannel) shutdown(err *error) { miss := true cc.once.Do(func () { miss = false log.Println("close", cc) if e := cc.conn.Close(); e != nil && *err != nil { *err = e } }) if miss { *err = nil } } func (cc *connChannel) Close() { var err error cc.shutdown(&err) } func newListenSocket(proto, addr string) (S, error) { if proto == "tcp" { if !strings.Contains(addr, ":") { addr = ":" + addr } } listen, err := net.Listen(proto, addr) if err != nil { return nil, err } s := &listenSocket{ proto: proto, addr: addr, listen: listen, } return s, nil } func (s *listenSocket) Open() (Channel, error) { conn, err := s.listen.Accept() if err != nil { return nil, err } return newConnChannel(conn), nil } func (s *listenSocket) String() string { return fmt.Sprintf("%s/%s,listen", s.proto, s.addr) } func (s *listenSocket) Close() { s.listen.Close() } func newDialSocket(proto, addr string) (S, error) { switch proto { case "tcp", "udp": if !strings.Contains(addr, ":") { addr = "localhost:" + addr } } return &dialSocket{proto: proto, addr: addr}, nil } func (s *dialSocket) String() string { return fmt.Sprintf("%s/%s", s.proto, s.addr) } func (s *dialSocket) Open() (Channel, error) { conn, err := net.Dial(s.proto, s.addr) if err != nil { return nil, err } return newConnChannel(conn), nil } func (s *dialSocket) Close() { } func New(name string) (S, error) { vv := strings.Split(name, ",") args := strings.Split(vv[0], "/") opts := map[string]string{} for _, v := range vv[1:] { ss := strings.SplitN(v, "=", 2) if len(ss) < 2 { opts[ss[0]] = "" } else { opts[ss[0]] = ss[1] } } var proto string var addr string if len(args) < 2 { addr = args[0] } else { proto, addr = args[0], args[1] } if proto == "" { proto = "tcp" } if addr == "" { return nil, fmt.Errorf("bad socket '%s'", name) } if _, ok := opts["listen"]; ok { return newListenSocket(proto, addr) } return newDialSocket(proto, addr) }