From ffb4c8b5c2641ae9ee845ddd7f7031c1111a6d45 Mon Sep 17 00:00:00 2001 From: Mikhail Osipov Date: Wed, 22 Jan 2020 04:32:26 +0300 Subject: add chain and env commands --- TODO | 4 +- pkg/server/echo.go | 3 +- pkg/server/env.go | 105 +++++++++++++++++++++++++++ pkg/server/server.go | 201 +++++++++++++++++++++++++++++++++++++++++---------- pkg/server/sleep.go | 9 +-- pkg/server/status.go | 5 +- 6 files changed, 276 insertions(+), 51 deletions(-) create mode 100644 pkg/server/env.go diff --git a/TODO b/TODO index c9c83b6..ed039ac 100644 --- a/TODO +++ b/TODO @@ -1,4 +1,4 @@ 1. DONE ./pkg/server/server.go make request -2. make chain commands +2. DONE make chain commands 3. add help command -4. env set/get +4. DONE env set/get diff --git a/pkg/server/echo.go b/pkg/server/echo.go index c8dce31..00f0c2d 100644 --- a/pkg/server/echo.go +++ b/pkg/server/echo.go @@ -2,7 +2,6 @@ package server import ( "strings" - "fmt" ) func init() { @@ -10,5 +9,5 @@ func init() { } func echo(r *request) { - fmt.Fprint(r.out, strings.Join(r.args, " ")) + r.Print(strings.Join(r.args, " ")) } diff --git a/pkg/server/env.go b/pkg/server/env.go new file mode 100644 index 0000000..769342a --- /dev/null +++ b/pkg/server/env.go @@ -0,0 +1,105 @@ +package server + +import ( + "sync" +) + +func init() { + setHandler(envSet, "env", "set") + setHandler(envGet, "env", "get") + setHandler(envShow, "env", "show") + setHandler(envShow, "env", "print") + setHandler(envUnset, "env", "unset") +} + +type env struct { + m map[string]string + sync.Mutex +} + +func (e *env) get(key string) (string, bool) { + e.Lock() + defer e.Unlock() + + v, ok := e.m[key] + + return v, ok +} + +func (e *env) set(key string, value string) { + e.Lock() + defer e.Unlock() + + if e.m == nil { + e.m = make(map[string]string) + } + + e.m[key] = value +} + +func (e *env) unset(key string) bool { + e.Lock() + defer e.Unlock() + + if e.m == nil { + return false + } + + if _, ok := e.m[key]; !ok { + return false + } + + delete(e.m, key) + + return true +} + +func (e *env) each(f func (string, string) bool) { + e.Lock() + defer e.Unlock() + + for k, v := range e.m { + if !f(k, v) { + break + } + } +} + +func envSet(r *request) { + r.expect(2) + + r.c.s.env.set(r.args[0], r.args[1]) +} + +func envGet(r *request) { + r.expect(1) + + if v, ok := r.c.s.env.get(r.args[0]); ok { + r.Print(v) + } else { + r.Print("no such variable") + } +} + +func envUnset(r *request) { + r.expect(1) + + if !r.c.s.env.unset(r.args[0]) { + r.Print("no such variable") + } +} + +func envShow(r *request) { + r.expect(0, 1) + + switch r.argc { + case 0: + r.c.s.env.each(func (k string, v string) bool { + r.Println(k, v) + return true + }) + + case 1: + envGet(r) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index f9ec602..0c3d2af 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -12,43 +12,51 @@ import ( "io" ) -type handler func (r *request) +type cmd struct { + name string + f func (r *request) +} type node struct { - h handler - nodes map[string]node + c *cmd + m map[string]*node } -var handlers = map[string]handler{} +var cmds = newNode() type Server struct { listen net.Listener since time.Time wg sync.WaitGroup - m sync.Mutex - done bool + once sync.Once + done chan struct{} + + env env nextCid uint64 } type client struct { - s *Server - id uint64 + s *Server + conn net.Conn nextRid uint64 } type request struct { + id uint64 + c *client - id uint64 + cmd *cmd name string + argc int args []string out *bytes.Buffer @@ -62,27 +70,121 @@ func (r *request) String() string { return fmt.Sprintf("request(%d)", r.id) } -func setHandler(h handler, names ...string) { - var path []string +func (r *request) Print(v ...interface{}) { + fmt.Fprint(r.out, v...) +} - for _, s := range names { - if _, ok := handlers[s]; ok { - err := fmt.Sprintf("handler '%s' already registered at '%s'", - s, strings.Join(path, " ")) - panic(err) +func (r *request) Printf(format string, v ...interface{}) { + fmt.Fprintf(r.out, format, v...) +} + +func (r *request) Println(v ...interface{}) { + fmt.Fprintln(r.out, v...) +} + +func (r *request) Fatal(v ...interface{}) { + panic(fmt.Sprint(v...)) +} + +func (r *request) Fatalf(format string, v ...interface{}) { + panic(fmt.Sprintf(format, v...)) +} + +func (r *request) Fatalln(v ...interface{}) { + panic(fmt.Sprintln(v...)) +} + +func (r *request) expect(c ...int) { + desc := func (n int) string { + var sep string + + if n == 1 { + sep = "is" + } else { + sep = "are" } - path = append(path, s) + return fmt.Sprintf("%d %s expected", n, sep) + } + + switch len(c) { + case 0: + if r.argc > 0 { + r.Fatal("args are not expected") + } - handlers[s] = h - break + case 1: + if r.argc < c[0] { + r.Fatal("not enough args: ", desc(c[0])) + } + + if r.argc > c[0] { + r.Fatal("too many args: ", desc(c[0])) + } + + case 2: + if r.argc < c[0] { + r.Fatal("not enough args: at least ", desc(c[0])) + } + + if r.argc > c[1] { + r.Fatal("too many args: no more than ", desc(c[1])) + } } } +func newNode() *node { + return &node{m: map[string]*node{}} +} + +func setHandler(f func (r *request), path ...string) { + node := cmds + + for _, name := range path { + v := node.m[name] + if v == nil { + v = newNode() + node.m[name] = v + } + + node = v + } + + if node.c != nil { + s := strings.Join(path, " ") + log.Panicf("handler already registered at '%s'", s) + } + + node.c = &cmd{ + name: strings.Join(path, " "), + f: f, + } +} + +func getHandler(path []string) (*cmd, []string) { + node := cmds + + for n, name := range path { + node = node.m[name] + if node == nil { + return nil, nil + } + + if node.c != nil { + return node.c, path[n + 1:] + } + } + + return nil, nil +} + func (s *Server) isDone() bool { - s.m.Lock() - defer s.m.Unlock() - return s.done + select { + case <- s.done: + return true + default: + return false + } } func New() (*Server, error) { @@ -94,6 +196,7 @@ func New() (*Server, error) { s := &Server{ listen: listen, since: time.Now(), + done: make(chan struct{}), } return s, nil @@ -124,11 +227,10 @@ func (s *Server) Run() { } func (s *Server) Stop() { - s.m.Lock() - s.done = true - s.m.Unlock() - - s.listen.Close() + s.once.Do(func () { + close(s.done) + s.listen.Close() + }) } func (s *Server) newClient(conn net.Conn) *client { @@ -143,14 +245,10 @@ func (s *Server) newClient(conn net.Conn) *client { return c } -func (c *client) newRequest(msg string) *request { - args := strings.Split(msg, " ") - +func (c *client) newRequest() *request { r := &request{ c: c, id: c.nextRid, - name: args[0], - args: args[1:], out: bytes.NewBuffer(nil), } @@ -173,14 +271,12 @@ func (c *client) handle() { break } - msg := string(buf[:nr]) - r := c.newRequest(msg) + query := string(buf[:nr]) - if h, ok := handlers[r.name]; ok { - log.Println(c, r, "run:", msg) - h(r) - } else { - fmt.Fprint(r.out, "unknown command") + r := c.newRequest() + + if r.parse(query) { + r.run() } if r.out.Len() == 0 { @@ -195,6 +291,33 @@ func (c *client) handle() { } } +func (r *request) parse(query string) bool { + c, args := getHandler(strings.Split(query, " ")) + + if c == nil { + r.Print("unknown command") + return false + } + + r.args = args + r.argc = len(args) + r.cmd = c + + return true +} + +func (r *request) run() { + log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " ")) + + defer func () { + if e := recover(); e != nil { + r.Print(e) + } + }() + + r.cmd.f(r) +} + func (c *client) close() { log.Println(c, "close") diff --git a/pkg/server/sleep.go b/pkg/server/sleep.go index 53c85ea..fc94079 100644 --- a/pkg/server/sleep.go +++ b/pkg/server/sleep.go @@ -3,7 +3,6 @@ package server import ( "strconv" "time" - "fmt" ) const maxSleep = 10 @@ -13,18 +12,16 @@ func init() { } func sleep(r *request) { - if len(r.args) == 0 { - return - } + r.expect(1) n, err := strconv.Atoi(r.args[0]) if err != nil || n < 0 { - fmt.Fprintf(r.out, "invalid time interval '%s'", r.args[0]) + r.Printf("invalid time interval '%s'", r.args[0]) return } if n > maxSleep { - fmt.Fprintf(r.out, "no more than %d", maxSleep) + r.Printf("no more than %d", maxSleep) return } diff --git a/pkg/server/status.go b/pkg/server/status.go index 462ac76..bc725a7 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -2,7 +2,6 @@ package server import ( "tunnel/pkg/config" - "fmt" ) func init() { @@ -10,5 +9,7 @@ func init() { } func status(r *request) { - fmt.Fprintf(r.out, "since %s", r.c.s.since.Format(config.TimeFormat)) + r.expect() + + r.Printf("since %s", r.c.s.since.Format(config.TimeFormat)) } -- cgit v1.2.3-70-g09d2