From ba26de9f705831cf4d6f932b0ddebca82c96bf58 Mon Sep 17 00:00:00 2001 From: Mikhail Osipov Date: Mon, 19 Oct 2020 10:33:07 +0300 Subject: expand tunnel.* vars, auth proto fix, go test --- TODO | 1 - cmd/tunneld/main.go | 2 +- pkg/server/env/env.go | 64 ++++++++++++++------------ pkg/server/hook/auth.go | 118 ++++++++++++++++++++++++++---------------------- pkg/server/hook/hook.go | 6 +-- pkg/server/server.go | 2 +- pkg/test/env_test.go | 26 +++++++++++ pkg/test/test.go | 34 ++++++++++++++ 8 files changed, 161 insertions(+), 92 deletions(-) create mode 100644 pkg/test/env_test.go create mode 100644 pkg/test/test.go diff --git a/TODO b/TODO index 595409b..852b397 100644 --- a/TODO +++ b/TODO @@ -17,4 +17,3 @@ - proxy socket with connect wait timeout - fix: set a @[a] - info for proxy -- env recursive expand (with Value) diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go index 50c7c82..8f9606a 100644 --- a/cmd/tunneld/main.go +++ b/cmd/tunneld/main.go @@ -298,7 +298,7 @@ func main() { log.Print("ready") - s.Run() + s.Serve() if err := deconfigure(s); err != nil { log.Println(err) diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go index b75d9a8..2c97669 100644 --- a/pkg/server/env/env.go +++ b/pkg/server/env/env.go @@ -10,11 +10,11 @@ import ( type env struct { m map[string]string sync.Mutex + p *env } type Env struct { *env - p *env } const namePattern = "[a-zA-Z][a-zA-Z0-9.]*" @@ -30,15 +30,15 @@ func New() Env { return Env{env: new(env)} } -func (e *env) init() { +func (e Env) init() { if e.m == nil { e.m = make(map[string]string) } } -func (e *env) Fork() Env { +func (e Env) Fork() Env { t := New() - t.p = e + t.p = e.env return t } @@ -47,12 +47,7 @@ func (e *env) Find(key string) (string, bool) { defer e.Unlock() v, ok := e.m[key] - - return v, ok -} - -func (e Env) Find(key string) (string, bool) { - if v, ok := e.env.Find(key); ok { + if ok { return v, ok } @@ -68,11 +63,6 @@ func (e Env) Get(key string) string { return v } -func (e Env) Has(key string) bool { - _, ok := e.Find(key) - return ok -} - func validKeyValue(key string, value string) error { if !isGoodName(key) { return errBadVariable @@ -119,7 +109,7 @@ func (e Env) Push(key string, value string) error { return nil } -func (e *env) Del(key string) bool { +func (e Env) Del(key string) bool { e.Lock() defer e.Unlock() @@ -136,7 +126,7 @@ func (e *env) Del(key string) bool { return true } -func (e *env) Each(f func(string, string) bool) { +func (e Env) Each(f func(string, string) bool) { var keys []string e.Lock() @@ -155,7 +145,7 @@ func (e *env) Each(f func(string, string) bool) { } } -func (e *env) Clear() { +func (e Env) Clear() { e.Lock() defer e.Unlock() @@ -163,10 +153,22 @@ func (e *env) Clear() { } func (e Env) replaceWith(r *regexp.Regexp, s string, f func(string) string) string { - for { - t := r.ReplaceAllStringFunc(s, f) + const maxIter = 16 + + for iter := 0; ; iter++ { + found := false + + t := r.ReplaceAllStringFunc(s, func(v string) string { + found = true + + if iter >= maxIter { + return "" + } + + return f(v) + }) - if t == s { + if !found && t == s { break } @@ -194,18 +196,22 @@ func (e Env) EvalStrings(s []string) []string { return t } +func (e Env) getLocal(key string) string { + if tunnel, ok := e.Find("tunnel"); ok { + if v, ok := e.Find("tunnel." + tunnel + "." + key); ok { + return v + } + } + + return e.Get(key) +} + func (e Env) Expand(s string) string { return e.replaceWith(expandRe, s, func(v string) string { - return e.Get(v[2 : len(v)-1]) + return e.getLocal(v[2 : len(v)-1]) }) } func (e Env) Value(key string) string { - if e.Has("tunnel") { - if v := e.Expand("@[tunnel.@[tunnel]." + key + "]"); v != "" { - return v - } - } - - return e.Expand(e.Get(key)) + return e.Expand(e.getLocal(key)) } diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go index a30e257..5910b56 100644 --- a/pkg/server/hook/auth.go +++ b/pkg/server/hook/auth.go @@ -7,37 +7,38 @@ import ( "io" "sync" "time" - "tunnel/pkg/pack" "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) const authTimeout = 5 * time.Second -const challengeLen = 16 + +const saltSize = 16 +const hashSize = md5.Size type auth struct { h *authHook secret string - challenge struct { + salt struct { self string peer string } hash string - recvChallenge chan struct{} - recvHash chan struct{} + readSaltDone chan struct{} + readHashDone chan struct{} - tmr *time.Timer + ready chan struct{} + cancel chan struct{} - fail chan struct{} - ok chan struct{} + tmr *time.Timer } -var errDupChallenge = errors.New("peer repeats challenge") +var errDupSalt = errors.New("peer repeats salt") var errAuthFail = errors.New("peer auth fail") var errTimeout = errors.New("timeout") @@ -45,24 +46,20 @@ type authHook struct { m sync.Map } -func (a *auth) generateChallenge() error { - b := make([]byte, challengeLen) +func (a *auth) Init() error { + b := make([]byte, saltSize) if _, err := rand.Read(b); err != nil { return err } - a.challenge.self = string(b) + a.salt.self = string(b) - a.h.m.Store(a.challenge.self, struct{}{}) + a.h.m.Store(a.salt.self, struct{}{}) return nil } -func (a *auth) deleteChallenge() { - a.h.m.Delete(a.challenge.self) -} - -func (a *auth) getHash(c string) string { +func (a *auth) hashSum(c string) string { h := md5.New() io.WriteString(h, c) @@ -71,11 +68,11 @@ func (a *auth) getHash(c string) string { return string(h.Sum(nil)) } -func (a *auth) wait(c chan struct{}) error { +func (a *auth) sync(c chan struct{}) error { select { case <-a.tmr.C: return errTimeout - case <-a.fail: + case <-a.cancel: return io.EOF case <-c: return nil @@ -83,63 +80,63 @@ func (a *auth) wait(c chan struct{}) error { } func (a *auth) Send(rq, wq queue.Q) error { - e := pack.NewEncoder(wq.Writer()) + w := wq.Writer() + + io.WriteString(w, a.salt.self) - if err := a.generateChallenge(); err != nil { + if err := a.sync(a.readSaltDone); err != nil { return err } - defer a.deleteChallenge() + if _, ok := a.h.m.Load(a.salt.peer); ok { + return errDupSalt + } - e.Lps([]byte(a.challenge.self)) + io.WriteString(w, a.hashSum(a.salt.peer)) - if err := a.wait(a.recvChallenge); err != nil { + if err := a.sync(a.readHashDone); err != nil { return err } - if _, ok := a.h.m.Load(a.challenge.peer); ok { - return errDupChallenge + if a.hash != a.hashSum(a.salt.self) { + close(a.cancel) + return errAuthFail } - e.Lps([]byte(a.getHash(a.challenge.peer))) + close(a.ready) - if err := a.wait(a.recvHash); err != nil { - return err - } + return queue.IoCopy(rq.Reader(), w) +} - if a.hash != a.getHash(a.challenge.self) { - close(a.fail) - return errAuthFail +func (a *auth) read(r io.Reader, n int, s *string) error { + b := make([]byte, n) + + if _, err := io.ReadFull(r, b); err != nil { + close(a.cancel) + return err } - close(a.ok) + *s = string(b) - return queue.Copy(rq, wq) + return nil } func (a *auth) Recv(rq, wq queue.Q) error { r := rq.Reader() - d := pack.NewDecoder(r) - if b, err := d.Lps(); err != nil { - close(a.fail) + if err := a.read(r, saltSize, &a.salt.peer); err != nil { return err - } else { - a.challenge.peer = string(b) } - close(a.recvChallenge) + close(a.readSaltDone) - if b, err := d.Lps(); err != nil { - close(a.fail) + if err := a.read(r, hashSize, &a.hash); err != nil { return err - } else { - a.hash = string(b) } - close(a.recvHash) + close(a.readHashDone) - if err := a.wait(a.ok); err != nil { + if err := a.sync(a.ready); err != nil { return err } @@ -148,16 +145,27 @@ func (a *auth) Recv(rq, wq queue.Q) error { return queue.IoCopy(r, wq.Writer()) } +func (a *auth) Close() { + a.h.m.Delete(a.salt.self) +} + func (h *authHook) Open(env env.Env) (interface{}, error) { a := &auth{ - h: h, - secret: env.Value("secret"), - recvChallenge: make(chan struct{}), - recvHash: make(chan struct{}), - fail: make(chan struct{}), - ok: make(chan struct{}), - tmr: time.NewTimer(authTimeout), + h: h, + secret: env.Value("secret"), + tmr: time.NewTimer(authTimeout), + + cancel: make(chan struct{}), + ready: make(chan struct{}), + + readSaltDone: make(chan struct{}), + readHashDone: make(chan struct{}), } + + if err := a.Init(); err != nil { + return nil, err + } + return a, nil } diff --git a/pkg/server/hook/hook.go b/pkg/server/hook/hook.go index 3065cbe..69aa237 100644 --- a/pkg/server/hook/hook.go +++ b/pkg/server/hook/hook.go @@ -38,10 +38,6 @@ type Recver interface { Recv(rq, wq queue.Q) error } -type Closer interface { - Close() -} - type Func func(rq, wq queue.Q) error func (f Func) Send(rq, wq queue.Q) error { @@ -86,7 +82,7 @@ func (w *wrapper) Open(env env.Env) (*Pipe, error) { } func (p *Pipe) Close() { - if c, ok := p.priv.(Closer); ok { + if c, ok := p.priv.(interface{ Close() }); ok { c.Close() } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 5ca81d1..d57d9dd 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -160,7 +160,7 @@ func (s *Server) Socket() string { return s.listen.Addr().String() } -func (s *Server) Run() { +func (s *Server) Serve() { for { conn, err := s.listen.Accept() if err != nil { diff --git a/pkg/test/env_test.go b/pkg/test/env_test.go new file mode 100644 index 0000000..b03a5c5 --- /dev/null +++ b/pkg/test/env_test.go @@ -0,0 +1,26 @@ +package test + +import ( + "testing" +) + +func TestEnv(t *testing.T) { + const msg = "Hello, World!" + + c, s, err := newClientServer() + if err != nil { + t.Fatal(err) + } + + defer s.Stop() + defer c.Close() + + r, err := c.Send([]string{"echo", msg}) + if err != nil { + t.Fatal(err) + } + + if r != msg { + t.Errorf("wrong reply: send '%s', recv '%s'", msg, r) + } +} diff --git a/pkg/test/test.go b/pkg/test/test.go new file mode 100644 index 0000000..6a2c776 --- /dev/null +++ b/pkg/test/test.go @@ -0,0 +1,34 @@ +package test + +import ( + "os" + "path/filepath" + "strconv" + + "tunnel/pkg/client" + "tunnel/pkg/server" +) + +func getSocketPath() string { + s := "tunnel.test." + strconv.Itoa(os.Getpid()) + return filepath.Join(os.TempDir(), s) +} + +func newClientServer() (*client.Client, *server.Server, error) { + socket := getSocketPath() + + s, err := server.New(socket) + if err != nil { + return nil, nil, err + } + + go s.Serve() + + c, err := client.New(socket) + if err != nil { + s.Stop() + return nil, nil, err + } + + return c, s, nil +} -- cgit v1.2.3-70-g09d2