diff options
Diffstat (limited to 'pkg/server/hook/auth.go')
| -rw-r--r-- | pkg/server/hook/auth.go | 118 |
1 files changed, 63 insertions, 55 deletions
diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go index a30e257..5910b56 100644 --- a/pkg/server/hook/auth.go +++ b/pkg/server/hook/auth.go @@ -7,37 +7,38 @@ import ( "io" "sync" "time" - "tunnel/pkg/pack" "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) const authTimeout = 5 * time.Second -const challengeLen = 16 + +const saltSize = 16 +const hashSize = md5.Size type auth struct { h *authHook secret string - challenge struct { + salt struct { self string peer string } hash string - recvChallenge chan struct{} - recvHash chan struct{} + readSaltDone chan struct{} + readHashDone chan struct{} - tmr *time.Timer + ready chan struct{} + cancel chan struct{} - fail chan struct{} - ok chan struct{} + tmr *time.Timer } -var errDupChallenge = errors.New("peer repeats challenge") +var errDupSalt = errors.New("peer repeats salt") var errAuthFail = errors.New("peer auth fail") var errTimeout = errors.New("timeout") @@ -45,24 +46,20 @@ type authHook struct { m sync.Map } -func (a *auth) generateChallenge() error { - b := make([]byte, challengeLen) +func (a *auth) Init() error { + b := make([]byte, saltSize) if _, err := rand.Read(b); err != nil { return err } - a.challenge.self = string(b) + a.salt.self = string(b) - a.h.m.Store(a.challenge.self, struct{}{}) + a.h.m.Store(a.salt.self, struct{}{}) return nil } -func (a *auth) deleteChallenge() { - a.h.m.Delete(a.challenge.self) -} - -func (a *auth) getHash(c string) string { +func (a *auth) hashSum(c string) string { h := md5.New() io.WriteString(h, c) @@ -71,11 +68,11 @@ func (a *auth) getHash(c string) string { return string(h.Sum(nil)) } -func (a *auth) wait(c chan struct{}) error { +func (a *auth) sync(c chan struct{}) error { select { case <-a.tmr.C: return errTimeout - case <-a.fail: + case <-a.cancel: return io.EOF case <-c: return nil @@ -83,63 +80,63 @@ func (a *auth) wait(c chan struct{}) error { } func (a *auth) Send(rq, wq queue.Q) error { - e := pack.NewEncoder(wq.Writer()) + w := wq.Writer() + + io.WriteString(w, a.salt.self) - if err := a.generateChallenge(); err != nil { + if err := a.sync(a.readSaltDone); err != nil { return err } - defer a.deleteChallenge() + if _, ok := a.h.m.Load(a.salt.peer); ok { + return errDupSalt + } - e.Lps([]byte(a.challenge.self)) + io.WriteString(w, a.hashSum(a.salt.peer)) - if err := a.wait(a.recvChallenge); err != nil { + if err := a.sync(a.readHashDone); err != nil { return err } - if _, ok := a.h.m.Load(a.challenge.peer); ok { - return errDupChallenge + if a.hash != a.hashSum(a.salt.self) { + close(a.cancel) + return errAuthFail } - e.Lps([]byte(a.getHash(a.challenge.peer))) + close(a.ready) - if err := a.wait(a.recvHash); err != nil { - return err - } + return queue.IoCopy(rq.Reader(), w) +} - if a.hash != a.getHash(a.challenge.self) { - close(a.fail) - return errAuthFail +func (a *auth) read(r io.Reader, n int, s *string) error { + b := make([]byte, n) + + if _, err := io.ReadFull(r, b); err != nil { + close(a.cancel) + return err } - close(a.ok) + *s = string(b) - return queue.Copy(rq, wq) + return nil } func (a *auth) Recv(rq, wq queue.Q) error { r := rq.Reader() - d := pack.NewDecoder(r) - if b, err := d.Lps(); err != nil { - close(a.fail) + if err := a.read(r, saltSize, &a.salt.peer); err != nil { return err - } else { - a.challenge.peer = string(b) } - close(a.recvChallenge) + close(a.readSaltDone) - if b, err := d.Lps(); err != nil { - close(a.fail) + if err := a.read(r, hashSize, &a.hash); err != nil { return err - } else { - a.hash = string(b) } - close(a.recvHash) + close(a.readHashDone) - if err := a.wait(a.ok); err != nil { + if err := a.sync(a.ready); err != nil { return err } @@ -148,16 +145,27 @@ func (a *auth) Recv(rq, wq queue.Q) error { return queue.IoCopy(r, wq.Writer()) } +func (a *auth) Close() { + a.h.m.Delete(a.salt.self) +} + func (h *authHook) Open(env env.Env) (interface{}, error) { a := &auth{ - h: h, - secret: env.Value("secret"), - recvChallenge: make(chan struct{}), - recvHash: make(chan struct{}), - fail: make(chan struct{}), - ok: make(chan struct{}), - tmr: time.NewTimer(authTimeout), + h: h, + secret: env.Value("secret"), + tmr: time.NewTimer(authTimeout), + + cancel: make(chan struct{}), + ready: make(chan struct{}), + + readSaltDone: make(chan struct{}), + readHashDone: make(chan struct{}), } + + if err := a.Init(); err != nil { + return nil, err + } + return a, nil } |
