From 27e13f14f4dba71b417ea530bfe035adbd8f0a93 Mon Sep 17 00:00:00 2001 From: Mikhail Osipov Date: Thu, 27 Feb 2020 01:51:55 +0300 Subject: add config file support --- TODO | 6 +-- cmd/tunnel/main.go | 14 ++---- cmd/tunneld/main.go | 129 ++++++++++++++++++++++++++++++++++++++++++++++---- pkg/client/client.go | 2 +- pkg/config/config.go | 38 +++++++++++---- pkg/server/env.go | 12 ++--- pkg/server/env/env.go | 22 +++++++-- pkg/server/server.go | 70 +++++++++++++++++++-------- test/auth.sh | 6 +-- test/env.sh | 10 ++-- 10 files changed, 238 insertions(+), 71 deletions(-) diff --git a/TODO b/TODO index a97e8f9..188afda 100644 --- a/TODO +++ b/TODO @@ -35,14 +35,14 @@ note: 7. DONE substitute variable over query (maybe) 8. CloseRead or CloseWriter 9. tunnel enable/disable -10. config from file -11. system/user? unix control socket location +10. DONE config from file +11. DONE system/user? unix control socket location 12. DONE modules: auth(chap), enc, dec 13. DONE print module name when stream closed by error 14. stream.run check errors on Open 15. per tunnel, per stream statistics 16. http connect proxy module 17. unix socket path from config file - group owner of socket dir/file + DONE group owner of socket dir/file command line option for socket 18. systemd socket activation diff --git a/cmd/tunnel/main.go b/cmd/tunnel/main.go index 44eac45..dc94974 100644 --- a/cmd/tunnel/main.go +++ b/cmd/tunnel/main.go @@ -22,22 +22,14 @@ func init() { func getSocketPath() string { if *systemFlag { - return getSystemSocketPath() + return config.GetSystemSocketPath() } - s, err := config.GetSocketPath() - if err != nil { - log.Fatal(err) - } - return s + return config.GetSocketPath() } func getSystemSocketPath() string { - s, err := config.GetSystemSocketPath() - if err != nil { - log.Fatal(err) - } - return s + return config.GetSystemSocketPath() } func main() { diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go index 6a16dec..aa0a370 100644 --- a/cmd/tunneld/main.go +++ b/cmd/tunneld/main.go @@ -1,13 +1,18 @@ package main import ( + "bufio" "errors" "flag" + "fmt" "log" "log/syslog" "os" "os/signal" + "os/user" "path" + "strconv" + "strings" "syscall" "tunnel/pkg/config" "tunnel/pkg/server" @@ -16,6 +21,7 @@ import ( var ( debugFlag = flag.Bool("d", false, "debug: print time and source info") syslogFlag = flag.Bool("s", false, "log output to syslog instead of stdout") + configFlag = flag.String("c", "", "path to configuration file") ) func initLog() { @@ -51,11 +57,16 @@ func sighandler(c chan os.Signal, s *server.Server) { } } +func initSignals(s *server.Server) { + var c = make(chan os.Signal) + + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + + go sighandler(c, s) +} + func getSocketPath() string { - s, err := config.GetSocketPath() - if err != nil { - log.Fatal(err) - } + s := config.GetSocketPath() if err := os.Mkdir(path.Dir(s), 0700); err != nil { if !errors.Is(err, syscall.EEXIST) { @@ -66,6 +77,107 @@ func getSocketPath() string { return s } +func openConfig() (*os.File, error) { + var c string + + if len(*configFlag) > 0 { + c = *configFlag + } else { + c = config.GetConfigPath() + } + + if c == "" { + return nil, nil + } + + fp, err := os.Open(c) + if err != nil { + if *configFlag == "" && errors.Is(err, syscall.ENOENT) { + return nil, nil + } + + return nil, err + } + + return fp, nil +} + +func readConfig(s *server.Server) error { + fp, err := openConfig() + if fp == nil || err != nil { + return err + } + + defer fp.Close() + + scanner := bufio.NewScanner(fp) + scanner.Split(bufio.ScanLines) + + for nline := 1; scanner.Scan(); nline++ { + args := strings.SplitN(scanner.Text(), "#", 2) + cmd := strings.TrimSpace(args[0]) + + if cmd == "" { + continue + } + + if err := s.Command(cmd); err != nil { + return fmt.Errorf("%s:%d: %s: %w", fp.Name(), nline, cmd, err) + } + } + + return nil +} + +func updateSocketGroup(s *server.Server, group string) error { + var gid int + + if g, err := user.LookupGroup(group); err != nil { + return err + } else { + var err error + + if gid, err = strconv.Atoi(g.Gid); err != nil { + return fmt.Errorf("bad group id %s: %w", g.Gid, err) + } + } + + f := s.Socket() + d := path.Dir(f) + + if err := os.Chown(d, -1, gid); err != nil { + return err + } + + if err := os.Chmod(d, 0750); err != nil { + return err + } + + if err := os.Chown(f, -1, gid); err != nil { + return err + } + + if err := os.Chmod(f, 0770); err != nil { + return err + } + + return nil +} + +func configure(s *server.Server) error { + if err := readConfig(s); err != nil { + return err + } + + if group, ok := s.Env().Find("server.socket.group"); ok { + if err := updateSocketGroup(s, group); err != nil { + return err + } + } + + return nil +} + func main() { flag.Parse() initLog() @@ -75,11 +187,12 @@ func main() { log.Fatal(err) } - var c = make(chan os.Signal) - - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + if err := configure(s); err != nil { + s.Stop() + log.Fatal(err) + } - go sighandler(c, s) + initSignals(s) log.Print("ready") diff --git a/pkg/client/client.go b/pkg/client/client.go index 29d956e..66aa745 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -42,7 +42,7 @@ func (c *Client) Send(args []string) (string, error) { buf := make([]byte, config.BufSize) - _, ew := c.conn.Write([]byte(out.Bytes())) + _, ew := c.conn.Write(out.Bytes()) if ew != nil { return "", ew } diff --git a/pkg/config/config.go b/pkg/config/config.go index 5483192..906c4b9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,7 +1,6 @@ package config import ( - "errors" "fmt" "os" "time" @@ -13,21 +12,40 @@ const BufSize = 1024 const IoTimeout = 5 * time.Second -var errNegativeUid = errors.New("negative uid") +func GetSystemSocketPath() string { + return "/run/tunnel/socket" +} -func GetSystemSocketPath() (string, error) { - return "/run/tunnel/socket", nil +func getuid() int { + uid := os.Getuid() + if uid < 0 { + panic("os.Getuid() returns negative uid") + } + return uid } -func GetSocketPath() (string, error) { +func runAsRoot() bool { uid := os.Getuid() + if uid < 0 { + panic("os.Getuid() returns negative uid") + } + return uid == 0 +} - switch uid { - case -1: - return "", errNegativeUid - case 0: +func GetSocketPath() string { + if uid := getuid(); uid == 0 { return GetSystemSocketPath() + } else { + return fmt.Sprintf("/run/user/%d/tunnel/socket", uid) } +} - return fmt.Sprintf("/run/user/%d/tunnel/socket", uid), nil +func GetConfigPath() string { + if uid := getuid(); uid == 0 { + return "/etc/tunnel.conf" + } else if s, err := os.UserConfigDir(); err == nil { + return s + "/tunnel/config" + } else { + return "" + } } diff --git a/pkg/server/env.go b/pkg/server/env.go index bc9d4bf..a3c9f49 100644 --- a/pkg/server/env.go +++ b/pkg/server/env.go @@ -18,7 +18,7 @@ func varSet(r *request) { } } -func varDel(r *request) { +func varUnset(r *request) { r.expect(1) if !r.c.s.env.Del(r.args[0]) { @@ -38,9 +38,9 @@ func varClear(r *request) { } func init() { - newCmd(varGet, "var get") - newCmd(varSet, "var set") - newCmd(varDel, "var del") - newCmd(varShow, "var show") - newCmd(varClear, "var clear") + newCmd(varGet, "get") + newCmd(varSet, "set") + newCmd(varUnset, "unset") + newCmd(varShow, "env") + newCmd(varClear, "clear") } diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go index fe8af25..237c0c5 100644 --- a/pkg/server/env/env.go +++ b/pkg/server/env/env.go @@ -3,6 +3,7 @@ package env import ( "errors" "regexp" + "sort" "sync" ) @@ -18,7 +19,7 @@ type Env struct { const namePattern = "[a-zA-Z][a-zA-Z0-9.]*" var isNamePattern = regexp.MustCompile("^" + namePattern + "$").MatchString -var namePatternRe = regexp.MustCompile("@" + namePattern) +var namePatternRe = regexp.MustCompile("@(" + namePattern + "|{" + namePattern + "})") var errBadVariable = errors.New("bad variable name") @@ -93,11 +94,19 @@ func (e *env) Del(key string) bool { } func (e *env) Each(f func(string, string) bool) { + var keys []string + e.Lock() defer e.Unlock() - for k, v := range e.m { - if !f(k, v) { + for k := range e.m { + keys = append(keys, k) + } + + sort.Strings(keys) + + for _, k := range keys { + if !f(k, e.m[k]) { break } } @@ -115,7 +124,12 @@ func (e *env) Eval(s string) string { defer e.Unlock() repl := func(v string) string { - if v, ok := e.m[v[1:]]; ok { + key := v[1:] + if key[0] == '{' { + key = key[1 : len(key)-1] + } + + if v, ok := e.m[key]; ok { return v } return "" diff --git a/pkg/server/server.go b/pkg/server/server.go index a9a50a4..badf5c0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -51,7 +51,9 @@ type request struct { argc int args []string - out *bytes.Buffer + fail bool + + out bytes.Buffer } type requestError string @@ -67,15 +69,15 @@ func (r *request) String() string { } func (r *request) Print(v ...interface{}) { - fmt.Fprint(r.out, v...) + fmt.Fprint(&r.out, v...) } func (r *request) Printf(format string, v ...interface{}) { - fmt.Fprintf(r.out, format, v...) + fmt.Fprintf(&r.out, format, v...) } func (r *request) Println(v ...interface{}) { - fmt.Fprintln(r.out, v...) + fmt.Fprintln(&r.out, v...) } func (r *request) Fatal(v ...interface{}) { @@ -145,6 +147,14 @@ func New(path string) (*Server, error) { return s, nil } +func (s *Server) Env() env.Env { + return s.env +} + +func (s *Server) Socket() string { + return s.listen.Addr().String() +} + func (s *Server) Run() { for { conn, err := s.listen.Accept() @@ -188,11 +198,28 @@ func (s *Server) newClient(conn net.Conn) *client { return c } +func (s *Server) Command(query string) error { + args := strings.Split(query, " ") + + r := &request{c: &client{s: s}} + + r.run(args) + + if r.fail { + if r.out.Len() == 0 { + return errors.New("failed") + } + + return errors.New(r.out.String()) + } + + return nil +} + func (c *client) newRequest() *request { r := &request{ - c: c, - id: c.nextRid, - out: new(bytes.Buffer), + c: c, + id: c.nextRid, } c.nextRid++ @@ -214,11 +241,15 @@ func (c *client) handle() { break } - query := string(buf[:nr]) + args, err := c.decode(buf[:nr]) + if err != nil { + log.Println(c, "decode:", err) + break + } r := c.newRequest() - r.run(query) + r.run(args) if r.out.Len() == 0 { r.out.Write([]byte("\n")) @@ -232,8 +263,8 @@ func (c *client) handle() { } } -func (r *request) decode(query string) []string { - dec := netstring.NewDecoder(bytes.NewReader([]byte(query))) +func (c *client) decode(b []byte) ([]string, error) { + dec := netstring.NewDecoder(bytes.NewReader(b)) var t []string for { @@ -241,14 +272,14 @@ func (r *request) decode(query string) []string { t = append(t, s) } else { if !errors.Is(err, io.EOF) { - r.Fatal("request parse failed") + return nil, err } break } } - return t + return t, nil } func (r *request) eval(args []string) []string { @@ -263,10 +294,8 @@ func (r *request) eval(args []string) []string { return args } -func (r *request) parse(query string) { - args := r.eval(r.decode(query)) - - c, args := getCmd(args) +func (r *request) parse(args []string) { + c, args := getCmd(r.eval(args)) if c == nil { r.Fatal("command not found") } @@ -276,20 +305,21 @@ func (r *request) parse(query string) { r.cmd = c } -func (r *request) run(query string) { +func (r *request) run(args []string) { defer func() { switch err := recover().(type) { case requestError: r.Print(err) + r.fail = true default: panic(err) case nil: } }() - r.parse(query) + log.Println(r.c, r, ">", strings.Join(args, " ")) - log.Println(r.c, r, ">", r.cmd.name, strings.Join(r.args, " ")) + r.parse(args) r.c.s.mu.Lock() defer r.c.s.mu.Unlock() diff --git a/test/auth.sh b/test/auth.sh index 311b8ff..a807c96 100755 --- a/test/auth.sh +++ b/test/auth.sh @@ -5,10 +5,10 @@ PATH=$PATH:$ROOT/cmd/tunnel tunnel add name T 2000,listen auth aes 3000 tunnel add name X 3000,listen -aes -auth 4000 -tunnel var set tunnel.T.secret secret -tunnel var set tunnel.X.secret secret +tunnel set tunnel.T.secret secret +tunnel set tunnel.X.secret secret nc -l 4000 & echo "Hello, World!" | nc -N localhost 2000 -tunnel var clear +tunnel clear tunnel del T tunnel del X diff --git a/test/env.sh b/test/env.sh index bcff497..7e60718 100755 --- a/test/env.sh +++ b/test/env.sh @@ -3,9 +3,9 @@ ROOT=$(dirname $0)/.. PATH=$ROOT/cmd/tunnel -tunnel var set cmd echo -tunnel var set args ^"@x, @y!" -tunnel var set x Hello -tunnel var set y World +tunnel set cmd echo +tunnel set args ^"@x, @y!" +tunnel set x Hello +tunnel set y World tunnel @cmd @args -tunnel var clear +tunnel clear -- cgit v1.2.3-70-g09d2