summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--TODO4
-rw-r--r--pkg/server/echo.go3
-rw-r--r--pkg/server/env.go105
-rw-r--r--pkg/server/server.go201
-rw-r--r--pkg/server/sleep.go9
-rw-r--r--pkg/server/status.go5
6 files changed, 276 insertions, 51 deletions
diff --git a/TODO b/TODO
index c9c83b6..ed039ac 100644
--- a/TODO
+++ b/TODO
@@ -1,4 +1,4 @@
1. DONE ./pkg/server/server.go make request
-2. make chain commands
+2. DONE make chain commands
3. add help command
-4. env set/get
+4. DONE env set/get
diff --git a/pkg/server/echo.go b/pkg/server/echo.go
index c8dce31..00f0c2d 100644
--- a/pkg/server/echo.go
+++ b/pkg/server/echo.go
@@ -2,7 +2,6 @@ package server
import (
"strings"
- "fmt"
)
func init() {
@@ -10,5 +9,5 @@ func init() {
}
func echo(r *request) {
- fmt.Fprint(r.out, strings.Join(r.args, " "))
+ r.Print(strings.Join(r.args, " "))
}
diff --git a/pkg/server/env.go b/pkg/server/env.go
new file mode 100644
index 0000000..769342a
--- /dev/null
+++ b/pkg/server/env.go
@@ -0,0 +1,105 @@
+package server
+
+import (
+ "sync"
+)
+
+func init() {
+ setHandler(envSet, "env", "set")
+ setHandler(envGet, "env", "get")
+ setHandler(envShow, "env", "show")
+ setHandler(envShow, "env", "print")
+ setHandler(envUnset, "env", "unset")
+}
+
+type env struct {
+ m map[string]string
+ sync.Mutex
+}
+
+func (e *env) get(key string) (string, bool) {
+ e.Lock()
+ defer e.Unlock()
+
+ v, ok := e.m[key]
+
+ return v, ok
+}
+
+func (e *env) set(key string, value string) {
+ e.Lock()
+ defer e.Unlock()
+
+ if e.m == nil {
+ e.m = make(map[string]string)
+ }
+
+ e.m[key] = value
+}
+
+func (e *env) unset(key string) bool {
+ e.Lock()
+ defer e.Unlock()
+
+ if e.m == nil {
+ return false
+ }
+
+ if _, ok := e.m[key]; !ok {
+ return false
+ }
+
+ delete(e.m, key)
+
+ return true
+}
+
+func (e *env) each(f func (string, string) bool) {
+ e.Lock()
+ defer e.Unlock()
+
+ for k, v := range e.m {
+ if !f(k, v) {
+ break
+ }
+ }
+}
+
+func envSet(r *request) {
+ r.expect(2)
+
+ r.c.s.env.set(r.args[0], r.args[1])
+}
+
+func envGet(r *request) {
+ r.expect(1)
+
+ if v, ok := r.c.s.env.get(r.args[0]); ok {
+ r.Print(v)
+ } else {
+ r.Print("no such variable")
+ }
+}
+
+func envUnset(r *request) {
+ r.expect(1)
+
+ if !r.c.s.env.unset(r.args[0]) {
+ r.Print("no such variable")
+ }
+}
+
+func envShow(r *request) {
+ r.expect(0, 1)
+
+ switch r.argc {
+ case 0:
+ r.c.s.env.each(func (k string, v string) bool {
+ r.Println(k, v)
+ return true
+ })
+
+ case 1:
+ envGet(r)
+ }
+}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index f9ec602..0c3d2af 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -12,43 +12,51 @@ import (
"io"
)
-type handler func (r *request)
+type cmd struct {
+ name string
+ f func (r *request)
+}
type node struct {
- h handler
- nodes map[string]node
+ c *cmd
+ m map[string]*node
}
-var handlers = map[string]handler{}
+var cmds = newNode()
type Server struct {
listen net.Listener
since time.Time
wg sync.WaitGroup
- m sync.Mutex
- done bool
+ once sync.Once
+ done chan struct{}
+
+ env env
nextCid uint64
}
type client struct {
- s *Server
-
id uint64
+ s *Server
+
conn net.Conn
nextRid uint64
}
type request struct {
+ id uint64
+
c *client
- id uint64
+ cmd *cmd
name string
+ argc int
args []string
out *bytes.Buffer
@@ -62,27 +70,121 @@ func (r *request) String() string {
return fmt.Sprintf("request(%d)", r.id)
}
-func setHandler(h handler, names ...string) {
- var path []string
+func (r *request) Print(v ...interface{}) {
+ fmt.Fprint(r.out, v...)
+}
- for _, s := range names {
- if _, ok := handlers[s]; ok {
- err := fmt.Sprintf("handler '%s' already registered at '%s'",
- s, strings.Join(path, " "))
- panic(err)
+func (r *request) Printf(format string, v ...interface{}) {
+ fmt.Fprintf(r.out, format, v...)
+}
+
+func (r *request) Println(v ...interface{}) {
+ fmt.Fprintln(r.out, v...)
+}
+
+func (r *request) Fatal(v ...interface{}) {
+ panic(fmt.Sprint(v...))
+}
+
+func (r *request) Fatalf(format string, v ...interface{}) {
+ panic(fmt.Sprintf(format, v...))
+}
+
+func (r *request) Fatalln(v ...interface{}) {
+ panic(fmt.Sprintln(v...))
+}
+
+func (r *request) expect(c ...int) {
+ desc := func (n int) string {
+ var sep string
+
+ if n == 1 {
+ sep = "is"
+ } else {
+ sep = "are"
}
- path = append(path, s)
+ return fmt.Sprintf("%d %s expected", n, sep)
+ }
+
+ switch len(c) {
+ case 0:
+ if r.argc > 0 {
+ r.Fatal("args are not expected")
+ }
- handlers[s] = h
- break
+ case 1:
+ if r.argc < c[0] {
+ r.Fatal("not enough args: ", desc(c[0]))
+ }
+
+ if r.argc > c[0] {
+ r.Fatal("too many args: ", desc(c[0]))
+ }
+
+ case 2:
+ if r.argc < c[0] {
+ r.Fatal("not enough args: at least ", desc(c[0]))
+ }
+
+ if r.argc > c[1] {
+ r.Fatal("too many args: no more than ", desc(c[1]))
+ }
}
}
+func newNode() *node {
+ return &node{m: map[string]*node{}}
+}
+
+func setHandler(f func (r *request), path ...string) {
+ node := cmds
+
+ for _, name := range path {
+ v := node.m[name]
+ if v == nil {
+ v = newNode()
+ node.m[name] = v
+ }
+
+ node = v
+ }
+
+ if node.c != nil {
+ s := strings.Join(path, " ")
+ log.Panicf("handler already registered at '%s'", s)
+ }
+
+ node.c = &cmd{
+ name: strings.Join(path, " "),
+ f: f,
+ }
+}
+
+func getHandler(path []string) (*cmd, []string) {
+ node := cmds
+
+ for n, name := range path {
+ node = node.m[name]
+ if node == nil {
+ return nil, nil
+ }
+
+ if node.c != nil {
+ return node.c, path[n + 1:]
+ }
+ }
+
+ return nil, nil
+}
+
func (s *Server) isDone() bool {
- s.m.Lock()
- defer s.m.Unlock()
- return s.done
+ select {
+ case <- s.done:
+ return true
+ default:
+ return false
+ }
}
func New() (*Server, error) {
@@ -94,6 +196,7 @@ func New() (*Server, error) {
s := &Server{
listen: listen,
since: time.Now(),
+ done: make(chan struct{}),
}
return s, nil
@@ -124,11 +227,10 @@ func (s *Server) Run() {
}
func (s *Server) Stop() {
- s.m.Lock()
- s.done = true
- s.m.Unlock()
-
- s.listen.Close()
+ s.once.Do(func () {
+ close(s.done)
+ s.listen.Close()
+ })
}
func (s *Server) newClient(conn net.Conn) *client {
@@ -143,14 +245,10 @@ func (s *Server) newClient(conn net.Conn) *client {
return c
}
-func (c *client) newRequest(msg string) *request {
- args := strings.Split(msg, " ")
-
+func (c *client) newRequest() *request {
r := &request{
c: c,
id: c.nextRid,
- name: args[0],
- args: args[1:],
out: bytes.NewBuffer(nil),
}
@@ -173,14 +271,12 @@ func (c *client) handle() {
break
}
- msg := string(buf[:nr])
- r := c.newRequest(msg)
+ query := string(buf[:nr])
- if h, ok := handlers[r.name]; ok {
- log.Println(c, r, "run:", msg)
- h(r)
- } else {
- fmt.Fprint(r.out, "unknown command")
+ r := c.newRequest()
+
+ if r.parse(query) {
+ r.run()
}
if r.out.Len() == 0 {
@@ -195,6 +291,33 @@ func (c *client) handle() {
}
}
+func (r *request) parse(query string) bool {
+ c, args := getHandler(strings.Split(query, " "))
+
+ if c == nil {
+ r.Print("unknown command")
+ return false
+ }
+
+ r.args = args
+ r.argc = len(args)
+ r.cmd = c
+
+ return true
+}
+
+func (r *request) run() {
+ log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " "))
+
+ defer func () {
+ if e := recover(); e != nil {
+ r.Print(e)
+ }
+ }()
+
+ r.cmd.f(r)
+}
+
func (c *client) close() {
log.Println(c, "close")
diff --git a/pkg/server/sleep.go b/pkg/server/sleep.go
index 53c85ea..fc94079 100644
--- a/pkg/server/sleep.go
+++ b/pkg/server/sleep.go
@@ -3,7 +3,6 @@ package server
import (
"strconv"
"time"
- "fmt"
)
const maxSleep = 10
@@ -13,18 +12,16 @@ func init() {
}
func sleep(r *request) {
- if len(r.args) == 0 {
- return
- }
+ r.expect(1)
n, err := strconv.Atoi(r.args[0])
if err != nil || n < 0 {
- fmt.Fprintf(r.out, "invalid time interval '%s'", r.args[0])
+ r.Printf("invalid time interval '%s'", r.args[0])
return
}
if n > maxSleep {
- fmt.Fprintf(r.out, "no more than %d", maxSleep)
+ r.Printf("no more than %d", maxSleep)
return
}
diff --git a/pkg/server/status.go b/pkg/server/status.go
index 462ac76..bc725a7 100644
--- a/pkg/server/status.go
+++ b/pkg/server/status.go
@@ -2,7 +2,6 @@ package server
import (
"tunnel/pkg/config"
- "fmt"
)
func init() {
@@ -10,5 +9,7 @@ func init() {
}
func status(r *request) {
- fmt.Fprintf(r.out, "since %s", r.c.s.since.Format(config.TimeFormat))
+ r.expect()
+
+ r.Printf("since %s", r.c.s.since.Format(config.TimeFormat))
}