feat: improve reverse proxy tunnel management with sync.Once and better error handling

This commit is contained in:
wjqserver 2026-04-03 00:29:15 +08:00
parent d53693952a
commit 1a6325d461
2 changed files with 58 additions and 33 deletions

View file

@ -518,24 +518,10 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
} }
} }
if outreq.Method == http.MethodConnect { if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) {
if bridged { if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
rewriteReverseProxyURL(outreq, upstream.target) cleanup()
if !p.config.PreserveHost { return nil, nil, nil, err
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
} else if reverseProxyIsExtendedConnectRequest(outreq) {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
} else {
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
cleanup()
return nil, nil, nil, err
}
} }
} else { } else {
rewriteReverseProxyURL(outreq, upstream.target) rewriteReverseProxyURL(outreq, upstream.target)
@ -1014,26 +1000,35 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r
conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller}
backConnClosed := make(chan struct{}) var closeOnce sync.Once
closeTunnel := func() {
closeOnce.Do(func() {
_ = conn.Close()
_ = backConn.Close()
})
}
go func() { go func() {
select { <-req.Context().Done()
case <-req.Context().Done(): closeTunnel()
case <-backConnClosed:
}
backConn.Close()
}() }()
defer close(backConnClosed)
defer conn.Close()
errc := make(chan error, 2) errc := make(chan error, 2)
copyer := switchProtocolCopier{user: conn, backend: backConn} copyer := switchProtocolCopier{user: conn, backend: backConn}
go copyer.copyToBackend(errc) go copyer.copyToBackend(errc)
go copyer.copyFromBackend(errc) go copyer.copyFromBackend(errc)
firstErr := <-errc var firstErr error
if firstErr == nil { for i := 0; i < 2; i++ {
firstErr = <-errc err := <-errc
if reverseProxyIsBenignTunnelError(err) {
continue
}
if firstErr == nil {
firstErr = err
closeTunnel()
}
} }
closeTunnel()
if reverseProxyIsBenignTunnelError(firstErr) { if reverseProxyIsBenignTunnelError(firstErr) {
return nil return nil
} }

View file

@ -17,6 +17,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -1662,14 +1663,24 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) {
enableHTTP2ExtendedConnectProtocol() enableHTTP2ExtendedConnectProtocol()
closeCalls := atomic.Int32{} closeCalls := atomic.Int32{}
backendReadDone := make(chan struct{}, 1)
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.Method != http.MethodGet { if req.Method != http.MethodGet {
return nil, fmt.Errorf("unexpected upstream method: %s", req.Method) return nil, fmt.Errorf("unexpected upstream method: %s", req.Method)
} }
backend := &countingReadWriteCloser{ var respondOnce sync.Once
readData: []byte("echo:ping\n"), var backend *countingReadWriteCloser
backend = &countingReadWriteCloser{
readDataCh: make(chan []byte, 1),
closeCalls: &closeCalls, closeCalls: &closeCalls,
closeWriteErr: http.ErrNotSupported, closeWriteErr: nil,
afterWrite: func() {
respondOnce.Do(func() {
backendReadDone <- struct{}{}
backend.readDataCh <- []byte("echo:ping\n")
close(backend.readDataCh)
})
},
} }
return &http.Response{ return &http.Response{
StatusCode: http.StatusSwitchingProtocols, StatusCode: http.StatusSwitchingProtocols,
@ -1719,6 +1730,12 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) {
_ = resp.Body.Close() _ = resp.Body.Close()
t.Fatalf("write tunneled request body: %v", err) t.Fatalf("write tunneled request body: %v", err)
} }
select {
case <-backendReadDone:
case <-time.After(2 * time.Second):
_ = resp.Body.Close()
t.Fatal("backend did not receive tunneled request body")
}
message, err := bufio.NewReader(resp.Body).ReadString('\n') message, err := bufio.NewReader(resp.Body).ReadString('\n')
if err != nil { if err != nil {
_ = resp.Body.Close() _ = resp.Body.Close()
@ -2428,12 +2445,21 @@ func (r errorReader) Read([]byte) (int, error) {
type countingReadWriteCloser struct { type countingReadWriteCloser struct {
readData []byte readData []byte
readDataCh chan []byte
writeBuf bytes.Buffer writeBuf bytes.Buffer
closeCalls *atomic.Int32 closeCalls *atomic.Int32
closeWriteErr error closeWriteErr error
afterWrite func()
} }
func (r *countingReadWriteCloser) Read(p []byte) (int, error) { func (r *countingReadWriteCloser) Read(p []byte) (int, error) {
if len(r.readData) == 0 && r.readDataCh != nil {
data, ok := <-r.readDataCh
if !ok {
return 0, io.EOF
}
r.readData = data
}
if len(r.readData) == 0 { if len(r.readData) == 0 {
return 0, io.EOF return 0, io.EOF
} }
@ -2443,7 +2469,11 @@ func (r *countingReadWriteCloser) Read(p []byte) (int, error) {
} }
func (r *countingReadWriteCloser) Write(p []byte) (int, error) { func (r *countingReadWriteCloser) Write(p []byte) (int, error) {
return r.writeBuf.Write(p) n, err := r.writeBuf.Write(p)
if err == nil && r.afterWrite != nil {
r.afterWrite()
}
return n, err
} }
func (r *countingReadWriteCloser) Close() error { func (r *countingReadWriteCloser) Close() error {