From d53693952a8cccd8e6579c7a92f7a63b5a6bb5d8 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 22:13:50 +0800 Subject: [PATCH] refactor: improve TLS config handling and add bridge connection tests --- http2xconnect.go | 5 +- reverseproxy.go | 9 ++- reverseproxy_test.go | 146 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 7 deletions(-) diff --git a/http2xconnect.go b/http2xconnect.go index 8521672..c691a77 100644 --- a/http2xconnect.go +++ b/http2xconnect.go @@ -51,9 +51,10 @@ func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripp transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetHTTP1(true) - transport.TLSClientConfig = tlsConfig - if transport.TLSClientConfig == nil { + if tlsConfig == nil { transport.TLSClientConfig = &tls.Config{} + } else { + transport.TLSClientConfig = tlsConfig.Clone() } if len(transport.TLSClientConfig.NextProtos) == 0 { transport.TLSClientConfig.NextProtos = []string{"http/1.1"} diff --git a/reverseproxy.go b/reverseproxy.go index 5d9b1ad..148a9b4 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1024,7 +1024,6 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r }() defer close(backConnClosed) defer conn.Close() - defer backConn.Close() errc := make(chan error, 2) copyer := switchProtocolCopier{user: conn, backend: backConn} @@ -1374,11 +1373,11 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { } func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { + if req == nil { + return context.Background(), false, nil + } protocol := reverseProxyExtendedConnectProtocol(req) - if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { - if req == nil { - return context.Background(), false, nil - } + if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { return req.Context(), false, nil } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 9252e4a..bf7b0bb 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -933,6 +933,29 @@ func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing }) } +func TestNewHTTP1BridgeTransportWithTLSConfigClonesInput(t *testing.T) { + t.Helper() + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + rt := newHTTP1BridgeTransportWithTLSConfig(tlsConfig) + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("unexpected transport type: %T", rt) + } + if transport.TLSClientConfig == nil { + t.Fatal("expected TLS client config") + } + if transport.TLSClientConfig == tlsConfig { + t.Fatal("expected bridge transport to clone TLS config") + } + if len(tlsConfig.NextProtos) != 0 { + t.Fatalf("input TLS config was mutated: %#v", tlsConfig.NextProtos) + } + if got := transport.TLSClientConfig.NextProtos; len(got) != 1 || got[0] != "http/1.1" { + t.Fatalf("unexpected transport NextProtos: %#v", got) + } +} + func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { t.Helper() @@ -1633,6 +1656,98 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + closeCalls := atomic.Int32{} + transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Method != http.MethodGet { + return nil, fmt.Errorf("unexpected upstream method: %s", req.Method) + } + backend := &countingReadWriteCloser{ + readData: []byte("echo:ping\n"), + closeCalls: &closeCalls, + closeWriteErr: http.ErrNotSupported, + } + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: http.Header{ + "Connection": []string{"Upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-WebSocket-Accept": []string{"ignored"}, + }, + Body: backend, + Request: req, + }, nil + }) + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: transport, + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + clientTransport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer clientTransport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := clientTransport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if _, err := io.WriteString(pw, "ping\n"); err != nil { + _ = resp.Body.Close() + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + _ = resp.Body.Close() + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + _ = resp.Body.Close() + t.Fatalf("unexpected tunneled response body: %q", message) + } + if err := pw.Close(); err != nil { + _ = resp.Body.Close() + t.Fatalf("close tunneled request body: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatalf("close response body: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if closeCalls.Load() > 0 { + break + } + time.Sleep(10 * time.Millisecond) + } + if got := closeCalls.Load(); got != 1 { + t.Fatalf("expected backend connection to close exactly once, got %d", got) + } +} + func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) { t.Helper() @@ -2311,6 +2426,37 @@ func (r errorReader) Read([]byte) (int, error) { return 0, r.err } +type countingReadWriteCloser struct { + readData []byte + writeBuf bytes.Buffer + closeCalls *atomic.Int32 + closeWriteErr error +} + +func (r *countingReadWriteCloser) Read(p []byte) (int, error) { + if len(r.readData) == 0 { + return 0, io.EOF + } + n := copy(p, r.readData) + r.readData = r.readData[n:] + return n, nil +} + +func (r *countingReadWriteCloser) Write(p []byte) (int, error) { + return r.writeBuf.Write(p) +} + +func (r *countingReadWriteCloser) Close() error { + if r.closeCalls != nil { + r.closeCalls.Add(1) + } + return nil +} + +func (r *countingReadWriteCloser) CloseWrite() error { + return r.closeWriteErr +} + func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw)