diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-01-23 06:24:21 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-01-23 06:26:25 +0300 |
| commit | 3145a06d233dfdd4a70cfc706eaaae3abfb937db (patch) | |
| tree | d5550568b223ff46b880a1708926a49d910bab94 /pkg/server | |
| parent | c03851d36298d24e2949a3de688cf2ed2f55b064 (diff) | |
fix client/server protocol
Diffstat (limited to 'pkg/server')
| -rw-r--r-- | pkg/server/env.go | 23 | ||||
| -rw-r--r-- | pkg/server/server.go | 55 | ||||
| -rw-r--r-- | pkg/server/sleep.go | 6 |
3 files changed, 53 insertions, 31 deletions
diff --git a/pkg/server/env.go b/pkg/server/env.go index eed6e67..1c24f0b 100644 --- a/pkg/server/env.go +++ b/pkg/server/env.go @@ -7,9 +7,9 @@ import ( func init() { newCmd(envSet, "env", "set") newCmd(envGet, "env", "get") - newCmd(envShow, "env", "show") - newCmd(envShow, "env", "print") newCmd(envUnset, "env", "unset") + + newCmd(envShow, "show", "env") } type env struct { @@ -77,7 +77,7 @@ func envGet(r *request) { if v, ok := r.c.s.env.get(r.args[0]); ok { r.Print(v) } else { - r.Print("no such variable") + r.Fatal("no such variable") } } @@ -85,20 +85,15 @@ func envUnset(r *request) { r.expect(1) if !r.c.s.env.unset(r.args[0]) { - r.Print("no such variable") + r.Fatal("no such variable") } } func envShow(r *request) { - r.expect(0, 1) + r.expect(0) - switch r.argc { - case 0: - r.c.s.env.each(func (k string, v string) bool { - r.Println(k, v) - return true - }) - case 1: - envGet(r) - } + r.c.s.env.each(func (k string, v string) bool { + r.Println(k, v) + return true + }) } diff --git a/pkg/server/server.go b/pkg/server/server.go index 1248b1d..e2515cb 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -2,6 +2,7 @@ package server import ( "tunnel/pkg/config" + "tunnel/pkg/netstring" "strings" "bytes" "sync" @@ -46,6 +47,8 @@ type request struct { argc int args []string + failed bool + out *bytes.Buffer } @@ -183,7 +186,7 @@ func (c *client) newRequest() *request { r := &request{ c: c, id: c.nextRid, - out: bytes.NewBuffer(nil), + out: new(bytes.Buffer), } c.nextRid++ @@ -209,9 +212,7 @@ func (c *client) handle() { r := c.newRequest() - if r.parse(query) { - r.run() - } + r.run(query) if r.out.Len() == 0 { r.out.Write([]byte("\n")) @@ -225,30 +226,58 @@ func (c *client) handle() { } } -func (r *request) parse(query string) bool { - c, args := getCmd(strings.Split(query, " ")) +func (r *request) decode(query string) []string { + dec := netstring.NewDecoder(bytes.NewReader([]byte(query))) + var t []string + + for { + if s, err := dec.Decode(); err == nil { + t = append(t, s) + } else { + if err == io.EOF { + break + } + + r.Fatal("failed to parse request") + } + } + + return t +} +func (r *request) parse(query string) { + c, args := getCmd(r.decode(query)) if c == nil { - r.Print("command not found") - return false + r.Fatal("command not found") + } + + for n, s := range args { + if strings.HasPrefix(s, "%") { + if v, ok := r.c.s.env.get(s[1:]); ok { + args[n] = v + } else { + r.Fatal("unbound variable ", s) + } + } } r.args = args r.argc = len(args) r.cmd = c - - return true } -func (r *request) run() { - log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " ")) - +func (r *request) run(query string) { defer func () { if e := recover(); e != nil { + r.failed = true r.Print(e) } }() + r.parse(query) + + log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " ")) + r.cmd.f(r) } diff --git a/pkg/server/sleep.go b/pkg/server/sleep.go index 14d22ad..bab9d9b 100644 --- a/pkg/server/sleep.go +++ b/pkg/server/sleep.go @@ -16,13 +16,11 @@ func sleep(r *request) { n, err := strconv.Atoi(r.args[0]) if err != nil || n < 0 { - r.Printf("invalid time interval '%s'", r.args[0]) - return + r.Fatalf("invalid time interval '%s'", r.args[0]) } if n > maxSleep { - r.Printf("no more than %d", maxSleep) - return + r.Fatalf("no more than %d", maxSleep) } time.Sleep(time.Duration(n) * time.Second) |
