mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
fix(reverseproxy): bridge websocket extended connect upstreams
This commit is contained in:
parent
919236665b
commit
a9c1662333
5 changed files with 508 additions and 99 deletions
194
reverseproxy.go
194
reverseproxy.go
|
|
@ -5,7 +5,10 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -52,9 +55,10 @@ type ReverseProxyConfig struct {
|
|||
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||
|
||||
Transport http.RoundTripper
|
||||
FlushInterval time.Duration
|
||||
BufferPool BufferPool
|
||||
Transport http.RoundTripper
|
||||
FlushInterval time.Duration
|
||||
BufferPool BufferPool
|
||||
AllowH2CUpstream bool
|
||||
|
||||
ModifyRequest func(*http.Request)
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
|
@ -86,6 +90,33 @@ type reverseProxyStatusError struct {
|
|||
err error
|
||||
}
|
||||
|
||||
type reverseProxyExtendedConnectBridge struct {
|
||||
body io.ReadCloser
|
||||
}
|
||||
|
||||
type reverseProxyH2ReadWriteCloser struct {
|
||||
io.ReadCloser
|
||||
ResponseWriter
|
||||
}
|
||||
|
||||
func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) {
|
||||
n, err := rwc.ResponseWriter.Write(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (rwc reverseProxyH2ReadWriteCloser) Close() error {
|
||||
if rwc.ReadCloser == nil {
|
||||
return nil
|
||||
}
|
||||
return rwc.ReadCloser.Close()
|
||||
}
|
||||
|
||||
func (e *reverseProxyStatusError) Error() string {
|
||||
if e == nil || e.err == nil {
|
||||
return ""
|
||||
|
|
@ -314,7 +345,7 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
|
|||
}
|
||||
defer cleanup()
|
||||
|
||||
transport := p.transportForUpstream(c.Request, upstream)
|
||||
transport := p.transportForUpstream(outreq, upstream)
|
||||
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
|
||||
var (
|
||||
roundTripMu sync.Mutex
|
||||
|
|
@ -353,6 +384,20 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
|
|||
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
|
||||
}
|
||||
|
||||
if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil {
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
return true, nil, false
|
||||
}
|
||||
if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil {
|
||||
return false, err, false
|
||||
}
|
||||
return true, nil, false
|
||||
}
|
||||
return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
|
||||
removeHopByHopHeaders(res.Header)
|
||||
res.Header.Del("Content-Length")
|
||||
|
|
@ -435,6 +480,10 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
|
|||
|
||||
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
|
||||
outreq := c.Request.Clone(ctx)
|
||||
bridgeCtx, bridged := reverseProxyPrepareExtendedConnectBridge(outreq)
|
||||
if bridged {
|
||||
outreq = outreq.WithContext(bridgeCtx)
|
||||
}
|
||||
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
|
||||
outreq.Body = nil
|
||||
} else if c.Request.GetBody != nil {
|
||||
|
|
@ -451,7 +500,7 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
|
|||
}
|
||||
outreq.Close = false
|
||||
var connectWriter *io.PipeWriter
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if outreq.Method == http.MethodConnect && !bridged {
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
outreq.Body = pipeReader
|
||||
outreq.ContentLength = -1
|
||||
|
|
@ -467,7 +516,13 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
|
|||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
if bridged {
|
||||
rewriteReverseProxyURL(outreq, upstream.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
} else if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
rewriteReverseProxyURL(outreq, upstream.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
|
|
@ -526,6 +581,15 @@ func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *
|
|||
if p.config.Transport != nil {
|
||||
return p.config.Transport
|
||||
}
|
||||
if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil {
|
||||
if upstream.bridgeTransport != nil {
|
||||
return upstream.bridgeTransport
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
if upstream.useH2C && upstream.h2cTransport != nil {
|
||||
return upstream.h2cTransport
|
||||
}
|
||||
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
|
||||
return upstream.extendedConnectTransport
|
||||
}
|
||||
|
|
@ -915,6 +979,71 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques
|
|||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error {
|
||||
if c == nil || c.Request == nil {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")}
|
||||
}
|
||||
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{
|
||||
status: http.StatusBadGateway,
|
||||
err: errors.New("backend returned bridged websocket response without writable body"),
|
||||
}
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
backConn.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
responseHeader := c.Writer.Header()
|
||||
reverseProxyCopyHeader(responseHeader, res.Header)
|
||||
responseHeader.Del("Upgrade")
|
||||
responseHeader.Del("Connection")
|
||||
responseHeader.Del("Sec-WebSocket-Accept")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
backConn.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer}
|
||||
brw := bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1))
|
||||
|
||||
backConnClosed := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnClosed:
|
||||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
defer close(backConnClosed)
|
||||
defer conn.Close()
|
||||
defer backConn.Close()
|
||||
|
||||
if err := brw.Flush(); err != nil {
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
copyer := switchProtocolCopier{user: conn, backend: backConn}
|
||||
go copyer.copyToBackend(errc)
|
||||
go copyer.copyFromBackend(errc)
|
||||
|
||||
firstErr := <-errc
|
||||
if firstErr == nil {
|
||||
firstErr = <-errc
|
||||
}
|
||||
if reverseProxyIsBenignTunnelError(firstErr) {
|
||||
return nil
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
|
||||
if c == nil || c.Request == nil {
|
||||
res.Body.Close()
|
||||
|
|
@ -1128,13 +1257,23 @@ func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstr
|
|||
|
||||
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
|
||||
for i, target := range targets {
|
||||
useH2C := strings.EqualFold(target.Scheme, "h2c")
|
||||
if useH2C {
|
||||
target = cloneReverseProxyURL(target)
|
||||
target.Scheme = "http"
|
||||
}
|
||||
upstream := &reverseProxyUpstream{
|
||||
key: fmt.Sprintf("%d:%s", i, target.String()),
|
||||
target: target,
|
||||
index: i,
|
||||
useH2C: useH2C || config.AllowH2CUpstream,
|
||||
}
|
||||
if config.Transport == nil {
|
||||
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
||||
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport()
|
||||
upstream.bridgeTransport = newHTTP1BridgeTransport()
|
||||
if upstream.useH2C {
|
||||
upstream.h2cTransport = newH2CTransport()
|
||||
}
|
||||
}
|
||||
upstreams = append(upstreams, upstream)
|
||||
}
|
||||
|
|
@ -1237,6 +1376,47 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
|||
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
||||
}
|
||||
|
||||
func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool) {
|
||||
protocol := reverseProxyExtendedConnectProtocol(req)
|
||||
if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
|
||||
if req == nil {
|
||||
return context.Background(), false
|
||||
}
|
||||
return req.Context(), false
|
||||
}
|
||||
|
||||
bridge := &reverseProxyExtendedConnectBridge{body: req.Body}
|
||||
ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge)
|
||||
req.Header.Del(":protocol")
|
||||
req.Method = http.MethodGet
|
||||
req.Body = http.NoBody
|
||||
req.ContentLength = 0
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||
key, err := reverseProxyGenerateWebSocketKey()
|
||||
if err == nil {
|
||||
req.Header.Set("Sec-WebSocket-Key", key)
|
||||
}
|
||||
return ctx, true
|
||||
}
|
||||
|
||||
func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge)
|
||||
return bridge
|
||||
}
|
||||
|
||||
func reverseProxyGenerateWebSocketKey() (string, error) {
|
||||
key := make([]byte, 16)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
|
||||
return reverseProxyExtendedConnectProtocol(req) != ""
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue