package main import ( "bufio" "errors" "flag" "fmt" "io" "log" "log/syslog" "os" "os/exec" "os/signal" "os/user" "path" "strconv" "strings" "sync" "syscall" "tunnel/pkg/config" "tunnel/pkg/server" ) var BuildVersion string const trimSize = 32 type programArgs struct { debug bool force bool syslog bool output string config string ctlrpath string ctlrname string version bool } func parseArgs() programArgs { var args programArgs flag.BoolVar(&args.debug, "d", false, "debug: print time and source info") flag.BoolVar(&args.force, "f", false, "try start with force") flag.BoolVar(&args.syslog, "syslog", false, "write log to syslog instead of stdout") flag.StringVar(&args.output, "o", "", "write log to file instead of stdout") flag.StringVar(&args.config, "c", "", "path to configuration file") flag.StringVar(&args.ctlrpath, "S", "", "path to control socket") flag.StringVar(&args.ctlrname, "s", "", "name of control socket") flag.BoolVar(&args.version, "version", false, "print version and exit") flag.Parse() return args } func initLog(args programArgs) { if args.output != "" && args.syslog { log.Fatal("bad usage: duplicate log write flag") } if args.debug && args.syslog { log.Fatal("bad usage: debug with syslog is out of sense") } if args.output != "" { const fileFlags = os.O_APPEND | os.O_CREATE | os.O_WRONLY f, err := os.OpenFile(args.output, fileFlags, 0644) if err != nil { log.Fatalf("log: %s", err) } log.SetOutput(f) } if args.syslog { sysLog, err := syslog.New(syslog.LOG_INFO, "tunneld") if err != nil { log.Fatalf("log: %s", err) } log.SetOutput(sysLog) } var logFlags int if args.debug { logFlags |= log.Ldate | log.Ltime | log.Lshortfile } if args.output != "" { logFlags |= log.Ldate | log.Ltime } log.SetFlags(logFlags) } func sighandler(c chan os.Signal, s *server.Server) { var try bool for sig := range c { if try { log.Fatal("force exit") } log.Printf("catch signal: %s", sig) try = true s.Stop() } } func initSignals(s *server.Server) { var c = make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) go sighandler(c, s) } func getSocketPath(args programArgs) string { if args.ctlrpath != "" { return args.ctlrpath } s := config.GetSocketPath(args.ctlrname) if err := os.Mkdir(path.Dir(s), 0700); err != nil { if !errors.Is(err, syscall.EEXIST) { log.Fatal(err) } } return s } type parser struct { s *server.Server m map[string]bool f []string } func newParser(s *server.Server) *parser { return &parser{s: s, m: map[string]bool{}} } func (p *parser) parse(file string, skip bool) error { if p.m[file] { log.Printf("skip reading '%s' twice", file) return nil } p.m[file] = true p.f = append(p.f, file) defer func() { p.f = p.f[:len(p.f)-1] }() log.Println("read config", file) fp, err := os.Open(file) if err != nil { if skip && errors.Is(err, syscall.ENOENT) { return nil } return err } defer fp.Close() return p.read(fp) } func (p *parser) read(fp *os.File) error { scanner := bufio.NewScanner(fp) var args []string var line int var block bool for n := 1; scanner.Scan(); n++ { if line == 0 { line = n } parts := strings.SplitN(scanner.Text(), "#", 2) s := strings.TrimSpace(parts[0]) if strings.HasSuffix(s, " --") { if block { return fmt.Errorf("%s:%d: duplicate block", fp.Name(), n) } s = s[:len(s)-3] block = true } if s == "--" { if !block { return fmt.Errorf("%s:%d: no block context", fp.Name(), n) } block = false s = "" } args = append(args, strings.Fields(s)...) if block { continue } t := strings.Join(args, " ") if err := p.apply(args); err != nil { return fmt.Errorf("%s:%d: %s: %w", fp.Name(), line, t, err) } args = []string{} line = 0 } if len(args) > 0 { return fmt.Errorf("%s:%d: unexpected end of file", fp.Name(), line) } return nil } func (p *parser) apply(args []string) error { if len(args) == 0 { return nil } if args[0] == "read" { if len(args) < 2 { return errors.New("argument expected") } file := args[1] if !path.IsAbs(file) { now := p.f[len(p.f)-1] file = path.Join(path.Dir(now), file) } return p.parse(file, false) } return p.s.Command(args) } func updateSocketGroup(s *server.Server, group string) error { var gid int if g, err := user.LookupGroup(group); err != nil { return err } else { var err error if gid, err = strconv.Atoi(g.Gid); err != nil { return fmt.Errorf("bad group id %s: %w", g.Gid, err) } } f := s.Socket() d := path.Dir(f) if err := os.Chown(d, -1, gid); err != nil { return err } if err := os.Chmod(d, 0750); err != nil { return err } if err := os.Chown(f, -1, gid); err != nil { return err } if err := os.Chmod(f, 0770); err != nil { return err } return nil } func runCommand(s *server.Server, name string) error { cmd := s.Env().Value(name) if cmd == "" { return nil } log.Printf("cmd \"%s\" try", cmd) args := strings.Fields(cmd) c := exec.Command(args[0], args[1:]...) if env := s.Env().Value("cmd.env"); env != "" { c.Env = append(os.Environ(), strings.Split(env, ",")...) } stdout, _ := c.StdoutPipe() stderr, _ := c.StderrPipe() if err := c.Start(); err != nil { return fmt.Errorf("cmd \"%s\" start failed: %w", cmd, err) } var wg sync.WaitGroup logger := func(s string, r io.Reader) { for scanner := bufio.NewScanner(r); scanner.Scan(); { log.Println(s, ">", scanner.Text()) } wg.Done() } wg.Add(2) go logger(args[0], stdout) go logger(args[0], stderr) if err := c.Wait(); err != nil { return fmt.Errorf("cmd \"%s\" failed: %w", cmd, err) } wg.Wait() log.Printf("cmd \"%s\" ok", cmd) return nil } func configure(s *server.Server, args programArgs) error { var file string var skip bool if args.config != "" { file = args.config } else { file = config.GetConfigPath() skip = true } if err := newParser(s).parse(file, skip); err != nil { return err } return postconfigure(s) } func postconfigure(s *server.Server) error { if group, ok := s.Env().Find("server.socket.group"); ok { if err := updateSocketGroup(s, group); err != nil { return err } } if err := runCommand(s, "cmd.init"); err != nil { return err } return nil } func deconfigure(s *server.Server) error { return runCommand(s, "cmd.fini") } func main() { log.SetFlags(0) args := parseArgs() if args.version { fmt.Fprintln(os.Stderr, BuildVersion) os.Exit(1) } initLog(args) if flag.NArg() > 0 { log.Fatal("bad usage: extra args") } socket := getSocketPath(args) if args.force { if err := os.Remove(socket); err != nil { if !errors.Is(err, syscall.ENOENT) { log.Fatal(err) } } } s, err := server.New(socket, BuildVersion) if err != nil { log.Fatal(err) } if err := configure(s, args); err != nil { s.Stop() log.Fatal(err) } initSignals(s) log.Print("ready") s.Serve() if err := deconfigure(s); err != nil { log.Println(err) } log.Print("exit") }