diff options
| -rw-r--r-- | pkg/server/module/auth.go | 23 | ||||
| -rw-r--r-- | pkg/server/queue/queue.go | 24 | ||||
| -rw-r--r-- | pkg/server/socket/socket.go | 13 |
3 files changed, 36 insertions, 24 deletions
diff --git a/pkg/server/module/auth.go b/pkg/server/module/auth.go index d4bc8b3..d269bf6 100644 --- a/pkg/server/module/auth.go +++ b/pkg/server/module/auth.go @@ -104,30 +104,29 @@ func (a *auth) Send(rq, wq queue.Q) error { return queue.Copy(rq, wq) } -func (a *auth) Recv(rq, wq queue.Q) error { - dec := netstring.NewDecoder(rq.Reader()) +func (a *auth) Recv(rq, wq queue.Q) (err error) { + r := rq.Reader() + d := netstring.NewDecoder(r) - if c, err := dec.Decode(); err != nil { + if a.challenge.peer, err = d.Decode(); err != nil { close(a.fail) - return err - } else { - a.challenge.peer = c - close(a.recvChallenge) + return } - if h, err := dec.Decode(); err != nil { + close(a.recvChallenge) + + if a.hash, err = d.Decode(); err != nil { close(a.fail) return err - } else { - a.hash = h - close(a.recvHash) } + close(a.recvHash) + if !a.wait(a.ok) { return nil } - return queue.Copy(rq, wq) + return queue.IoCopy(r, wq.Writer()) } func getAuthSecret(env env.Env) string { diff --git a/pkg/server/queue/queue.go b/pkg/server/queue/queue.go index 979fa33..4b69e3a 100644 --- a/pkg/server/queue/queue.go +++ b/pkg/server/queue/queue.go @@ -37,6 +37,30 @@ func (r *reader) Read(p []byte) (int, error) { return n, nil } +func (r *reader) WriteTo(w io.Writer) (int64, error) { + if writer, ok := w.(*writer); ok { + if len(r.b) > 0 { + writer.q <- r.b + } + + return 0, Copy(r.q, writer.q) + } + + if len(r.b) > 0 { + if _, err := w.Write(r.b); err != nil { + return 0, err + } + } + + for b := range r.q { + if _, err := w.Write(b); err != nil { + return 0, err + } + } + + return 0, nil +} + func (q Q) Writer() io.Writer { return &writer{q: q} } diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index 48a650c..3db4310 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -47,18 +47,7 @@ func (cc *connChannel) Send(wq queue.Q) (err error) { func (cc *connChannel) Recv(rq queue.Q) (err error) { defer cc.shutdown(&err) - - for b := range rq { - for len(b) > 0 { - n, err := cc.conn.Write(b) - if err != nil { - return err - } - b = b[n:] - } - } - - return nil + return queue.IoCopy(rq.Reader(), cc.conn) } func (cc *connChannel) String() string { |
