summaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/server/socket/exec.go112
-rw-r--r--pkg/test/exec_test.go27
2 files changed, 139 insertions, 0 deletions
diff --git a/pkg/server/socket/exec.go b/pkg/server/socket/exec.go
new file mode 100644
index 0000000..bb3bcd0
--- /dev/null
+++ b/pkg/server/socket/exec.go
@@ -0,0 +1,112 @@
+package socket
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "os/exec"
+ "strings"
+ "sync"
+
+ "tunnel/pkg/server/env"
+ "tunnel/pkg/server/queue"
+)
+
+type execSocket struct {
+ Cmd string `opts:"required"`
+}
+
+type execConn struct {
+ s *execSocket
+ cmd *exec.Cmd
+
+ stdin io.WriteCloser
+ stdout io.ReadCloser
+
+ once sync.Once
+ wg sync.WaitGroup
+}
+
+func (s *execSocket) String() string {
+ return fmt.Sprintf("exec(%s)", s.Cmd)
+}
+
+func (s *execSocket) Close() {
+}
+
+func (s *execSocket) New(env env.Env) (Conn, error) {
+ tunnel, stream := env.Get("tunnel"), env.Get("stream")
+
+ args := strings.Fields(s.Cmd)
+ if len(args) == 0 {
+ return nil, errors.New("bad command")
+ }
+
+ cmd := exec.Command(args[0], args[1:]...)
+
+ stdin, _ := cmd.StdinPipe()
+ stdout, _ := cmd.StdoutPipe()
+ stderr, _ := cmd.StderrPipe()
+
+ if err := cmd.Start(); err != nil {
+ return nil, err
+ }
+
+ c := &execConn{
+ s: s,
+ cmd: cmd,
+ stdin: stdin,
+ stdout: stdout,
+ }
+
+ c.wg.Add(2)
+
+ go func(s string, r io.Reader) {
+ for scanner := bufio.NewScanner(r); scanner.Scan(); {
+ log.Printf("tunnel:%s stream:%s %s > %s", tunnel, stream, s, scanner.Text())
+ }
+ c.wg.Done()
+ }(args[0], stderr)
+
+ go func() {
+ c.wg.Wait()
+ c.cmd.Wait()
+ }()
+
+ return c, nil
+}
+
+func (c *execConn) String() string {
+ return c.s.String()
+}
+
+func (c *execConn) Send(wq queue.Q) error {
+ c.wg.Add(1)
+ defer c.wg.Done()
+ return queue.IoCopy(c.stdout, wq.Writer())
+}
+
+func (c *execConn) Recv(rq queue.Q) error {
+ c.wg.Add(1)
+ defer c.wg.Done()
+ return queue.IoCopy(rq.Reader(), c.stdin)
+}
+
+func (c *execConn) Close() error {
+ err := ErrAlreadyClosed
+
+ c.once.Do(func() {
+ log.Println("close", c.s)
+ c.cmd.Process.Kill()
+ c.wg.Done()
+ err = nil
+ })
+
+ return err
+}
+
+func init() {
+ register("exec", "in/out throw external process", execSocket{})
+}
diff --git a/pkg/test/exec_test.go b/pkg/test/exec_test.go
new file mode 100644
index 0000000..c8d61a0
--- /dev/null
+++ b/pkg/test/exec_test.go
@@ -0,0 +1,27 @@
+package test
+
+import (
+ "testing"
+
+ "strings"
+)
+
+func TestExec(t *testing.T) {
+ e := newEnv(t)
+ defer e.Free()
+
+ c := e.newInstance()
+
+ c.Exec("add name T listen,addr=-:0 upper exec,cmd=cat")
+
+ conn := e.Dial("tcp", c.Get("tunnel.T.listen"))
+
+ e.Write(conn, dummy)
+
+ buf := make([]byte, len(dummy))
+ e.ReadFull(conn, buf)
+
+ if r := string(buf); r != strings.ToUpper(dummy) {
+ t.Fatalf("wrong reply: send '%s', recv '%s'", dummy, r)
+ }
+}