package module import ( "crypto/md5" "crypto/rand" "errors" "fmt" "io" "tunnel/pkg/netstring" "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) const ChallengeLen = 16 type auth struct { secret string challenge struct { self string peer string } hash string recvChallenge chan struct{} recvHash chan struct{} fail chan struct{} ok chan struct{} } var errDupChallenge = errors.New("peer duplicates challenge") var errAuthFail = errors.New("peer auth fail") type authModule struct{} func (a *auth) generateChallenge() error { b := make([]byte, ChallengeLen) if _, err := rand.Read(b); err != nil { return err } a.challenge.self = string(b) return nil } func (a *auth) sendChallenge(q queue.Q) bool { enc := netstring.NewEncoder(q.Writer()) enc.Encode(a.challenge.self) return a.wait(a.recvChallenge) } func (a *auth) getHash(c string) string { h := md5.New() io.WriteString(h, a.secret) io.WriteString(h, c) return string(h.Sum(nil)) } func (a *auth) sendHash(q queue.Q) bool { enc := netstring.NewEncoder(q.Writer()) enc.Encode(a.getHash(a.challenge.peer)) return a.wait(a.recvHash) } func (a *auth) wait(c chan struct{}) bool { select { case <-a.fail: return false case <-c: return true } } func (a *auth) Send(rq, wq queue.Q) error { if err := a.generateChallenge(); err != nil { return err } if !a.sendChallenge(wq) { return nil } if a.challenge.self == a.challenge.peer { return errDupChallenge } if !a.sendHash(wq) { return nil } if a.hash != a.getHash(a.challenge.self) { close(a.fail) return errAuthFail } close(a.ok) return queue.Copy(rq, wq) } func (a *auth) Recv(rq, wq queue.Q) error { dec := netstring.NewDecoder(rq.Reader()) if c, err := dec.Decode(); err != nil { close(a.fail) return err } else { a.challenge.peer = c close(a.recvChallenge) } if h, err := dec.Decode(); err != nil { close(a.fail) return err } else { a.hash = h close(a.recvHash) } if !a.wait(a.ok) { return nil } return queue.Copy(rq, wq) } func getAuthSecret(env env.Env) string { if id, ok := env.Find("tunnel"); ok { k := fmt.Sprintf("tunnel.%s.secret", id) if v, ok := env.Find(k); ok { return v } } return env.Get("secret") } func (m authModule) Open(env env.Env) (Pipe, Pipe) { a := &auth{ secret: getAuthSecret(env), recvChallenge: make(chan struct{}), recvHash: make(chan struct{}), fail: make(chan struct{}), ok: make(chan struct{}), } return a.Send, a.Recv } func init() { register("auth", func(opts.Opts, env.Env) (module, error) { return authModule{}, nil }) }