diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-02-27 01:51:55 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-02-27 02:11:12 +0300 |
| commit | 27e13f14f4dba71b417ea530bfe035adbd8f0a93 (patch) | |
| tree | 0ed374c02a7f1b5516a8e6e9ab1d37b42860e299 /pkg | |
| parent | 085bdfb75eb1c4b90a25a792815f8b80ed06dccb (diff) | |
add config file support
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/client/client.go | 2 | ||||
| -rw-r--r-- | pkg/config/config.go | 38 | ||||
| -rw-r--r-- | pkg/server/env.go | 12 | ||||
| -rw-r--r-- | pkg/server/env/env.go | 22 | ||||
| -rw-r--r-- | pkg/server/server.go | 70 |
5 files changed, 103 insertions, 41 deletions
diff --git a/pkg/client/client.go b/pkg/client/client.go index 29d956e..66aa745 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -42,7 +42,7 @@ func (c *Client) Send(args []string) (string, error) { buf := make([]byte, config.BufSize) - _, ew := c.conn.Write([]byte(out.Bytes())) + _, ew := c.conn.Write(out.Bytes()) if ew != nil { return "", ew } diff --git a/pkg/config/config.go b/pkg/config/config.go index 5483192..906c4b9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,7 +1,6 @@ package config import ( - "errors" "fmt" "os" "time" @@ -13,21 +12,40 @@ const BufSize = 1024 const IoTimeout = 5 * time.Second -var errNegativeUid = errors.New("negative uid") +func GetSystemSocketPath() string { + return "/run/tunnel/socket" +} -func GetSystemSocketPath() (string, error) { - return "/run/tunnel/socket", nil +func getuid() int { + uid := os.Getuid() + if uid < 0 { + panic("os.Getuid() returns negative uid") + } + return uid } -func GetSocketPath() (string, error) { +func runAsRoot() bool { uid := os.Getuid() + if uid < 0 { + panic("os.Getuid() returns negative uid") + } + return uid == 0 +} - switch uid { - case -1: - return "", errNegativeUid - case 0: +func GetSocketPath() string { + if uid := getuid(); uid == 0 { return GetSystemSocketPath() + } else { + return fmt.Sprintf("/run/user/%d/tunnel/socket", uid) } +} - return fmt.Sprintf("/run/user/%d/tunnel/socket", uid), nil +func GetConfigPath() string { + if uid := getuid(); uid == 0 { + return "/etc/tunnel.conf" + } else if s, err := os.UserConfigDir(); err == nil { + return s + "/tunnel/config" + } else { + return "" + } } diff --git a/pkg/server/env.go b/pkg/server/env.go index bc9d4bf..a3c9f49 100644 --- a/pkg/server/env.go +++ b/pkg/server/env.go @@ -18,7 +18,7 @@ func varSet(r *request) { } } -func varDel(r *request) { +func varUnset(r *request) { r.expect(1) if !r.c.s.env.Del(r.args[0]) { @@ -38,9 +38,9 @@ func varClear(r *request) { } func init() { - newCmd(varGet, "var get") - newCmd(varSet, "var set") - newCmd(varDel, "var del") - newCmd(varShow, "var show") - newCmd(varClear, "var clear") + newCmd(varGet, "get") + newCmd(varSet, "set") + newCmd(varUnset, "unset") + newCmd(varShow, "env") + newCmd(varClear, "clear") } diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go index fe8af25..237c0c5 100644 --- a/pkg/server/env/env.go +++ b/pkg/server/env/env.go @@ -3,6 +3,7 @@ package env import ( "errors" "regexp" + "sort" "sync" ) @@ -18,7 +19,7 @@ type Env struct { const namePattern = "[a-zA-Z][a-zA-Z0-9.]*" var isNamePattern = regexp.MustCompile("^" + namePattern + "$").MatchString -var namePatternRe = regexp.MustCompile("@" + namePattern) +var namePatternRe = regexp.MustCompile("@(" + namePattern + "|{" + namePattern + "})") var errBadVariable = errors.New("bad variable name") @@ -93,11 +94,19 @@ func (e *env) Del(key string) bool { } func (e *env) Each(f func(string, string) bool) { + var keys []string + e.Lock() defer e.Unlock() - for k, v := range e.m { - if !f(k, v) { + for k := range e.m { + keys = append(keys, k) + } + + sort.Strings(keys) + + for _, k := range keys { + if !f(k, e.m[k]) { break } } @@ -115,7 +124,12 @@ func (e *env) Eval(s string) string { defer e.Unlock() repl := func(v string) string { - if v, ok := e.m[v[1:]]; ok { + key := v[1:] + if key[0] == '{' { + key = key[1 : len(key)-1] + } + + if v, ok := e.m[key]; ok { return v } return "" diff --git a/pkg/server/server.go b/pkg/server/server.go index a9a50a4..badf5c0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -51,7 +51,9 @@ type request struct { argc int args []string - out *bytes.Buffer + fail bool + + out bytes.Buffer } type requestError string @@ -67,15 +69,15 @@ func (r *request) String() string { } func (r *request) Print(v ...interface{}) { - fmt.Fprint(r.out, v...) + fmt.Fprint(&r.out, v...) } func (r *request) Printf(format string, v ...interface{}) { - fmt.Fprintf(r.out, format, v...) + fmt.Fprintf(&r.out, format, v...) } func (r *request) Println(v ...interface{}) { - fmt.Fprintln(r.out, v...) + fmt.Fprintln(&r.out, v...) } func (r *request) Fatal(v ...interface{}) { @@ -145,6 +147,14 @@ func New(path string) (*Server, error) { return s, nil } +func (s *Server) Env() env.Env { + return s.env +} + +func (s *Server) Socket() string { + return s.listen.Addr().String() +} + func (s *Server) Run() { for { conn, err := s.listen.Accept() @@ -188,11 +198,28 @@ func (s *Server) newClient(conn net.Conn) *client { return c } +func (s *Server) Command(query string) error { + args := strings.Split(query, " ") + + r := &request{c: &client{s: s}} + + r.run(args) + + if r.fail { + if r.out.Len() == 0 { + return errors.New("failed") + } + + return errors.New(r.out.String()) + } + + return nil +} + func (c *client) newRequest() *request { r := &request{ - c: c, - id: c.nextRid, - out: new(bytes.Buffer), + c: c, + id: c.nextRid, } c.nextRid++ @@ -214,11 +241,15 @@ func (c *client) handle() { break } - query := string(buf[:nr]) + args, err := c.decode(buf[:nr]) + if err != nil { + log.Println(c, "decode:", err) + break + } r := c.newRequest() - r.run(query) + r.run(args) if r.out.Len() == 0 { r.out.Write([]byte("\n")) @@ -232,8 +263,8 @@ func (c *client) handle() { } } -func (r *request) decode(query string) []string { - dec := netstring.NewDecoder(bytes.NewReader([]byte(query))) +func (c *client) decode(b []byte) ([]string, error) { + dec := netstring.NewDecoder(bytes.NewReader(b)) var t []string for { @@ -241,14 +272,14 @@ func (r *request) decode(query string) []string { t = append(t, s) } else { if !errors.Is(err, io.EOF) { - r.Fatal("request parse failed") + return nil, err } break } } - return t + return t, nil } func (r *request) eval(args []string) []string { @@ -263,10 +294,8 @@ func (r *request) eval(args []string) []string { return args } -func (r *request) parse(query string) { - args := r.eval(r.decode(query)) - - c, args := getCmd(args) +func (r *request) parse(args []string) { + c, args := getCmd(r.eval(args)) if c == nil { r.Fatal("command not found") } @@ -276,20 +305,21 @@ func (r *request) parse(query string) { r.cmd = c } -func (r *request) run(query string) { +func (r *request) run(args []string) { defer func() { switch err := recover().(type) { case requestError: r.Print(err) + r.fail = true default: panic(err) case nil: } }() - r.parse(query) + log.Println(r.c, r, ">", strings.Join(args, " ")) - log.Println(r.c, r, ">", r.cmd.name, strings.Join(r.args, " ")) + r.parse(args) r.c.s.mu.Lock() defer r.c.s.mu.Unlock() |
