summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--TODO3
-rw-r--r--pkg/client/client.go13
-rw-r--r--pkg/netstring/netstring.go55
-rw-r--r--pkg/server/env.go23
-rw-r--r--pkg/server/server.go55
-rw-r--r--pkg/server/sleep.go6
6 files changed, 121 insertions, 34 deletions
diff --git a/TODO b/TODO
index 36ad1ae..80f783d 100644
--- a/TODO
+++ b/TODO
@@ -11,3 +11,6 @@ note:
client connection maybe created on accept client or
created on stream creation.
In the latter case client will be multiplexer
+
+6. check variable name
+7. substitute variable over query (maybe)
diff --git a/pkg/client/client.go b/pkg/client/client.go
index 2f0a16e..8f5fc58 100644
--- a/pkg/client/client.go
+++ b/pkg/client/client.go
@@ -2,8 +2,9 @@ package client
import (
"tunnel/pkg/config"
- "strings"
+ "tunnel/pkg/netstring"
"errors"
+ "bytes"
"time"
"net"
"io"
@@ -32,10 +33,16 @@ func (c *Client) Send(args []string) (string, error) {
c.conn.SetDeadline(t)
}()
- msg := strings.Join(args, " ")
+ out := new(bytes.Buffer)
+ enc := netstring.NewEncoder(out)
+
+ for _, s := range args {
+ enc.Encode(s)
+ }
+
buf := make([]byte, config.BufSize)
- _, ew := c.conn.Write([]byte(msg))
+ _, ew := c.conn.Write([]byte(out.Bytes()))
if ew != nil {
return "", ew
}
diff --git a/pkg/netstring/netstring.go b/pkg/netstring/netstring.go
new file mode 100644
index 0000000..01b179a
--- /dev/null
+++ b/pkg/netstring/netstring.go
@@ -0,0 +1,55 @@
+package netstring
+
+import (
+ "errors"
+ "fmt"
+ "io"
+)
+
+type Encoder struct {
+ w io.Writer
+}
+
+type Decoder struct {
+ r io.Reader
+}
+
+var errBadFormat = errors.New("netstring: bad format")
+
+func NewEncoder(w io.Writer) *Encoder {
+ return &Encoder{w: w}
+}
+
+func NewDecoder(r io.Reader) *Decoder {
+ return &Decoder{r: r}
+}
+
+func (e *Encoder) Encode(s string) error {
+ _, err := fmt.Fprintf(e.w, "%d:%s,", len(s), s)
+ return err
+}
+
+func (d *Decoder) Decode() (out string, err error) {
+ var n int
+
+ _, err = fmt.Fscanf(d.r, "%d:", &n)
+ if err != nil {
+ return
+ }
+
+ buf := make([]byte, n + 1)
+
+ _, err = io.ReadFull(d.r, buf)
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return
+ }
+
+ if buf[n] != ',' {
+ return "", errBadFormat
+ }
+
+ return string(buf[:n]), nil
+}
diff --git a/pkg/server/env.go b/pkg/server/env.go
index eed6e67..1c24f0b 100644
--- a/pkg/server/env.go
+++ b/pkg/server/env.go
@@ -7,9 +7,9 @@ import (
func init() {
newCmd(envSet, "env", "set")
newCmd(envGet, "env", "get")
- newCmd(envShow, "env", "show")
- newCmd(envShow, "env", "print")
newCmd(envUnset, "env", "unset")
+
+ newCmd(envShow, "show", "env")
}
type env struct {
@@ -77,7 +77,7 @@ func envGet(r *request) {
if v, ok := r.c.s.env.get(r.args[0]); ok {
r.Print(v)
} else {
- r.Print("no such variable")
+ r.Fatal("no such variable")
}
}
@@ -85,20 +85,15 @@ func envUnset(r *request) {
r.expect(1)
if !r.c.s.env.unset(r.args[0]) {
- r.Print("no such variable")
+ r.Fatal("no such variable")
}
}
func envShow(r *request) {
- r.expect(0, 1)
+ r.expect(0)
- switch r.argc {
- case 0:
- r.c.s.env.each(func (k string, v string) bool {
- r.Println(k, v)
- return true
- })
- case 1:
- envGet(r)
- }
+ r.c.s.env.each(func (k string, v string) bool {
+ r.Println(k, v)
+ return true
+ })
}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 1248b1d..e2515cb 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -2,6 +2,7 @@ package server
import (
"tunnel/pkg/config"
+ "tunnel/pkg/netstring"
"strings"
"bytes"
"sync"
@@ -46,6 +47,8 @@ type request struct {
argc int
args []string
+ failed bool
+
out *bytes.Buffer
}
@@ -183,7 +186,7 @@ func (c *client) newRequest() *request {
r := &request{
c: c,
id: c.nextRid,
- out: bytes.NewBuffer(nil),
+ out: new(bytes.Buffer),
}
c.nextRid++
@@ -209,9 +212,7 @@ func (c *client) handle() {
r := c.newRequest()
- if r.parse(query) {
- r.run()
- }
+ r.run(query)
if r.out.Len() == 0 {
r.out.Write([]byte("\n"))
@@ -225,30 +226,58 @@ func (c *client) handle() {
}
}
-func (r *request) parse(query string) bool {
- c, args := getCmd(strings.Split(query, " "))
+func (r *request) decode(query string) []string {
+ dec := netstring.NewDecoder(bytes.NewReader([]byte(query)))
+ var t []string
+
+ for {
+ if s, err := dec.Decode(); err == nil {
+ t = append(t, s)
+ } else {
+ if err == io.EOF {
+ break
+ }
+
+ r.Fatal("failed to parse request")
+ }
+ }
+
+ return t
+}
+func (r *request) parse(query string) {
+ c, args := getCmd(r.decode(query))
if c == nil {
- r.Print("command not found")
- return false
+ r.Fatal("command not found")
+ }
+
+ for n, s := range args {
+ if strings.HasPrefix(s, "%") {
+ if v, ok := r.c.s.env.get(s[1:]); ok {
+ args[n] = v
+ } else {
+ r.Fatal("unbound variable ", s)
+ }
+ }
}
r.args = args
r.argc = len(args)
r.cmd = c
-
- return true
}
-func (r *request) run() {
- log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " "))
-
+func (r *request) run(query string) {
defer func () {
if e := recover(); e != nil {
+ r.failed = true
r.Print(e)
}
}()
+ r.parse(query)
+
+ log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " "))
+
r.cmd.f(r)
}
diff --git a/pkg/server/sleep.go b/pkg/server/sleep.go
index 14d22ad..bab9d9b 100644
--- a/pkg/server/sleep.go
+++ b/pkg/server/sleep.go
@@ -16,13 +16,11 @@ func sleep(r *request) {
n, err := strconv.Atoi(r.args[0])
if err != nil || n < 0 {
- r.Printf("invalid time interval '%s'", r.args[0])
- return
+ r.Fatalf("invalid time interval '%s'", r.args[0])
}
if n > maxSleep {
- r.Printf("no more than %d", maxSleep)
- return
+ r.Fatalf("no more than %d", maxSleep)
}
time.Sleep(time.Duration(n) * time.Second)