From bd5339bff8bf5f5e877e94dfef265a22570a69c7 Mon Sep 17 00:00:00 2001 From: Mikhail Osipov Date: Mon, 17 Feb 2020 11:56:43 +0300 Subject: first working version --- pkg/server/cmds.go | 12 ++- pkg/server/echo.go | 8 +- pkg/server/env.go | 35 +++----- pkg/server/exit.go | 8 +- pkg/server/module/alpha.go | 32 ++++++++ pkg/server/module/hex.go | 27 +++++++ pkg/server/module/module.go | 48 +++++++++++ pkg/server/queue/queue.go | 60 ++++++++++++++ pkg/server/server.go | 30 ++++--- pkg/server/sleep.go | 8 +- pkg/server/socket/socket.go | 151 ++++++++++++++++++++++++++++++++++ pkg/server/status.go | 8 +- pkg/server/stream.go | 193 ++++++++++++++++++++++++++++++++++++++++++++ 13 files changed, 563 insertions(+), 57 deletions(-) create mode 100644 pkg/server/module/alpha.go create mode 100644 pkg/server/module/hex.go create mode 100644 pkg/server/module/module.go create mode 100644 pkg/server/queue/queue.go create mode 100644 pkg/server/socket/socket.go create mode 100644 pkg/server/stream.go (limited to 'pkg/server') diff --git a/pkg/server/cmds.go b/pkg/server/cmds.go index e383dac..6eabbde 100644 --- a/pkg/server/cmds.go +++ b/pkg/server/cmds.go @@ -26,10 +26,15 @@ func newNode() *node { return &node{m: map[string]*node{}} } -func newCmd(f func (r *request), path ...string) { +func newCmd(f func (r *request), where string) { + path := strings.Split(where, " ") node := cmds for _, name := range path { + if name == "" { + panic("invalid command path") + } + v := node.m[name] if v == nil { v = newNode() @@ -40,12 +45,11 @@ func newCmd(f func (r *request), path ...string) { } if node.c != nil { - s := strings.Join(path, " ") - log.Panicf("handler already registered at '%s'", s) + log.Panicf("handler already registered at '%s'", where) } node.c = &cmd{ - name: strings.Join(path, " "), + name: where, f: f, } } diff --git a/pkg/server/echo.go b/pkg/server/echo.go index 8980ed1..0387a3e 100644 --- a/pkg/server/echo.go +++ b/pkg/server/echo.go @@ -4,10 +4,10 @@ import ( "strings" ) -func init() { - newCmd(echo, "echo") -} - func echo(r *request) { r.Print(strings.Join(r.args, " ")) } + +func init() { + newCmd(echo, "echo") +} diff --git a/pkg/server/env.go b/pkg/server/env.go index 1618c62..818310c 100644 --- a/pkg/server/env.go +++ b/pkg/server/env.go @@ -2,41 +2,25 @@ package server import ( "regexp" - "sync" ) -func init() { - newCmd(varGet, "var", "get") - newCmd(varSet, "var", "set") - newCmd(varDel, "var", "del") - newCmd(varShow, "var", "show") - newCmd(varClear, "var", "clear") -} - type env struct { m map[string]string - sync.Mutex } const varNamePattern = "[a-zA-Z][a-zA-Z0-9]*" var isValidVarName = regexp.MustCompile("^" + varNamePattern + "$").MatchString -var varTokenRe = regexp.MustCompile("%" + varNamePattern) +var varTokenRe = regexp.MustCompile("@" + varNamePattern) func (e *env) get(key string) (string, bool) { - e.Lock() - defer e.Unlock() - v, ok := e.m[key] return v, ok } func (e *env) set(key string, value string) { - e.Lock() - defer e.Unlock() - if e.m == nil { e.m = make(map[string]string) } @@ -45,9 +29,6 @@ func (e *env) set(key string, value string) { } func (e *env) del(key string) bool { - e.Lock() - defer e.Unlock() - if e.m == nil { return false } @@ -62,9 +43,6 @@ func (e *env) del(key string) bool { } func (e *env) each(f func (string, string) bool) { - e.Lock() - defer e.Unlock() - for k, v := range e.m { if !f(k, v) { break @@ -73,9 +51,6 @@ func (e *env) each(f func (string, string) bool) { } func (e *env) clear() { - e.Lock() - defer e.Unlock() - e.m = nil } @@ -117,3 +92,11 @@ func varShow(r *request) { func varClear(r *request) { r.c.s.env.clear() } + +func init() { + newCmd(varGet, "var get") + newCmd(varSet, "var set") + newCmd(varDel, "var del") + newCmd(varShow, "var show") + newCmd(varClear, "var clear") +} diff --git a/pkg/server/exit.go b/pkg/server/exit.go index 5226ac3..785568d 100644 --- a/pkg/server/exit.go +++ b/pkg/server/exit.go @@ -1,9 +1,9 @@ package server -func init() { - newCmd(exit, "exit") -} - func exit(r *request) { r.c.s.Stop() } + +func init() { + newCmd(exit, "exit") +} diff --git a/pkg/server/module/alpha.go b/pkg/server/module/alpha.go new file mode 100644 index 0000000..be9032c --- /dev/null +++ b/pkg/server/module/alpha.go @@ -0,0 +1,32 @@ +package module + +import ( + "tunnel/pkg/server/queue" + "unicode" + "bufio" + "io" +) + +func alpha(cb func (rune) rune) pipe { + return func (rq, wq queue.Q) error { + r := bufio.NewReader(rq.Reader()) + + for { + c, _, err := r.ReadRune() + if err != nil { + if err == io.EOF { + break + } + return err + } + wq <- []byte(string(cb(c))) + } + + return nil + } +} + +func init() { + register("lower", alpha(unicode.ToLower)) + register("upper", alpha(unicode.ToUpper)) +} diff --git a/pkg/server/module/hex.go b/pkg/server/module/hex.go new file mode 100644 index 0000000..2ffd1fc --- /dev/null +++ b/pkg/server/module/hex.go @@ -0,0 +1,27 @@ +package module + +import ( + "tunnel/pkg/server/queue" + "encoding/hex" +) + +func hexEncoder(rq, wq queue.Q) error { + enc := hex.NewEncoder(wq.Writer()) + + for b := range rq { + enc.Write(b) + } + + return nil +} + +func hexDecoder(rq, wq queue.Q) error { + r := hex.NewDecoder(rq.Reader()) + w := wq.Writer() + return queue.IoCopy(r, w) +} + +func init() { + register("hex", pipe(hexEncoder)) + register("unhex", pipe(hexDecoder)) +} diff --git a/pkg/server/module/module.go b/pkg/server/module/module.go new file mode 100644 index 0000000..768a87b --- /dev/null +++ b/pkg/server/module/module.go @@ -0,0 +1,48 @@ +package module + +import ( + "tunnel/pkg/server/queue" + "fmt" + "log" +) + +var modules = map[string]M{} + +type pipe func (rq, wq queue.Q) error + +type M interface { + Open() (pipe, pipe) +} + +type reverse struct { + M +} + +func Reverse(m M) M { + return &reverse{m} +} + +func (r *reverse) Open() (pipe, pipe) { + p1, p2 := r.M.Open() + return p2, p1 +} + +func (p pipe) Open() (pipe, pipe) { + return p, nil +} + +func register(name string, m M) { + if _, ok := modules[name]; ok { + log.Panicf("duplicate module name '%s'", name) + } + + modules[name] = m +} + +func New(name string) (M, error) { + if m, ok := modules[name]; ok { + return m, nil + } + + return nil, fmt.Errorf("unknown module '%s'", name) +} diff --git a/pkg/server/queue/queue.go b/pkg/server/queue/queue.go new file mode 100644 index 0000000..8d0f395 --- /dev/null +++ b/pkg/server/queue/queue.go @@ -0,0 +1,60 @@ +package queue + +import ( + "io" +) + +type Q chan []byte + +type reader struct { + b []byte + q Q +} + +type writer struct { + q Q +} + +func New() Q { + return make(Q) +} + +func (q Q) Reader() io.Reader { + return &reader{q: q} +} + +func (r *reader) Read(p []byte) (int, error) { + if len(r.b) == 0 { + r.b = <-r.q + if r.b == nil { + return 0, io.EOF + } + } + + n := copy(p, r.b) + r.b = r.b[n:] + + return n, nil +} + +func (q Q) Writer() io.Writer { + return &writer{q: q} +} + +func (w *writer) Write(p []byte) (int, error) { + buf := make([]byte, len(p)) + copy(buf, p) + w.q <- buf + + return len(p), nil +} + +func IoCopy(r io.Reader, w io.Writer) error { + if _, err := io.Copy(w, r); err != nil { + if err != io.EOF { + return err + } + } + + return nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 0e1bf24..4f012d0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -4,6 +4,7 @@ import ( "tunnel/pkg/config" "tunnel/pkg/netstring" "strings" + "errors" "bytes" "sync" "time" @@ -18,10 +19,12 @@ type Server struct { since time.Time wg sync.WaitGroup + mu sync.Mutex once sync.Once done chan struct{} + streams streams env env nextCid uint64 @@ -47,11 +50,13 @@ type request struct { argc int args []string - failed bool - out *bytes.Buffer } +type requestError string + +var errNotImplemented = errors.New("not implemented") + func (c *client) String() string { return fmt.Sprintf("client(%d)", c.id) } @@ -73,15 +78,11 @@ func (r *request) Println(v ...interface{}) { } func (r *request) Fatal(v ...interface{}) { - panic(fmt.Sprint(v...)) + panic(requestError(fmt.Sprint(v...))) } func (r *request) Fatalf(format string, v ...interface{}) { - panic(fmt.Sprintf(format, v...)) -} - -func (r *request) Fatalln(v ...interface{}) { - panic(fmt.Sprintln(v...)) + panic(requestError(fmt.Sprintf(format, v...))) } func (r *request) expect(c ...int) { @@ -134,6 +135,7 @@ func New() (*Server, error) { listen: listen, since: time.Now(), done: make(chan struct{}), + streams: make(streams), } return s, nil @@ -294,9 +296,12 @@ func (r *request) parse(query string) { func (r *request) run(query string) { defer func () { - if e := recover(); e != nil { - r.failed = true - r.Print(e) + switch err := recover().(type) { + case requestError: + r.Print(err) + default: + panic(err) + case nil: } }() @@ -304,6 +309,9 @@ func (r *request) run(query string) { log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " ")) + r.c.s.mu.Lock() + defer r.c.s.mu.Unlock() + r.cmd.f(r) } diff --git a/pkg/server/sleep.go b/pkg/server/sleep.go index bab9d9b..7d21135 100644 --- a/pkg/server/sleep.go +++ b/pkg/server/sleep.go @@ -7,10 +7,6 @@ import ( const maxSleep = 10 -func init() { - newCmd(sleep, "sleep") -} - func sleep(r *request) { r.expect(1) @@ -25,3 +21,7 @@ func sleep(r *request) { time.Sleep(time.Duration(n) * time.Second) } + +func init() { + newCmd(sleep, "sleep") +} diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go new file mode 100644 index 0000000..f097a80 --- /dev/null +++ b/pkg/server/socket/socket.go @@ -0,0 +1,151 @@ +package socket + +import ( + "tunnel/pkg/server/queue" + "strings" + "sync" + "fmt" + "log" + "net" +) + +type Channel interface { + Send(wq queue.Q) error + Recv(rq queue.Q) error + Close() +} + +type S interface { + Open() (Channel, error) + Close() +} + +type listenSocket struct { + listen net.Listener +} + +type dialSocket struct { + proto, addr string +} + +type connChannel struct { + conn net.Conn + once sync.Once + cancel chan struct{} +} + +func newConnChannel(conn net.Conn) Channel { + return &connChannel{conn: conn, cancel: make(chan struct{})} +} + +func (cc *connChannel) Send(wq queue.Q) (err error) { + defer cc.shutdown(&err) + return queue.IoCopy(cc.conn, wq.Writer()) +} + +func (cc *connChannel) Recv(rq queue.Q) (err error) { + defer cc.shutdown(&err) + + for b := range rq { + for len(b) > 0 { + n, err := cc.conn.Write(b) + if err != nil { + return err + } + b = b[n:] + } + } + + return nil +} + +func (cc *connChannel) String() string { + addr := cc.conn.RemoteAddr() + return fmt.Sprintf("%s/%s", addr.Network(), addr.String()) +} + +func (cc *connChannel) isCanceled() bool { + select { + case <- cc.cancel: + return true + default: + return false + } +} + +func (cc *connChannel) shutdown(err *error) { + select { + case <- cc.cancel: + *err = nil + default: + cc.once.Do(func () { + close(cc.cancel) + log.Println("close", cc) + if e := cc.conn.Close(); e != nil && *err != nil { + *err = e + } + }) + } +} + +func (cc *connChannel) Close() { + var err error + cc.shutdown(&err) +} + +func newListenSocket(proto, addr string) (S, error) { + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + listen, err := net.Listen(proto, addr) + if err != nil { + return nil, err + } + + return &listenSocket{listen: listen}, nil +} + +func (s *listenSocket) Open() (Channel, error) { + conn, err := s.listen.Accept() + if err != nil { + return nil, err + } + return newConnChannel(conn), nil +} + +func (s *listenSocket) Close() { + s.listen.Close() +} + +func newDialSocket(proto, addr string) (S, error) { + return &dialSocket{proto: proto, addr: addr}, nil +} + +func (s *dialSocket) Open() (Channel, error) { + conn, err := net.Dial(s.proto, s.addr) + if err != nil { + return nil, err + } + return newConnChannel(conn), nil +} + +func (s *dialSocket) Close() { +} + +func New(desc string) (S, error) { + args := strings.Split(desc, "/") + + if len(args) != 2 { + return nil, fmt.Errorf("bad socket '%s'", desc) + } + + proto, addr := args[0], args[1] + + switch proto { + case "tcp-listen": return newListenSocket("tcp", addr) + case "tcp": return newDialSocket("tcp", addr) + } + + return nil, fmt.Errorf("bad socket '%s': unknown type", desc) +} diff --git a/pkg/server/status.go b/pkg/server/status.go index aff3844..4689274 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -4,12 +4,12 @@ import ( "tunnel/pkg/config" ) -func init() { - newCmd(status, "status") -} - func status(r *request) { r.expect() r.Printf("since %s", r.c.s.since.Format(config.TimeFormat)) } + +func init() { + newCmd(status, "status") +} diff --git a/pkg/server/stream.go b/pkg/server/stream.go new file mode 100644 index 0000000..7c9cc82 --- /dev/null +++ b/pkg/server/stream.go @@ -0,0 +1,193 @@ +package server + +import ( + "tunnel/pkg/server/module" + "tunnel/pkg/server/queue" + "tunnel/pkg/server/socket" + "strings" + "sort" + "fmt" + "log" +) + +type stream struct { + id string + args string + + in, out socket.S + m []module.M +} + +type streams map[string]*stream + +func (s *stream) String() string { + return fmt.Sprintf("stream(%s)", s.id) +} + +func (s *stream) Close() { + s.in.Close() + s.out.Close() +} + +func (s *stream) run() { + for { + if in, err := s.in.Open(); err != nil { + log.Println(s, err) + } else { + log.Printf("%s accept %s", s, in) + go s.run2(in) + } + } +} + +func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { + watch := func (q queue.Q, f func (q queue.Q) error) { + if err := f(q); err != nil { + log.Println(s, err) + } + } + + go func () { + watch(wq, c.Send) + close(wq) + }() + + go watch(rq, c.Recv) +} + +func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) { + go func () { + if err := f(rq, wq); err != nil { + log.Println(s, err) + } + + close(wq) + }() +} + +func (s *stream) run2(in socket.Channel) { + out, err := s.out.Open() + if err != nil { + log.Println(s, err) + in.Close() + return + } + + rq, wq := queue.New(), queue.New() + + s.watchChannel(rq, wq, in) + + for _, m := range s.m { + send, recv := m.Open() + if send != nil { + q := queue.New() + s.watchPipe(wq, q, send) + wq = q + } + if recv != nil { + q := queue.New() + s.watchPipe(q, rq, recv) + rq = q + } + } + + s.watchChannel(wq, rq, out) +} + +func newStream(id string, args []string) (*stream, error) { + var in, out socket.S + var err error + + n := len(args) - 1 + + if in, err = socket.New(args[0]); err != nil { + return nil, err + } + + if out, err = socket.New(args[n]); err != nil { + in.Close() + return nil, err + } + + s := &stream{ + id: id, + args: strings.Join(args, " "), + in: in, + out: out, + } + + reverse := false + + for _, arg := range args[1:n] { + var m module.M + + if arg == "-" { + reverse = true + continue + } + + if arg == "+" { + reverse = false + continue + } + + if m, err = module.New(arg); err != nil { + s.Close() + return nil, err + } + + if reverse { + m = module.Reverse(m) + reverse = false + } + + s.m = append(s.m, m) + } + + if reverse { + s.Close() + return nil, fmt.Errorf("bad '-' usage") + } + + go s.run() + + return s, nil +} + +func streamAdd(r *request) { + if r.argc < 3 { + r.Fatal("not enough args") + } + + id := r.args[0] + if _, ok := r.c.s.streams[id]; ok { + r.Fatal("duplicate id") + } + + s, err := newStream(id, r.args[1:]) + if err != nil { + r.Fatal(err) + } + + r.c.s.streams[id] = s +} + +func streamShow(r *request) { + var keys []string + + for k := range r.c.s.streams { + keys = append(keys, k) + } + + sort.Strings(keys) + + for _, k := range keys { + s := r.c.s.streams[k] + r.Println(s.id, s.args) + } +} + +func init() { + newCmd(streamAdd, "add") + newCmd(streamShow, "show") +} -- cgit v1.2.3-70-g09d2