package server import ( "tunnel/pkg/config" "tunnel/pkg/netstring" "strings" "bytes" "sync" "time" "fmt" "log" "net" "io" ) type Server struct { listen net.Listener since time.Time wg sync.WaitGroup once sync.Once done chan struct{} env env nextCid uint64 } type client struct { id uint64 s *Server conn net.Conn nextRid uint64 } type request struct { id uint64 c *client cmd *cmd argc int args []string failed bool out *bytes.Buffer } func (c *client) String() string { return fmt.Sprintf("client(%d)", c.id) } func (r *request) String() string { return fmt.Sprintf("request(%d)", r.id) } func (r *request) Print(v ...interface{}) { fmt.Fprint(r.out, v...) } func (r *request) Printf(format string, v ...interface{}) { fmt.Fprintf(r.out, format, v...) } func (r *request) Println(v ...interface{}) { fmt.Fprintln(r.out, v...) } func (r *request) Fatal(v ...interface{}) { panic(fmt.Sprint(v...)) } func (r *request) Fatalf(format string, v ...interface{}) { panic(fmt.Sprintf(format, v...)) } func (r *request) Fatalln(v ...interface{}) { panic(fmt.Sprintln(v...)) } func (r *request) expect(c ...int) { desc := func (n int) string { var sep string if n == 1 { sep = " is " } else { sep = " are " } return fmt.Sprint(n, sep, "expected") } check := func (cond bool, args ...interface{}) { if cond { r.Fatal(args...) } } switch len(c) { case 0: check(r.argc > 0, "args are not expected") case 1: check(r.argc < c[0], "not enough args: ", desc(c[0])) check(r.argc > c[0], "too many args: ", desc(c[0])) case 2: check(r.argc < c[0], "not enough args: at least ", desc(c[0])) check(r.argc > c[1], "too many args: no more than ", desc(c[1])) } } func (s *Server) isDone() bool { select { case <- s.done: return true default: return false } } func New() (*Server, error) { listen, err := net.Listen(config.SockType, config.SockPath) if err != nil { return nil, err } s := &Server{ listen: listen, since: time.Now(), done: make(chan struct{}), } return s, nil } func (s *Server) Run() { for { conn, err := s.listen.Accept() if err != nil { if s.isDone() { break } log.Print(err) continue } c := s.newClient(conn) log.Println(c, "accept") s.wg.Add(1) go c.handle() } s.wg.Wait() } func (s *Server) Stop() { s.once.Do(func () { close(s.done) s.listen.Close() }) } func (s *Server) newClient(conn net.Conn) *client { c := &client{ s: s, conn: conn, id: s.nextCid, } s.nextCid++ return c } func (c *client) newRequest() *request { r := &request{ c: c, id: c.nextRid, out: new(bytes.Buffer), } c.nextRid++ return r } func (c *client) handle() { defer c.close() buf := make([]byte, config.BufSize) for { nr, er := c.conn.Read(buf) if er != nil { if er != io.EOF { log.Println(c, "handle:", er) } break } query := string(buf[:nr]) r := c.newRequest() r.run(query) if r.out.Len() == 0 { r.out.Write([]byte("\n")) } _, ew := c.conn.Write(r.out.Bytes()) if ew != nil { log.Println(c, "handle:", ew) break } } } func (r *request) decode(query string) []string { dec := netstring.NewDecoder(bytes.NewReader([]byte(query))) var t []string for { if s, err := dec.Decode(); err == nil { t = append(t, s) } else { if err == io.EOF { break } r.Fatal("failed to parse request") } } return t } func (r *request) parse(query string) { c, args := getCmd(r.decode(query)) if c == nil { r.Fatal("command not found") } for n, s := range args { if strings.HasPrefix(s, "%") { if v, ok := r.c.s.env.get(s[1:]); ok { args[n] = v } else { r.Fatal("unbound variable ", s) } } } r.args = args r.argc = len(args) r.cmd = c } func (r *request) run(query string) { defer func () { if e := recover(); e != nil { r.failed = true r.Print(e) } }() r.parse(query) log.Printf("%s %s run [%s] '%s'", r.c, r, r.cmd.name, strings.Join(r.args, " ")) r.cmd.f(r) } func (c *client) close() { log.Println(c, "close") c.conn.Close() c.s.wg.Done() }