summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikhail Osipov <mike.osipov@gmail.com>2020-10-19 10:33:07 +0300
committerMikhail Osipov <mike.osipov@gmail.com>2020-10-19 10:55:16 +0300
commitba26de9f705831cf4d6f932b0ddebca82c96bf58 (patch)
tree0aae866abe9ca43481ce920b2b9426c32d7a40f8
parent0ff71a73f4ccca2fb6b366d5896886e281eeac62 (diff)
expand tunnel.* vars, auth proto fix, go test
-rw-r--r--TODO1
-rw-r--r--cmd/tunneld/main.go2
-rw-r--r--pkg/server/env/env.go64
-rw-r--r--pkg/server/hook/auth.go118
-rw-r--r--pkg/server/hook/hook.go6
-rw-r--r--pkg/server/server.go2
-rw-r--r--pkg/test/env_test.go26
-rw-r--r--pkg/test/test.go34
8 files changed, 161 insertions, 92 deletions
diff --git a/TODO b/TODO
index 595409b..852b397 100644
--- a/TODO
+++ b/TODO
@@ -17,4 +17,3 @@
- proxy socket with connect wait timeout
- fix: set a @[a]
- info for proxy
-- env recursive expand (with Value)
diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go
index 50c7c82..8f9606a 100644
--- a/cmd/tunneld/main.go
+++ b/cmd/tunneld/main.go
@@ -298,7 +298,7 @@ func main() {
log.Print("ready")
- s.Run()
+ s.Serve()
if err := deconfigure(s); err != nil {
log.Println(err)
diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go
index b75d9a8..2c97669 100644
--- a/pkg/server/env/env.go
+++ b/pkg/server/env/env.go
@@ -10,11 +10,11 @@ import (
type env struct {
m map[string]string
sync.Mutex
+ p *env
}
type Env struct {
*env
- p *env
}
const namePattern = "[a-zA-Z][a-zA-Z0-9.]*"
@@ -30,15 +30,15 @@ func New() Env {
return Env{env: new(env)}
}
-func (e *env) init() {
+func (e Env) init() {
if e.m == nil {
e.m = make(map[string]string)
}
}
-func (e *env) Fork() Env {
+func (e Env) Fork() Env {
t := New()
- t.p = e
+ t.p = e.env
return t
}
@@ -47,12 +47,7 @@ func (e *env) Find(key string) (string, bool) {
defer e.Unlock()
v, ok := e.m[key]
-
- return v, ok
-}
-
-func (e Env) Find(key string) (string, bool) {
- if v, ok := e.env.Find(key); ok {
+ if ok {
return v, ok
}
@@ -68,11 +63,6 @@ func (e Env) Get(key string) string {
return v
}
-func (e Env) Has(key string) bool {
- _, ok := e.Find(key)
- return ok
-}
-
func validKeyValue(key string, value string) error {
if !isGoodName(key) {
return errBadVariable
@@ -119,7 +109,7 @@ func (e Env) Push(key string, value string) error {
return nil
}
-func (e *env) Del(key string) bool {
+func (e Env) Del(key string) bool {
e.Lock()
defer e.Unlock()
@@ -136,7 +126,7 @@ func (e *env) Del(key string) bool {
return true
}
-func (e *env) Each(f func(string, string) bool) {
+func (e Env) Each(f func(string, string) bool) {
var keys []string
e.Lock()
@@ -155,7 +145,7 @@ func (e *env) Each(f func(string, string) bool) {
}
}
-func (e *env) Clear() {
+func (e Env) Clear() {
e.Lock()
defer e.Unlock()
@@ -163,10 +153,22 @@ func (e *env) Clear() {
}
func (e Env) replaceWith(r *regexp.Regexp, s string, f func(string) string) string {
- for {
- t := r.ReplaceAllStringFunc(s, f)
+ const maxIter = 16
+
+ for iter := 0; ; iter++ {
+ found := false
+
+ t := r.ReplaceAllStringFunc(s, func(v string) string {
+ found = true
+
+ if iter >= maxIter {
+ return ""
+ }
+
+ return f(v)
+ })
- if t == s {
+ if !found && t == s {
break
}
@@ -194,18 +196,22 @@ func (e Env) EvalStrings(s []string) []string {
return t
}
+func (e Env) getLocal(key string) string {
+ if tunnel, ok := e.Find("tunnel"); ok {
+ if v, ok := e.Find("tunnel." + tunnel + "." + key); ok {
+ return v
+ }
+ }
+
+ return e.Get(key)
+}
+
func (e Env) Expand(s string) string {
return e.replaceWith(expandRe, s, func(v string) string {
- return e.Get(v[2 : len(v)-1])
+ return e.getLocal(v[2 : len(v)-1])
})
}
func (e Env) Value(key string) string {
- if e.Has("tunnel") {
- if v := e.Expand("@[tunnel.@[tunnel]." + key + "]"); v != "" {
- return v
- }
- }
-
- return e.Expand(e.Get(key))
+ return e.Expand(e.getLocal(key))
}
diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go
index a30e257..5910b56 100644
--- a/pkg/server/hook/auth.go
+++ b/pkg/server/hook/auth.go
@@ -7,37 +7,38 @@ import (
"io"
"sync"
"time"
- "tunnel/pkg/pack"
"tunnel/pkg/server/env"
"tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
const authTimeout = 5 * time.Second
-const challengeLen = 16
+
+const saltSize = 16
+const hashSize = md5.Size
type auth struct {
h *authHook
secret string
- challenge struct {
+ salt struct {
self string
peer string
}
hash string
- recvChallenge chan struct{}
- recvHash chan struct{}
+ readSaltDone chan struct{}
+ readHashDone chan struct{}
- tmr *time.Timer
+ ready chan struct{}
+ cancel chan struct{}
- fail chan struct{}
- ok chan struct{}
+ tmr *time.Timer
}
-var errDupChallenge = errors.New("peer repeats challenge")
+var errDupSalt = errors.New("peer repeats salt")
var errAuthFail = errors.New("peer auth fail")
var errTimeout = errors.New("timeout")
@@ -45,24 +46,20 @@ type authHook struct {
m sync.Map
}
-func (a *auth) generateChallenge() error {
- b := make([]byte, challengeLen)
+func (a *auth) Init() error {
+ b := make([]byte, saltSize)
if _, err := rand.Read(b); err != nil {
return err
}
- a.challenge.self = string(b)
+ a.salt.self = string(b)
- a.h.m.Store(a.challenge.self, struct{}{})
+ a.h.m.Store(a.salt.self, struct{}{})
return nil
}
-func (a *auth) deleteChallenge() {
- a.h.m.Delete(a.challenge.self)
-}
-
-func (a *auth) getHash(c string) string {
+func (a *auth) hashSum(c string) string {
h := md5.New()
io.WriteString(h, c)
@@ -71,11 +68,11 @@ func (a *auth) getHash(c string) string {
return string(h.Sum(nil))
}
-func (a *auth) wait(c chan struct{}) error {
+func (a *auth) sync(c chan struct{}) error {
select {
case <-a.tmr.C:
return errTimeout
- case <-a.fail:
+ case <-a.cancel:
return io.EOF
case <-c:
return nil
@@ -83,63 +80,63 @@ func (a *auth) wait(c chan struct{}) error {
}
func (a *auth) Send(rq, wq queue.Q) error {
- e := pack.NewEncoder(wq.Writer())
+ w := wq.Writer()
+
+ io.WriteString(w, a.salt.self)
- if err := a.generateChallenge(); err != nil {
+ if err := a.sync(a.readSaltDone); err != nil {
return err
}
- defer a.deleteChallenge()
+ if _, ok := a.h.m.Load(a.salt.peer); ok {
+ return errDupSalt
+ }
- e.Lps([]byte(a.challenge.self))
+ io.WriteString(w, a.hashSum(a.salt.peer))
- if err := a.wait(a.recvChallenge); err != nil {
+ if err := a.sync(a.readHashDone); err != nil {
return err
}
- if _, ok := a.h.m.Load(a.challenge.peer); ok {
- return errDupChallenge
+ if a.hash != a.hashSum(a.salt.self) {
+ close(a.cancel)
+ return errAuthFail
}
- e.Lps([]byte(a.getHash(a.challenge.peer)))
+ close(a.ready)
- if err := a.wait(a.recvHash); err != nil {
- return err
- }
+ return queue.IoCopy(rq.Reader(), w)
+}
- if a.hash != a.getHash(a.challenge.self) {
- close(a.fail)
- return errAuthFail
+func (a *auth) read(r io.Reader, n int, s *string) error {
+ b := make([]byte, n)
+
+ if _, err := io.ReadFull(r, b); err != nil {
+ close(a.cancel)
+ return err
}
- close(a.ok)
+ *s = string(b)
- return queue.Copy(rq, wq)
+ return nil
}
func (a *auth) Recv(rq, wq queue.Q) error {
r := rq.Reader()
- d := pack.NewDecoder(r)
- if b, err := d.Lps(); err != nil {
- close(a.fail)
+ if err := a.read(r, saltSize, &a.salt.peer); err != nil {
return err
- } else {
- a.challenge.peer = string(b)
}
- close(a.recvChallenge)
+ close(a.readSaltDone)
- if b, err := d.Lps(); err != nil {
- close(a.fail)
+ if err := a.read(r, hashSize, &a.hash); err != nil {
return err
- } else {
- a.hash = string(b)
}
- close(a.recvHash)
+ close(a.readHashDone)
- if err := a.wait(a.ok); err != nil {
+ if err := a.sync(a.ready); err != nil {
return err
}
@@ -148,16 +145,27 @@ func (a *auth) Recv(rq, wq queue.Q) error {
return queue.IoCopy(r, wq.Writer())
}
+func (a *auth) Close() {
+ a.h.m.Delete(a.salt.self)
+}
+
func (h *authHook) Open(env env.Env) (interface{}, error) {
a := &auth{
- h: h,
- secret: env.Value("secret"),
- recvChallenge: make(chan struct{}),
- recvHash: make(chan struct{}),
- fail: make(chan struct{}),
- ok: make(chan struct{}),
- tmr: time.NewTimer(authTimeout),
+ h: h,
+ secret: env.Value("secret"),
+ tmr: time.NewTimer(authTimeout),
+
+ cancel: make(chan struct{}),
+ ready: make(chan struct{}),
+
+ readSaltDone: make(chan struct{}),
+ readHashDone: make(chan struct{}),
}
+
+ if err := a.Init(); err != nil {
+ return nil, err
+ }
+
return a, nil
}
diff --git a/pkg/server/hook/hook.go b/pkg/server/hook/hook.go
index 3065cbe..69aa237 100644
--- a/pkg/server/hook/hook.go
+++ b/pkg/server/hook/hook.go
@@ -38,10 +38,6 @@ type Recver interface {
Recv(rq, wq queue.Q) error
}
-type Closer interface {
- Close()
-}
-
type Func func(rq, wq queue.Q) error
func (f Func) Send(rq, wq queue.Q) error {
@@ -86,7 +82,7 @@ func (w *wrapper) Open(env env.Env) (*Pipe, error) {
}
func (p *Pipe) Close() {
- if c, ok := p.priv.(Closer); ok {
+ if c, ok := p.priv.(interface{ Close() }); ok {
c.Close()
}
}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 5ca81d1..d57d9dd 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -160,7 +160,7 @@ func (s *Server) Socket() string {
return s.listen.Addr().String()
}
-func (s *Server) Run() {
+func (s *Server) Serve() {
for {
conn, err := s.listen.Accept()
if err != nil {
diff --git a/pkg/test/env_test.go b/pkg/test/env_test.go
new file mode 100644
index 0000000..b03a5c5
--- /dev/null
+++ b/pkg/test/env_test.go
@@ -0,0 +1,26 @@
+package test
+
+import (
+ "testing"
+)
+
+func TestEnv(t *testing.T) {
+ const msg = "Hello, World!"
+
+ c, s, err := newClientServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ defer s.Stop()
+ defer c.Close()
+
+ r, err := c.Send([]string{"echo", msg})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if r != msg {
+ t.Errorf("wrong reply: send '%s', recv '%s'", msg, r)
+ }
+}
diff --git a/pkg/test/test.go b/pkg/test/test.go
new file mode 100644
index 0000000..6a2c776
--- /dev/null
+++ b/pkg/test/test.go
@@ -0,0 +1,34 @@
+package test
+
+import (
+ "os"
+ "path/filepath"
+ "strconv"
+
+ "tunnel/pkg/client"
+ "tunnel/pkg/server"
+)
+
+func getSocketPath() string {
+ s := "tunnel.test." + strconv.Itoa(os.Getpid())
+ return filepath.Join(os.TempDir(), s)
+}
+
+func newClientServer() (*client.Client, *server.Server, error) {
+ socket := getSocketPath()
+
+ s, err := server.New(socket)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ go s.Serve()
+
+ c, err := client.New(socket)
+ if err != nil {
+ s.Stop()
+ return nil, nil, err
+ }
+
+ return c, s, nil
+}