summaryrefslogtreecommitdiff
path: root/pkg/server
diff options
context:
space:
mode:
authorMikhail Osipov <mike.osipov@gmail.com>2020-03-08 01:33:06 +0300
committerMikhail Osipov <mike.osipov@gmail.com>2020-03-08 01:33:06 +0300
commit45009e12dd8c8dda711c08f91bc8f6c925966d93 (patch)
tree1c9efeaa980b2c3a16779b591cda56d2aae5f86f /pkg/server
parentc83b04c10c3d1126f295a72f9e6d96bf1924238a (diff)
mono, force and simpler channels
Diffstat (limited to 'pkg/server')
-rw-r--r--pkg/server/socket/socket.go29
-rw-r--r--pkg/server/tunnel.go32
2 files changed, 39 insertions, 22 deletions
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 0060ce6..1bb7549 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -14,8 +14,12 @@ import (
var errAlreadyClosed = errors.New("already closed")
+type exportChannel struct {
+ info string
+ Channel
+}
+
type Channel interface {
- Origin() string
Send(wq queue.Q) error
Recv(rq queue.Q) error
Close() error
@@ -36,7 +40,6 @@ type dialSocket struct {
}
type connChannel struct {
- origin string
conn net.Conn
once sync.Once
}
@@ -47,8 +50,12 @@ type loopChannel struct {
q queue.Q
}
-func newConnChannel(origin string, conn net.Conn) Channel {
- return &connChannel{origin: origin, conn: conn}
+func (c exportChannel) String() string {
+ return c.info
+}
+
+func newConnChannel(conn net.Conn) Channel {
+ return &connChannel{conn: conn}
}
func (c *connChannel) final(f func() error, err error) error {
@@ -63,10 +70,6 @@ func (c *connChannel) final(f func() error, err error) error {
return err
}
-func (c *connChannel) Origin() string {
- return c.origin
-}
-
func (c *connChannel) Send(wq queue.Q) error {
err := queue.IoCopy(c.conn, wq.Writer())
return c.final(c.Close, err)
@@ -121,9 +124,9 @@ func (s *listenSocket) Open(env env.Env) (Channel, error) {
}
addr := conn.RemoteAddr()
- origin := fmt.Sprintf("%s/%s", addr.Network(), addr)
+ info := fmt.Sprintf("%s/%s", addr.Network(), addr)
- return newConnChannel(origin, conn), nil
+ return exportChannel{info, newConnChannel(conn)}, nil
}
func (s *listenSocket) String() string {
@@ -154,16 +157,12 @@ func (s *dialSocket) Open(env env.Env) (Channel, error) {
if err != nil {
return nil, err
}
- return newConnChannel("-", conn), nil
+ return exportChannel{"-", newConnChannel(conn)}, nil
}
func (s *dialSocket) Close() {
}
-func (c *loopChannel) Origin() string {
- return "loop"
-}
-
func (c *loopChannel) Send(wq queue.Q) error {
return queue.Copy(c.q, wq)
}
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index 01d7c22..72d4c13 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -56,6 +56,8 @@ type tunnel struct {
quit chan struct{}
done chan struct{}
+ mono bool
+
in, out socket.S
hooks []hook.H
@@ -123,6 +125,11 @@ func (t *tunnel) serve() {
t.handle(in)
wg.Done()
}()
+
+ if t.mono {
+ wg.Wait()
+ t.wg.Wait()
+ }
}
}
@@ -240,7 +247,7 @@ func (s *stream) channel(c socket.Channel, m *metric, rq, wq queue.Q) {
q := queue.New()
go counter(&m.rx, rq, q)
watch(q, c.Recv)
- rq.Dry()
+ q.Dry()
}()
}
@@ -311,7 +318,7 @@ func parseHooks(args []string, env env.Env) ([]hook.H, error) {
return hooks, nil
}
-func newTunnel(args []string, env env.Env) (*tunnel, error) {
+func newTunnel(mono bool, args []string, env env.Env) (*tunnel, error) {
var in, out socket.S
var hooks []hook.H
var err error
@@ -337,6 +344,7 @@ func newTunnel(args []string, env env.Env) (*tunnel, error) {
args: strings.Join(args, " "),
quit: make(chan struct{}),
done: make(chan struct{}),
+ mono: mono,
hooks: hooks,
in: in,
out: out,
@@ -356,9 +364,10 @@ func isOkTunnelName(s string) bool {
func tunnelAdd(r *request) {
args := r.args
name := ""
+ mono := false
- if len(args) >= 2 {
- if args[0] == "name" {
+ for len(args) > 1 {
+ if args[0] == "name" && len(args) > 1 {
name = args[1]
if !isOkTunnelName(name) {
r.Fatal("bad name")
@@ -369,14 +378,23 @@ func tunnelAdd(r *request) {
}
args = args[2:]
+ continue
}
+
+ if args[0] == "mono" {
+ mono = true
+ args = args[1:]
+ continue
+ }
+
+ break
}
if len(args) < 2 {
r.Fatal("not enough args")
}
- t, err := newTunnel(args, r.c.s.env)
+ t, err := newTunnel(mono, args, r.c.s.env)
if err != nil {
r.Fatal(err)
}
@@ -459,7 +477,7 @@ func showActive(r *request) {
defer t.mu.Unlock()
foreachStream(t.streams, func(s *stream) {
- r.Println(t.id, s.id, s.in.Origin(), s.out.Origin(), s.info())
+ r.Println(t.id, s.id, s.in, s.out, s.info())
})
})
}
@@ -479,7 +497,7 @@ func showRecent(r *request) {
for _, s := range streams {
when := s.until.Format(config.TimeFormat)
- r.Println(when, s.t.id, s.id, s.in.Origin(), s.out.Origin(), s.info())
+ r.Println(when, s.t.id, s.id, s.in, s.out, s.info())
}
}