diff options
Diffstat (limited to 'pkg/server/server.go')
| -rw-r--r-- | pkg/server/server.go | 70 |
1 files changed, 50 insertions, 20 deletions
diff --git a/pkg/server/server.go b/pkg/server/server.go index a9a50a4..badf5c0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -51,7 +51,9 @@ type request struct { argc int args []string - out *bytes.Buffer + fail bool + + out bytes.Buffer } type requestError string @@ -67,15 +69,15 @@ func (r *request) String() string { } func (r *request) Print(v ...interface{}) { - fmt.Fprint(r.out, v...) + fmt.Fprint(&r.out, v...) } func (r *request) Printf(format string, v ...interface{}) { - fmt.Fprintf(r.out, format, v...) + fmt.Fprintf(&r.out, format, v...) } func (r *request) Println(v ...interface{}) { - fmt.Fprintln(r.out, v...) + fmt.Fprintln(&r.out, v...) } func (r *request) Fatal(v ...interface{}) { @@ -145,6 +147,14 @@ func New(path string) (*Server, error) { return s, nil } +func (s *Server) Env() env.Env { + return s.env +} + +func (s *Server) Socket() string { + return s.listen.Addr().String() +} + func (s *Server) Run() { for { conn, err := s.listen.Accept() @@ -188,11 +198,28 @@ func (s *Server) newClient(conn net.Conn) *client { return c } +func (s *Server) Command(query string) error { + args := strings.Split(query, " ") + + r := &request{c: &client{s: s}} + + r.run(args) + + if r.fail { + if r.out.Len() == 0 { + return errors.New("failed") + } + + return errors.New(r.out.String()) + } + + return nil +} + func (c *client) newRequest() *request { r := &request{ - c: c, - id: c.nextRid, - out: new(bytes.Buffer), + c: c, + id: c.nextRid, } c.nextRid++ @@ -214,11 +241,15 @@ func (c *client) handle() { break } - query := string(buf[:nr]) + args, err := c.decode(buf[:nr]) + if err != nil { + log.Println(c, "decode:", err) + break + } r := c.newRequest() - r.run(query) + r.run(args) if r.out.Len() == 0 { r.out.Write([]byte("\n")) @@ -232,8 +263,8 @@ func (c *client) handle() { } } -func (r *request) decode(query string) []string { - dec := netstring.NewDecoder(bytes.NewReader([]byte(query))) +func (c *client) decode(b []byte) ([]string, error) { + dec := netstring.NewDecoder(bytes.NewReader(b)) var t []string for { @@ -241,14 +272,14 @@ func (r *request) decode(query string) []string { t = append(t, s) } else { if !errors.Is(err, io.EOF) { - r.Fatal("request parse failed") + return nil, err } break } } - return t + return t, nil } func (r *request) eval(args []string) []string { @@ -263,10 +294,8 @@ func (r *request) eval(args []string) []string { return args } -func (r *request) parse(query string) { - args := r.eval(r.decode(query)) - - c, args := getCmd(args) +func (r *request) parse(args []string) { + c, args := getCmd(r.eval(args)) if c == nil { r.Fatal("command not found") } @@ -276,20 +305,21 @@ func (r *request) parse(query string) { r.cmd = c } -func (r *request) run(query string) { +func (r *request) run(args []string) { defer func() { switch err := recover().(type) { case requestError: r.Print(err) + r.fail = true default: panic(err) case nil: } }() - r.parse(query) + log.Println(r.c, r, ">", strings.Join(args, " ")) - log.Println(r.c, r, ">", r.cmd.name, strings.Join(r.args, " ")) + r.parse(args) r.c.s.mu.Lock() defer r.c.s.mu.Unlock() |
