summaryrefslogtreecommitdiff
path: root/pkg/server/hook
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server/hook')
-rw-r--r--pkg/server/hook/auth.go118
-rw-r--r--pkg/server/hook/hook.go6
2 files changed, 64 insertions, 60 deletions
diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go
index a30e257..5910b56 100644
--- a/pkg/server/hook/auth.go
+++ b/pkg/server/hook/auth.go
@@ -7,37 +7,38 @@ import (
"io"
"sync"
"time"
- "tunnel/pkg/pack"
"tunnel/pkg/server/env"
"tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
const authTimeout = 5 * time.Second
-const challengeLen = 16
+
+const saltSize = 16
+const hashSize = md5.Size
type auth struct {
h *authHook
secret string
- challenge struct {
+ salt struct {
self string
peer string
}
hash string
- recvChallenge chan struct{}
- recvHash chan struct{}
+ readSaltDone chan struct{}
+ readHashDone chan struct{}
- tmr *time.Timer
+ ready chan struct{}
+ cancel chan struct{}
- fail chan struct{}
- ok chan struct{}
+ tmr *time.Timer
}
-var errDupChallenge = errors.New("peer repeats challenge")
+var errDupSalt = errors.New("peer repeats salt")
var errAuthFail = errors.New("peer auth fail")
var errTimeout = errors.New("timeout")
@@ -45,24 +46,20 @@ type authHook struct {
m sync.Map
}
-func (a *auth) generateChallenge() error {
- b := make([]byte, challengeLen)
+func (a *auth) Init() error {
+ b := make([]byte, saltSize)
if _, err := rand.Read(b); err != nil {
return err
}
- a.challenge.self = string(b)
+ a.salt.self = string(b)
- a.h.m.Store(a.challenge.self, struct{}{})
+ a.h.m.Store(a.salt.self, struct{}{})
return nil
}
-func (a *auth) deleteChallenge() {
- a.h.m.Delete(a.challenge.self)
-}
-
-func (a *auth) getHash(c string) string {
+func (a *auth) hashSum(c string) string {
h := md5.New()
io.WriteString(h, c)
@@ -71,11 +68,11 @@ func (a *auth) getHash(c string) string {
return string(h.Sum(nil))
}
-func (a *auth) wait(c chan struct{}) error {
+func (a *auth) sync(c chan struct{}) error {
select {
case <-a.tmr.C:
return errTimeout
- case <-a.fail:
+ case <-a.cancel:
return io.EOF
case <-c:
return nil
@@ -83,63 +80,63 @@ func (a *auth) wait(c chan struct{}) error {
}
func (a *auth) Send(rq, wq queue.Q) error {
- e := pack.NewEncoder(wq.Writer())
+ w := wq.Writer()
+
+ io.WriteString(w, a.salt.self)
- if err := a.generateChallenge(); err != nil {
+ if err := a.sync(a.readSaltDone); err != nil {
return err
}
- defer a.deleteChallenge()
+ if _, ok := a.h.m.Load(a.salt.peer); ok {
+ return errDupSalt
+ }
- e.Lps([]byte(a.challenge.self))
+ io.WriteString(w, a.hashSum(a.salt.peer))
- if err := a.wait(a.recvChallenge); err != nil {
+ if err := a.sync(a.readHashDone); err != nil {
return err
}
- if _, ok := a.h.m.Load(a.challenge.peer); ok {
- return errDupChallenge
+ if a.hash != a.hashSum(a.salt.self) {
+ close(a.cancel)
+ return errAuthFail
}
- e.Lps([]byte(a.getHash(a.challenge.peer)))
+ close(a.ready)
- if err := a.wait(a.recvHash); err != nil {
- return err
- }
+ return queue.IoCopy(rq.Reader(), w)
+}
- if a.hash != a.getHash(a.challenge.self) {
- close(a.fail)
- return errAuthFail
+func (a *auth) read(r io.Reader, n int, s *string) error {
+ b := make([]byte, n)
+
+ if _, err := io.ReadFull(r, b); err != nil {
+ close(a.cancel)
+ return err
}
- close(a.ok)
+ *s = string(b)
- return queue.Copy(rq, wq)
+ return nil
}
func (a *auth) Recv(rq, wq queue.Q) error {
r := rq.Reader()
- d := pack.NewDecoder(r)
- if b, err := d.Lps(); err != nil {
- close(a.fail)
+ if err := a.read(r, saltSize, &a.salt.peer); err != nil {
return err
- } else {
- a.challenge.peer = string(b)
}
- close(a.recvChallenge)
+ close(a.readSaltDone)
- if b, err := d.Lps(); err != nil {
- close(a.fail)
+ if err := a.read(r, hashSize, &a.hash); err != nil {
return err
- } else {
- a.hash = string(b)
}
- close(a.recvHash)
+ close(a.readHashDone)
- if err := a.wait(a.ok); err != nil {
+ if err := a.sync(a.ready); err != nil {
return err
}
@@ -148,16 +145,27 @@ func (a *auth) Recv(rq, wq queue.Q) error {
return queue.IoCopy(r, wq.Writer())
}
+func (a *auth) Close() {
+ a.h.m.Delete(a.salt.self)
+}
+
func (h *authHook) Open(env env.Env) (interface{}, error) {
a := &auth{
- h: h,
- secret: env.Value("secret"),
- recvChallenge: make(chan struct{}),
- recvHash: make(chan struct{}),
- fail: make(chan struct{}),
- ok: make(chan struct{}),
- tmr: time.NewTimer(authTimeout),
+ h: h,
+ secret: env.Value("secret"),
+ tmr: time.NewTimer(authTimeout),
+
+ cancel: make(chan struct{}),
+ ready: make(chan struct{}),
+
+ readSaltDone: make(chan struct{}),
+ readHashDone: make(chan struct{}),
}
+
+ if err := a.Init(); err != nil {
+ return nil, err
+ }
+
return a, nil
}
diff --git a/pkg/server/hook/hook.go b/pkg/server/hook/hook.go
index 3065cbe..69aa237 100644
--- a/pkg/server/hook/hook.go
+++ b/pkg/server/hook/hook.go
@@ -38,10 +38,6 @@ type Recver interface {
Recv(rq, wq queue.Q) error
}
-type Closer interface {
- Close()
-}
-
type Func func(rq, wq queue.Q) error
func (f Func) Send(rq, wq queue.Q) error {
@@ -86,7 +82,7 @@ func (w *wrapper) Open(env env.Env) (*Pipe, error) {
}
func (p *Pipe) Close() {
- if c, ok := p.priv.(Closer); ok {
+ if c, ok := p.priv.(interface{ Close() }); ok {
c.Close()
}
}