diff options
Diffstat (limited to 'pkg/server')
| -rw-r--r-- | pkg/server/env.go | 8 | ||||
| -rw-r--r-- | pkg/server/server.go | 17 | ||||
| -rw-r--r-- | pkg/server/socket/socket.go | 26 | ||||
| -rw-r--r-- | pkg/server/socket/tun.go | 133 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 33 |
5 files changed, 184 insertions, 33 deletions
diff --git a/pkg/server/env.go b/pkg/server/env.go index a3c9f49..ea93a0d 100644 --- a/pkg/server/env.go +++ b/pkg/server/env.go @@ -1,5 +1,9 @@ package server +import ( + "strings" +) + func varGet(r *request) { r.expect(1) @@ -11,9 +15,9 @@ func varGet(r *request) { } func varSet(r *request) { - r.expect(2) + value := strings.Join(r.args[1:], " ") - if err := r.c.s.env.Set(r.args[0], r.args[1]); err != nil { + if err := r.c.s.env.Set(r.args[0], value); err != nil { r.Fatal(err) } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 6f3c312..e794b56 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -10,7 +10,6 @@ import ( "strings" "sync" "time" - "tunnel/pkg/config" "tunnel/pkg/netstring" "tunnel/pkg/server/env" ) @@ -37,6 +36,8 @@ type client struct { s *Server conn net.Conn + r *netstring.Decoder + w *netstring.Encoder nextRid int } @@ -133,7 +134,7 @@ func (s *Server) isDone() bool { } func New(path string) (*Server, error) { - listen, err := net.Listen("unixpacket", path) + listen, err := net.Listen("unix", path) if err != nil { return nil, err } @@ -194,6 +195,8 @@ func (s *Server) newClient(conn net.Conn) *client { c := &client{ s: s, conn: conn, + r: netstring.NewDecoder(conn), + w: netstring.NewEncoder(conn), id: s.nextCid, } @@ -234,18 +237,16 @@ func (c *client) newRequest() *request { func (c *client) handle() { defer c.close() - buf := make([]byte, config.BufSize) - for { - nr, er := c.conn.Read(buf) + req, er := c.r.Decode() if er != nil { - if er != io.EOF { + if !errors.Is(er, io.EOF) { log.Println(c, "handle:", er) } break } - args, err := c.decode(buf[:nr]) + args, err := c.decode([]byte(req)) if err != nil { log.Println(c, "decode:", err) break @@ -259,7 +260,7 @@ func (c *client) handle() { r.out.Write([]byte("\n")) } - _, ew := c.conn.Write(r.out.Bytes()) + ew := c.w.Encode(r.out.String()) if ew != nil { log.Println(c, "handle:", ew) break diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index a945ce0..7b5ea73 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -12,7 +12,7 @@ import ( "tunnel/pkg/server/queue" ) -var errAlreadyClosed = errors.New("already closed") +var ErrAlreadyClosed = errors.New("already closed") type exported struct { info string @@ -50,26 +50,12 @@ func newConn(cn net.Conn) Channel { return c } -func (c *conn) final(f func() error, err error) error { - if e := f(); e != nil { - if e == errAlreadyClosed { - return nil - } else { - return e - } - } - - return err -} - func (c *conn) Send(wq queue.Q) error { - err := queue.IoCopy(c, wq.Writer()) - return c.final(c.Close, err) + return queue.IoCopy(c, wq.Writer()) } func (c *conn) Recv(rq queue.Q) error { - err := queue.IoCopy(rq.Reader(), c) - return c.final(c.Close, err) + return queue.IoCopy(rq.Reader(), c) } func (c *conn) String() string { @@ -78,7 +64,7 @@ func (c *conn) String() string { } func (c *conn) Close() error { - err := errAlreadyClosed + err := ErrAlreadyClosed c.once.Do(func() { log.Println("close", c) @@ -113,6 +99,10 @@ func New(desc string, env env.Env) (S, error) { return nil, fmt.Errorf("bad socket '%s'", desc) } + if proto == "tun" { + return newTunSocket(addr) + } + if _, ok := opts["listen"]; ok { return newListenSocket(proto, addr) } diff --git a/pkg/server/socket/tun.go b/pkg/server/socket/tun.go new file mode 100644 index 0000000..78bdfd4 --- /dev/null +++ b/pkg/server/socket/tun.go @@ -0,0 +1,133 @@ +package socket + +import ( + "errors" + "fmt" + "golang.org/x/sys/unix" + "log" + "os" + "strings" + "sync" + "tunnel/pkg/pack" + "tunnel/pkg/server/env" + "tunnel/pkg/server/queue" + "unsafe" +) + +const maxTunBufSize = 65535 + +var errPartialWrite = errors.New("partial write") + +type ifReq struct { + name [unix.IFNAMSIZ]uint8 + flags uint16 +} + +type tunSocket struct { + name string +} + +type tunChannel struct { + name string + s *tunSocket + fp *os.File + once sync.Once +} + +func ioctl(fd int, req uintptr, ptr unsafe.Pointer) error { + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), req, uintptr(ptr)) + if errno != 0 { + return errno + } + + return nil +} + +func newTunSocket(name string) (S, error) { + return &tunSocket{name: name}, nil + +} + +func (s *tunSocket) String() string { + return fmt.Sprintf("tun/%s", s.name) +} + +func (s *tunSocket) Open(env.Env) (Channel, error) { + fd, err := unix.Open("/dev/net/tun", unix.O_RDWR, 0) + if err != nil { + return nil, err + } + + ifr := &ifReq{} + copy(ifr.name[:], s.name) + ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI + + if err := ioctl(fd, unix.TUNSETIFF, unsafe.Pointer(ifr)); err != nil { + unix.Close(fd) + return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.name, err) + } + + if err := unix.SetNonblock(fd, true); err != nil { + unix.Close(fd) + return nil, fmt.Errorf("set nonblock %s: %w", s.name, err) + } + + c := &tunChannel{ + name: strings.Trim(string(ifr.name[:]), "\x00"), + fp: os.NewFile(uintptr(fd), "tun"), + } + + return c, nil +} + +func (s *tunSocket) Close() { +} + +func (c *tunChannel) Send(wq queue.Q) error { + buf := make([]byte, maxTunBufSize) + enc := pack.NewEncoder(wq.Writer()) + + for { + n, err := c.fp.Read(buf) + if err != nil { + return err + } + + enc.Lps(buf[0:n]) + } +} + +func (c *tunChannel) Recv(rq queue.Q) error { + dec := pack.NewDecoder(rq.Reader()) + + for { + b, err := dec.Lps() + if err != nil { + return err + } + + n, err := c.fp.Write(b) + if err != nil { + return err + } + + if n != len(b) { + return errPartialWrite + } + } +} + +func (c *tunChannel) String() string { + return "tun/" + c.name +} + +func (c *tunChannel) Close() error { + err := ErrAlreadyClosed + + c.once.Do(func() { + log.Println("close", c) + err = c.fp.Close() + }) + + return err +} diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index 79734b9..4543714 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -146,10 +146,10 @@ func (t *tunnel) serve() { } } else if t.alive() { log.Println(t, err) - t.sleep(5 * time.Second) } if !ok { + t.sleep(5 * time.Second) t.release() } } @@ -234,7 +234,21 @@ func (s *stream) channel(c socket.Channel, m *metric, rq, wq queue.Q) { watch := func(q queue.Q, f func(q queue.Q) error) { defer s.wg.Done() - if err := f(q); err != nil && !errors.Is(err, io.EOF) { + err := f(q) + + if errors.Is(err, io.EOF) { + err = nil + } + + if e := c.Close(); e != nil { + if e == socket.ErrAlreadyClosed { + err = nil + } else { + err = e + } + } + + if err != nil { log.Println(s.t, s, err) } } @@ -377,7 +391,7 @@ func isOkTunnelName(s string) bool { func tunnelAdd(r *request) { args := r.args name := "" - limit := maxQueueLimit + limit := 1 for len(args) > 1 { if args[0] == "name" { @@ -394,8 +408,17 @@ func tunnelAdd(r *request) { continue } - if args[0] == "mono" { - limit = 1 + if args[0] == "limit" { + if n, _ := strconv.Atoi(args[1]); n > 0 && n < maxQueueLimit { + limit = n + } else { + r.Fatal("bad limit") + } + args = args[2:] + } + + if args[0] == "unlim" { + limit = maxQueueLimit args = args[1:] continue } |
