diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/server/env/env.go | 5 | ||||
| -rw-r--r-- | pkg/server/module/auth.go | 9 | ||||
| -rw-r--r-- | pkg/server/module/hex.go | 2 | ||||
| -rw-r--r-- | pkg/server/module/tee.go | 112 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 9 |
5 files changed, 128 insertions, 9 deletions
diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go index 237c0c5..118f853 100644 --- a/pkg/server/env/env.go +++ b/pkg/server/env/env.go @@ -22,6 +22,7 @@ var isNamePattern = regexp.MustCompile("^" + namePattern + "$").MatchString var namePatternRe = regexp.MustCompile("@(" + namePattern + "|{" + namePattern + "})") var errBadVariable = errors.New("bad variable name") +var errEmptyVariable = errors.New("empty variable") func New() Env { return Env{new(env)} @@ -66,6 +67,10 @@ func (e *env) Set(key string, value string) error { return errBadVariable } + if value == "" { + return errEmptyVariable + } + e.Lock() defer e.Unlock() diff --git a/pkg/server/module/auth.go b/pkg/server/module/auth.go index de58e82..5e5caeb 100644 --- a/pkg/server/module/auth.go +++ b/pkg/server/module/auth.go @@ -4,7 +4,6 @@ import ( "crypto/md5" "crypto/rand" "errors" - "fmt" "io" "tunnel/pkg/netstring" "tunnel/pkg/server/env" @@ -124,12 +123,8 @@ func (a *auth) Recv(rq, wq queue.Q) (err error) { } func getAuthSecret(env env.Env) string { - if id, ok := env.Find("tunnel"); ok { - k := fmt.Sprintf("tunnel.%s.secret", id) - - if v, ok := env.Find(k); ok { - return v - } + if v := env.Eval("@{tunnel.@{tunnel}.secret}"); v != "" { + return v } return env.Get("secret") diff --git a/pkg/server/module/hex.go b/pkg/server/module/hex.go index e71688c..ef4ff37 100644 --- a/pkg/server/module/hex.go +++ b/pkg/server/module/hex.go @@ -24,7 +24,7 @@ func (m hexModule) Recv(rq, wq queue.Q) error { return queue.IoCopy(r, wq.Writer()) } -func (m hexModule) Open(env env.Env) (interface{}, error) { +func (m hexModule) Open(env.Env) (interface{}, error) { return m, nil } diff --git a/pkg/server/module/tee.go b/pkg/server/module/tee.go new file mode 100644 index 0000000..7953247 --- /dev/null +++ b/pkg/server/module/tee.go @@ -0,0 +1,112 @@ +package module + +import ( + "bytes" + "fmt" + "os" + "encoding/hex" + "sync" + "tunnel/pkg/server/env" + "tunnel/pkg/server/opts" + "tunnel/pkg/server/queue" +) + +const teeDefaultPath = "/tmp/tunnel.dump" + +type tee struct { + f *os.File + mu sync.Mutex + wg sync.WaitGroup +} + +type teeModule struct { + path string +} + +func (t *tee) dump(s string, p []byte) error { + var out bytes.Buffer + + fmt.Fprintln(&out, s, len(p)) + + w := hex.Dumper(&out) + w.Write(p) + w.Close() + + if _, err := t.f.Write(out.Bytes()); err != nil { + return err + } + + return nil +} + +func (t *tee) Send(rq, wq queue.Q) error { + defer t.wg.Done() + + for b := range rq { + t.dump(">", b) + wq <- b + } + + return nil +} + +func (t *tee) Recv(rq, wq queue.Q) error { + defer t.wg.Done() + + for b := range rq { + t.dump("<", b) + wq <- b + } + + return nil +} + +func (m *teeModule) where(env env.Env) string { + if m.path != "" { + return m.path + } + + if v := env.Eval("@{tunnel.@{tunnel}.tee.path}"); v != "" { + return v + } + + if v, ok := env.Find("module.tee.path"); ok { + return v + } + + return teeDefaultPath +} + +func (m *teeModule) Open(env env.Env) (interface{}, error) { + tid, sid := env.Get("tunnel"), env.Get("stream") + name := fmt.Sprintf("%s.%s.%s", m.where(env), tid, sid) + + var t tee + + if f, err := os.Create(name); err != nil { + return nil, err + } else { + t.f = f + } + + t.wg.Add(2) + + go func() { + t.wg.Wait() + t.f.Close() + }() + + return &t, nil +} + +func newTeeModule(opts opts.Opts, env env.Env) (module, error) { + m := &teeModule{} + if path, ok := opts["path"]; ok { + m.path = path + } + return m, nil +} + +func init() { + register("tee", newTeeModule) +} diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index 189703d..91d7533 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -6,6 +6,7 @@ import ( "io" "log" "sort" + "strconv" "strings" "sync" "time" @@ -140,6 +141,7 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream { } s.env.Set("tunnel", t.id) + s.env.Set("stream", strconv.Itoa(s.id)) s.run() @@ -208,7 +210,12 @@ func (s *stream) run() { s.channel(s.in, rq, wq) for _, m := range s.t.m { - send, recv, _ := module.Open(m, s.env) + send, recv, err := module.Open(m, s.env) + if err != nil { + // FIXME: abort stream on error + log.Println(s.t, s, m, err) + continue + } if send != nil { q := queue.New() |
