package main import ( "bufio" "errors" "flag" "fmt" "log" "log/syslog" "os" "os/signal" "os/user" "path" "strconv" "strings" "syscall" "tunnel/pkg/config" "tunnel/pkg/server" ) var ( debugFlag = flag.Bool("d", false, "debug: print time and source info") forceFlag = flag.Bool("f", false, "try start with force") syslogFlag = flag.Bool("s", false, "write log to syslog instead of stdout") configFlag = flag.String("c", "", "path to configuration file") ) func initLog() { var logFlags int if *debugFlag { logFlags |= log.Ldate | log.Ltime | log.Lshortfile } log.SetFlags(logFlags) if *syslogFlag { sysLog, err := syslog.New(syslog.LOG_INFO, "tunneld") if err != nil { log.Fatal(err) } log.SetOutput(sysLog) } } func sighandler(c chan os.Signal, s *server.Server) { var try bool for sig := range c { if try { log.Fatal("force exit") } log.Printf("catch signal: %s", sig) try = true s.Stop() } } func initSignals(s *server.Server) { var c = make(chan os.Signal) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) go sighandler(c, s) } func getSocketPath() string { s := config.GetSocketPath() if err := os.Mkdir(path.Dir(s), 0700); err != nil { if !errors.Is(err, syscall.EEXIST) { log.Fatal(err) } } return s } func openConfig() (*os.File, error) { var c string if len(*configFlag) > 0 { c = *configFlag } else { c = config.GetConfigPath() } if c == "" { return nil, nil } fp, err := os.Open(c) if err != nil { if *configFlag == "" && errors.Is(err, syscall.ENOENT) { return nil, nil } return nil, err } return fp, nil } func readConfig(s *server.Server) error { fp, err := openConfig() if fp == nil || err != nil { return err } defer fp.Close() scanner := bufio.NewScanner(fp) scanner.Split(bufio.ScanLines) for nline := 1; scanner.Scan(); nline++ { args := strings.SplitN(scanner.Text(), "#", 2) cmd := strings.TrimSpace(args[0]) if cmd == "" { continue } if err := s.Command(cmd); err != nil { return fmt.Errorf("%s:%d: %s: %w", fp.Name(), nline, cmd, err) } } return nil } func updateSocketGroup(s *server.Server, group string) error { var gid int if g, err := user.LookupGroup(group); err != nil { return err } else { var err error if gid, err = strconv.Atoi(g.Gid); err != nil { return fmt.Errorf("bad group id %s: %w", g.Gid, err) } } f := s.Socket() d := path.Dir(f) if err := os.Chown(d, -1, gid); err != nil { return err } if err := os.Chmod(d, 0750); err != nil { return err } if err := os.Chown(f, -1, gid); err != nil { return err } if err := os.Chmod(f, 0770); err != nil { return err } return nil } func configure(s *server.Server) error { if err := readConfig(s); err != nil { return err } if group, ok := s.Env().Find("server.socket.group"); ok { if err := updateSocketGroup(s, group); err != nil { return err } } return nil } func main() { flag.Parse() initLog() socket := getSocketPath() if *forceFlag { if err := os.Remove(socket); err != nil { if !errors.Is(err, syscall.ENOENT) { log.Fatal(err) } } } s, err := server.New(socket) if err != nil { log.Fatal(err) } if err := configure(s); err != nil { s.Stop() log.Fatal(err) } initSignals(s) log.Print("ready") s.Run() log.Print("exit") }