diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-03-08 01:33:06 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-03-08 01:33:06 +0300 |
| commit | 45009e12dd8c8dda711c08f91bc8f6c925966d93 (patch) | |
| tree | 1c9efeaa980b2c3a16779b591cda56d2aae5f86f | |
| parent | c83b04c10c3d1126f295a72f9e6d96bf1924238a (diff) | |
mono, force and simpler channels
| -rw-r--r-- | cmd/tunneld/main.go | 14 | ||||
| -rw-r--r-- | pkg/server/socket/socket.go | 29 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 32 |
3 files changed, 51 insertions, 24 deletions
diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go index aa0a370..ffa6c43 100644 --- a/cmd/tunneld/main.go +++ b/cmd/tunneld/main.go @@ -20,7 +20,8 @@ import ( var ( debugFlag = flag.Bool("d", false, "debug: print time and source info") - syslogFlag = flag.Bool("s", false, "log output to syslog instead of stdout") + forceFlag = flag.Bool("f", false, "try start with force") + syslogFlag = flag.Bool("s", false, "write log to syslog instead of stdout") configFlag = flag.String("c", "", "path to configuration file") ) @@ -182,7 +183,16 @@ func main() { flag.Parse() initLog() - s, err := server.New(getSocketPath()) + socket := getSocketPath() + if *forceFlag { + if err := os.Remove(socket); err != nil { + if !errors.Is(err, syscall.ENOENT) { + log.Fatal(err) + } + } + } + + s, err := server.New(socket) if err != nil { log.Fatal(err) } diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index 0060ce6..1bb7549 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -14,8 +14,12 @@ import ( var errAlreadyClosed = errors.New("already closed") +type exportChannel struct { + info string + Channel +} + type Channel interface { - Origin() string Send(wq queue.Q) error Recv(rq queue.Q) error Close() error @@ -36,7 +40,6 @@ type dialSocket struct { } type connChannel struct { - origin string conn net.Conn once sync.Once } @@ -47,8 +50,12 @@ type loopChannel struct { q queue.Q } -func newConnChannel(origin string, conn net.Conn) Channel { - return &connChannel{origin: origin, conn: conn} +func (c exportChannel) String() string { + return c.info +} + +func newConnChannel(conn net.Conn) Channel { + return &connChannel{conn: conn} } func (c *connChannel) final(f func() error, err error) error { @@ -63,10 +70,6 @@ func (c *connChannel) final(f func() error, err error) error { return err } -func (c *connChannel) Origin() string { - return c.origin -} - func (c *connChannel) Send(wq queue.Q) error { err := queue.IoCopy(c.conn, wq.Writer()) return c.final(c.Close, err) @@ -121,9 +124,9 @@ func (s *listenSocket) Open(env env.Env) (Channel, error) { } addr := conn.RemoteAddr() - origin := fmt.Sprintf("%s/%s", addr.Network(), addr) + info := fmt.Sprintf("%s/%s", addr.Network(), addr) - return newConnChannel(origin, conn), nil + return exportChannel{info, newConnChannel(conn)}, nil } func (s *listenSocket) String() string { @@ -154,16 +157,12 @@ func (s *dialSocket) Open(env env.Env) (Channel, error) { if err != nil { return nil, err } - return newConnChannel("-", conn), nil + return exportChannel{"-", newConnChannel(conn)}, nil } func (s *dialSocket) Close() { } -func (c *loopChannel) Origin() string { - return "loop" -} - func (c *loopChannel) Send(wq queue.Q) error { return queue.Copy(c.q, wq) } diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index 01d7c22..72d4c13 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -56,6 +56,8 @@ type tunnel struct { quit chan struct{} done chan struct{} + mono bool + in, out socket.S hooks []hook.H @@ -123,6 +125,11 @@ func (t *tunnel) serve() { t.handle(in) wg.Done() }() + + if t.mono { + wg.Wait() + t.wg.Wait() + } } } @@ -240,7 +247,7 @@ func (s *stream) channel(c socket.Channel, m *metric, rq, wq queue.Q) { q := queue.New() go counter(&m.rx, rq, q) watch(q, c.Recv) - rq.Dry() + q.Dry() }() } @@ -311,7 +318,7 @@ func parseHooks(args []string, env env.Env) ([]hook.H, error) { return hooks, nil } -func newTunnel(args []string, env env.Env) (*tunnel, error) { +func newTunnel(mono bool, args []string, env env.Env) (*tunnel, error) { var in, out socket.S var hooks []hook.H var err error @@ -337,6 +344,7 @@ func newTunnel(args []string, env env.Env) (*tunnel, error) { args: strings.Join(args, " "), quit: make(chan struct{}), done: make(chan struct{}), + mono: mono, hooks: hooks, in: in, out: out, @@ -356,9 +364,10 @@ func isOkTunnelName(s string) bool { func tunnelAdd(r *request) { args := r.args name := "" + mono := false - if len(args) >= 2 { - if args[0] == "name" { + for len(args) > 1 { + if args[0] == "name" && len(args) > 1 { name = args[1] if !isOkTunnelName(name) { r.Fatal("bad name") @@ -369,14 +378,23 @@ func tunnelAdd(r *request) { } args = args[2:] + continue } + + if args[0] == "mono" { + mono = true + args = args[1:] + continue + } + + break } if len(args) < 2 { r.Fatal("not enough args") } - t, err := newTunnel(args, r.c.s.env) + t, err := newTunnel(mono, args, r.c.s.env) if err != nil { r.Fatal(err) } @@ -459,7 +477,7 @@ func showActive(r *request) { defer t.mu.Unlock() foreachStream(t.streams, func(s *stream) { - r.Println(t.id, s.id, s.in.Origin(), s.out.Origin(), s.info()) + r.Println(t.id, s.id, s.in, s.out, s.info()) }) }) } @@ -479,7 +497,7 @@ func showRecent(r *request) { for _, s := range streams { when := s.until.Format(config.TimeFormat) - r.Println(when, s.t.id, s.id, s.in.Origin(), s.out.Origin(), s.info()) + r.Println(when, s.t.id, s.id, s.in, s.out, s.info()) } } |
