diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-02-26 02:38:02 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-02-26 02:38:57 +0300 |
| commit | e6631acaa5af979d87645d74997955a1304cd648 (patch) | |
| tree | ba1a13102d437bc908a7c32ce604a802d5098868 | |
| parent | d6b87d6ad43219bf5b9cdfef50697e1b066dc4ea (diff) | |
[server] prepare module.Open to return error
| -rw-r--r-- | pkg/server/module/aes.go | 19 | ||||
| -rw-r--r-- | pkg/server/module/alpha.go | 6 | ||||
| -rw-r--r-- | pkg/server/module/auth.go | 30 | ||||
| -rw-r--r-- | pkg/server/module/hex.go | 8 | ||||
| -rw-r--r-- | pkg/server/module/module.go | 88 | ||||
| -rw-r--r-- | pkg/server/module/split.go | 4 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 34 |
7 files changed, 104 insertions, 85 deletions
diff --git a/pkg/server/module/aes.go b/pkg/server/module/aes.go index 68dcc7c..13153bc 100644 --- a/pkg/server/module/aes.go +++ b/pkg/server/module/aes.go @@ -11,13 +11,13 @@ import ( "tunnel/pkg/server/queue" ) -type aesInfo struct { +type aesModule struct{} + +type aesPipe struct { key []byte } -type aesModule struct{} - -func (a *aesInfo) Send(rq, wq queue.Q) error { +func (a *aesPipe) Send(rq, wq queue.Q) error { block, err := aes.NewCipher(a.key) if err != nil { return err @@ -39,7 +39,7 @@ func (a *aesInfo) Send(rq, wq queue.Q) error { return queue.IoCopy(rq.Reader(), writer) } -func (a *aesInfo) Recv(rq, wq queue.Q) error { +func (a *aesPipe) Recv(rq, wq queue.Q) error { block, err := aes.NewCipher(a.key) if err != nil { return err @@ -64,19 +64,18 @@ func (a *aesInfo) Recv(rq, wq queue.Q) error { return queue.IoCopy(reader, wq.Writer()) } -func newAes(env env.Env) *aesInfo { +func newAes(env env.Env) *aesPipe { s := getAuthSecret(env) h := md5.Sum([]byte(s)) - a := &aesInfo{key: make([]byte, 16)} + a := &aesPipe{key: make([]byte, 16)} copy(a.key, h[:]) return a } -func (m aesModule) Open(env env.Env) (Pipe, Pipe) { - a := newAes(env) - return a.Send, a.Recv +func (m aesModule) Open(env env.Env) (interface{}, error) { + return newAes(env), nil } func init() { diff --git a/pkg/server/module/alpha.go b/pkg/server/module/alpha.go index 4552c50..8174b25 100644 --- a/pkg/server/module/alpha.go +++ b/pkg/server/module/alpha.go @@ -7,7 +7,7 @@ import ( "unicode" ) -func alpha(cb func(rune) rune) Pipe { +func alpha(cb func(rune) rune) Func { return func(rq, wq queue.Q) error { r := bufio.NewReader(rq.Reader()) @@ -27,6 +27,6 @@ func alpha(cb func(rune) rune) Pipe { } func init() { - registerPipe("lower", alpha(unicode.ToLower)) - registerPipe("upper", alpha(unicode.ToUpper)) + registerFunc("lower", alpha(unicode.ToLower)) + registerFunc("upper", alpha(unicode.ToUpper)) } diff --git a/pkg/server/module/auth.go b/pkg/server/module/auth.go index d269bf6..de58e82 100644 --- a/pkg/server/module/auth.go +++ b/pkg/server/module/auth.go @@ -47,12 +47,6 @@ func (a *auth) generateChallenge() error { return nil } -func (a *auth) sendChallenge(q queue.Q) bool { - enc := netstring.NewEncoder(q.Writer()) - enc.Encode(a.challenge.self) - return a.wait(a.recvChallenge) -} - func (a *auth) getHash(c string) string { h := md5.New() @@ -62,13 +56,7 @@ func (a *auth) getHash(c string) string { return string(h.Sum(nil)) } -func (a *auth) sendHash(q queue.Q) bool { - enc := netstring.NewEncoder(q.Writer()) - enc.Encode(a.getHash(a.challenge.peer)) - return a.wait(a.recvHash) -} - -func (a *auth) wait(c chan struct{}) bool { +func (a *auth) isReady(c chan struct{}) bool { select { case <-a.fail: return false @@ -78,11 +66,15 @@ func (a *auth) wait(c chan struct{}) bool { } func (a *auth) Send(rq, wq queue.Q) error { + e := netstring.NewEncoder(wq.Writer()) + if err := a.generateChallenge(); err != nil { return err } - if !a.sendChallenge(wq) { + e.Encode(a.challenge.self) + + if !a.isReady(a.recvChallenge) { return nil } @@ -90,7 +82,9 @@ func (a *auth) Send(rq, wq queue.Q) error { return errDupChallenge } - if !a.sendHash(wq) { + e.Encode(a.getHash(a.challenge.peer)) + + if !a.isReady(a.recvHash) { return nil } @@ -122,7 +116,7 @@ func (a *auth) Recv(rq, wq queue.Q) (err error) { close(a.recvHash) - if !a.wait(a.ok) { + if !a.isReady(a.ok) { return nil } @@ -141,7 +135,7 @@ func getAuthSecret(env env.Env) string { return env.Get("secret") } -func (m authModule) Open(env env.Env) (Pipe, Pipe) { +func (m authModule) Open(env env.Env) (interface{}, error) { a := &auth{ secret: getAuthSecret(env), recvChallenge: make(chan struct{}), @@ -149,7 +143,7 @@ func (m authModule) Open(env env.Env) (Pipe, Pipe) { fail: make(chan struct{}), ok: make(chan struct{}), } - return a.Send, a.Recv + return a, nil } func init() { diff --git a/pkg/server/module/hex.go b/pkg/server/module/hex.go index 8a25e50..c3ad0db 100644 --- a/pkg/server/module/hex.go +++ b/pkg/server/module/hex.go @@ -9,7 +9,7 @@ import ( type hexModule struct{} -func hexEncoder(rq, wq queue.Q) error { +func (h hexModule) Send(rq, wq queue.Q) error { enc := hex.NewEncoder(wq.Writer()) for b := range rq { @@ -19,14 +19,14 @@ func hexEncoder(rq, wq queue.Q) error { return nil } -func hexDecoder(rq, wq queue.Q) error { +func (h hexModule) Recv(rq, wq queue.Q) error { r := hex.NewDecoder(rq.Reader()) w := wq.Writer() return queue.IoCopy(r, w) } -func (m hexModule) Open(env env.Env) (Pipe, Pipe) { - return hexEncoder, hexDecoder +func (m hexModule) Open(env env.Env) (interface{}, error) { + return m, nil } func init() { diff --git a/pkg/server/module/module.go b/pkg/server/module/module.go index 69fb90b..3df6eb5 100644 --- a/pkg/server/module/module.go +++ b/pkg/server/module/module.go @@ -3,6 +3,7 @@ package module import ( "fmt" "log" + "strings" "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" @@ -13,7 +14,7 @@ type moduleInitFunc func(opts.Opts, env.Env) (module, error) var modules = map[string]moduleInitFunc{} type module interface { - Open(env env.Env) (Pipe, Pipe) + Open(env env.Env) (interface{}, error) } type M interface { @@ -21,57 +22,92 @@ type M interface { String() string } -type Pipe func(rq, wq queue.Q) error - -func (p Pipe) Open(env env.Env) (Pipe, Pipe) { - return p, nil +type Sender interface { + Send(rq, wq queue.Q) error } -type reverse struct { - M +type Recver interface { + Recv(rq, wq queue.Q) error } -func Reverse(m M) M { - return &reverse{m} +type Func func(rq, wq queue.Q) error + +func (f Func) Send(rq, wq queue.Q) error { + return f(rq, wq) } -func (r *reverse) Open(env env.Env) (Pipe, Pipe) { - p1, p2 := r.M.Open(env) - return p2, p1 +func (f Func) Open(env env.Env) (interface{}, error) { + return f, nil } -type named struct { - name string +type wrapper struct { module + name string + reverse bool } -func (m *named) String() string { - return fmt.Sprintf("module:%s", m.name) +func (w *wrapper) String() string { + return fmt.Sprintf("module:%s", w.name) } -func register(name string, f moduleInitFunc) { - if _, ok := modules[name]; ok { - log.Panicf("duplicate module name '%s'", name) +func Open(m M, env env.Env) (Func, Func, error) { + var send, recv Func + + w := m.(*wrapper) + + it, err := m.Open(env) + if err != nil { + return nil, nil, err } - modules[name] = f -} + if sender, ok := it.(Sender); ok { + send = sender.Send + } -func registerPipe(name string, p Pipe) { - register(name, func(opts.Opts, env.Env) (module, error) { - return p, nil - }) + if recver, ok := it.(Recver); ok { + recv = recver.Recv + } + + if w.reverse { + send, recv = recv, send + } + + return send, recv, nil } func New(desc string, env env.Env) (M, error) { name, opts := opts.Parse(desc) + reverse := false + + if strings.HasPrefix(name, "-") { + name = name[1:] + reverse = true + } if f, ok := modules[name]; !ok { return nil, fmt.Errorf("unknown module '%s'", name) } else if m, err := f(opts, env); err != nil { return nil, err } else { - return &named{name: name, module: m}, nil + w := &wrapper{ + module: m, + name: name, + reverse: reverse, + } + return w, nil } +} +func register(name string, f moduleInitFunc) { + if _, ok := modules[name]; ok { + log.Panicf("duplicate module name '%s'", name) + } + + modules[name] = f +} + +func registerFunc(name string, p Func) { + register(name, func(opts.Opts, env.Env) (module, error) { + return p, nil + }) } diff --git a/pkg/server/module/split.go b/pkg/server/module/split.go index 138814f..139d062 100644 --- a/pkg/server/module/split.go +++ b/pkg/server/module/split.go @@ -34,8 +34,8 @@ func (m *splitModule) Send(rq, wq queue.Q) error { return nil } -func (m *splitModule) Open(env.Env) (Pipe, Pipe) { - return m.Send, nil +func (m *splitModule) Open(env.Env) (interface{}, error) { + return m, nil } func newSplitModule(opts opts.Opts, env env.Env) (module, error) { diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index 49f684e..5a5e302 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -19,6 +19,7 @@ import ( type stream struct { id int t *tunnel + env env.Env since time.Time wg sync.WaitGroup in, out socket.Channel @@ -134,9 +135,12 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream { in: in, out: out, id: t.nextSid, + env: t.env.Copy(), since: time.Now(), } + s.env.Set("tunnel", t.id) + s.run() t.mu.Lock() @@ -181,13 +185,13 @@ func (s *stream) channel(c socket.Channel, rq, wq queue.Q) { }() } -func (s *stream) pipe(m module.M, p module.Pipe, rq, wq queue.Q) { +func (s *stream) pipe(m module.M, f module.Func, rq, wq queue.Q) { s.wg.Add(1) go func() { defer s.wg.Done() - if err := p(rq, wq); err != nil && !errors.Is(err, io.EOF) { + if err := f(rq, wq); err != nil && !errors.Is(err, io.EOF) { log.Println(s.t, s, m, err) } @@ -197,9 +201,6 @@ func (s *stream) pipe(m module.M, p module.Pipe, rq, wq queue.Q) { } func (s *stream) run() { - env := s.t.env.Copy() - env.Set("tunnel", s.t.id) - s.t.wg.Add(1) rq, wq := queue.New(), queue.New() @@ -207,12 +208,14 @@ func (s *stream) run() { s.channel(s.in, rq, wq) for _, m := range s.t.m { - send, recv := m.Open(env) + send, recv, _ := module.Open(m, s.env) + if send != nil { q := queue.New() s.pipe(m, send, wq, q) wq = q } + if recv != nil { q := queue.New() s.pipe(m, recv, q, rq) @@ -232,24 +235,11 @@ func parseModules(args []string, env env.Env) ([]module.M, error) { var mm []module.M for _, arg := range args { - var reverse bool - var m module.M - var err error - - if strings.HasPrefix(arg, "-") { - reverse = true - arg = arg[1:] - } - - if m, err = module.New(arg, env); err != nil { + if m, err := module.New(arg, env); err != nil { return nil, err + } else { + mm = append(mm, m) } - - if reverse { - m = module.Reverse(m) - } - - mm = append(mm, m) } return mm, nil |
