diff options
Diffstat (limited to 'pkg/server/socket/socket.go')
| -rw-r--r-- | pkg/server/socket/socket.go | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go new file mode 100644 index 0000000..f097a80 --- /dev/null +++ b/pkg/server/socket/socket.go @@ -0,0 +1,151 @@ +package socket + +import ( + "tunnel/pkg/server/queue" + "strings" + "sync" + "fmt" + "log" + "net" +) + +type Channel interface { + Send(wq queue.Q) error + Recv(rq queue.Q) error + Close() +} + +type S interface { + Open() (Channel, error) + Close() +} + +type listenSocket struct { + listen net.Listener +} + +type dialSocket struct { + proto, addr string +} + +type connChannel struct { + conn net.Conn + once sync.Once + cancel chan struct{} +} + +func newConnChannel(conn net.Conn) Channel { + return &connChannel{conn: conn, cancel: make(chan struct{})} +} + +func (cc *connChannel) Send(wq queue.Q) (err error) { + defer cc.shutdown(&err) + return queue.IoCopy(cc.conn, wq.Writer()) +} + +func (cc *connChannel) Recv(rq queue.Q) (err error) { + defer cc.shutdown(&err) + + for b := range rq { + for len(b) > 0 { + n, err := cc.conn.Write(b) + if err != nil { + return err + } + b = b[n:] + } + } + + return nil +} + +func (cc *connChannel) String() string { + addr := cc.conn.RemoteAddr() + return fmt.Sprintf("%s/%s", addr.Network(), addr.String()) +} + +func (cc *connChannel) isCanceled() bool { + select { + case <- cc.cancel: + return true + default: + return false + } +} + +func (cc *connChannel) shutdown(err *error) { + select { + case <- cc.cancel: + *err = nil + default: + cc.once.Do(func () { + close(cc.cancel) + log.Println("close", cc) + if e := cc.conn.Close(); e != nil && *err != nil { + *err = e + } + }) + } +} + +func (cc *connChannel) Close() { + var err error + cc.shutdown(&err) +} + +func newListenSocket(proto, addr string) (S, error) { + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + listen, err := net.Listen(proto, addr) + if err != nil { + return nil, err + } + + return &listenSocket{listen: listen}, nil +} + +func (s *listenSocket) Open() (Channel, error) { + conn, err := s.listen.Accept() + if err != nil { + return nil, err + } + return newConnChannel(conn), nil +} + +func (s *listenSocket) Close() { + s.listen.Close() +} + +func newDialSocket(proto, addr string) (S, error) { + return &dialSocket{proto: proto, addr: addr}, nil +} + +func (s *dialSocket) Open() (Channel, error) { + conn, err := net.Dial(s.proto, s.addr) + if err != nil { + return nil, err + } + return newConnChannel(conn), nil +} + +func (s *dialSocket) Close() { +} + +func New(desc string) (S, error) { + args := strings.Split(desc, "/") + + if len(args) != 2 { + return nil, fmt.Errorf("bad socket '%s'", desc) + } + + proto, addr := args[0], args[1] + + switch proto { + case "tcp-listen": return newListenSocket("tcp", addr) + case "tcp": return newDialSocket("tcp", addr) + } + + return nil, fmt.Errorf("bad socket '%s': unknown type", desc) +} |
