diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-02-23 05:20:04 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-02-23 05:20:04 +0300 |
| commit | de868930e2301b68a50bde088dd83dc575b72c54 (patch) | |
| tree | 9f3df0bbdb5035100311fe7194aef66186478ec4 /pkg/server/tunnel.go | |
| parent | 7c7fafefef94c5fb8bfe319e7745d80a1e88205d (diff) | |
prepare to auth
Diffstat (limited to 'pkg/server/tunnel.go')
| -rw-r--r-- | pkg/server/tunnel.go | 127 |
1 files changed, 74 insertions, 53 deletions
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index e4a324c..a7854bc 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -1,9 +1,10 @@ package server import ( + "tunnel/pkg/server/socket" "tunnel/pkg/server/module" "tunnel/pkg/server/queue" - "tunnel/pkg/server/socket" + "tunnel/pkg/server/env" "tunnel/pkg/config" "strings" "time" @@ -37,14 +38,16 @@ type tunnel struct { in, out socket.S m []module.M + + env env.Env } func (s *stream) String() string { - return fmt.Sprintf("stream(%d)", s.id) + return fmt.Sprintf("stream:%d", s.id) } func (t *tunnel) String() string { - return fmt.Sprintf("tunnel(%s)", t.id) + return fmt.Sprintf("tunnel:%s", t.id) } func (t *tunnel) stopServe() { @@ -84,7 +87,7 @@ func (t *tunnel) serve() { var wg sync.WaitGroup for { - if in, err := t.in.Open(); err != nil { + if in, err := t.in.Open(t.env); err != nil { if t.isQuit() { break } @@ -109,7 +112,7 @@ func (t *tunnel) serve() { } func (t *tunnel) handle(in socket.Channel) { - out, err := t.out.Open() + out, err := t.out.Open(t.env) if err != nil { log.Println(t, err) in.Close() @@ -132,13 +135,13 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream { since: time.Now(), } + s.run() + t.mu.Lock() t.nextSid++ t.streams[s.id] = s t.mu.Unlock() - s.run() - go func () { s.wg.Wait() @@ -154,7 +157,7 @@ func (t *tunnel) newStream(in, out socket.Channel) *stream { return s } -func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { +func (s *stream) channel(c socket.Channel, rq, wq queue.Q) { watch := func (q queue.Q, f func (q queue.Q) error) { defer s.wg.Done() @@ -170,20 +173,24 @@ func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { close(wq) }() - go watch(rq, c.Recv) + go func () { + watch(rq, c.Recv) + rq.Dry() + }() } -func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) { +func (s *stream) pipe(m module.M, p module.Pipe, rq, wq queue.Q) { s.wg.Add(1) go func () { defer s.wg.Done() - if err := f(rq, wq); err != nil { - log.Println(s.t, s, err) + if err := p(rq, wq); err != nil { + log.Println(s.t, s, m, err) } close(wq) + rq.Dry() }() } @@ -192,23 +199,23 @@ func (s *stream) run() { rq, wq := queue.New(), queue.New() - s.watchChannel(rq, wq, s.in) + s.channel(s.in, rq, wq) for _, m := range s.t.m { - send, recv := m.Open() + send, recv := m.Open(s.t.env) if send != nil { q := queue.New() - s.watchPipe(wq, q, send) + s.pipe(m, send, wq, q) wq = q } if recv != nil { q := queue.New() - s.watchPipe(q, rq, recv) + s.pipe(m, recv, q, rq) rq = q } } - s.watchChannel(wq, rq, s.out) + s.channel(s.out, wq, rq) } func (s *stream) stop() { @@ -216,34 +223,14 @@ func (s *stream) stop() { s.out.Close() } -func newTunnel(args []string) (*tunnel, error) { - var in, out socket.S - var err error - - n := len(args) - 1 - - if in, err = socket.New(args[0]); err != nil { - return nil, err - } - - if out, err = socket.New(args[n]); err != nil { - in.Close() - return nil, err - } - - t := &tunnel{ - args: strings.Join(args, " "), - quit: make(chan struct{}), - done: make(chan struct{}), - in: in, - out: out, - streams: make(map[int]*stream), - } +func parseModules(args []string, env env.Env) ([]module.M, error) { + var mm []module.M reverse := false - for _, arg := range args[1:n] { + for _, arg := range args { var m module.M + var err error if arg == "-" { reverse = true @@ -255,8 +242,7 @@ func newTunnel(args []string) (*tunnel, error) { continue } - if m, err = module.New(arg); err != nil { - t.Close() + if m, err = module.New(arg, env); err != nil { return nil, err } @@ -265,14 +251,49 @@ func newTunnel(args []string) (*tunnel, error) { reverse = false } - t.m = append(t.m, m) + mm = append(mm, m) } if reverse { - t.Close() return nil, fmt.Errorf("bad '-' usage") } + return mm, nil +} + +func newTunnel(args []string, env env.Env) (*tunnel, error) { + var in, out socket.S + var mm []module.M + var err error + + n := len(args) - 1 + + if in, err = socket.New(args[0], env); err != nil { + return nil, err + } + + if out, err = socket.New(args[n], env); err != nil { + in.Close() + return nil, err + } + + if mm, err = parseModules(args[1:n], env); err != nil { + in.Close() + out.Close() + return nil, err + } + + t := &tunnel{ + args: strings.Join(args, " "), + quit: make(chan struct{}), + done: make(chan struct{}), + m: mm, + in: in, + out: out, + env: env, + streams: make(map[int]*stream), + } + go t.serve() return t, nil @@ -305,7 +326,7 @@ func tunnelAdd(r *request) { r.Fatal("not enough args") } - t, err := newTunnel(args) + t, err := newTunnel(args, r.c.s.env) if err != nil { r.Fatal(err) } @@ -384,18 +405,17 @@ func tunnelShow(r *request) { func streamShow(r *request) { foreachTunnel(r.c.s.tunnels, func (t *tunnel) { - r.Println(t.id, t.args) - t.mu.Lock() - if len(t.streams) == 0 { - r.Println("\t", "nothing") - } else { + defer t.mu.Unlock() + + if len(t.streams) > 0 { + r.Println(t.id, t.args) + foreachStream(t.streams, func (s *stream) { tm := s.since.Format(config.TimeFormat) r.Println("\t", s.id, tm, s.in, s.out) }) } - t.mu.Unlock() }) } @@ -406,5 +426,6 @@ func init() { newCmd(tunnelRename, "rename") newCmd(tunnelShow, "show") - newCmd(streamShow, "stream show") + + newCmd(streamShow, "streams") } |
