summaryrefslogtreecommitdiff
path: root/pkg/server
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server')
-rw-r--r--pkg/server/env.go23
-rw-r--r--pkg/server/server.go55
-rw-r--r--pkg/server/sleep.go6
3 files changed, 53 insertions, 31 deletions
diff --git a/pkg/server/env.go b/pkg/server/env.go
index eed6e67..1c24f0b 100644
--- a/pkg/server/env.go
+++ b/pkg/server/env.go
@@ -7,9 +7,9 @@ import (
func init() {
newCmd(envSet, "env", "set")
newCmd(envGet, "env", "get")
- newCmd(envShow, "env", "show")
- newCmd(envShow, "env", "print")
newCmd(envUnset, "env", "unset")
+
+ newCmd(envShow, "show", "env")
}
type env struct {
@@ -77,7 +77,7 @@ func envGet(r *request) {
if v, ok := r.c.s.env.get(r.args[0]); ok {
r.Print(v)
} else {
- r.Print("no such variable")
+ r.Fatal("no such variable")
}
}
@@ -85,20 +85,15 @@ func envUnset(r *request) {
r.expect(1)
if !r.c.s.env.unset(r.args[0]) {
- r.Print("no such variable")
+ r.Fatal("no such variable")
}
}
func envShow(r *request) {
- r.expect(0, 1)
+ r.expect(0)
- switch r.argc {
- case 0:
- r.c.s.env.each(func (k string, v string) bool {
- r.Println(k, v)
- return true
- })
- case 1:
- envGet(r)
- }
+ r.c.s.env.each(func (k string, v string) bool {
+ r.Println(k, v)
+ return true
+ })
}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 1248b1d..e2515cb 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -2,6 +2,7 @@ package server
import (
"tunnel/pkg/config"
+ "tunnel/pkg/netstring"
"strings"
"bytes"
"sync"
@@ -46,6 +47,8 @@ type request struct {
argc int
args []string
+ failed bool
+
out *bytes.Buffer
}
@@ -183,7 +186,7 @@ func (c *client) newRequest() *request {
r := &request{
c: c,
id: c.nextRid,
- out: bytes.NewBuffer(nil),
+ out: new(bytes.Buffer),
}
c.nextRid++
@@ -209,9 +212,7 @@ func (c *client) handle() {
r := c.newRequest()
- if r.parse(query) {
- r.run()
- }
+ r.run(query)
if r.out.Len() == 0 {
r.out.Write([]byte("\n"))
@@ -225,30 +226,58 @@ func (c *client) handle() {
}
}
-func (r *request) parse(query string) bool {
- c, args := getCmd(strings.Split(query, " "))
+func (r *request) decode(query string) []string {
+ dec := netstring.NewDecoder(bytes.NewReader([]byte(query)))
+ var t []string
+
+ for {
+ if s, err := dec.Decode(); err == nil {
+ t = append(t, s)
+ } else {
+ if err == io.EOF {
+ break
+ }
+
+ r.Fatal("failed to parse request")
+ }
+ }
+
+ return t
+}
+func (r *request) parse(query string) {
+ c, args := getCmd(r.decode(query))
if c == nil {
- r.Print("command not found")
- return false
+ r.Fatal("command not found")
+ }
+
+ for n, s := range args {
+ if strings.HasPrefix(s, "%") {
+ if v, ok := r.c.s.env.get(s[1:]); ok {
+ args[n] = v
+ } else {
+ r.Fatal("unbound variable ", s)
+ }
+ }
}
r.args = args
r.argc = len(args)
r.cmd = c
-
- return true
}
-func (r *request) run() {
- log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " "))
-
+func (r *request) run(query string) {
defer func () {
if e := recover(); e != nil {
+ r.failed = true
r.Print(e)
}
}()
+ r.parse(query)
+
+ log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " "))
+
r.cmd.f(r)
}
diff --git a/pkg/server/sleep.go b/pkg/server/sleep.go
index 14d22ad..bab9d9b 100644
--- a/pkg/server/sleep.go
+++ b/pkg/server/sleep.go
@@ -16,13 +16,11 @@ func sleep(r *request) {
n, err := strconv.Atoi(r.args[0])
if err != nil || n < 0 {
- r.Printf("invalid time interval '%s'", r.args[0])
- return
+ r.Fatalf("invalid time interval '%s'", r.args[0])
}
if n > maxSleep {
- r.Printf("no more than %d", maxSleep)
- return
+ r.Fatalf("no more than %d", maxSleep)
}
time.Sleep(time.Duration(n) * time.Second)