package main import ( "bufio" "errors" "flag" "fmt" "io" "log" "log/syslog" "os" "os/exec" "os/signal" "os/user" "path" "regexp" "strconv" "strings" "syscall" "tunnel/pkg/config" "tunnel/pkg/server" ) var ( debugFlag = flag.Bool("d", false, "debug: print time and source info") forceFlag = flag.Bool("f", false, "try start with force") syslogFlag = flag.Bool("s", false, "write log to syslog instead of stdout") configFlag = flag.String("c", "", "path to configuration file") socketFlag = flag.String("S", "", "path to control socket") ) const trimSize = 32 var wordsRe = regexp.MustCompile("[[:^space:]]+") func initLog() { var logFlags int if *debugFlag { logFlags |= log.Ldate | log.Ltime | log.Lshortfile } log.SetFlags(logFlags) if *syslogFlag { sysLog, err := syslog.New(syslog.LOG_INFO, "tunneld") if err != nil { log.Fatal(err) } log.SetOutput(sysLog) } } func sighandler(c chan os.Signal, s *server.Server) { var try bool for sig := range c { if try { log.Fatal("force exit") } log.Printf("catch signal: %s", sig) try = true s.Stop() } } func initSignals(s *server.Server) { var c = make(chan os.Signal) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) go sighandler(c, s) } func getSocketPath() string { if *socketFlag != "" { return *socketFlag } s := config.GetSocketPath() if err := os.Mkdir(path.Dir(s), 0700); err != nil { if !errors.Is(err, syscall.EEXIST) { log.Fatal(err) } } return s } type parser struct { s *server.Server m map[string]bool f []string } func newParser(s *server.Server) *parser { return &parser{s: s, m: map[string]bool{}} } func (p *parser) parse(file string, skip bool) error { if p.m[file] { log.Printf("skip reading '%s' twice", file) return nil } p.m[file] = true p.f = append(p.f, file) defer func() { p.f = p.f[:len(p.f)-1] }() log.Println("read config", file) fp, err := os.Open(file) if err != nil { if skip && errors.Is(err, syscall.ENOENT) { return nil } return err } defer fp.Close() return p.read(fp) } func (p *parser) read(fp *os.File) error { scanner := bufio.NewScanner(fp) for nline := 1; scanner.Scan(); nline++ { s := strings.SplitN(scanner.Text(), "#", 2) t := strings.TrimSpace(s[0]) args := wordsRe.FindAllString(s[0], -1) if err := p.apply(args); err != nil { return fmt.Errorf("%s:%d: %s: %w", fp.Name(), nline, t, err) } } return nil } func (p *parser) apply(args []string) error { if len(args) == 0 { return nil } if args[0] == "include" { if len(args) < 2 { return errors.New("argument expected") } file := args[1] if !path.IsAbs(file) { now := p.f[len(p.f)-1] file = path.Join(path.Dir(now), file) } return p.parse(file, false) } return p.s.Command(args) } func updateSocketGroup(s *server.Server, group string) error { var gid int if g, err := user.LookupGroup(group); err != nil { return err } else { var err error if gid, err = strconv.Atoi(g.Gid); err != nil { return fmt.Errorf("bad group id %s: %w", g.Gid, err) } } f := s.Socket() d := path.Dir(f) if err := os.Chown(d, -1, gid); err != nil { return err } if err := os.Chmod(d, 0750); err != nil { return err } if err := os.Chown(f, -1, gid); err != nil { return err } if err := os.Chmod(f, 0770); err != nil { return err } return nil } func runCommand(s *server.Server, name string) error { cmd := s.Env().Value(name) if cmd == "" { return nil } log.Printf("cmd \"%s\" try", cmd) args := strings.Split(cmd, " ") c := exec.Command(args[0], args[1:]...) stdout, _ := c.StdoutPipe() stderr, _ := c.StderrPipe() if err := c.Start(); err != nil { return fmt.Errorf("cmd \"%s\" start failed: %w", cmd, err) } logger := func(s string, r io.Reader) { for scanner := bufio.NewScanner(r); scanner.Scan(); { log.Println(s, ">", scanner.Text()) } } go logger(args[0], stdout) go logger(args[0], stderr) if err := c.Wait(); err != nil { return fmt.Errorf("cmd \"%s\" failed: %w", cmd, err) } log.Printf("cmd \"%s\" ok", cmd) return nil } func configure(s *server.Server) error { var file string var skip bool if *configFlag != "" { file = *configFlag } else { file = config.GetConfigPath() skip = true } if err := newParser(s).parse(file, skip); err != nil { return err } return postconfigure(s) } func postconfigure(s *server.Server) error { if group, ok := s.Env().Find("server.socket.group"); ok { if err := updateSocketGroup(s, group); err != nil { return err } } if err := runCommand(s, "cmd.init"); err != nil { return err } return nil } func deconfigure(s *server.Server) error { return runCommand(s, "cmd.fini") } func main() { flag.Parse() initLog() socket := getSocketPath() if *forceFlag { if err := os.Remove(socket); err != nil { if !errors.Is(err, syscall.ENOENT) { log.Fatal(err) } } } s, err := server.New(socket) if err != nil { log.Fatal(err) } if err := configure(s); err != nil { s.Stop() log.Fatal(err) } initSignals(s) log.Print("ready") s.Run() if err := deconfigure(s); err != nil { log.Println(err) } log.Print("exit") }