summaryrefslogtreecommitdiff
path: root/pkg/server
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server')
-rw-r--r--pkg/server/automap.go9
-rw-r--r--pkg/server/env/env.go80
-rw-r--r--pkg/server/hook/aes.go16
-rw-r--r--pkg/server/hook/auth.go18
-rw-r--r--pkg/server/hook/b64.go19
-rw-r--r--pkg/server/hook/b85.go19
-rw-r--r--pkg/server/hook/dump.go30
-rw-r--r--pkg/server/hook/hex.go19
-rw-r--r--pkg/server/hook/hook.go77
-rw-r--r--pkg/server/hook/info-http.go10
-rw-r--r--pkg/server/hook/proxy.go28
-rw-r--r--pkg/server/hook/split.go37
-rw-r--r--pkg/server/hook/zip.go19
-rw-r--r--pkg/server/opts/opts.go149
-rw-r--r--pkg/server/server.go4
-rw-r--r--pkg/server/socket/defer.go23
-rw-r--r--pkg/server/socket/dial.go29
-rw-r--r--pkg/server/socket/listen.go66
-rw-r--r--pkg/server/socket/loop.go6
-rw-r--r--pkg/server/socket/proxy.go21
-rw-r--r--pkg/server/socket/socket.go76
-rw-r--r--pkg/server/socket/tun.go24
-rw-r--r--pkg/server/tunnel.go93
23 files changed, 527 insertions, 345 deletions
diff --git a/pkg/server/automap.go b/pkg/server/automap.go
index 67ae5b0..f38f758 100644
--- a/pkg/server/automap.go
+++ b/pkg/server/automap.go
@@ -10,16 +10,21 @@ type automap map[string]interface{}
var errExists = errors.New("already exists")
var errNotFound = errors.New("no such entry")
-func (m automap) add(v interface{}) string {
+func (m automap) next() string {
for n := 1; ; n++ {
k := fmt.Sprintf("%d", n)
if _, ok := m[k]; !ok {
- m[k] = v
return k
}
}
}
+func (m automap) add(v interface{}) string {
+ k := m.next()
+ m[k] = v
+ return k
+}
+
func (m automap) rename(old string, new string) (interface{}, error) {
if _, ok := m[old]; !ok {
return nil, errNotFound
diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go
index 2c97669..7aa93d2 100644
--- a/pkg/server/env/env.go
+++ b/pkg/server/env/env.go
@@ -4,12 +4,23 @@ import (
"errors"
"regexp"
"sort"
+ "strings"
"sync"
)
type env struct {
+ // vars
m map[string]string
+
+ // forks
+ c map[string]*env
+
+ // key in parent env's map
+ k string
+
sync.Mutex
+
+ // parent env
p *env
}
@@ -36,23 +47,84 @@ func (e Env) init() {
}
}
-func (e Env) Fork() Env {
+func (e Env) Fork(path ...string) Env {
t := New()
t.p = e.env
+
+ if len(path) > 0 {
+ k := strings.Join(path, ".")
+
+ e.Lock()
+ if e.c == nil {
+ e.c = make(map[string]*env)
+ } else if _, ok := e.c[k]; ok {
+ panic("env fork already exists: " + k)
+ }
+ e.c[k] = t.env
+ e.Unlock()
+
+ t.k = k
+ }
+
return t
}
-func (e *env) Find(key string) (string, bool) {
+func (e Env) Detach() {
+ if e.p != nil {
+ e.p.Lock()
+ delete(e.p.c, e.k)
+ e.p.Unlock()
+ }
+}
+
+func (e *env) find(key string) (string, bool, func() (string, bool)) {
e.Lock()
defer e.Unlock()
v, ok := e.m[key]
if ok {
- return v, ok
+ return v, ok, nil
+ }
+
+ if e.c != nil {
+ s := strings.Split(key, ".")
+
+ for k := ""; len(s) > 1; s = s[1:] {
+ if k == "" {
+ k = s[0]
+ } else {
+ k += "." + s[0]
+ }
+
+ if c, ok := e.c[k]; ok {
+ return "", false, func() (string, bool) {
+ return c.Find(strings.Join(s[1:], "."))
+ }
+ }
+ }
}
if e.p != nil {
- return e.p.Find(key)
+ return "", false, func() (string, bool) {
+ return e.p.Find(key)
+ }
+ }
+
+ return "", false, nil
+}
+
+func (e *env) Find(key string) (string, bool) {
+ if key == "" {
+ return "", false
+ }
+
+ v, ok, f := e.find(key)
+ if ok {
+ return v, ok
+ }
+
+ if f != nil {
+ return f()
}
return "", false
diff --git a/pkg/server/hook/aes.go b/pkg/server/hook/aes.go
index e437651..dc48f49 100644
--- a/pkg/server/hook/aes.go
+++ b/pkg/server/hook/aes.go
@@ -6,8 +6,8 @@ import (
"crypto/md5"
"crypto/rand"
"io"
+
"tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
@@ -64,24 +64,16 @@ func (a *aesPipe) Recv(rq, wq queue.Q) error {
return queue.IoCopy(reader, wq.Writer())
}
-func newAes(env env.Env) *aesPipe {
+func (aesHook) New(env env.Env) (interface{}, error) {
s := env.Value("secret")
h := md5.Sum([]byte(s))
a := &aesPipe{key: make([]byte, 16)}
copy(a.key, h[:])
- return a
-}
-
-func (aesHook) Open(env env.Env) (interface{}, error) {
- return newAes(env), nil
-}
-
-func newAesHook(opts.Opts) (hook, error) {
- return aesHook{}, nil
+ return a, nil
}
func init() {
- register("aes", newAesHook)
+ register("aes", aesHook{})
}
diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go
index 5910b56..dbfc9bc 100644
--- a/pkg/server/hook/auth.go
+++ b/pkg/server/hook/auth.go
@@ -7,8 +7,8 @@ import (
"io"
"sync"
"time"
+
"tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
@@ -17,6 +17,10 @@ const authTimeout = 5 * time.Second
const saltSize = 16
const hashSize = md5.Size
+type authHook struct {
+ m sync.Map
+}
+
type auth struct {
h *authHook
@@ -42,10 +46,6 @@ var errDupSalt = errors.New("peer repeats salt")
var errAuthFail = errors.New("peer auth fail")
var errTimeout = errors.New("timeout")
-type authHook struct {
- m sync.Map
-}
-
func (a *auth) Init() error {
b := make([]byte, saltSize)
if _, err := rand.Read(b); err != nil {
@@ -149,7 +149,7 @@ func (a *auth) Close() {
a.h.m.Delete(a.salt.self)
}
-func (h *authHook) Open(env env.Env) (interface{}, error) {
+func (h *authHook) New(env env.Env) (interface{}, error) {
a := &auth{
h: h,
secret: env.Value("secret"),
@@ -169,10 +169,6 @@ func (h *authHook) Open(env env.Env) (interface{}, error) {
return a, nil
}
-func newAuthHook(opts.Opts) (hook, error) {
- return &authHook{}, nil
-}
-
func init() {
- register("auth", newAuthHook)
+ register("auth", authHook{})
}
diff --git a/pkg/server/hook/b64.go b/pkg/server/hook/b64.go
index fce42a0..c6637e5 100644
--- a/pkg/server/hook/b64.go
+++ b/pkg/server/hook/b64.go
@@ -2,17 +2,16 @@ package hook
import (
"encoding/base64"
+
"tunnel/pkg/netstring"
- "tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
var b64Enc = base64.RawStdEncoding
-type b64Hook struct{}
+type b64Pipe struct{}
-func (b64Hook) Send(rq, wq queue.Q) error {
+func (b64Pipe) Send(rq, wq queue.Q) error {
e := netstring.NewEncoder(wq.Writer())
for b := range rq {
@@ -22,7 +21,7 @@ func (b64Hook) Send(rq, wq queue.Q) error {
return nil
}
-func (b64Hook) Recv(rq, wq queue.Q) error {
+func (b64Pipe) Recv(rq, wq queue.Q) error {
d := netstring.NewDecoder(rq.Reader())
for {
@@ -39,14 +38,6 @@ func (b64Hook) Recv(rq, wq queue.Q) error {
}
}
-func (h b64Hook) Open(env.Env) (interface{}, error) {
- return h, nil
-}
-
-func newB64Hook(opts.Opts) (hook, error) {
- return b64Hook{}, nil
-}
-
func init() {
- register("b64", newB64Hook)
+ registerPipe("b64", b64Pipe{})
}
diff --git a/pkg/server/hook/b85.go b/pkg/server/hook/b85.go
index bf36b56..d90a1c4 100644
--- a/pkg/server/hook/b85.go
+++ b/pkg/server/hook/b85.go
@@ -3,15 +3,14 @@ package hook
import (
"encoding/ascii85"
"errors"
+
"tunnel/pkg/netstring"
- "tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
-type b85Hook struct{}
+type b85Pipe struct{}
-func (b85Hook) Send(rq, wq queue.Q) error {
+func (b85Pipe) Send(rq, wq queue.Q) error {
e := netstring.NewEncoder(wq.Writer())
for b := range rq {
@@ -23,7 +22,7 @@ func (b85Hook) Send(rq, wq queue.Q) error {
return nil
}
-func (b85Hook) Recv(rq, wq queue.Q) error {
+func (b85Pipe) Recv(rq, wq queue.Q) error {
d := netstring.NewDecoder(rq.Reader())
for {
@@ -44,14 +43,6 @@ func (b85Hook) Recv(rq, wq queue.Q) error {
}
}
-func (h b85Hook) Open(env.Env) (interface{}, error) {
- return h, nil
-}
-
-func newB85Hook(opts.Opts) (hook, error) {
- return b85Hook{}, nil
-}
-
func init() {
- register("b85", newB85Hook)
+ registerPipe("b85", b85Pipe{})
}
diff --git a/pkg/server/hook/dump.go b/pkg/server/hook/dump.go
index 864443b..d871d63 100644
--- a/pkg/server/hook/dump.go
+++ b/pkg/server/hook/dump.go
@@ -7,28 +7,28 @@ import (
"os"
"path"
"time"
+
"tunnel/pkg/config"
"tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
const dumpDefaultFile = "/tmp/tunnel/dump"
+type dumpHook struct {
+ File string
+ Time bool
+}
+
type dump struct {
f *os.File
h *dumpHook
}
-type dumpHook struct {
- file string
- time bool
-}
-
func (t *dump) write(s string, p []byte) error {
var out bytes.Buffer
- if t.h.time {
+ if t.h.Time {
now := time.Now().Format(config.TimeMsFormat)
fmt.Fprintln(&out, now, s, len(p))
} else {
@@ -69,8 +69,8 @@ func (t *dump) Close() {
}
func (h *dumpHook) where(env env.Env) string {
- if h.file != "" {
- return h.file
+ if h.File != "" {
+ return h.File
}
if v := env.Value("dump.file"); v != "" {
@@ -80,7 +80,7 @@ func (h *dumpHook) where(env env.Env) string {
return dumpDefaultFile
}
-func (h *dumpHook) Open(env env.Env) (interface{}, error) {
+func (h *dumpHook) New(env env.Env) (interface{}, error) {
file := h.where(env)
dir := path.Dir(file)
@@ -102,14 +102,6 @@ func (h *dumpHook) Open(env env.Env) (interface{}, error) {
return t, nil
}
-func newDumpHook(opts opts.Opts) (hook, error) {
- h := &dumpHook{
- file: opts["file"],
- time: opts.Bool("time"),
- }
- return h, nil
-}
-
func init() {
- register("dump", newDumpHook)
+ register("dump", dumpHook{})
}
diff --git a/pkg/server/hook/hex.go b/pkg/server/hook/hex.go
index bc71bf2..362dbd4 100644
--- a/pkg/server/hook/hex.go
+++ b/pkg/server/hook/hex.go
@@ -2,14 +2,13 @@ package hook
import (
"encoding/hex"
- "tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
+
"tunnel/pkg/server/queue"
)
-type hexHook struct{}
+type hexPipe struct{}
-func (hexHook) Send(rq, wq queue.Q) error {
+func (hexPipe) Send(rq, wq queue.Q) error {
enc := hex.NewEncoder(wq.Writer())
for b := range rq {
@@ -19,19 +18,11 @@ func (hexHook) Send(rq, wq queue.Q) error {
return nil
}
-func (hexHook) Recv(rq, wq queue.Q) error {
+func (hexPipe) Recv(rq, wq queue.Q) error {
r := hex.NewDecoder(rq.Reader())
return queue.IoCopy(r, wq.Writer())
}
-func (h hexHook) Open(env.Env) (interface{}, error) {
- return h, nil
-}
-
-func newHexHook(opts.Opts) (hook, error) {
- return hexHook{}, nil
-}
-
func init() {
- register("hex", newHexHook)
+ registerPipe("hex", hexPipe{})
}
diff --git a/pkg/server/hook/hook.go b/pkg/server/hook/hook.go
index 69aa237..36b01d4 100644
--- a/pkg/server/hook/hook.go
+++ b/pkg/server/hook/hook.go
@@ -3,16 +3,16 @@ package hook
import (
"fmt"
"log"
+ "reflect"
"sort"
"strings"
+
"tunnel/pkg/server/env"
"tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
-type hookInitFunc func(opts.Opts) (hook, error)
-
-var hooks = map[string]hookInitFunc{}
+var hooks = map[string]interface{}{}
type Pipe struct {
priv interface{}
@@ -21,13 +21,12 @@ type Pipe struct {
Recv Func
}
-type hook interface {
- Open(env env.Env) (interface{}, error)
+type Hooker interface {
+ New(env env.Env) (interface{}, error)
}
type H interface {
- Open(env env.Env) (*Pipe, error)
- String() string
+ New(env env.Env) (*Pipe, error)
}
type Sender interface {
@@ -44,12 +43,12 @@ func (f Func) Send(rq, wq queue.Q) error {
return f(rq, wq)
}
-func (f Func) Open(env env.Env) (interface{}, error) {
+func (f Func) New(env env.Env) (interface{}, error) {
return f, nil
}
type wrapper struct {
- hook
+ hook Hooker
name string
reverse bool
}
@@ -58,8 +57,8 @@ func (w *wrapper) String() string {
return fmt.Sprintf("hook:%s", w.name)
}
-func (w *wrapper) Open(env env.Env) (*Pipe, error) {
- it, err := w.hook.Open(env)
+func (w *wrapper) New(env env.Env) (*Pipe, error) {
+ it, err := w.hook.New(env)
if err != nil {
return nil, err
}
@@ -87,6 +86,23 @@ func (p *Pipe) Close() {
}
}
+func initHook(i interface{}, opts opts.Opts) (Hooker, error) {
+ if f, ok := i.(Func); ok {
+ return f, nil
+ }
+
+ if p, ok := i.(pipeHolder); ok {
+ return p, nil
+ }
+
+ h := reflect.New(reflect.TypeOf(i)).Interface()
+ if err := opts.Configure(h); err != nil {
+ return nil, err
+ }
+
+ return h.(Hooker), nil
+}
+
func New(desc string) (H, error) {
name, opts := opts.Parse(desc)
reverse := false
@@ -96,9 +112,9 @@ func New(desc string) (H, error) {
reverse = true
}
- if f, ok := hooks[name]; !ok {
+ if i, ok := hooks[name]; !ok {
return nil, fmt.Errorf("unknown hook '%s'", name)
- } else if h, err := f(opts); err != nil {
+ } else if h, err := initHook(i, opts); err != nil {
return nil, fmt.Errorf("%s: %w", name, err)
} else {
w := &wrapper{
@@ -110,18 +126,41 @@ func New(desc string) (H, error) {
}
}
-func register(name string, f hookInitFunc) {
+func register(name string, i interface{}) {
+ switch t := reflect.TypeOf(i); t.Kind() {
+ case reflect.Struct:
+ if _, ok := reflect.New(t).Interface().(Hooker); !ok {
+ log.Panicf("uncompatible hook type '%s'", t.String())
+ }
+ case reflect.Func:
+ if _, ok := i.(Func); !ok {
+ log.Panicf("uncompatible func type '%s'", t.String())
+ }
+ default:
+ log.Panicf("non-struct and non-func type '%s'", t.String())
+ }
+
if _, ok := hooks[name]; ok {
log.Panicf("duplicate hook name '%s'", name)
}
- hooks[name] = f
+ hooks[name] = i
}
-func registerFunc(name string, p Func) {
- register(name, func(opts.Opts) (hook, error) {
- return p, nil
- })
+func registerFunc(name string, f Func) {
+ register(name, f)
+}
+
+type pipeHolder struct {
+ i interface{}
+}
+
+func (p pipeHolder) New(env.Env) (interface{}, error) {
+ return p.i, nil
+}
+
+func registerPipe(name string, i interface{}) {
+ register(name, pipeHolder{i})
}
func GetList() []string {
diff --git a/pkg/server/hook/info-http.go b/pkg/server/hook/info-http.go
index 73480ff..ec56f87 100644
--- a/pkg/server/hook/info-http.go
+++ b/pkg/server/hook/info-http.go
@@ -2,9 +2,9 @@ package hook
import (
"bufio"
+
"tunnel/pkg/http"
"tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
@@ -42,14 +42,10 @@ func (info *infoHttp) Recv(rq, wq queue.Q) error {
return queue.Copy(rq, wq)
}
-func (infoHttpHook) Open(env env.Env) (interface{}, error) {
+func (infoHttpHook) New(env env.Env) (interface{}, error) {
return &infoHttp{env: env}, nil
}
-func newInfoHttpHook(opts.Opts) (hook, error) {
- return infoHttpHook{}, nil
-}
-
func init() {
- register("info-http", newInfoHttpHook)
+ register("info-http", infoHttpHook{})
}
diff --git a/pkg/server/hook/proxy.go b/pkg/server/hook/proxy.go
index bc6da18..4276d9a 100644
--- a/pkg/server/hook/proxy.go
+++ b/pkg/server/hook/proxy.go
@@ -3,12 +3,11 @@ package hook
import (
"bufio"
"bytes"
- "errors"
"fmt"
"regexp"
+
"tunnel/pkg/http"
"tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
@@ -16,8 +15,8 @@ var addrPattern = "^([0-9a-zA-Z-.]+|\\[[0-9a-fA-F:]*\\]):[0-9]+$"
var isGoodAddr = regexp.MustCompile(addrPattern).MatchString
type proxyHook struct {
- addr string
- auth string
+ Addr string `opts:"required"`
+ Auth string
}
type proxy struct {
@@ -67,15 +66,15 @@ func (p *proxy) Recv(rq, wq queue.Q) error {
return queue.IoCopy(r, wq.Writer())
}
-func (h *proxyHook) Open(env env.Env) (interface{}, error) {
- addr := env.Expand(h.addr)
+func (h *proxyHook) New(env env.Env) (interface{}, error) {
+ addr := env.Expand(h.Addr)
if !isGoodAddr(addr) {
return nil, fmt.Errorf("invalid addr '%s'", addr)
}
p := &proxy{
addr: addr,
- auth: h.auth,
+ auth: h.Auth,
c: make(chan bool),
env: env,
}
@@ -87,19 +86,6 @@ func (h *proxyHook) Open(env env.Env) (interface{}, error) {
return p, nil
}
-func newProxyHook(opts opts.Opts) (hook, error) {
- h := &proxyHook{
- addr: opts["addr"],
- auth: opts["auth"],
- }
-
- if h.addr == "" {
- return nil, errors.New("expected addr")
- }
-
- return h, nil
-}
-
func init() {
- register("proxy", newProxyHook)
+ register("proxy", proxyHook{})
}
diff --git a/pkg/server/hook/split.go b/pkg/server/hook/split.go
index 6a2c4ca..59c8055 100644
--- a/pkg/server/hook/split.go
+++ b/pkg/server/hook/split.go
@@ -1,27 +1,24 @@
package hook
import (
- "errors"
- "strconv"
"tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
)
-const splitDefaultSize = 1024
-
-var errBadSize = errors.New("bad size value")
-
type splitHook struct {
- size int
+ Size int `opts:"positive,default:1024"`
+}
+
+type splitPipe struct {
+ h *splitHook
}
-func (h *splitHook) Send(rq, wq queue.Q) error {
+func (p splitPipe) Send(rq, wq queue.Q) error {
for b := range rq {
var upto int
for n := 0; n < len(b); n = upto {
- upto += h.size
+ upto += p.h.Size
if upto > len(b) {
upto = len(b)
@@ -34,24 +31,10 @@ func (h *splitHook) Send(rq, wq queue.Q) error {
return nil
}
-func (h *splitHook) Open(env.Env) (interface{}, error) {
- return h, nil
-}
-
-func newSplitHook(opts opts.Opts) (hook, error) {
- size := splitDefaultSize
-
- if s, ok := opts["size"]; ok {
- var err error
-
- if size, err = strconv.Atoi(s); err != nil || size <= 0 {
- return nil, errBadSize
- }
- }
-
- return &splitHook{size: size}, nil
+func (h *splitHook) New(env.Env) (interface{}, error) {
+ return &splitPipe{h}, nil
}
func init() {
- register("split", newSplitHook)
+ register("split", splitHook{})
}
diff --git a/pkg/server/hook/zip.go b/pkg/server/hook/zip.go
index bde4957..615b50d 100644
--- a/pkg/server/hook/zip.go
+++ b/pkg/server/hook/zip.go
@@ -3,14 +3,13 @@ package hook
import (
"compress/flate"
"io"
- "tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
+
"tunnel/pkg/server/queue"
)
-type zipHook struct{}
+type zipPipe struct{}
-func (zipHook) Send(rq, wq queue.Q) error {
+func (zipPipe) Send(rq, wq queue.Q) error {
w, err := flate.NewWriter(wq.Writer(), flate.BestCompression)
if err != nil {
return err
@@ -28,7 +27,7 @@ func (zipHook) Send(rq, wq queue.Q) error {
return w.Close()
}
-func (zipHook) Recv(rq, wq queue.Q) error {
+func (zipPipe) Recv(rq, wq queue.Q) error {
r := flate.NewReader(rq.Reader())
// FIXME: not received ending due to ultimate conn.Close
@@ -42,14 +41,6 @@ func (zipHook) Recv(rq, wq queue.Q) error {
return r.Close()
}
-func (h zipHook) Open(env.Env) (interface{}, error) {
- return h, nil
-}
-
-func newZipHook(opts.Opts) (hook, error) {
- return zipHook{}, nil
-}
-
func init() {
- register("zip", newZipHook)
+ registerPipe("zip", zipPipe{})
}
diff --git a/pkg/server/opts/opts.go b/pkg/server/opts/opts.go
index 22383d8..c5729a3 100644
--- a/pkg/server/opts/opts.go
+++ b/pkg/server/opts/opts.go
@@ -1,6 +1,19 @@
package opts
-import "strings"
+import (
+ "fmt"
+ "reflect"
+ "strconv"
+ "strings"
+)
+
+type fieldSpec struct {
+ name string
+ inline bool
+ required bool
+ positive bool
+ defaultValue string
+}
type Opts map[string]string
@@ -20,7 +33,137 @@ func Parse(s string) (string, Opts) {
return v[0], m
}
-func (m Opts) Bool(key string) bool {
- _, ok := m[key]
+func (opts Opts) Bool(key string) bool {
+ _, ok := opts[key]
return ok
}
+
+func (opts Opts) Configure(v interface{}) error {
+ rv := reflect.ValueOf(v)
+
+ switch {
+ case rv.Kind() != reflect.Ptr:
+ return fmt.Errorf("opts: configure failed: non-pointer %s", rv.Type().String())
+ case rv.IsNil():
+ return fmt.Errorf("opts: configure failed: nil %s", rv.Type().String())
+ case rv.Elem().Kind() != reflect.Struct:
+ return fmt.Errorf("opts: configure failed: non-struct %s", rv.Elem().Type().String())
+ }
+
+ return opts.configure(rv.Elem())
+}
+
+func getFieldSpec(field reflect.StructField) (fs fieldSpec) {
+ fs.name = strings.ToLower(field.Name)
+
+ for _, s := range strings.Split(field.Tag.Get("opts"), ",") {
+ switch {
+ case s == "inline":
+ fs.inline = true
+ case s == "required":
+ fs.required = true
+ case s == "positive":
+ fs.positive = true
+ case strings.HasPrefix(s, "default:"):
+ fs.defaultValue = s[8:]
+ }
+ }
+
+ return
+}
+
+func (opts Opts) setValue(v reflect.Value, fs fieldSpec) error {
+ s, ok := opts[fs.name]
+ kind := v.Kind()
+
+ if kind == reflect.Bool {
+ if s != "" {
+ return fmt.Errorf("%s: boolean option does not need a value", fs.name)
+ }
+
+ if ok {
+ v.SetBool(true)
+ }
+
+ return nil
+ }
+
+ if s == "" {
+ if fs.required {
+ return fmt.Errorf("%s: expect value", fs.name)
+ }
+
+ s = fs.defaultValue
+ }
+
+ if s == "" {
+ return nil
+ }
+
+ switch kind {
+ case reflect.String:
+ v.SetString(s)
+ case reflect.Int:
+ n, err := strconv.Atoi(s)
+ if err != nil {
+ return fmt.Errorf("%s: %w", fs.name, err)
+ }
+ if n <= 0 {
+ return fmt.Errorf("%s: expect positive", fs.name)
+ }
+ v.SetInt(int64(n))
+ default:
+ return fmt.Errorf("%s: unsupported type '%s'", fs.name, v.Type().String())
+ }
+
+ return nil
+}
+
+func (opts Opts) configure(v reflect.Value) error {
+ m := map[string]bool{}
+
+ for s := range opts {
+ m[s] = true
+ }
+
+ var visit func(v reflect.Value) error
+
+ visit = func(v reflect.Value) error {
+ t := v.Type()
+
+ for n := 0; n < v.NumField(); n++ {
+ fs := getFieldSpec(t.Field(n))
+ fv := v.Field(n)
+
+ if fs.inline {
+ if err := visit(fv); err != nil {
+ return err
+ }
+ } else if fv.CanSet() {
+ if err := opts.setValue(v.Field(n), fs); err != nil {
+ return err
+ }
+
+ delete(m, fs.name)
+ }
+ }
+
+ return nil
+ }
+
+ if err := visit(v); err != nil {
+ return err
+ }
+
+ if len(m) > 0 {
+ var ss []string
+
+ for s := range m {
+ ss = append(ss, s)
+ }
+
+ return fmt.Errorf("unknown options: %s", strings.Join(ss, ","))
+ }
+
+ return nil
+}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index d57d9dd..8c5683a 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -254,10 +254,6 @@ func (c *client) handle() {
r.run(args)
- if r.out.Len() == 0 {
- r.out.Write([]byte("\n"))
- }
-
ew := c.w.Encode(r.out.String())
if ew != nil {
log.Println(c, "handle:", ew)
diff --git a/pkg/server/socket/defer.go b/pkg/server/socket/defer.go
index 7ed303d..7c1436e 100644
--- a/pkg/server/socket/defer.go
+++ b/pkg/server/socket/defer.go
@@ -6,28 +6,19 @@ import (
)
type deferSocket struct {
- S
+ dialSocket `opts:"inline"`
}
type deferConn struct {
- sock *deferSocket
+ sock S
wait chan bool
env env.Env
conn Conn
}
-func newDeferSocket(proto, addr string) (S, error) {
- s, err := newDialSocket(proto, addr)
- if err != nil {
- return s, err
- }
-
- return &deferSocket{s}, nil
-}
-
-func (s *deferSocket) Open(env env.Env) (Conn, error) {
+func (s *deferSocket) New(env env.Env) (Conn, error) {
c := &deferConn{
- sock: s,
+ sock: &s.dialSocket,
wait: make(chan bool),
env: env,
}
@@ -54,7 +45,7 @@ func (c *deferConn) Recv(rq queue.Q) error {
return nil
}
- conn, err := c.sock.S.Open(c.env)
+ conn, err := c.sock.New(c.env)
if err != nil {
c.wait <- false
return err
@@ -83,3 +74,7 @@ func (c *deferConn) Close() (err error) {
return
}
+
+func init() {
+ register("defer", deferSocket{})
+}
diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go
index 7623084..b2df3b7 100644
--- a/pkg/server/socket/dial.go
+++ b/pkg/server/socket/dial.go
@@ -4,36 +4,25 @@ import (
"fmt"
"log"
"net"
- "strings"
"time"
+
"tunnel/pkg/server/env"
)
const defaultTimeout = 5 * time.Second
type dialSocket struct {
- proto, addr string
-}
-
-func newDialSocket(proto, addr string) (S, error) {
- return &dialSocket{proto: proto, addr: addr}, nil
+ Proto string `opts:"default:tcp"`
+ Addr string `opts:"required"`
}
func (s *dialSocket) String() string {
- return fmt.Sprintf("%s/%s", s.proto, s.addr)
+ return fmt.Sprintf("%s/%s", s.Proto, s.Addr)
}
-func (s *dialSocket) Open(e env.Env) (Conn, error) {
- addr := e.Expand(s.addr)
-
- switch s.proto {
- case "tcp", "udp":
- if !strings.Contains(addr, ":") {
- addr = "localhost:" + addr
- }
- }
-
- conn, err := net.DialTimeout(s.proto, addr, defaultTimeout)
+func (s *dialSocket) New(e env.Env) (Conn, error) {
+ proto, addr := parseProtoAddr(s.Proto, e.Expand(s.Addr))
+ conn, err := net.DialTimeout(proto, addr, defaultTimeout)
if err != nil {
return nil, err
}
@@ -49,3 +38,7 @@ func (s *dialSocket) Open(e env.Env) (Conn, error) {
func (s *dialSocket) Close() {
}
+
+func init() {
+ register("dial", dialSocket{})
+}
diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go
index 910e5de..2c2f184 100644
--- a/pkg/server/socket/listen.go
+++ b/pkg/server/socket/listen.go
@@ -5,64 +5,54 @@ import (
"fmt"
"log"
"net"
- "strings"
+
"tunnel/pkg/server/env"
- "tunnel/pkg/server/opts"
)
type listenSocket struct {
- proto, addr string
- listen net.Listener
- redirect bool
- tproxy bool
-}
+ Proto string `opts:"default:tcp"`
+ Addr string `opts:"required"`
-func newListenSocket(proto, addr string, opts opts.Opts) (S, error) {
- redirect := opts.Bool("redirect")
- tproxy := opts.Bool("tproxy")
+ Redirect bool
+ Tproxy bool
- if proto == "tcp" {
- if !strings.Contains(addr, ":") {
- addr = ":" + addr
- }
- }
+ listen net.Listener
+}
- if redirect && proto != "tcp" {
- return nil, errors.New("redirect not supported")
+func (s *listenSocket) Prepare(e env.Env) error {
+ if s.Redirect && s.Proto != "tcp" {
+ return errors.New("redirect not supported")
}
- if tproxy && proto != "tcp" {
- return nil, errors.New("tproxy not supported")
+ if s.Tproxy && s.Proto != "tcp" {
+ return errors.New("tproxy not supported")
}
- if redirect && tproxy {
- return nil, errors.New("redirect and tproxy cannot be used together")
+ if s.Redirect && s.Tproxy {
+ return errors.New("redirect and tproxy cannot be used together")
}
+ proto, addr := parseProtoAddr(s.Proto, s.Addr)
listen, err := net.Listen(proto, addr)
if err != nil {
- return nil, err
+ return err
}
- if tproxy {
+ e.Set("listen", listen.Addr().String())
+
+ if s.Tproxy {
if err := setConnTransparent(listen); err != nil {
listen.Close()
- return nil, err
+ return err
}
}
- s := &listenSocket{
- proto: proto,
- addr: addr,
- listen: listen,
- redirect: redirect,
- tproxy: tproxy,
- }
+ s.listen = listen
- return s, nil
+ return nil
}
-func (s *listenSocket) Open(env env.Env) (Conn, error) {
+func (s *listenSocket) New(env env.Env) (Conn, error) {
var original string
conn, err := s.listen.Accept()
@@ -74,7 +64,7 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) {
desc := fmt.Sprintf("%s/%s->%s", la.Network(), ra, la)
info := fmt.Sprintf("<%s/%s", ra.Network(), ra)
- if s.redirect {
+ if s.Redirect {
if err := getConnOriginalAddr(conn, &original); err != nil {
log.Println("accept", desc, "failed")
conn.Close()
@@ -84,7 +74,7 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) {
}
}
- if s.tproxy {
+ if s.Tproxy {
env.Set("original", la.String())
}
@@ -98,9 +88,13 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) {
}
func (s *listenSocket) String() string {
- return fmt.Sprintf("%s/%s,listen", s.proto, s.addr)
+ return fmt.Sprintf("%s/%s,listen", s.Proto, s.Addr)
}
func (s *listenSocket) Close() {
s.listen.Close()
}
+
+func init() {
+ register("listen", listenSocket{})
+}
diff --git a/pkg/server/socket/loop.go b/pkg/server/socket/loop.go
index a06448a..c442140 100644
--- a/pkg/server/socket/loop.go
+++ b/pkg/server/socket/loop.go
@@ -30,7 +30,7 @@ func (c *loopConn) Close() error {
return nil
}
-func (s *loopSocket) Open(env.Env) (Conn, error) {
+func (s *loopSocket) New(env.Env) (Conn, error) {
return &loopConn{make(chan queue.Q), make(chan error)}, nil
}
@@ -41,6 +41,6 @@ func (s *loopSocket) String() string {
func (s *loopSocket) Close() {
}
-func newLoopSocket() (S, error) {
- return &loopSocket{}, nil
+func init() {
+ register("loop", loopSocket{})
}
diff --git a/pkg/server/socket/proxy.go b/pkg/server/socket/proxy.go
index e4baec2..1be4bba 100644
--- a/pkg/server/socket/proxy.go
+++ b/pkg/server/socket/proxy.go
@@ -5,6 +5,7 @@ import (
"bytes"
"errors"
"fmt"
+
"tunnel/pkg/http"
"tunnel/pkg/server/env"
"tunnel/pkg/server/queue"
@@ -16,7 +17,7 @@ type status struct {
}
type proxySocket struct {
- proto string
+ Proto string `opts:"default:tcp"`
}
type proxyServer struct {
@@ -28,11 +29,7 @@ type proxyServer struct {
conn Conn
}
-func newProxySocket(proto string) (S, error) {
- return &proxySocket{proto}, nil
-}
-
-func (sock *proxySocket) Open(env env.Env) (Conn, error) {
+func (sock *proxySocket) New(env env.Env) (Conn, error) {
s := &proxyServer{
sock: sock,
auth: env.Value("proxy.auth"),
@@ -78,12 +75,12 @@ func (s *proxyServer) Send(wq queue.Q) error {
}
func (s *proxyServer) initConn(addr string) error {
- dial, err := newDialSocket(s.sock.proto, addr)
- if err != nil {
- return err
+ dial := dialSocket{
+ Proto: s.sock.Proto,
+ Addr: addr,
}
- conn, err := dial.Open(s.env)
+ conn, err := dial.New(s.env)
if err != nil {
dial.Close()
return err
@@ -138,3 +135,7 @@ func (s *proxyServer) Close() (err error) {
return
}
+
+func init() {
+ register("proxy", proxySocket{})
+}
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index 62ce5cf..03b73d9 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -5,8 +5,11 @@ import (
"fmt"
"log"
"net"
+ "reflect"
+ "sort"
"strings"
"sync"
+
"tunnel/pkg/server/env"
"tunnel/pkg/server/opts"
"tunnel/pkg/server/queue"
@@ -21,7 +24,7 @@ type Conn interface {
}
type S interface {
- Open(env env.Env) (Conn, error)
+ New(env env.Env) (Conn, error)
Close()
}
@@ -66,43 +69,62 @@ func (c *conn) Close() error {
return err
}
-func New(desc string) (S, error) {
- base, opts := opts.Parse(desc)
- args := strings.SplitN(base, "/", 2)
-
- var proto string
- var addr string
+func New(desc string, e env.Env) (S, error) {
+ name, opts := opts.Parse(desc)
- if len(args) < 2 {
- addr = args[0]
- } else {
- proto, addr = args[0], args[1]
+ t, ok := sockets[name]
+ if !ok {
+ return nil, fmt.Errorf("%s: unknown type", name)
}
- if proto == "" {
- proto = "tcp"
+ s := reflect.New(t).Interface()
+ if err := opts.Configure(s); err != nil {
+ return nil, fmt.Errorf("%s: %w", name, err)
}
- switch addr {
- case "loop":
- return newLoopSocket()
- case "proxy":
- return newProxySocket(proto)
- case "":
- return nil, fmt.Errorf("bad socket '%s'", desc)
+ if i, ok := s.(interface{ Prepare(env.Env) error }); ok {
+ if err := i.Prepare(e); err != nil {
+ return nil, fmt.Errorf("%s: %w", name, err)
+ }
}
- if proto == "tun" {
- return newTunSocket(addr)
+ return s.(S), nil
+}
+
+func parseProtoAddr(proto, addr string) (string, string) {
+ if proto == "tcp" || proto == "udp" {
+ if strings.HasPrefix(addr, "-:") {
+ addr = "localhost" + addr[1:]
+ }
}
- if opts.Bool("listen") {
- return newListenSocket(proto, addr, opts)
+ return proto, addr
+}
+
+var sockets = map[string]reflect.Type{}
+
+func register(name string, i interface{}) {
+ t := reflect.TypeOf(i)
+ if t.Kind() != reflect.Struct {
+ log.Panicf("non-struct type '%s'", t.String())
+ }
+ if _, ok := reflect.New(t).Interface().(S); !ok {
+ log.Panicf("uncompatible socket type '%s'", t.String())
+ }
+ if _, ok := sockets[name]; ok {
+ log.Panicf("duplicate socket name '%s'", name)
}
+ sockets[name] = t
+}
+
+func GetList() []string {
+ var list []string
- if opts.Bool("defer") {
- return newDeferSocket(proto, addr)
+ for k := range sockets {
+ list = append(list, k)
}
- return newDialSocket(proto, addr)
+ sort.Strings(list)
+
+ return list
}
diff --git a/pkg/server/socket/tun.go b/pkg/server/socket/tun.go
index d48c30c..3e673eb 100644
--- a/pkg/server/socket/tun.go
+++ b/pkg/server/socket/tun.go
@@ -8,10 +8,11 @@ import (
"os"
"strings"
"sync"
+ "unsafe"
+
"tunnel/pkg/pack"
"tunnel/pkg/server/env"
"tunnel/pkg/server/queue"
- "unsafe"
)
const maxTunBufSize = 65535
@@ -24,7 +25,7 @@ type ifReq struct {
}
type tunSocket struct {
- name string
+ Name string `opts:"required"`
}
type tunConn struct {
@@ -34,35 +35,30 @@ type tunConn struct {
once sync.Once
}
-func newTunSocket(name string) (S, error) {
- return &tunSocket{name: name}, nil
-
-}
-
func (s *tunSocket) String() string {
- return fmt.Sprintf("tun/%s", s.name)
+ return fmt.Sprintf("tun/%s", s.Name)
}
func (s *tunSocket) Single() {}
-func (s *tunSocket) Open(env.Env) (Conn, error) {
+func (s *tunSocket) New(env.Env) (Conn, error) {
fd, err := unix.Open("/dev/net/tun", unix.O_RDWR, 0)
if err != nil {
return nil, err
}
ifr := &ifReq{}
- copy(ifr.name[:], s.name)
+ copy(ifr.name[:], s.Name)
ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI
if err := ioctl(fd, unix.TUNSETIFF, unsafe.Pointer(ifr)); err != nil {
unix.Close(fd)
- return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.name, err)
+ return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.Name, err)
}
if err := unix.SetNonblock(fd, true); err != nil {
unix.Close(fd)
- return nil, fmt.Errorf("set nonblock %s: %w", s.name, err)
+ return nil, fmt.Errorf("set nonblock %s: %w", s.Name, err)
}
c := &tunConn{
@@ -124,3 +120,7 @@ func (c *tunConn) Close() error {
return err
}
+
+func init() {
+ register("tun", tunSocket{})
+}
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index b6c5ada..8b86ddc 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -11,6 +11,7 @@ import (
"sync"
"sync/atomic"
"time"
+
"tunnel/pkg/config"
"tunnel/pkg/server/env"
"tunnel/pkg/server/hook"
@@ -92,6 +93,8 @@ func (t *tunnel) stopStreams() {
}
func (t *tunnel) Close() {
+ t.env.Detach()
+
t.stopServe()
t.stopStreams()
@@ -139,7 +142,7 @@ func (t *tunnel) openPipes(env env.Env) ([]*hook.Pipe, error) {
}
for _, h := range t.hooks {
- p, err := h.Open(env)
+ p, err := h.New(env)
if err != nil {
cleanup()
return nil, fmt.Errorf("%s: %w", h, err)
@@ -159,11 +162,11 @@ func (t *tunnel) serve() {
env.Set("tunnel", t.id)
env.Set("stream", strconv.Itoa(t.nextSid))
- if in, err := t.in.Open(env); err != nil {
+ if in, err := t.in.New(env); err != nil {
if t.alive() {
log.Println(t, err)
}
- } else if out, err := t.out.Open(env); err != nil {
+ } else if out, err := t.out.New(env); err != nil {
log.Println(t, err)
in.Close()
} else if pipes, err := t.openPipes(env); err != nil {
@@ -377,47 +380,41 @@ func parseHooks(args []string) ([]hook.H, error) {
return hooks, nil
}
-func newTunnel(limit int, args []string, env env.Env) (*tunnel, error) {
- var in, out socket.S
- var hooks []hook.H
- var err error
+func (t *tunnel) init(limit int, args []string, env env.Env) (err error) {
+ t.env = env.Fork("tunnel", t.id)
+ defer func() {
+ if err != nil {
+ t.env.Detach()
+ }
+ }()
+
+ closeOnFail := func(s socket.S) {
+ if err != nil {
+ s.Close()
+ }
+ }
n := len(args) - 1
- if in, err = socket.New(args[0]); err != nil {
- return nil, err
+ if t.in, err = socket.New(args[0], t.env); err != nil {
+ return
}
+ defer closeOnFail(t.in)
- if _, ok := in.(socket.Single); ok {
+ if _, ok := t.in.(socket.Single); ok {
limit = 1
}
- if out, err = socket.New(args[n]); err != nil {
- in.Close()
- return nil, err
+ if t.out, err = socket.New(args[n], t.env); err != nil {
+ return
}
+ defer closeOnFail(t.out)
- if hooks, err = parseHooks(args[1:n]); err != nil {
- in.Close()
- out.Close()
- return nil, err
+ if t.hooks, err = parseHooks(args[1:n]); err != nil {
+ return
}
- t := &tunnel{
- args: strings.Join(args, " "),
- quit: make(chan struct{}),
- done: make(chan struct{}),
- hooks: hooks,
- in: in,
- out: out,
- env: env,
- queue: make(chan struct{}, limit),
- streams: make(map[int]*stream),
- }
-
- go t.serve()
-
- return t, nil
+ return
}
func isOkTunnelName(s string) bool {
@@ -471,18 +468,27 @@ func tunnelAdd(r *request) {
r.Fatal("not enough args")
}
- t, err := newTunnel(limit, args, r.c.s.env)
- if err != nil {
- r.Fatal(err)
+ if name == "" {
+ name = r.c.s.tunnels.next()
}
- if name == "" {
- t.id = r.c.s.tunnels.add(t)
- } else {
- t.id = name
- r.c.s.tunnels[t.id] = t
+ t := &tunnel{
+ id: name,
+ args: strings.Join(args, " "),
+ quit: make(chan struct{}),
+ done: make(chan struct{}),
+ queue: make(chan struct{}, limit),
+ streams: make(map[int]*stream),
}
+ if err := t.init(limit, args, r.c.s.env); err != nil {
+ r.Fatal(err)
+ }
+
+ r.c.s.tunnels[t.id] = t
+
+ go t.serve()
+
log.Println(r.c, r, t, "create")
}
@@ -584,6 +590,12 @@ func showHooks(r *request) {
}
}
+func showSockets(r *request) {
+ for _, s := range socket.GetList() {
+ r.Println(s)
+ }
+}
+
func init() {
newCmd(tunnelAdd, "add")
newCmd(tunnelDel, "del")
@@ -591,6 +603,7 @@ func init() {
newCmd(tunnelRename, "rename")
newCmd(showHooks, "hooks")
+ newCmd(showSockets, "sockets")
newCmd(showTunnels, "show")