summaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
Diffstat (limited to 'cmd')
-rw-r--r--cmd/tunnel/main.go14
-rw-r--r--cmd/tunneld/main.go129
2 files changed, 124 insertions, 19 deletions
diff --git a/cmd/tunnel/main.go b/cmd/tunnel/main.go
index 44eac45..dc94974 100644
--- a/cmd/tunnel/main.go
+++ b/cmd/tunnel/main.go
@@ -22,22 +22,14 @@ func init() {
func getSocketPath() string {
if *systemFlag {
- return getSystemSocketPath()
+ return config.GetSystemSocketPath()
}
- s, err := config.GetSocketPath()
- if err != nil {
- log.Fatal(err)
- }
- return s
+ return config.GetSocketPath()
}
func getSystemSocketPath() string {
- s, err := config.GetSystemSocketPath()
- if err != nil {
- log.Fatal(err)
- }
- return s
+ return config.GetSystemSocketPath()
}
func main() {
diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go
index 6a16dec..aa0a370 100644
--- a/cmd/tunneld/main.go
+++ b/cmd/tunneld/main.go
@@ -1,13 +1,18 @@
package main
import (
+ "bufio"
"errors"
"flag"
+ "fmt"
"log"
"log/syslog"
"os"
"os/signal"
+ "os/user"
"path"
+ "strconv"
+ "strings"
"syscall"
"tunnel/pkg/config"
"tunnel/pkg/server"
@@ -16,6 +21,7 @@ import (
var (
debugFlag = flag.Bool("d", false, "debug: print time and source info")
syslogFlag = flag.Bool("s", false, "log output to syslog instead of stdout")
+ configFlag = flag.String("c", "", "path to configuration file")
)
func initLog() {
@@ -51,11 +57,16 @@ func sighandler(c chan os.Signal, s *server.Server) {
}
}
+func initSignals(s *server.Server) {
+ var c = make(chan os.Signal)
+
+ signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
+
+ go sighandler(c, s)
+}
+
func getSocketPath() string {
- s, err := config.GetSocketPath()
- if err != nil {
- log.Fatal(err)
- }
+ s := config.GetSocketPath()
if err := os.Mkdir(path.Dir(s), 0700); err != nil {
if !errors.Is(err, syscall.EEXIST) {
@@ -66,6 +77,107 @@ func getSocketPath() string {
return s
}
+func openConfig() (*os.File, error) {
+ var c string
+
+ if len(*configFlag) > 0 {
+ c = *configFlag
+ } else {
+ c = config.GetConfigPath()
+ }
+
+ if c == "" {
+ return nil, nil
+ }
+
+ fp, err := os.Open(c)
+ if err != nil {
+ if *configFlag == "" && errors.Is(err, syscall.ENOENT) {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+
+ return fp, nil
+}
+
+func readConfig(s *server.Server) error {
+ fp, err := openConfig()
+ if fp == nil || err != nil {
+ return err
+ }
+
+ defer fp.Close()
+
+ scanner := bufio.NewScanner(fp)
+ scanner.Split(bufio.ScanLines)
+
+ for nline := 1; scanner.Scan(); nline++ {
+ args := strings.SplitN(scanner.Text(), "#", 2)
+ cmd := strings.TrimSpace(args[0])
+
+ if cmd == "" {
+ continue
+ }
+
+ if err := s.Command(cmd); err != nil {
+ return fmt.Errorf("%s:%d: %s: %w", fp.Name(), nline, cmd, err)
+ }
+ }
+
+ return nil
+}
+
+func updateSocketGroup(s *server.Server, group string) error {
+ var gid int
+
+ if g, err := user.LookupGroup(group); err != nil {
+ return err
+ } else {
+ var err error
+
+ if gid, err = strconv.Atoi(g.Gid); err != nil {
+ return fmt.Errorf("bad group id %s: %w", g.Gid, err)
+ }
+ }
+
+ f := s.Socket()
+ d := path.Dir(f)
+
+ if err := os.Chown(d, -1, gid); err != nil {
+ return err
+ }
+
+ if err := os.Chmod(d, 0750); err != nil {
+ return err
+ }
+
+ if err := os.Chown(f, -1, gid); err != nil {
+ return err
+ }
+
+ if err := os.Chmod(f, 0770); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func configure(s *server.Server) error {
+ if err := readConfig(s); err != nil {
+ return err
+ }
+
+ if group, ok := s.Env().Find("server.socket.group"); ok {
+ if err := updateSocketGroup(s, group); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
func main() {
flag.Parse()
initLog()
@@ -75,11 +187,12 @@ func main() {
log.Fatal(err)
}
- var c = make(chan os.Signal)
-
- signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
+ if err := configure(s); err != nil {
+ s.Stop()
+ log.Fatal(err)
+ }
- go sighandler(c, s)
+ initSignals(s)
log.Print("ready")