From 2f01cc1db51368f36ba6ae664a3b0db9234f683f Mon Sep 17 00:00:00 2001 From: Mikhail Osipov Date: Fri, 27 Aug 2021 21:38:16 +0300 Subject: add origin host and port to env --- pkg/server/hook/proxy.go | 2 +- pkg/server/socket/listen.go | 36 ++++++++++++++++++++++++------------ pkg/server/socket/sys.go | 10 ++++------ 3 files changed, 29 insertions(+), 19 deletions(-) (limited to 'pkg/server') diff --git a/pkg/server/hook/proxy.go b/pkg/server/hook/proxy.go index bba17e3..0a79056 100644 --- a/pkg/server/hook/proxy.go +++ b/pkg/server/hook/proxy.go @@ -53,7 +53,7 @@ func (p *proxy) Recv(rq, wq queue.Q) error { resp, err := http.ParseResponse(r) if err == nil && resp.Code != http.OK { - err = fmt.Errorf("connect failed: %d %s", resp.Code, resp.Desc) + err = fmt.Errorf("connect %s failed: %d %s", p.addr, resp.Code, resp.Desc) } if err != nil { diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go index 9d19677..6848620 100644 --- a/pkg/server/socket/listen.go +++ b/pkg/server/socket/listen.go @@ -52,36 +52,48 @@ func (s *listenSocket) Prepare(e env.Env) error { return nil } +func setOriginHostPort(env env.Env, origin, host, port string) { + env.Set("origin", origin) + env.Set("origin.host", host) + env.Set("origin.port", port) +} + func (s *listenSocket) New(env env.Env) (Conn, error) { - var original string + var origin string conn, err := s.listen.Accept() if err != nil { return nil, err } - la, ra := conn.LocalAddr(), conn.RemoteAddr() - desc := fmt.Sprintf("%s/%s->%s", la.Network(), ra, la) + ra := conn.RemoteAddr() + desc := fmt.Sprintf("%s/%s->%s", s.Proto, ra, s.Addr) info := fmt.Sprintf("<%s/%s", ra.Network(), ra) if s.Redirect { - if err := getConnOriginalAddr(conn, &original); err != nil { - log.Println("accept", desc, "failed") - conn.Close() - return nil, err - } else { - env.Set("original", original) + host, port, err := getConnOriginalHostPort(conn) + if err != nil { + defer conn.Close() + return nil, fmt.Errorf("accept %s failed: %s", desc, err) } + origin = net.JoinHostPort(host, port) + setOriginHostPort(env, origin, host, port) } if s.Tproxy { - env.Set("original", la.String()) + origin = conn.LocalAddr().String() + host, port, err := net.SplitHostPort(origin) + if err != nil { + defer conn.Close() + return nil, fmt.Errorf("accept %s failed: %s", desc, err) + } + setOriginHostPort(env, origin, host, port) } - if original == "" { + if origin == "" { log.Println("accept", desc) } else { - log.Println("accept", desc, "original", original) + log.Println("accept", desc, "origin", origin) } return newConn(conn, desc, info), nil diff --git a/pkg/server/socket/sys.go b/pkg/server/socket/sys.go index f90d2da..b2d836e 100644 --- a/pkg/server/socket/sys.go +++ b/pkg/server/socket/sys.go @@ -106,7 +106,7 @@ func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error { } } -func getConnOriginalAddr(conn net.Conn, addr *string) error { +func getConnOriginalHostPort(conn net.Conn) (string, string, error) { var sa unix.RawSockaddrAny f := func(fd int) error { @@ -114,7 +114,7 @@ func getConnOriginalAddr(conn net.Conn, addr *string) error { } if err := withConnControl(conn, f); err != nil { - return fmt.Errorf("get-original-addr: %w", err) + return "", "", fmt.Errorf("get-original-addr: %w", err) } var host net.IP @@ -128,12 +128,10 @@ func getConnOriginalAddr(conn net.Conn, addr *string) error { sin := (*unix.RawSockaddrInet6)(unsafe.Pointer(&sa)) host, port = sin.Addr[:], sin.Port default: - return errors.New("get-original-addr: unknown address family") + return "", "", fmt.Errorf("get-original-addr: unknown address family %d", sa.Addr.Family) } - *addr = net.JoinHostPort(host.String(), strconv.Itoa(int(be16toh(port)))) - - return nil + return host.String(), strconv.Itoa(int(be16toh(port))), nil } func setSocketTransparent(fd int) error { -- cgit v1.2.3-70-g09d2