diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/test/auth_test.go | 27 | ||||
| -rw-r--r-- | pkg/test/defer_test.go | 25 | ||||
| -rw-r--r-- | pkg/test/env_test.go | 12 | ||||
| -rw-r--r-- | pkg/test/hook_test.go | 46 | ||||
| -rw-r--r-- | pkg/test/proxy_test.go | 31 | ||||
| -rw-r--r-- | pkg/test/test.go | 82 |
6 files changed, 118 insertions, 105 deletions
diff --git a/pkg/test/auth_test.go b/pkg/test/auth_test.go index 41d1c2d..3b84874 100644 --- a/pkg/test/auth_test.go +++ b/pkg/test/auth_test.go @@ -5,8 +5,10 @@ import ( ) func TestAuthHook(t *testing.T) { - c, s := newClientServer(t) - defer closeClientServer(c, s) + e := newEnv(t) + defer e.Free() + + c := e.newInstance() c.Exec("add name T listen,addr=-:0 auth aes dial,addr=@[tunnel.X.listen]") c.Exec("add name X listen,addr=-:0 -aes -auth dial,addr=@[addr]") @@ -14,23 +16,18 @@ func TestAuthHook(t *testing.T) { c.Exec("set tunnel.X.secret secret") c.Exec("set tunnel.T.secret secret") - listen := xListen(t, "tcp", "127.0.0.1:0") - defer listen.Close() - + listen := e.Listen("tcp", "127.0.0.1:0") c.Set("addr", listen.Addr()) - out := xDial(t, "tcp", c.Get("tunnel.T.listen")) - defer out.Close() - - in := xAccept(t, listen) - defer in.Close() + out := e.Dial("tcp", c.Get("tunnel.T.listen")) + in := e.Accept(listen) - xWrite(t, out, xData) + e.Write(out, dummy) - buf := make([]byte, len(xData)) - xReadFull(t, in, buf) + buf := make([]byte, len(dummy)) + e.ReadFull(in, buf) - if r := string(buf); r != xData { - t.Fatalf("wrong reply: send '%s', recv '%s'", xData, r) + if r := string(buf); r != dummy { + e.Fatalf("wrong reply: send '%s', recv '%s'", dummy, r) } } diff --git a/pkg/test/defer_test.go b/pkg/test/defer_test.go index 215e0b1..57854b2 100644 --- a/pkg/test/defer_test.go +++ b/pkg/test/defer_test.go @@ -5,28 +5,27 @@ import ( ) func TestDeferSocket(t *testing.T) { - c, s := newClientServer(t) - defer closeClientServer(c, s) + e := newEnv(t) + defer e.Free() + + c := e.newInstance() c.Exec("add name T listen,addr=-:0 defer,addr=@[addr]") - out := xDial(t, "tcp", c.Get("tunnel.T.listen")) - defer out.Close() + out := e.Dial("tcp", c.Get("tunnel.T.listen")) - listen := xListen(t, "tcp", "127.0.0.1:0") - defer listen.Close() + listen := e.Listen("tcp", "127.0.0.1:0") c.Set("addr", listen.Addr()) - xWrite(t, out, xData) + e.Write(out, dummy) - in := xAccept(t, listen) - defer in.Close() + in := e.Accept(listen) - buf := make([]byte, len(xData)) - xReadFull(t, in, buf) + buf := make([]byte, len(dummy)) + e.ReadFull(in, buf) - if r := string(buf); r != xData { - t.Fatalf("wrong reply: send '%s', recv '%s'", xData, r) + if r := string(buf); r != dummy { + e.Fatalf("wrong reply: send '%s', recv '%s'", dummy, r) } } diff --git a/pkg/test/env_test.go b/pkg/test/env_test.go index 7a1cd64..2862fd2 100644 --- a/pkg/test/env_test.go +++ b/pkg/test/env_test.go @@ -5,11 +5,13 @@ import ( ) func TestEnv(t *testing.T) { - c, s := newClientServer(t) - defer closeClientServer(c, s) + e := newEnv(t) + defer e.Free() - r := c.Send("echo %s", xData) - if r != xData { - t.Errorf("wrong reply: send '%s', recv '%s'", xData, r) + c := e.newInstance() + + r := c.Send("echo %s", dummy) + if r != dummy { + e.Errorf("wrong reply: send '%s', recv '%s'", dummy, r) } } diff --git a/pkg/test/hook_test.go b/pkg/test/hook_test.go index 31e74f7..7808883 100644 --- a/pkg/test/hook_test.go +++ b/pkg/test/hook_test.go @@ -8,47 +8,45 @@ import ( ) func TestUpperHook(t *testing.T) { - c, s := newClientServer(t) - defer closeClientServer(c, s) + e := newEnv(t) + defer e.Free() + + c := e.newInstance() c.Exec("add name T listen,addr=-:0 upper loop") - conn := xDial(t, "tcp", c.Get("tunnel.T.listen")) - defer conn.Close() + conn := e.Dial("tcp", c.Get("tunnel.T.listen")) - xWrite(t, conn, xData) + e.Write(conn, dummy) - buf := make([]byte, len(xData)) - xReadFull(t, conn, buf) + buf := make([]byte, len(dummy)) + e.ReadFull(conn, buf) - if r := string(buf); r != strings.ToUpper(xData) { - t.Fatalf("wrong reply: send '%s', recv '%s'", xData, r) + if r := string(buf); r != strings.ToUpper(dummy) { + t.Fatalf("wrong reply: send '%s', recv '%s'", dummy, r) } } func TestHexHook(t *testing.T) { - c, s := newClientServer(t) - defer closeClientServer(c, s) + e := newEnv(t) + defer e.Free() - c.Exec("add name T listen,addr=-:0 hex dial,addr=@[addr]") + c := e.newInstance() - listen := xListen(t, "tcp", "127.0.0.1:0") - defer listen.Close() + c.Exec("add name T listen,addr=-:0 hex dial,addr=@[addr]") + listen := e.Listen("tcp", "127.0.0.1:0") c.Set("addr", listen.Addr()) - out := xDial(t, "tcp", c.Get("tunnel.T.listen")) - defer out.Close() - - in := xAccept(t, listen) - defer in.Close() + out := e.Dial("tcp", c.Get("tunnel.T.listen")) + in := e.Accept(listen) - xWrite(t, out, xData) + e.Write(out, dummy) - buf := make([]byte, 2*len(xData)) - xReadFull(t, in, buf) + buf := make([]byte, 2*len(dummy)) + e.ReadFull(in, buf) - if r := string(buf); r != hex.EncodeToString([]byte(xData)) { - t.Fatalf("wrong reply: send '%s', recv '%s'", xData, r) + if r := string(buf); r != hex.EncodeToString([]byte(dummy)) { + t.Fatalf("wrong reply: send '%s', recv '%s'", dummy, r) } } diff --git a/pkg/test/proxy_test.go b/pkg/test/proxy_test.go index b2fb097..ae89b0a 100644 --- a/pkg/test/proxy_test.go +++ b/pkg/test/proxy_test.go @@ -5,32 +5,29 @@ import ( ) func TestProxyHook(t *testing.T) { - const msg = "Hello, World!" + e := newEnv(t) + defer e.Free() - c, s := newClientServer(t) - defer closeClientServer(c, s) + c := e.newInstance() - listen := xListen(t, "tcp", "127.0.0.1:0") - defer listen.Close() - - saddr := c.AddListenTunnel("S", "proxy") - caddr := c.AddListenTunnel("C", "proxy,addr=%s dial,addr=%s", listen.Addr(), saddr) + c.Exec("add name C listen,addr=-:0 proxy,addr=@[addr] dial,addr=@[tunnel.S.listen]") + c.Exec("add name S listen,addr=-:0 proxy") c.Exec("set tunnel.S.proxy.auth user:password") c.Exec("set tunnel.C.proxy.auth user:password") - out := xDial(t, "tcp", caddr) - defer out.Close() + listen := e.Listen("tcp", "127.0.0.1:0") + c.Set("addr", listen.Addr()) - in := xAccept(t, listen) - defer in.Close() + out := e.Dial("tcp", c.Get("tunnel.C.listen")) + in := e.Accept(listen) - xWrite(t, out, msg) + e.Write(out, dummy) - buf := make([]byte, len(msg)) - xReadFull(t, in, buf) + buf := make([]byte, len(dummy)) + e.ReadFull(in, buf) - if r := string(buf); r != msg { - t.Fatalf("wrong reply: send '%s', recv '%s'", msg, r) + if r := string(buf); r != dummy { + t.Fatalf("wrong reply: send '%s', recv '%s'", dummy, r) } } diff --git a/pkg/test/test.go b/pkg/test/test.go index 44289c9..6d83ad0 100644 --- a/pkg/test/test.go +++ b/pkg/test/test.go @@ -14,10 +14,26 @@ import ( "tunnel/pkg/server" ) -const xData = "Hello, World!" +const dummy = "Hello, World!" type env struct { *testing.T + + trash []func() +} + +func newEnv(t *testing.T) *env { + return &env{T: t} +} + +func (e *env) Free() { + for _, f := range e.trash { + f() + } +} + +func (e *env) add(f func()) { + e.trash = append(e.trash, f) } func getSocketPath(id string) string { @@ -28,19 +44,19 @@ func getSocketPath(id string) string { type Client struct { *client.Client - t *testing.T + e *env } type Server struct { *server.Server } -func newClientServer(t *testing.T) (*Client, *Server) { - socket := getSocketPath(t.Name()) +func (e *env) newInstance() *Client { + socket := getSocketPath(e.Name()) s, err := server.New(socket) if err != nil { - t.Fatal(err) + e.Fatal(err) } go s.Serve() @@ -48,15 +64,15 @@ func newClientServer(t *testing.T) (*Client, *Server) { c, err := client.New(socket) if err != nil { s.Stop() - t.Fatal(err) + e.Fatal(err) } - return &Client{c, t}, &Server{s} -} + e.add(func() { + c.Close() + s.Stop() + }) -func closeClientServer(c *Client, s *Server) { - c.Close() - s.Stop() + return &Client{c, e} } func (c *Client) Send(format string, args ...interface{}) string { @@ -65,7 +81,7 @@ func (c *Client) Send(format string, args ...interface{}) string { r, err := c.Client.Send(t) if err != nil { - c.t.Fatal(err) + c.e.Fatal(err) } return r @@ -82,34 +98,34 @@ func (c *Client) Set(name string, value interface{}) { func (c *Client) Exec(format string, args ...interface{}) { s := c.Send(format, args...) if s != "" { - c.t.Fatal(s) + c.e.Fatal(s) } } -func (c *Client) AddListenTunnel(name string, format string, args ...interface{}) string { - t := append([]interface{}{name}, args...) - c.Exec("add name %s listen,addr=127.0.0.1:0 "+format, t...) - return c.Send("get tunnel.%s.listen", name) -} - -func xListen(t *testing.T, network, address string) net.Listener { +func (e *env) Listen(network, address string) net.Listener { listen, err := net.Listen(network, address) if err != nil { - t.Fatal(err) + e.Fatal(err) } + e.add(func() { + listen.Close() + }) return listen } -func xDial(t *testing.T, network, address string) net.Conn { +func (e *env) Dial(network, address string) net.Conn { d := net.Dialer{Timeout: 100 * time.Millisecond} conn, err := d.Dial(network, address) if err != nil { - t.Fatal(err) + e.Fatal(err) } + e.add(func() { + conn.Close() + }) return conn } -func xAccept(t *testing.T, listen net.Listener) net.Conn { +func (e *env) Accept(listen net.Listener) net.Conn { var conn net.Conn c := make(chan error, 1) @@ -123,16 +139,20 @@ func xAccept(t *testing.T, listen net.Listener) net.Conn { select { case err := <-c: if err != nil { - t.Fatal(err) + e.Fatal(err) } case <-timer.C: - t.Fatal("accept timeout") + e.Fatal("accept timeout") } + e.add(func() { + conn.Close() + }) + return conn } -func xWrite(t *testing.T, conn net.Conn, i interface{}) { +func (e *env) Write(conn net.Conn, i interface{}) { var buf []byte switch v := i.(type) { @@ -141,20 +161,20 @@ func xWrite(t *testing.T, conn net.Conn, i interface{}) { case []byte: buf = v default: - t.Fatalf("unexpected type %T", i) + e.Fatalf("unexpected type %T", i) } conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) if _, err := conn.Write(buf); err != nil { - t.Fatal("write to conn:", err) + e.Fatal("write to conn:", err) } conn.SetDeadline(time.Time{}) } -func xReadFull(t *testing.T, conn net.Conn, buf []byte) { +func (e *env) ReadFull(conn net.Conn, buf []byte) { conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) if _, err := io.ReadFull(conn, buf); err != nil { - t.Fatal("read from conn:", err) + e.Fatal("read from conn:", err) } conn.SetDeadline(time.Time{}) } |
