summaryrefslogtreecommitdiff
path: root/pkg/server/socket/socket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/socket/socket.go')
-rw-r--r--pkg/server/socket/socket.go106
1 files changed, 71 insertions, 35 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index f097a80..cad1ad3 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -21,6 +21,7 @@ type S interface {
}
type listenSocket struct {
+ proto, addr string
listen net.Listener
}
@@ -31,11 +32,10 @@ type dialSocket struct {
type connChannel struct {
conn net.Conn
once sync.Once
- cancel chan struct{}
}
func newConnChannel(conn net.Conn) Channel {
- return &connChannel{conn: conn, cancel: make(chan struct{})}
+ return &connChannel{conn: conn}
}
func (cc *connChannel) Send(wq queue.Q) (err error) {
@@ -60,31 +60,23 @@ func (cc *connChannel) Recv(rq queue.Q) (err error) {
}
func (cc *connChannel) String() string {
- addr := cc.conn.RemoteAddr()
- return fmt.Sprintf("%s/%s", addr.Network(), addr.String())
-}
-
-func (cc *connChannel) isCanceled() bool {
- select {
- case <- cc.cancel:
- return true
- default:
- return false
- }
+ local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr()
+ return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote)
}
func (cc *connChannel) shutdown(err *error) {
- select {
- case <- cc.cancel:
+ miss := true
+
+ cc.once.Do(func () {
+ miss = false
+ log.Println("close", cc)
+ if e := cc.conn.Close(); e != nil && *err != nil {
+ *err = e
+ }
+ })
+
+ if miss {
*err = nil
- default:
- cc.once.Do(func () {
- close(cc.cancel)
- log.Println("close", cc)
- if e := cc.conn.Close(); e != nil && *err != nil {
- *err = e
- }
- })
}
}
@@ -94,8 +86,10 @@ func (cc *connChannel) Close() {
}
func newListenSocket(proto, addr string) (S, error) {
- if !strings.Contains(addr, ":") {
- addr = ":" + addr
+ if proto == "tcp" {
+ if !strings.Contains(addr, ":") {
+ addr = ":" + addr
+ }
}
listen, err := net.Listen(proto, addr)
@@ -103,7 +97,13 @@ func newListenSocket(proto, addr string) (S, error) {
return nil, err
}
- return &listenSocket{listen: listen}, nil
+ s := &listenSocket{
+ proto: proto,
+ addr: addr,
+ listen: listen,
+ }
+
+ return s, nil
}
func (s *listenSocket) Open() (Channel, error) {
@@ -114,14 +114,29 @@ func (s *listenSocket) Open() (Channel, error) {
return newConnChannel(conn), nil
}
+func (s *listenSocket) String() string {
+ return fmt.Sprintf("%s/%s,listen", s.proto, s.addr)
+}
+
func (s *listenSocket) Close() {
s.listen.Close()
}
func newDialSocket(proto, addr string) (S, error) {
+ switch proto {
+ case "tcp", "udp":
+ if !strings.Contains(addr, ":") {
+ addr = "localhost:" + addr
+ }
+ }
+
return &dialSocket{proto: proto, addr: addr}, nil
}
+func (s *dialSocket) String() string {
+ return fmt.Sprintf("%s/%s", s.proto, s.addr)
+}
+
func (s *dialSocket) Open() (Channel, error) {
conn, err := net.Dial(s.proto, s.addr)
if err != nil {
@@ -133,19 +148,40 @@ func (s *dialSocket) Open() (Channel, error) {
func (s *dialSocket) Close() {
}
-func New(desc string) (S, error) {
- args := strings.Split(desc, "/")
+func New(name string) (S, error) {
+ vv := strings.Split(name, ",")
+ args := strings.Split(vv[0], "/")
+ opts := map[string]string{}
- if len(args) != 2 {
- return nil, fmt.Errorf("bad socket '%s'", desc)
+ for _, v := range vv[1:] {
+ ss := strings.SplitN(v, "=", 2)
+ if len(ss) < 2 {
+ opts[ss[0]] = ""
+ } else {
+ opts[ss[0]] = ss[1]
+ }
}
- proto, addr := args[0], args[1]
+ var proto string
+ var addr string
- switch proto {
- case "tcp-listen": return newListenSocket("tcp", addr)
- case "tcp": return newDialSocket("tcp", addr)
+ if len(args) < 2 {
+ addr = args[0]
+ } else {
+ proto, addr = args[0], args[1]
+ }
+
+ if proto == "" {
+ proto = "tcp"
+ }
+
+ if addr == "" {
+ return nil, fmt.Errorf("bad socket '%s'", name)
+ }
+
+ if _, ok := opts["listen"]; ok {
+ return newListenSocket(proto, addr)
}
- return nil, fmt.Errorf("bad socket '%s': unknown type", desc)
+ return newDialSocket(proto, addr)
}