package hook import ( "crypto/md5" "crypto/rand" "errors" "io" "sync" "time" "tunnel/pkg/netstring" "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) const authTimeout = 5 * time.Second const challengeLen = 16 type auth struct { h *authHook secret string challenge struct { self string peer string } hash string recvChallenge chan struct{} recvHash chan struct{} tmr *time.Timer fail chan struct{} ok chan struct{} } var errDupChallenge = errors.New("peer repeats challenge") var errAuthFail = errors.New("peer auth fail") var errTimeout = errors.New("timeout") type authHook struct { m sync.Map } func (a *auth) generateChallenge() error { b := make([]byte, challengeLen) if _, err := rand.Read(b); err != nil { return err } a.challenge.self = string(b) a.h.m.Store(a.challenge.self, struct{}{}) return nil } func (a *auth) deleteChallenge() { a.h.m.Delete(a.challenge.self) } func (a *auth) getHash(c string) string { h := md5.New() io.WriteString(h, c) io.WriteString(h, a.secret) return string(h.Sum(nil)) } func (a *auth) wait(c chan struct{}) error { select { case <-a.tmr.C: return errTimeout case <-a.fail: return io.EOF case <-c: return nil } } func (a *auth) Send(rq, wq queue.Q) error { e := netstring.NewEncoder(wq.Writer()) if err := a.generateChallenge(); err != nil { return err } defer a.deleteChallenge() e.Encode(a.challenge.self) if err := a.wait(a.recvChallenge); err != nil { return err } if _, ok := a.h.m.Load(a.challenge.peer); ok { return errDupChallenge } e.Encode(a.getHash(a.challenge.peer)) if err := a.wait(a.recvHash); err != nil { return err } if a.hash != a.getHash(a.challenge.self) { close(a.fail) return errAuthFail } close(a.ok) return queue.Copy(rq, wq) } func (a *auth) Recv(rq, wq queue.Q) (err error) { r := rq.Reader() d := netstring.NewDecoder(r) if a.challenge.peer, err = d.Decode(); err != nil { close(a.fail) return } close(a.recvChallenge) if a.hash, err = d.Decode(); err != nil { close(a.fail) return err } close(a.recvHash) if err = a.wait(a.ok); err != nil { return } a.tmr.Stop() return queue.IoCopy(r, wq.Writer()) } func (h *authHook) Open(env env.Env) (interface{}, error) { a := &auth{ h: h, secret: getHookVar(env, "secret"), recvChallenge: make(chan struct{}), recvHash: make(chan struct{}), fail: make(chan struct{}), ok: make(chan struct{}), tmr: time.NewTimer(authTimeout), } return a, nil } func newAuthHook(opts.Opts, env.Env) (hook, error) { return &authHook{}, nil } func init() { register("auth", newAuthHook) }