summaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
authorMikhail Osipov <mike.osipov@gmail.com>2020-03-03 10:48:04 +0300
committerMikhail Osipov <mike.osipov@gmail.com>2020-03-03 23:37:57 +0300
commitff51af3c48cee0d15c9802b2adb44f59a77e4c75 (patch)
tree0425f1b881cf6708ccb071dd18c6a5072879ddb4 /pkg
parente6a63987f6963241dcfa981bf1081206a06f2990 (diff)
handle channel close
Diffstat (limited to 'pkg')
-rw-r--r--pkg/server/socket/socket.go60
1 files changed, 33 insertions, 27 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 4629e22..a9fa319 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -1,6 +1,7 @@
package socket
import (
+ "errors"
"fmt"
"log"
"net"
@@ -11,10 +12,12 @@ import (
"tunnel/pkg/server/queue"
)
+var errAlreadyClosed = errors.New("already closed")
+
type Channel interface {
Send(wq queue.Q) error
Recv(rq queue.Q) error
- Close()
+ Close() error
}
type S interface {
@@ -46,40 +49,42 @@ func newConnChannel(conn net.Conn) Channel {
return &connChannel{conn: conn}
}
-func (cc *connChannel) Send(wq queue.Q) (err error) {
- defer cc.shutdown(&err)
- return queue.IoCopy(cc.conn, wq.Writer())
+func (c *connChannel) final(f func() error, err error) error {
+ if e := f(); e != nil {
+ if e == errAlreadyClosed {
+ return nil
+ } else {
+ return e
+ }
+ }
+
+ return err
}
-func (cc *connChannel) Recv(rq queue.Q) (err error) {
- defer cc.shutdown(&err)
- return queue.IoCopy(rq.Reader(), cc.conn)
+func (c *connChannel) Send(wq queue.Q) error {
+ err := queue.IoCopy(c.conn, wq.Writer())
+ return c.final(c.Close, err)
}
-func (cc *connChannel) String() string {
- local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr()
- return fmt.Sprintf("%s/%s->%s", local.Network(), remote, local)
+func (c *connChannel) Recv(rq queue.Q) error {
+ err := queue.IoCopy(rq.Reader(), c.conn)
+ return c.final(c.Close, err)
}
-func (cc *connChannel) shutdown(err *error) {
- miss := true
+func (c *connChannel) String() string {
+ local, remote := c.conn.LocalAddr(), c.conn.RemoteAddr()
+ return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote)
+}
- cc.once.Do(func() {
- miss = false
- log.Println("close", cc)
- if e := cc.conn.Close(); e != nil && *err != nil {
- *err = e
- }
- })
+func (c *connChannel) Close() error {
+ err := errAlreadyClosed
- if miss {
- *err = nil
- }
-}
+ c.once.Do(func() {
+ log.Println("close", c)
+ err = c.conn.Close()
+ })
-func (cc *connChannel) Close() {
- var err error
- cc.shutdown(&err)
+ return err
}
func newListenSocket(proto, addr string) (S, error) {
@@ -154,7 +159,8 @@ func (c *loopChannel) Recv(rq queue.Q) error {
return queue.Copy(rq, c.q)
}
-func (c *loopChannel) Close() {
+func (c *loopChannel) Close() error {
+ return nil
}
func (c *loopChannel) String() string {