summaryrefslogtreecommitdiff
path: root/pkg/server
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/server')
-rw-r--r--pkg/server/automap.go12
-rw-r--r--pkg/server/server.go14
-rw-r--r--pkg/server/socket/socket.go106
-rw-r--r--pkg/server/stream.go193
-rw-r--r--pkg/server/tunnel.go294
5 files changed, 384 insertions, 235 deletions
diff --git a/pkg/server/automap.go b/pkg/server/automap.go
new file mode 100644
index 0000000..15cafe4
--- /dev/null
+++ b/pkg/server/automap.go
@@ -0,0 +1,12 @@
+package server
+
+type automap map[int]interface{}
+
+func (m automap) add(v interface{}) int {
+ for k := 0;; k++ {
+ if _, ok := m[k]; !ok {
+ m[k] = v
+ return k
+ }
+ }
+}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 4f012d0..ce910f3 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -24,24 +24,24 @@ type Server struct {
once sync.Once
done chan struct{}
- streams streams
+ tunnels automap
env env
- nextCid uint64
+ nextCid int
}
type client struct {
- id uint64
+ id int
s *Server
conn net.Conn
- nextRid uint64
+ nextRid int
}
type request struct {
- id uint64
+ id int
c *client
@@ -135,7 +135,7 @@ func New() (*Server, error) {
listen: listen,
since: time.Now(),
done: make(chan struct{}),
- streams: make(streams),
+ tunnels: make(automap),
}
return s, nil
@@ -307,7 +307,7 @@ func (r *request) run(query string) {
r.parse(query)
- log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " "))
+ log.Println(r.c, r, ">", r.cmd.name, strings.Join(r.args, " "))
r.c.s.mu.Lock()
defer r.c.s.mu.Unlock()
diff --git a/pkg/server/socket/socket.go b/pkg/server/socket/socket.go
index f097a80..cad1ad3 100644
--- a/pkg/server/socket/socket.go
+++ b/pkg/server/socket/socket.go
@@ -21,6 +21,7 @@ type S interface {
}
type listenSocket struct {
+ proto, addr string
listen net.Listener
}
@@ -31,11 +32,10 @@ type dialSocket struct {
type connChannel struct {
conn net.Conn
once sync.Once
- cancel chan struct{}
}
func newConnChannel(conn net.Conn) Channel {
- return &connChannel{conn: conn, cancel: make(chan struct{})}
+ return &connChannel{conn: conn}
}
func (cc *connChannel) Send(wq queue.Q) (err error) {
@@ -60,31 +60,23 @@ func (cc *connChannel) Recv(rq queue.Q) (err error) {
}
func (cc *connChannel) String() string {
- addr := cc.conn.RemoteAddr()
- return fmt.Sprintf("%s/%s", addr.Network(), addr.String())
-}
-
-func (cc *connChannel) isCanceled() bool {
- select {
- case <- cc.cancel:
- return true
- default:
- return false
- }
+ local, remote := cc.conn.LocalAddr(), cc.conn.RemoteAddr()
+ return fmt.Sprintf("%s/%s->%s", local.Network(), local, remote)
}
func (cc *connChannel) shutdown(err *error) {
- select {
- case <- cc.cancel:
+ miss := true
+
+ cc.once.Do(func () {
+ miss = false
+ log.Println("close", cc)
+ if e := cc.conn.Close(); e != nil && *err != nil {
+ *err = e
+ }
+ })
+
+ if miss {
*err = nil
- default:
- cc.once.Do(func () {
- close(cc.cancel)
- log.Println("close", cc)
- if e := cc.conn.Close(); e != nil && *err != nil {
- *err = e
- }
- })
}
}
@@ -94,8 +86,10 @@ func (cc *connChannel) Close() {
}
func newListenSocket(proto, addr string) (S, error) {
- if !strings.Contains(addr, ":") {
- addr = ":" + addr
+ if proto == "tcp" {
+ if !strings.Contains(addr, ":") {
+ addr = ":" + addr
+ }
}
listen, err := net.Listen(proto, addr)
@@ -103,7 +97,13 @@ func newListenSocket(proto, addr string) (S, error) {
return nil, err
}
- return &listenSocket{listen: listen}, nil
+ s := &listenSocket{
+ proto: proto,
+ addr: addr,
+ listen: listen,
+ }
+
+ return s, nil
}
func (s *listenSocket) Open() (Channel, error) {
@@ -114,14 +114,29 @@ func (s *listenSocket) Open() (Channel, error) {
return newConnChannel(conn), nil
}
+func (s *listenSocket) String() string {
+ return fmt.Sprintf("%s/%s,listen", s.proto, s.addr)
+}
+
func (s *listenSocket) Close() {
s.listen.Close()
}
func newDialSocket(proto, addr string) (S, error) {
+ switch proto {
+ case "tcp", "udp":
+ if !strings.Contains(addr, ":") {
+ addr = "localhost:" + addr
+ }
+ }
+
return &dialSocket{proto: proto, addr: addr}, nil
}
+func (s *dialSocket) String() string {
+ return fmt.Sprintf("%s/%s", s.proto, s.addr)
+}
+
func (s *dialSocket) Open() (Channel, error) {
conn, err := net.Dial(s.proto, s.addr)
if err != nil {
@@ -133,19 +148,40 @@ func (s *dialSocket) Open() (Channel, error) {
func (s *dialSocket) Close() {
}
-func New(desc string) (S, error) {
- args := strings.Split(desc, "/")
+func New(name string) (S, error) {
+ vv := strings.Split(name, ",")
+ args := strings.Split(vv[0], "/")
+ opts := map[string]string{}
- if len(args) != 2 {
- return nil, fmt.Errorf("bad socket '%s'", desc)
+ for _, v := range vv[1:] {
+ ss := strings.SplitN(v, "=", 2)
+ if len(ss) < 2 {
+ opts[ss[0]] = ""
+ } else {
+ opts[ss[0]] = ss[1]
+ }
}
- proto, addr := args[0], args[1]
+ var proto string
+ var addr string
- switch proto {
- case "tcp-listen": return newListenSocket("tcp", addr)
- case "tcp": return newDialSocket("tcp", addr)
+ if len(args) < 2 {
+ addr = args[0]
+ } else {
+ proto, addr = args[0], args[1]
+ }
+
+ if proto == "" {
+ proto = "tcp"
+ }
+
+ if addr == "" {
+ return nil, fmt.Errorf("bad socket '%s'", name)
+ }
+
+ if _, ok := opts["listen"]; ok {
+ return newListenSocket(proto, addr)
}
- return nil, fmt.Errorf("bad socket '%s': unknown type", desc)
+ return newDialSocket(proto, addr)
}
diff --git a/pkg/server/stream.go b/pkg/server/stream.go
deleted file mode 100644
index 7c9cc82..0000000
--- a/pkg/server/stream.go
+++ /dev/null
@@ -1,193 +0,0 @@
-package server
-
-import (
- "tunnel/pkg/server/module"
- "tunnel/pkg/server/queue"
- "tunnel/pkg/server/socket"
- "strings"
- "sort"
- "fmt"
- "log"
-)
-
-type stream struct {
- id string
- args string
-
- in, out socket.S
- m []module.M
-}
-
-type streams map[string]*stream
-
-func (s *stream) String() string {
- return fmt.Sprintf("stream(%s)", s.id)
-}
-
-func (s *stream) Close() {
- s.in.Close()
- s.out.Close()
-}
-
-func (s *stream) run() {
- for {
- if in, err := s.in.Open(); err != nil {
- log.Println(s, err)
- } else {
- log.Printf("%s accept %s", s, in)
- go s.run2(in)
- }
- }
-}
-
-func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) {
- watch := func (q queue.Q, f func (q queue.Q) error) {
- if err := f(q); err != nil {
- log.Println(s, err)
- }
- }
-
- go func () {
- watch(wq, c.Send)
- close(wq)
- }()
-
- go watch(rq, c.Recv)
-}
-
-func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) {
- go func () {
- if err := f(rq, wq); err != nil {
- log.Println(s, err)
- }
-
- close(wq)
- }()
-}
-
-func (s *stream) run2(in socket.Channel) {
- out, err := s.out.Open()
- if err != nil {
- log.Println(s, err)
- in.Close()
- return
- }
-
- rq, wq := queue.New(), queue.New()
-
- s.watchChannel(rq, wq, in)
-
- for _, m := range s.m {
- send, recv := m.Open()
- if send != nil {
- q := queue.New()
- s.watchPipe(wq, q, send)
- wq = q
- }
- if recv != nil {
- q := queue.New()
- s.watchPipe(q, rq, recv)
- rq = q
- }
- }
-
- s.watchChannel(wq, rq, out)
-}
-
-func newStream(id string, args []string) (*stream, error) {
- var in, out socket.S
- var err error
-
- n := len(args) - 1
-
- if in, err = socket.New(args[0]); err != nil {
- return nil, err
- }
-
- if out, err = socket.New(args[n]); err != nil {
- in.Close()
- return nil, err
- }
-
- s := &stream{
- id: id,
- args: strings.Join(args, " "),
- in: in,
- out: out,
- }
-
- reverse := false
-
- for _, arg := range args[1:n] {
- var m module.M
-
- if arg == "-" {
- reverse = true
- continue
- }
-
- if arg == "+" {
- reverse = false
- continue
- }
-
- if m, err = module.New(arg); err != nil {
- s.Close()
- return nil, err
- }
-
- if reverse {
- m = module.Reverse(m)
- reverse = false
- }
-
- s.m = append(s.m, m)
- }
-
- if reverse {
- s.Close()
- return nil, fmt.Errorf("bad '-' usage")
- }
-
- go s.run()
-
- return s, nil
-}
-
-func streamAdd(r *request) {
- if r.argc < 3 {
- r.Fatal("not enough args")
- }
-
- id := r.args[0]
- if _, ok := r.c.s.streams[id]; ok {
- r.Fatal("duplicate id")
- }
-
- s, err := newStream(id, r.args[1:])
- if err != nil {
- r.Fatal(err)
- }
-
- r.c.s.streams[id] = s
-}
-
-func streamShow(r *request) {
- var keys []string
-
- for k := range r.c.s.streams {
- keys = append(keys, k)
- }
-
- sort.Strings(keys)
-
- for _, k := range keys {
- s := r.c.s.streams[k]
- r.Println(s.id, s.args)
- }
-}
-
-func init() {
- newCmd(streamAdd, "add")
- newCmd(streamShow, "show")
-}
diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go
new file mode 100644
index 0000000..b1c3fe8
--- /dev/null
+++ b/pkg/server/tunnel.go
@@ -0,0 +1,294 @@
+package server
+
+import (
+ "tunnel/pkg/server/module"
+ "tunnel/pkg/server/queue"
+ "tunnel/pkg/server/socket"
+ "tunnel/pkg/config"
+ "sync/atomic"
+ "strings"
+ "time"
+ "sort"
+ "sync"
+ "fmt"
+ "log"
+)
+
+type stream struct {
+ id int
+ n int32
+ t *tunnel
+ since time.Time
+ in, out socket.Channel
+}
+
+type tunnel struct {
+ id int
+ args string
+
+ streams map[int]*stream
+
+ mu sync.Mutex
+
+ nextSid int
+
+ in, out socket.S
+ m []module.M
+}
+
+func (s *stream) String() string {
+ return fmt.Sprintf("stream(%d)", s.id)
+}
+
+func (t *tunnel) String() string {
+ return fmt.Sprintf("tunnel(%d)", t.id)
+}
+
+func (t *tunnel) Close() {
+ t.in.Close()
+ t.out.Close()
+}
+
+func (t *tunnel) run() {
+ for {
+ if in, err := t.in.Open(); err != nil {
+ log.Println(t, err)
+ time.Sleep(5 * time.Second)
+ } else {
+ log.Println(t, "open", in)
+ go t.run2(in)
+ }
+ }
+}
+
+func (t *tunnel) newStream(in, out socket.Channel) *stream {
+ s := &stream{
+ t: t,
+ in: in,
+ out: out,
+ id: t.nextSid,
+ since: time.Now(),
+ }
+
+ t.mu.Lock()
+ t.nextSid++
+ t.streams[s.id] = s
+ t.mu.Unlock()
+
+ return s
+}
+
+func (s *stream) ref() {
+ atomic.AddInt32(&s.n, 1)
+}
+
+func (s *stream) unref() {
+ if atomic.AddInt32(&s.n, -1) == 0 {
+ log.Println(s.t, s, "close")
+
+ s.t.mu.Lock()
+ delete(s.t.streams, s.id)
+ s.t.mu.Unlock()
+ }
+}
+
+func (s *stream) watchChannel(rq, wq queue.Q, c socket.Channel) {
+ watch := func (q queue.Q, f func (q queue.Q) error) {
+ s.ref()
+ defer s.unref()
+
+ if err := f(q); err != nil {
+ log.Println(s.t, s, err)
+ }
+ }
+
+ go func () {
+ watch(wq, c.Send)
+ close(wq)
+ }()
+
+ go watch(rq, c.Recv)
+}
+
+func (s *stream) watchPipe(rq, wq queue.Q, f func (rq, wq queue.Q) error) {
+ go func () {
+ s.ref()
+ defer s.unref()
+
+ if err := f(rq, wq); err != nil {
+ log.Println(s.t, s, err)
+ }
+
+ close(wq)
+ }()
+}
+
+func (t *tunnel) run2(in socket.Channel) {
+ out, err := t.out.Open()
+ if err != nil {
+ log.Println(t, err)
+ in.Close()
+ return
+ }
+
+ log.Println(t, "open", out)
+
+ s := t.newStream(in, out)
+
+ s.ref()
+ defer s.unref()
+
+ rq, wq := queue.New(), queue.New()
+
+ s.watchChannel(rq, wq, in)
+
+ for _, m := range t.m {
+ send, recv := m.Open()
+ if send != nil {
+ q := queue.New()
+ s.watchPipe(wq, q, send)
+ wq = q
+ }
+ if recv != nil {
+ q := queue.New()
+ s.watchPipe(q, rq, recv)
+ rq = q
+ }
+ }
+
+ s.watchChannel(wq, rq, out)
+
+ log.Println(t, s, "create", in, out)
+}
+
+func newTunnel(args []string) (*tunnel, error) {
+ var in, out socket.S
+ var err error
+
+ n := len(args) - 1
+
+ if in, err = socket.New(args[0]); err != nil {
+ return nil, err
+ }
+
+ if out, err = socket.New(args[n]); err != nil {
+ in.Close()
+ return nil, err
+ }
+
+ t := &tunnel{
+ args: strings.Join(args, " "),
+ in: in,
+ out: out,
+ streams: make(map[int]*stream),
+ }
+
+ reverse := false
+
+ for _, arg := range args[1:n] {
+ var m module.M
+
+ if arg == "-" {
+ reverse = true
+ continue
+ }
+
+ if arg == "+" {
+ reverse = false
+ continue
+ }
+
+ if m, err = module.New(arg); err != nil {
+ t.Close()
+ return nil, err
+ }
+
+ if reverse {
+ m = module.Reverse(m)
+ reverse = false
+ }
+
+ t.m = append(t.m, m)
+ }
+
+ if reverse {
+ t.Close()
+ return nil, fmt.Errorf("bad '-' usage")
+ }
+
+ go t.run()
+
+ return t, nil
+}
+
+func tunnelAdd(r *request) {
+ if r.argc < 2 {
+ r.Fatal("not enough args")
+ }
+
+ t, err := newTunnel(r.args)
+ if err != nil {
+ r.Fatal(err)
+ }
+
+ log.Println(r.c, r, t, "create")
+
+ t.id = r.c.s.tunnels.add(t)
+}
+
+func foreachTunnel(m automap, f func (t *tunnel)) {
+ var keys []int
+
+ for k := range m {
+ keys = append(keys, k)
+ }
+
+ sort.Ints(keys)
+
+ for _, k := range keys {
+ f(m[k].(*tunnel))
+ }
+}
+
+func foreachStream(m map[int]*stream, f func (s *stream)) {
+ var keys []int
+
+ for k := range m {
+ keys = append(keys, k)
+ }
+
+ sort.Ints(keys)
+
+ for _, k := range keys {
+ f(m[k])
+ }
+}
+
+func tunnelShow(r *request) {
+ foreachTunnel(r.c.s.tunnels, func (t *tunnel) {
+ r.Println(t.id, t.args)
+ })
+}
+
+func streamShow(r *request) {
+ foreachTunnel(r.c.s.tunnels, func (t *tunnel) {
+ r.Println(t.id, t.args)
+
+ t.mu.Lock()
+ if len(t.streams) == 0 {
+ r.Println("\t", "nothing")
+ } else {
+ foreachStream(t.streams, func (s *stream) {
+ tm := s.since.Format(config.TimeFormat)
+ r.Println("\t", s.id, tm, s.in, s.out)
+ })
+ }
+ t.mu.Unlock()
+ })
+}
+
+func init() {
+ newCmd(tunnelAdd, "add")
+ newCmd(tunnelShow, "show")
+ newCmd(streamShow, "stream show")
+}