From dcdb1504a32b709122709a37dc39f4fc869e70d8 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:58:34 +0800 Subject: [PATCH] feat: add robust transport cloning and improve header handling in reverse proxy --- http2xconnect.go | 26 ++++++++++++++++++++++--- reverseproxy.go | 3 +-- reverseproxy_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/http2xconnect.go b/http2xconnect.go index 872f5b3..8521672 100644 --- a/http2xconnect.go +++ b/http2xconnect.go @@ -6,8 +6,10 @@ package touka import ( "crypto/tls" + "net" "net/http" "sync" + "time" _ "unsafe" "golang.org/x/net/http2" @@ -34,7 +36,7 @@ func configureHTTP2ExtendedConnectServer(srv *http.Server) error { func newHTTP2ExtendedConnectTransport() http.RoundTripper { enableHTTP2ExtendedConnectProtocol() - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetHTTP1(true) transport.Protocols.SetHTTP2(true) @@ -46,7 +48,7 @@ func newHTTP1BridgeTransport() http.RoundTripper { } func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper { - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetHTTP1(true) transport.TLSClientConfig = tlsConfig @@ -60,8 +62,26 @@ func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripp } func newH2CTransport() http.RoundTripper { - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetUnencryptedHTTP2(true) return transport } + +func cloneDefaultTransport() *http.Transport { + if transport, ok := http.DefaultTransport.(*http.Transport); ok { + return transport.Clone() + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} diff --git a/reverseproxy.go b/reverseproxy.go index afdbd9c..5d9b1ad 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1004,8 +1004,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r responseHeader := c.Writer.Header() reverseProxyCopyHeader(responseHeader, res.Header) - responseHeader.Del("Upgrade") - responseHeader.Del("Connection") + removeHopByHopHeaders(responseHeader) responseHeader.Del("Sec-WebSocket-Accept") c.Writer.WriteHeader(http.StatusOK) if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 85a64d4..9252e4a 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -893,6 +893,46 @@ func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *te } } +func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing.T) { + t.Helper() + + originalDefaultTransport := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("unexpected round trip") + }) + t.Cleanup(func() { + http.DefaultTransport = originalDefaultTransport + }) + + assertTransport := func(name string, rt http.RoundTripper, check func(*http.Transport)) { + t.Helper() + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("%s returned %T, want *http.Transport", name, rt) + } + check(transport) + } + + assertTransport("newHTTP2ExtendedConnectTransport", newHTTP2ExtendedConnectTransport(), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.HTTP1() || !transport.Protocols.HTTP2() { + t.Fatalf("unexpected protocols for extended connect transport: %#v", transport.Protocols) + } + }) + assertTransport("newHTTP1BridgeTransportWithTLSConfig", newHTTP1BridgeTransportWithTLSConfig(nil), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.HTTP1() || transport.Protocols.HTTP2() || transport.Protocols.UnencryptedHTTP2() { + t.Fatalf("unexpected protocols for bridge transport: %#v", transport.Protocols) + } + if transport.TLSClientConfig == nil || len(transport.TLSClientConfig.NextProtos) != 1 || transport.TLSClientConfig.NextProtos[0] != "http/1.1" { + t.Fatalf("unexpected TLS next protos for bridge transport: %#v", transport.TLSClientConfig) + } + }) + assertTransport("newH2CTransport", newH2CTransport(), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.UnencryptedHTTP2() || transport.Protocols.HTTP1() || transport.Protocols.HTTP2() { + t.Fatalf("unexpected protocols for h2c transport: %#v", transport.Protocols) + } + }) +} + func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { t.Helper() @@ -1509,7 +1549,7 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } defer conn.Close() - _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n") + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade, X-Hop-Token\r\nX-Hop-Token: hidden\r\nSec-WebSocket-Accept: ignored\r\n\r\n") if err := brw.Flush(); err != nil { errCh <- fmt.Errorf("upstream flush failed: %w", err) return @@ -1565,6 +1605,9 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { if got := resp.Header.Get("Upgrade"); got != "" { t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got) } + if got := resp.Header.Get("X-Hop-Token"); got != "" { + t.Fatalf("bridged extended CONNECT response should not expose hop-by-hop token header, got %q", got) + } if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { t.Fatalf("unexpected Via response header: %#v", gotVia) }