mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat: improve reverse proxy tunnel management with sync.Once and better error handling
This commit is contained in:
parent
d53693952a
commit
1a6325d461
2 changed files with 58 additions and 33 deletions
|
|
@ -518,25 +518,11 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if outreq.Method == http.MethodConnect {
|
if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) {
|
||||||
if bridged {
|
|
||||||
rewriteReverseProxyURL(outreq, upstream.target)
|
|
||||||
if !p.config.PreserveHost {
|
|
||||||
outreq.Host = ""
|
|
||||||
}
|
|
||||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
|
||||||
} else if reverseProxyIsExtendedConnectRequest(outreq) {
|
|
||||||
rewriteReverseProxyURL(outreq, upstream.target)
|
|
||||||
if !p.config.PreserveHost {
|
|
||||||
outreq.Host = ""
|
|
||||||
}
|
|
||||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
|
||||||
} else {
|
|
||||||
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
|
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
|
||||||
cleanup()
|
cleanup()
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
rewriteReverseProxyURL(outreq, upstream.target)
|
rewriteReverseProxyURL(outreq, upstream.target)
|
||||||
if !p.config.PreserveHost {
|
if !p.config.PreserveHost {
|
||||||
|
|
@ -1014,26 +1000,35 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r
|
||||||
|
|
||||||
conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller}
|
conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller}
|
||||||
|
|
||||||
backConnClosed := make(chan struct{})
|
var closeOnce sync.Once
|
||||||
go func() {
|
closeTunnel := func() {
|
||||||
select {
|
closeOnce.Do(func() {
|
||||||
case <-req.Context().Done():
|
_ = conn.Close()
|
||||||
case <-backConnClosed:
|
_ = backConn.Close()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
backConn.Close()
|
go func() {
|
||||||
|
<-req.Context().Done()
|
||||||
|
closeTunnel()
|
||||||
}()
|
}()
|
||||||
defer close(backConnClosed)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
errc := make(chan error, 2)
|
errc := make(chan error, 2)
|
||||||
copyer := switchProtocolCopier{user: conn, backend: backConn}
|
copyer := switchProtocolCopier{user: conn, backend: backConn}
|
||||||
go copyer.copyToBackend(errc)
|
go copyer.copyToBackend(errc)
|
||||||
go copyer.copyFromBackend(errc)
|
go copyer.copyFromBackend(errc)
|
||||||
|
|
||||||
firstErr := <-errc
|
var firstErr error
|
||||||
if firstErr == nil {
|
for i := 0; i < 2; i++ {
|
||||||
firstErr = <-errc
|
err := <-errc
|
||||||
|
if reverseProxyIsBenignTunnelError(err) {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
closeTunnel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closeTunnel()
|
||||||
if reverseProxyIsBenignTunnelError(firstErr) {
|
if reverseProxyIsBenignTunnelError(firstErr) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -1662,14 +1663,24 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) {
|
||||||
enableHTTP2ExtendedConnectProtocol()
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
|
||||||
closeCalls := atomic.Int32{}
|
closeCalls := atomic.Int32{}
|
||||||
|
backendReadDone := make(chan struct{}, 1)
|
||||||
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
if req.Method != http.MethodGet {
|
if req.Method != http.MethodGet {
|
||||||
return nil, fmt.Errorf("unexpected upstream method: %s", req.Method)
|
return nil, fmt.Errorf("unexpected upstream method: %s", req.Method)
|
||||||
}
|
}
|
||||||
backend := &countingReadWriteCloser{
|
var respondOnce sync.Once
|
||||||
readData: []byte("echo:ping\n"),
|
var backend *countingReadWriteCloser
|
||||||
|
backend = &countingReadWriteCloser{
|
||||||
|
readDataCh: make(chan []byte, 1),
|
||||||
closeCalls: &closeCalls,
|
closeCalls: &closeCalls,
|
||||||
closeWriteErr: http.ErrNotSupported,
|
closeWriteErr: nil,
|
||||||
|
afterWrite: func() {
|
||||||
|
respondOnce.Do(func() {
|
||||||
|
backendReadDone <- struct{}{}
|
||||||
|
backend.readDataCh <- []byte("echo:ping\n")
|
||||||
|
close(backend.readDataCh)
|
||||||
|
})
|
||||||
|
},
|
||||||
}
|
}
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
StatusCode: http.StatusSwitchingProtocols,
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
|
|
@ -1719,6 +1730,12 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) {
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
t.Fatalf("write tunneled request body: %v", err)
|
t.Fatalf("write tunneled request body: %v", err)
|
||||||
}
|
}
|
||||||
|
select {
|
||||||
|
case <-backendReadDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
t.Fatal("backend did not receive tunneled request body")
|
||||||
|
}
|
||||||
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
@ -2428,12 +2445,21 @@ func (r errorReader) Read([]byte) (int, error) {
|
||||||
|
|
||||||
type countingReadWriteCloser struct {
|
type countingReadWriteCloser struct {
|
||||||
readData []byte
|
readData []byte
|
||||||
|
readDataCh chan []byte
|
||||||
writeBuf bytes.Buffer
|
writeBuf bytes.Buffer
|
||||||
closeCalls *atomic.Int32
|
closeCalls *atomic.Int32
|
||||||
closeWriteErr error
|
closeWriteErr error
|
||||||
|
afterWrite func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *countingReadWriteCloser) Read(p []byte) (int, error) {
|
func (r *countingReadWriteCloser) Read(p []byte) (int, error) {
|
||||||
|
if len(r.readData) == 0 && r.readDataCh != nil {
|
||||||
|
data, ok := <-r.readDataCh
|
||||||
|
if !ok {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
r.readData = data
|
||||||
|
}
|
||||||
if len(r.readData) == 0 {
|
if len(r.readData) == 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
@ -2443,7 +2469,11 @@ func (r *countingReadWriteCloser) Read(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *countingReadWriteCloser) Write(p []byte) (int, error) {
|
func (r *countingReadWriteCloser) Write(p []byte) (int, error) {
|
||||||
return r.writeBuf.Write(p)
|
n, err := r.writeBuf.Write(p)
|
||||||
|
if err == nil && r.afterWrite != nil {
|
||||||
|
r.afterWrite()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *countingReadWriteCloser) Close() error {
|
func (r *countingReadWriteCloser) Close() error {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue