package test import ( "fmt" "io" "net" "os" "path/filepath" "strings" "testing" "time" "tunnel/pkg/client" "tunnel/pkg/server" ) const dummy = "Hello, World!" type env struct { *testing.T trash []func() } func newEnv(t *testing.T) *env { return &env{T: t} } func (e *env) Free() { for _, f := range e.trash { f() } } func (e *env) add(f func()) { e.trash = append(e.trash, f) } func getSocketPath(id string) string { s := fmt.Sprintf("tunnel.%d.test.%s", os.Getpid(), id) return filepath.Join(os.TempDir(), s) } type Client struct { *client.Client e *env } type Server struct { *server.Server } func (e *env) newInstance() *Client { socket := getSocketPath(e.Name()) s, err := server.New(socket, "test") if err != nil { e.Fatal(err) } go s.Serve() c, err := client.New(socket) if err != nil { s.Stop() e.Fatal(err) } e.add(func() { c.Close() s.Stop() }) return &Client{c, e} } func (c *Client) Send(format string, args ...interface{}) string { s := fmt.Sprintf(format, args...) t := strings.Fields(s) r, err := c.Client.Send(t) if err != nil { c.e.Fatal(err) } return r } func (c *Client) Get(name string) string { return c.Send("get %s", name) } func (c *Client) Set(name string, value interface{}) { c.Exec("set %s %s", name, value) } func (c *Client) Exec(format string, args ...interface{}) { s := c.Send(format, args...) if s != "" { c.e.Fatal(s) } } func (e *env) Listen(network, address string) net.Listener { listen, err := net.Listen(network, address) if err != nil { e.Fatal(err) } e.add(func() { listen.Close() }) return listen } func (e *env) Dial(network, address string) net.Conn { d := net.Dialer{Timeout: 100 * time.Millisecond} conn, err := d.Dial(network, address) if err != nil { e.Fatal(err) } e.add(func() { conn.Close() }) return conn } func (e *env) Accept(listen net.Listener) net.Conn { var conn net.Conn c := make(chan error, 1) go func() { var err error conn, err = listen.Accept() c <- err }() timer := time.NewTimer(100 * time.Millisecond) select { case err := <-c: if err != nil { e.Fatal(err) } case <-timer.C: e.Fatal("accept timeout") } e.add(func() { conn.Close() }) return conn } func (e *env) Write(conn net.Conn, i interface{}) { var buf []byte switch v := i.(type) { case string: buf = []byte(v) case []byte: buf = v default: e.Fatalf("unexpected type %T", i) } conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) if _, err := conn.Write(buf); err != nil { e.Fatal("write to conn:", err) } conn.SetDeadline(time.Time{}) } func (e *env) ReadFull(conn net.Conn, buf []byte) { conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) if _, err := io.ReadFull(conn, buf); err != nil { e.Fatal("read from conn:", err) } conn.SetDeadline(time.Time{}) } func (e *env) NewTempFile(pattern string, data string) string { f, err := os.CreateTemp("", pattern) if err != nil { e.Fatalf("create temp: %v", err) } defer f.Close() io.WriteString(f, data) return f.Name() }