package hook import ( "crypto/md5" "crypto/rand" "errors" "io" "sync" "time" "tunnel/pkg/server/env" "tunnel/pkg/server/queue" ) const authTimeout = 5 * time.Second const saltSize = 16 const hashSize = md5.Size type authHook struct { m sync.Map } type auth struct { h *authHook secret string salt struct { self string peer string } hash string readSaltDone chan struct{} readHashDone chan struct{} ready chan struct{} cancel chan struct{} tmr *time.Timer } var errDupSalt = errors.New("peer repeats salt") var errAuthFail = errors.New("peer auth fail") var errTimeout = errors.New("timeout") func (a *auth) Init() error { b := make([]byte, saltSize) if _, err := rand.Read(b); err != nil { return err } a.salt.self = string(b) a.h.m.Store(a.salt.self, struct{}{}) return nil } func (a *auth) hashSum(c string) string { h := md5.New() io.WriteString(h, c) io.WriteString(h, a.secret) return string(h.Sum(nil)) } func (a *auth) sync(c chan struct{}) error { select { case <-a.tmr.C: return errTimeout case <-a.cancel: return io.EOF case <-c: return nil } } func (a *auth) Send(rq, wq queue.Q) error { w := wq.Writer() io.WriteString(w, a.salt.self) if err := a.sync(a.readSaltDone); err != nil { return err } if _, ok := a.h.m.Load(a.salt.peer); ok { return errDupSalt } io.WriteString(w, a.hashSum(a.salt.peer)) if err := a.sync(a.readHashDone); err != nil { return err } if a.hash != a.hashSum(a.salt.self) { close(a.cancel) return errAuthFail } close(a.ready) return queue.IoCopy(rq.Reader(), w) } func (a *auth) read(r io.Reader, n int, s *string) error { b := make([]byte, n) if _, err := io.ReadFull(r, b); err != nil { close(a.cancel) return err } *s = string(b) return nil } func (a *auth) Recv(rq, wq queue.Q) error { r := rq.Reader() if err := a.read(r, saltSize, &a.salt.peer); err != nil { return err } close(a.readSaltDone) if err := a.read(r, hashSize, &a.hash); err != nil { return err } close(a.readHashDone) if err := a.sync(a.ready); err != nil { return err } a.tmr.Stop() return queue.IoCopy(r, wq.Writer()) } func (a *auth) Close() { a.h.m.Delete(a.salt.self) } func (h *authHook) New(env env.Env) (interface{}, error) { a := &auth{ h: h, secret: env.Value("secret"), tmr: time.NewTimer(authTimeout), cancel: make(chan struct{}), ready: make(chan struct{}), readSaltDone: make(chan struct{}), readHashDone: make(chan struct{}), } if err := a.Init(); err != nil { return nil, err } return a, nil } func init() { register("auth", authHook{}) }