package server import ( "bytes" "errors" "fmt" "io" "log" "net" "strings" "sync" "time" "tunnel/pkg/netstring" "tunnel/pkg/server/env" ) type Server struct { listen net.Listener since time.Time wg sync.WaitGroup mu sync.Mutex once sync.Once done chan struct{} tunnels automap env env.Env nextCid int } type client struct { id int s *Server conn net.Conn r *netstring.Decoder w *netstring.Encoder nextRid int } type request struct { id int c *client cmd *cmd argc int args []string fail bool out bytes.Buffer } type requestError string var errNotImplemented = errors.New("not implemented") func (c *client) String() string { return fmt.Sprintf("client:%d", c.id) } func (r *request) String() string { return fmt.Sprintf("request:%d", r.id) } func (r *request) Write(p []byte) (int, error) { return r.out.Write(p) } func (r *request) Print(v ...interface{}) { fmt.Fprint(&r.out, v...) } func (r *request) Printf(format string, v ...interface{}) { fmt.Fprintf(&r.out, format, v...) } func (r *request) Println(v ...interface{}) { fmt.Fprintln(&r.out, v...) } func (r *request) Fatal(v ...interface{}) { panic(requestError(fmt.Sprint(v...))) } func (r *request) Fatalf(format string, v ...interface{}) { panic(requestError(fmt.Sprintf(format, v...))) } func (r *request) expect(c ...int) { desc := func(n int) string { var sep string if n == 1 { sep = " is " } else { sep = " are " } return fmt.Sprint(n, sep, "expected") } check := func(cond bool, args ...interface{}) { if cond { r.Fatal(args...) } } switch len(c) { case 0: check(r.argc > 0, "args are not expected") case 1: check(r.argc < c[0], "not enough args: ", desc(c[0])) check(r.argc > c[0], "too many args: ", desc(c[0])) case 2: check(r.argc < c[0], "not enough args: at least ", desc(c[0])) check(r.argc > c[1], "too many args: no more than ", desc(c[1])) } } func (s *Server) isDone() bool { select { case <-s.done: return true default: return false } } func New(path string) (*Server, error) { listen, err := net.Listen("unix", path) if err != nil { return nil, err } log.Println("listen at", path) s := &Server{ env: env.New(), listen: listen, since: time.Now(), done: make(chan struct{}), tunnels: make(automap), } return s, nil } func (s *Server) Env() env.Env { return s.env } func (s *Server) Socket() string { return s.listen.Addr().String() } func (s *Server) Run() { for { conn, err := s.listen.Accept() if err != nil { if s.isDone() { break } log.Print(err) continue } c := s.newClient(conn) log.Println(c, "accept") s.wg.Add(1) go c.handle() } s.wg.Wait() } func (s *Server) Stop() { s.once.Do(func() { close(s.done) s.listen.Close() }) } func (s *Server) newClient(conn net.Conn) *client { c := &client{ s: s, conn: conn, r: netstring.NewDecoder(conn), w: netstring.NewEncoder(conn), id: s.nextCid, } s.nextCid++ return c } func (s *Server) Command(query string) error { args := strings.Split(query, " ") r := &request{c: &client{s: s}} r.run(args) if r.fail { if r.out.Len() == 0 { return errors.New("failed") } return errors.New(r.out.String()) } return nil } func (c *client) newRequest() *request { r := &request{ c: c, id: c.nextRid, } c.nextRid++ return r } func (c *client) handle() { defer c.close() for { req, er := c.r.Decode() if er != nil { if !errors.Is(er, io.EOF) { log.Println(c, "handle:", er) } break } args, err := c.decode([]byte(req)) if err != nil { log.Println(c, "decode:", err) break } r := c.newRequest() r.run(args) if r.out.Len() == 0 { r.out.Write([]byte("\n")) } ew := c.w.Encode(r.out.String()) if ew != nil { log.Println(c, "handle:", ew) break } } } func (c *client) decode(b []byte) ([]string, error) { dec := netstring.NewDecoder(bytes.NewReader(b)) var t []string for { if s, err := dec.Decode(); err == nil { t = append(t, s) } else { if !errors.Is(err, io.EOF) { return nil, err } break } } return t, nil } func (r *request) eval(args []string) []string { var out []string for _, s := range args { var t string if strings.HasPrefix(s, ":") { t = s[1:] } else { t = r.c.s.env.Eval(s) } out = append(out, t) } return out } func (r *request) parse(args []string) { if c, args := getCmd(r.eval(args)); c == nil { r.Fatal("command not found") } else { r.args = args r.argc = len(args) r.cmd = c } } func (r *request) run(args []string) { defer func() { switch err := recover().(type) { case requestError: r.Print(err) r.fail = true default: panic(err) case nil: } }() log.Println(r.c, r, ">", strings.Join(args, " ")) r.parse(args) r.c.s.mu.Lock() defer r.c.s.mu.Unlock() r.cmd.f(r) } func (c *client) close() { log.Println(c, "close") c.conn.Close() c.s.wg.Done() }