diff options
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/tunnel/main.go | 14 | ||||
| -rw-r--r-- | cmd/tunneld/main.go | 129 |
2 files changed, 124 insertions, 19 deletions
diff --git a/cmd/tunnel/main.go b/cmd/tunnel/main.go index 44eac45..dc94974 100644 --- a/cmd/tunnel/main.go +++ b/cmd/tunnel/main.go @@ -22,22 +22,14 @@ func init() { func getSocketPath() string { if *systemFlag { - return getSystemSocketPath() + return config.GetSystemSocketPath() } - s, err := config.GetSocketPath() - if err != nil { - log.Fatal(err) - } - return s + return config.GetSocketPath() } func getSystemSocketPath() string { - s, err := config.GetSystemSocketPath() - if err != nil { - log.Fatal(err) - } - return s + return config.GetSystemSocketPath() } func main() { diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go index 6a16dec..aa0a370 100644 --- a/cmd/tunneld/main.go +++ b/cmd/tunneld/main.go @@ -1,13 +1,18 @@ package main import ( + "bufio" "errors" "flag" + "fmt" "log" "log/syslog" "os" "os/signal" + "os/user" "path" + "strconv" + "strings" "syscall" "tunnel/pkg/config" "tunnel/pkg/server" @@ -16,6 +21,7 @@ import ( var ( debugFlag = flag.Bool("d", false, "debug: print time and source info") syslogFlag = flag.Bool("s", false, "log output to syslog instead of stdout") + configFlag = flag.String("c", "", "path to configuration file") ) func initLog() { @@ -51,11 +57,16 @@ func sighandler(c chan os.Signal, s *server.Server) { } } +func initSignals(s *server.Server) { + var c = make(chan os.Signal) + + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + + go sighandler(c, s) +} + func getSocketPath() string { - s, err := config.GetSocketPath() - if err != nil { - log.Fatal(err) - } + s := config.GetSocketPath() if err := os.Mkdir(path.Dir(s), 0700); err != nil { if !errors.Is(err, syscall.EEXIST) { @@ -66,6 +77,107 @@ func getSocketPath() string { return s } +func openConfig() (*os.File, error) { + var c string + + if len(*configFlag) > 0 { + c = *configFlag + } else { + c = config.GetConfigPath() + } + + if c == "" { + return nil, nil + } + + fp, err := os.Open(c) + if err != nil { + if *configFlag == "" && errors.Is(err, syscall.ENOENT) { + return nil, nil + } + + return nil, err + } + + return fp, nil +} + +func readConfig(s *server.Server) error { + fp, err := openConfig() + if fp == nil || err != nil { + return err + } + + defer fp.Close() + + scanner := bufio.NewScanner(fp) + scanner.Split(bufio.ScanLines) + + for nline := 1; scanner.Scan(); nline++ { + args := strings.SplitN(scanner.Text(), "#", 2) + cmd := strings.TrimSpace(args[0]) + + if cmd == "" { + continue + } + + if err := s.Command(cmd); err != nil { + return fmt.Errorf("%s:%d: %s: %w", fp.Name(), nline, cmd, err) + } + } + + return nil +} + +func updateSocketGroup(s *server.Server, group string) error { + var gid int + + if g, err := user.LookupGroup(group); err != nil { + return err + } else { + var err error + + if gid, err = strconv.Atoi(g.Gid); err != nil { + return fmt.Errorf("bad group id %s: %w", g.Gid, err) + } + } + + f := s.Socket() + d := path.Dir(f) + + if err := os.Chown(d, -1, gid); err != nil { + return err + } + + if err := os.Chmod(d, 0750); err != nil { + return err + } + + if err := os.Chown(f, -1, gid); err != nil { + return err + } + + if err := os.Chmod(f, 0770); err != nil { + return err + } + + return nil +} + +func configure(s *server.Server) error { + if err := readConfig(s); err != nil { + return err + } + + if group, ok := s.Env().Find("server.socket.group"); ok { + if err := updateSocketGroup(s, group); err != nil { + return err + } + } + + return nil +} + func main() { flag.Parse() initLog() @@ -75,11 +187,12 @@ func main() { log.Fatal(err) } - var c = make(chan os.Signal) - - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + if err := configure(s); err != nil { + s.Stop() + log.Fatal(err) + } - go sighandler(c, s) + initSignals(s) log.Print("ready") |
