diff options
Diffstat (limited to 'pkg/server')
| -rw-r--r-- | pkg/server/automap.go | 12 | ||||
| -rw-r--r-- | pkg/server/server.go | 14 | ||||
| -rw-r--r-- | pkg/server/socket/socket.go | 106 | ||||
| -rw-r--r-- | pkg/server/stream.go | 193 | ||||
| -rw-r--r-- | pkg/server/tunnel.go | 294 |
5 files changed, 384 insertions, 235 deletions
diff --git a/pkg/server/automap.go b/pkg/server/automap.go new file mode 100644 index 0000000..15cafe4 --- /dev/null +++ b/pkg/server/automap.go @@ -0,0 +1,12 @@ +package server + +type automap map[int]interface{} + +func (m automap) add(v interface{}) int { + for k := 0;; k++ { + if _, ok := m[k]; !ok { + m[k] = v + return k + } + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 4f012d0..ce910f3 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -24,24 +24,24 @@ type Server struct { once sync.Once done chan struct{} - streams streams + tunnels automap env env - nextCid uint64 + nextCid int } type client struct { - id uint64 + id int s *Server conn net.Conn - nextRid uint64 + nextRid int } type request struct { - id uint64 + id int c *client @@ -135,7 +135,7 @@ func New() (*Server, error) { listen: listen, since: time.Now(), done: make(chan struct{}), - streams: make(streams), + tunnels: make(automap), } return s, nil @@ -307,7 +307,7 @@ func (r *request) run(query string) { r.parse(query) - log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " ")) + log.Println(r.c, r, ">", r.cmd.name, strings.Join(r.args, " ")) r.c.s.mu.Lock() defer r.c.s.mu.Unlock() diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go index f097a80..cad1ad3 100644 --- a/pkg/server/socket/socket.go +++ b/pkg/server/socket/socket.go @@ -21,6 +21,7 @@ type S interface { } type listenSocket struct { + proto, addr string listen net.Listener } @@ -31,11 +32,10 @@ type dialSocket struct { type connChannel struct { conn net.Conn once sync.Once - cancel chan struct{} } func newConnChannel(conn net.Conn) Channel { - return &connChannel{conn: conn, cancel: make(chan struct{})} + return &connChannel{conn: conn} } func (cc *connChannel) Send(wq queue.Q) (err error) { @@ -60,31 +60,23 @@ func (cc *connChannel) Recv(rq queue.Q) (err error) { } func (cc *connChannel) String() string { - addr := cc.conn.RemoteAddr() - return fmt.Sprintf("%s/%s", addr.Network(), addr.String()) -} - -func (cc *connChannel) isCanceled() bool { - select { - case <- cc.cancel: - return true - default: - return false - } + local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr() + return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote) } func (cc *connChannel) shutdown(err *error) { - select { - case <- cc.cancel: + miss := true + + cc.once.Do(func () { + miss = false + log.Println("close", cc) + if e := cc.conn.Close(); e != nil && *err != nil { + *err = e + } + }) + + if miss { *err = nil - default: - cc.once.Do(func () { - close(cc.cancel) - log.Println("close", cc) - if e := cc.conn.Close(); e != nil && *err != nil { - *err = e - } - }) } } @@ -94,8 +86,10 @@ func (cc *connChannel) Close() { } func newListenSocket(proto, addr string) (S, error) { - if !strings.Contains(addr, ":") { - addr = ":" + addr + if proto == "tcp" { + if !strings.Contains(addr, ":") { + addr = ":" + addr + } } listen, err := net.Listen(proto, addr) @@ -103,7 +97,13 @@ func newListenSocket(proto, addr string) (S, error) { return nil, err } - return &listenSocket{listen: listen}, nil + s := &listenSocket{ + proto: proto, + addr: addr, + listen: listen, + } + + return s, nil } func (s *listenSocket) Open() (Channel, error) { @@ -114,14 +114,29 @@ func (s *listenSocket) Open() (Channel, error) { return newConnChannel(conn), nil } +func (s *listenSocket) String() string { + return fmt.Sprintf("%s/%s,listen", s.proto, s.addr) +} + func (s *listenSocket) Close() { s.listen.Close() } func newDialSocket(proto, addr string) (S, error) { + switch proto { + case "tcp", "udp": + if !strings.Contains(addr, ":") { + addr = "localhost:" + addr + } + } + return &dialSocket{proto: proto, addr: addr}, nil } +func (s *dialSocket) String() string { + return fmt.Sprintf("%s/%s", s.proto, s.addr) +} + func (s *dialSocket) Open() (Channel, error) { conn, err := net.Dial(s.proto, s.addr) if err != nil { @@ -133,19 +148,40 @@ func (s *dialSocket) Open() (Channel, error) { func (s *dialSocket) Close() { } -func New(desc string) (S, error) { - args := strings.Split(desc, "/") +func New(name string) (S, error) { + vv := strings.Split(name, ",") + args := strings.Split(vv[0], "/") + opts := map[string]string{} - if len(args) != 2 { - return nil, fmt.Errorf("bad socket '%s'", desc) + for _, v := range vv[1:] { + ss := strings.SplitN(v, "=", 2) + if len(ss) < 2 { + opts[ss[0]] = "" + } else { + opts[ss[0]] = ss[1] + } } - proto, addr := args[0], args[1] + var proto string + var addr string - switch proto { - case "tcp-listen": return newListenSocket("tcp", addr) - case "tcp": return newDialSocket("tcp", addr) + if len(args) < 2 { + addr = args[0] + } else { + proto, addr = args[0], args[1] + } + + if proto == "" { + proto = "tcp" + } + + if addr == "" { + return nil, fmt.Errorf("bad socket '%s'", name) + } + + if _, ok := opts["listen"]; ok { + return newListenSocket(proto, addr) } - return nil, fmt.Errorf("bad socket '%s': unknown type", desc) + return newDialSocket(proto, addr) } diff --git a/pkg/server/stream.go b/pkg/server/stream.go deleted file mode 100644 index 7c9cc82..0000000 --- a/pkg/server/stream.go +++ /dev/null @@ -1,193 +0,0 @@ -package server - -import ( - "tunnel/pkg/server/module" - "tunnel/pkg/server/queue" - "tunnel/pkg/server/socket" - "strings" - "sort" - "fmt" - "log" -) - -type stream struct { - id string - args string - - in, out socket.S - m []module.M -} - -type streams map[string]*stream - -func (s *stream) String() string { - return fmt.Sprintf("stream(%s)", s.id) -} - -func (s *stream) Close() { - s.in.Close() - s.out.Close() -} - -func (s *stream) run() { - for { - if in, err := s.in.Open(); err != nil { - log.Println(s, err) - } else { - log.Printf("%s accept %s", s, in) - go s.run2(in) - } - } -} - -func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { - watch := func (q queue.Q, f func (q queue.Q) error) { - if err := f(q); err != nil { - log.Println(s, err) - } - } - - go func () { - watch(wq, c.Send) - close(wq) - }() - - go watch(rq, c.Recv) -} - -func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) { - go func () { - if err := f(rq, wq); err != nil { - log.Println(s, err) - } - - close(wq) - }() -} - -func (s *stream) run2(in socket.Channel) { - out, err := s.out.Open() - if err != nil { - log.Println(s, err) - in.Close() - return - } - - rq, wq := queue.New(), queue.New() - - s.watchChannel(rq, wq, in) - - for _, m := range s.m { - send, recv := m.Open() - if send != nil { - q := queue.New() - s.watchPipe(wq, q, send) - wq = q - } - if recv != nil { - q := queue.New() - s.watchPipe(q, rq, recv) - rq = q - } - } - - s.watchChannel(wq, rq, out) -} - -func newStream(id string, args []string) (*stream, error) { - var in, out socket.S - var err error - - n := len(args) - 1 - - if in, err = socket.New(args[0]); err != nil { - return nil, err - } - - if out, err = socket.New(args[n]); err != nil { - in.Close() - return nil, err - } - - s := &stream{ - id: id, - args: strings.Join(args, " "), - in: in, - out: out, - } - - reverse := false - - for _, arg := range args[1:n] { - var m module.M - - if arg == "-" { - reverse = true - continue - } - - if arg == "+" { - reverse = false - continue - } - - if m, err = module.New(arg); err != nil { - s.Close() - return nil, err - } - - if reverse { - m = module.Reverse(m) - reverse = false - } - - s.m = append(s.m, m) - } - - if reverse { - s.Close() - return nil, fmt.Errorf("bad '-' usage") - } - - go s.run() - - return s, nil -} - -func streamAdd(r *request) { - if r.argc < 3 { - r.Fatal("not enough args") - } - - id := r.args[0] - if _, ok := r.c.s.streams[id]; ok { - r.Fatal("duplicate id") - } - - s, err := newStream(id, r.args[1:]) - if err != nil { - r.Fatal(err) - } - - r.c.s.streams[id] = s -} - -func streamShow(r *request) { - var keys []string - - for k := range r.c.s.streams { - keys = append(keys, k) - } - - sort.Strings(keys) - - for _, k := range keys { - s := r.c.s.streams[k] - r.Println(s.id, s.args) - } -} - -func init() { - newCmd(streamAdd, "add") - newCmd(streamShow, "show") -} diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go new file mode 100644 index 0000000..b1c3fe8 --- /dev/null +++ b/pkg/server/tunnel.go @@ -0,0 +1,294 @@ +package server + +import ( + "tunnel/pkg/server/module" + "tunnel/pkg/server/queue" + "tunnel/pkg/server/socket" + "tunnel/pkg/config" + "sync/atomic" + "strings" + "time" + "sort" + "sync" + "fmt" + "log" +) + +type stream struct { + id int + n int32 + t *tunnel + since time.Time + in, out socket.Channel +} + +type tunnel struct { + id int + args string + + streams map[int]*stream + + mu sync.Mutex + + nextSid int + + in, out socket.S + m []module.M +} + +func (s *stream) String() string { + return fmt.Sprintf("stream(%d)", s.id) +} + +func (t *tunnel) String() string { + return fmt.Sprintf("tunnel(%d)", t.id) +} + +func (t *tunnel) Close() { + t.in.Close() + t.out.Close() +} + +func (t *tunnel) run() { + for { + if in, err := t.in.Open(); err != nil { + log.Println(t, err) + time.Sleep(5 * time.Second) + } else { + log.Println(t, "open", in) + go t.run2(in) + } + } +} + +func (t *tunnel) newStream(in, out socket.Channel) *stream { + s := &stream{ + t: t, + in: in, + out: out, + id: t.nextSid, + since: time.Now(), + } + + t.mu.Lock() + t.nextSid++ + t.streams[s.id] = s + t.mu.Unlock() + + return s +} + +func (s *stream) ref() { + atomic.AddInt32(&s.n, 1) +} + +func (s *stream) unref() { + if atomic.AddInt32(&s.n, -1) == 0 { + log.Println(s.t, s, "close") + + s.t.mu.Lock() + delete(s.t.streams, s.id) + s.t.mu.Unlock() + } +} + +func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) { + watch := func (q queue.Q, f func (q queue.Q) error) { + s.ref() + defer s.unref() + + if err := f(q); err != nil { + log.Println(s.t, s, err) + } + } + + go func () { + watch(wq, c.Send) + close(wq) + }() + + go watch(rq, c.Recv) +} + +func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) { + go func () { + s.ref() + defer s.unref() + + if err := f(rq, wq); err != nil { + log.Println(s.t, s, err) + } + + close(wq) + }() +} + +func (t *tunnel) run2(in socket.Channel) { + out, err := t.out.Open() + if err != nil { + log.Println(t, err) + in.Close() + return + } + + log.Println(t, "open", out) + + s := t.newStream(in, out) + + s.ref() + defer s.unref() + + rq, wq := queue.New(), queue.New() + + s.watchChannel(rq, wq, in) + + for _, m := range t.m { + send, recv := m.Open() + if send != nil { + q := queue.New() + s.watchPipe(wq, q, send) + wq = q + } + if recv != nil { + q := queue.New() + s.watchPipe(q, rq, recv) + rq = q + } + } + + s.watchChannel(wq, rq, out) + + log.Println(t, s, "create", in, out) +} + +func newTunnel(args []string) (*tunnel, error) { + var in, out socket.S + var err error + + n := len(args) - 1 + + if in, err = socket.New(args[0]); err != nil { + return nil, err + } + + if out, err = socket.New(args[n]); err != nil { + in.Close() + return nil, err + } + + t := &tunnel{ + args: strings.Join(args, " "), + in: in, + out: out, + streams: make(map[int]*stream), + } + + reverse := false + + for _, arg := range args[1:n] { + var m module.M + + if arg == "-" { + reverse = true + continue + } + + if arg == "+" { + reverse = false + continue + } + + if m, err = module.New(arg); err != nil { + t.Close() + return nil, err + } + + if reverse { + m = module.Reverse(m) + reverse = false + } + + t.m = append(t.m, m) + } + + if reverse { + t.Close() + return nil, fmt.Errorf("bad '-' usage") + } + + go t.run() + + return t, nil +} + +func tunnelAdd(r *request) { + if r.argc < 2 { + r.Fatal("not enough args") + } + + t, err := newTunnel(r.args) + if err != nil { + r.Fatal(err) + } + + log.Println(r.c, r, t, "create") + + t.id = r.c.s.tunnels.add(t) +} + +func foreachTunnel(m automap, f func (t *tunnel)) { + var keys []int + + for k := range m { + keys = append(keys, k) + } + + sort.Ints(keys) + + for _, k := range keys { + f(m[k].(*tunnel)) + } +} + +func foreachStream(m map[int]*stream, f func (s *stream)) { + var keys []int + + for k := range m { + keys = append(keys, k) + } + + sort.Ints(keys) + + for _, k := range keys { + f(m[k]) + } +} + +func tunnelShow(r *request) { + foreachTunnel(r.c.s.tunnels, func (t *tunnel) { + r.Println(t.id, t.args) + }) +} + +func streamShow(r *request) { + foreachTunnel(r.c.s.tunnels, func (t *tunnel) { + r.Println(t.id, t.args) + + t.mu.Lock() + if len(t.streams) == 0 { + r.Println("\t", "nothing") + } else { + foreachStream(t.streams, func (s *stream) { + tm := s.since.Format(config.TimeFormat) + r.Println("\t", s.id, tm, s.in, s.out) + }) + } + t.mu.Unlock() + }) +} + +func init() { + newCmd(tunnelAdd, "add") + newCmd(tunnelShow, "show") + newCmd(streamShow, "stream show") +} |
