diff options
| -rw-r--r-- | TODO | 2 | ||||
| -rw-r--r-- | pkg/server/env.go | 68 | ||||
| -rw-r--r-- | pkg/server/env/env.go | 114 | ||||
| -rw-r--r-- | pkg/server/module/alpha.go | 6 | ||||
| -rw-r--r-- | pkg/server/module/auth.go | 32 | ||||
| -rw-r--r-- | pkg/server/module/hex.go | 4 | ||||
| -rw-r--r-- | pkg/server/module/module.go | 55 | ||||
| -rw-r--r-- | pkg/server/opts/opts.go | 21 | ||||
| -rw-r--r-- | pkg/server/queue/queue.go | 11 | ||||
| -rw-r--r-- | pkg/server/server.go | 32 | ||||
| -rw-r--r-- | pkg/server/socket/socket.go | 26 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 127 |
12 files changed, 322 insertions, 176 deletions
@@ -37,3 +37,5 @@ note: 9. tunnel enable/disable 10. config from file 11. system/user? unix control socket location +12. modules: auth(chap), enc, dec +13. print module name when stream closed by error diff --git a/pkg/server/env.go b/pkg/server/env.go index 818310c..7b6b36d 100644 --- a/pkg/server/env.go +++ b/pkg/server/env.go @@ -1,63 +1,9 @@ package server -import ( - "regexp" -) - -type env struct { - m map[string]string -} - -const varNamePattern = "[a-zA-Z][a-zA-Z0-9]*" - -var isValidVarName = regexp.MustCompile("^" + varNamePattern + "$").MatchString - -var varTokenRe = regexp.MustCompile("@" + varNamePattern) - -func (e *env) get(key string) (string, bool) { - v, ok := e.m[key] - - return v, ok -} - -func (e *env) set(key string, value string) { - if e.m == nil { - e.m = make(map[string]string) - } - - e.m[key] = value -} - -func (e *env) del(key string) bool { - if e.m == nil { - return false - } - - if _, ok := e.m[key]; !ok { - return false - } - - delete(e.m, key) - - return true -} - -func (e *env) each(f func (string, string) bool) { - for k, v := range e.m { - if !f(k, v) { - break - } - } -} - -func (e *env) clear() { - e.m = nil -} - func varGet(r *request) { r.expect(1) - if v, ok := r.c.s.env.get(r.args[0]); ok { + if v, ok := r.c.s.env.Find(r.args[0]); ok { r.Print(v) } else { r.Fatal("no such variable") @@ -67,30 +13,28 @@ func varGet(r *request) { func varSet(r *request) { r.expect(2) - if !isValidVarName(r.args[0]) { - r.Fatal("bad variable name") + if err := r.c.s.env.Set(r.args[0], r.args[1]); err != nil { + r.Fatal(err) } - - r.c.s.env.set(r.args[0], r.args[1]) } func varDel(r *request) { r.expect(1) - if !r.c.s.env.del(r.args[0]) { + if !r.c.s.env.Del(r.args[0]) { r.Fatal("no such variable") } } func varShow(r *request) { - r.c.s.env.each(func (k string, v string) bool { + r.c.s.env.Each(func (k string, v string) bool { r.Println(k, v) return true }) } func varClear(r *request) { - r.c.s.env.clear() + r.c.s.env.Clear() } func init() { diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go new file mode 100644 index 0000000..ab47ae8 --- /dev/null +++ b/pkg/server/env/env.go @@ -0,0 +1,114 @@ +package env + +import ( + "errors" + "regexp" + "sync" +) + +type env struct { + m map[string]string + sync.Mutex +} + +type Env struct { + *env +} + +const namePattern = "[a-zA-Z][a-zA-Z0-9]*" +var isNamePattern = regexp.MustCompile("^" + namePattern + "$").MatchString +var namePatternRe = regexp.MustCompile("@" + namePattern) + +var errBadVariable = errors.New("bad variable name") + +func New() Env { + return Env{new(env)} +} + +func (e *env) Find(key string) (string, bool) { + e.Lock() + defer e.Unlock() + + v, ok := e.m[key] + + return v, ok +} + +func (e *env) Get(key string) string { + v, _ := e.Find(key) + return v +} + +func (e *env) Set(key string, value string) error { + if !isNamePattern(key) { + return errBadVariable + } + + e.Lock() + defer e.Unlock() + + if e.m == nil { + e.m = make(map[string]string) + } + + e.m[key] = value + + return nil +} + +func (e *env) Del(key string) bool { + e.Lock() + defer e.Unlock() + + if e.m == nil { + return false + } + + if _, ok := e.m[key]; !ok { + return false + } + + delete(e.m, key) + + return true +} + +func (e *env) Each(f func (string, string) bool) { + e.Lock() + defer e.Unlock() + + for k, v := range e.m { + if !f(k, v) { + break + } + } +} + +func (e *env) Clear() { + e.Lock() + defer e.Unlock() + + e.m = nil +} + +func (e *env) Eval(s string) string { + e.Lock() + defer e.Unlock() + + repl := func (v string) string { + if v, ok := e.m[v[1:]]; ok { + return v + } + return "" + } + + for { + if t := namePatternRe.ReplaceAllStringFunc(s, repl); t == s { + break + } else { + s = t + } + } + + return s +} diff --git a/pkg/server/module/alpha.go b/pkg/server/module/alpha.go index be9032c..9eb1e2c 100644 --- a/pkg/server/module/alpha.go +++ b/pkg/server/module/alpha.go @@ -7,7 +7,7 @@ import ( "io" ) -func alpha(cb func (rune) rune) pipe { +func alpha(cb func (rune) rune) Pipe { return func (rq, wq queue.Q) error { r := bufio.NewReader(rq.Reader()) @@ -27,6 +27,6 @@ func alpha(cb func (rune) rune) pipe { } func init() { - register("lower", alpha(unicode.ToLower)) - register("upper", alpha(unicode.ToUpper)) + registerPipe("lower", alpha(unicode.ToLower)) + registerPipe("upper", alpha(unicode.ToUpper)) } diff --git a/pkg/server/module/auth.go b/pkg/server/module/auth.go new file mode 100644 index 0000000..05761ed --- /dev/null +++ b/pkg/server/module/auth.go @@ -0,0 +1,32 @@ +package module + +import ( + "tunnel/pkg/server/queue" + "tunnel/pkg/server/opts" + "tunnel/pkg/server/env" +) + +type auth struct { + secret string +} + +type authModule struct{} + +func (a *auth) Send(rq, wq queue.Q) error { + return queue.Copy(rq, wq) +} + +func (a *auth) Recv(rq, wq queue.Q) error { + return queue.Copy(rq, wq) +} + +func (m authModule) Open(env env.Env) (Pipe, Pipe) { + a := &auth{env.Get("secret")} + return a.Send, a.Recv +} + +func init() { + register("auth", func (opts.Opts, env.Env) (module, error) { + return authModule{}, nil + }) +} diff --git a/pkg/server/module/hex.go b/pkg/server/module/hex.go index 2ffd1fc..9b80e0d 100644 --- a/pkg/server/module/hex.go +++ b/pkg/server/module/hex.go @@ -22,6 +22,6 @@ func hexDecoder(rq, wq queue.Q) error { } func init() { - register("hex", pipe(hexEncoder)) - register("unhex", pipe(hexDecoder)) + registerPipe("hex", Pipe(hexEncoder)) + registerPipe("unhex", Pipe(hexDecoder)) } diff --git a/pkg/server/module/module.go b/pkg/server/module/module.go index 768a87b..87bdd20 100644 --- a/pkg/server/module/module.go +++ b/pkg/server/module/module.go @@ -2,16 +2,29 @@ package module import ( "tunnel/pkg/server/queue" + "tunnel/pkg/server/opts" + "tunnel/pkg/server/env" "fmt" "log" ) -var modules = map[string]M{} +type moduleInitFunc func (opts.Opts, env.Env) (module, error) -type pipe func (rq, wq queue.Q) error +var modules = map[string]moduleInitFunc{} + +type module interface { + Open(env env.Env) (Pipe, Pipe) +} type M interface { - Open() (pipe, pipe) + module + String() string +} + +type Pipe func (rq, wq queue.Q) error + +func (p Pipe) Open(env env.Env) (Pipe, Pipe) { + return p, nil } type reverse struct { @@ -22,27 +35,43 @@ func Reverse(m M) M { return &reverse{m} } -func (r *reverse) Open() (pipe, pipe) { - p1, p2 := r.M.Open() +func (r *reverse) Open(env env.Env) (Pipe, Pipe) { + p1, p2 := r.M.Open(env) return p2, p1 } -func (p pipe) Open() (pipe, pipe) { - return p, nil +type named struct { + name string + module +} + +func (m *named) String() string { + return fmt.Sprintf("module:%s", m.name) } -func register(name string, m M) { +func register(name string, f moduleInitFunc) { if _, ok := modules[name]; ok { log.Panicf("duplicate module name '%s'", name) } - modules[name] = m + modules[name] = f } -func New(name string) (M, error) { - if m, ok := modules[name]; ok { - return m, nil +func registerPipe(name string, p Pipe) { + register(name, func (opts.Opts, env.Env) (module, error) { + return p, nil + }) +} + +func New(desc string, env env.Env) (M, error) { + name, opts := opts.Parse(desc) + + if f, ok := modules[name]; !ok { + return nil, fmt.Errorf("unknown module '%s'", name) + } else if m, err := f(opts, env); err != nil { + return nil, err + } else { + return &named{name: name, module: m}, nil } - return nil, fmt.Errorf("unknown module '%s'", name) } diff --git a/pkg/server/opts/opts.go b/pkg/server/opts/opts.go new file mode 100644 index 0000000..25dd8e6 --- /dev/null +++ b/pkg/server/opts/opts.go @@ -0,0 +1,21 @@ +package opts + +import "strings" + +type Opts map[string]string + +func Parse(s string) (string, Opts) { + v := strings.Split(s, ",") + m := map[string]string{} + + for _, t := range v[1:] { + kv := strings.SplitN(t, "=", 2) + if len(kv) < 2 { + m[kv[0]] = "" + } else { + m[kv[0]] = kv[1] + } + } + + return v[0], m +} diff --git a/pkg/server/queue/queue.go b/pkg/server/queue/queue.go index 8d0f395..745d971 100644 --- a/pkg/server/queue/queue.go +++ b/pkg/server/queue/queue.go @@ -41,6 +41,10 @@ func (q Q) Writer() io.Writer { return &writer{q: q} } +func (q Q) Dry() { + for _ = range q {} +} + func (w *writer) Write(p []byte) (int, error) { buf := make([]byte, len(p)) copy(buf, p) @@ -58,3 +62,10 @@ func IoCopy(r io.Reader, w io.Writer) error { return nil } + +func Copy(rq, wq Q) error { + for b := range rq { + wq <- b + } + return nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index ce910f3..d380fb4 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,6 +3,7 @@ package server import ( "tunnel/pkg/config" "tunnel/pkg/netstring" + "tunnel/pkg/server/env" "strings" "errors" "bytes" @@ -25,7 +26,7 @@ type Server struct { done chan struct{} tunnels automap - env env + env env.Env nextCid int } @@ -58,11 +59,11 @@ type requestError string var errNotImplemented = errors.New("not implemented") func (c *client) String() string { - return fmt.Sprintf("client(%d)", c.id) + return fmt.Sprintf("client:%d", c.id) } func (r *request) String() string { - return fmt.Sprintf("request(%d)", r.id) + return fmt.Sprintf("request:%d", r.id) } func (r *request) Print(v ...interface{}) { @@ -132,6 +133,7 @@ func New() (*Server, error) { } s := &Server{ + env: env.New(), listen: listen, since: time.Now(), done: make(chan struct{}), @@ -248,33 +250,11 @@ func (r *request) decode(query string) []string { } func (r *request) eval(args []string) []string { - repl := func (v string) string { - if v, ok := r.c.s.env.get(v[1:]); ok { - return v - } - - r.Fatal("unbound variable ", v) - - return v - } - - eval := func (s string) string { - var t string - - for ;; s = t { - t = varTokenRe.ReplaceAllStringFunc(s, repl) - - if s == t { - return s - } - } - } - for n, s := range args { if strings.HasPrefix(s, "^") { args[n] = s[1:] } else { - args[n] = eval(s) + args[n] = r.c.s.env.Eval(s) } } diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index bf754cf..c6219a5 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -2,6 +2,8 @@ package socket import ( "tunnel/pkg/server/queue" + "tunnel/pkg/server/opts" + "tunnel/pkg/server/env" "strings" "sync" "fmt" @@ -16,7 +18,7 @@ type Channel interface { } type S interface { - Open() (Channel, error) + Open(env env.Env) (Channel, error) Close() } @@ -106,7 +108,7 @@ func newListenSocket(proto, addr string) (S, error) { return s, nil } -func (s *listenSocket) Open() (Channel, error) { +func (s *listenSocket) Open(env env.Env) (Channel, error) { conn, err := s.listen.Accept() if err != nil { return nil, err @@ -137,7 +139,7 @@ func (s *dialSocket) String() string { return fmt.Sprintf("%s/%s", s.proto, s.addr) } -func (s *dialSocket) Open() (Channel, error) { +func (s *dialSocket) Open(env env.Env) (Channel, error) { conn, err := net.Dial(s.proto, s.addr) if err != nil { return nil, err @@ -148,19 +150,9 @@ func (s *dialSocket) Open() (Channel, error) { func (s *dialSocket) Close() { } -func New(name string) (S, error) { - vv := strings.Split(name, ",") - args := strings.Split(vv[0], "/") - opts := map[string]string{} - - for _, v := range vv[1:] { - ss := strings.SplitN(v, "=", 2) - if len(ss) < 2 { - opts[ss[0]] = "" - } else { - opts[ss[0]] = ss[1] - } - } +func New(desc string, env env.Env) (S, error) { + base, opts := opts.Parse(desc) + args := strings.Split(base, "/") var proto string var addr string @@ -176,7 +168,7 @@ func New(name string) (S, error) { } if addr == "" { - return nil, fmt.Errorf("bad socket '%s'", name) + return nil, fmt.Errorf("bad socket '%s'", desc) } if _, ok := opts["listen"]; ok { diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index e4a324c..a7854bc 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -1,9 +1,10 @@ package server import ( + "tunnel/pkg/server/socket" "tunnel/pkg/server/module" "tunnel/pkg/server/queue" - "tunnel/pkg/server/socket" + "tunnel/pkg/server/env" "tunnel/pkg/config" "strings" "time" @@ -37,14 +38,16 @@ type tunnel struct { in, out socket.S m []module.M + + env env.Env } func (s *stream) String() string { - return fmt.Sprintf("stream(%d)", s.id) + return fmt.Sprintf("stream:%d", s.id) } func (t *tunnel) String() string { - return fmt.Sprintf("tunnel(%s)", t.id) + return fmt.Sprintf("tunnel:%s", t.id) } func (t *tunnel) stopServe() { @@ -84,7 +87,7 @@ func (t *tunnel) serve() { var wg sync.WaitGroup for { - if in, err := t.in.Open(); err != nil { + if in, err := t.in.Open(t.env); err != nil { if t.isQuit() { break } @@ -109,7 +112,7 @@ func (t *tunnel) serve() { } func (t *tunnel) handle(in socket.Channel) { - out, err := t.out.Open() + out, err := t.out.Open(t.env) if err != nil { log.Println(t, err) in.Close() @@ -132,13 +135,13 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream { since: time.Now(), } + s.run() + t.mu.Lock() t.nextSid++ t.streams[s.id] = s t.mu.Unlock() - s.run() - go func () { s.wg.Wait() @@ -154,7 +157,7 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream { return s } -func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { +func (s *stream) channel(c socket.Channel, rq, wq queue.Q) { watch := func (q queue.Q, f func (q queue.Q) error) { defer s.wg.Done() @@ -170,20 +173,24 @@ func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { close(wq) }() - go watch(rq, c.Recv) + go func () { + watch(rq, c.Recv) + rq.Dry() + }() } -func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) { +func (s *stream) pipe(m module.M, p module.Pipe, rq, wq queue.Q) { s.wg.Add(1) go func () { defer s.wg.Done() - if err := f(rq, wq); err != nil { - log.Println(s.t, s, err) + if err := p(rq, wq); err != nil { + log.Println(s.t, s, m, err) } close(wq) + rq.Dry() }() } @@ -192,23 +199,23 @@ func (s *stream) run() { rq, wq := queue.New(), queue.New() - s.watchChannel(rq, wq, s.in) + s.channel(s.in, rq, wq) for _, m := range s.t.m { - send, recv := m.Open() + send, recv := m.Open(s.t.env) if send != nil { q := queue.New() - s.watchPipe(wq, q, send) + s.pipe(m, send, wq, q) wq = q } if recv != nil { q := queue.New() - s.watchPipe(q, rq, recv) + s.pipe(m, recv, q, rq) rq = q } } - s.watchChannel(wq, rq, s.out) + s.channel(s.out, wq, rq) } func (s *stream) stop() { @@ -216,34 +223,14 @@ func (s *stream) stop() { s.out.Close() } -func newTunnel(args []string) (*tunnel, error) { - var in, out socket.S - var err error - - n := len(args) - 1 - - if in, err = socket.New(args[0]); err != nil { - return nil, err - } - - if out, err = socket.New(args[n]); err != nil { - in.Close() - return nil, err - } - - t := &tunnel{ - args: strings.Join(args, " "), - quit: make(chan struct{}), - done: make(chan struct{}), - in: in, - out: out, - streams: make(map[int]*stream), - } +func parseModules(args []string, env env.Env) ([]module.M, error) { + var mm []module.M reverse := false - for _, arg := range args[1:n] { + for _, arg := range args { var m module.M + var err error if arg == "-" { reverse = true @@ -255,8 +242,7 @@ func newTunnel(args []string) (*tunnel, error) { continue } - if m, err = module.New(arg); err != nil { - t.Close() + if m, err = module.New(arg, env); err != nil { return nil, err } @@ -265,14 +251,49 @@ func newTunnel(args []string) (*tunnel, error) { reverse = false } - t.m = append(t.m, m) + mm = append(mm, m) } if reverse { - t.Close() return nil, fmt.Errorf("bad '-' usage") } + return mm, nil +} + +func newTunnel(args []string, env env.Env) (*tunnel, error) { + var in, out socket.S + var mm []module.M + var err error + + n := len(args) - 1 + + if in, err = socket.New(args[0], env); err != nil { + return nil, err + } + + if out, err = socket.New(args[n], env); err != nil { + in.Close() + return nil, err + } + + if mm, err = parseModules(args[1:n], env); err != nil { + in.Close() + out.Close() + return nil, err + } + + t := &tunnel{ + args: strings.Join(args, " "), + quit: make(chan struct{}), + done: make(chan struct{}), + m: mm, + in: in, + out: out, + env: env, + streams: make(map[int]*stream), + } + go t.serve() return t, nil @@ -305,7 +326,7 @@ func tunnelAdd(r *request) { r.Fatal("not enough args") } - t, err := newTunnel(args) + t, err := newTunnel(args, r.c.s.env) if err != nil { r.Fatal(err) } @@ -384,18 +405,17 @@ func tunnelShow(r *request) { func streamShow(r *request) { foreachTunnel(r.c.s.tunnels, func (t *tunnel) { - r.Println(t.id, t.args) - t.mu.Lock() - if len(t.streams) == 0 { - r.Println("\t", "nothing") - } else { + defer t.mu.Unlock() + + if len(t.streams) > 0 { + r.Println(t.id, t.args) + foreachStream(t.streams, func (s *stream) { tm := s.since.Format(config.TimeFormat) r.Println("\t", s.id, tm, s.in, s.out) }) } - t.mu.Unlock() }) } @@ -406,5 +426,6 @@ func init() { newCmd(tunnelRename, "rename") newCmd(tunnelShow, "show") - newCmd(streamShow, "stream show") + + newCmd(streamShow, "streams") } |
