summaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/server/hook/auth.go32
-rw-r--r--pkg/server/socket/socket.go4
-rw-r--r--pkg/server/tunnel.go4
3 files changed, 25 insertions, 15 deletions
diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go
index 11068e5..7f02816 100644
--- a/pkg/server/hook/auth.go
+++ b/pkg/server/hook/auth.go
@@ -5,13 +5,15 @@ import (
"crypto/rand"
"errors"
"io"
+ "time"
"tunnel/pkg/netstring"
"tunnel/pkg/server/env"
"tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
-const ChallengeLen = 16
+const authTimeout = 5 * time.Second
+const challengeLen = 16
type auth struct {
secret string
@@ -26,17 +28,20 @@ type auth struct {
recvChallenge chan struct{}
recvHash chan struct{}
+ tmr *time.Timer
+
fail chan struct{}
ok chan struct{}
}
var errDupChallenge = errors.New("peer duplicates challenge")
var errAuthFail = errors.New("peer auth fail")
+var errTimeout = errors.New("timeout")
type authHook struct{}
func (a *auth) generateChallenge() error {
- b := make([]byte, ChallengeLen)
+ b := make([]byte, challengeLen)
if _, err := rand.Read(b); err != nil {
return err
}
@@ -55,12 +60,14 @@ func (a *auth) getHash(c string) string {
return string(h.Sum(nil))
}
-func (a *auth) isReady(c chan struct{}) bool {
+func (a *auth) wait(c chan struct{}) error {
select {
+ case <-a.tmr.C:
+ return errTimeout
case <-a.fail:
- return false
+ return io.EOF
case <-c:
- return true
+ return nil
}
}
@@ -73,8 +80,8 @@ func (a *auth) Send(rq, wq queue.Q) error {
e.Encode(a.challenge.self)
- if !a.isReady(a.recvChallenge) {
- return nil
+ if err := a.wait(a.recvChallenge); err != nil {
+ return err
}
if a.challenge.self == a.challenge.peer {
@@ -83,8 +90,8 @@ func (a *auth) Send(rq, wq queue.Q) error {
e.Encode(a.getHash(a.challenge.peer))
- if !a.isReady(a.recvHash) {
- return nil
+ if err := a.wait(a.recvHash); err != nil {
+ return err
}
if a.hash != a.getHash(a.challenge.self) {
@@ -115,10 +122,12 @@ func (a *auth) Recv(rq, wq queue.Q) (err error) {
close(a.recvHash)
- if !a.isReady(a.ok) {
- return nil
+ if err = a.wait(a.ok); err != nil {
+ return
}
+ a.tmr.Stop()
+
return queue.IoCopy(r, wq.Writer())
}
@@ -129,6 +138,7 @@ func (authHook) Open(env env.Env) (interface{}, error) {
recvHash: make(chan struct{}),
fail: make(chan struct{}),
ok: make(chan struct{}),
+ tmr: time.NewTimer(authTimeout),
}
return a, nil
}
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 1bb7549..5ccf1bd 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -40,8 +40,8 @@ type dialSocket struct {
}
type connChannel struct {
- conn net.Conn
- once sync.Once
+ conn net.Conn
+ once sync.Once
}
type loopSocket struct{}
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index 72d4c13..d59d9db 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -46,7 +46,7 @@ type tunnel struct {
args string
streams map[int]*stream
- recent []*stream
+ recent []*stream
mu sync.Mutex
wg sync.WaitGroup
@@ -186,7 +186,7 @@ func (t *tunnel) delStream(s *stream) {
t.recent = append(t.recent, s)
if len(t.recent) > maxRecentSize {
- t.recent = t.recent[len(t.recent) - maxRecentSize:]
+ t.recent = t.recent[len(t.recent)-maxRecentSize:]
}
}