summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--TODO3
-rw-r--r--cmd/tunnel/main.go44
-rw-r--r--cmd/tunneld/main.go28
-rw-r--r--pkg/client/client.go4
-rw-r--r--pkg/config/config.go29
-rw-r--r--pkg/server/server.go6
-rw-r--r--pkg/server/socket/socket.go2
7 files changed, 96 insertions, 20 deletions
diff --git a/TODO b/TODO
index ced8b85..62412fa 100644
--- a/TODO
+++ b/TODO
@@ -42,3 +42,6 @@ note:
14. stream.run check errors on Open
15. per tunnel, per stream statistics
16. http connect proxy module
+17. unix socket path from config file
+ group owner of socket dir/file
+ command line option for socket
diff --git a/cmd/tunnel/main.go b/cmd/tunnel/main.go
index 302b60e..5c6029e 100644
--- a/cmd/tunnel/main.go
+++ b/cmd/tunnel/main.go
@@ -1,31 +1,63 @@
package main
import (
+ "errors"
+ "flag"
"fmt"
"log"
"os"
"strings"
+ "syscall"
"tunnel/pkg/client"
+ "tunnel/pkg/config"
)
+var systemSocketVar bool
+
func init() {
log.SetFlags(0)
+
+ flag.BoolVar(&systemSocketVar, "system", false, "use system instance")
+}
+
+func getSocketPath() string {
+ if systemSocketVar {
+ return getSystemSocketPath()
+ }
+
+ s, err := config.GetSocketPath()
+ if err != nil {
+ log.Fatal(err)
+ }
+ return s
+}
+
+func getSystemSocketPath() string {
+ s, err := config.GetSystemSocketPath()
+ if err != nil {
+ log.Fatal(err)
+ }
+ return s
}
func main() {
- var args = os.Args
+ flag.Parse()
- if len(args) < 2 {
- fmt.Fprintln(os.Stderr, "bad usage")
- os.Exit(1)
+ if flag.NArg() < 1 {
+ log.Fatalf("Usage: %s command [arguments], try help", os.Args[0])
}
- c, err := client.New()
+ user, system := getSocketPath(), getSystemSocketPath()
+
+ c, err := client.New(user)
if err != nil {
+ if user != system && errors.Is(err, syscall.ENOENT) {
+ c, err = client.New(system)
+ }
log.Fatal(err)
}
- reply, err := c.Send(args[1:])
+ reply, err := c.Send(flag.Args())
if err != nil {
c.Close()
log.Fatal(err)
diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go
index 1ffb99e..cd8afda 100644
--- a/cmd/tunneld/main.go
+++ b/cmd/tunneld/main.go
@@ -1,11 +1,14 @@
package main
import (
+ "errors"
"flag"
"log"
"os"
"os/signal"
+ "path"
"syscall"
+ "tunnel/pkg/config"
"tunnel/pkg/server"
)
@@ -39,19 +42,34 @@ func sighandler(c chan os.Signal, s *server.Server) {
}
}
+func getSocketPath() string {
+ s, err := config.GetSocketPath()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ if err := os.Mkdir(path.Dir(s), 0700); err != nil {
+ if !errors.Is(err, syscall.EEXIST) {
+ log.Fatal(err)
+ }
+ }
+
+ return s
+}
+
func main() {
flag.Parse()
initLog()
- var c = make(chan os.Signal)
-
- signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
-
- s, err := server.New()
+ s, err := server.New(getSocketPath())
if err != nil {
log.Fatal(err)
}
+ var c = make(chan os.Signal)
+
+ signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
+
go sighandler(c, s)
log.Print("ready")
diff --git a/pkg/client/client.go b/pkg/client/client.go
index 1264853..29d956e 100644
--- a/pkg/client/client.go
+++ b/pkg/client/client.go
@@ -16,8 +16,8 @@ type Client struct {
conn net.Conn
}
-func New() (*Client, error) {
- conn, err := net.Dial(config.SockType, config.SockPath)
+func New(path string) (*Client, error) {
+ conn, err := net.Dial("unixpacket", path)
if err != nil {
return nil, err
}
diff --git a/pkg/config/config.go b/pkg/config/config.go
index fb09ad0..5483192 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -1,12 +1,33 @@
package config
-import "time"
-
-const SockType = "unixpacket"
-const SockPath = "/tmp/tunnel.sock"
+import (
+ "errors"
+ "fmt"
+ "os"
+ "time"
+)
const TimeFormat = "2006-01-02/15:04:05"
const BufSize = 1024
const IoTimeout = 5 * time.Second
+
+var errNegativeUid = errors.New("negative uid")
+
+func GetSystemSocketPath() (string, error) {
+ return "/run/tunnel/socket", nil
+}
+
+func GetSocketPath() (string, error) {
+ uid := os.Getuid()
+
+ switch uid {
+ case -1:
+ return "", errNegativeUid
+ case 0:
+ return GetSystemSocketPath()
+ }
+
+ return fmt.Sprintf("/run/user/%d/tunnel/socket", uid), nil
+}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 571397f..a9a50a4 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -126,12 +126,14 @@ func (s *Server) isDone() bool {
}
}
-func New() (*Server, error) {
- listen, err := net.Listen(config.SockType, config.SockPath)
+func New(path string) (*Server, error) {
+ listen, err := net.Listen("unixpacket", path)
if err != nil {
return nil, err
}
+ log.Println("listen at", path)
+
s := &Server{
env: env.New(),
listen: listen,
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 3db4310..c91423e 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -141,7 +141,7 @@ func (s *dialSocket) Close() {
func New(desc string, env env.Env) (S, error) {
base, opts := opts.Parse(desc)
- args := strings.Split(base, "/")
+ args := strings.SplitN(base, "/", 2)
var proto string
var addr string