summaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/client/client.go2
-rw-r--r--pkg/config/config.go38
-rw-r--r--pkg/server/env.go12
-rw-r--r--pkg/server/env/env.go22
-rw-r--r--pkg/server/server.go70
5 files changed, 103 insertions, 41 deletions
diff --git a/pkg/client/client.go b/pkg/client/client.go
index 29d956e..66aa745 100644
--- a/pkg/client/client.go
+++ b/pkg/client/client.go
@@ -42,7 +42,7 @@ func (c *Client) Send(args []string) (string, error) {
buf := make([]byte, config.BufSize)
- _, ew := c.conn.Write([]byte(out.Bytes()))
+ _, ew := c.conn.Write(out.Bytes())
if ew != nil {
return "", ew
}
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 5483192..906c4b9 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -1,7 +1,6 @@
package config
import (
- "errors"
"fmt"
"os"
"time"
@@ -13,21 +12,40 @@ const BufSize = 1024
const IoTimeout = 5 * time.Second
-var errNegativeUid = errors.New("negative uid")
+func GetSystemSocketPath() string {
+ return "/run/tunnel/socket"
+}
-func GetSystemSocketPath() (string, error) {
- return "/run/tunnel/socket", nil
+func getuid() int {
+ uid := os.Getuid()
+ if uid < 0 {
+ panic("os.Getuid() returns negative uid")
+ }
+ return uid
}
-func GetSocketPath() (string, error) {
+func runAsRoot() bool {
uid := os.Getuid()
+ if uid < 0 {
+ panic("os.Getuid() returns negative uid")
+ }
+ return uid == 0
+}
- switch uid {
- case -1:
- return "", errNegativeUid
- case 0:
+func GetSocketPath() string {
+ if uid := getuid(); uid == 0 {
return GetSystemSocketPath()
+ } else {
+ return fmt.Sprintf("/run/user/%d/tunnel/socket", uid)
}
+}
- return fmt.Sprintf("/run/user/%d/tunnel/socket", uid), nil
+func GetConfigPath() string {
+ if uid := getuid(); uid == 0 {
+ return "/etc/tunnel.conf"
+ } else if s, err := os.UserConfigDir(); err == nil {
+ return s + "/tunnel/config"
+ } else {
+ return ""
+ }
}
diff --git a/pkg/server/env.go b/pkg/server/env.go
index bc9d4bf..a3c9f49 100644
--- a/pkg/server/env.go
+++ b/pkg/server/env.go
@@ -18,7 +18,7 @@ func varSet(r *request) {
}
}
-func varDel(r *request) {
+func varUnset(r *request) {
r.expect(1)
if !r.c.s.env.Del(r.args[0]) {
@@ -38,9 +38,9 @@ func varClear(r *request) {
}
func init() {
- newCmd(varGet, "var get")
- newCmd(varSet, "var set")
- newCmd(varDel, "var del")
- newCmd(varShow, "var show")
- newCmd(varClear, "var clear")
+ newCmd(varGet, "get")
+ newCmd(varSet, "set")
+ newCmd(varUnset, "unset")
+ newCmd(varShow, "env")
+ newCmd(varClear, "clear")
}
diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go
index fe8af25..237c0c5 100644
--- a/pkg/server/env/env.go
+++ b/pkg/server/env/env.go
@@ -3,6 +3,7 @@ package env
import (
"errors"
"regexp"
+ "sort"
"sync"
)
@@ -18,7 +19,7 @@ type Env struct {
const namePattern = "[a-zA-Z][a-zA-Z0-9.]*"
var isNamePattern = regexp.MustCompile("^" + namePattern + "$").MatchString
-var namePatternRe = regexp.MustCompile("@" + namePattern)
+var namePatternRe = regexp.MustCompile("@(" + namePattern + "|{" + namePattern + "})")
var errBadVariable = errors.New("bad variable name")
@@ -93,11 +94,19 @@ func (e *env) Del(key string) bool {
}
func (e *env) Each(f func(string, string) bool) {
+ var keys []string
+
e.Lock()
defer e.Unlock()
- for k, v := range e.m {
- if !f(k, v) {
+ for k := range e.m {
+ keys = append(keys, k)
+ }
+
+ sort.Strings(keys)
+
+ for _, k := range keys {
+ if !f(k, e.m[k]) {
break
}
}
@@ -115,7 +124,12 @@ func (e *env) Eval(s string) string {
defer e.Unlock()
repl := func(v string) string {
- if v, ok := e.m[v[1:]]; ok {
+ key := v[1:]
+ if key[0] == '{' {
+ key = key[1 : len(key)-1]
+ }
+
+ if v, ok := e.m[key]; ok {
return v
}
return ""
diff --git a/pkg/server/server.go b/pkg/server/server.go
index a9a50a4..badf5c0 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -51,7 +51,9 @@ type request struct {
argc int
args []string
- out *bytes.Buffer
+ fail bool
+
+ out bytes.Buffer
}
type requestError string
@@ -67,15 +69,15 @@ func (r *request) String() string {
}
func (r *request) Print(v ...interface{}) {
- fmt.Fprint(r.out, v...)
+ fmt.Fprint(&r.out, v...)
}
func (r *request) Printf(format string, v ...interface{}) {
- fmt.Fprintf(r.out, format, v...)
+ fmt.Fprintf(&r.out, format, v...)
}
func (r *request) Println(v ...interface{}) {
- fmt.Fprintln(r.out, v...)
+ fmt.Fprintln(&r.out, v...)
}
func (r *request) Fatal(v ...interface{}) {
@@ -145,6 +147,14 @@ func New(path string) (*Server, error) {
return s, nil
}
+func (s *Server) Env() env.Env {
+ return s.env
+}
+
+func (s *Server) Socket() string {
+ return s.listen.Addr().String()
+}
+
func (s *Server) Run() {
for {
conn, err := s.listen.Accept()
@@ -188,11 +198,28 @@ func (s *Server) newClient(conn net.Conn) *client {
return c
}
+func (s *Server) Command(query string) error {
+ args := strings.Split(query, " ")
+
+ r := &request{c: &client{s: s}}
+
+ r.run(args)
+
+ if r.fail {
+ if r.out.Len() == 0 {
+ return errors.New("failed")
+ }
+
+ return errors.New(r.out.String())
+ }
+
+ return nil
+}
+
func (c *client) newRequest() *request {
r := &request{
- c: c,
- id: c.nextRid,
- out: new(bytes.Buffer),
+ c: c,
+ id: c.nextRid,
}
c.nextRid++
@@ -214,11 +241,15 @@ func (c *client) handle() {
break
}
- query := string(buf[:nr])
+ args, err := c.decode(buf[:nr])
+ if err != nil {
+ log.Println(c, "decode:", err)
+ break
+ }
r := c.newRequest()
- r.run(query)
+ r.run(args)
if r.out.Len() == 0 {
r.out.Write([]byte("\n"))
@@ -232,8 +263,8 @@ func (c *client) handle() {
}
}
-func (r *request) decode(query string) []string {
- dec := netstring.NewDecoder(bytes.NewReader([]byte(query)))
+func (c *client) decode(b []byte) ([]string, error) {
+ dec := netstring.NewDecoder(bytes.NewReader(b))
var t []string
for {
@@ -241,14 +272,14 @@ func (r *request) decode(query string) []string {
t = append(t, s)
} else {
if !errors.Is(err, io.EOF) {
- r.Fatal("request parse failed")
+ return nil, err
}
break
}
}
- return t
+ return t, nil
}
func (r *request) eval(args []string) []string {
@@ -263,10 +294,8 @@ func (r *request) eval(args []string) []string {
return args
}
-func (r *request) parse(query string) {
- args := r.eval(r.decode(query))
-
- c, args := getCmd(args)
+func (r *request) parse(args []string) {
+ c, args := getCmd(r.eval(args))
if c == nil {
r.Fatal("command not found")
}
@@ -276,20 +305,21 @@ func (r *request) parse(query string) {
r.cmd = c
}
-func (r *request) run(query string) {
+func (r *request) run(args []string) {
defer func() {
switch err := recover().(type) {
case requestError:
r.Print(err)
+ r.fail = true
default:
panic(err)
case nil:
}
}()
- r.parse(query)
+ log.Println(r.c, r, ">", strings.Join(args, " "))
- log.Println(r.c, r, ">", r.cmd.name, strings.Join(r.args, " "))
+ r.parse(args)
r.c.s.mu.Lock()
defer r.c.s.mu.Unlock()