summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikhail Osipov <mike.osipov@gmail.com>2020-02-27 01:51:55 +0300
committerMikhail Osipov <mike.osipov@gmail.com>2020-02-27 02:11:12 +0300
commit27e13f14f4dba71b417ea530bfe035adbd8f0a93 (patch)
tree0ed374c02a7f1b5516a8e6e9ab1d37b42860e299
parent085bdfb75eb1c4b90a25a792815f8b80ed06dccb (diff)
add config file support
-rw-r--r--TODO6
-rw-r--r--cmd/tunnel/main.go14
-rw-r--r--cmd/tunneld/main.go129
-rw-r--r--pkg/client/client.go2
-rw-r--r--pkg/config/config.go38
-rw-r--r--pkg/server/env.go12
-rw-r--r--pkg/server/env/env.go22
-rw-r--r--pkg/server/server.go70
-rwxr-xr-xtest/auth.sh6
-rwxr-xr-xtest/env.sh10
10 files changed, 238 insertions, 71 deletions
diff --git a/TODO b/TODO
index a97e8f9..188afda 100644
--- a/TODO
+++ b/TODO
@@ -35,14 +35,14 @@ note:
7. DONE substitute variable over query (maybe)
8. CloseRead or CloseWriter
9. tunnel enable/disable
-10. config from file
-11. system/user? unix control socket location
+10. DONE config from file
+11. DONE system/user? unix control socket location
12. DONE modules: auth(chap), enc, dec
13. DONE print module name when stream closed by error
14. stream.run check errors on Open
15. per tunnel, per stream statistics
16. http connect proxy module
17. unix socket path from config file
- group owner of socket dir/file
+ DONE group owner of socket dir/file
command line option for socket
18. systemd socket activation
diff --git a/cmd/tunnel/main.go b/cmd/tunnel/main.go
index 44eac45..dc94974 100644
--- a/cmd/tunnel/main.go
+++ b/cmd/tunnel/main.go
@@ -22,22 +22,14 @@ func init() {
func getSocketPath() string {
if *systemFlag {
- return getSystemSocketPath()
+ return config.GetSystemSocketPath()
}
- s, err := config.GetSocketPath()
- if err != nil {
- log.Fatal(err)
- }
- return s
+ return config.GetSocketPath()
}
func getSystemSocketPath() string {
- s, err := config.GetSystemSocketPath()
- if err != nil {
- log.Fatal(err)
- }
- return s
+ return config.GetSystemSocketPath()
}
func main() {
diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go
index 6a16dec..aa0a370 100644
--- a/cmd/tunneld/main.go
+++ b/cmd/tunneld/main.go
@@ -1,13 +1,18 @@
package main
import (
+ "bufio"
"errors"
"flag"
+ "fmt"
"log"
"log/syslog"
"os"
"os/signal"
+ "os/user"
"path"
+ "strconv"
+ "strings"
"syscall"
"tunnel/pkg/config"
"tunnel/pkg/server"
@@ -16,6 +21,7 @@ import (
var (
debugFlag = flag.Bool("d", false, "debug: print time and source info")
syslogFlag = flag.Bool("s", false, "log output to syslog instead of stdout")
+ configFlag = flag.String("c", "", "path to configuration file")
)
func initLog() {
@@ -51,11 +57,16 @@ func sighandler(c chan os.Signal, s *server.Server) {
}
}
+func initSignals(s *server.Server) {
+ var c = make(chan os.Signal)
+
+ signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
+
+ go sighandler(c, s)
+}
+
func getSocketPath() string {
- s, err := config.GetSocketPath()
- if err != nil {
- log.Fatal(err)
- }
+ s := config.GetSocketPath()
if err := os.Mkdir(path.Dir(s), 0700); err != nil {
if !errors.Is(err, syscall.EEXIST) {
@@ -66,6 +77,107 @@ func getSocketPath() string {
return s
}
+func openConfig() (*os.File, error) {
+ var c string
+
+ if len(*configFlag) > 0 {
+ c = *configFlag
+ } else {
+ c = config.GetConfigPath()
+ }
+
+ if c == "" {
+ return nil, nil
+ }
+
+ fp, err := os.Open(c)
+ if err != nil {
+ if *configFlag == "" && errors.Is(err, syscall.ENOENT) {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+
+ return fp, nil
+}
+
+func readConfig(s *server.Server) error {
+ fp, err := openConfig()
+ if fp == nil || err != nil {
+ return err
+ }
+
+ defer fp.Close()
+
+ scanner := bufio.NewScanner(fp)
+ scanner.Split(bufio.ScanLines)
+
+ for nline := 1; scanner.Scan(); nline++ {
+ args := strings.SplitN(scanner.Text(), "#", 2)
+ cmd := strings.TrimSpace(args[0])
+
+ if cmd == "" {
+ continue
+ }
+
+ if err := s.Command(cmd); err != nil {
+ return fmt.Errorf("%s:%d: %s: %w", fp.Name(), nline, cmd, err)
+ }
+ }
+
+ return nil
+}
+
+func updateSocketGroup(s *server.Server, group string) error {
+ var gid int
+
+ if g, err := user.LookupGroup(group); err != nil {
+ return err
+ } else {
+ var err error
+
+ if gid, err = strconv.Atoi(g.Gid); err != nil {
+ return fmt.Errorf("bad group id %s: %w", g.Gid, err)
+ }
+ }
+
+ f := s.Socket()
+ d := path.Dir(f)
+
+ if err := os.Chown(d, -1, gid); err != nil {
+ return err
+ }
+
+ if err := os.Chmod(d, 0750); err != nil {
+ return err
+ }
+
+ if err := os.Chown(f, -1, gid); err != nil {
+ return err
+ }
+
+ if err := os.Chmod(f, 0770); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func configure(s *server.Server) error {
+ if err := readConfig(s); err != nil {
+ return err
+ }
+
+ if group, ok := s.Env().Find("server.socket.group"); ok {
+ if err := updateSocketGroup(s, group); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
func main() {
flag.Parse()
initLog()
@@ -75,11 +187,12 @@ func main() {
log.Fatal(err)
}
- var c = make(chan os.Signal)
-
- signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
+ if err := configure(s); err != nil {
+ s.Stop()
+ log.Fatal(err)
+ }
- go sighandler(c, s)
+ initSignals(s)
log.Print("ready")
diff --git a/pkg/client/client.go b/pkg/client/client.go
index 29d956e..66aa745 100644
--- a/pkg/client/client.go
+++ b/pkg/client/client.go
@@ -42,7 +42,7 @@ func (c *Client) Send(args []string) (string, error) {
buf := make([]byte, config.BufSize)
- _, ew := c.conn.Write([]byte(out.Bytes()))
+ _, ew := c.conn.Write(out.Bytes())
if ew != nil {
return "", ew
}
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 5483192..906c4b9 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -1,7 +1,6 @@
package config
import (
- "errors"
"fmt"
"os"
"time"
@@ -13,21 +12,40 @@ const BufSize = 1024
const IoTimeout = 5 * time.Second
-var errNegativeUid = errors.New("negative uid")
+func GetSystemSocketPath() string {
+ return "/run/tunnel/socket"
+}
-func GetSystemSocketPath() (string, error) {
- return "/run/tunnel/socket", nil
+func getuid() int {
+ uid := os.Getuid()
+ if uid < 0 {
+ panic("os.Getuid() returns negative uid")
+ }
+ return uid
}
-func GetSocketPath() (string, error) {
+func runAsRoot() bool {
uid := os.Getuid()
+ if uid < 0 {
+ panic("os.Getuid() returns negative uid")
+ }
+ return uid == 0
+}
- switch uid {
- case -1:
- return "", errNegativeUid
- case 0:
+func GetSocketPath() string {
+ if uid := getuid(); uid == 0 {
return GetSystemSocketPath()
+ } else {
+ return fmt.Sprintf("/run/user/%d/tunnel/socket", uid)
}
+}
- return fmt.Sprintf("/run/user/%d/tunnel/socket", uid), nil
+func GetConfigPath() string {
+ if uid := getuid(); uid == 0 {
+ return "/etc/tunnel.conf"
+ } else if s, err := os.UserConfigDir(); err == nil {
+ return s + "/tunnel/config"
+ } else {
+ return ""
+ }
}
diff --git a/pkg/server/env.go b/pkg/server/env.go
index bc9d4bf..a3c9f49 100644
--- a/pkg/server/env.go
+++ b/pkg/server/env.go
@@ -18,7 +18,7 @@ func varSet(r *request) {
}
}
-func varDel(r *request) {
+func varUnset(r *request) {
r.expect(1)
if !r.c.s.env.Del(r.args[0]) {
@@ -38,9 +38,9 @@ func varClear(r *request) {
}
func init() {
- newCmd(varGet, "var get")
- newCmd(varSet, "var set")
- newCmd(varDel, "var del")
- newCmd(varShow, "var show")
- newCmd(varClear, "var clear")
+ newCmd(varGet, "get")
+ newCmd(varSet, "set")
+ newCmd(varUnset, "unset")
+ newCmd(varShow, "env")
+ newCmd(varClear, "clear")
}
diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go
index fe8af25..237c0c5 100644
--- a/pkg/server/env/env.go
+++ b/pkg/server/env/env.go
@@ -3,6 +3,7 @@ package env
import (
"errors"
"regexp"
+ "sort"
"sync"
)
@@ -18,7 +19,7 @@ type Env struct {
const namePattern = "[a-zA-Z][a-zA-Z0-9.]*"
var isNamePattern = regexp.MustCompile("^" + namePattern + "$").MatchString
-var namePatternRe = regexp.MustCompile("@" + namePattern)
+var namePatternRe = regexp.MustCompile("@(" + namePattern + "|{" + namePattern + "})")
var errBadVariable = errors.New("bad variable name")
@@ -93,11 +94,19 @@ func (e *env) Del(key string) bool {
}
func (e *env) Each(f func(string, string) bool) {
+ var keys []string
+
e.Lock()
defer e.Unlock()
- for k, v := range e.m {
- if !f(k, v) {
+ for k := range e.m {
+ keys = append(keys, k)
+ }
+
+ sort.Strings(keys)
+
+ for _, k := range keys {
+ if !f(k, e.m[k]) {
break
}
}
@@ -115,7 +124,12 @@ func (e *env) Eval(s string) string {
defer e.Unlock()
repl := func(v string) string {
- if v, ok := e.m[v[1:]]; ok {
+ key := v[1:]
+ if key[0] == '{' {
+ key = key[1 : len(key)-1]
+ }
+
+ if v, ok := e.m[key]; ok {
return v
}
return ""
diff --git a/pkg/server/server.go b/pkg/server/server.go
index a9a50a4..badf5c0 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -51,7 +51,9 @@ type request struct {
argc int
args []string
- out *bytes.Buffer
+ fail bool
+
+ out bytes.Buffer
}
type requestError string
@@ -67,15 +69,15 @@ func (r *request) String() string {
}
func (r *request) Print(v ...interface{}) {
- fmt.Fprint(r.out, v...)
+ fmt.Fprint(&r.out, v...)
}
func (r *request) Printf(format string, v ...interface{}) {
- fmt.Fprintf(r.out, format, v...)
+ fmt.Fprintf(&r.out, format, v...)
}
func (r *request) Println(v ...interface{}) {
- fmt.Fprintln(r.out, v...)
+ fmt.Fprintln(&r.out, v...)
}
func (r *request) Fatal(v ...interface{}) {
@@ -145,6 +147,14 @@ func New(path string) (*Server, error) {
return s, nil
}
+func (s *Server) Env() env.Env {
+ return s.env
+}
+
+func (s *Server) Socket() string {
+ return s.listen.Addr().String()
+}
+
func (s *Server) Run() {
for {
conn, err := s.listen.Accept()
@@ -188,11 +198,28 @@ func (s *Server) newClient(conn net.Conn) *client {
return c
}
+func (s *Server) Command(query string) error {
+ args := strings.Split(query, " ")
+
+ r := &request{c: &client{s: s}}
+
+ r.run(args)
+
+ if r.fail {
+ if r.out.Len() == 0 {
+ return errors.New("failed")
+ }
+
+ return errors.New(r.out.String())
+ }
+
+ return nil
+}
+
func (c *client) newRequest() *request {
r := &request{
- c: c,
- id: c.nextRid,
- out: new(bytes.Buffer),
+ c: c,
+ id: c.nextRid,
}
c.nextRid++
@@ -214,11 +241,15 @@ func (c *client) handle() {
break
}
- query := string(buf[:nr])
+ args, err := c.decode(buf[:nr])
+ if err != nil {
+ log.Println(c, "decode:", err)
+ break
+ }
r := c.newRequest()
- r.run(query)
+ r.run(args)
if r.out.Len() == 0 {
r.out.Write([]byte("\n"))
@@ -232,8 +263,8 @@ func (c *client) handle() {
}
}
-func (r *request) decode(query string) []string {
- dec := netstring.NewDecoder(bytes.NewReader([]byte(query)))
+func (c *client) decode(b []byte) ([]string, error) {
+ dec := netstring.NewDecoder(bytes.NewReader(b))
var t []string
for {
@@ -241,14 +272,14 @@ func (r *request) decode(query string) []string {
t = append(t, s)
} else {
if !errors.Is(err, io.EOF) {
- r.Fatal("request parse failed")
+ return nil, err
}
break
}
}
- return t
+ return t, nil
}
func (r *request) eval(args []string) []string {
@@ -263,10 +294,8 @@ func (r *request) eval(args []string) []string {
return args
}
-func (r *request) parse(query string) {
- args := r.eval(r.decode(query))
-
- c, args := getCmd(args)
+func (r *request) parse(args []string) {
+ c, args := getCmd(r.eval(args))
if c == nil {
r.Fatal("command not found")
}
@@ -276,20 +305,21 @@ func (r *request) parse(query string) {
r.cmd = c
}
-func (r *request) run(query string) {
+func (r *request) run(args []string) {
defer func() {
switch err := recover().(type) {
case requestError:
r.Print(err)
+ r.fail = true
default:
panic(err)
case nil:
}
}()
- r.parse(query)
+ log.Println(r.c, r, ">", strings.Join(args, " "))
- log.Println(r.c, r, ">", r.cmd.name, strings.Join(r.args, " "))
+ r.parse(args)
r.c.s.mu.Lock()
defer r.c.s.mu.Unlock()
diff --git a/test/auth.sh b/test/auth.sh
index 311b8ff..a807c96 100755
--- a/test/auth.sh
+++ b/test/auth.sh
@@ -5,10 +5,10 @@ PATH=$PATH:$ROOT/cmd/tunnel
tunnel add name T 2000,listen auth aes 3000
tunnel add name X 3000,listen -aes -auth 4000
-tunnel var set tunnel.T.secret secret
-tunnel var set tunnel.X.secret secret
+tunnel set tunnel.T.secret secret
+tunnel set tunnel.X.secret secret
nc -l 4000 &
echo "Hello, World!" | nc -N localhost 2000
-tunnel var clear
+tunnel clear
tunnel del T
tunnel del X
diff --git a/test/env.sh b/test/env.sh
index bcff497..7e60718 100755
--- a/test/env.sh
+++ b/test/env.sh
@@ -3,9 +3,9 @@
ROOT=$(dirname $0)/..
PATH=$ROOT/cmd/tunnel
-tunnel var set cmd echo
-tunnel var set args ^"@x, @y!"
-tunnel var set x Hello
-tunnel var set y World
+tunnel set cmd echo
+tunnel set args ^"@x, @y!"
+tunnel set x Hello
+tunnel set y World
tunnel @cmd @args
-tunnel var clear
+tunnel clear