feat: add robust transport cloning and improve header handling in reverse proxy

This commit is contained in:
wjqserver 2026-04-02 19:58:34 +08:00
parent 20dc6e4047
commit dcdb1504a3
3 changed files with 68 additions and 6 deletions

View file

@ -6,8 +6,10 @@ package touka
import ( import (
"crypto/tls" "crypto/tls"
"net"
"net/http" "net/http"
"sync" "sync"
"time"
_ "unsafe" _ "unsafe"
"golang.org/x/net/http2" "golang.org/x/net/http2"
@ -34,7 +36,7 @@ func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
func newHTTP2ExtendedConnectTransport() http.RoundTripper { func newHTTP2ExtendedConnectTransport() http.RoundTripper {
enableHTTP2ExtendedConnectProtocol() enableHTTP2ExtendedConnectProtocol()
transport := http.DefaultTransport.(*http.Transport).Clone() transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols) transport.Protocols = new(http.Protocols)
transport.Protocols.SetHTTP1(true) transport.Protocols.SetHTTP1(true)
transport.Protocols.SetHTTP2(true) transport.Protocols.SetHTTP2(true)
@ -46,7 +48,7 @@ func newHTTP1BridgeTransport() http.RoundTripper {
} }
func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper { func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper {
transport := http.DefaultTransport.(*http.Transport).Clone() transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols) transport.Protocols = new(http.Protocols)
transport.Protocols.SetHTTP1(true) transport.Protocols.SetHTTP1(true)
transport.TLSClientConfig = tlsConfig transport.TLSClientConfig = tlsConfig
@ -60,8 +62,26 @@ func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripp
} }
func newH2CTransport() http.RoundTripper { func newH2CTransport() http.RoundTripper {
transport := http.DefaultTransport.(*http.Transport).Clone() transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols) transport.Protocols = new(http.Protocols)
transport.Protocols.SetUnencryptedHTTP2(true) transport.Protocols.SetUnencryptedHTTP2(true)
return transport return transport
} }
func cloneDefaultTransport() *http.Transport {
if transport, ok := http.DefaultTransport.(*http.Transport); ok {
return transport.Clone()
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

View file

@ -1004,8 +1004,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r
responseHeader := c.Writer.Header() responseHeader := c.Writer.Header()
reverseProxyCopyHeader(responseHeader, res.Header) reverseProxyCopyHeader(responseHeader, res.Header)
responseHeader.Del("Upgrade") removeHopByHopHeaders(responseHeader)
responseHeader.Del("Connection")
responseHeader.Del("Sec-WebSocket-Accept") responseHeader.Del("Sec-WebSocket-Accept")
c.Writer.WriteHeader(http.StatusOK) c.Writer.WriteHeader(http.StatusOK)
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {

View file

@ -893,6 +893,46 @@ func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *te
} }
} }
func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing.T) {
t.Helper()
originalDefaultTransport := http.DefaultTransport
http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
return nil, errors.New("unexpected round trip")
})
t.Cleanup(func() {
http.DefaultTransport = originalDefaultTransport
})
assertTransport := func(name string, rt http.RoundTripper, check func(*http.Transport)) {
t.Helper()
transport, ok := rt.(*http.Transport)
if !ok {
t.Fatalf("%s returned %T, want *http.Transport", name, rt)
}
check(transport)
}
assertTransport("newHTTP2ExtendedConnectTransport", newHTTP2ExtendedConnectTransport(), func(transport *http.Transport) {
if transport.Protocols == nil || !transport.Protocols.HTTP1() || !transport.Protocols.HTTP2() {
t.Fatalf("unexpected protocols for extended connect transport: %#v", transport.Protocols)
}
})
assertTransport("newHTTP1BridgeTransportWithTLSConfig", newHTTP1BridgeTransportWithTLSConfig(nil), func(transport *http.Transport) {
if transport.Protocols == nil || !transport.Protocols.HTTP1() || transport.Protocols.HTTP2() || transport.Protocols.UnencryptedHTTP2() {
t.Fatalf("unexpected protocols for bridge transport: %#v", transport.Protocols)
}
if transport.TLSClientConfig == nil || len(transport.TLSClientConfig.NextProtos) != 1 || transport.TLSClientConfig.NextProtos[0] != "http/1.1" {
t.Fatalf("unexpected TLS next protos for bridge transport: %#v", transport.TLSClientConfig)
}
})
assertTransport("newH2CTransport", newH2CTransport(), func(transport *http.Transport) {
if transport.Protocols == nil || !transport.Protocols.UnencryptedHTTP2() || transport.Protocols.HTTP1() || transport.Protocols.HTTP2() {
t.Fatalf("unexpected protocols for h2c transport: %#v", transport.Protocols)
}
})
}
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
t.Helper() t.Helper()
@ -1509,7 +1549,7 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
} }
defer conn.Close() defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n") _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade, X-Hop-Token\r\nX-Hop-Token: hidden\r\nSec-WebSocket-Accept: ignored\r\n\r\n")
if err := brw.Flush(); err != nil { if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("upstream flush failed: %w", err) errCh <- fmt.Errorf("upstream flush failed: %w", err)
return return
@ -1565,6 +1605,9 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
if got := resp.Header.Get("Upgrade"); got != "" { if got := resp.Header.Get("Upgrade"); got != "" {
t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got) t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got)
} }
if got := resp.Header.Get("X-Hop-Token"); got != "" {
t.Fatalf("bridged extended CONNECT response should not expose hop-by-hop token header, got %q", got)
}
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" {
t.Fatalf("unexpected Via response header: %#v", gotVia) t.Fatalf("unexpected Via response header: %#v", gotVia)
} }