summaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
authorMikhail Osipov <mike.osipov@gmail.com>2020-02-17 11:56:43 +0300
committerMikhail Osipov <mike.osipov@gmail.com>2020-02-19 23:51:31 +0300
commitbd5339bff8bf5f5e877e94dfef265a22570a69c7 (patch)
tree5902df7a9f21c00d9c414f6b0c2b79aadfd84752 /pkg
parentdf935315c7201b7d42eb361b3ac3d36fe83e53e6 (diff)
first working version
Diffstat (limited to 'pkg')
-rw-r--r--pkg/server/cmds.go12
-rw-r--r--pkg/server/echo.go8
-rw-r--r--pkg/server/env.go35
-rw-r--r--pkg/server/exit.go8
-rw-r--r--pkg/server/module/alpha.go32
-rw-r--r--pkg/server/module/hex.go27
-rw-r--r--pkg/server/module/module.go48
-rw-r--r--pkg/server/queue/queue.go60
-rw-r--r--pkg/server/server.go30
-rw-r--r--pkg/server/sleep.go8
-rw-r--r--pkg/server/socket/socket.go151
-rw-r--r--pkg/server/status.go8
-rw-r--r--pkg/server/stream.go193
13 files changed, 563 insertions, 57 deletions
diff --git a/pkg/server/cmds.go b/pkg/server/cmds.go
index e383dac..6eabbde 100644
--- a/pkg/server/cmds.go
+++ b/pkg/server/cmds.go
@@ -26,10 +26,15 @@ func newNode() *node {
return &node{m: map[string]*node{}}
}
-func newCmd(f func (r *request), path ...string) {
+func newCmd(f func (r *request), where string) {
+ path := strings.Split(where, " ")
node := cmds
for _, name := range path {
+ if name == "" {
+ panic("invalid command path")
+ }
+
v := node.m[name]
if v == nil {
v = newNode()
@@ -40,12 +45,11 @@ func newCmd(f func (r *request), path ...string) {
}
if node.c != nil {
- s := strings.Join(path, " ")
- log.Panicf("handler already registered at '%s'", s)
+ log.Panicf("handler already registered at '%s'", where)
}
node.c = &cmd{
- name: strings.Join(path, " "),
+ name: where,
f: f,
}
}
diff --git a/pkg/server/echo.go b/pkg/server/echo.go
index 8980ed1..0387a3e 100644
--- a/pkg/server/echo.go
+++ b/pkg/server/echo.go
@@ -4,10 +4,10 @@ import (
"strings"
)
-func init() {
- newCmd(echo, "echo")
-}
-
func echo(r *request) {
r.Print(strings.Join(r.args, " "))
}
+
+func init() {
+ newCmd(echo, "echo")
+}
diff --git a/pkg/server/env.go b/pkg/server/env.go
index 1618c62..818310c 100644
--- a/pkg/server/env.go
+++ b/pkg/server/env.go
@@ -2,41 +2,25 @@ package server
import (
"regexp"
- "sync"
)
-func init() {
- newCmd(varGet, "var", "get")
- newCmd(varSet, "var", "set")
- newCmd(varDel, "var", "del")
- newCmd(varShow, "var", "show")
- newCmd(varClear, "var", "clear")
-}
-
type env struct {
m map[string]string
- sync.Mutex
}
const varNamePattern = "[a-zA-Z][a-zA-Z0-9]*"
var isValidVarName = regexp.MustCompile("^" + varNamePattern + "$").MatchString
-var varTokenRe = regexp.MustCompile("%" + varNamePattern)
+var varTokenRe = regexp.MustCompile("@" + varNamePattern)
func (e *env) get(key string) (string, bool) {
- e.Lock()
- defer e.Unlock()
-
v, ok := e.m[key]
return v, ok
}
func (e *env) set(key string, value string) {
- e.Lock()
- defer e.Unlock()
-
if e.m == nil {
e.m = make(map[string]string)
}
@@ -45,9 +29,6 @@ func (e *env) set(key string, value string) {
}
func (e *env) del(key string) bool {
- e.Lock()
- defer e.Unlock()
-
if e.m == nil {
return false
}
@@ -62,9 +43,6 @@ func (e *env) del(key string) bool {
}
func (e *env) each(f func (string, string) bool) {
- e.Lock()
- defer e.Unlock()
-
for k, v := range e.m {
if !f(k, v) {
break
@@ -73,9 +51,6 @@ func (e *env) each(f func (string, string) bool) {
}
func (e *env) clear() {
- e.Lock()
- defer e.Unlock()
-
e.m = nil
}
@@ -117,3 +92,11 @@ func varShow(r *request) {
func varClear(r *request) {
r.c.s.env.clear()
}
+
+func init() {
+ newCmd(varGet, "var get")
+ newCmd(varSet, "var set")
+ newCmd(varDel, "var del")
+ newCmd(varShow, "var show")
+ newCmd(varClear, "var clear")
+}
diff --git a/pkg/server/exit.go b/pkg/server/exit.go
index 5226ac3..785568d 100644
--- a/pkg/server/exit.go
+++ b/pkg/server/exit.go
@@ -1,9 +1,9 @@
package server
-func init() {
- newCmd(exit, "exit")
-}
-
func exit(r *request) {
r.c.s.Stop()
}
+
+func init() {
+ newCmd(exit, "exit")
+}
diff --git a/pkg/server/module/alpha.go b/pkg/server/module/alpha.go
new file mode 100644
index 0000000..be9032c
--- /dev/null
+++ b/pkg/server/module/alpha.go
@@ -0,0 +1,32 @@
+package module
+
+import (
+ "tunnel/pkg/server/queue"
+ "unicode"
+ "bufio"
+ "io"
+)
+
+func alpha(cb func (rune) rune) pipe {
+ return func (rq, wq queue.Q) error {
+ r := bufio.NewReader(rq.Reader())
+
+ for {
+ c, _, err := r.ReadRune()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ return err
+ }
+ wq <- []byte(string(cb(c)))
+ }
+
+ return nil
+ }
+}
+
+func init() {
+ register("lower", alpha(unicode.ToLower))
+ register("upper", alpha(unicode.ToUpper))
+}
diff --git a/pkg/server/module/hex.go b/pkg/server/module/hex.go
new file mode 100644
index 0000000..2ffd1fc
--- /dev/null
+++ b/pkg/server/module/hex.go
@@ -0,0 +1,27 @@
+package module
+
+import (
+ "tunnel/pkg/server/queue"
+ "encoding/hex"
+)
+
+func hexEncoder(rq, wq queue.Q) error {
+ enc := hex.NewEncoder(wq.Writer())
+
+ for b := range rq {
+ enc.Write(b)
+ }
+
+ return nil
+}
+
+func hexDecoder(rq, wq queue.Q) error {
+ r := hex.NewDecoder(rq.Reader())
+ w := wq.Writer()
+ return queue.IoCopy(r, w)
+}
+
+func init() {
+ register("hex", pipe(hexEncoder))
+ register("unhex", pipe(hexDecoder))
+}
diff --git a/pkg/server/module/module.go b/pkg/server/module/module.go
new file mode 100644
index 0000000..768a87b
--- /dev/null
+++ b/pkg/server/module/module.go
@@ -0,0 +1,48 @@
+package module
+
+import (
+ "tunnel/pkg/server/queue"
+ "fmt"
+ "log"
+)
+
+var modules = map[string]M{}
+
+type pipe func (rq, wq queue.Q) error
+
+type M interface {
+ Open() (pipe, pipe)
+}
+
+type reverse struct {
+ M
+}
+
+func Reverse(m M) M {
+ return &reverse{m}
+}
+
+func (r *reverse) Open() (pipe, pipe) {
+ p1, p2 := r.M.Open()
+ return p2, p1
+}
+
+func (p pipe) Open() (pipe, pipe) {
+ return p, nil
+}
+
+func register(name string, m M) {
+ if _, ok := modules[name]; ok {
+ log.Panicf("duplicate module name '%s'", name)
+ }
+
+ modules[name] = m
+}
+
+func New(name string) (M, error) {
+ if m, ok := modules[name]; ok {
+ return m, nil
+ }
+
+ return nil, fmt.Errorf("unknown module '%s'", name)
+}
diff --git a/pkg/server/queue/queue.go b/pkg/server/queue/queue.go
new file mode 100644
index 0000000..8d0f395
--- /dev/null
+++ b/pkg/server/queue/queue.go
@@ -0,0 +1,60 @@
+package queue
+
+import (
+ "io"
+)
+
+type Q chan []byte
+
+type reader struct {
+ b []byte
+ q Q
+}
+
+type writer struct {
+ q Q
+}
+
+func New() Q {
+ return make(Q)
+}
+
+func (q Q) Reader() io.Reader {
+ return &reader{q: q}
+}
+
+func (r *reader) Read(p []byte) (int, error) {
+ if len(r.b) == 0 {
+ r.b = <-r.q
+ if r.b == nil {
+ return 0, io.EOF
+ }
+ }
+
+ n := copy(p, r.b)
+ r.b = r.b[n:]
+
+ return n, nil
+}
+
+func (q Q) Writer() io.Writer {
+ return &writer{q: q}
+}
+
+func (w *writer) Write(p []byte) (int, error) {
+ buf := make([]byte, len(p))
+ copy(buf, p)
+ w.q <- buf
+
+ return len(p), nil
+}
+
+func IoCopy(r io.Reader, w io.Writer) error {
+ if _, err := io.Copy(w, r); err != nil {
+ if err != io.EOF {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 0e1bf24..4f012d0 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -4,6 +4,7 @@ import (
"tunnel/pkg/config"
"tunnel/pkg/netstring"
"strings"
+ "errors"
"bytes"
"sync"
"time"
@@ -18,10 +19,12 @@ type Server struct {
since time.Time
wg sync.WaitGroup
+ mu sync.Mutex
once sync.Once
done chan struct{}
+ streams streams
env env
nextCid uint64
@@ -47,11 +50,13 @@ type request struct {
argc int
args []string
- failed bool
-
out *bytes.Buffer
}
+type requestError string
+
+var errNotImplemented = errors.New("not implemented")
+
func (c *client) String() string {
return fmt.Sprintf("client(%d)", c.id)
}
@@ -73,15 +78,11 @@ func (r *request) Println(v ...interface{}) {
}
func (r *request) Fatal(v ...interface{}) {
- panic(fmt.Sprint(v...))
+ panic(requestError(fmt.Sprint(v...)))
}
func (r *request) Fatalf(format string, v ...interface{}) {
- panic(fmt.Sprintf(format, v...))
-}
-
-func (r *request) Fatalln(v ...interface{}) {
- panic(fmt.Sprintln(v...))
+ panic(requestError(fmt.Sprintf(format, v...)))
}
func (r *request) expect(c ...int) {
@@ -134,6 +135,7 @@ func New() (*Server, error) {
listen: listen,
since: time.Now(),
done: make(chan struct{}),
+ streams: make(streams),
}
return s, nil
@@ -294,9 +296,12 @@ func (r *request) parse(query string) {
func (r *request) run(query string) {
defer func () {
- if e := recover(); e != nil {
- r.failed = true
- r.Print(e)
+ switch err := recover().(type) {
+ case requestError:
+ r.Print(err)
+ default:
+ panic(err)
+ case nil:
}
}()
@@ -304,6 +309,9 @@ func (r *request) run(query string) {
log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " "))
+ r.c.s.mu.Lock()
+ defer r.c.s.mu.Unlock()
+
r.cmd.f(r)
}
diff --git a/pkg/server/sleep.go b/pkg/server/sleep.go
index bab9d9b..7d21135 100644
--- a/pkg/server/sleep.go
+++ b/pkg/server/sleep.go
@@ -7,10 +7,6 @@ import (
const maxSleep = 10
-func init() {
- newCmd(sleep, "sleep")
-}
-
func sleep(r *request) {
r.expect(1)
@@ -25,3 +21,7 @@ func sleep(r *request) {
time.Sleep(time.Duration(n) * time.Second)
}
+
+func init() {
+ newCmd(sleep, "sleep")
+}
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
new file mode 100644
index 0000000..f097a80
--- /dev/null
+++ b/pkg/server/socket/socket.go
@@ -0,0 +1,151 @@
+package socket
+
+import (
+ "tunnel/pkg/server/queue"
+ "strings"
+ "sync"
+ "fmt"
+ "log"
+ "net"
+)
+
+type Channel interface {
+ Send(wq queue.Q) error
+ Recv(rq queue.Q) error
+ Close()
+}
+
+type S interface {
+ Open() (Channel, error)
+ Close()
+}
+
+type listenSocket struct {
+ listen net.Listener
+}
+
+type dialSocket struct {
+ proto, addr string
+}
+
+type connChannel struct {
+ conn net.Conn
+ once sync.Once
+ cancel chan struct{}
+}
+
+func newConnChannel(conn net.Conn) Channel {
+ return &connChannel{conn: conn, cancel: make(chan struct{})}
+}
+
+func (cc *connChannel) Send(wq queue.Q) (err error) {
+ defer cc.shutdown(&err)
+ return queue.IoCopy(cc.conn, wq.Writer())
+}
+
+func (cc *connChannel) Recv(rq queue.Q) (err error) {
+ defer cc.shutdown(&err)
+
+ for b := range rq {
+ for len(b) > 0 {
+ n, err := cc.conn.Write(b)
+ if err != nil {
+ return err
+ }
+ b = b[n:]
+ }
+ }
+
+ return nil
+}
+
+func (cc *connChannel) String() string {
+ addr := cc.conn.RemoteAddr()
+ return fmt.Sprintf("%s/%s", addr.Network(), addr.String())
+}
+
+func (cc *connChannel) isCanceled() bool {
+ select {
+ case <- cc.cancel:
+ return true
+ default:
+ return false
+ }
+}
+
+func (cc *connChannel) shutdown(err *error) {
+ select {
+ case <- cc.cancel:
+ *err = nil
+ default:
+ cc.once.Do(func () {
+ close(cc.cancel)
+ log.Println("close", cc)
+ if e := cc.conn.Close(); e != nil && *err != nil {
+ *err = e
+ }
+ })
+ }
+}
+
+func (cc *connChannel) Close() {
+ var err error
+ cc.shutdown(&err)
+}
+
+func newListenSocket(proto, addr string) (S, error) {
+ if !strings.Contains(addr, ":") {
+ addr = ":" + addr
+ }
+
+ listen, err := net.Listen(proto, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ return &listenSocket{listen: listen}, nil
+}
+
+func (s *listenSocket) Open() (Channel, error) {
+ conn, err := s.listen.Accept()
+ if err != nil {
+ return nil, err
+ }
+ return newConnChannel(conn), nil
+}
+
+func (s *listenSocket) Close() {
+ s.listen.Close()
+}
+
+func newDialSocket(proto, addr string) (S, error) {
+ return &dialSocket{proto: proto, addr: addr}, nil
+}
+
+func (s *dialSocket) Open() (Channel, error) {
+ conn, err := net.Dial(s.proto, s.addr)
+ if err != nil {
+ return nil, err
+ }
+ return newConnChannel(conn), nil
+}
+
+func (s *dialSocket) Close() {
+}
+
+func New(desc string) (S, error) {
+ args := strings.Split(desc, "/")
+
+ if len(args) != 2 {
+ return nil, fmt.Errorf("bad socket '%s'", desc)
+ }
+
+ proto, addr := args[0], args[1]
+
+ switch proto {
+ case "tcp-listen": return newListenSocket("tcp", addr)
+ case "tcp": return newDialSocket("tcp", addr)
+ }
+
+ return nil, fmt.Errorf("bad socket '%s': unknown type", desc)
+}
diff --git a/pkg/server/status.go b/pkg/server/status.go
index aff3844..4689274 100644
--- a/pkg/server/status.go
+++ b/pkg/server/status.go
@@ -4,12 +4,12 @@ import (
"tunnel/pkg/config"
)
-func init() {
- newCmd(status, "status")
-}
-
func status(r *request) {
r.expect()
r.Printf("since %s", r.c.s.since.Format(config.TimeFormat))
}
+
+func init() {
+ newCmd(status, "status")
+}
diff --git a/pkg/server/stream.go b/pkg/server/stream.go
new file mode 100644
index 0000000..7c9cc82
--- /dev/null
+++ b/pkg/server/stream.go
@@ -0,0 +1,193 @@
+package server
+
+import (
+ "tunnel/pkg/server/module"
+ "tunnel/pkg/server/queue"
+ "tunnel/pkg/server/socket"
+ "strings"
+ "sort"
+ "fmt"
+ "log"
+)
+
+type stream struct {
+ id string
+ args string
+
+ in, out socket.S
+ m []module.M
+}
+
+type streams map[string]*stream
+
+func (s *stream) String() string {
+ return fmt.Sprintf("stream(%s)", s.id)
+}
+
+func (s *stream) Close() {
+ s.in.Close()
+ s.out.Close()
+}
+
+func (s *stream) run() {
+ for {
+ if in, err := s.in.Open(); err != nil {
+ log.Println(s, err)
+ } else {
+ log.Printf("%s accept %s", s, in)
+ go s.run2(in)
+ }
+ }
+}
+
+func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) {
+ watch := func (q queue.Q, f func (q queue.Q) error) {
+ if err := f(q); err != nil {
+ log.Println(s, err)
+ }
+ }
+
+ go func () {
+ watch(wq, c.Send)
+ close(wq)
+ }()
+
+ go watch(rq, c.Recv)
+}
+
+func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) {
+ go func () {
+ if err := f(rq, wq); err != nil {
+ log.Println(s, err)
+ }
+
+ close(wq)
+ }()
+}
+
+func (s *stream) run2(in socket.Channel) {
+ out, err := s.out.Open()
+ if err != nil {
+ log.Println(s, err)
+ in.Close()
+ return
+ }
+
+ rq, wq := queue.New(), queue.New()
+
+ s.watchChannel(rq, wq, in)
+
+ for _, m := range s.m {
+ send, recv := m.Open()
+ if send != nil {
+ q := queue.New()
+ s.watchPipe(wq, q, send)
+ wq = q
+ }
+ if recv != nil {
+ q := queue.New()
+ s.watchPipe(q, rq, recv)
+ rq = q
+ }
+ }
+
+ s.watchChannel(wq, rq, out)
+}
+
+func newStream(id string, args []string) (*stream, error) {
+ var in, out socket.S
+ var err error
+
+ n := len(args) - 1
+
+ if in, err = socket.New(args[0]); err != nil {
+ return nil, err
+ }
+
+ if out, err = socket.New(args[n]); err != nil {
+ in.Close()
+ return nil, err
+ }
+
+ s := &stream{
+ id: id,
+ args: strings.Join(args, " "),
+ in: in,
+ out: out,
+ }
+
+ reverse := false
+
+ for _, arg := range args[1:n] {
+ var m module.M
+
+ if arg == "-" {
+ reverse = true
+ continue
+ }
+
+ if arg == "+" {
+ reverse = false
+ continue
+ }
+
+ if m, err = module.New(arg); err != nil {
+ s.Close()
+ return nil, err
+ }
+
+ if reverse {
+ m = module.Reverse(m)
+ reverse = false
+ }
+
+ s.m = append(s.m, m)
+ }
+
+ if reverse {
+ s.Close()
+ return nil, fmt.Errorf("bad '-' usage")
+ }
+
+ go s.run()
+
+ return s, nil
+}
+
+func streamAdd(r *request) {
+ if r.argc < 3 {
+ r.Fatal("not enough args")
+ }
+
+ id := r.args[0]
+ if _, ok := r.c.s.streams[id]; ok {
+ r.Fatal("duplicate id")
+ }
+
+ s, err := newStream(id, r.args[1:])
+ if err != nil {
+ r.Fatal(err)
+ }
+
+ r.c.s.streams[id] = s
+}
+
+func streamShow(r *request) {
+ var keys []string
+
+ for k := range r.c.s.streams {
+ keys = append(keys, k)
+ }
+
+ sort.Strings(keys)
+
+ for _, k := range keys {
+ s := r.c.s.streams[k]
+ r.Println(s.id, s.args)
+ }
+}
+
+func init() {
+ newCmd(streamAdd, "add")
+ newCmd(streamShow, "show")
+}