package server import ( "tunnel/pkg/config" "tunnel/pkg/netstring" "strings" "errors" "bytes" "sync" "time" "fmt" "log" "net" "io" ) type Server struct { listen net.Listener since time.Time wg sync.WaitGroup mu sync.Mutex once sync.Once done chan struct{} tunnels automap env env nextCid int } type client struct { id int s *Server conn net.Conn nextRid int } type request struct { id int c *client cmd *cmd argc int args []string out *bytes.Buffer } type requestError string var errNotImplemented = errors.New("not implemented") func (c *client) String() string { return fmt.Sprintf("client(%d)", c.id) } func (r *request) String() string { return fmt.Sprintf("request(%d)", r.id) } func (r *request) Print(v ...interface{}) { fmt.Fprint(r.out, v...) } func (r *request) Printf(format string, v ...interface{}) { fmt.Fprintf(r.out, format, v...) } func (r *request) Println(v ...interface{}) { fmt.Fprintln(r.out, v...) } func (r *request) Fatal(v ...interface{}) { panic(requestError(fmt.Sprint(v...))) } func (r *request) Fatalf(format string, v ...interface{}) { panic(requestError(fmt.Sprintf(format, v...))) } func (r *request) expect(c ...int) { desc := func (n int) string { var sep string if n == 1 { sep = " is " } else { sep = " are " } return fmt.Sprint(n, sep, "expected") } check := func (cond bool, args ...interface{}) { if cond { r.Fatal(args...) } } switch len(c) { case 0: check(r.argc > 0, "args are not expected") case 1: check(r.argc < c[0], "not enough args: ", desc(c[0])) check(r.argc > c[0], "too many args: ", desc(c[0])) case 2: check(r.argc < c[0], "not enough args: at least ", desc(c[0])) check(r.argc > c[1], "too many args: no more than ", desc(c[1])) } } func (s *Server) isDone() bool { select { case <- s.done: return true default: return false } } func New() (*Server, error) { listen, err := net.Listen(config.SockType, config.SockPath) if err != nil { return nil, err } s := &Server{ listen: listen, since: time.Now(), done: make(chan struct{}), tunnels: make(automap), } return s, nil } func (s *Server) Run() { for { conn, err := s.listen.Accept() if err != nil { if s.isDone() { break } log.Print(err) continue } c := s.newClient(conn) log.Println(c, "accept") s.wg.Add(1) go c.handle() } s.wg.Wait() } func (s *Server) Stop() { s.once.Do(func () { close(s.done) s.listen.Close() }) } func (s *Server) newClient(conn net.Conn) *client { c := &client{ s: s, conn: conn, id: s.nextCid, } s.nextCid++ return c } func (c *client) newRequest() *request { r := &request{ c: c, id: c.nextRid, out: new(bytes.Buffer), } c.nextRid++ return r } func (c *client) handle() { defer c.close() buf := make([]byte, config.BufSize) for { nr, er := c.conn.Read(buf) if er != nil { if er != io.EOF { log.Println(c, "handle:", er) } break } query := string(buf[:nr]) r := c.newRequest() r.run(query) if r.out.Len() == 0 { r.out.Write([]byte("\n")) } _, ew := c.conn.Write(r.out.Bytes()) if ew != nil { log.Println(c, "handle:", ew) break } } } func (r *request) decode(query string) []string { dec := netstring.NewDecoder(bytes.NewReader([]byte(query))) var t []string for { if s, err := dec.Decode(); err == nil { t = append(t, s) } else { if err == io.EOF { break } r.Fatal("failed to parse request") } } return t } func (r *request) eval(args []string) []string { repl := func (v string) string { if v, ok := r.c.s.env.get(v[1:]); ok { return v } r.Fatal("unbound variable ", v) return v } eval := func (s string) string { var t string for ;; s = t { t = varTokenRe.ReplaceAllStringFunc(s, repl) if s == t { return s } } } for n, s := range args { if strings.HasPrefix(s, "^") { args[n] = s[1:] } else { args[n] = eval(s) } } return args } func (r *request) parse(query string) { args := r.eval(r.decode(query)) c, args := getCmd(args) if c == nil { r.Fatal("command not found") } r.args = args r.argc = len(args) r.cmd = c } func (r *request) run(query string) { defer func () { switch err := recover().(type) { case requestError: r.Print(err) default: panic(err) case nil: } }() r.parse(query) log.Println(r.c, r, ">", r.cmd.name, strings.Join(r.args, " ")) r.c.s.mu.Lock() defer r.c.s.mu.Unlock() r.cmd.f(r) } func (c *client) close() { log.Println(c, "close") c.conn.Close() c.s.wg.Done() }