diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-03-03 10:48:04 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-03-03 13:45:33 +0300 |
| commit | 61287a0faab06105c382b5d24c4942a8b4cc960a (patch) | |
| tree | 0966780eb2a0a4ffd9d2738f486a0a984fd49476 /pkg/server/socket/socket.go | |
| parent | e6a63987f6963241dcfa981bf1081206a06f2990 (diff) | |
handle channel close
Diffstat (limited to 'pkg/server/socket/socket.go')
| -rw-r--r-- | pkg/server/socket/socket.go | 60 |
1 files changed, 33 insertions, 27 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index 4629e22..cebfe47 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "net" + "errors" "strings" "sync" "tunnel/pkg/server/env" @@ -11,10 +12,12 @@ import ( "tunnel/pkg/server/queue" ) +var errAlreadyClosed = errors.New("already closed") + type Channel interface { Send(wq queue.Q) error Recv(rq queue.Q) error - Close() + Close() error } type S interface { @@ -46,40 +49,42 @@ func newConnChannel(conn net.Conn) Channel { return &connChannel{conn: conn} } -func (cc *connChannel) Send(wq queue.Q) (err error) { - defer cc.shutdown(&err) - return queue.IoCopy(cc.conn, wq.Writer()) +func (c *connChannel) final(f func() error, err error) error { + if e := f(); e != nil { + if e == errAlreadyClosed { + return nil + } else { + return e + } + } + + return err } -func (cc *connChannel) Recv(rq queue.Q) (err error) { - defer cc.shutdown(&err) - return queue.IoCopy(rq.Reader(), cc.conn) +func (c *connChannel) Send(wq queue.Q) error { + err := queue.IoCopy(c.conn, wq.Writer()) + return c.final(c.Close, err) } -func (cc *connChannel) String() string { - local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr() - return fmt.Sprintf("%s/%s->%s", local.Network(), remote, local) +func (c *connChannel) Recv(rq queue.Q) error { + err := queue.IoCopy(rq.Reader(), c.conn) + return c.final(c.Close, err) } -func (cc *connChannel) shutdown(err *error) { - miss := true +func (c *connChannel) String() string { + local, remote := c.conn.LocalAddr(), c.conn.RemoteAddr() + return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote) +} - cc.once.Do(func() { - miss = false - log.Println("close", cc) - if e := cc.conn.Close(); e != nil && *err != nil { - *err = e - } - }) +func (c *connChannel) Close() error { + err := errAlreadyClosed - if miss { - *err = nil - } -} + c.once.Do(func() { + log.Println("close", c) + err = c.conn.Close() + }) -func (cc *connChannel) Close() { - var err error - cc.shutdown(&err) + return err } func newListenSocket(proto, addr string) (S, error) { @@ -154,7 +159,8 @@ func (c *loopChannel) Recv(rq queue.Q) error { return queue.Copy(rq, c.q) } -func (c *loopChannel) Close() { +func (c *loopChannel) Close() error { + return nil } func (c *loopChannel) String() string { |
