diff options
| author | Mikhail Osipov <mike.osipov@gmail.com> | 2021-08-27 21:38:16 +0300 |
|---|---|---|
| committer | Mikhail Osipov <mike.osipov@gmail.com> | 2021-08-27 21:38:16 +0300 |
| commit | 2f01cc1db51368f36ba6ae664a3b0db9234f683f (patch) | |
| tree | fb4a869419816fe4fd83440611a9bc7b6a8140df /pkg | |
| parent | 3015840d9835717762de52e091adb58f1e2e3b63 (diff) | |
add origin host and port to env
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/server/hook/proxy.go | 2 | ||||
| -rw-r--r-- | pkg/server/socket/listen.go | 36 | ||||
| -rw-r--r-- | pkg/server/socket/sys.go | 10 | ||||
| -rw-r--r-- | pkg/test/test.go | 2 |
4 files changed, 30 insertions, 20 deletions
diff --git a/pkg/server/hook/proxy.go b/pkg/server/hook/proxy.go index bba17e3..0a79056 100644 --- a/pkg/server/hook/proxy.go +++ b/pkg/server/hook/proxy.go @@ -53,7 +53,7 @@ func (p *proxy) Recv(rq, wq queue.Q) error { resp, err := http.ParseResponse(r) if err == nil && resp.Code != http.OK { - err = fmt.Errorf("connect failed: %d %s", resp.Code, resp.Desc) + err = fmt.Errorf("connect %s failed: %d %s", p.addr, resp.Code, resp.Desc) } if err != nil { diff --git a/pkg/server/socket/listen.go b/pkg/server/socket/listen.go index 9d19677..6848620 100644 --- a/pkg/server/socket/listen.go +++ b/pkg/server/socket/listen.go @@ -52,36 +52,48 @@ func (s *listenSocket) Prepare(e env.Env) error { return nil } +func setOriginHostPort(env env.Env, origin, host, port string) { + env.Set("origin", origin) + env.Set("origin.host", host) + env.Set("origin.port", port) +} + func (s *listenSocket) New(env env.Env) (Conn, error) { - var original string + var origin string conn, err := s.listen.Accept() if err != nil { return nil, err } - la, ra := conn.LocalAddr(), conn.RemoteAddr() - desc := fmt.Sprintf("%s/%s->%s", la.Network(), ra, la) + ra := conn.RemoteAddr() + desc := fmt.Sprintf("%s/%s->%s", s.Proto, ra, s.Addr) info := fmt.Sprintf("<%s/%s", ra.Network(), ra) if s.Redirect { - if err := getConnOriginalAddr(conn, &original); err != nil { - log.Println("accept", desc, "failed") - conn.Close() - return nil, err - } else { - env.Set("original", original) + host, port, err := getConnOriginalHostPort(conn) + if err != nil { + defer conn.Close() + return nil, fmt.Errorf("accept %s failed: %s", desc, err) } + origin = net.JoinHostPort(host, port) + setOriginHostPort(env, origin, host, port) } if s.Tproxy { - env.Set("original", la.String()) + origin = conn.LocalAddr().String() + host, port, err := net.SplitHostPort(origin) + if err != nil { + defer conn.Close() + return nil, fmt.Errorf("accept %s failed: %s", desc, err) + } + setOriginHostPort(env, origin, host, port) } - if original == "" { + if origin == "" { log.Println("accept", desc) } else { - log.Println("accept", desc, "original", original) + log.Println("accept", desc, "origin", origin) } return newConn(conn, desc, info), nil diff --git a/pkg/server/socket/sys.go b/pkg/server/socket/sys.go index f90d2da..b2d836e 100644 --- a/pkg/server/socket/sys.go +++ b/pkg/server/socket/sys.go @@ -106,7 +106,7 @@ func getSocketOriginalDst(fd int, sa *unix.RawSockaddrAny) error { } } -func getConnOriginalAddr(conn net.Conn, addr *string) error { +func getConnOriginalHostPort(conn net.Conn) (string, string, error) { var sa unix.RawSockaddrAny f := func(fd int) error { @@ -114,7 +114,7 @@ func getConnOriginalAddr(conn net.Conn, addr *string) error { } if err := withConnControl(conn, f); err != nil { - return fmt.Errorf("get-original-addr: %w", err) + return "", "", fmt.Errorf("get-original-addr: %w", err) } var host net.IP @@ -128,12 +128,10 @@ func getConnOriginalAddr(conn net.Conn, addr *string) error { sin := (*unix.RawSockaddrInet6)(unsafe.Pointer(&sa)) host, port = sin.Addr[:], sin.Port default: - return errors.New("get-original-addr: unknown address family") + return "", "", fmt.Errorf("get-original-addr: unknown address family %d", sa.Addr.Family) } - *addr = net.JoinHostPort(host.String(), strconv.Itoa(int(be16toh(port)))) - - return nil + return host.String(), strconv.Itoa(int(be16toh(port))), nil } func setSocketTransparent(fd int) error { diff --git a/pkg/test/test.go b/pkg/test/test.go index e2ecfcb..1237fe7 100644 --- a/pkg/test/test.go +++ b/pkg/test/test.go @@ -54,7 +54,7 @@ type Server struct { func (e *env) newInstance() *Client { socket := getSocketPath(e.Name()) - s, err := server.New(socket) + s, err := server.New(socket, "test") if err != nil { e.Fatal(err) } |
