From a9c1662333c2396ca042cd4f160b1966c396f743 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:19:41 +0800 Subject: [PATCH 1/7] fix(reverseproxy): bridge websocket extended connect upstreams --- docs/reverse-proxy.md | 19 +++ http2xconnect.go | 40 +++-- reverseproxy.go | 194 ++++++++++++++++++++++- reverseproxy_lb.go | 3 + reverseproxy_test.go | 351 ++++++++++++++++++++++++++++++++---------- 5 files changed, 508 insertions(+), 99 deletions(-) diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 15ebafd..7d05290 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -68,6 +68,7 @@ type ReverseProxyConfig struct { Transport http.RoundTripper FlushInterval time.Duration BufferPool BufferPool + AllowH2CUpstream bool ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error @@ -191,6 +192,24 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ })) ``` +### `AllowH2CUpstream` + +允许代理使用未加密 HTTP/2(h2c)与 `http://` upstream 通信。 + +- 默认关闭 +- 这是一个显式配置项 +- 启用后,Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游 +- 这意味着上游本身也必须显式支持 h2c;它不是“先试 h2c,失败再自动回退到 h1”的协商模式 + +```go +r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + AllowH2CUpstream: true, +})) +``` + +对于下游 HTTP/2 extended `CONNECT` websocket 场景,Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade,以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。 + ### `Transport` 可选。用于自定义底层转发所使用的 `http.RoundTripper`。 diff --git a/http2xconnect.go b/http2xconnect.go index b3b12a0..872f5b3 100644 --- a/http2xconnect.go +++ b/http2xconnect.go @@ -5,12 +5,8 @@ package touka import ( - "context" "crypto/tls" - "net" "net/http" - "net/url" - "strings" "sync" _ "unsafe" @@ -36,18 +32,36 @@ func configureHTTP2ExtendedConnectServer(srv *http.Server) error { return http2.ConfigureServer(srv, nil) } -func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper { +func newHTTP2ExtendedConnectTransport() http.RoundTripper { enableHTTP2ExtendedConnectProtocol() + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetHTTP1(true) + transport.Protocols.SetHTTP2(true) + return transport +} - transport := &http2.Transport{} - if target == nil || !strings.EqualFold(target.Scheme, "http") { - return transport +func newHTTP1BridgeTransport() http.RoundTripper { + return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}}) +} + +func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetHTTP1(true) + transport.TLSClientConfig = tlsConfig + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} } - - transport.AllowHTTP = true - transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { - var dialer net.Dialer - return dialer.DialContext(ctx, network, addr) + if len(transport.TLSClientConfig.NextProtos) == 0 { + transport.TLSClientConfig.NextProtos = []string{"http/1.1"} } return transport } + +func newH2CTransport() http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return transport +} diff --git a/reverseproxy.go b/reverseproxy.go index 186e163..13ffe89 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -5,7 +5,10 @@ package touka import ( + "bufio" "context" + "crypto/rand" + "encoding/base64" "errors" "fmt" "io" @@ -52,9 +55,10 @@ type ReverseProxyConfig struct { LoadBalancing ReverseProxyLoadBalancingConfig PassiveHealth ReverseProxyPassiveHealthConfig - Transport http.RoundTripper - FlushInterval time.Duration - BufferPool BufferPool + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool + AllowH2CUpstream bool ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error @@ -86,6 +90,33 @@ type reverseProxyStatusError struct { err error } +type reverseProxyExtendedConnectBridge struct { + body io.ReadCloser +} + +type reverseProxyH2ReadWriteCloser struct { + io.ReadCloser + ResponseWriter +} + +func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { + n, err := rwc.ResponseWriter.Write(p) + if err != nil { + return 0, err + } + if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + return 0, err + } + return n, nil +} + +func (rwc reverseProxyH2ReadWriteCloser) Close() error { + if rwc.ReadCloser == nil { + return nil + } + return rwc.ReadCloser.Close() +} + func (e *reverseProxyStatusError) Error() string { if e == nil || e.err == nil { return "" @@ -314,7 +345,7 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte } defer cleanup() - transport := p.transportForUpstream(c.Request, upstream) + transport := p.transportForUpstream(outreq, upstream) rawWriter := reverseProxyBaseResponseWriter(c.Writer) var ( roundTripMu sync.Mutex @@ -353,6 +384,20 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte upstream.recordFailure(time.Now(), p.config.PassiveHealth) } + if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil { + if res.StatusCode == http.StatusSwitchingProtocols { + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return true, nil, false + } + if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil { + return false, err, false + } + return true, nil, false + } + return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false + } + if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { removeHopByHopHeaders(res.Header) res.Header.Del("Content-Length") @@ -435,6 +480,10 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { outreq := c.Request.Clone(ctx) + bridgeCtx, bridged := reverseProxyPrepareExtendedConnectBridge(outreq) + if bridged { + outreq = outreq.WithContext(bridgeCtx) + } if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { outreq.Body = nil } else if c.Request.GetBody != nil { @@ -451,7 +500,7 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte } outreq.Close = false var connectWriter *io.PipeWriter - if outreq.Method == http.MethodConnect { + if outreq.Method == http.MethodConnect && !bridged { pipeReader, pipeWriter := io.Pipe() outreq.Body = pipeReader outreq.ContentLength = -1 @@ -467,7 +516,13 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte } if outreq.Method == http.MethodConnect { - if reverseProxyIsExtendedConnectRequest(outreq) { + if bridged { + rewriteReverseProxyURL(outreq, upstream.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } else if reverseProxyIsExtendedConnectRequest(outreq) { rewriteReverseProxyURL(outreq, upstream.target) if !p.config.PreserveHost { outreq.Host = "" @@ -526,6 +581,15 @@ func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream * if p.config.Transport != nil { return p.config.Transport } + if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil { + if upstream.bridgeTransport != nil { + return upstream.bridgeTransport + } + return http.DefaultTransport + } + if upstream.useH2C && upstream.h2cTransport != nil { + return upstream.h2cTransport + } if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil { return upstream.extendedConnectTransport } @@ -915,6 +979,71 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error { + if c == nil || c.Request == nil { + res.Body.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")} + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("backend returned bridged websocket response without writable body"), + } + } + + controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + responseHeader := c.Writer.Header() + reverseProxyCopyHeader(responseHeader, res.Header) + responseHeader.Del("Upgrade") + responseHeader.Del("Connection") + responseHeader.Del("Sec-WebSocket-Accept") + c.Writer.WriteHeader(http.StatusOK) + if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer} + brw := bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) + + backConnClosed := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + case <-backConnClosed: + } + backConn.Close() + }() + defer close(backConnClosed) + defer conn.Close() + defer backConn.Close() + + if err := brw.Flush(); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + copyer := switchProtocolCopier{user: conn, backend: backConn} + go copyer.copyToBackend(errc) + go copyer.copyFromBackend(errc) + + firstErr := <-errc + if firstErr == nil { + firstErr = <-errc + } + if reverseProxyIsBenignTunnelError(firstErr) { + return nil + } + return firstErr +} + func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { if c == nil || c.Request == nil { res.Body.Close() @@ -1128,13 +1257,23 @@ func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstr upstreams := make([]*reverseProxyUpstream, 0, len(targets)) for i, target := range targets { + useH2C := strings.EqualFold(target.Scheme, "h2c") + if useH2C { + target = cloneReverseProxyURL(target) + target.Scheme = "http" + } upstream := &reverseProxyUpstream{ key: fmt.Sprintf("%d:%s", i, target.String()), target: target, index: i, + useH2C: useH2C || config.AllowH2CUpstream, } if config.Transport == nil { - upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) + upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport() + upstream.bridgeTransport = newHTTP1BridgeTransport() + if upstream.useH2C { + upstream.h2cTransport = newH2CTransport() + } } upstreams = append(upstreams, upstream) } @@ -1237,6 +1376,47 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } +func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool) { + protocol := reverseProxyExtendedConnectProtocol(req) + if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { + if req == nil { + return context.Background(), false + } + return req.Context(), false + } + + bridge := &reverseProxyExtendedConnectBridge{body: req.Body} + ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge) + req.Header.Del(":protocol") + req.Method = http.MethodGet + req.Body = http.NoBody + req.ContentLength = 0 + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Version", "13") + key, err := reverseProxyGenerateWebSocketKey() + if err == nil { + req.Header.Set("Sec-WebSocket-Key", key) + } + return ctx, true +} + +func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge { + if ctx == nil { + return nil + } + bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge) + return bridge +} + +func reverseProxyGenerateWebSocketKey() (string, error) { + key := make([]byte, 16) + if _, err := rand.Read(key); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(key), nil +} + func reverseProxyIsExtendedConnectRequest(req *http.Request) bool { return reverseProxyExtendedConnectProtocol(req) != "" } diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go index 9b41af0..d2d45ab 100644 --- a/reverseproxy_lb.go +++ b/reverseproxy_lb.go @@ -57,7 +57,10 @@ type reverseProxyUpstream struct { key string target *url.URL index int + useH2C bool extendedConnectTransport http.RoundTripper + bridgeTransport http.RoundTripper + h2cTransport http.RoundTripper inFlight atomic.Int64 passiveMu sync.Mutex diff --git a/reverseproxy_test.go b/reverseproxy_test.go index b68f74e..b05f426 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -112,7 +112,8 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { t.Fatalf("unexpected body: %q", string(body)) } if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } if got.Path != "/base/api/ping" { t.Fatalf("unexpected upstream path: %q", got.Path) @@ -765,6 +766,43 @@ func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) { } } +func TestReverseProxyAllowH2CUpstream(t *testing.T) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen h2c upstream: %v", err) + } + server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Upstream-Proto", r.Proto) + _, _ = io.WriteString(w, "ok") + })} + server.Protocols = new(http.Protocols) + server.Protocols.SetUnencryptedHTTP2(true) + errCh := make(chan error, 1) + go func() { + errCh <- server.Serve(listener) + }() + defer func() { + _ = server.Close() + <-errCh + }() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://"+listener.Addr().String()), + AllowH2CUpstream: true, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("unexpected response: code=%d body=%q", rr.Code, rr.Body.String()) + } + if got := rr.Header().Get("X-Upstream-Proto"); got != "HTTP/2.0" { + t.Fatalf("expected h2c upstream proto, got %q", got) + } +} + func TestReverseProxyCustomErrorHandler(t *testing.T) { t.Helper() @@ -1363,19 +1401,29 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { enableHTTP2ExtendedConnectProtocol() errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - if r.ProtoMajor != 2 { - errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto) + if got := r.Header.Get(":protocol"); got != "" { + errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) w.WriteHeader(http.StatusBadRequest) return } - if got := r.Header.Get(":protocol"); got != "websocket" { - errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { + errCh <- fmt.Errorf("unexpected upstream Connection header: %#v", r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected upstream Upgrade header: %q", r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { + errCh <- errors.New("missing upstream Sec-WebSocket-Key header") w.WriteHeader(http.StatusBadRequest) return } @@ -1385,36 +1433,41 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { return } - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("enable full duplex failed: %w", err) + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() - line, err := bufio.NewReader(r.Body).ReadString('\n') + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') if err != nil { errCh <- fmt.Errorf("read tunneled request body failed: %w", err) return } - if _, err := io.WriteString(w, "echo:"+line); err != nil { + if _, err := io.WriteString(brw, "echo:"+line); err != nil { errCh <- fmt.Errorf("write tunneled response body failed: %w", err) return } - _ = controller.Flush() + _ = brw.Flush() })) - upstream.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { - t.Fatalf("configure upstream HTTP/2 server: %v", err) - } - upstream.StartTLS() defer upstream.Close() engine := New() engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ Target: mustParseURL(t, upstream.URL), - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), Via: "proxy.test", })) @@ -1445,7 +1498,10 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status: %d", resp.StatusCode) } - if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" { + if got := resp.Header.Get("Upgrade"); got != "" { + t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got) + } + if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { t.Fatalf("unexpected Via response header: %#v", gotVia) } @@ -1470,6 +1526,116 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor != 1 { + errCh <- fmt.Errorf("expected bridged upstream protocol HTTP/1.x, got %s", r.Proto) + w.WriteHeader(http.StatusBadRequest) + return + } + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected websocket bridge headers: Connection=%#v Upgrade=%q", r.Header.Values("Connection"), r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(brw, "echo:"+line); err != nil { + errCh <- fmt.Errorf("write tunneled response body failed: %w", err) + return + } + _ = brw.Flush() + })) + upstream.EnableHTTP2 = true + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) + } + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled response body: %q", message) + } + _ = pw.Close() + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { t.Helper() @@ -1477,42 +1643,62 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { errCh := make(chan error, 8) newBackend := func(name string) *httptest.Server { - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - if got := r.Header.Get(":protocol"); got != "websocket" { + if got := r.Header.Get(":protocol"); got != "" { errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got) w.WriteHeader(http.StatusBadRequest) return } - - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err) + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { + errCh <- fmt.Errorf("%s unexpected upstream Connection header: %#v", name, r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("%s unexpected upstream Upgrade header: %q", name, r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { + errCh <- fmt.Errorf("%s missing upstream Sec-WebSocket-Key header", name) + w.WriteHeader(http.StatusBadRequest) return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() - line, err := bufio.NewReader(r.Body).ReadString('\n') + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- fmt.Errorf("%s upstream response writer does not support hijack", name) + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("%s upstream hijack failed: %w", name, err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("%s upstream flush failed: %w", name, err) + return + } + + line, err := brw.ReadString('\n') if err != nil { errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err) return } - if _, err := io.WriteString(w, name+":"+line); err != nil { + if _, err := io.WriteString(brw, name+":"+line); err != nil { errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err) return } - _ = controller.Flush() + _ = brw.Flush() })) - server.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil { - t.Fatalf("configure %s HTTP/2 server: %v", name, err) - } - server.StartTLS() return server } @@ -1527,8 +1713,7 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { LoadBalancing: ReverseProxyLoadBalancingConfig{ Policy: LBRoundRobin(), }, - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Via: "proxy.test", + Via: "proxy.test", })) proxy := httptest.NewUnstartedServer(engine) @@ -1557,7 +1742,8 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } if _, err := io.WriteString(pw, payload+"\n"); err != nil { t.Fatalf("write tunneled request body: %v", err) @@ -1592,55 +1778,59 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { enableHTTP2ExtendedConnectProtocol() errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("enable full duplex failed: %w", err) + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() - reader := bufio.NewReader(r.Body) + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + reader := bufio.NewReader(brw) line, err := reader.ReadString('\n') if err != nil { errCh <- fmt.Errorf("read tunneled request body failed: %w", err) return } - if _, err := io.WriteString(w, "ack:"+line); err != nil { + if _, err := io.WriteString(brw, "ack:"+line); err != nil { errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err) return } - _ = controller.Flush() + _ = brw.Flush() if _, err := io.Copy(io.Discard, reader); err != nil { errCh <- fmt.Errorf("wait for request half-close failed: %w", err) return } - if _, err := io.WriteString(w, "after-close\n"); err != nil { + if _, err := io.WriteString(brw, "after-close\n"); err != nil { errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err) return } - _ = controller.Flush() + _ = brw.Flush() })) - upstream.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { - t.Fatalf("configure upstream HTTP/2 server: %v", err) - } - upstream.StartTLS() defer upstream.Close() engine := New() engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Via: "proxy.test", + Target: mustParseURL(t, upstream.URL), + Via: "proxy.test", })) proxy := httptest.NewUnstartedServer(engine) @@ -1668,7 +1858,8 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } reader := bufio.NewReader(resp.Body) @@ -1707,36 +1898,37 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi enableHTTP2ExtendedConnectProtocol() errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("enable full duplex failed: %w", err) + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + _ = brw.Flush() <-r.Context().Done() })) - upstream.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { - t.Fatalf("configure upstream HTTP/2 server: %v", err) - } - upstream.StartTLS() defer upstream.Close() proxyErrCh := make(chan error, 1) engine := New() engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Via: "proxy.test", + Target: mustParseURL(t, upstream.URL), + Via: "proxy.test", ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { select { case proxyErrCh <- err: @@ -1772,7 +1964,8 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } writeErrCh := make(chan error, 1) From 50c6a2361405e2ee103a2cc5a301f38c891c2ca6 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:50:27 +0800 Subject: [PATCH 2/7] refactor: simplify reverse proxy bridged connection handling by removing unused bufio --- reverseproxy.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 13ffe89..e674e7e 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -5,7 +5,6 @@ package touka import ( - "bufio" "context" "crypto/rand" "encoding/base64" @@ -1011,7 +1010,6 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r } conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer} - brw := bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) backConnClosed := make(chan struct{}) go func() { @@ -1025,10 +1023,6 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r defer conn.Close() defer backConn.Close() - if err := brw.Flush(); err != nil { - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} - } - errc := make(chan error, 2) copyer := switchProtocolCopier{user: conn, backend: backConn} go copyer.copyToBackend(errc) From 7abedc1acea918edb522d00a5e1c3c7e7d45870d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:33:18 +0800 Subject: [PATCH 3/7] enhance: improve reverse proxy error handling and add tests --- reverseproxy.go | 24 ++++---- reverseproxy_test.go | 130 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 10 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index e674e7e..2d6dfea 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -101,10 +101,10 @@ type reverseProxyH2ReadWriteCloser struct { func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { n, err := rwc.ResponseWriter.Write(p) if err != nil { - return 0, err + return n, err } if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { - return 0, err + return n, err } return n, nil } @@ -479,7 +479,10 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { outreq := c.Request.Clone(ctx) - bridgeCtx, bridged := reverseProxyPrepareExtendedConnectBridge(outreq) + bridgeCtx, bridged, err := reverseProxyPrepareExtendedConnectBridge(outreq) + if err != nil { + return nil, nil, nil, err + } if bridged { outreq = outreq.WithContext(bridgeCtx) } @@ -1370,13 +1373,13 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } -func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool) { +func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { protocol := reverseProxyExtendedConnectProtocol(req) if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { if req == nil { - return context.Background(), false + return context.Background(), false, nil } - return req.Context(), false + return req.Context(), false, nil } bridge := &reverseProxyExtendedConnectBridge{body: req.Body} @@ -1389,10 +1392,11 @@ func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Contex req.Header.Set("Connection", "Upgrade") req.Header.Set("Sec-WebSocket-Version", "13") key, err := reverseProxyGenerateWebSocketKey() - if err == nil { - req.Header.Set("Sec-WebSocket-Key", key) + if err != nil { + return nil, false, fmt.Errorf("reverse proxy failed to generate websocket key: %w", err) } - return ctx, true + req.Header.Set("Sec-WebSocket-Key", key) + return ctx, true, nil } func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge { @@ -1405,7 +1409,7 @@ func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseP func reverseProxyGenerateWebSocketKey() (string, error) { key := make([]byte, 16) - if _, err := rand.Read(key); err != nil { + if _, err := io.ReadFull(rand.Reader, key); err != nil { return "", err } return base64.StdEncoding.EncodeToString(key), nil diff --git a/reverseproxy_test.go b/reverseproxy_test.go index b05f426..8a250b2 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -2,7 +2,9 @@ package touka import ( "bufio" + "bytes" "context" + crand "crypto/rand" "crypto/tls" "errors" "fmt" @@ -829,6 +831,67 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) { } } +func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *testing.T) { + t.Helper() + + flushErr := errors.New("flush failed") + writer := &flushErrorResponseWriter{flushErr: flushErr} + conn := reverseProxyH2ReadWriteCloser{ + ReadCloser: io.NopCloser(strings.NewReader("")), + ResponseWriter: writer, + } + + n, err := conn.Write([]byte("ping")) + if n != len("ping") { + t.Fatalf("unexpected bytes written: %d", n) + } + if !errors.Is(err, flushErr) { + t.Fatalf("unexpected write error: %v", err) + } + if got := writer.body.String(); got != "ping" { + t.Fatalf("unexpected buffered body: %q", got) + } +} + +func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *testing.T) { + t.Helper() + + transportCalled := atomic.Bool{} + entropyErr := errors.New("entropy source unavailable") + originalReader := crand.Reader + crand.Reader = errorReader{err: entropyErr} + t.Cleanup(func() { + crand.Reader = originalReader + }) + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + transportCalled.Store(true) + return nil, errors.New("unexpected round trip") + }), + ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) { + w.WriteHeader(reverseProxyStatusCode(err)) + _, _ = io.WriteString(w, err.Error()) + }, + })) + + headers := make(http.Header) + headers.Set(":protocol", "websocket") + rr := PerformRequest(engine, http.MethodConnect, "/ws", nil, headers) + + if transportCalled.Load() { + t.Fatal("transport should not be called when websocket key generation fails") + } + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "reverse proxy failed to generate websocket key") || !strings.Contains(body, entropyErr.Error()) { + t.Fatalf("unexpected error body: %q", body) + } +} + func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { t.Helper() @@ -2137,6 +2200,73 @@ func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) return fn(req) } +type flushErrorResponseWriter struct { + header http.Header + body bytes.Buffer + status int + written bool + flushErr error +} + +func (w *flushErrorResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *flushErrorResponseWriter) WriteHeader(statusCode int) { + if w.written { + return + } + w.status = statusCode + w.written = true +} + +func (w *flushErrorResponseWriter) Write(p []byte) (int, error) { + if !w.written { + w.WriteHeader(http.StatusOK) + } + return w.body.Write(p) +} + +func (w *flushErrorResponseWriter) Flush() {} + +func (w *flushErrorResponseWriter) FlushError() error { + if !w.written { + w.WriteHeader(http.StatusOK) + } + return w.flushErr +} + +func (w *flushErrorResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} + +func (w *flushErrorResponseWriter) Status() int { + return w.status +} + +func (w *flushErrorResponseWriter) Size() int { + return w.body.Len() +} + +func (w *flushErrorResponseWriter) Written() bool { + return w.written +} + +func (w *flushErrorResponseWriter) IsHijacked() bool { + return false +} + +type errorReader struct { + err error +} + +func (r errorReader) Read([]byte) (int, error) { + return 0, r.err +} + func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) From 20dc6e4047cb42ca8b00206bddb6d762a897fce1 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:44:02 +0800 Subject: [PATCH 4/7] refactor: cache ResponseController in H2ReadWriteCloser for better performance --- reverseproxy.go | 9 +++++---- reverseproxy_test.go | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 2d6dfea..afdbd9c 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -96,20 +96,21 @@ type reverseProxyExtendedConnectBridge struct { type reverseProxyH2ReadWriteCloser struct { io.ReadCloser ResponseWriter + controller *http.ResponseController } -func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { +func (rwc *reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { n, err := rwc.ResponseWriter.Write(p) if err != nil { return n, err } - if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + if err := rwc.controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { return n, err } return n, nil } -func (rwc reverseProxyH2ReadWriteCloser) Close() error { +func (rwc *reverseProxyH2ReadWriteCloser) Close() error { if rwc.ReadCloser == nil { return nil } @@ -1012,7 +1013,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } - conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer} + conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} backConnClosed := make(chan struct{}) go func() { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 8a250b2..85a64d4 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -836,9 +836,10 @@ func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *te flushErr := errors.New("flush failed") writer := &flushErrorResponseWriter{flushErr: flushErr} - conn := reverseProxyH2ReadWriteCloser{ + conn := &reverseProxyH2ReadWriteCloser{ ReadCloser: io.NopCloser(strings.NewReader("")), ResponseWriter: writer, + controller: http.NewResponseController(reverseProxyBaseResponseWriter(writer)), } n, err := conn.Write([]byte("ping")) From dcdb1504a32b709122709a37dc39f4fc869e70d8 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:58:34 +0800 Subject: [PATCH 5/7] feat: add robust transport cloning and improve header handling in reverse proxy --- http2xconnect.go | 26 ++++++++++++++++++++++--- reverseproxy.go | 3 +-- reverseproxy_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/http2xconnect.go b/http2xconnect.go index 872f5b3..8521672 100644 --- a/http2xconnect.go +++ b/http2xconnect.go @@ -6,8 +6,10 @@ package touka import ( "crypto/tls" + "net" "net/http" "sync" + "time" _ "unsafe" "golang.org/x/net/http2" @@ -34,7 +36,7 @@ func configureHTTP2ExtendedConnectServer(srv *http.Server) error { func newHTTP2ExtendedConnectTransport() http.RoundTripper { enableHTTP2ExtendedConnectProtocol() - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetHTTP1(true) transport.Protocols.SetHTTP2(true) @@ -46,7 +48,7 @@ func newHTTP1BridgeTransport() http.RoundTripper { } func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper { - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetHTTP1(true) transport.TLSClientConfig = tlsConfig @@ -60,8 +62,26 @@ func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripp } func newH2CTransport() http.RoundTripper { - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetUnencryptedHTTP2(true) return transport } + +func cloneDefaultTransport() *http.Transport { + if transport, ok := http.DefaultTransport.(*http.Transport); ok { + return transport.Clone() + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} diff --git a/reverseproxy.go b/reverseproxy.go index afdbd9c..5d9b1ad 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1004,8 +1004,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r responseHeader := c.Writer.Header() reverseProxyCopyHeader(responseHeader, res.Header) - responseHeader.Del("Upgrade") - responseHeader.Del("Connection") + removeHopByHopHeaders(responseHeader) responseHeader.Del("Sec-WebSocket-Accept") c.Writer.WriteHeader(http.StatusOK) if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 85a64d4..9252e4a 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -893,6 +893,46 @@ func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *te } } +func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing.T) { + t.Helper() + + originalDefaultTransport := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("unexpected round trip") + }) + t.Cleanup(func() { + http.DefaultTransport = originalDefaultTransport + }) + + assertTransport := func(name string, rt http.RoundTripper, check func(*http.Transport)) { + t.Helper() + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("%s returned %T, want *http.Transport", name, rt) + } + check(transport) + } + + assertTransport("newHTTP2ExtendedConnectTransport", newHTTP2ExtendedConnectTransport(), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.HTTP1() || !transport.Protocols.HTTP2() { + t.Fatalf("unexpected protocols for extended connect transport: %#v", transport.Protocols) + } + }) + assertTransport("newHTTP1BridgeTransportWithTLSConfig", newHTTP1BridgeTransportWithTLSConfig(nil), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.HTTP1() || transport.Protocols.HTTP2() || transport.Protocols.UnencryptedHTTP2() { + t.Fatalf("unexpected protocols for bridge transport: %#v", transport.Protocols) + } + if transport.TLSClientConfig == nil || len(transport.TLSClientConfig.NextProtos) != 1 || transport.TLSClientConfig.NextProtos[0] != "http/1.1" { + t.Fatalf("unexpected TLS next protos for bridge transport: %#v", transport.TLSClientConfig) + } + }) + assertTransport("newH2CTransport", newH2CTransport(), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.UnencryptedHTTP2() || transport.Protocols.HTTP1() || transport.Protocols.HTTP2() { + t.Fatalf("unexpected protocols for h2c transport: %#v", transport.Protocols) + } + }) +} + func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { t.Helper() @@ -1509,7 +1549,7 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } defer conn.Close() - _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n") + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade, X-Hop-Token\r\nX-Hop-Token: hidden\r\nSec-WebSocket-Accept: ignored\r\n\r\n") if err := brw.Flush(); err != nil { errCh <- fmt.Errorf("upstream flush failed: %w", err) return @@ -1565,6 +1605,9 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { if got := resp.Header.Get("Upgrade"); got != "" { t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got) } + if got := resp.Header.Get("X-Hop-Token"); got != "" { + t.Fatalf("bridged extended CONNECT response should not expose hop-by-hop token header, got %q", got) + } if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { t.Fatalf("unexpected Via response header: %#v", gotVia) } From d53693952a8cccd8e6579c7a92f7a63b5a6bb5d8 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 22:13:50 +0800 Subject: [PATCH 6/7] refactor: improve TLS config handling and add bridge connection tests --- http2xconnect.go | 5 +- reverseproxy.go | 9 ++- reverseproxy_test.go | 146 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 7 deletions(-) diff --git a/http2xconnect.go b/http2xconnect.go index 8521672..c691a77 100644 --- a/http2xconnect.go +++ b/http2xconnect.go @@ -51,9 +51,10 @@ func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripp transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetHTTP1(true) - transport.TLSClientConfig = tlsConfig - if transport.TLSClientConfig == nil { + if tlsConfig == nil { transport.TLSClientConfig = &tls.Config{} + } else { + transport.TLSClientConfig = tlsConfig.Clone() } if len(transport.TLSClientConfig.NextProtos) == 0 { transport.TLSClientConfig.NextProtos = []string{"http/1.1"} diff --git a/reverseproxy.go b/reverseproxy.go index 5d9b1ad..148a9b4 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1024,7 +1024,6 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r }() defer close(backConnClosed) defer conn.Close() - defer backConn.Close() errc := make(chan error, 2) copyer := switchProtocolCopier{user: conn, backend: backConn} @@ -1374,11 +1373,11 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { } func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { + if req == nil { + return context.Background(), false, nil + } protocol := reverseProxyExtendedConnectProtocol(req) - if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { - if req == nil { - return context.Background(), false, nil - } + if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { return req.Context(), false, nil } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 9252e4a..bf7b0bb 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -933,6 +933,29 @@ func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing }) } +func TestNewHTTP1BridgeTransportWithTLSConfigClonesInput(t *testing.T) { + t.Helper() + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + rt := newHTTP1BridgeTransportWithTLSConfig(tlsConfig) + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("unexpected transport type: %T", rt) + } + if transport.TLSClientConfig == nil { + t.Fatal("expected TLS client config") + } + if transport.TLSClientConfig == tlsConfig { + t.Fatal("expected bridge transport to clone TLS config") + } + if len(tlsConfig.NextProtos) != 0 { + t.Fatalf("input TLS config was mutated: %#v", tlsConfig.NextProtos) + } + if got := transport.TLSClientConfig.NextProtos; len(got) != 1 || got[0] != "http/1.1" { + t.Fatalf("unexpected transport NextProtos: %#v", got) + } +} + func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { t.Helper() @@ -1633,6 +1656,98 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + closeCalls := atomic.Int32{} + transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Method != http.MethodGet { + return nil, fmt.Errorf("unexpected upstream method: %s", req.Method) + } + backend := &countingReadWriteCloser{ + readData: []byte("echo:ping\n"), + closeCalls: &closeCalls, + closeWriteErr: http.ErrNotSupported, + } + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: http.Header{ + "Connection": []string{"Upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-WebSocket-Accept": []string{"ignored"}, + }, + Body: backend, + Request: req, + }, nil + }) + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: transport, + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + clientTransport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer clientTransport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := clientTransport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if _, err := io.WriteString(pw, "ping\n"); err != nil { + _ = resp.Body.Close() + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + _ = resp.Body.Close() + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + _ = resp.Body.Close() + t.Fatalf("unexpected tunneled response body: %q", message) + } + if err := pw.Close(); err != nil { + _ = resp.Body.Close() + t.Fatalf("close tunneled request body: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatalf("close response body: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if closeCalls.Load() > 0 { + break + } + time.Sleep(10 * time.Millisecond) + } + if got := closeCalls.Load(); got != 1 { + t.Fatalf("expected backend connection to close exactly once, got %d", got) + } +} + func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) { t.Helper() @@ -2311,6 +2426,37 @@ func (r errorReader) Read([]byte) (int, error) { return 0, r.err } +type countingReadWriteCloser struct { + readData []byte + writeBuf bytes.Buffer + closeCalls *atomic.Int32 + closeWriteErr error +} + +func (r *countingReadWriteCloser) Read(p []byte) (int, error) { + if len(r.readData) == 0 { + return 0, io.EOF + } + n := copy(p, r.readData) + r.readData = r.readData[n:] + return n, nil +} + +func (r *countingReadWriteCloser) Write(p []byte) (int, error) { + return r.writeBuf.Write(p) +} + +func (r *countingReadWriteCloser) Close() error { + if r.closeCalls != nil { + r.closeCalls.Add(1) + } + return nil +} + +func (r *countingReadWriteCloser) CloseWrite() error { + return r.closeWriteErr +} + func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) From 1a6325d461d6c2594b29c8ba61c1af99a0dd6454 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Fri, 3 Apr 2026 00:29:15 +0800 Subject: [PATCH 7/7] feat: improve reverse proxy tunnel management with sync.Once and better error handling --- reverseproxy.go | 53 ++++++++++++++++++++------------------------ reverseproxy_test.go | 38 +++++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 148a9b4..1b89b2a 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -518,24 +518,10 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte } } - if outreq.Method == http.MethodConnect { - if bridged { - rewriteReverseProxyURL(outreq, upstream.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } else if reverseProxyIsExtendedConnectRequest(outreq) { - rewriteReverseProxyURL(outreq, upstream.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } else { - if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { - cleanup() - return nil, nil, nil, err - } + if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) { + if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { + cleanup() + return nil, nil, nil, err } } else { rewriteReverseProxyURL(outreq, upstream.target) @@ -1014,26 +1000,35 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} - backConnClosed := make(chan struct{}) + var closeOnce sync.Once + closeTunnel := func() { + closeOnce.Do(func() { + _ = conn.Close() + _ = backConn.Close() + }) + } go func() { - select { - case <-req.Context().Done(): - case <-backConnClosed: - } - backConn.Close() + <-req.Context().Done() + closeTunnel() }() - defer close(backConnClosed) - defer conn.Close() errc := make(chan error, 2) copyer := switchProtocolCopier{user: conn, backend: backConn} go copyer.copyToBackend(errc) go copyer.copyFromBackend(errc) - firstErr := <-errc - if firstErr == nil { - firstErr = <-errc + var firstErr error + for i := 0; i < 2; i++ { + err := <-errc + if reverseProxyIsBenignTunnelError(err) { + continue + } + if firstErr == nil { + firstErr = err + closeTunnel() + } } + closeTunnel() if reverseProxyIsBenignTunnelError(firstErr) { return nil } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index bf7b0bb..9cbc734 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -17,6 +17,7 @@ import ( "net/url" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -1662,14 +1663,24 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { enableHTTP2ExtendedConnectProtocol() closeCalls := atomic.Int32{} + backendReadDone := make(chan struct{}, 1) transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.Method != http.MethodGet { return nil, fmt.Errorf("unexpected upstream method: %s", req.Method) } - backend := &countingReadWriteCloser{ - readData: []byte("echo:ping\n"), + var respondOnce sync.Once + var backend *countingReadWriteCloser + backend = &countingReadWriteCloser{ + readDataCh: make(chan []byte, 1), closeCalls: &closeCalls, - closeWriteErr: http.ErrNotSupported, + closeWriteErr: nil, + afterWrite: func() { + respondOnce.Do(func() { + backendReadDone <- struct{}{} + backend.readDataCh <- []byte("echo:ping\n") + close(backend.readDataCh) + }) + }, } return &http.Response{ StatusCode: http.StatusSwitchingProtocols, @@ -1719,6 +1730,12 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { _ = resp.Body.Close() t.Fatalf("write tunneled request body: %v", err) } + select { + case <-backendReadDone: + case <-time.After(2 * time.Second): + _ = resp.Body.Close() + t.Fatal("backend did not receive tunneled request body") + } message, err := bufio.NewReader(resp.Body).ReadString('\n') if err != nil { _ = resp.Body.Close() @@ -2428,12 +2445,21 @@ func (r errorReader) Read([]byte) (int, error) { type countingReadWriteCloser struct { readData []byte + readDataCh chan []byte writeBuf bytes.Buffer closeCalls *atomic.Int32 closeWriteErr error + afterWrite func() } func (r *countingReadWriteCloser) Read(p []byte) (int, error) { + if len(r.readData) == 0 && r.readDataCh != nil { + data, ok := <-r.readDataCh + if !ok { + return 0, io.EOF + } + r.readData = data + } if len(r.readData) == 0 { return 0, io.EOF } @@ -2443,7 +2469,11 @@ func (r *countingReadWriteCloser) Read(p []byte) (int, error) { } func (r *countingReadWriteCloser) Write(p []byte) (int, error) { - return r.writeBuf.Write(p) + n, err := r.writeBuf.Write(p) + if err == nil && r.afterWrite != nil { + r.afterWrite() + } + return n, err } func (r *countingReadWriteCloser) Close() error {