From 1a6325d461d6c2594b29c8ba61c1af99a0dd6454 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Fri, 3 Apr 2026 00:29:15 +0800 Subject: [PATCH] feat: improve reverse proxy tunnel management with sync.Once and better error handling --- reverseproxy.go | 53 ++++++++++++++++++++------------------------ reverseproxy_test.go | 38 +++++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 148a9b4..1b89b2a 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -518,24 +518,10 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte } } - if outreq.Method == http.MethodConnect { - if bridged { - rewriteReverseProxyURL(outreq, upstream.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } else if reverseProxyIsExtendedConnectRequest(outreq) { - rewriteReverseProxyURL(outreq, upstream.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } else { - if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { - cleanup() - return nil, nil, nil, err - } + if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) { + if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { + cleanup() + return nil, nil, nil, err } } else { rewriteReverseProxyURL(outreq, upstream.target) @@ -1014,26 +1000,35 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} - backConnClosed := make(chan struct{}) + var closeOnce sync.Once + closeTunnel := func() { + closeOnce.Do(func() { + _ = conn.Close() + _ = backConn.Close() + }) + } go func() { - select { - case <-req.Context().Done(): - case <-backConnClosed: - } - backConn.Close() + <-req.Context().Done() + closeTunnel() }() - defer close(backConnClosed) - defer conn.Close() errc := make(chan error, 2) copyer := switchProtocolCopier{user: conn, backend: backConn} go copyer.copyToBackend(errc) go copyer.copyFromBackend(errc) - firstErr := <-errc - if firstErr == nil { - firstErr = <-errc + var firstErr error + for i := 0; i < 2; i++ { + err := <-errc + if reverseProxyIsBenignTunnelError(err) { + continue + } + if firstErr == nil { + firstErr = err + closeTunnel() + } } + closeTunnel() if reverseProxyIsBenignTunnelError(firstErr) { return nil } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index bf7b0bb..9cbc734 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -17,6 +17,7 @@ import ( "net/url" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -1662,14 +1663,24 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { enableHTTP2ExtendedConnectProtocol() closeCalls := atomic.Int32{} + backendReadDone := make(chan struct{}, 1) transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.Method != http.MethodGet { return nil, fmt.Errorf("unexpected upstream method: %s", req.Method) } - backend := &countingReadWriteCloser{ - readData: []byte("echo:ping\n"), + var respondOnce sync.Once + var backend *countingReadWriteCloser + backend = &countingReadWriteCloser{ + readDataCh: make(chan []byte, 1), closeCalls: &closeCalls, - closeWriteErr: http.ErrNotSupported, + closeWriteErr: nil, + afterWrite: func() { + respondOnce.Do(func() { + backendReadDone <- struct{}{} + backend.readDataCh <- []byte("echo:ping\n") + close(backend.readDataCh) + }) + }, } return &http.Response{ StatusCode: http.StatusSwitchingProtocols, @@ -1719,6 +1730,12 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { _ = resp.Body.Close() t.Fatalf("write tunneled request body: %v", err) } + select { + case <-backendReadDone: + case <-time.After(2 * time.Second): + _ = resp.Body.Close() + t.Fatal("backend did not receive tunneled request body") + } message, err := bufio.NewReader(resp.Body).ReadString('\n') if err != nil { _ = resp.Body.Close() @@ -2428,12 +2445,21 @@ func (r errorReader) Read([]byte) (int, error) { type countingReadWriteCloser struct { readData []byte + readDataCh chan []byte writeBuf bytes.Buffer closeCalls *atomic.Int32 closeWriteErr error + afterWrite func() } func (r *countingReadWriteCloser) Read(p []byte) (int, error) { + if len(r.readData) == 0 && r.readDataCh != nil { + data, ok := <-r.readDataCh + if !ok { + return 0, io.EOF + } + r.readData = data + } if len(r.readData) == 0 { return 0, io.EOF } @@ -2443,7 +2469,11 @@ func (r *countingReadWriteCloser) Read(p []byte) (int, error) { } func (r *countingReadWriteCloser) Write(p []byte) (int, error) { - return r.writeBuf.Write(p) + n, err := r.writeBuf.Write(p) + if err == nil && r.afterWrite != nil { + r.afterWrite() + } + return n, err } func (r *countingReadWriteCloser) Close() error {