From ed44c592d314c3212222748eb674211916c47120 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 03:18:49 +0800 Subject: [PATCH 1/4] fix(reverseproxy): align forwarding and tunnel semantics --- docs/reverse-proxy.md | 15 +- ecw.go | 2 +- engine.go | 52 +++-- respw.go | 2 +- reverseproxy.go | 368 +++++++++++++++++++++++++++++++++- reverseproxy_test.go | 451 +++++++++++++++++++++++++++++++++++++++++- 6 files changed, 864 insertions(+), 26 deletions(-) diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 5dfcbd1..1dfd760 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -242,11 +242,20 @@ const ( r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "gateway-1", + ForwardedBy: "_gateway-1", Via: "edge-1", })) ``` +如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。 + +- IPv4:`203.0.113.43` +- IPv6 / 带端口:`[2001:db8::17]:443` +- 匿名标识:`_gateway-1` +- 未知:`unknown` + +像 `gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。 + `Via` 不是“留空即禁用”的开关。当前实现中: - 如果 `Via` 非空,则使用该值追加 `Via` @@ -282,11 +291,13 @@ Touka 会尽量遵循代理链语义: Touka 的反向代理实现支持以下能力: +- `CONNECT` 隧道转发(HTTP/1.x) - `Connection: Upgrade` / `Upgrade` 协议升级转发 - WebSocket 等 101 Switching Protocols 场景 - SSE(Server-Sent Events)立即刷新 - Trailer 透传 - 1xx 响应透传 +- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理 例如,代理 WebSocket 服务: @@ -341,7 +352,7 @@ func main() { r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "gateway-1", + ForwardedBy: "_gateway-1", Via: "gateway-1", FlushInterval: 100 * time.Millisecond, ModifyRequest: func(req *http.Request) { diff --git a/ecw.go b/ecw.go index 754571f..dedbe27 100644 --- a/ecw.go +++ b/ecw.go @@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool { func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { hijacker, ok := ecw.w.(http.Hijacker) if !ok { - return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface") + return nil, nil, http.ErrNotSupported } return hijacker.Hijack() } diff --git a/engine.go b/engine.go index c2eae91..a4350c0 100644 --- a/engine.go +++ b/engine.go @@ -475,21 +475,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { func MethodNotAllowed() HandlerFunc { return func(c *Context) { httpMethod := c.Request.Method - requestPath := c.Request.URL.Path + requestPath := routeLookupPath(c.Request) engine := c.engine // 是否是OPTIONS方式 if httpMethod == http.MethodOptions { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := []string{} - for _, treeIter := range engine.methodTrees { - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - allowedMethods = append(allowedMethods, treeIter.method) - } - } + allowedMethods := engine.allowedMethodsForPath(requestPath) if len(allowedMethods) > 0 { // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) @@ -705,7 +696,7 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // 这是路由查找和执行的核心逻辑 func (engine *Engine) handleRequest(c *Context) { httpMethod := c.Request.Method - requestPath := c.Request.URL.Path + requestPath := routeLookupPath(c.Request) // 查找对应的路由树的根节点 rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 @@ -725,7 +716,7 @@ func (engine *Engine) handleRequest(c *Context) { } // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) - if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 + if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向 if value.tsr && engine.RedirectTrailingSlash { // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ redirectPath := requestPath @@ -782,6 +773,41 @@ func (engine *Engine) handleRequest(c *Context) { //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } +func routeLookupPath(req *http.Request) string { + if req == nil { + return "" + } + + if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") { + return "/" + req.RequestURI + } + if isGeneralOptionsRequest(req) { + return "" + } + if req.URL == nil { + return "" + } + return req.URL.Path +} + +func isGeneralOptionsRequest(req *http.Request) bool { + return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" +} + +func (engine *Engine) allowedMethodsForPath(requestPath string) []string { + allowedMethods := make([]string, 0, len(engine.methodTrees)) + for _, treeIter := range engine.methodTrees { + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) + PutTempSkippedNodes(tempSkippedNodes) + if value.handlers != nil { + allowedMethods = append(allowedMethods, treeIter.method) + } + } + return allowedMethods +} + // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // 它可以用于在长连接 (如 SSE) 中监听关闭信号. func (engine *Engine) Context() context.Context { diff --git a/respw.go b/respw.go index dd94db3..ef5cc3c 100644 --- a/respw.go +++ b/respw.go @@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) { // 尝试从底层 ResponseWriter 获取 Hijacker 接口 hj, ok := rw.ResponseWriter.(http.Hijacker) if !ok { - return nil, nil, errors.New("http.Hijacker interface not supported") + return nil, nil, http.ErrNotSupported } // 调用底层的 Hijack 方法 diff --git a/reverseproxy.go b/reverseproxy.go index 1730b1e..977402b 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -14,6 +14,7 @@ import ( "net" "net/http" "net/http/httptrace" + "net/http/httputil" "net/netip" "net/textproto" "net/url" @@ -217,6 +218,12 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { default: proxy.config.ForwardedHeaders = ForwardedBoth } + proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy) + if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) { + if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil { + proxy.configError = err + } + } return proxy } @@ -234,11 +241,20 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { transport = http.DefaultTransport } + updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) + if err != nil { + p.handleError(c, err) + return + } + if handledLocally { + return + } + ctx, cancel := p.requestContext(c) defer cancel() outreq := c.Request.Clone(ctx) - if c.Request.ContentLength == 0 { + if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { outreq.Body = nil } if outreq.Body != nil { @@ -249,12 +265,35 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { outreq.Header = make(http.Header) } outreq.Close = false - - rewriteReverseProxyURL(outreq, p.target) - if !p.config.PreserveHost { - outreq.Host = "" + var connectWriter *io.PipeWriter + defer func() { + if connectWriter != nil { + _ = connectWriter.Close() + } + }() + if outreq.Method == http.MethodConnect { + pipeReader, pipeWriter := io.Pipe() + outreq.Body = pipeReader + outreq.ContentLength = -1 + defer outreq.Body.Close() + connectWriter = pipeWriter + } + + if outreq.Method == http.MethodConnect { + if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { + p.handleError(c, err) + return + } + } else { + rewriteReverseProxyURL(outreq, p.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } + if updatedMaxForwards != "" { + outreq.Header.Set("Max-Forwards", updatedMaxForwards) } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) reqUpType := reverseProxyUpgradeType(outreq.Header) if reqUpType != "" && !isPrintableASCII(reqUpType) { @@ -318,6 +357,23 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { return } + if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { + removeHopByHopHeaders(res.Header) + res.Header.Del("Content-Length") + res.Header.Del("Transfer-Encoding") + res.ContentLength = -1 + res.TransferEncoding = nil + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return + } + if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil { + p.handleError(c, err) + } + connectWriter = nil + return + } + if res.StatusCode == http.StatusSwitchingProtocols { appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { @@ -353,6 +409,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { defer res.Body.Close() c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err)) p.logf(c, "reverse proxy body copy failed: %v", err) + if reverseProxyShouldPanicOnCopyError(c.Request) { + panic(http.ErrAbortHandler) + } return } res.Body.Close() @@ -378,6 +437,86 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { } } +func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) { + if c == nil || c.Request == nil { + return "", false, nil + } + + switch c.Request.Method { + case http.MethodOptions, http.MethodTrace: + default: + return "", false, nil + } + + rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards")) + if rawValue == "" { + return "", false, nil + } + + value, err := strconv.Atoi(rawValue) + if err != nil || value < 0 { + return "", false, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("invalid Max-Forwards value %q", rawValue), + } + } + if value == 0 { + switch c.Request.Method { + case http.MethodTrace: + return "", true, p.writeLocalTraceResponse(c) + case http.MethodOptions: + p.writeLocalOptionsResponse(c) + return "", true, nil + } + } + + return strconv.Itoa(value - 1), false, nil +} + +func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error { + if c == nil || c.Request == nil { + return nil + } + + traceReq := c.Request.Clone(c.Request.Context()) + traceReq.Body = nil + traceReq.ContentLength = 0 + traceReq.TransferEncoding = nil + traceReq.RequestURI = c.Request.RequestURI + if traceReq.RequestURI == "" && traceReq.URL != nil { + traceReq.RequestURI = traceReq.URL.RequestURI() + } + traceReq.Header = traceReq.Header.Clone() + for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} { + traceReq.Header.Del(key) + } + + dump, err := httputil.DumpRequest(traceReq, false) + if err != nil { + return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err} + } + + c.Writer.Header().Set("Content-Type", "message/http") + c.Writer.WriteHeader(http.StatusOK) + _, err = c.Writer.Write(dump) + return err +} + +func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { + if c == nil { + return + } + + if c.engine != nil { + if c.Request != nil && c.Request.RequestURI != "*" { + if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 { + c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) + } + } + } + c.Writer.WriteHeader(http.StatusOK) +} + func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) { ctx := c.Request.Context() if ctx.Done() != nil { @@ -522,7 +661,11 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques clientConn, brw, err := c.Writer.Hijack() if err != nil { backConn.Close() - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + status := http.StatusBadGateway + if errors.Is(err, http.ErrNotSupported) { + status = http.StatusNotImplemented + } + return &reverseProxyStatusError{status: status, err: err} } defer clientConn.Close() @@ -561,6 +704,80 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { + if backWrite == nil { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"), + } + } + backRead := res.Body + + clientConn, brw, err := c.Writer.Hijack() + if err != nil { + backRead.Close() + _ = backWrite.Close() + status := http.StatusBadGateway + if errors.Is(err, http.ErrNotSupported) { + status = http.StatusNotImplemented + } + return &reverseProxyStatusError{status: status, err: err} + } + + defer clientConn.Close() + defer backRead.Close() + defer backWrite.Close() + + backConnClosed := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + case <-backConnClosed: + } + backRead.Close() + _ = backWrite.Close() + }() + defer close(backConnClosed) + + res.Body = nil + if err := res.Write(brw); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + if err := brw.Flush(); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + go func() { + if _, err := io.Copy(clientConn, backRead); err != nil { + errc <- err + return + } + if cw, ok := clientConn.(interface{ CloseWrite() error }); ok { + errc <- cw.CloseWrite() + return + } + errc <- errReverseProxyCopyDone + }() + go func() { + if _, err := io.Copy(backWrite, clientConn); err != nil { + errc <- err + return + } + errc <- backWrite.Close() + }() + + firstErr := <-errc + if firstErr == nil { + firstErr = <-errc + } + if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) { + return nil + } + return firstErr +} + func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { return -1 @@ -638,6 +855,10 @@ func reverseProxyStatusCode(err error) int { if errors.As(err, &statusErr) && statusErr.status > 0 { return statusErr.status } + var netErr net.Error + if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) { + return http.StatusGatewayTimeout + } return http.StatusBadGateway } @@ -651,6 +872,17 @@ func validateReverseProxyTarget(target *url.URL) error { return nil } +func validateReverseProxyForwardedBy(value string) error { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return nil + } + if !isValidForwardedNodeIdentifier(trimmed) { + return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value) + } + return nil +} + func normalizeReverseProxyTarget(target *url.URL) { switch strings.ToLower(target.Scheme) { case "ws": @@ -732,6 +964,83 @@ func buildForwardedHeaderValue(clientIP, by, host, scheme string) string { return strings.Join(pairs, ";") } +func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { + return policy == ForwardedBoth || policy == ForwardedRFC7239Only +} + +func isValidForwardedNodeIdentifier(value string) bool { + if value == "" { + return false + } + if strings.HasPrefix(value, "[") { + closing := strings.IndexByte(value, ']') + if closing <= 1 { + return false + } + addr, err := netip.ParseAddr(value[1:closing]) + if err != nil || !addr.Is6() { + return false + } + if closing == len(value)-1 { + return true + } + if value[closing+1] != ':' { + return false + } + return isValidForwardedNodePort(value[closing+2:]) + } + + host, port, hasPort := strings.Cut(value, ":") + if hasPort { + switch { + case host == "unknown", isValidForwardedObfuscatedIdentifier(host): + return isValidForwardedNodePort(port) + default: + addr, err := netip.ParseAddr(host) + return err == nil && addr.Is4() && isValidForwardedNodePort(port) + } + } + + if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) { + return true + } + addr, err := netip.ParseAddr(value) + return err == nil && addr.Is4() +} + +func isValidForwardedNodePort(value string) bool { + if value == "" { + return false + } + if isValidForwardedObfuscatedIdentifier(value) { + return true + } + if len(value) > 5 { + return false + } + port, err := strconv.Atoi(value) + return err == nil && port > 0 && port <= 65535 +} + +func isValidForwardedObfuscatedIdentifier(value string) bool { + if len(value) < 2 || value[0] != '_' { + return false + } + for i := 1; i < len(value); i++ { + b := value[i] + if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') { + continue + } + switch b { + case '.', '_', '-': + continue + default: + return false + } + } + return true +} + func formatForwardedFor(clientIP string) string { addr, err := netip.ParseAddr(clientIP) if err != nil { @@ -817,6 +1126,47 @@ func rewriteReverseProxyURL(req *http.Request, target *url.URL) { } } +func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error { + connectTarget, err := reverseProxyConnectTarget(target) + if err != nil { + return &reverseProxyStatusError{status: http.StatusBadRequest, err: err} + } + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = "" + req.URL.RawPath = "" + req.URL.RawQuery = "" + req.URL.Opaque = connectTarget + req.Host = connectTarget + return nil +} + +func reverseProxyConnectTarget(target *url.URL) (string, error) { + if target == nil { + return "", errReverseProxyNilTarget + } + host := target.Hostname() + if host == "" { + return "", errReverseProxyInvalidTarget + } + port := target.Port() + if port == "" { + switch strings.ToLower(target.Scheme) { + case "http": + port = "80" + case "https": + port = "443" + default: + return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme) + } + } + portNum, err := strconv.Atoi(port) + if err != nil || portNum <= 0 || portNum > 65535 { + return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port) + } + return net.JoinHostPort(host, port), nil +} + func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) { if base.RawPath == "" && incoming.RawPath == "" { return reverseProxySingleJoiningSlash(base.Path, incoming.Path), "" @@ -919,6 +1269,10 @@ func cleanReverseProxyQueryParams(rawQuery string) string { return values.Encode() } +func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { + return req != nil && req.Context().Value(http.ServerContextKey) != nil +} + func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { return UnwrapResponseWriter(writer) } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index f82aff9..b7df512 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -2,6 +2,7 @@ package touka import ( "bufio" + "context" "errors" "fmt" "io" @@ -70,7 +71,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ Target: target, ForwardedHeaders: ForwardedBoth, - ForwardedBy: "proxy-node", + ForwardedBy: "_proxy-node", Via: "proxy.test", })) @@ -144,7 +145,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { if !strings.Contains(got.Forwarded, "for=198.51.100.10") { t.Fatalf("forwarded header missing client ip: %q", got.Forwarded) } - if !strings.Contains(got.Forwarded, "by=proxy-node") { + if !strings.Contains(got.Forwarded, "by=_proxy-node") { t.Fatalf("forwarded header missing by token: %q", got.Forwarded) } if !strings.Contains(got.Forwarded, "host=client.example") { @@ -170,6 +171,61 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { } } +func TestReverseProxyRejectsInvalidForwardedBy(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + ForwardedHeaders: ForwardedBoth, + ForwardedBy: "proxy-node", + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusInternalServerError { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyForwardedByTrimsWhitespace(t *testing.T) { + t.Helper() + + forwardedCh := make(chan string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + forwardedCh <- r.Header.Get("Forwarded") + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, backend.URL), + ForwardedHeaders: ForwardedBoth, + ForwardedBy: " _proxy-node ", + })) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + req.RemoteAddr = "198.51.100.10:4567" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case forwarded := <-forwardedCh: + if !strings.Contains(forwarded, "by=_proxy-node") { + t.Fatalf("unexpected Forwarded header: %q", forwarded) + } + if strings.Contains(forwarded, `by=" _proxy-node "`) { + t.Fatalf("forwarded header should not preserve surrounding whitespace: %q", forwarded) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Forwarded header") + } +} + func TestReverseProxyDefaultViaFallback(t *testing.T) { t.Helper() @@ -229,6 +285,23 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) { } } +func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, context.DeadlineExceeded + }), + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusGatewayTimeout { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) { t.Helper() @@ -452,6 +525,362 @@ func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) { } } +func TestReverseProxyUpgradeNeedsHijacker(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("backend response writer does not support hijack") + } + conn, brw, err := hj.Hijack() + if err != nil { + t.Fatalf("backend hijack failed: %v", err) + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + _ = brw.Flush() + })) + defer backend.Close() + + engine := New() + engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/ws", nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotImplemented { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyMaxForwardsTraceHandledLocally(t *testing.T) { + t.Helper() + + called := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) + req.RequestURI = "/trace" + req.Header.Set("Max-Forwards", "0") + req.Header.Set("Authorization", "secret") + req.Header.Set("Cookie", "a=b") + req.Header.Set("Forwarded", "for=192.0.2.1") + + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + resp := rr.Result() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if got := resp.Header.Get("Content-Type"); got != "message/http" { + t.Fatalf("unexpected content type: %q", got) + } + if !strings.Contains(string(body), "TRACE /trace HTTP/1.1") { + t.Fatalf("trace body missing request line: %q", string(body)) + } + if strings.Contains(string(body), "Authorization:") { + t.Fatalf("trace body leaked authorization header: %q", string(body)) + } + if strings.Contains(string(body), "Cookie:") { + t.Fatalf("trace body leaked cookie header: %q", string(body)) + } + if strings.Contains(string(body), "Forwarded:") { + t.Fatalf("trace body leaked forwarded header: %q", string(body)) + } + + select { + case <-called: + t.Fatal("backend should not be called when Max-Forwards is zero") + default: + } +} + +func TestReverseProxyMaxForwardsTraceDecrementsBeforeForwarding(t *testing.T) { + t.Helper() + + maxForwardsCh := make(chan string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + maxForwardsCh <- r.Header.Get("Max-Forwards") + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) + req.Header.Set("Max-Forwards", "2") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case got := <-maxForwardsCh: + if got != "1" { + t.Fatalf("unexpected Max-Forwards header: %q", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Max-Forwards") + } +} + +func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) { + t.Helper() + + called := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", func(c *Context) { c.Status(http.StatusNoContent) }) + engine.OPTIONS("/proxy", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodOptions, "http://client.example/proxy", nil) + req.Header.Set("Max-Forwards", "0") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rr.Code) + } + allow := rr.Header().Get("Allow") + if !strings.Contains(allow, http.MethodGet) || !strings.Contains(allow, http.MethodOptions) { + t.Fatalf("unexpected Allow header: %q", allow) + } + + select { + case <-called: + t.Fatal("backend should not be called when Max-Forwards is zero") + default: + } +} + +func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { + t.Helper() + + engine := New() + engine.OPTIONS("/", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + req := httptest.NewRequest(http.MethodOptions, "http://client.example/", nil) + req.RequestURI = "*" + req.URL.Path = "" + req.URL.RawPath = "" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code) + } +} + +func TestReverseProxyConnectTunnel(t *testing.T) { + t.Helper() + + backendAddr := "" + errCh := make(chan error, 4) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if got, want := r.RequestURI, backendAddr; got != want { + errCh <- fmt.Errorf("unexpected CONNECT target %q, want %q", got, want) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("backend response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\nVia: 1.1 upstream\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("backend read failed: %w", err) + return + } + _, _ = io.WriteString(brw, strings.ToUpper(line)) + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend write failed: %w", err) + return + } + })) + defer backend.Close() + backendAddr = strings.TrimPrefix(backend.URL, "http://") + + engine := New() + engine.Handle(http.MethodConnect, "/:authority", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, backend.URL), + Via: "proxy.test", + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + + _, err = fmt.Fprintf(conn, "CONNECT origin.example:443 HTTP/1.1\r\nHost: origin.example:443\r\n\r\n") + if err != nil { + t.Fatalf("write connect request: %v", err) + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read status line: %v", err) + } + if !strings.Contains(statusLine, "200") { + t.Fatalf("unexpected status line: %q", statusLine) + } + + headers, err := textproto.NewReader(reader).ReadMIMEHeader() + if err != nil { + t.Fatalf("read headers: %v", err) + } + respHeader := http.Header(headers) + if got := respHeader.Get("Content-Length"); got != "" { + t.Fatalf("CONNECT response should not include Content-Length, got %q", got) + } + if got := respHeader.Get("Transfer-Encoding"); got != "" { + t.Fatalf("CONNECT response should not include Transfer-Encoding, got %q", got) + } + if gotVia := respHeader.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.1 upstream" || gotVia[1] != "1.1 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(conn, "ping\n"); err != nil { + t.Fatalf("write tunneled payload: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read tunneled payload: %v", err) + } + if message != "PING\n" { + t.Fatalf("unexpected tunneled payload: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyConnectNeedsHijacker(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("backend response writer does not support hijack") + } + conn, brw, err := hj.Hijack() + if err != nil { + t.Fatalf("backend hijack failed: %v", err) + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\n\r\n") + _ = brw.Flush() + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/tunnel", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodConnect, "http://client.example/tunnel", nil) + req.URL.Path = "/tunnel" + req.RequestURI = "/tunnel" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotImplemented { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + Body: &failingReadCloser{chunks: []string{"ok"}, err: errors.New("boom")}, + ContentLength: -1, + Request: req, + }, nil + }), + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := proxy.Client().Get(proxy.URL + "/proxy") + if err != nil { + t.Fatalf("perform request: %v", err) + } + _, err = io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err == nil { + t.Fatal("expected body read to fail after upstream copy error") + } +} + func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) { t.Helper() @@ -568,3 +997,21 @@ func mustParseURL(t *testing.T, raw string) *url.URL { } return u } + +type failingReadCloser struct { + chunks []string + err error +} + +func (r *failingReadCloser) Read(p []byte) (int, error) { + if len(r.chunks) == 0 { + return 0, r.err + } + n := copy(p, r.chunks[0]) + r.chunks = r.chunks[1:] + return n, nil +} + +func (r *failingReadCloser) Close() error { + return nil +} From 2165cc4114e9c33a7b8176df17b0ffe370abe822 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 03:53:17 +0800 Subject: [PATCH 2/4] feat(http2): support OPTIONS * and extended CONNECT --- docs/reverse-proxy.md | 1 + docs/routing.md | 2 + engine.go | 25 +++++++++ go.mod | 3 +- go.sum | 2 + http2xconnect.go | 53 ++++++++++++++++++ reverseproxy.go | 119 ++++++++++++++++++++++++++++++++++++---- reverseproxy_test.go | 123 +++++++++++++++++++++++++++++++++++++++++- 8 files changed, 316 insertions(+), 12 deletions(-) create mode 100644 http2xconnect.go diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 1dfd760..959d866 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -292,6 +292,7 @@ Touka 会尽量遵循代理链语义: Touka 的反向代理实现支持以下能力: - `CONNECT` 隧道转发(HTTP/1.x) +- HTTP/2 extended `CONNECT` - `Connection: Upgrade` / `Upgrade` 协议升级转发 - WebSocket 等 101 Switching Protocols 场景 - SSE(Server-Sent Events)立即刷新 diff --git a/docs/routing.md b/docs/routing.md index e90308e..223081a 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -22,6 +22,8 @@ r.ANY("/any", handle) r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) ``` +服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。 + ## 路径参数 (Named Parameters) 使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 diff --git a/engine.go b/engine.go index a4350c0..b7cf330 100644 --- a/engine.go +++ b/engine.go @@ -7,6 +7,7 @@ package touka import ( "context" "errors" + "io" "reflect" "runtime" "strings" @@ -344,6 +345,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) { func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { if engine.serverProtocols != nil { srv.Protocols = engine.serverProtocols + if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() { + if err := configureHTTP2ExtendedConnectServer(srv); err != nil { + panic(err) + } + } } } @@ -695,6 +701,11 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // handleRequest 负责根据请求查找路由并执行相应的处理函数链 // 这是路由查找和执行的核心逻辑 func (engine *Engine) handleRequest(c *Context) { + if isGeneralOptionsRequest(c.Request) { + engine.handleGeneralOptions(c) + return + } + httpMethod := c.Request.Method requestPath := routeLookupPath(c.Request) @@ -808,6 +819,20 @@ func (engine *Engine) allowedMethodsForPath(requestPath string) []string { return allowedMethods } +func (engine *Engine) handleGeneralOptions(c *Context) { + if c == nil || c.Request == nil { + return + } + + c.Writer.Header().Set("Content-Length", "0") + if c.Request.ContentLength != 0 { + mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10) + _, _ = io.Copy(io.Discard, mb) + } + c.Writer.WriteHeader(http.StatusOK) + c.Abort() +} + // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // 它可以用于在长连接 (如 SSE) 中监听关闭信号. func (engine *Engine) Context() context.Context { diff --git a/go.mod b/go.mod index 42f4be4..bd0c046 100644 --- a/go.mod +++ b/go.mod @@ -8,9 +8,10 @@ require ( github.com/WJQSERVER/wanf v0.0.8 github.com/fenthope/reco v0.0.5 github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 + golang.org/x/net v0.52.0 ) require ( github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.52.0 // indirect + golang.org/x/text v0.35.0 // indirect ) diff --git a/go.sum b/go.sum index b49879b..6a8d0c6 100644 --- a/go.sum +++ b/go.sum @@ -12,3 +12,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= diff --git a/http2xconnect.go b/http2xconnect.go new file mode 100644 index 0000000..b3b12a0 --- /dev/null +++ b/http2xconnect.go @@ -0,0 +1,53 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2026 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "net/url" + "strings" + "sync" + _ "unsafe" + + "golang.org/x/net/http2" +) + +var enableHTTP2ExtendedConnectOnce sync.Once + +//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol +var xnetDisableHTTP2ExtendedConnectProtocol bool + +func enableHTTP2ExtendedConnectProtocol() { + enableHTTP2ExtendedConnectOnce.Do(func() { + xnetDisableHTTP2ExtendedConnectProtocol = false + }) +} + +func configureHTTP2ExtendedConnectServer(srv *http.Server) error { + if srv == nil { + return nil + } + enableHTTP2ExtendedConnectProtocol() + return http2.ConfigureServer(srv, nil) +} + +func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper { + enableHTTP2ExtendedConnectProtocol() + + transport := &http2.Transport{} + if target == nil || !strings.EqualFold(target.Scheme, "http") { + return transport + } + + transport.AllowHTTP = true + transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) + } + return transport +} diff --git a/reverseproxy.go b/reverseproxy.go index 977402b..e01f4d0 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -67,10 +67,11 @@ var ( ) type reverseProxyHandler struct { - config ReverseProxyConfig - target *url.URL - receivedBy string - configError error + config ReverseProxyConfig + target *url.URL + receivedBy string + configError error + extendedConnectTransport http.RoundTripper } type reverseProxyStatusError struct { @@ -208,6 +209,9 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { target: target, receivedBy: reverseProxyReceivedBy(config.Via), } + if config.Transport == nil { + proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) + } if err := validateReverseProxyTarget(target); err != nil { proxy.configError = err @@ -238,7 +242,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { transport := p.config.Transport if transport == nil { - transport = http.DefaultTransport + if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil { + transport = p.extendedConnectTransport + } else { + transport = http.DefaultTransport + } } updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) @@ -280,9 +288,17 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { } if outreq.Method == http.MethodConnect { - if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { - p.handleError(c, err) - return + if reverseProxyIsExtendedConnectRequest(outreq) { + rewriteReverseProxyURL(outreq, p.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } else { + if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { + p.handleError(c, err) + return + } } } else { rewriteReverseProxyURL(outreq, p.target) @@ -367,7 +383,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { if !p.modifyResponse(c, res, outreq) { return } - if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil { + handleConnect := p.handleConnectResponse + if reverseProxyIsExtendedConnectRequest(outreq) { + handleConnect = p.handleExtendedConnectResponse + } + if err := handleConnect(c, outreq, res, connectWriter); err != nil { p.handleError(c, err) } connectWriter = nil @@ -778,6 +798,72 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { + if c == nil || c.Request == nil { + res.Body.Close() + if backWrite != nil { + _ = backWrite.Close() + } + return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")} + } + if backWrite == nil { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"), + } + } + + controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + res.Body.Close() + _ = backWrite.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + reverseProxyCopyHeader(c.Writer.Header(), res.Header) + c.Writer.WriteHeader(res.StatusCode) + if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + res.Body.Close() + _ = backWrite.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(backWrite, c.Request.Body) + closeErr := backWrite.Close() + if err != nil && !reverseProxyIsBenignTunnelError(err) { + errc <- err + return + } + errc <- closeErr + }() + go func() { + copyErr := p.copyResponse(c.Writer, res.Body, -1) + closeErr := res.Body.Close() + if copyErr != nil { + errc <- copyErr + return + } + errc <- closeErr + }() + + firstErr := <-errc + _ = c.Request.Body.Close() + _ = backWrite.Close() + _ = res.Body.Close() + secondErr := <-errc + + for _, err := range []error{firstErr, secondErr} { + if reverseProxyIsBenignTunnelError(err) { + continue + } + return err + } + return nil +} + func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { return -1 @@ -968,6 +1054,17 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } +func reverseProxyIsExtendedConnectRequest(req *http.Request) bool { + return reverseProxyExtendedConnectProtocol(req) != "" +} + +func reverseProxyExtendedConnectProtocol(req *http.Request) string { + if req == nil || req.Method != http.MethodConnect || req.Header == nil { + return "" + } + return textproto.TrimString(req.Header.Get(":protocol")) +} + func isValidForwardedNodeIdentifier(value string) bool { if value == "" { return false @@ -1273,6 +1370,10 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { return req != nil && req.Context().Value(http.ServerContextKey) != nil } +func reverseProxyIsBenignTunnelError(err error) bool { + return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) +} + func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { return UnwrapResponseWriter(writer) } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index b7df512..345dd97 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -3,6 +3,7 @@ package touka import ( "bufio" "context" + "crypto/tls" "errors" "fmt" "io" @@ -15,6 +16,8 @@ import ( "strings" "testing" "time" + + "golang.org/x/net/http2" ) func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { @@ -680,7 +683,7 @@ func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) { } } -func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { +func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) { t.Helper() engine := New() @@ -695,9 +698,12 @@ func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { rr := httptest.NewRecorder() engine.ServeHTTP(rr, req) - if rr.Code != http.StatusNotFound { + if rr.Code != http.StatusOK { t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code) } + if got := rr.Header().Get("Content-Length"); got != "0" { + t.Fatalf("unexpected Content-Length header: %q", got) + } } func TestReverseProxyConnectTunnel(t *testing.T) { @@ -848,6 +854,119 @@ func TestReverseProxyConnectNeedsHijacker(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.ProtoMajor != 2 { + errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get(":protocol"); got != "websocket" { + errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.URL.Path; got != "/ws" { + errCh <- fmt.Errorf("unexpected upstream path: %q", got) + w.WriteHeader(http.StatusBadRequest) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("enable full duplex failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + line, err := bufio.NewReader(r.Body).ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(w, "echo:"+line); err != nil { + errCh <- fmt.Errorf("write tunneled response body failed: %w", err) + return + } + _ = controller.Flush() + })) + upstream.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { + t.Fatalf("configure upstream HTTP/2 server: %v", err) + } + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled response body: %q", message) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { t.Helper() From 59f190ce3a6097e659fcb46ecc630a65a8e8eebb Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 04:09:43 +0800 Subject: [PATCH 3/4] fix(http2): preserve extended CONNECT tunnel shutdown semantics --- reverseproxy.go | 52 ++++++++--- reverseproxy_test.go | 217 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+), 11 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index e01f4d0..bb1784b 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -829,6 +829,19 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } + var closeOnce sync.Once + closeTunnel := func() { + closeOnce.Do(func() { + _ = c.Request.Body.Close() + _ = backWrite.Close() + _ = res.Body.Close() + }) + } + go func() { + <-req.Context().Done() + closeTunnel() + }() + errc := make(chan error, 2) go func() { _, err := io.Copy(backWrite, c.Request.Body) @@ -849,19 +862,24 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt errc <- closeErr }() - firstErr := <-errc - _ = c.Request.Body.Close() - _ = backWrite.Close() - _ = res.Body.Close() - secondErr := <-errc - - for _, err := range []error{firstErr, secondErr} { + var firstErr error + for i := 0; i < 2; i++ { + err := <-errc if reverseProxyIsBenignTunnelError(err) { continue } - return err + if firstErr == nil { + firstErr = err + closeTunnel() + } } - return nil + closeTunnel() + if reverseProxyIsBenignTunnelError(firstErr) { + return nil + } + + return firstErr + } func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { @@ -902,7 +920,7 @@ func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byt var written int64 for { nr, rerr := src.Read(buf) - if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) { + if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) { p.logf(nil, "reverse proxy read error during body copy: %v", rerr) } if nr > 0 { @@ -1371,7 +1389,19 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { } func reverseProxyIsBenignTunnelError(err error) bool { - return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) + return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err) +} + +func reverseProxyIsClosedBodyError(err error) bool { + if err == nil { + return false + } + switch err.Error() { + case "body closed by handler", "http2: response body closed", "response body closed": + return true + default: + return false + } } func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 345dd97..e56aa5e 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -967,6 +967,223 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("enable full duplex failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + reader := bufio.NewReader(r.Body) + line, err := reader.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(w, "ack:"+line); err != nil { + errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err) + return + } + _ = controller.Flush() + + if _, err := io.Copy(io.Discard, reader); err != nil { + errCh <- fmt.Errorf("wait for request half-close failed: %w", err) + return + } + if _, err := io.WriteString(w, "after-close\n"); err != nil { + errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err) + return + } + _ = controller.Flush() + })) + upstream.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { + t.Fatalf("configure upstream HTTP/2 server: %v", err) + } + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + reader := bufio.NewReader(resp.Body) + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read immediate tunneled response: %v", err) + } + if message != "ack:ping\n" { + t.Fatalf("unexpected immediate tunneled response: %q", message) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + + message, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("read post-close tunneled response: %v", err) + } + if message != "after-close\n" { + t.Fatalf("unexpected post-close tunneled response: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("enable full duplex failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + <-r.Context().Done() + })) + upstream.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { + t.Fatalf("configure upstream HTTP/2 server: %v", err) + } + upstream.StartTLS() + defer upstream.Close() + + proxyErrCh := make(chan error, 1) + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + select { + case proxyErrCh <- err: + default: + } + }, + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(ctx, http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + writeErrCh := make(chan error, 1) + go func() { + _, err := io.WriteString(pw, strings.Repeat("x", 1<<20)) + writeErrCh <- err + }() + time.Sleep(50 * time.Millisecond) + + cancel() + _ = pw.CloseWithError(context.Canceled) + select { + case <-writeErrCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for request body writer to unblock") + } + + select { + case err := <-proxyErrCh: + t.Fatalf("proxy error handler should not be called on cancellation, got: %v", err) + case <-time.After(200 * time.Millisecond): + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { t.Helper() From 919236665bfa59a55a3c46e3bfec9a4114edf28f Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:40:56 +0800 Subject: [PATCH 4/4] feat(reverseproxy): add upstream balancing and failover --- docs/reverse-proxy.md | 113 +++++++- reverseproxy.go | 426 +++++++++++++++++++++-------- reverseproxy_lb.go | 352 ++++++++++++++++++++++++ reverseproxy_test.go | 619 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1394 insertions(+), 116 deletions(-) create mode 100644 reverseproxy_lb.go diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 959d866..15ebafd 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -59,7 +59,11 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ ```go type ReverseProxyConfig struct { - Target *url.URL + Target *url.URL + Targets []string + + LoadBalancing ReverseProxyLoadBalancingConfig + PassiveHealth ReverseProxyPassiveHealthConfig Transport http.RoundTripper FlushInterval time.Duration @@ -78,12 +82,115 @@ type ReverseProxyConfig struct { ### `Target` -必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。 +与 `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme` 和 `host`。 ```go target, _ := url.Parse("http://backend:9000") ``` +### `Targets` + +可选。用于配置多个后端目标地址。 + +- `Target` 与 `Targets` 互斥,只能使用其中一种 +- `Targets` 的每一项都必须是完整 URL +- 每个 target 仍然可以自带 base path 和 query + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Targets: []string{ + "http://127.0.0.1:9001/base?from=a", + "http://127.0.0.1:9002/base?from=b", + }, +})) +``` + +这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。 + +### `LoadBalancing` + +用于配置 upstream 选择策略和重试行为。 + +```go +type ReverseProxyLoadBalancingConfig struct { + Policy ReverseProxyLBPolicy + Retries int + TryDuration time.Duration + TryInterval time.Duration +} +``` + +当前内置策略: + +- `touka.LBRandom()` +- `touka.LBRoundRobin()` +- `touka.LBFirst()` +- `touka.LBLeastConn()` +- `touka.LBIPHash()` +- `touka.LBClientIPHash()` +- `touka.LBURIHash()` +- `touka.LBHeader("X-Upstream", fallback)` +- `touka.LBQuery("tenant", fallback)` + +其中: + +- `LBFirst()` 适合主备/故障转移顺序 +- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback +- 如果 `LBHeader` / `LBQuery` 没有显式 fallback,则默认回退到 `LBRandom()` + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Targets: []string{ + "http://127.0.0.1:9001", + "http://127.0.0.1:9002", + }, + LoadBalancing: touka.ReverseProxyLoadBalancingConfig{ + Policy: touka.LBHeader("X-Upstream", touka.LBFirst()), + Retries: 1, + }, +})) +``` + +重试说明: + +- 只对未开始收到上游响应的失败进行重试 +- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试 +- `Retries` 表示额外重试次数 +- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机 +- `TryInterval` 表示两次重试之间的等待间隔 + +### `PassiveHealth` + +用于配置被动健康检查。它不会后台探测 upstream,而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。 + +```go +type ReverseProxyPassiveHealthConfig struct { + FailDuration time.Duration + MaxFails int + UnhealthyStatus []int +} +``` + +- `FailDuration > 0` 时启用被动健康跟踪 +- `MaxFails <= 0` 时默认按 `1` 处理 +- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Targets: []string{ + "http://127.0.0.1:9001", + "http://127.0.0.1:9002", + }, + LoadBalancing: touka.ReverseProxyLoadBalancingConfig{ + Policy: touka.LBFirst(), + }, + PassiveHealth: touka.ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + UnhealthyStatus: []int{http.StatusServiceUnavailable}, + }, +})) +``` + ### `Transport` 可选。用于自定义底层转发所使用的 `http.RoundTripper`。 @@ -150,6 +257,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ 在请求真正发往后端前,对出站请求做最后修改。 +如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。 + 常见用途: - 覆盖 `Host` diff --git a/reverseproxy.go b/reverseproxy.go index bb1784b..186e163 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -23,6 +23,8 @@ import ( "sync" "sync/atomic" "time" + + "golang.org/x/net/http2" ) // ForwardedHeadersPolicy controls how forwarding headers are generated. @@ -44,7 +46,11 @@ type BufferPool interface { // ReverseProxyConfig configures the reverse proxy handler. type ReverseProxyConfig struct { - Target *url.URL + Target *url.URL + Targets []string + + LoadBalancing ReverseProxyLoadBalancingConfig + PassiveHealth ReverseProxyPassiveHealthConfig Transport http.RoundTripper FlushInterval time.Duration @@ -61,17 +67,18 @@ type ReverseProxyConfig struct { } var ( - errReverseProxyNilTarget = errors.New("reverse proxy target is nil") - errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") - errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") + errReverseProxyNilTarget = errors.New("reverse proxy target is nil") + errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") + errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") + errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams") ) type reverseProxyHandler struct { - config ReverseProxyConfig - target *url.URL - receivedBy string - configError error - extendedConnectTransport http.RoundTripper + config ReverseProxyConfig + upstreams []*reverseProxyUpstream + receivedBy string + configError error + roundRobin atomic.Uint64 } type reverseProxyStatusError struct { @@ -199,22 +206,16 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc { } func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { - target := cloneReverseProxyURL(config.Target) - if target != nil { - normalizeReverseProxyTarget(target) - } - proxy := &reverseProxyHandler{ config: config, - target: target, receivedBy: reverseProxyReceivedBy(config.Via), } - if config.Transport == nil { - proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) - } - if err := validateReverseProxyTarget(target); err != nil { + upstreams, err := buildReverseProxyUpstreams(config) + if err != nil { proxy.configError = err + } else { + proxy.upstreams = upstreams } switch config.ForwardedHeaders { @@ -228,6 +229,11 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { proxy.configError = err } } + if proxy.configError == nil { + if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil { + proxy.configError = err + } + } return proxy } @@ -240,15 +246,6 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { return } - transport := p.config.Transport - if transport == nil { - if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil { - transport = p.extendedConnectTransport - } else { - transport = http.DefaultTransport - } - } - updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) if err != nil { p.handleError(c, err) @@ -260,86 +257,64 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { ctx, cancel := p.requestContext(c) defer cancel() + attempted := make(map[string]struct{}, len(p.upstreams)) + attempts := 0 + started := time.Now() + var lastErr error - outreq := c.Request.Clone(ctx) - if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { - outreq.Body = nil - } - if outreq.Body != nil { - outreq.Body = &noopCloseReader{readCloser: outreq.Body} - defer outreq.Body.Close() - } - if outreq.Header == nil { - outreq.Header = make(http.Header) - } - outreq.Close = false - var connectWriter *io.PipeWriter - defer func() { - if connectWriter != nil { - _ = connectWriter.Close() - } - }() - if outreq.Method == http.MethodConnect { - pipeReader, pipeWriter := io.Pipe() - outreq.Body = pipeReader - outreq.ContentLength = -1 - defer outreq.Body.Close() - connectWriter = pipeWriter - } - - if outreq.Method == http.MethodConnect { - if reverseProxyIsExtendedConnectRequest(outreq) { - rewriteReverseProxyURL(outreq, p.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } else { - if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { - p.handleError(c, err) + for { + upstream, err := p.selectUpstream(c, attempted) + if err != nil { + if lastErr != nil { + p.handleError(c, lastErr) return } + p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err}) + return } - } else { - rewriteReverseProxyURL(outreq, p.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } - if updatedMaxForwards != "" { - outreq.Header.Set("Max-Forwards", updatedMaxForwards) - } - reqUpType := reverseProxyUpgradeType(outreq.Header) - if reqUpType != "" && !isPrintableASCII(reqUpType) { - p.handleError(c, &reverseProxyStatusError{ - status: http.StatusBadRequest, - err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), - }) + attempts++ + upstream.inFlight.Add(1) + served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards) + upstream.inFlight.Add(-1) + + if served { + return + } + if attemptErr != nil { + lastErr = attemptErr + } + if retriable && p.shouldRetryAttempt(c.Request, attempts, started) { + attempted[upstream.key] = struct{}{} + if !p.waitRetryInterval(ctx, started) { + if lastErr != nil { + p.handleError(c, lastErr) + } + return + } + continue + } + if attemptErr != nil { + p.handleError(c, attemptErr) + return + } + if lastErr != nil { + p.handleError(c, lastErr) + return + } + p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams}) return } +} - removeHopByHopHeaders(outreq.Header) - if headerValuesContainToken(c.Request.Header["Te"], "trailers") { - outreq.Header.Set("Te", "trailers") - } - if reqUpType != "" { - outreq.Header.Set("Connection", "Upgrade") - outreq.Header.Set("Upgrade", reqUpType) - } - - p.addForwardingHeaders(c.Request, outreq) - appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) - - if _, ok := outreq.Header["User-Agent"]; !ok { - outreq.Header.Set("User-Agent", "") - } - - if p.config.ModifyRequest != nil { - p.config.ModifyRequest(outreq) +func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) { + outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards) + if err != nil { + return false, err, false } + defer cleanup() + transport := p.transportForUpstream(c.Request, upstream) rawWriter := reverseProxyBaseResponseWriter(c.Writer) var ( roundTripMu sync.Mutex @@ -369,8 +344,13 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { roundTripDone = true roundTripMu.Unlock() if err != nil { - p.handleError(c, err) - return + if reverseProxyShouldCountPassiveFailure(outreq, err) { + upstream.recordFailure(time.Now(), p.config.PassiveHealth) + } + return false, err, true + } + if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) { + upstream.recordFailure(time.Now(), p.config.PassiveHealth) } if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { @@ -381,35 +361,34 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { res.TransferEncoding = nil appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return + return true, nil, false } handleConnect := p.handleConnectResponse if reverseProxyIsExtendedConnectRequest(outreq) { handleConnect = p.handleExtendedConnectResponse } if err := handleConnect(c, outreq, res, connectWriter); err != nil { - p.handleError(c, err) + return false, err, false } - connectWriter = nil - return + return true, nil, false } if res.StatusCode == http.StatusSwitchingProtocols { appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return + return true, nil, false } if err := p.handleUpgradeResponse(c, outreq, res); err != nil { - p.handleError(c, err) + return false, err, false } - return + return true, nil, false } removeHopByHopHeaders(res.Header) appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return + return true, nil, false } reverseProxyCopyHeader(c.Writer.Header(), res.Header) @@ -432,7 +411,7 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { if reverseProxyShouldPanicOnCopyError(c.Request) { panic(http.ErrAbortHandler) } - return + return true, nil, false } res.Body.Close() @@ -440,13 +419,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { c.Writer.Flush() } - // Keep the stdlib-compatible fallback here. - // If the backend only exposes additional trailer keys after the body has been - // fully read, the trailer map can grow and those values must be written using - // the TrailerPrefix form instead of the pre-announced bare header keys. if len(res.Trailer) == announcedTrailers { reverseProxyCopyHeader(c.Writer.Header(), res.Trailer) - return + return true, nil, false } for key, values := range res.Trailer { @@ -455,6 +430,148 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { c.Writer.Header().Add(prefixedKey, value) } } + return true, nil, false +} + +func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { + outreq := c.Request.Clone(ctx) + if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { + outreq.Body = nil + } else if c.Request.GetBody != nil { + body, err := c.Request.GetBody() + if err != nil { + return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err) + } + outreq.Body = body + } else if outreq.Body != nil { + outreq.Body = &noopCloseReader{readCloser: outreq.Body} + } + if outreq.Header == nil { + outreq.Header = make(http.Header) + } + outreq.Close = false + var connectWriter *io.PipeWriter + if outreq.Method == http.MethodConnect { + pipeReader, pipeWriter := io.Pipe() + outreq.Body = pipeReader + outreq.ContentLength = -1 + connectWriter = pipeWriter + } + cleanup := func() { + if outreq.Body != nil { + _ = outreq.Body.Close() + } + if connectWriter != nil { + _ = connectWriter.Close() + } + } + + if outreq.Method == http.MethodConnect { + if reverseProxyIsExtendedConnectRequest(outreq) { + rewriteReverseProxyURL(outreq, upstream.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } else { + if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { + cleanup() + return nil, nil, nil, err + } + } + } else { + rewriteReverseProxyURL(outreq, upstream.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } + if updatedMaxForwards != "" { + outreq.Header.Set("Max-Forwards", updatedMaxForwards) + } + + reqUpType := reverseProxyUpgradeType(outreq.Header) + if reqUpType != "" && !isPrintableASCII(reqUpType) { + cleanup() + return nil, nil, nil, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), + } + } + + removeHopByHopHeaders(outreq.Header) + if headerValuesContainToken(c.Request.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + p.addForwardingHeaders(c.Request, outreq) + appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) + + if _, ok := outreq.Header["User-Agent"]; !ok { + outreq.Header.Set("User-Agent", "") + } + + if p.config.ModifyRequest != nil { + p.config.ModifyRequest(outreq) + } + + return outreq, connectWriter, cleanup, nil +} + +func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper { + if p.config.Transport != nil { + return p.config.Transport + } + if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil { + return upstream.extendedConnectTransport + } + return http.DefaultTransport +} + +func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool { + if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) { + return false + } + lb := p.config.LoadBalancing + if lb.TryDuration > 0 { + return time.Since(started) < lb.TryDuration + } + return attempts <= lb.Retries +} + +func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool { + interval := p.config.LoadBalancing.TryInterval + tryDuration := p.config.LoadBalancing.TryDuration + if tryDuration > 0 && interval == 0 { + interval = 250 * time.Millisecond + } + if tryDuration > 0 { + remaining := tryDuration - time.Since(started) + if remaining <= 0 { + return false + } + if interval <= 0 { + return ctx.Err() == nil + } + if interval > remaining { + return false + } + } + if interval <= 0 { + return ctx.Err() == nil + } + timer := time.NewTimer(interval) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } } func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) { @@ -976,6 +1093,54 @@ func validateReverseProxyTarget(target *url.URL) error { return nil } +func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) { + if config.Target != nil && len(config.Targets) > 0 { + return nil, errors.New("reverse proxy Target and Targets cannot be used together") + } + + targets := make([]*url.URL, 0, max(1, len(config.Targets))) + if config.Target != nil { + target := cloneReverseProxyURL(config.Target) + normalizeReverseProxyTarget(target) + if err := validateReverseProxyTarget(target); err != nil { + return nil, err + } + targets = append(targets, target) + } + for i, rawTarget := range config.Targets { + trimmed := strings.TrimSpace(rawTarget) + if trimmed == "" { + return nil, fmt.Errorf("reverse proxy target at index %d is empty", i) + } + target, err := url.Parse(trimmed) + if err != nil { + return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) + } + normalizeReverseProxyTarget(target) + if err := validateReverseProxyTarget(target); err != nil { + return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) + } + targets = append(targets, target) + } + if len(targets) == 0 { + return nil, errReverseProxyNilTarget + } + + upstreams := make([]*reverseProxyUpstream, 0, len(targets)) + for i, target := range targets { + upstream := &reverseProxyUpstream{ + key: fmt.Sprintf("%d:%s", i, target.String()), + target: target, + index: i, + } + if config.Transport == nil { + upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil +} + func validateReverseProxyForwardedBy(value string) error { trimmed := strings.TrimSpace(value) if trimmed == "" { @@ -1388,6 +1553,35 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { return req != nil && req.Context().Value(http.ServerContextKey) != nil } +func reverseProxyCanRetryRequest(req *http.Request) bool { + if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) { + return false + } + if req.Body == nil || req.ContentLength == 0 { + return true + } + return req.GetBody != nil +} + +func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool { + if err == nil || reverseProxyIsBenignTunnelError(err) { + return false + } + if req != nil && req.Context().Err() != nil { + return false + } + return !errors.Is(err, context.Canceled) +} + +func reverseProxyMethodIsSafe(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + return true + default: + return false + } +} + func reverseProxyIsBenignTunnelError(err error) bool { return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err) } @@ -1396,6 +1590,10 @@ func reverseProxyIsClosedBodyError(err error) bool { if err == nil { return false } + var streamErr http2.StreamError + if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel { + return true + } switch err.Error() { case "body closed by handler", "http2: response body closed", "response body closed": return true diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go new file mode 100644 index 0000000..9b41af0 --- /dev/null +++ b/reverseproxy_lb.go @@ -0,0 +1,352 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2026 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "fmt" + "math/rand/v2" + "net/http" + "net/textproto" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" +) + +// ReverseProxyLoadBalancingConfig configures upstream selection and retries. +type ReverseProxyLoadBalancingConfig struct { + Policy ReverseProxyLBPolicy + Retries int + TryDuration time.Duration + TryInterval time.Duration +} + +// ReverseProxyPassiveHealthConfig configures inline passive health tracking. +type ReverseProxyPassiveHealthConfig struct { + FailDuration time.Duration + MaxFails int + UnhealthyStatus []int +} + +// ReverseProxyLBPolicy selects an upstream from the configured target pool. +// Use the helper constructors such as LBRandom or LBHeader to build a policy. +type ReverseProxyLBPolicy struct { + kind reverseProxyLBPolicyKind + key string + fallback *ReverseProxyLBPolicy +} + +type reverseProxyLBPolicyKind uint8 + +const ( + reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota + reverseProxyLBPolicyRoundRobin + reverseProxyLBPolicyFirst + reverseProxyLBPolicyLeastConn + reverseProxyLBPolicyIPHash + reverseProxyLBPolicyClientIPHash + reverseProxyLBPolicyURIHash + reverseProxyLBPolicyHeader + reverseProxyLBPolicyQuery +) + +type reverseProxyUpstream struct { + key string + target *url.URL + index int + extendedConnectTransport http.RoundTripper + inFlight atomic.Int64 + + passiveMu sync.Mutex + failures []time.Time +} + +func LBRandom() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom} +} + +func LBRoundRobin() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin} +} + +func LBFirst() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst} +} + +func LBLeastConn() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn} +} + +func LBIPHash() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash} +} + +func LBClientIPHash() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash} +} + +func LBURIHash() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash} +} + +func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy { + policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))} + if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil { + policy.fallback = &fallback + } + return policy +} + +func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy { + policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)} + if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil { + policy.fallback = &fallback + } + return policy +} + +func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error { + switch policy.kind { + case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst, + reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash, + reverseProxyLBPolicyURIHash: + return nil + case reverseProxyLBPolicyHeader: + if policy.key == "" { + return fmt.Errorf("reverse proxy header load-balancing policy requires a header field") + } + case reverseProxyLBPolicyQuery: + if policy.key == "" { + return fmt.Errorf("reverse proxy query load-balancing policy requires a query key") + } + default: + return fmt.Errorf("reverse proxy load-balancing policy is invalid") + } + if policy.fallback != nil { + return validateReverseProxyLBPolicy(*policy.fallback) + } + return nil +} + +func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) { + now := time.Now() + policy := p.config.LoadBalancing.Policy + candidates := p.availableUpstreams(now, excluded) + if len(candidates) == 0 && len(excluded) > 0 { + candidates = p.availableUpstreams(now, nil) + } + if len(candidates) == 0 { + return nil, errReverseProxyNoAvailableUpstreams + } + return p.selectUpstreamWithPolicy(c, candidates, policy), nil +} + +func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream { + candidates := make([]*reverseProxyUpstream, 0, len(p.upstreams)) + for _, upstream := range p.upstreams { + if _, skip := excluded[upstream.key]; skip { + continue + } + if !upstream.healthy(now, p.config.PassiveHealth) { + continue + } + candidates = append(candidates, upstream) + } + return candidates +} + +func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + + switch policy.kind { + case reverseProxyLBPolicyRoundRobin: + return candidates[p.nextRoundRobinIndex(len(candidates))] + case reverseProxyLBPolicyFirst: + return candidates[0] + case reverseProxyLBPolicyLeastConn: + return p.selectLeastConnUpstream(candidates) + case reverseProxyLBPolicyIPHash: + return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr)) + case reverseProxyLBPolicyClientIPHash: + return reverseProxySelectHRW(candidates, c.RequestIP()) + case reverseProxyLBPolicyURIHash: + if c.Request == nil || c.Request.URL == nil { + return reverseProxySelectRandom(candidates) + } + return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI()) + case reverseProxyLBPolicyHeader: + if c.Request != nil && c.Request.Header != nil { + if values, ok := c.Request.Header[policy.key]; ok { + return reverseProxySelectHRW(candidates, strings.Join(values, ",")) + } + } + return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) + case reverseProxyLBPolicyQuery: + if c.Request != nil && c.Request.URL != nil { + if values, ok := c.Request.URL.Query()[policy.key]; ok { + return reverseProxySelectHRW(candidates, strings.Join(values, ",")) + } + } + return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) + case reverseProxyLBPolicyRandom: + fallthrough + default: + return reverseProxySelectRandom(candidates) + } +} + +func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int { + if size <= 1 { + return 0 + } + return int((p.roundRobin.Add(1) - 1) % uint64(size)) +} + +func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + selected := candidates[0] + lowest := selected.inFlight.Load() + ties := []*reverseProxyUpstream{selected} + for _, upstream := range candidates[1:] { + count := upstream.inFlight.Load() + switch { + case count < lowest: + selected = upstream + lowest = count + ties = []*reverseProxyUpstream{upstream} + case count == lowest: + ties = append(ties, upstream) + } + } + if len(ties) == 1 { + return selected + } + return ties[p.nextRoundRobinIndex(len(ties))] +} + +func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + if len(candidates) == 1 { + return candidates[0] + } + return candidates[rand.IntN(len(candidates))] +} + +func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + if key == "" { + return reverseProxySelectRandom(candidates) + } + selected := candidates[0] + bestScore := reverseProxyHRWScore(key, selected.key) + for _, upstream := range candidates[1:] { + score := reverseProxyHRWScore(key, upstream.key) + if score > bestScore { + selected = upstream + bestScore = score + } + } + return selected +} + +func reverseProxyHRWScore(key, upstreamKey string) uint64 { + const ( + offset64 = 14695981039346656037 + prime64 = 1099511628211 + ) + h := uint64(offset64) + for i := 0; i < len(key); i++ { + h ^= uint64(key[i]) + h *= prime64 + } + h ^= 0xff + h *= prime64 + for i := 0; i < len(upstreamKey); i++ { + h ^= uint64(upstreamKey[i]) + h *= prime64 + } + return h +} + +func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy { + if policy.fallback != nil { + return *policy.fallback + } + return LBRandom() +} + +func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool { + maxFails := reverseProxyPassiveMaxFails(config) + if config.FailDuration <= 0 || maxFails <= 0 { + return true + } + + u.passiveMu.Lock() + defer u.passiveMu.Unlock() + u.pruneFailuresLocked(now, config.FailDuration) + return len(u.failures) < maxFails +} + +func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) { + maxFails := reverseProxyPassiveMaxFails(config) + if config.FailDuration <= 0 || maxFails <= 0 { + return + } + + u.passiveMu.Lock() + defer u.passiveMu.Unlock() + u.pruneFailuresLocked(now, config.FailDuration) + u.failures = append(u.failures, now) +} + +func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) { + if len(u.failures) == 0 || window <= 0 { + if window <= 0 { + u.failures = nil + } + return + } + cutoff := now.Add(-window) + keep := 0 + for _, failureAt := range u.failures { + if failureAt.Before(cutoff) { + continue + } + u.failures[keep] = failureAt + keep++ + } + u.failures = u.failures[:keep] +} + +func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int { + if config.FailDuration <= 0 { + return 0 + } + if config.MaxFails <= 0 { + return 1 + } + return config.MaxFails +} + +func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool { + if status <= 0 { + return false + } + for _, unhealthyStatus := range config.UnhealthyStatus { + if status == unhealthyStatus { + return true + } + } + return false +} diff --git a/reverseproxy_test.go b/reverseproxy_test.go index e56aa5e..b68f74e 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -13,7 +13,9 @@ import ( "net/http/httptrace" "net/textproto" "net/url" + "strconv" "strings" + "sync/atomic" "testing" "time" @@ -262,6 +264,507 @@ func TestReverseProxyDefaultViaFallback(t *testing.T) { } } +func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Targets: []string{"http://example.net"}, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusInternalServerError { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) { + t.Helper() + + type snapshot struct { + Path string + RawQuery string + } + + backendOneCh := make(chan snapshot, 1) + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery} + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwoCh := make(chan snapshot, 1) + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery} + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ + Targets: []string{ + backendOne.URL + "/one?from=one", + backendTwo.URL + "/two?from=two", + }, + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()}, + })) + + first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil) + if first.Code != http.StatusOK || first.Body.String() != "one" { + t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String()) + } + second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil) + if second.Code != http.StatusOK || second.Body.String() != "two" { + t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String()) + } + + select { + case got := <-backendOneCh: + if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" { + t.Fatalf("unexpected first upstream request: %#v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first upstream request") + } + + select { + case got := <-backendTwoCh: + if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" { + t.Fatalf("unexpected second upstream request: %#v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for second upstream request") + } +} + +func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) { + t.Helper() + + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBHeader("X-Upstream", LBFirst()), + }, + })) + + fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" { + t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String()) + } + + headers := http.Header{"X-Upstream": {"tenant-a"}} + firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers) + secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers) + if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK { + t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code) + } + if firstSticky.Body.String() != secondSticky.Body.String() { + t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String()) + } +} + +func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) { + t.Helper() + + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBQuery("tenant", LBFirst()), + }, + })) + + fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" { + t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String()) + } + + firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil) + secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil) + if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK { + t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code) + } + if firstSticky.Body.String() != secondSticky.Body.String() { + t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String()) + } +} + +func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) { + t.Helper() + + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"}) + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBClientIPHash(), + }, + })) + + reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + reqOne.RemoteAddr = "10.0.0.1:1234" + reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10") + rrOne := httptest.NewRecorder() + engine.ServeHTTP(rrOne, reqOne) + + reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + reqTwo.RemoteAddr = "10.0.0.2:5678" + reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10") + rrTwo := httptest.NewRecorder() + engine.ServeHTTP(rrTwo, reqTwo) + + if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK { + t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code) + } + if rrOne.Body.String() != rrTwo.Body.String() { + t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String()) + } +} + +func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 1, + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String()) + } +} + +func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, r.Header.Get("X-Attempt")) + })) + defer backend.Close() + + var attempts atomic.Int64 + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 1, + }, + ModifyRequest: func(req *http.Request) { + req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10)) + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rr.Code) + } + if rr.Body.String() != "2" { + t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String()) + } +} + +func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) { + t.Helper() + + backendCalls := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCalls <- struct{}{} + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 1, + }, + })) + + rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil) + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case <-backendCalls: + t.Fatal("unsafe POST request should not be retried to the next upstream") + default: + } +} + +func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) { + t.Helper() + + backendOneStarted := make(chan struct{}, 1) + releaseBackendOne := make(chan struct{}) + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendOneStarted <- struct{}{} + <-releaseBackendOne + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBLeastConn(), + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + client := proxy.Client() + client.Timeout = 5 * time.Second + + firstRespCh := make(chan string, 1) + firstErrCh := make(chan error, 1) + go func() { + resp, err := client.Get(proxy.URL + "/proxy") + if err != nil { + firstErrCh <- err + return + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + firstErrCh <- err + return + } + firstRespCh <- string(body) + }() + + select { + case <-backendOneStarted: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first backend request") + } + + secondResp, err := client.Get(proxy.URL + "/proxy") + if err != nil { + close(releaseBackendOne) + t.Fatalf("second request failed: %v", err) + } + secondBody, err := io.ReadAll(secondResp.Body) + _ = secondResp.Body.Close() + if err != nil { + close(releaseBackendOne) + t.Fatalf("read second response: %v", err) + } + if string(secondBody) != "two" { + close(releaseBackendOne) + t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody)) + } + + close(releaseBackendOne) + select { + case err := <-firstErrCh: + t.Fatalf("first request failed: %v", err) + case body := <-firstRespCh: + if body != "one" { + t.Fatalf("unexpected first response body: %q", body) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first response body") + } +} + +func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) { + t.Helper() + + primaryCalls := make(chan struct{}, 4) + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + primaryCalls <- struct{}{} + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = io.WriteString(w, "primary down") + })) + defer primary.Close() + + secondaryCalls := make(chan struct{}, 4) + secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + secondaryCalls <- struct{}{} + _, _ = io.WriteString(w, "secondary up") + })) + defer secondary.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{primary.URL, secondary.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + }, + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + UnhealthyStatus: []int{http.StatusServiceUnavailable}, + }, + })) + + first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" { + t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String()) + } + second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if second.Code != http.StatusOK || second.Body.String() != "secondary up" { + t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String()) + } + + select { + case <-primaryCalls: + default: + t.Fatal("expected primary to receive the first request") + } + select { + case <-secondaryCalls: + default: + t.Fatal("expected secondary to receive the second request") + } + select { + case <-primaryCalls: + t.Fatal("primary should not receive the second request while unhealthy") + default: + } +} + +func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) { + t.Helper() + + started := make(chan struct{}, 1) + release := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + started <- struct{}{} + <-release + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backend.URL}, + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + client := proxy.Client() + respCh := make(chan error, 1) + go func() { + resp, err := client.Do(req) + if resp != nil { + _ = resp.Body.Close() + } + respCh <- err + }() + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend request") + } + cancel() + close(release) + select { + case <-respCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for canceled request to finish") + } + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String()) + } +} + +func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) { + t.Helper() + + backendCalls := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCalls <- struct{}{} + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 3, + TryDuration: 100 * time.Millisecond, + TryInterval: 250 * time.Millisecond, + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case <-backendCalls: + t.Fatal("retry budget should expire before the next upstream attempt") + default: + } +} + func TestReverseProxyCustomErrorHandler(t *testing.T) { t.Helper() @@ -967,6 +1470,122 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 8) + newBackend := func(name string) *httptest.Server { + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if got := r.Header.Get(":protocol"); got != "websocket" { + errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got) + w.WriteHeader(http.StatusBadRequest) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + line, err := bufio.NewReader(r.Body).ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err) + return + } + if _, err := io.WriteString(w, name+":"+line); err != nil { + errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err) + return + } + _ = controller.Flush() + })) + server.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil { + t.Fatalf("configure %s HTTP/2 server: %v", name, err) + } + server.StartTLS() + return server + } + + backendOne := newBackend("one") + defer backendOne.Close() + backendTwo := newBackend("two") + defer backendTwo.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBRoundRobin(), + }, + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + doRequest := func(payload string) string { + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if _, err := io.WriteString(pw, payload+"\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + return message + } + + if got := doRequest("ping"); got != "one:ping\n" { + t.Fatalf("unexpected first tunneled response: %q", got) + } + if got := doRequest("pong"); got != "two:pong\n" { + t.Fatalf("unexpected second tunneled response: %q", got) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { t.Helper()