refactor: improve TLS config handling and add bridge connection tests

This commit is contained in:
wjqserver 2026-04-02 22:13:50 +08:00
parent dcdb1504a3
commit d53693952a
3 changed files with 153 additions and 7 deletions

View file

@ -51,9 +51,10 @@ func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripp
transport := cloneDefaultTransport() transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols) transport.Protocols = new(http.Protocols)
transport.Protocols.SetHTTP1(true) transport.Protocols.SetHTTP1(true)
transport.TLSClientConfig = tlsConfig if tlsConfig == nil {
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{} transport.TLSClientConfig = &tls.Config{}
} else {
transport.TLSClientConfig = tlsConfig.Clone()
} }
if len(transport.TLSClientConfig.NextProtos) == 0 { if len(transport.TLSClientConfig.NextProtos) == 0 {
transport.TLSClientConfig.NextProtos = []string{"http/1.1"} transport.TLSClientConfig.NextProtos = []string{"http/1.1"}

View file

@ -1024,7 +1024,6 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r
}() }()
defer close(backConnClosed) defer close(backConnClosed)
defer conn.Close() defer conn.Close()
defer backConn.Close()
errc := make(chan error, 2) errc := make(chan error, 2)
copyer := switchProtocolCopier{user: conn, backend: backConn} copyer := switchProtocolCopier{user: conn, backend: backConn}
@ -1374,11 +1373,11 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
} }
func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) {
protocol := reverseProxyExtendedConnectProtocol(req)
if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
if req == nil { if req == nil {
return context.Background(), false, nil return context.Background(), false, nil
} }
protocol := reverseProxyExtendedConnectProtocol(req)
if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
return req.Context(), false, nil return req.Context(), false, nil
} }

View file

@ -933,6 +933,29 @@ func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing
}) })
} }
func TestNewHTTP1BridgeTransportWithTLSConfigClonesInput(t *testing.T) {
t.Helper()
tlsConfig := &tls.Config{InsecureSkipVerify: true}
rt := newHTTP1BridgeTransportWithTLSConfig(tlsConfig)
transport, ok := rt.(*http.Transport)
if !ok {
t.Fatalf("unexpected transport type: %T", rt)
}
if transport.TLSClientConfig == nil {
t.Fatal("expected TLS client config")
}
if transport.TLSClientConfig == tlsConfig {
t.Fatal("expected bridge transport to clone TLS config")
}
if len(tlsConfig.NextProtos) != 0 {
t.Fatalf("input TLS config was mutated: %#v", tlsConfig.NextProtos)
}
if got := transport.TLSClientConfig.NextProtos; len(got) != 1 || got[0] != "http/1.1" {
t.Fatalf("unexpected transport NextProtos: %#v", got)
}
}
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
t.Helper() t.Helper()
@ -1633,6 +1656,98 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
} }
} }
func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) {
t.Helper()
enableHTTP2ExtendedConnectProtocol()
closeCalls := atomic.Int32{}
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.Method != http.MethodGet {
return nil, fmt.Errorf("unexpected upstream method: %s", req.Method)
}
backend := &countingReadWriteCloser{
readData: []byte("echo:ping\n"),
closeCalls: &closeCalls,
closeWriteErr: http.ErrNotSupported,
}
return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Header: http.Header{
"Connection": []string{"Upgrade"},
"Upgrade": []string{"websocket"},
"Sec-WebSocket-Accept": []string{"ignored"},
},
Body: backend,
Request: req,
}, nil
})
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
Transport: transport,
}))
proxy := httptest.NewUnstartedServer(engine)
proxy.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
t.Fatalf("configure proxy HTTP/2 server: %v", err)
}
proxy.StartTLS()
defer proxy.Close()
clientTransport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
defer clientTransport.CloseIdleConnections()
pr, pw := io.Pipe()
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
if err != nil {
t.Fatalf("new CONNECT request: %v", err)
}
req.Header.Set(":protocol", "websocket")
resp, err := clientTransport.RoundTrip(req)
if err != nil {
t.Fatalf("round trip extended CONNECT: %v", err)
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if _, err := io.WriteString(pw, "ping\n"); err != nil {
_ = resp.Body.Close()
t.Fatalf("write tunneled request body: %v", err)
}
message, err := bufio.NewReader(resp.Body).ReadString('\n')
if err != nil {
_ = resp.Body.Close()
t.Fatalf("read tunneled response body: %v", err)
}
if message != "echo:ping\n" {
_ = resp.Body.Close()
t.Fatalf("unexpected tunneled response body: %q", message)
}
if err := pw.Close(); err != nil {
_ = resp.Body.Close()
t.Fatalf("close tunneled request body: %v", err)
}
if err := resp.Body.Close(); err != nil {
t.Fatalf("close response body: %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if closeCalls.Load() > 0 {
break
}
time.Sleep(10 * time.Millisecond)
}
if got := closeCalls.Load(); got != 1 {
t.Fatalf("expected backend connection to close exactly once, got %d", got)
}
}
func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) { func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) {
t.Helper() t.Helper()
@ -2311,6 +2426,37 @@ func (r errorReader) Read([]byte) (int, error) {
return 0, r.err return 0, r.err
} }
type countingReadWriteCloser struct {
readData []byte
writeBuf bytes.Buffer
closeCalls *atomic.Int32
closeWriteErr error
}
func (r *countingReadWriteCloser) Read(p []byte) (int, error) {
if len(r.readData) == 0 {
return 0, io.EOF
}
n := copy(p, r.readData)
r.readData = r.readData[n:]
return n, nil
}
func (r *countingReadWriteCloser) Write(p []byte) (int, error) {
return r.writeBuf.Write(p)
}
func (r *countingReadWriteCloser) Close() error {
if r.closeCalls != nil {
r.closeCalls.Add(1)
}
return nil
}
func (r *countingReadWriteCloser) CloseWrite() error {
return r.closeWriteErr
}
func mustParseURL(t *testing.T, raw string) *url.URL { func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper() t.Helper()
u, err := url.Parse(raw) u, err := url.Parse(raw)