package server import ( "bytes" "errors" "fmt" "io" "log" "net" "strings" "sync" "time" "tunnel/pkg/netstring" "tunnel/pkg/server/env" ) type Server struct { listen net.Listener since time.Time version string wg sync.WaitGroup mu sync.Mutex once sync.Once done chan struct{} tunnels automap env env.Env nextCid int } type client struct { id int s *Server conn net.Conn r *netstring.Decoder w *netstring.Encoder nextRid int } type request struct { id int c *client fail bool out bytes.Buffer } type requestError string var errNotImplemented = errors.New("not implemented") func (c *client) String() string { return fmt.Sprintf("client:%d", c.id) } func (r *request) String() string { return fmt.Sprintf("request:%d", r.id) } func (r *request) Write(p []byte) (int, error) { return r.out.Write(p) } func (r *request) Print(v ...interface{}) { fmt.Fprint(&r.out, v...) } func (r *request) Printf(format string, v ...interface{}) { fmt.Fprintf(&r.out, format, v...) } func (r *request) Println(v ...interface{}) { fmt.Fprintln(&r.out, v...) } func (r *request) Fatal(v ...interface{}) { panic(requestError(fmt.Sprint(v...))) } func (r *request) Fatalf(format string, v ...interface{}) { panic(requestError(fmt.Sprintf(format, v...))) } func (s *Server) isDone() bool { select { case <-s.done: return true default: return false } } func New(path string, version string) (*Server, error) { listen, err := net.Listen("unix", path) if err != nil { return nil, err } log.Println("listen at", path) s := &Server{ env: env.New(), listen: listen, since: time.Now(), done: make(chan struct{}), tunnels: make(automap), version: version, } return s, nil } func (s *Server) Env() env.Env { return s.env } func (s *Server) Socket() string { return s.listen.Addr().String() } func (s *Server) Serve() { for { conn, err := s.listen.Accept() if err != nil { if s.isDone() { break } log.Print(err) continue } c := s.newClient(conn) log.Println(c, "accept") s.wg.Add(1) go c.reader() } s.wg.Wait() } func (s *Server) Stop() { s.once.Do(func() { close(s.done) s.listen.Close() }) } func (s *Server) newClient(conn net.Conn) *client { c := &client{ s: s, conn: conn, r: netstring.NewDecoder(conn), w: netstring.NewEncoder(conn), id: s.nextCid, } s.nextCid++ return c } func (s *Server) Command(args []string) error { r := &request{c: &client{s: s}} r.run(args) if r.fail { if r.out.Len() == 0 { return errors.New("failed") } return errors.New(r.out.String()) } return nil } func (c *client) newRequest() *request { r := &request{ c: c, id: c.nextRid, } c.nextRid++ return r } func (c *client) reader() { defer c.close() for { req, er := c.r.Decode() if er != nil { if !errors.Is(er, io.EOF) { log.Println(c, "reader:", er) } break } args, err := c.decode([]byte(req)) if err != nil { log.Println(c, "decode:", err) break } r := c.newRequest() r.run(args) ew := c.w.Encode(r.out.String()) if ew != nil { log.Println(c, "reader:", ew) break } } } func (c *client) decode(b []byte) ([]string, error) { dec := netstring.NewDecoder(bytes.NewReader(b)) var t []string for { if s, err := dec.Decode(); err == nil { t = append(t, s) } else { if !errors.Is(err, io.EOF) { return nil, err } break } } return t, nil } func (r *request) run(args []string) { defer func() { switch err := recover().(type) { case requestError: r.Print(err) r.fail = true default: panic(err) case nil: } }() log.Println(r.c, r, ">", strings.Join(args, " ")) r.c.s.mu.Lock() defer r.c.s.mu.Unlock() runCmd(r, r.c.s.env.EvalStrings(args)) } func (c *client) close() { log.Println(c, "close") c.conn.Close() c.s.wg.Done() }