summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--TODO2
-rw-r--r--pkg/server/env.go68
-rw-r--r--pkg/server/env/env.go114
-rw-r--r--pkg/server/module/alpha.go6
-rw-r--r--pkg/server/module/auth.go32
-rw-r--r--pkg/server/module/hex.go4
-rw-r--r--pkg/server/module/module.go55
-rw-r--r--pkg/server/opts/opts.go21
-rw-r--r--pkg/server/queue/queue.go11
-rw-r--r--pkg/server/server.go32
-rw-r--r--pkg/server/socket/socket.go26
-rw-r--r--pkg/server/tunnel.go127
12 files changed, 322 insertions, 176 deletions
diff --git a/TODO b/TODO
index 3e92628..4e079dd 100644
--- a/TODO
+++ b/TODO
@@ -37,3 +37,5 @@ note:
9. tunnel enable/disable
10. config from file
11. system/user? unix control socket location
+12. modules: auth(chap), enc, dec
+13. print module name when stream closed by error
diff --git a/pkg/server/env.go b/pkg/server/env.go
index 818310c..7b6b36d 100644
--- a/pkg/server/env.go
+++ b/pkg/server/env.go
@@ -1,63 +1,9 @@
package server
-import (
- "regexp"
-)
-
-type env struct {
- m map[string]string
-}
-
-const varNamePattern = "[a-zA-Z][a-zA-Z0-9]*"
-
-var isValidVarName = regexp.MustCompile("^" + varNamePattern + "$").MatchString
-
-var varTokenRe = regexp.MustCompile("@" + varNamePattern)
-
-func (e *env) get(key string) (string, bool) {
- v, ok := e.m[key]
-
- return v, ok
-}
-
-func (e *env) set(key string, value string) {
- if e.m == nil {
- e.m = make(map[string]string)
- }
-
- e.m[key] = value
-}
-
-func (e *env) del(key string) bool {
- if e.m == nil {
- return false
- }
-
- if _, ok := e.m[key]; !ok {
- return false
- }
-
- delete(e.m, key)
-
- return true
-}
-
-func (e *env) each(f func (string, string) bool) {
- for k, v := range e.m {
- if !f(k, v) {
- break
- }
- }
-}
-
-func (e *env) clear() {
- e.m = nil
-}
-
func varGet(r *request) {
r.expect(1)
- if v, ok := r.c.s.env.get(r.args[0]); ok {
+ if v, ok := r.c.s.env.Find(r.args[0]); ok {
r.Print(v)
} else {
r.Fatal("no such variable")
@@ -67,30 +13,28 @@ func varGet(r *request) {
func varSet(r *request) {
r.expect(2)
- if !isValidVarName(r.args[0]) {
- r.Fatal("bad variable name")
+ if err := r.c.s.env.Set(r.args[0], r.args[1]); err != nil {
+ r.Fatal(err)
}
-
- r.c.s.env.set(r.args[0], r.args[1])
}
func varDel(r *request) {
r.expect(1)
- if !r.c.s.env.del(r.args[0]) {
+ if !r.c.s.env.Del(r.args[0]) {
r.Fatal("no such variable")
}
}
func varShow(r *request) {
- r.c.s.env.each(func (k string, v string) bool {
+ r.c.s.env.Each(func (k string, v string) bool {
r.Println(k, v)
return true
})
}
func varClear(r *request) {
- r.c.s.env.clear()
+ r.c.s.env.Clear()
}
func init() {
diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go
new file mode 100644
index 0000000..ab47ae8
--- /dev/null
+++ b/pkg/server/env/env.go
@@ -0,0 +1,114 @@
+package env
+
+import (
+ "errors"
+ "regexp"
+ "sync"
+)
+
+type env struct {
+ m map[string]string
+ sync.Mutex
+}
+
+type Env struct {
+ *env
+}
+
+const namePattern = "[a-zA-Z][a-zA-Z0-9]*"
+var isNamePattern = regexp.MustCompile("^" + namePattern + "$").MatchString
+var namePatternRe = regexp.MustCompile("@" + namePattern)
+
+var errBadVariable = errors.New("bad variable name")
+
+func New() Env {
+ return Env{new(env)}
+}
+
+func (e *env) Find(key string) (string, bool) {
+ e.Lock()
+ defer e.Unlock()
+
+ v, ok := e.m[key]
+
+ return v, ok
+}
+
+func (e *env) Get(key string) string {
+ v, _ := e.Find(key)
+ return v
+}
+
+func (e *env) Set(key string, value string) error {
+ if !isNamePattern(key) {
+ return errBadVariable
+ }
+
+ e.Lock()
+ defer e.Unlock()
+
+ if e.m == nil {
+ e.m = make(map[string]string)
+ }
+
+ e.m[key] = value
+
+ return nil
+}
+
+func (e *env) Del(key string) bool {
+ e.Lock()
+ defer e.Unlock()
+
+ if e.m == nil {
+ return false
+ }
+
+ if _, ok := e.m[key]; !ok {
+ return false
+ }
+
+ delete(e.m, key)
+
+ return true
+}
+
+func (e *env) Each(f func (string, string) bool) {
+ e.Lock()
+ defer e.Unlock()
+
+ for k, v := range e.m {
+ if !f(k, v) {
+ break
+ }
+ }
+}
+
+func (e *env) Clear() {
+ e.Lock()
+ defer e.Unlock()
+
+ e.m = nil
+}
+
+func (e *env) Eval(s string) string {
+ e.Lock()
+ defer e.Unlock()
+
+ repl := func (v string) string {
+ if v, ok := e.m[v[1:]]; ok {
+ return v
+ }
+ return ""
+ }
+
+ for {
+ if t := namePatternRe.ReplaceAllStringFunc(s, repl); t == s {
+ break
+ } else {
+ s = t
+ }
+ }
+
+ return s
+}
diff --git a/pkg/server/module/alpha.go b/pkg/server/module/alpha.go
index be9032c..9eb1e2c 100644
--- a/pkg/server/module/alpha.go
+++ b/pkg/server/module/alpha.go
@@ -7,7 +7,7 @@ import (
"io"
)
-func alpha(cb func (rune) rune) pipe {
+func alpha(cb func (rune) rune) Pipe {
return func (rq, wq queue.Q) error {
r := bufio.NewReader(rq.Reader())
@@ -27,6 +27,6 @@ func alpha(cb func (rune) rune) pipe {
}
func init() {
- register("lower", alpha(unicode.ToLower))
- register("upper", alpha(unicode.ToUpper))
+ registerPipe("lower", alpha(unicode.ToLower))
+ registerPipe("upper", alpha(unicode.ToUpper))
}
diff --git a/pkg/server/module/auth.go b/pkg/server/module/auth.go
new file mode 100644
index 0000000..05761ed
--- /dev/null
+++ b/pkg/server/module/auth.go
@@ -0,0 +1,32 @@
+package module
+
+import (
+ "tunnel/pkg/server/queue"
+ "tunnel/pkg/server/opts"
+ "tunnel/pkg/server/env"
+)
+
+type auth struct {
+ secret string
+}
+
+type authModule struct{}
+
+func (a *auth) Send(rq, wq queue.Q) error {
+ return queue.Copy(rq, wq)
+}
+
+func (a *auth) Recv(rq, wq queue.Q) error {
+ return queue.Copy(rq, wq)
+}
+
+func (m authModule) Open(env env.Env) (Pipe, Pipe) {
+ a := &auth{env.Get("secret")}
+ return a.Send, a.Recv
+}
+
+func init() {
+ register("auth", func (opts.Opts, env.Env) (module, error) {
+ return authModule{}, nil
+ })
+}
diff --git a/pkg/server/module/hex.go b/pkg/server/module/hex.go
index 2ffd1fc..9b80e0d 100644
--- a/pkg/server/module/hex.go
+++ b/pkg/server/module/hex.go
@@ -22,6 +22,6 @@ func hexDecoder(rq, wq queue.Q) error {
}
func init() {
- register("hex", pipe(hexEncoder))
- register("unhex", pipe(hexDecoder))
+ registerPipe("hex", Pipe(hexEncoder))
+ registerPipe("unhex", Pipe(hexDecoder))
}
diff --git a/pkg/server/module/module.go b/pkg/server/module/module.go
index 768a87b..87bdd20 100644
--- a/pkg/server/module/module.go
+++ b/pkg/server/module/module.go
@@ -2,16 +2,29 @@ package module
import (
"tunnel/pkg/server/queue"
+ "tunnel/pkg/server/opts"
+ "tunnel/pkg/server/env"
"fmt"
"log"
)
-var modules = map[string]M{}
+type moduleInitFunc func (opts.Opts, env.Env) (module, error)
-type pipe func (rq, wq queue.Q) error
+var modules = map[string]moduleInitFunc{}
+
+type module interface {
+ Open(env env.Env) (Pipe, Pipe)
+}
type M interface {
- Open() (pipe, pipe)
+ module
+ String() string
+}
+
+type Pipe func (rq, wq queue.Q) error
+
+func (p Pipe) Open(env env.Env) (Pipe, Pipe) {
+ return p, nil
}
type reverse struct {
@@ -22,27 +35,43 @@ func Reverse(m M) M {
return &reverse{m}
}
-func (r *reverse) Open() (pipe, pipe) {
- p1, p2 := r.M.Open()
+func (r *reverse) Open(env env.Env) (Pipe, Pipe) {
+ p1, p2 := r.M.Open(env)
return p2, p1
}
-func (p pipe) Open() (pipe, pipe) {
- return p, nil
+type named struct {
+ name string
+ module
+}
+
+func (m *named) String() string {
+ return fmt.Sprintf("module:%s", m.name)
}
-func register(name string, m M) {
+func register(name string, f moduleInitFunc) {
if _, ok := modules[name]; ok {
log.Panicf("duplicate module name '%s'", name)
}
- modules[name] = m
+ modules[name] = f
}
-func New(name string) (M, error) {
- if m, ok := modules[name]; ok {
- return m, nil
+func registerPipe(name string, p Pipe) {
+ register(name, func (opts.Opts, env.Env) (module, error) {
+ return p, nil
+ })
+}
+
+func New(desc string, env env.Env) (M, error) {
+ name, opts := opts.Parse(desc)
+
+ if f, ok := modules[name]; !ok {
+ return nil, fmt.Errorf("unknown module '%s'", name)
+ } else if m, err := f(opts, env); err != nil {
+ return nil, err
+ } else {
+ return &named{name: name, module: m}, nil
}
- return nil, fmt.Errorf("unknown module '%s'", name)
}
diff --git a/pkg/server/opts/opts.go b/pkg/server/opts/opts.go
new file mode 100644
index 0000000..25dd8e6
--- /dev/null
+++ b/pkg/server/opts/opts.go
@@ -0,0 +1,21 @@
+package opts
+
+import "strings"
+
+type Opts map[string]string
+
+func Parse(s string) (string, Opts) {
+ v := strings.Split(s, ",")
+ m := map[string]string{}
+
+ for _, t := range v[1:] {
+ kv := strings.SplitN(t, "=", 2)
+ if len(kv) < 2 {
+ m[kv[0]] = ""
+ } else {
+ m[kv[0]] = kv[1]
+ }
+ }
+
+ return v[0], m
+}
diff --git a/pkg/server/queue/queue.go b/pkg/server/queue/queue.go
index 8d0f395..745d971 100644
--- a/pkg/server/queue/queue.go
+++ b/pkg/server/queue/queue.go
@@ -41,6 +41,10 @@ func (q Q) Writer() io.Writer {
return &writer{q: q}
}
+func (q Q) Dry() {
+ for _ = range q {}
+}
+
func (w *writer) Write(p []byte) (int, error) {
buf := make([]byte, len(p))
copy(buf, p)
@@ -58,3 +62,10 @@ func IoCopy(r io.Reader, w io.Writer) error {
return nil
}
+
+func Copy(rq, wq Q) error {
+ for b := range rq {
+ wq <- b
+ }
+ return nil
+}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index ce910f3..d380fb4 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -3,6 +3,7 @@ package server
import (
"tunnel/pkg/config"
"tunnel/pkg/netstring"
+ "tunnel/pkg/server/env"
"strings"
"errors"
"bytes"
@@ -25,7 +26,7 @@ type Server struct {
done chan struct{}
tunnels automap
- env env
+ env env.Env
nextCid int
}
@@ -58,11 +59,11 @@ type requestError string
var errNotImplemented = errors.New("not implemented")
func (c *client) String() string {
- return fmt.Sprintf("client(%d)", c.id)
+ return fmt.Sprintf("client:%d", c.id)
}
func (r *request) String() string {
- return fmt.Sprintf("request(%d)", r.id)
+ return fmt.Sprintf("request:%d", r.id)
}
func (r *request) Print(v ...interface{}) {
@@ -132,6 +133,7 @@ func New() (*Server, error) {
}
s := &Server{
+ env: env.New(),
listen: listen,
since: time.Now(),
done: make(chan struct{}),
@@ -248,33 +250,11 @@ func (r *request) decode(query string) []string {
}
func (r *request) eval(args []string) []string {
- repl := func (v string) string {
- if v, ok := r.c.s.env.get(v[1:]); ok {
- return v
- }
-
- r.Fatal("unbound variable ", v)
-
- return v
- }
-
- eval := func (s string) string {
- var t string
-
- for ;; s = t {
- t = varTokenRe.ReplaceAllStringFunc(s, repl)
-
- if s == t {
- return s
- }
- }
- }
-
for n, s := range args {
if strings.HasPrefix(s, "^") {
args[n] = s[1:]
} else {
- args[n] = eval(s)
+ args[n] = r.c.s.env.Eval(s)
}
}
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index bf754cf..c6219a5 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -2,6 +2,8 @@ package socket
import (
"tunnel/pkg/server/queue"
+ "tunnel/pkg/server/opts"
+ "tunnel/pkg/server/env"
"strings"
"sync"
"fmt"
@@ -16,7 +18,7 @@ type Channel interface {
}
type S interface {
- Open() (Channel, error)
+ Open(env env.Env) (Channel, error)
Close()
}
@@ -106,7 +108,7 @@ func newListenSocket(proto, addr string) (S, error) {
return s, nil
}
-func (s *listenSocket) Open() (Channel, error) {
+func (s *listenSocket) Open(env env.Env) (Channel, error) {
conn, err := s.listen.Accept()
if err != nil {
return nil, err
@@ -137,7 +139,7 @@ func (s *dialSocket) String() string {
return fmt.Sprintf("%s/%s", s.proto, s.addr)
}
-func (s *dialSocket) Open() (Channel, error) {
+func (s *dialSocket) Open(env env.Env) (Channel, error) {
conn, err := net.Dial(s.proto, s.addr)
if err != nil {
return nil, err
@@ -148,19 +150,9 @@ func (s *dialSocket) Open() (Channel, error) {
func (s *dialSocket) Close() {
}
-func New(name string) (S, error) {
- vv := strings.Split(name, ",")
- args := strings.Split(vv[0], "/")
- opts := map[string]string{}
-
- for _, v := range vv[1:] {
- ss := strings.SplitN(v, "=", 2)
- if len(ss) < 2 {
- opts[ss[0]] = ""
- } else {
- opts[ss[0]] = ss[1]
- }
- }
+func New(desc string, env env.Env) (S, error) {
+ base, opts := opts.Parse(desc)
+ args := strings.Split(base, "/")
var proto string
var addr string
@@ -176,7 +168,7 @@ func New(name string) (S, error) {
}
if addr == "" {
- return nil, fmt.Errorf("bad socket '%s'", name)
+ return nil, fmt.Errorf("bad socket '%s'", desc)
}
if _, ok := opts["listen"]; ok {
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index e4a324c..a7854bc 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -1,9 +1,10 @@
package server
import (
+ "tunnel/pkg/server/socket"
"tunnel/pkg/server/module"
"tunnel/pkg/server/queue"
- "tunnel/pkg/server/socket"
+ "tunnel/pkg/server/env"
"tunnel/pkg/config"
"strings"
"time"
@@ -37,14 +38,16 @@ type tunnel struct {
in, out socket.S
m []module.M
+
+ env env.Env
}
func (s *stream) String() string {
- return fmt.Sprintf("stream(%d)", s.id)
+ return fmt.Sprintf("stream:%d", s.id)
}
func (t *tunnel) String() string {
- return fmt.Sprintf("tunnel(%s)", t.id)
+ return fmt.Sprintf("tunnel:%s", t.id)
}
func (t *tunnel) stopServe() {
@@ -84,7 +87,7 @@ func (t *tunnel) serve() {
var wg sync.WaitGroup
for {
- if in, err := t.in.Open(); err != nil {
+ if in, err := t.in.Open(t.env); err != nil {
if t.isQuit() {
break
}
@@ -109,7 +112,7 @@ func (t *tunnel) serve() {
}
func (t *tunnel) handle(in socket.Channel) {
- out, err := t.out.Open()
+ out, err := t.out.Open(t.env)
if err != nil {
log.Println(t, err)
in.Close()
@@ -132,13 +135,13 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream {
since: time.Now(),
}
+ s.run()
+
t.mu.Lock()
t.nextSid++
t.streams[s.id] = s
t.mu.Unlock()
- s.run()
-
go func () {
s.wg.Wait()
@@ -154,7 +157,7 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream {
return s
}
-func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) {
+func (s *stream) channel(c socket.Channel, rq, wq queue.Q) {
watch := func (q queue.Q, f func (q queue.Q) error) {
defer s.wg.Done()
@@ -170,20 +173,24 @@ func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) {
close(wq)
}()
- go watch(rq, c.Recv)
+ go func () {
+ watch(rq, c.Recv)
+ rq.Dry()
+ }()
}
-func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) {
+func (s *stream) pipe(m module.M, p module.Pipe, rq, wq queue.Q) {
s.wg.Add(1)
go func () {
defer s.wg.Done()
- if err := f(rq, wq); err != nil {
- log.Println(s.t, s, err)
+ if err := p(rq, wq); err != nil {
+ log.Println(s.t, s, m, err)
}
close(wq)
+ rq.Dry()
}()
}
@@ -192,23 +199,23 @@ func (s *stream) run() {
rq, wq := queue.New(), queue.New()
- s.watchChannel(rq, wq, s.in)
+ s.channel(s.in, rq, wq)
for _, m := range s.t.m {
- send, recv := m.Open()
+ send, recv := m.Open(s.t.env)
if send != nil {
q := queue.New()
- s.watchPipe(wq, q, send)
+ s.pipe(m, send, wq, q)
wq = q
}
if recv != nil {
q := queue.New()
- s.watchPipe(q, rq, recv)
+ s.pipe(m, recv, q, rq)
rq = q
}
}
- s.watchChannel(wq, rq, s.out)
+ s.channel(s.out, wq, rq)
}
func (s *stream) stop() {
@@ -216,34 +223,14 @@ func (s *stream) stop() {
s.out.Close()
}
-func newTunnel(args []string) (*tunnel, error) {
- var in, out socket.S
- var err error
-
- n := len(args) - 1
-
- if in, err = socket.New(args[0]); err != nil {
- return nil, err
- }
-
- if out, err = socket.New(args[n]); err != nil {
- in.Close()
- return nil, err
- }
-
- t := &tunnel{
- args: strings.Join(args, " "),
- quit: make(chan struct{}),
- done: make(chan struct{}),
- in: in,
- out: out,
- streams: make(map[int]*stream),
- }
+func parseModules(args []string, env env.Env) ([]module.M, error) {
+ var mm []module.M
reverse := false
- for _, arg := range args[1:n] {
+ for _, arg := range args {
var m module.M
+ var err error
if arg == "-" {
reverse = true
@@ -255,8 +242,7 @@ func newTunnel(args []string) (*tunnel, error) {
continue
}
- if m, err = module.New(arg); err != nil {
- t.Close()
+ if m, err = module.New(arg, env); err != nil {
return nil, err
}
@@ -265,14 +251,49 @@ func newTunnel(args []string) (*tunnel, error) {
reverse = false
}
- t.m = append(t.m, m)
+ mm = append(mm, m)
}
if reverse {
- t.Close()
return nil, fmt.Errorf("bad '-' usage")
}
+ return mm, nil
+}
+
+func newTunnel(args []string, env env.Env) (*tunnel, error) {
+ var in, out socket.S
+ var mm []module.M
+ var err error
+
+ n := len(args) - 1
+
+ if in, err = socket.New(args[0], env); err != nil {
+ return nil, err
+ }
+
+ if out, err = socket.New(args[n], env); err != nil {
+ in.Close()
+ return nil, err
+ }
+
+ if mm, err = parseModules(args[1:n], env); err != nil {
+ in.Close()
+ out.Close()
+ return nil, err
+ }
+
+ t := &tunnel{
+ args: strings.Join(args, " "),
+ quit: make(chan struct{}),
+ done: make(chan struct{}),
+ m: mm,
+ in: in,
+ out: out,
+ env: env,
+ streams: make(map[int]*stream),
+ }
+
go t.serve()
return t, nil
@@ -305,7 +326,7 @@ func tunnelAdd(r *request) {
r.Fatal("not enough args")
}
- t, err := newTunnel(args)
+ t, err := newTunnel(args, r.c.s.env)
if err != nil {
r.Fatal(err)
}
@@ -384,18 +405,17 @@ func tunnelShow(r *request) {
func streamShow(r *request) {
foreachTunnel(r.c.s.tunnels, func (t *tunnel) {
- r.Println(t.id, t.args)
-
t.mu.Lock()
- if len(t.streams) == 0 {
- r.Println("\t", "nothing")
- } else {
+ defer t.mu.Unlock()
+
+ if len(t.streams) > 0 {
+ r.Println(t.id, t.args)
+
foreachStream(t.streams, func (s *stream) {
tm := s.since.Format(config.TimeFormat)
r.Println("\t", s.id, tm, s.in, s.out)
})
}
- t.mu.Unlock()
})
}
@@ -406,5 +426,6 @@ func init() {
newCmd(tunnelRename, "rename")
newCmd(tunnelShow, "show")
- newCmd(streamShow, "stream show")
+
+ newCmd(streamShow, "streams")
}