feat: add robust transport cloning and improve header handling in reverse proxy

This commit is contained in:
wjqserver 2026-04-02 19:58:34 +08:00
parent 20dc6e4047
commit dcdb1504a3
3 changed files with 68 additions and 6 deletions

View file

@ -893,6 +893,46 @@ func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *te
}
}
func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing.T) {
t.Helper()
originalDefaultTransport := http.DefaultTransport
http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
return nil, errors.New("unexpected round trip")
})
t.Cleanup(func() {
http.DefaultTransport = originalDefaultTransport
})
assertTransport := func(name string, rt http.RoundTripper, check func(*http.Transport)) {
t.Helper()
transport, ok := rt.(*http.Transport)
if !ok {
t.Fatalf("%s returned %T, want *http.Transport", name, rt)
}
check(transport)
}
assertTransport("newHTTP2ExtendedConnectTransport", newHTTP2ExtendedConnectTransport(), func(transport *http.Transport) {
if transport.Protocols == nil || !transport.Protocols.HTTP1() || !transport.Protocols.HTTP2() {
t.Fatalf("unexpected protocols for extended connect transport: %#v", transport.Protocols)
}
})
assertTransport("newHTTP1BridgeTransportWithTLSConfig", newHTTP1BridgeTransportWithTLSConfig(nil), func(transport *http.Transport) {
if transport.Protocols == nil || !transport.Protocols.HTTP1() || transport.Protocols.HTTP2() || transport.Protocols.UnencryptedHTTP2() {
t.Fatalf("unexpected protocols for bridge transport: %#v", transport.Protocols)
}
if transport.TLSClientConfig == nil || len(transport.TLSClientConfig.NextProtos) != 1 || transport.TLSClientConfig.NextProtos[0] != "http/1.1" {
t.Fatalf("unexpected TLS next protos for bridge transport: %#v", transport.TLSClientConfig)
}
})
assertTransport("newH2CTransport", newH2CTransport(), func(transport *http.Transport) {
if transport.Protocols == nil || !transport.Protocols.UnencryptedHTTP2() || transport.Protocols.HTTP1() || transport.Protocols.HTTP2() {
t.Fatalf("unexpected protocols for h2c transport: %#v", transport.Protocols)
}
})
}
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
t.Helper()
@ -1509,7 +1549,7 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n")
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade, X-Hop-Token\r\nX-Hop-Token: hidden\r\nSec-WebSocket-Accept: ignored\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("upstream flush failed: %w", err)
return
@ -1565,6 +1605,9 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
if got := resp.Header.Get("Upgrade"); got != "" {
t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got)
}
if got := resp.Header.Get("X-Hop-Token"); got != "" {
t.Fatalf("bridged extended CONNECT response should not expose hop-by-hop token header, got %q", got)
}
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" {
t.Fatalf("unexpected Via response header: %#v", gotVia)
}