summaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/client/client.go22
-rw-r--r--pkg/config/config.go2
-rw-r--r--pkg/server/env.go8
-rw-r--r--pkg/server/server.go17
-rw-r--r--pkg/server/socket/socket.go26
-rw-r--r--pkg/server/socket/tun.go133
-rw-r--r--pkg/server/tunnel.go33
7 files changed, 198 insertions, 43 deletions
diff --git a/pkg/client/client.go b/pkg/client/client.go
index 66aa745..eee397f 100644
--- a/pkg/client/client.go
+++ b/pkg/client/client.go
@@ -14,15 +14,23 @@ var errClosed = errors.New("server closed connection")
type Client struct {
conn net.Conn
+ r *netstring.Decoder
+ w *netstring.Encoder
}
func New(path string) (*Client, error) {
- conn, err := net.Dial("unixpacket", path)
+ conn, err := net.Dial("unix", path)
if err != nil {
return nil, err
}
- return &Client{conn: conn}, nil
+ c := &Client{
+ conn: conn,
+ r: netstring.NewDecoder(conn),
+ w: netstring.NewEncoder(conn),
+ }
+
+ return c, nil
}
func (c *Client) Send(args []string) (string, error) {
@@ -40,23 +48,21 @@ func (c *Client) Send(args []string) (string, error) {
enc.Encode(s)
}
- buf := make([]byte, config.BufSize)
-
- _, ew := c.conn.Write(out.Bytes())
+ ew := c.w.Encode(out.String())
if ew != nil {
return "", ew
}
- nr, er := c.conn.Read(buf)
+ resp, er := c.r.Decode()
if er != nil {
- if er == io.EOF {
+ if errors.Is(er, io.EOF) {
return "", errClosed
}
return "", er
}
- return string(buf[:nr]), nil
+ return resp, nil
}
func (c *Client) Close() {
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 5184b7e..0e58553 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -8,8 +8,6 @@ import (
const TimeFormat = "2006-01-02/15:04:05"
-const BufSize = 64 * 1024
-
const IoTimeout = 5 * time.Second
func GetSystemSocketPath() string {
diff --git a/pkg/server/env.go b/pkg/server/env.go
index a3c9f49..ea93a0d 100644
--- a/pkg/server/env.go
+++ b/pkg/server/env.go
@@ -1,5 +1,9 @@
package server
+import (
+ "strings"
+)
+
func varGet(r *request) {
r.expect(1)
@@ -11,9 +15,9 @@ func varGet(r *request) {
}
func varSet(r *request) {
- r.expect(2)
+ value := strings.Join(r.args[1:], " ")
- if err := r.c.s.env.Set(r.args[0], r.args[1]); err != nil {
+ if err := r.c.s.env.Set(r.args[0], value); err != nil {
r.Fatal(err)
}
}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 6f3c312..e794b56 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -10,7 +10,6 @@ import (
"strings"
"sync"
"time"
- "tunnel/pkg/config"
"tunnel/pkg/netstring"
"tunnel/pkg/server/env"
)
@@ -37,6 +36,8 @@ type client struct {
s *Server
conn net.Conn
+ r *netstring.Decoder
+ w *netstring.Encoder
nextRid int
}
@@ -133,7 +134,7 @@ func (s *Server) isDone() bool {
}
func New(path string) (*Server, error) {
- listen, err := net.Listen("unixpacket", path)
+ listen, err := net.Listen("unix", path)
if err != nil {
return nil, err
}
@@ -194,6 +195,8 @@ func (s *Server) newClient(conn net.Conn) *client {
c := &client{
s: s,
conn: conn,
+ r: netstring.NewDecoder(conn),
+ w: netstring.NewEncoder(conn),
id: s.nextCid,
}
@@ -234,18 +237,16 @@ func (c *client) newRequest() *request {
func (c *client) handle() {
defer c.close()
- buf := make([]byte, config.BufSize)
-
for {
- nr, er := c.conn.Read(buf)
+ req, er := c.r.Decode()
if er != nil {
- if er != io.EOF {
+ if !errors.Is(er, io.EOF) {
log.Println(c, "handle:", er)
}
break
}
- args, err := c.decode(buf[:nr])
+ args, err := c.decode([]byte(req))
if err != nil {
log.Println(c, "decode:", err)
break
@@ -259,7 +260,7 @@ func (c *client) handle() {
r.out.Write([]byte("\n"))
}
- _, ew := c.conn.Write(r.out.Bytes())
+ ew := c.w.Encode(r.out.String())
if ew != nil {
log.Println(c, "handle:", ew)
break
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index a945ce0..7b5ea73 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -12,7 +12,7 @@ import (
"tunnel/pkg/server/queue"
)
-var errAlreadyClosed = errors.New("already closed")
+var ErrAlreadyClosed = errors.New("already closed")
type exported struct {
info string
@@ -50,26 +50,12 @@ func newConn(cn net.Conn) Channel {
return c
}
-func (c *conn) final(f func() error, err error) error {
- if e := f(); e != nil {
- if e == errAlreadyClosed {
- return nil
- } else {
- return e
- }
- }
-
- return err
-}
-
func (c *conn) Send(wq queue.Q) error {
- err := queue.IoCopy(c, wq.Writer())
- return c.final(c.Close, err)
+ return queue.IoCopy(c, wq.Writer())
}
func (c *conn) Recv(rq queue.Q) error {
- err := queue.IoCopy(rq.Reader(), c)
- return c.final(c.Close, err)
+ return queue.IoCopy(rq.Reader(), c)
}
func (c *conn) String() string {
@@ -78,7 +64,7 @@ func (c *conn) String() string {
}
func (c *conn) Close() error {
- err := errAlreadyClosed
+ err := ErrAlreadyClosed
c.once.Do(func() {
log.Println("close", c)
@@ -113,6 +99,10 @@ func New(desc string, env env.Env) (S, error) {
return nil, fmt.Errorf("bad socket '%s'", desc)
}
+ if proto == "tun" {
+ return newTunSocket(addr)
+ }
+
if _, ok := opts["listen"]; ok {
return newListenSocket(proto, addr)
}
diff --git a/pkg/server/socket/tun.go b/pkg/server/socket/tun.go
new file mode 100644
index 0000000..78bdfd4
--- /dev/null
+++ b/pkg/server/socket/tun.go
@@ -0,0 +1,133 @@
+package socket
+
+import (
+ "errors"
+ "fmt"
+ "golang.org/x/sys/unix"
+ "log"
+ "os"
+ "strings"
+ "sync"
+ "tunnel/pkg/pack"
+ "tunnel/pkg/server/env"
+ "tunnel/pkg/server/queue"
+ "unsafe"
+)
+
+const maxTunBufSize = 65535
+
+var errPartialWrite = errors.New("partial write")
+
+type ifReq struct {
+ name [unix.IFNAMSIZ]uint8
+ flags uint16
+}
+
+type tunSocket struct {
+ name string
+}
+
+type tunChannel struct {
+ name string
+ s *tunSocket
+ fp *os.File
+ once sync.Once
+}
+
+func ioctl(fd int, req uintptr, ptr unsafe.Pointer) error {
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), req, uintptr(ptr))
+ if errno != 0 {
+ return errno
+ }
+
+ return nil
+}
+
+func newTunSocket(name string) (S, error) {
+ return &tunSocket{name: name}, nil
+
+}
+
+func (s *tunSocket) String() string {
+ return fmt.Sprintf("tun/%s", s.name)
+}
+
+func (s *tunSocket) Open(env.Env) (Channel, error) {
+ fd, err := unix.Open("/dev/net/tun", unix.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ ifr := &ifReq{}
+ copy(ifr.name[:], s.name)
+ ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI
+
+ if err := ioctl(fd, unix.TUNSETIFF, unsafe.Pointer(ifr)); err != nil {
+ unix.Close(fd)
+ return nil, fmt.Errorf("ioctl TUNSETIFF %s: %w", s.name, err)
+ }
+
+ if err := unix.SetNonblock(fd, true); err != nil {
+ unix.Close(fd)
+ return nil, fmt.Errorf("set nonblock %s: %w", s.name, err)
+ }
+
+ c := &tunChannel{
+ name: strings.Trim(string(ifr.name[:]), "\x00"),
+ fp: os.NewFile(uintptr(fd), "tun"),
+ }
+
+ return c, nil
+}
+
+func (s *tunSocket) Close() {
+}
+
+func (c *tunChannel) Send(wq queue.Q) error {
+ buf := make([]byte, maxTunBufSize)
+ enc := pack.NewEncoder(wq.Writer())
+
+ for {
+ n, err := c.fp.Read(buf)
+ if err != nil {
+ return err
+ }
+
+ enc.Lps(buf[0:n])
+ }
+}
+
+func (c *tunChannel) Recv(rq queue.Q) error {
+ dec := pack.NewDecoder(rq.Reader())
+
+ for {
+ b, err := dec.Lps()
+ if err != nil {
+ return err
+ }
+
+ n, err := c.fp.Write(b)
+ if err != nil {
+ return err
+ }
+
+ if n != len(b) {
+ return errPartialWrite
+ }
+ }
+}
+
+func (c *tunChannel) String() string {
+ return "tun/" + c.name
+}
+
+func (c *tunChannel) Close() error {
+ err := ErrAlreadyClosed
+
+ c.once.Do(func() {
+ log.Println("close", c)
+ err = c.fp.Close()
+ })
+
+ return err
+}
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
index 79734b9..4543714 100644
--- a/pkg/server/tunnel.go
+++ b/pkg/server/tunnel.go
@@ -146,10 +146,10 @@ func (t *tunnel) serve() {
}
} else if t.alive() {
log.Println(t, err)
- t.sleep(5 * time.Second)
}
if !ok {
+ t.sleep(5 * time.Second)
t.release()
}
}
@@ -234,7 +234,21 @@ func (s *stream) channel(c socket.Channel, m *metric, rq, wq queue.Q) {
watch := func(q queue.Q, f func(q queue.Q) error) {
defer s.wg.Done()
- if err := f(q); err != nil && !errors.Is(err, io.EOF) {
+ err := f(q)
+
+ if errors.Is(err, io.EOF) {
+ err = nil
+ }
+
+ if e := c.Close(); e != nil {
+ if e == socket.ErrAlreadyClosed {
+ err = nil
+ } else {
+ err = e
+ }
+ }
+
+ if err != nil {
log.Println(s.t, s, err)
}
}
@@ -377,7 +391,7 @@ func isOkTunnelName(s string) bool {
func tunnelAdd(r *request) {
args := r.args
name := ""
- limit := maxQueueLimit
+ limit := 1
for len(args) > 1 {
if args[0] == "name" {
@@ -394,8 +408,17 @@ func tunnelAdd(r *request) {
continue
}
- if args[0] == "mono" {
- limit = 1
+ if args[0] == "limit" {
+ if n, _ := strconv.Atoi(args[1]); n > 0 && n < maxQueueLimit {
+ limit = n
+ } else {
+ r.Fatal("bad limit")
+ }
+ args = args[2:]
+ }
+
+ if args[0] == "unlim" {
+ limit = maxQueueLimit
args = args[1:]
continue
}