diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2020-12-16 15:27:48 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2020-12-16 15:27:48 +0300 |
| commit | 6fed9dd0dd62718f78eca11e30a71c2712636fbd (patch) | |
| tree | 8d1f90b96efbe8ea8aea350c283325adc216ef9d /pkg | |
| parent | 050ea053dd549f0dd01beddfcd74989858391fd7 (diff) | |
hook and socket args check fix, tests
Diffstat (limited to 'pkg')
28 files changed, 785 insertions, 365 deletions
diff --git a/pkg/server/automap.go b/pkg/server/automap.go index 67ae5b0..f38f758 100644 --- a/pkg/server/automap.go +++ b/pkg/server/automap.go @@ -10,16 +10,21 @@ type automap map[string]interface{} var errExists = errors.New("already exists") var errNotFound = errors.New("no such entry") -func (m automap) add(v interface{}) string { +func (m automap) next() string { for n := 1; ; n++ { k := fmt.Sprintf("%d", n) if _, ok := m[k]; !ok { - m[k] = v return k } } } +func (m automap) add(v interface{}) string { + k := m.next() + m[k] = v + return k +} + func (m automap) rename(old string, new string) (interface{}, error) { if _, ok := m[old]; !ok { return nil, errNotFound diff --git a/pkg/server/env/env.go b/pkg/server/env/env.go index 2c97669..7aa93d2 100644 --- a/pkg/server/env/env.go +++ b/pkg/server/env/env.go @@ -4,12 +4,23 @@ import ( "errors" "regexp" "sort" + "strings" "sync" ) type env struct { + // vars m map[string]string + + // forks + c map[string]*env + + // key in parent env's map + k string + sync.Mutex + + // parent env p *env } @@ -36,23 +47,84 @@ func (e Env) init() { } } -func (e Env) Fork() Env { +func (e Env) Fork(path ...string) Env { t := New() t.p = e.env + + if len(path) > 0 { + k := strings.Join(path, ".") + + e.Lock() + if e.c == nil { + e.c = make(map[string]*env) + } else if _, ok := e.c[k]; ok { + panic("env fork already exists: " + k) + } + e.c[k] = t.env + e.Unlock() + + t.k = k + } + return t } -func (e *env) Find(key string) (string, bool) { +func (e Env) Detach() { + if e.p != nil { + e.p.Lock() + delete(e.p.c, e.k) + e.p.Unlock() + } +} + +func (e *env) find(key string) (string, bool, func() (string, bool)) { e.Lock() defer e.Unlock() v, ok := e.m[key] if ok { - return v, ok + return v, ok, nil + } + + if e.c != nil { + s := strings.Split(key, ".") + + for k := ""; len(s) > 1; s = s[1:] { + if k == "" { + k = s[0] + } else { + k += "." + s[0] + } + + if c, ok := e.c[k]; ok { + return "", false, func() (string, bool) { + return c.Find(strings.Join(s[1:], ".")) + } + } + } } if e.p != nil { - return e.p.Find(key) + return "", false, func() (string, bool) { + return e.p.Find(key) + } + } + + return "", false, nil +} + +func (e *env) Find(key string) (string, bool) { + if key == "" { + return "", false + } + + v, ok, f := e.find(key) + if ok { + return v, ok + } + + if f != nil { + return f() } return "", false diff --git a/pkg/server/hook/aes.go b/pkg/server/hook/aes.go index e437651..dc48f49 100644 --- a/pkg/server/hook/aes.go +++ b/pkg/server/hook/aes.go @@ -6,8 +6,8 @@ import ( "crypto/md5" "crypto/rand" "io" + "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) @@ -64,24 +64,16 @@ func (a *aesPipe) Recv(rq, wq queue.Q) error { return queue.IoCopy(reader, wq.Writer()) } -func newAes(env env.Env) *aesPipe { +func (aesHook) New(env env.Env) (interface{}, error) { s := env.Value("secret") h := md5.Sum([]byte(s)) a := &aesPipe{key: make([]byte, 16)} copy(a.key, h[:]) - return a -} - -func (aesHook) Open(env env.Env) (interface{}, error) { - return newAes(env), nil -} - -func newAesHook(opts.Opts) (hook, error) { - return aesHook{}, nil + return a, nil } func init() { - register("aes", newAesHook) + register("aes", aesHook{}) } diff --git a/pkg/server/hook/auth.go b/pkg/server/hook/auth.go index 5910b56..dbfc9bc 100644 --- a/pkg/server/hook/auth.go +++ b/pkg/server/hook/auth.go @@ -7,8 +7,8 @@ import ( "io" "sync" "time" + "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) @@ -17,6 +17,10 @@ const authTimeout = 5 * time.Second const saltSize = 16 const hashSize = md5.Size +type authHook struct { + m sync.Map +} + type auth struct { h *authHook @@ -42,10 +46,6 @@ var errDupSalt = errors.New("peer repeats salt") var errAuthFail = errors.New("peer auth fail") var errTimeout = errors.New("timeout") -type authHook struct { - m sync.Map -} - func (a *auth) Init() error { b := make([]byte, saltSize) if _, err := rand.Read(b); err != nil { @@ -149,7 +149,7 @@ func (a *auth) Close() { a.h.m.Delete(a.salt.self) } -func (h *authHook) Open(env env.Env) (interface{}, error) { +func (h *authHook) New(env env.Env) (interface{}, error) { a := &auth{ h: h, secret: env.Value("secret"), @@ -169,10 +169,6 @@ func (h *authHook) Open(env env.Env) (interface{}, error) { return a, nil } -func newAuthHook(opts.Opts) (hook, error) { - return &authHook{}, nil -} - func init() { - register("auth", newAuthHook) + register("auth", authHook{}) } diff --git a/pkg/server/hook/b64.go b/pkg/server/hook/b64.go index fce42a0..c6637e5 100644 --- a/pkg/server/hook/b64.go +++ b/pkg/server/hook/b64.go @@ -2,17 +2,16 @@ package hook import ( "encoding/base64" + "tunnel/pkg/netstring" - "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) var b64Enc = base64.RawStdEncoding -type b64Hook struct{} +type b64Pipe struct{} -func (b64Hook) Send(rq, wq queue.Q) error { +func (b64Pipe) Send(rq, wq queue.Q) error { e := netstring.NewEncoder(wq.Writer()) for b := range rq { @@ -22,7 +21,7 @@ func (b64Hook) Send(rq, wq queue.Q) error { return nil } -func (b64Hook) Recv(rq, wq queue.Q) error { +func (b64Pipe) Recv(rq, wq queue.Q) error { d := netstring.NewDecoder(rq.Reader()) for { @@ -39,14 +38,6 @@ func (b64Hook) Recv(rq, wq queue.Q) error { } } -func (h b64Hook) Open(env.Env) (interface{}, error) { - return h, nil -} - -func newB64Hook(opts.Opts) (hook, error) { - return b64Hook{}, nil -} - func init() { - register("b64", newB64Hook) + registerPipe("b64", b64Pipe{}) } diff --git a/pkg/server/hook/b85.go b/pkg/server/hook/b85.go index bf36b56..d90a1c4 100644 --- a/pkg/server/hook/b85.go +++ b/pkg/server/hook/b85.go @@ -3,15 +3,14 @@ package hook import ( "encoding/ascii85" "errors" + "tunnel/pkg/netstring" - "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) -type b85Hook struct{} +type b85Pipe struct{} -func (b85Hook) Send(rq, wq queue.Q) error { +func (b85Pipe) Send(rq, wq queue.Q) error { e := netstring.NewEncoder(wq.Writer()) for b := range rq { @@ -23,7 +22,7 @@ func (b85Hook) Send(rq, wq queue.Q) error { return nil } -func (b85Hook) Recv(rq, wq queue.Q) error { +func (b85Pipe) Recv(rq, wq queue.Q) error { d := netstring.NewDecoder(rq.Reader()) for { @@ -44,14 +43,6 @@ func (b85Hook) Recv(rq, wq queue.Q) error { } } -func (h b85Hook) Open(env.Env) (interface{}, error) { - return h, nil -} - -func newB85Hook(opts.Opts) (hook, error) { - return b85Hook{}, nil -} - func init() { - register("b85", newB85Hook) + registerPipe("b85", b85Pipe{}) } diff --git a/pkg/server/hook/dump.go b/pkg/server/hook/dump.go index 864443b..d871d63 100644 --- a/pkg/server/hook/dump.go +++ b/pkg/server/hook/dump.go @@ -7,28 +7,28 @@ import ( "os" "path" "time" + "tunnel/pkg/config" "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) const dumpDefaultFile = "/tmp/tunnel/dump" +type dumpHook struct { + File string + Time bool +} + type dump struct { f *os.File h *dumpHook } -type dumpHook struct { - file string - time bool -} - func (t *dump) write(s string, p []byte) error { var out bytes.Buffer - if t.h.time { + if t.h.Time { now := time.Now().Format(config.TimeMsFormat) fmt.Fprintln(&out, now, s, len(p)) } else { @@ -69,8 +69,8 @@ func (t *dump) Close() { } func (h *dumpHook) where(env env.Env) string { - if h.file != "" { - return h.file + if h.File != "" { + return h.File } if v := env.Value("dump.file"); v != "" { @@ -80,7 +80,7 @@ func (h *dumpHook) where(env env.Env) string { return dumpDefaultFile } -func (h *dumpHook) Open(env env.Env) (interface{}, error) { +func (h *dumpHook) New(env env.Env) (interface{}, error) { file := h.where(env) dir := path.Dir(file) @@ -102,14 +102,6 @@ func (h *dumpHook) Open(env env.Env) (interface{}, error) { return t, nil } -func newDumpHook(opts opts.Opts) (hook, error) { - h := &dumpHook{ - file: opts["file"], - time: opts.Bool("time"), - } - return h, nil -} - func init() { - register("dump", newDumpHook) + register("dump", dumpHook{}) } diff --git a/pkg/server/hook/hex.go b/pkg/server/hook/hex.go index bc71bf2..362dbd4 100644 --- a/pkg/server/hook/hex.go +++ b/pkg/server/hook/hex.go @@ -2,14 +2,13 @@ package hook import ( "encoding/hex" - "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" + "tunnel/pkg/server/queue" ) -type hexHook struct{} +type hexPipe struct{} -func (hexHook) Send(rq, wq queue.Q) error { +func (hexPipe) Send(rq, wq queue.Q) error { enc := hex.NewEncoder(wq.Writer()) for b := range rq { @@ -19,19 +18,11 @@ func (hexHook) Send(rq, wq queue.Q) error { return nil } -func (hexHook) Recv(rq, wq queue.Q) error { +func (hexPipe) Recv(rq, wq queue.Q) error { r := hex.NewDecoder(rq.Reader()) return queue.IoCopy(r, wq.Writer()) } -func (h hexHook) Open(env.Env) (interface{}, error) { - return h, nil -} - -func newHexHook(opts.Opts) (hook, error) { - return hexHook{}, nil -} - func init() { - register("hex", newHexHook) + registerPipe("hex", hexPipe{}) } diff --git a/pkg/server/hook/hook.go b/pkg/server/hook/hook.go index 69aa237..36b01d4 100644 --- a/pkg/server/hook/hook.go +++ b/pkg/server/hook/hook.go @@ -3,16 +3,16 @@ package hook import ( "fmt" "log" + "reflect" "sort" "strings" + "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) -type hookInitFunc func(opts.Opts) (hook, error) - -var hooks = map[string]hookInitFunc{} +var hooks = map[string]interface{}{} type Pipe struct { priv interface{} @@ -21,13 +21,12 @@ type Pipe struct { Recv Func } -type hook interface { - Open(env env.Env) (interface{}, error) +type Hooker interface { + New(env env.Env) (interface{}, error) } type H interface { - Open(env env.Env) (*Pipe, error) - String() string + New(env env.Env) (*Pipe, error) } type Sender interface { @@ -44,12 +43,12 @@ func (f Func) Send(rq, wq queue.Q) error { return f(rq, wq) } -func (f Func) Open(env env.Env) (interface{}, error) { +func (f Func) New(env env.Env) (interface{}, error) { return f, nil } type wrapper struct { - hook + hook Hooker name string reverse bool } @@ -58,8 +57,8 @@ func (w *wrapper) String() string { return fmt.Sprintf("hook:%s", w.name) } -func (w *wrapper) Open(env env.Env) (*Pipe, error) { - it, err := w.hook.Open(env) +func (w *wrapper) New(env env.Env) (*Pipe, error) { + it, err := w.hook.New(env) if err != nil { return nil, err } @@ -87,6 +86,23 @@ func (p *Pipe) Close() { } } +func initHook(i interface{}, opts opts.Opts) (Hooker, error) { + if f, ok := i.(Func); ok { + return f, nil + } + + if p, ok := i.(pipeHolder); ok { + return p, nil + } + + h := reflect.New(reflect.TypeOf(i)).Interface() + if err := opts.Configure(h); err != nil { + return nil, err + } + + return h.(Hooker), nil +} + func New(desc string) (H, error) { name, opts := opts.Parse(desc) reverse := false @@ -96,9 +112,9 @@ func New(desc string) (H, error) { reverse = true } - if f, ok := hooks[name]; !ok { + if i, ok := hooks[name]; !ok { return nil, fmt.Errorf("unknown hook '%s'", name) - } else if h, err := f(opts); err != nil { + } else if h, err := initHook(i, opts); err != nil { return nil, fmt.Errorf("%s: %w", name, err) } else { w := &wrapper{ @@ -110,18 +126,41 @@ func New(desc string) (H, error) { } } -func register(name string, f hookInitFunc) { +func register(name string, i interface{}) { + switch t := reflect.TypeOf(i); t.Kind() { + case reflect.Struct: + if _, ok := reflect.New(t).Interface().(Hooker); !ok { + log.Panicf("uncompatible hook type '%s'", t.String()) + } + case reflect.Func: + if _, ok := i.(Func); !ok { + log.Panicf("uncompatible func type '%s'", t.String()) + } + default: + log.Panicf("non-struct and non-func type '%s'", t.String()) + } + if _, ok := hooks[name]; ok { log.Panicf("duplicate hook name '%s'", name) } - hooks[name] = f + hooks[name] = i } -func registerFunc(name string, p Func) { - register(name, func(opts.Opts) (hook, error) { - return p, nil - }) +func registerFunc(name string, f Func) { + register(name, f) +} + +type pipeHolder struct { + i interface{} +} + +func (p pipeHolder) New(env.Env) (interface{}, error) { + return p.i, nil +} + +func registerPipe(name string, i interface{}) { + register(name, pipeHolder{i}) } func GetList() []string { diff --git a/pkg/server/hook/info-http.go b/pkg/server/hook/info-http.go index 73480ff..ec56f87 100644 --- a/pkg/server/hook/info-http.go +++ b/pkg/server/hook/info-http.go @@ -2,9 +2,9 @@ package hook import ( "bufio" + "tunnel/pkg/http" "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) @@ -42,14 +42,10 @@ func (info *infoHttp) Recv(rq, wq queue.Q) error { return queue.Copy(rq, wq) } -func (infoHttpHook) Open(env env.Env) (interface{}, error) { +func (infoHttpHook) New(env env.Env) (interface{}, error) { return &infoHttp{env: env}, nil } -func newInfoHttpHook(opts.Opts) (hook, error) { - return infoHttpHook{}, nil -} - func init() { - register("info-http", newInfoHttpHook) + register("info-http", infoHttpHook{}) } diff --git a/pkg/server/hook/proxy.go b/pkg/server/hook/proxy.go index bc6da18..4276d9a 100644 --- a/pkg/server/hook/proxy.go +++ b/pkg/server/hook/proxy.go @@ -3,12 +3,11 @@ package hook import ( "bufio" "bytes" - "errors" "fmt" "regexp" + "tunnel/pkg/http" "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) @@ -16,8 +15,8 @@ var addrPattern = "^([0-9a-zA-Z-.]+|\\[[0-9a-fA-F:]*\\]):[0-9]+$" var isGoodAddr = regexp.MustCompile(addrPattern).MatchString type proxyHook struct { - addr string - auth string + Addr string `opts:"required"` + Auth string } type proxy struct { @@ -67,15 +66,15 @@ func (p *proxy) Recv(rq, wq queue.Q) error { return queue.IoCopy(r, wq.Writer()) } -func (h *proxyHook) Open(env env.Env) (interface{}, error) { - addr := env.Expand(h.addr) +func (h *proxyHook) New(env env.Env) (interface{}, error) { + addr := env.Expand(h.Addr) if !isGoodAddr(addr) { return nil, fmt.Errorf("invalid addr '%s'", addr) } p := &proxy{ addr: addr, - auth: h.auth, + auth: h.Auth, c: make(chan bool), env: env, } @@ -87,19 +86,6 @@ func (h *proxyHook) Open(env env.Env) (interface{}, error) { return p, nil } -func newProxyHook(opts opts.Opts) (hook, error) { - h := &proxyHook{ - addr: opts["addr"], - auth: opts["auth"], - } - - if h.addr == "" { - return nil, errors.New("expected addr") - } - - return h, nil -} - func init() { - register("proxy", newProxyHook) + register("proxy", proxyHook{}) } diff --git a/pkg/server/hook/split.go b/pkg/server/hook/split.go index 6a2c4ca..59c8055 100644 --- a/pkg/server/hook/split.go +++ b/pkg/server/hook/split.go @@ -1,27 +1,24 @@ package hook import ( - "errors" - "strconv" "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" ) -const splitDefaultSize = 1024 - -var errBadSize = errors.New("bad size value") - type splitHook struct { - size int + Size int `opts:"positive,default:1024"` +} + +type splitPipe struct { + h *splitHook } -func (h *splitHook) Send(rq, wq queue.Q) error { +func (p splitPipe) Send(rq, wq queue.Q) error { for b := range rq { var upto int for n := 0; n < len(b); n = upto { - upto += h.size + upto += p.h.Size if upto > len(b) { upto = len(b) @@ -34,24 +31,10 @@ func (h *splitHook) Send(rq, wq queue.Q) error { return nil } -func (h *splitHook) Open(env.Env) (interface{}, error) { - return h, nil -} - -func newSplitHook(opts opts.Opts) (hook, error) { - size := splitDefaultSize - - if s, ok := opts["size"]; ok { - var err error - - if size, err = strconv.Atoi(s); err != nil || size <= 0 { - return nil, errBadSize - } - } - - return &splitHook{size: size}, nil +func (h *splitHook) New(env.Env) (interface{}, error) { + return &splitPipe{h}, nil } func init() { - register("split", newSplitHook) + register("split", splitHook{}) } diff --git a/pkg/server/hook/zip.go b/pkg/server/hook/zip.go index bde4957..615b50d 100644 --- a/pkg/server/hook/zip.go +++ b/pkg/server/hook/zip.go @@ -3,14 +3,13 @@ package hook import ( "compress/flate" "io" - "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" + "tunnel/pkg/server/queue" ) -type zipHook struct{} +type zipPipe struct{} -func (zipHook) Send(rq, wq queue.Q) error { +func (zipPipe) Send(rq, wq queue.Q) error { w, err := flate.NewWriter(wq.Writer(), flate.BestCompression) if err != nil { return err @@ -28,7 +27,7 @@ func (zipHook) Send(rq, wq queue.Q) error { return w.Close() } -func (zipHook) Recv(rq, wq queue.Q) error { +func (zipPipe) Recv(rq, wq queue.Q) error { r := flate.NewReader(rq.Reader()) // FIXME: not received ending due to ultimate conn.Close @@ -42,14 +41,6 @@ func (zipHook) Recv(rq, wq queue.Q) error { return r.Close() } -func (h zipHook) Open(env.Env) (interface{}, error) { - return h, nil -} - -func newZipHook(opts.Opts) (hook, error) { - return zipHook{}, nil -} - func init() { - register("zip", newZipHook) + registerPipe("zip", zipPipe{}) } diff --git a/pkg/server/opts/opts.go b/pkg/server/opts/opts.go index 22383d8..c5729a3 100644 --- a/pkg/server/opts/opts.go +++ b/pkg/server/opts/opts.go @@ -1,6 +1,19 @@ package opts -import "strings" +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +type fieldSpec struct { + name string + inline bool + required bool + positive bool + defaultValue string +} type Opts map[string]string @@ -20,7 +33,137 @@ func Parse(s string) (string, Opts) { return v[0], m } -func (m Opts) Bool(key string) bool { - _, ok := m[key] +func (opts Opts) Bool(key string) bool { + _, ok := opts[key] return ok } + +func (opts Opts) Configure(v interface{}) error { + rv := reflect.ValueOf(v) + + switch { + case rv.Kind() != reflect.Ptr: + return fmt.Errorf("opts: configure failed: non-pointer %s", rv.Type().String()) + case rv.IsNil(): + return fmt.Errorf("opts: configure failed: nil %s", rv.Type().String()) + case rv.Elem().Kind() != reflect.Struct: + return fmt.Errorf("opts: configure failed: non-struct %s", rv.Elem().Type().String()) + } + + return opts.configure(rv.Elem()) +} + +func getFieldSpec(field reflect.StructField) (fs fieldSpec) { + fs.name = strings.ToLower(field.Name) + + for _, s := range strings.Split(field.Tag.Get("opts"), ",") { + switch { + case s == "inline": + fs.inline = true + case s == "required": + fs.required = true + case s == "positive": + fs.positive = true + case strings.HasPrefix(s, "default:"): + fs.defaultValue = s[8:] + } + } + + return +} + +func (opts Opts) setValue(v reflect.Value, fs fieldSpec) error { + s, ok := opts[fs.name] + kind := v.Kind() + + if kind == reflect.Bool { + if s != "" { + return fmt.Errorf("%s: boolean option does not need a value", fs.name) + } + + if ok { + v.SetBool(true) + } + + return nil + } + + if s == "" { + if fs.required { + return fmt.Errorf("%s: expect value", fs.name) + } + + s = fs.defaultValue + } + + if s == "" { + return nil + } + + switch kind { + case reflect.String: + v.SetString(s) + case reflect.Int: + n, err := strconv.Atoi(s) + if err != nil { + return fmt.Errorf("%s: %w", fs.name, err) + } + if n <= 0 { + return fmt.Errorf("%s: expect positive", fs.name) + } + v.SetInt(int64(n)) + default: + return fmt.Errorf("%s: unsupported type '%s'", fs.name, v.Type().String()) + } + + return nil +} + +func (opts Opts) configure(v reflect.Value) error { + m := map[string]bool{} + + for s := range opts { + m[s] = true + } + + var visit func(v reflect.Value) error + + visit = func(v reflect.Value) error { + t := v.Type() + + for n := 0; n < v.NumField(); n++ { + fs := getFieldSpec(t.Field(n)) + fv := v.Field(n) + + if fs.inline { + if err := visit(fv); err != nil { + return err + } + } else if fv.CanSet() { + if err := opts.setValue(v.Field(n), fs); err != nil { + return err + } + + delete(m, fs.name) + } + } + + return nil + } + + if err := visit(v); err != nil { + return err + } + + if len(m) > 0 { + var ss []string + + for s := range m { + ss = append(ss, s) + } + + return fmt.Errorf("unknown options: %s", strings.Join(ss, ",")) + } + + return nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index d57d9dd..8c5683a 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -254,10 +254,6 @@ func (c *client) handle() { r.run(args) - if r.out.Len() == 0 { - r.out.Write([]byte("\n")) - } - ew := c.w.Encode(r.out.String()) if ew != nil { log.Println(c, "handle:", ew) diff --git a/pkg/server/socket/defer.go b/pkg/server/socket/defer.go index 7ed303d..7c1436e 100644 --- a/pkg/server/socket/defer.go +++ b/pkg/server/socket/defer.go @@ -6,28 +6,19 @@ import ( ) type deferSocket struct { - S + dialSocket `opts:"inline"` } type deferConn struct { - sock *deferSocket + sock S wait chan bool env env.Env conn Conn } -func newDeferSocket(proto, addr string) (S, error) { - s, err := newDialSocket(proto, addr) - if err != nil { - return s, err - } - - return &deferSocket{s}, nil -} - -func (s *deferSocket) Open(env env.Env) (Conn, error) { +func (s *deferSocket) New(env env.Env) (Conn, error) { c := &deferConn{ - sock: s, + sock: &s.dialSocket, wait: make(chan bool), env: env, } @@ -54,7 +45,7 @@ func (c *deferConn) Recv(rq queue.Q) error { return nil } - conn, err := c.sock.S.Open(c.env) + conn, err := c.sock.New(c.env) if err != nil { c.wait <- false return err @@ -83,3 +74,7 @@ func (c *deferConn) Close() (err error) { return } + +func init() { + register("defer", deferSocket{}) +} diff --git a/pkg/server/socket/dial.go b/pkg/server/socket/dial.go index 7623084..b2df3b7 100644 --- a/pkg/server/socket/dial.go +++ b/pkg/server/socket/dial.go @@ -4,36 +4,25 @@ import ( "fmt" "log" "net" - "strings" "time" + "tunnel/pkg/server/env" ) const defaultTimeout = 5 * time.Second type dialSocket struct { - proto, addr string -} - -func newDialSocket(proto, addr string) (S, error) { - return &dialSocket{proto: proto, addr: addr}, nil + Proto string `opts:"default:tcp"` + Addr string `opts:"required"` } func (s *dialSocket) String() string { - return fmt.Sprintf("%s/%s", s.proto, s.addr) + return fmt.Sprintf("%s/%s", s.Proto, s.Addr) } -func (s *dialSocket) Open(e env.Env) (Conn, error) { - addr := e.Expand(s.addr) - - switch s.proto { - case "tcp", "udp": - if !strings.Contains(addr, ":") { - addr = "localhost:" + addr - } - } - - conn, err := net.DialTimeout(s.proto, addr, defaultTimeout) +func (s *dialSocket) New(e env.Env) (Conn, error) { + proto, addr := parseProtoAddr(s.Proto, e.Expand(s.Addr)) + conn, err := net.DialTimeout(proto, addr, defaultTimeout) if err != nil { return nil, err } @@ -49,3 +38,7 @@ func (s *dialSocket) Open(e env.Env) (Conn, error) { func (s *dialSocket) Close() { } + +func init() { + register("dial", dialSocket{}) +} diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go index 910e5de..2c2f184 100644 --- a/pkg/server/socket/listen.go +++ b/pkg/server/socket/listen.go @@ -5,64 +5,54 @@ import ( "fmt" "log" "net" - "strings" + "tunnel/pkg/server/env" - "tunnel/pkg/server/opts" ) type listenSocket struct { - proto, addr string - listen net.Listener - redirect bool - tproxy bool -} + Proto string `opts:"default:tcp"` + Addr string `opts:"required"` -func newListenSocket(proto, addr string, opts opts.Opts) (S, error) { - redirect := opts.Bool("redirect") - tproxy := opts.Bool("tproxy") + Redirect bool + Tproxy bool - if proto == "tcp" { - if !strings.Contains(addr, ":") { - addr = ":" + addr - } - } + listen net.Listener +} - if redirect && proto != "tcp" { - return nil, errors.New("redirect not supported") +func (s *listenSocket) Prepare(e env.Env) error { + if s.Redirect && s.Proto != "tcp" { + return errors.New("redirect not supported") } - if tproxy && proto != "tcp" { - return nil, errors.New("tproxy not supported") + if s.Tproxy && s.Proto != "tcp" { + return errors.New("tproxy not supported") } - if redirect && tproxy { - return nil, errors.New("redirect and tproxy cannot be used together") + if s.Redirect && s.Tproxy { + return errors.New("redirect and tproxy cannot be used together") } + proto, addr := parseProtoAddr(s.Proto, s.Addr) listen, err := net.Listen(proto, addr) if err != nil { - return nil, err + return err } - if tproxy { + e.Set("listen", listen.Addr().String()) + + if s.Tproxy { if err := setConnTransparent(listen); err != nil { listen.Close() - return nil, err + return err } } - s := &listenSocket{ - proto: proto, - addr: addr, - listen: listen, - redirect: redirect, - tproxy: tproxy, - } + s.listen = listen - return s, nil + return nil } -func (s *listenSocket) Open(env env.Env) (Conn, error) { +func (s *listenSocket) New(env env.Env) (Conn, error) { var original string conn, err := s.listen.Accept() @@ -74,7 +64,7 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) { desc := fmt.Sprintf("%s/%s->%s", la.Network(), ra, la) info := fmt.Sprintf("<%s/%s", ra.Network(), ra) - if s.redirect { + if s.Redirect { if err := getConnOriginalAddr(conn, &original); err != nil { log.Println("accept", desc, "failed") conn.Close() @@ -84,7 +74,7 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) { } } - if s.tproxy { + if s.Tproxy { env.Set("original", la.String()) } @@ -98,9 +88,13 @@ func (s *listenSocket) Open(env env.Env) (Conn, error) { } func (s *listenSocket) String() string { - return fmt.Sprintf("%s/%s,listen", s.proto, s.addr) + return fmt.Sprintf("%s/%s,listen", s.Proto, s.Addr) } func (s *listenSocket) Close() { s.listen.Close() } + +func init() { + register("listen", listenSocket{}) +} diff --git a/pkg/server/socket/loop.go b/pkg/server/socket/loop.go index a06448a..c442140 100644 --- a/pkg/server/socket/loop.go +++ b/pkg/server/socket/loop.go @@ -30,7 +30,7 @@ func (c *loopConn) Close() error { return nil } -func (s *loopSocket) Open(env.Env) (Conn, error) { +func (s *loopSocket) New(env.Env) (Conn, error) { return &loopConn{make(chan queue.Q), make(chan error)}, nil } @@ -41,6 +41,6 @@ func (s *loopSocket) String() string { func (s *loopSocket) Close() { } -func newLoopSocket() (S, error) { - return &loopSocket{}, nil +func init() { + register("loop", loopSocket{}) } diff --git a/pkg/server/socket/proxy.go b/pkg/server/socket/proxy.go index e4baec2..1be4bba 100644 --- a/pkg/server/socket/proxy.go +++ b/pkg/server/socket/proxy.go @@ -5,6 +5,7 @@ import ( "bytes" "errors" "fmt" + "tunnel/pkg/http" "tunnel/pkg/server/env" "tunnel/pkg/server/queue" @@ -16,7 +17,7 @@ type status struct { } type proxySocket struct { - proto string + Proto string `opts:"default:tcp"` } type proxyServer struct { @@ -28,11 +29,7 @@ type proxyServer struct { conn Conn } -func newProxySocket(proto string) (S, error) { - return &proxySocket{proto}, nil -} - -func (sock *proxySocket) Open(env env.Env) (Conn, error) { +func (sock *proxySocket) New(env env.Env) (Conn, error) { s := &proxyServer{ sock: sock, auth: env.Value("proxy.auth"), @@ -78,12 +75,12 @@ func (s *proxyServer) Send(wq queue.Q) error { } func (s *proxyServer) initConn(addr string) error { - dial, err := newDialSocket(s.sock.proto, addr) - if err != nil { - return err + dial := dialSocket{ + Proto: s.sock.Proto, + Addr: addr, } - conn, err := dial.Open(s.env) + conn, err := dial.New(s.env) if err != nil { dial.Close() return err @@ -138,3 +135,7 @@ func (s *proxyServer) Close() (err error) { return } + +func init() { + register("proxy", proxySocket{}) +} diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index 62ce5cf..03b73d9 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -5,8 +5,11 @@ import ( "fmt" "log" "net" + "reflect" + "sort" "strings" "sync" + "tunnel/pkg/server/env" "tunnel/pkg/server/opts" "tunnel/pkg/server/queue" @@ -21,7 +24,7 @@ type Conn interface { } type S interface { - Open(env env.Env) (Conn, error) + New(env env.Env) (Conn, error) Close() } @@ -66,43 +69,62 @@ func (c *conn) Close() error { return err } -func New(desc string) (S, error) { - base, opts := opts.Parse(desc) - args := strings.SplitN(base, "/", 2) - - var proto string - var addr string +func New(desc string, e env.Env) (S, error) { + name, opts := opts.Parse(desc) - if len(args) < 2 { - addr = args[0] - } else { - proto, addr = args[0], args[1] + t, ok := sockets[name] + if !ok { + return nil, fmt.Errorf("%s: unknown type", name) } - if proto == "" { - proto = "tcp" + s := reflect.New(t).Interface() + if err := opts.Configure(s); err != nil { + return nil, fmt.Errorf("%s: %w", name, err) } - switch addr { - case "loop": - return newLoopSocket() - case "proxy": - return newProxySocket(proto) - case "": - return nil, fmt.Errorf("bad socket '%s'", desc) + if i, ok := s.(interface{ Prepare(env.Env) error }); ok { + if err := i.Prepare(e); err != nil { + return nil, fmt.Errorf("%s: %w", name, err) + } } - if proto == "tun" { - return newTunSocket(addr) + return s.(S), nil +} + +func parseProtoAddr(proto, addr string) (string, string) { + if proto == "tcp" || proto == "udp" { + if strings.HasPrefix(addr, "-:") { + addr = "localhost" + addr[1:] + } } - if opts.Bool("listen") { - return newListenSocket(proto, addr, opts) + return proto, addr +} + +var sockets = map[string]reflect.Type{} + +func register(name string, i interface{}) { + t := reflect.TypeOf(i) + if t.Kind() != reflect.Struct { + log.Panicf("non-struct type '%s'", t.String()) + } + if _, ok := reflect.New(t).Interface().(S); !ok { + log.Panicf("uncompatible socket type '%s'", t.String()) + } + if _, ok := sockets[name]; ok { + log.Panicf("duplicate socket name '%s'", name) } + sockets[name] = t +} + +func GetList() []string { + var list []string - if opts.Bool("defer") { - return newDeferSocket(proto, addr) + for k := range sockets { + list = append(list, k) } - return newDialSocket(proto, addr) + sort.Strings(list) + + return list } diff --git a/pkg/server/socket/tun.go b/pkg/server/socket/tun.go index d48c30c..3e673eb 100644 --- a/pkg/server/socket/tun.go +++ b/pkg/server/socket/tun.go @@ -8,10 +8,11 @@ import ( "os" "strings" "sync" + "unsafe" + "tunnel/pkg/pack" "tunnel/pkg/server/env" "tunnel/pkg/server/queue" - "unsafe" ) const maxTunBufSize = 65535 @@ -24,7 +25,7 @@ type ifReq struct { } type tunSocket struct { - name string + Name string `opts:"required"` } type tunConn struct { @@ -34,35 +35,30 @@ type tunConn struct { once sync.Once } -func newTunSocket(name string) (S, error) { - return &tunSocket{name: name}, nil - -} - func (s *tunSocket) String() string { - return fmt.Sprintf("tun/%s", s.name) + return fmt.Sprintf("tun/%s", s.Name) } func (s *tunSocket) Single() {} -func (s *tunSocket) Open(env.Env) (Conn, error) { +func (s *tunSocket) New(env.Env) (Conn, error) { fd, err := unix.Open("/dev/net/tun", unix.O_RDWR, 0) if err != nil { return nil, err } ifr := &ifReq{} - copy(ifr.name[:], s.name) + copy(ifr.name[:], s.Name) ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI if err := ioctl(fd, unix.TUNSETIFF, unsafe.Pointer(ifr)); err != nil { unix.Close(fd) - return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.name, err) + return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.Name, err) } if err := unix.SetNonblock(fd, true); err != nil { unix.Close(fd) - return nil, fmt.Errorf("set nonblock %s: %w", s.name, err) + return nil, fmt.Errorf("set nonblock %s: %w", s.Name, err) } c := &tunConn{ @@ -124,3 +120,7 @@ func (c *tunConn) Close() error { return err } + +func init() { + register("tun", tunSocket{}) +} diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index b6c5ada..8b86ddc 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -11,6 +11,7 @@ import ( "sync" "sync/atomic" "time" + "tunnel/pkg/config" "tunnel/pkg/server/env" "tunnel/pkg/server/hook" @@ -92,6 +93,8 @@ func (t *tunnel) stopStreams() { } func (t *tunnel) Close() { + t.env.Detach() + t.stopServe() t.stopStreams() @@ -139,7 +142,7 @@ func (t *tunnel) openPipes(env env.Env) ([]*hook.Pipe, error) { } for _, h := range t.hooks { - p, err := h.Open(env) + p, err := h.New(env) if err != nil { cleanup() return nil, fmt.Errorf("%s: %w", h, err) @@ -159,11 +162,11 @@ func (t *tunnel) serve() { env.Set("tunnel", t.id) env.Set("stream", strconv.Itoa(t.nextSid)) - if in, err := t.in.Open(env); err != nil { + if in, err := t.in.New(env); err != nil { if t.alive() { log.Println(t, err) } - } else if out, err := t.out.Open(env); err != nil { + } else if out, err := t.out.New(env); err != nil { log.Println(t, err) in.Close() } else if pipes, err := t.openPipes(env); err != nil { @@ -377,47 +380,41 @@ func parseHooks(args []string) ([]hook.H, error) { return hooks, nil } -func newTunnel(limit int, args []string, env env.Env) (*tunnel, error) { - var in, out socket.S - var hooks []hook.H - var err error +func (t *tunnel) init(limit int, args []string, env env.Env) (err error) { + t.env = env.Fork("tunnel", t.id) + defer func() { + if err != nil { + t.env.Detach() + } + }() + + closeOnFail := func(s socket.S) { + if err != nil { + s.Close() + } + } n := len(args) - 1 - if in, err = socket.New(args[0]); err != nil { - return nil, err + if t.in, err = socket.New(args[0], t.env); err != nil { + return } + defer closeOnFail(t.in) - if _, ok := in.(socket.Single); ok { + if _, ok := t.in.(socket.Single); ok { limit = 1 } - if out, err = socket.New(args[n]); err != nil { - in.Close() - return nil, err + if t.out, err = socket.New(args[n], t.env); err != nil { + return } + defer closeOnFail(t.out) - if hooks, err = parseHooks(args[1:n]); err != nil { - in.Close() - out.Close() - return nil, err + if t.hooks, err = parseHooks(args[1:n]); err != nil { + return } - t := &tunnel{ - args: strings.Join(args, " "), - quit: make(chan struct{}), - done: make(chan struct{}), - hooks: hooks, - in: in, - out: out, - env: env, - queue: make(chan struct{}, limit), - streams: make(map[int]*stream), - } - - go t.serve() - - return t, nil + return } func isOkTunnelName(s string) bool { @@ -471,18 +468,27 @@ func tunnelAdd(r *request) { r.Fatal("not enough args") } - t, err := newTunnel(limit, args, r.c.s.env) - if err != nil { - r.Fatal(err) + if name == "" { + name = r.c.s.tunnels.next() } - if name == "" { - t.id = r.c.s.tunnels.add(t) - } else { - t.id = name - r.c.s.tunnels[t.id] = t + t := &tunnel{ + id: name, + args: strings.Join(args, " "), + quit: make(chan struct{}), + done: make(chan struct{}), + queue: make(chan struct{}, limit), + streams: make(map[int]*stream), } + if err := t.init(limit, args, r.c.s.env); err != nil { + r.Fatal(err) + } + + r.c.s.tunnels[t.id] = t + + go t.serve() + log.Println(r.c, r, t, "create") } @@ -584,6 +590,12 @@ func showHooks(r *request) { } } +func showSockets(r *request) { + for _, s := range socket.GetList() { + r.Println(s) + } +} + func init() { newCmd(tunnelAdd, "add") newCmd(tunnelDel, "del") @@ -591,6 +603,7 @@ func init() { newCmd(tunnelRename, "rename") newCmd(showHooks, "hooks") + newCmd(showSockets, "sockets") newCmd(showTunnels, "show") diff --git a/pkg/test/auth_test.go b/pkg/test/auth_test.go new file mode 100644 index 0000000..cf5f28b --- /dev/null +++ b/pkg/test/auth_test.go @@ -0,0 +1,36 @@ +package test + +import ( + "testing" +) + +func TestAuthHook(t *testing.T) { + const msg = "Hello, World!" + + c, s := newClientServer(t) + defer closeClientServer(c, s) + + listen := xListen(t, "tcp", "127.0.0.1:0") + defer listen.Close() + + xaddr := c.AddListenTunnel("X", "-aes -auth dial,addr=%s", listen.Addr()) + taddr := c.AddListenTunnel("T", "auth aes dial,addr=%s", xaddr) + + c.Exec("set tunnel.X.secret secret") + c.Exec("set tunnel.T.secret secret") + + out := xDial(t, "tcp", taddr) + defer out.Close() + + in := xAccept(t, listen) + defer in.Close() + + xWrite(t, out, msg) + + buf := make([]byte, len(msg)) + xReadFull(t, in, buf) + + if r := string(buf); r != msg { + t.Fatalf("wrong reply: send '%s', recv '%s'", msg, r) + } +} diff --git a/pkg/test/env_test.go b/pkg/test/env_test.go index b03a5c5..5d6f03e 100644 --- a/pkg/test/env_test.go +++ b/pkg/test/env_test.go @@ -7,19 +7,10 @@ import ( func TestEnv(t *testing.T) { const msg = "Hello, World!" - c, s, err := newClientServer() - if err != nil { - t.Fatal(err) - } - - defer s.Stop() - defer c.Close() - - r, err := c.Send([]string{"echo", msg}) - if err != nil { - t.Fatal(err) - } + c, s := newClientServer(t) + defer closeClientServer(c, s) + r := c.Send("echo %s", msg) if r != msg { t.Errorf("wrong reply: send '%s', recv '%s'", msg, r) } diff --git a/pkg/test/hook_test.go b/pkg/test/hook_test.go new file mode 100644 index 0000000..06204bb --- /dev/null +++ b/pkg/test/hook_test.go @@ -0,0 +1,59 @@ +package test + +import ( + "testing" + + "encoding/hex" + "strings" +) + +func TestUpperHook(t *testing.T) { + const msg = "Hello, World!" + + c, s := newClientServer(t) + defer closeClientServer(c, s) + + tunnel := "add name %s listen,addr=127.0.0.1:0 upper loop" + c.Exec(tunnel, t.Name()) + + addr := c.Send("get tunnel.%s.listen", t.Name()) + + conn := xDial(t, "tcp", addr) + defer conn.Close() + + xWrite(t, conn, msg) + + buf := make([]byte, len(msg)) + xReadFull(t, conn, buf) + + if r := string(buf); r != strings.ToUpper(msg) { + t.Fatalf("wrong reply: send '%s', recv '%s'", msg, r) + } +} + +func TestHexHook(t *testing.T) { + const msg = "Hello, World!" + + c, s := newClientServer(t) + defer closeClientServer(c, s) + + listen := xListen(t, "tcp", "127.0.0.1:0") + defer listen.Close() + + addr := c.AddListenTunnel(t.Name(), "hex dial,addr=%s", listen.Addr()) + + out := xDial(t, "tcp", addr) + defer out.Close() + + in := xAccept(t, listen) + defer in.Close() + + xWrite(t, out, msg) + + buf := make([]byte, 2*len(msg)) + xReadFull(t, in, buf) + + if r := string(buf); r != hex.EncodeToString([]byte(msg)) { + t.Fatalf("wrong reply: send '%s', recv '%s'", msg, r) + } +} diff --git a/pkg/test/proxy_test.go b/pkg/test/proxy_test.go new file mode 100644 index 0000000..b2fb097 --- /dev/null +++ b/pkg/test/proxy_test.go @@ -0,0 +1,36 @@ +package test + +import ( + "testing" +) + +func TestProxyHook(t *testing.T) { + const msg = "Hello, World!" + + c, s := newClientServer(t) + defer closeClientServer(c, s) + + listen := xListen(t, "tcp", "127.0.0.1:0") + defer listen.Close() + + saddr := c.AddListenTunnel("S", "proxy") + caddr := c.AddListenTunnel("C", "proxy,addr=%s dial,addr=%s", listen.Addr(), saddr) + + c.Exec("set tunnel.S.proxy.auth user:password") + c.Exec("set tunnel.C.proxy.auth user:password") + + out := xDial(t, "tcp", caddr) + defer out.Close() + + in := xAccept(t, listen) + defer in.Close() + + xWrite(t, out, msg) + + buf := make([]byte, len(msg)) + xReadFull(t, in, buf) + + if r := string(buf); r != msg { + t.Fatalf("wrong reply: send '%s', recv '%s'", msg, r) + } +} diff --git a/pkg/test/test.go b/pkg/test/test.go index 6a2c776..89722fd 100644 --- a/pkg/test/test.go +++ b/pkg/test/test.go @@ -1,25 +1,44 @@ package test import ( + "fmt" + "io" + "net" "os" "path/filepath" - "strconv" + "strings" + "testing" + "time" "tunnel/pkg/client" "tunnel/pkg/server" ) -func getSocketPath() string { - s := "tunnel.test." + strconv.Itoa(os.Getpid()) +type env struct { + *testing.T +} + +func getSocketPath(id string) string { + s := fmt.Sprintf("tunnel.%d.test.%s", os.Getpid(), id) return filepath.Join(os.TempDir(), s) } -func newClientServer() (*client.Client, *server.Server, error) { - socket := getSocketPath() +type Client struct { + *client.Client + + t *testing.T +} + +type Server struct { + *server.Server +} + +func newClientServer(t *testing.T) (*Client, *Server) { + socket := getSocketPath(t.Name()) s, err := server.New(socket) if err != nil { - return nil, nil, err + t.Fatal(err) } go s.Serve() @@ -27,8 +46,105 @@ func newClientServer() (*client.Client, *server.Server, error) { c, err := client.New(socket) if err != nil { s.Stop() - return nil, nil, err + t.Fatal(err) + } + + return &Client{c, t}, &Server{s} +} + +func closeClientServer(c *Client, s *Server) { + c.Close() + s.Stop() +} + +func (c *Client) Send(format string, args ...interface{}) string { + s := fmt.Sprintf(format, args...) + t := strings.Split(s, " ") + + r, err := c.Client.Send(t) + if err != nil { + c.t.Fatal(err) } - return c, s, nil + return r +} + +func (c *Client) Exec(format string, args ...interface{}) { + s := c.Send(format, args...) + if s != "" { + c.t.Fatal(s) + } +} + +func (c *Client) AddListenTunnel(name string, format string, args ...interface{}) string { + t := append([]interface{}{name}, args...) + c.Exec("add name %s listen,addr=127.0.0.1:0 "+format, t...) + return c.Send("get tunnel.%s.listen", name) +} + +func xListen(t *testing.T, network, address string) net.Listener { + listen, err := net.Listen(network, address) + if err != nil { + t.Fatal(err) + } + return listen +} + +func xDial(t *testing.T, network, address string) net.Conn { + d := net.Dialer{Timeout: 100 * time.Millisecond} + conn, err := d.Dial(network, address) + if err != nil { + t.Fatal(err) + } + return conn +} + +func xAccept(t *testing.T, listen net.Listener) net.Conn { + var conn net.Conn + + c := make(chan error, 1) + go func() { + var err error + conn, err = listen.Accept() + c <- err + }() + + timer := time.NewTimer(100 * time.Millisecond) + select { + case err := <-c: + if err != nil { + t.Fatal(err) + } + case <-timer.C: + t.Fatal("accept timeout") + } + + return conn +} + +func xWrite(t *testing.T, conn net.Conn, i interface{}) { + var buf []byte + + switch v := i.(type) { + case string: + buf = []byte(v) + case []byte: + buf = v + default: + t.Fatalf("unexpected type %T", i) + } + + conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + if _, err := conn.Write(buf); err != nil { + t.Fatal("write to conn:", err) + } + conn.SetDeadline(time.Time{}) +} + +func xReadFull(t *testing.T, conn net.Conn, buf []byte) { + conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Fatal("read from conn:", err) + } + conn.SetDeadline(time.Time{}) } |
