From 9e04a8bee8492cb662ebc8b7fd50a23c48c7d03f Mon Sep 17 00:00:00 2001 From: Mikhail Osipov Date: Thu, 20 Feb 2020 04:56:03 +0300 Subject: streams and tunnels --- pkg/server/socket/socket.go | 106 +++++++++++++++++++++++++++++--------------- 1 file changed, 71 insertions(+), 35 deletions(-) (limited to 'pkg/server/socket') diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index f097a80..cad1ad3 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -21,6 +21,7 @@ type S interface { } type listenSocket struct { + proto, addr string listen net.Listener } @@ -31,11 +32,10 @@ type dialSocket struct { type connChannel struct { conn net.Conn once sync.Once - cancel chan struct{} } func newConnChannel(conn net.Conn) Channel { - return &connChannel{conn: conn, cancel: make(chan struct{})} + return &connChannel{conn: conn} } func (cc *connChannel) Send(wq queue.Q) (err error) { @@ -60,31 +60,23 @@ func (cc *connChannel) Recv(rq queue.Q) (err error) { } func (cc *connChannel) String() string { - addr := cc.conn.RemoteAddr() - return fmt.Sprintf("%s/%s", addr.Network(), addr.String()) -} - -func (cc *connChannel) isCanceled() bool { - select { - case <- cc.cancel: - return true - default: - return false - } + local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr() + return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote) } func (cc *connChannel) shutdown(err *error) { - select { - case <- cc.cancel: + miss := true + + cc.once.Do(func () { + miss = false + log.Println("close", cc) + if e := cc.conn.Close(); e != nil && *err != nil { + *err = e + } + }) + + if miss { *err = nil - default: - cc.once.Do(func () { - close(cc.cancel) - log.Println("close", cc) - if e := cc.conn.Close(); e != nil && *err != nil { - *err = e - } - }) } } @@ -94,8 +86,10 @@ func (cc *connChannel) Close() { } func newListenSocket(proto, addr string) (S, error) { - if !strings.Contains(addr, ":") { - addr = ":" + addr + if proto == "tcp" { + if !strings.Contains(addr, ":") { + addr = ":" + addr + } } listen, err := net.Listen(proto, addr) @@ -103,7 +97,13 @@ func newListenSocket(proto, addr string) (S, error) { return nil, err } - return &listenSocket{listen: listen}, nil + s := &listenSocket{ + proto: proto, + addr: addr, + listen: listen, + } + + return s, nil } func (s *listenSocket) Open() (Channel, error) { @@ -114,14 +114,29 @@ func (s *listenSocket) Open() (Channel, error) { return newConnChannel(conn), nil } +func (s *listenSocket) String() string { + return fmt.Sprintf("%s/%s,listen", s.proto, s.addr) +} + func (s *listenSocket) Close() { s.listen.Close() } func newDialSocket(proto, addr string) (S, error) { + switch proto { + case "tcp", "udp": + if !strings.Contains(addr, ":") { + addr = "localhost:" + addr + } + } + return &dialSocket{proto: proto, addr: addr}, nil } +func (s *dialSocket) String() string { + return fmt.Sprintf("%s/%s", s.proto, s.addr) +} + func (s *dialSocket) Open() (Channel, error) { conn, err := net.Dial(s.proto, s.addr) if err != nil { @@ -133,19 +148,40 @@ func (s *dialSocket) Open() (Channel, error) { func (s *dialSocket) Close() { } -func New(desc string) (S, error) { - args := strings.Split(desc, "/") +func New(name string) (S, error) { + vv := strings.Split(name, ",") + args := strings.Split(vv[0], "/") + opts := map[string]string{} - if len(args) != 2 { - return nil, fmt.Errorf("bad socket '%s'", desc) + for _, v := range vv[1:] { + ss := strings.SplitN(v, "=", 2) + if len(ss) < 2 { + opts[ss[0]] = "" + } else { + opts[ss[0]] = ss[1] + } } - proto, addr := args[0], args[1] + var proto string + var addr string - switch proto { - case "tcp-listen": return newListenSocket("tcp", addr) - case "tcp": return newDialSocket("tcp", addr) + if len(args) < 2 { + addr = args[0] + } else { + proto, addr = args[0], args[1] + } + + if proto == "" { + proto = "tcp" + } + + if addr == "" { + return nil, fmt.Errorf("bad socket '%s'", name) + } + + if _, ok := opts["listen"]; ok { + return newListenSocket(proto, addr) } - return nil, fmt.Errorf("bad socket '%s': unknown type", desc) + return newDialSocket(proto, addr) } -- cgit v1.2.3-70-g09d2