diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 5dfcbd1..1dfd760 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -242,11 +242,20 @@ const ( r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "gateway-1", + ForwardedBy: "_gateway-1", Via: "edge-1", })) ``` +如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。 + +- IPv4:`203.0.113.43` +- IPv6 / 带端口:`[2001:db8::17]:443` +- 匿名标识:`_gateway-1` +- 未知:`unknown` + +像 `gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。 + `Via` 不是“留空即禁用”的开关。当前实现中: - 如果 `Via` 非空,则使用该值追加 `Via` @@ -282,11 +291,13 @@ Touka 会尽量遵循代理链语义: Touka 的反向代理实现支持以下能力: +- `CONNECT` 隧道转发(HTTP/1.x) - `Connection: Upgrade` / `Upgrade` 协议升级转发 - WebSocket 等 101 Switching Protocols 场景 - SSE(Server-Sent Events)立即刷新 - Trailer 透传 - 1xx 响应透传 +- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理 例如,代理 WebSocket 服务: @@ -341,7 +352,7 @@ func main() { r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "gateway-1", + ForwardedBy: "_gateway-1", Via: "gateway-1", FlushInterval: 100 * time.Millisecond, ModifyRequest: func(req *http.Request) { diff --git a/ecw.go b/ecw.go index 754571f..dedbe27 100644 --- a/ecw.go +++ b/ecw.go @@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool { func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { hijacker, ok := ecw.w.(http.Hijacker) if !ok { - return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface") + return nil, nil, http.ErrNotSupported } return hijacker.Hijack() } diff --git a/engine.go b/engine.go index c2eae91..a4350c0 100644 --- a/engine.go +++ b/engine.go @@ -475,21 +475,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { func MethodNotAllowed() HandlerFunc { return func(c *Context) { httpMethod := c.Request.Method - requestPath := c.Request.URL.Path + requestPath := routeLookupPath(c.Request) engine := c.engine // 是否是OPTIONS方式 if httpMethod == http.MethodOptions { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := []string{} - for _, treeIter := range engine.methodTrees { - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - allowedMethods = append(allowedMethods, treeIter.method) - } - } + allowedMethods := engine.allowedMethodsForPath(requestPath) if len(allowedMethods) > 0 { // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) @@ -705,7 +696,7 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // 这是路由查找和执行的核心逻辑 func (engine *Engine) handleRequest(c *Context) { httpMethod := c.Request.Method - requestPath := c.Request.URL.Path + requestPath := routeLookupPath(c.Request) // 查找对应的路由树的根节点 rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 @@ -725,7 +716,7 @@ func (engine *Engine) handleRequest(c *Context) { } // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) - if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 + if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向 if value.tsr && engine.RedirectTrailingSlash { // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ redirectPath := requestPath @@ -782,6 +773,41 @@ func (engine *Engine) handleRequest(c *Context) { //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } +func routeLookupPath(req *http.Request) string { + if req == nil { + return "" + } + + if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") { + return "/" + req.RequestURI + } + if isGeneralOptionsRequest(req) { + return "" + } + if req.URL == nil { + return "" + } + return req.URL.Path +} + +func isGeneralOptionsRequest(req *http.Request) bool { + return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" +} + +func (engine *Engine) allowedMethodsForPath(requestPath string) []string { + allowedMethods := make([]string, 0, len(engine.methodTrees)) + for _, treeIter := range engine.methodTrees { + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) + PutTempSkippedNodes(tempSkippedNodes) + if value.handlers != nil { + allowedMethods = append(allowedMethods, treeIter.method) + } + } + return allowedMethods +} + // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // 它可以用于在长连接 (如 SSE) 中监听关闭信号. func (engine *Engine) Context() context.Context { diff --git a/respw.go b/respw.go index dd94db3..ef5cc3c 100644 --- a/respw.go +++ b/respw.go @@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) { // 尝试从底层 ResponseWriter 获取 Hijacker 接口 hj, ok := rw.ResponseWriter.(http.Hijacker) if !ok { - return nil, nil, errors.New("http.Hijacker interface not supported") + return nil, nil, http.ErrNotSupported } // 调用底层的 Hijack 方法 diff --git a/reverseproxy.go b/reverseproxy.go index 1730b1e..977402b 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -14,6 +14,7 @@ import ( "net" "net/http" "net/http/httptrace" + "net/http/httputil" "net/netip" "net/textproto" "net/url" @@ -217,6 +218,12 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { default: proxy.config.ForwardedHeaders = ForwardedBoth } + proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy) + if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) { + if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil { + proxy.configError = err + } + } return proxy } @@ -234,11 +241,20 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { transport = http.DefaultTransport } + updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) + if err != nil { + p.handleError(c, err) + return + } + if handledLocally { + return + } + ctx, cancel := p.requestContext(c) defer cancel() outreq := c.Request.Clone(ctx) - if c.Request.ContentLength == 0 { + if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { outreq.Body = nil } if outreq.Body != nil { @@ -249,12 +265,35 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { outreq.Header = make(http.Header) } outreq.Close = false - - rewriteReverseProxyURL(outreq, p.target) - if !p.config.PreserveHost { - outreq.Host = "" + var connectWriter *io.PipeWriter + defer func() { + if connectWriter != nil { + _ = connectWriter.Close() + } + }() + if outreq.Method == http.MethodConnect { + pipeReader, pipeWriter := io.Pipe() + outreq.Body = pipeReader + outreq.ContentLength = -1 + defer outreq.Body.Close() + connectWriter = pipeWriter + } + + if outreq.Method == http.MethodConnect { + if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { + p.handleError(c, err) + return + } + } else { + rewriteReverseProxyURL(outreq, p.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } + if updatedMaxForwards != "" { + outreq.Header.Set("Max-Forwards", updatedMaxForwards) } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) reqUpType := reverseProxyUpgradeType(outreq.Header) if reqUpType != "" && !isPrintableASCII(reqUpType) { @@ -318,6 +357,23 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { return } + if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { + removeHopByHopHeaders(res.Header) + res.Header.Del("Content-Length") + res.Header.Del("Transfer-Encoding") + res.ContentLength = -1 + res.TransferEncoding = nil + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return + } + if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil { + p.handleError(c, err) + } + connectWriter = nil + return + } + if res.StatusCode == http.StatusSwitchingProtocols { appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { @@ -353,6 +409,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { defer res.Body.Close() c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err)) p.logf(c, "reverse proxy body copy failed: %v", err) + if reverseProxyShouldPanicOnCopyError(c.Request) { + panic(http.ErrAbortHandler) + } return } res.Body.Close() @@ -378,6 +437,86 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { } } +func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) { + if c == nil || c.Request == nil { + return "", false, nil + } + + switch c.Request.Method { + case http.MethodOptions, http.MethodTrace: + default: + return "", false, nil + } + + rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards")) + if rawValue == "" { + return "", false, nil + } + + value, err := strconv.Atoi(rawValue) + if err != nil || value < 0 { + return "", false, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("invalid Max-Forwards value %q", rawValue), + } + } + if value == 0 { + switch c.Request.Method { + case http.MethodTrace: + return "", true, p.writeLocalTraceResponse(c) + case http.MethodOptions: + p.writeLocalOptionsResponse(c) + return "", true, nil + } + } + + return strconv.Itoa(value - 1), false, nil +} + +func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error { + if c == nil || c.Request == nil { + return nil + } + + traceReq := c.Request.Clone(c.Request.Context()) + traceReq.Body = nil + traceReq.ContentLength = 0 + traceReq.TransferEncoding = nil + traceReq.RequestURI = c.Request.RequestURI + if traceReq.RequestURI == "" && traceReq.URL != nil { + traceReq.RequestURI = traceReq.URL.RequestURI() + } + traceReq.Header = traceReq.Header.Clone() + for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} { + traceReq.Header.Del(key) + } + + dump, err := httputil.DumpRequest(traceReq, false) + if err != nil { + return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err} + } + + c.Writer.Header().Set("Content-Type", "message/http") + c.Writer.WriteHeader(http.StatusOK) + _, err = c.Writer.Write(dump) + return err +} + +func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { + if c == nil { + return + } + + if c.engine != nil { + if c.Request != nil && c.Request.RequestURI != "*" { + if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 { + c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) + } + } + } + c.Writer.WriteHeader(http.StatusOK) +} + func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) { ctx := c.Request.Context() if ctx.Done() != nil { @@ -522,7 +661,11 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques clientConn, brw, err := c.Writer.Hijack() if err != nil { backConn.Close() - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + status := http.StatusBadGateway + if errors.Is(err, http.ErrNotSupported) { + status = http.StatusNotImplemented + } + return &reverseProxyStatusError{status: status, err: err} } defer clientConn.Close() @@ -561,6 +704,80 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { + if backWrite == nil { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"), + } + } + backRead := res.Body + + clientConn, brw, err := c.Writer.Hijack() + if err != nil { + backRead.Close() + _ = backWrite.Close() + status := http.StatusBadGateway + if errors.Is(err, http.ErrNotSupported) { + status = http.StatusNotImplemented + } + return &reverseProxyStatusError{status: status, err: err} + } + + defer clientConn.Close() + defer backRead.Close() + defer backWrite.Close() + + backConnClosed := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + case <-backConnClosed: + } + backRead.Close() + _ = backWrite.Close() + }() + defer close(backConnClosed) + + res.Body = nil + if err := res.Write(brw); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + if err := brw.Flush(); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + go func() { + if _, err := io.Copy(clientConn, backRead); err != nil { + errc <- err + return + } + if cw, ok := clientConn.(interface{ CloseWrite() error }); ok { + errc <- cw.CloseWrite() + return + } + errc <- errReverseProxyCopyDone + }() + go func() { + if _, err := io.Copy(backWrite, clientConn); err != nil { + errc <- err + return + } + errc <- backWrite.Close() + }() + + firstErr := <-errc + if firstErr == nil { + firstErr = <-errc + } + if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) { + return nil + } + return firstErr +} + func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { return -1 @@ -638,6 +855,10 @@ func reverseProxyStatusCode(err error) int { if errors.As(err, &statusErr) && statusErr.status > 0 { return statusErr.status } + var netErr net.Error + if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) { + return http.StatusGatewayTimeout + } return http.StatusBadGateway } @@ -651,6 +872,17 @@ func validateReverseProxyTarget(target *url.URL) error { return nil } +func validateReverseProxyForwardedBy(value string) error { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return nil + } + if !isValidForwardedNodeIdentifier(trimmed) { + return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value) + } + return nil +} + func normalizeReverseProxyTarget(target *url.URL) { switch strings.ToLower(target.Scheme) { case "ws": @@ -732,6 +964,83 @@ func buildForwardedHeaderValue(clientIP, by, host, scheme string) string { return strings.Join(pairs, ";") } +func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { + return policy == ForwardedBoth || policy == ForwardedRFC7239Only +} + +func isValidForwardedNodeIdentifier(value string) bool { + if value == "" { + return false + } + if strings.HasPrefix(value, "[") { + closing := strings.IndexByte(value, ']') + if closing <= 1 { + return false + } + addr, err := netip.ParseAddr(value[1:closing]) + if err != nil || !addr.Is6() { + return false + } + if closing == len(value)-1 { + return true + } + if value[closing+1] != ':' { + return false + } + return isValidForwardedNodePort(value[closing+2:]) + } + + host, port, hasPort := strings.Cut(value, ":") + if hasPort { + switch { + case host == "unknown", isValidForwardedObfuscatedIdentifier(host): + return isValidForwardedNodePort(port) + default: + addr, err := netip.ParseAddr(host) + return err == nil && addr.Is4() && isValidForwardedNodePort(port) + } + } + + if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) { + return true + } + addr, err := netip.ParseAddr(value) + return err == nil && addr.Is4() +} + +func isValidForwardedNodePort(value string) bool { + if value == "" { + return false + } + if isValidForwardedObfuscatedIdentifier(value) { + return true + } + if len(value) > 5 { + return false + } + port, err := strconv.Atoi(value) + return err == nil && port > 0 && port <= 65535 +} + +func isValidForwardedObfuscatedIdentifier(value string) bool { + if len(value) < 2 || value[0] != '_' { + return false + } + for i := 1; i < len(value); i++ { + b := value[i] + if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') { + continue + } + switch b { + case '.', '_', '-': + continue + default: + return false + } + } + return true +} + func formatForwardedFor(clientIP string) string { addr, err := netip.ParseAddr(clientIP) if err != nil { @@ -817,6 +1126,47 @@ func rewriteReverseProxyURL(req *http.Request, target *url.URL) { } } +func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error { + connectTarget, err := reverseProxyConnectTarget(target) + if err != nil { + return &reverseProxyStatusError{status: http.StatusBadRequest, err: err} + } + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = "" + req.URL.RawPath = "" + req.URL.RawQuery = "" + req.URL.Opaque = connectTarget + req.Host = connectTarget + return nil +} + +func reverseProxyConnectTarget(target *url.URL) (string, error) { + if target == nil { + return "", errReverseProxyNilTarget + } + host := target.Hostname() + if host == "" { + return "", errReverseProxyInvalidTarget + } + port := target.Port() + if port == "" { + switch strings.ToLower(target.Scheme) { + case "http": + port = "80" + case "https": + port = "443" + default: + return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme) + } + } + portNum, err := strconv.Atoi(port) + if err != nil || portNum <= 0 || portNum > 65535 { + return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port) + } + return net.JoinHostPort(host, port), nil +} + func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) { if base.RawPath == "" && incoming.RawPath == "" { return reverseProxySingleJoiningSlash(base.Path, incoming.Path), "" @@ -919,6 +1269,10 @@ func cleanReverseProxyQueryParams(rawQuery string) string { return values.Encode() } +func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { + return req != nil && req.Context().Value(http.ServerContextKey) != nil +} + func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { return UnwrapResponseWriter(writer) } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index f82aff9..b7df512 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -2,6 +2,7 @@ package touka import ( "bufio" + "context" "errors" "fmt" "io" @@ -70,7 +71,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ Target: target, ForwardedHeaders: ForwardedBoth, - ForwardedBy: "proxy-node", + ForwardedBy: "_proxy-node", Via: "proxy.test", })) @@ -144,7 +145,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { if !strings.Contains(got.Forwarded, "for=198.51.100.10") { t.Fatalf("forwarded header missing client ip: %q", got.Forwarded) } - if !strings.Contains(got.Forwarded, "by=proxy-node") { + if !strings.Contains(got.Forwarded, "by=_proxy-node") { t.Fatalf("forwarded header missing by token: %q", got.Forwarded) } if !strings.Contains(got.Forwarded, "host=client.example") { @@ -170,6 +171,61 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { } } +func TestReverseProxyRejectsInvalidForwardedBy(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + ForwardedHeaders: ForwardedBoth, + ForwardedBy: "proxy-node", + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusInternalServerError { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyForwardedByTrimsWhitespace(t *testing.T) { + t.Helper() + + forwardedCh := make(chan string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + forwardedCh <- r.Header.Get("Forwarded") + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, backend.URL), + ForwardedHeaders: ForwardedBoth, + ForwardedBy: " _proxy-node ", + })) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + req.RemoteAddr = "198.51.100.10:4567" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case forwarded := <-forwardedCh: + if !strings.Contains(forwarded, "by=_proxy-node") { + t.Fatalf("unexpected Forwarded header: %q", forwarded) + } + if strings.Contains(forwarded, `by=" _proxy-node "`) { + t.Fatalf("forwarded header should not preserve surrounding whitespace: %q", forwarded) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Forwarded header") + } +} + func TestReverseProxyDefaultViaFallback(t *testing.T) { t.Helper() @@ -229,6 +285,23 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) { } } +func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, context.DeadlineExceeded + }), + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusGatewayTimeout { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) { t.Helper() @@ -452,6 +525,362 @@ func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) { } } +func TestReverseProxyUpgradeNeedsHijacker(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("backend response writer does not support hijack") + } + conn, brw, err := hj.Hijack() + if err != nil { + t.Fatalf("backend hijack failed: %v", err) + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + _ = brw.Flush() + })) + defer backend.Close() + + engine := New() + engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/ws", nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotImplemented { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyMaxForwardsTraceHandledLocally(t *testing.T) { + t.Helper() + + called := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) + req.RequestURI = "/trace" + req.Header.Set("Max-Forwards", "0") + req.Header.Set("Authorization", "secret") + req.Header.Set("Cookie", "a=b") + req.Header.Set("Forwarded", "for=192.0.2.1") + + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + resp := rr.Result() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if got := resp.Header.Get("Content-Type"); got != "message/http" { + t.Fatalf("unexpected content type: %q", got) + } + if !strings.Contains(string(body), "TRACE /trace HTTP/1.1") { + t.Fatalf("trace body missing request line: %q", string(body)) + } + if strings.Contains(string(body), "Authorization:") { + t.Fatalf("trace body leaked authorization header: %q", string(body)) + } + if strings.Contains(string(body), "Cookie:") { + t.Fatalf("trace body leaked cookie header: %q", string(body)) + } + if strings.Contains(string(body), "Forwarded:") { + t.Fatalf("trace body leaked forwarded header: %q", string(body)) + } + + select { + case <-called: + t.Fatal("backend should not be called when Max-Forwards is zero") + default: + } +} + +func TestReverseProxyMaxForwardsTraceDecrementsBeforeForwarding(t *testing.T) { + t.Helper() + + maxForwardsCh := make(chan string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + maxForwardsCh <- r.Header.Get("Max-Forwards") + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) + req.Header.Set("Max-Forwards", "2") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case got := <-maxForwardsCh: + if got != "1" { + t.Fatalf("unexpected Max-Forwards header: %q", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Max-Forwards") + } +} + +func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) { + t.Helper() + + called := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", func(c *Context) { c.Status(http.StatusNoContent) }) + engine.OPTIONS("/proxy", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodOptions, "http://client.example/proxy", nil) + req.Header.Set("Max-Forwards", "0") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rr.Code) + } + allow := rr.Header().Get("Allow") + if !strings.Contains(allow, http.MethodGet) || !strings.Contains(allow, http.MethodOptions) { + t.Fatalf("unexpected Allow header: %q", allow) + } + + select { + case <-called: + t.Fatal("backend should not be called when Max-Forwards is zero") + default: + } +} + +func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { + t.Helper() + + engine := New() + engine.OPTIONS("/", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + req := httptest.NewRequest(http.MethodOptions, "http://client.example/", nil) + req.RequestURI = "*" + req.URL.Path = "" + req.URL.RawPath = "" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code) + } +} + +func TestReverseProxyConnectTunnel(t *testing.T) { + t.Helper() + + backendAddr := "" + errCh := make(chan error, 4) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if got, want := r.RequestURI, backendAddr; got != want { + errCh <- fmt.Errorf("unexpected CONNECT target %q, want %q", got, want) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("backend response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\nVia: 1.1 upstream\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("backend read failed: %w", err) + return + } + _, _ = io.WriteString(brw, strings.ToUpper(line)) + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend write failed: %w", err) + return + } + })) + defer backend.Close() + backendAddr = strings.TrimPrefix(backend.URL, "http://") + + engine := New() + engine.Handle(http.MethodConnect, "/:authority", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, backend.URL), + Via: "proxy.test", + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + + _, err = fmt.Fprintf(conn, "CONNECT origin.example:443 HTTP/1.1\r\nHost: origin.example:443\r\n\r\n") + if err != nil { + t.Fatalf("write connect request: %v", err) + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read status line: %v", err) + } + if !strings.Contains(statusLine, "200") { + t.Fatalf("unexpected status line: %q", statusLine) + } + + headers, err := textproto.NewReader(reader).ReadMIMEHeader() + if err != nil { + t.Fatalf("read headers: %v", err) + } + respHeader := http.Header(headers) + if got := respHeader.Get("Content-Length"); got != "" { + t.Fatalf("CONNECT response should not include Content-Length, got %q", got) + } + if got := respHeader.Get("Transfer-Encoding"); got != "" { + t.Fatalf("CONNECT response should not include Transfer-Encoding, got %q", got) + } + if gotVia := respHeader.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.1 upstream" || gotVia[1] != "1.1 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(conn, "ping\n"); err != nil { + t.Fatalf("write tunneled payload: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read tunneled payload: %v", err) + } + if message != "PING\n" { + t.Fatalf("unexpected tunneled payload: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyConnectNeedsHijacker(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("backend response writer does not support hijack") + } + conn, brw, err := hj.Hijack() + if err != nil { + t.Fatalf("backend hijack failed: %v", err) + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\n\r\n") + _ = brw.Flush() + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/tunnel", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodConnect, "http://client.example/tunnel", nil) + req.URL.Path = "/tunnel" + req.RequestURI = "/tunnel" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotImplemented { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + Body: &failingReadCloser{chunks: []string{"ok"}, err: errors.New("boom")}, + ContentLength: -1, + Request: req, + }, nil + }), + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := proxy.Client().Get(proxy.URL + "/proxy") + if err != nil { + t.Fatalf("perform request: %v", err) + } + _, err = io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err == nil { + t.Fatal("expected body read to fail after upstream copy error") + } +} + func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) { t.Helper() @@ -568,3 +997,21 @@ func mustParseURL(t *testing.T, raw string) *url.URL { } return u } + +type failingReadCloser struct { + chunks []string + err error +} + +func (r *failingReadCloser) Read(p []byte) (int, error) { + if len(r.chunks) == 0 { + return 0, r.err + } + n := copy(p, r.chunks[0]) + r.chunks = r.chunks[1:] + return n, nil +} + +func (r *failingReadCloser) Close() error { + return nil +}