summaryrefslogtreecommitdiff
path: root/pkg/server/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/socket')
-rw-r--r--pkg/server/socket/socket.go151
1 files changed, 151 insertions, 0 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
new file mode 100644
index 0000000..f097a80
--- /dev/null
+++ b/pkg/server/socket/socket.go
@@ -0,0 +1,151 @@
+package socket
+
+import (
+ "tunnel/pkg/server/queue"
+ "strings"
+ "sync"
+ "fmt"
+ "log"
+ "net"
+)
+
+type Channel interface {
+ Send(wq queue.Q) error
+ Recv(rq queue.Q) error
+ Close()
+}
+
+type S interface {
+ Open() (Channel, error)
+ Close()
+}
+
+type listenSocket struct {
+ listen net.Listener
+}
+
+type dialSocket struct {
+ proto, addr string
+}
+
+type connChannel struct {
+ conn net.Conn
+ once sync.Once
+ cancel chan struct{}
+}
+
+func newConnChannel(conn net.Conn) Channel {
+ return &connChannel{conn: conn, cancel: make(chan struct{})}
+}
+
+func (cc *connChannel) Send(wq queue.Q) (err error) {
+ defer cc.shutdown(&err)
+ return queue.IoCopy(cc.conn, wq.Writer())
+}
+
+func (cc *connChannel) Recv(rq queue.Q) (err error) {
+ defer cc.shutdown(&err)
+
+ for b := range rq {
+ for len(b) > 0 {
+ n, err := cc.conn.Write(b)
+ if err != nil {
+ return err
+ }
+ b = b[n:]
+ }
+ }
+
+ return nil
+}
+
+func (cc *connChannel) String() string {
+ addr := cc.conn.RemoteAddr()
+ return fmt.Sprintf("%s/%s", addr.Network(), addr.String())
+}
+
+func (cc *connChannel) isCanceled() bool {
+ select {
+ case <- cc.cancel:
+ return true
+ default:
+ return false
+ }
+}
+
+func (cc *connChannel) shutdown(err *error) {
+ select {
+ case <- cc.cancel:
+ *err = nil
+ default:
+ cc.once.Do(func () {
+ close(cc.cancel)
+ log.Println("close", cc)
+ if e := cc.conn.Close(); e != nil && *err != nil {
+ *err = e
+ }
+ })
+ }
+}
+
+func (cc *connChannel) Close() {
+ var err error
+ cc.shutdown(&err)
+}
+
+func newListenSocket(proto, addr string) (S, error) {
+ if !strings.Contains(addr, ":") {
+ addr = ":" + addr
+ }
+
+ listen, err := net.Listen(proto, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ return &listenSocket{listen: listen}, nil
+}
+
+func (s *listenSocket) Open() (Channel, error) {
+ conn, err := s.listen.Accept()
+ if err != nil {
+ return nil, err
+ }
+ return newConnChannel(conn), nil
+}
+
+func (s *listenSocket) Close() {
+ s.listen.Close()
+}
+
+func newDialSocket(proto, addr string) (S, error) {
+ return &dialSocket{proto: proto, addr: addr}, nil
+}
+
+func (s *dialSocket) Open() (Channel, error) {
+ conn, err := net.Dial(s.proto, s.addr)
+ if err != nil {
+ return nil, err
+ }
+ return newConnChannel(conn), nil
+}
+
+func (s *dialSocket) Close() {
+}
+
+func New(desc string) (S, error) {
+ args := strings.Split(desc, "/")
+
+ if len(args) != 2 {
+ return nil, fmt.Errorf("bad socket '%s'", desc)
+ }
+
+ proto, addr := args[0], args[1]
+
+ switch proto {
+ case "tcp-listen": return newListenSocket("tcp", addr)
+ case "tcp": return newDialSocket("tcp", addr)
+ }
+
+ return nil, fmt.Errorf("bad socket '%s': unknown type", desc)
+}