mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Merge pull request #78 from infinite-iroha/break/v1-enhance-reverse-proxy
fix(reverseproxy): bridge websocket extended connect upstreams
This commit is contained in:
commit
863f984990
5 changed files with 881 additions and 109 deletions
|
|
@ -68,6 +68,7 @@ type ReverseProxyConfig struct {
|
|||
Transport http.RoundTripper
|
||||
FlushInterval time.Duration
|
||||
BufferPool BufferPool
|
||||
AllowH2CUpstream bool
|
||||
|
||||
ModifyRequest func(*http.Request)
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
|
@ -191,6 +192,24 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
|||
}))
|
||||
```
|
||||
|
||||
### `AllowH2CUpstream`
|
||||
|
||||
允许代理使用未加密 HTTP/2(h2c)与 `http://` upstream 通信。
|
||||
|
||||
- 默认关闭
|
||||
- 这是一个显式配置项
|
||||
- 启用后,Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游
|
||||
- 这意味着上游本身也必须显式支持 h2c;它不是“先试 h2c,失败再自动回退到 h1”的协商模式
|
||||
|
||||
```go
|
||||
r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Target: target,
|
||||
AllowH2CUpstream: true,
|
||||
}))
|
||||
```
|
||||
|
||||
对于下游 HTTP/2 extended `CONNECT` websocket 场景,Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade,以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。
|
||||
|
||||
### `Transport`
|
||||
|
||||
可选。用于自定义底层转发所使用的 `http.RoundTripper`。
|
||||
|
|
|
|||
|
|
@ -5,13 +5,11 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
|
|
@ -36,18 +34,55 @@ func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
|
|||
return http2.ConfigureServer(srv, nil)
|
||||
}
|
||||
|
||||
func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper {
|
||||
func newHTTP2ExtendedConnectTransport() http.RoundTripper {
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
transport := &http2.Transport{}
|
||||
if target == nil || !strings.EqualFold(target.Scheme, "http") {
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Protocols = new(http.Protocols)
|
||||
transport.Protocols.SetHTTP1(true)
|
||||
transport.Protocols.SetHTTP2(true)
|
||||
return transport
|
||||
}
|
||||
}
|
||||
|
||||
transport.AllowHTTP = true
|
||||
transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
func newHTTP1BridgeTransport() http.RoundTripper {
|
||||
return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}})
|
||||
}
|
||||
|
||||
func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper {
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Protocols = new(http.Protocols)
|
||||
transport.Protocols.SetHTTP1(true)
|
||||
if tlsConfig == nil {
|
||||
transport.TLSClientConfig = &tls.Config{}
|
||||
} else {
|
||||
transport.TLSClientConfig = tlsConfig.Clone()
|
||||
}
|
||||
if len(transport.TLSClientConfig.NextProtos) == 0 {
|
||||
transport.TLSClientConfig.NextProtos = []string{"http/1.1"}
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
func newH2CTransport() http.RoundTripper {
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Protocols = new(http.Protocols)
|
||||
transport.Protocols.SetUnencryptedHTTP2(true)
|
||||
return transport
|
||||
}
|
||||
|
||||
func cloneDefaultTransport() *http.Transport {
|
||||
if transport, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
return transport.Clone()
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
196
reverseproxy.go
196
reverseproxy.go
|
|
@ -6,6 +6,8 @@ package touka
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -55,6 +57,7 @@ type ReverseProxyConfig struct {
|
|||
Transport http.RoundTripper
|
||||
FlushInterval time.Duration
|
||||
BufferPool BufferPool
|
||||
AllowH2CUpstream bool
|
||||
|
||||
ModifyRequest func(*http.Request)
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
|
@ -86,6 +89,34 @@ type reverseProxyStatusError struct {
|
|||
err error
|
||||
}
|
||||
|
||||
type reverseProxyExtendedConnectBridge struct {
|
||||
body io.ReadCloser
|
||||
}
|
||||
|
||||
type reverseProxyH2ReadWriteCloser struct {
|
||||
io.ReadCloser
|
||||
ResponseWriter
|
||||
controller *http.ResponseController
|
||||
}
|
||||
|
||||
func (rwc *reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) {
|
||||
n, err := rwc.ResponseWriter.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if err := rwc.controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
return n, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (rwc *reverseProxyH2ReadWriteCloser) Close() error {
|
||||
if rwc.ReadCloser == nil {
|
||||
return nil
|
||||
}
|
||||
return rwc.ReadCloser.Close()
|
||||
}
|
||||
|
||||
func (e *reverseProxyStatusError) Error() string {
|
||||
if e == nil || e.err == nil {
|
||||
return ""
|
||||
|
|
@ -314,7 +345,7 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
|
|||
}
|
||||
defer cleanup()
|
||||
|
||||
transport := p.transportForUpstream(c.Request, upstream)
|
||||
transport := p.transportForUpstream(outreq, upstream)
|
||||
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
|
||||
var (
|
||||
roundTripMu sync.Mutex
|
||||
|
|
@ -353,6 +384,20 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
|
|||
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
|
||||
}
|
||||
|
||||
if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil {
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
return true, nil, false
|
||||
}
|
||||
if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil {
|
||||
return false, err, false
|
||||
}
|
||||
return true, nil, false
|
||||
}
|
||||
return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
|
||||
removeHopByHopHeaders(res.Header)
|
||||
res.Header.Del("Content-Length")
|
||||
|
|
@ -435,6 +480,13 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
|
|||
|
||||
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
|
||||
outreq := c.Request.Clone(ctx)
|
||||
bridgeCtx, bridged, err := reverseProxyPrepareExtendedConnectBridge(outreq)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if bridged {
|
||||
outreq = outreq.WithContext(bridgeCtx)
|
||||
}
|
||||
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
|
||||
outreq.Body = nil
|
||||
} else if c.Request.GetBody != nil {
|
||||
|
|
@ -451,7 +503,7 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
|
|||
}
|
||||
outreq.Close = false
|
||||
var connectWriter *io.PipeWriter
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if outreq.Method == http.MethodConnect && !bridged {
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
outreq.Body = pipeReader
|
||||
outreq.ContentLength = -1
|
||||
|
|
@ -466,19 +518,11 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
|
|||
}
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
rewriteReverseProxyURL(outreq, upstream.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
} else {
|
||||
if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
|
||||
cleanup()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rewriteReverseProxyURL(outreq, upstream.target)
|
||||
if !p.config.PreserveHost {
|
||||
|
|
@ -526,6 +570,15 @@ func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *
|
|||
if p.config.Transport != nil {
|
||||
return p.config.Transport
|
||||
}
|
||||
if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil {
|
||||
if upstream.bridgeTransport != nil {
|
||||
return upstream.bridgeTransport
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
if upstream.useH2C && upstream.h2cTransport != nil {
|
||||
return upstream.h2cTransport
|
||||
}
|
||||
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
|
||||
return upstream.extendedConnectTransport
|
||||
}
|
||||
|
|
@ -915,6 +968,73 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques
|
|||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error {
|
||||
if c == nil || c.Request == nil {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")}
|
||||
}
|
||||
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{
|
||||
status: http.StatusBadGateway,
|
||||
err: errors.New("backend returned bridged websocket response without writable body"),
|
||||
}
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
backConn.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
responseHeader := c.Writer.Header()
|
||||
reverseProxyCopyHeader(responseHeader, res.Header)
|
||||
removeHopByHopHeaders(responseHeader)
|
||||
responseHeader.Del("Sec-WebSocket-Accept")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
backConn.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller}
|
||||
|
||||
var closeOnce sync.Once
|
||||
closeTunnel := func() {
|
||||
closeOnce.Do(func() {
|
||||
_ = conn.Close()
|
||||
_ = backConn.Close()
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
<-req.Context().Done()
|
||||
closeTunnel()
|
||||
}()
|
||||
|
||||
errc := make(chan error, 2)
|
||||
copyer := switchProtocolCopier{user: conn, backend: backConn}
|
||||
go copyer.copyToBackend(errc)
|
||||
go copyer.copyFromBackend(errc)
|
||||
|
||||
var firstErr error
|
||||
for i := 0; i < 2; i++ {
|
||||
err := <-errc
|
||||
if reverseProxyIsBenignTunnelError(err) {
|
||||
continue
|
||||
}
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
closeTunnel()
|
||||
}
|
||||
}
|
||||
closeTunnel()
|
||||
if reverseProxyIsBenignTunnelError(firstErr) {
|
||||
return nil
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
|
||||
if c == nil || c.Request == nil {
|
||||
res.Body.Close()
|
||||
|
|
@ -1128,13 +1248,23 @@ func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstr
|
|||
|
||||
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
|
||||
for i, target := range targets {
|
||||
useH2C := strings.EqualFold(target.Scheme, "h2c")
|
||||
if useH2C {
|
||||
target = cloneReverseProxyURL(target)
|
||||
target.Scheme = "http"
|
||||
}
|
||||
upstream := &reverseProxyUpstream{
|
||||
key: fmt.Sprintf("%d:%s", i, target.String()),
|
||||
target: target,
|
||||
index: i,
|
||||
useH2C: useH2C || config.AllowH2CUpstream,
|
||||
}
|
||||
if config.Transport == nil {
|
||||
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
||||
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport()
|
||||
upstream.bridgeTransport = newHTTP1BridgeTransport()
|
||||
if upstream.useH2C {
|
||||
upstream.h2cTransport = newH2CTransport()
|
||||
}
|
||||
}
|
||||
upstreams = append(upstreams, upstream)
|
||||
}
|
||||
|
|
@ -1237,6 +1367,48 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
|||
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
||||
}
|
||||
|
||||
func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) {
|
||||
if req == nil {
|
||||
return context.Background(), false, nil
|
||||
}
|
||||
protocol := reverseProxyExtendedConnectProtocol(req)
|
||||
if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
|
||||
return req.Context(), false, nil
|
||||
}
|
||||
|
||||
bridge := &reverseProxyExtendedConnectBridge{body: req.Body}
|
||||
ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge)
|
||||
req.Header.Del(":protocol")
|
||||
req.Method = http.MethodGet
|
||||
req.Body = http.NoBody
|
||||
req.ContentLength = 0
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||
key, err := reverseProxyGenerateWebSocketKey()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("reverse proxy failed to generate websocket key: %w", err)
|
||||
}
|
||||
req.Header.Set("Sec-WebSocket-Key", key)
|
||||
return ctx, true, nil
|
||||
}
|
||||
|
||||
func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge)
|
||||
return bridge
|
||||
}
|
||||
|
||||
func reverseProxyGenerateWebSocketKey() (string, error) {
|
||||
key := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
|
||||
return reverseProxyExtendedConnectProtocol(req) != ""
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,7 +57,10 @@ type reverseProxyUpstream struct {
|
|||
key string
|
||||
target *url.URL
|
||||
index int
|
||||
useH2C bool
|
||||
extendedConnectTransport http.RoundTripper
|
||||
bridgeTransport http.RoundTripper
|
||||
h2cTransport http.RoundTripper
|
||||
inFlight atomic.Int64
|
||||
|
||||
passiveMu sync.Mutex
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ package touka
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -15,6 +17,7 @@ import (
|
|||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -112,7 +115,8 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
|||
t.Fatalf("unexpected body: %q", string(body))
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
if got.Path != "/base/api/ping" {
|
||||
t.Fatalf("unexpected upstream path: %q", got.Path)
|
||||
|
|
@ -765,6 +769,43 @@ func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyAllowH2CUpstream(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen h2c upstream: %v", err)
|
||||
}
|
||||
server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Upstream-Proto", r.Proto)
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
})}
|
||||
server.Protocols = new(http.Protocols)
|
||||
server.Protocols.SetUnencryptedHTTP2(true)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Serve(listener)
|
||||
}()
|
||||
defer func() {
|
||||
_ = server.Close()
|
||||
<-errCh
|
||||
}()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://"+listener.Addr().String()),
|
||||
AllowH2CUpstream: true,
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
||||
t.Fatalf("unexpected response: code=%d body=%q", rr.Code, rr.Body.String())
|
||||
}
|
||||
if got := rr.Header().Get("X-Upstream-Proto"); got != "HTTP/2.0" {
|
||||
t.Fatalf("expected h2c upstream proto, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -791,6 +832,131 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
flushErr := errors.New("flush failed")
|
||||
writer := &flushErrorResponseWriter{flushErr: flushErr}
|
||||
conn := &reverseProxyH2ReadWriteCloser{
|
||||
ReadCloser: io.NopCloser(strings.NewReader("")),
|
||||
ResponseWriter: writer,
|
||||
controller: http.NewResponseController(reverseProxyBaseResponseWriter(writer)),
|
||||
}
|
||||
|
||||
n, err := conn.Write([]byte("ping"))
|
||||
if n != len("ping") {
|
||||
t.Fatalf("unexpected bytes written: %d", n)
|
||||
}
|
||||
if !errors.Is(err, flushErr) {
|
||||
t.Fatalf("unexpected write error: %v", err)
|
||||
}
|
||||
if got := writer.body.String(); got != "ping" {
|
||||
t.Fatalf("unexpected buffered body: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
transportCalled := atomic.Bool{}
|
||||
entropyErr := errors.New("entropy source unavailable")
|
||||
originalReader := crand.Reader
|
||||
crand.Reader = errorReader{err: entropyErr}
|
||||
t.Cleanup(func() {
|
||||
crand.Reader = originalReader
|
||||
})
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://example.com"),
|
||||
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
transportCalled.Store(true)
|
||||
return nil, errors.New("unexpected round trip")
|
||||
}),
|
||||
ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) {
|
||||
w.WriteHeader(reverseProxyStatusCode(err))
|
||||
_, _ = io.WriteString(w, err.Error())
|
||||
},
|
||||
}))
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set(":protocol", "websocket")
|
||||
rr := PerformRequest(engine, http.MethodConnect, "/ws", nil, headers)
|
||||
|
||||
if transportCalled.Load() {
|
||||
t.Fatal("transport should not be called when websocket key generation fails")
|
||||
}
|
||||
if rr.Code != http.StatusBadGateway {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, "reverse proxy failed to generate websocket key") || !strings.Contains(body, entropyErr.Error()) {
|
||||
t.Fatalf("unexpected error body: %q", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
originalDefaultTransport := http.DefaultTransport
|
||||
http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("unexpected round trip")
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
http.DefaultTransport = originalDefaultTransport
|
||||
})
|
||||
|
||||
assertTransport := func(name string, rt http.RoundTripper, check func(*http.Transport)) {
|
||||
t.Helper()
|
||||
transport, ok := rt.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("%s returned %T, want *http.Transport", name, rt)
|
||||
}
|
||||
check(transport)
|
||||
}
|
||||
|
||||
assertTransport("newHTTP2ExtendedConnectTransport", newHTTP2ExtendedConnectTransport(), func(transport *http.Transport) {
|
||||
if transport.Protocols == nil || !transport.Protocols.HTTP1() || !transport.Protocols.HTTP2() {
|
||||
t.Fatalf("unexpected protocols for extended connect transport: %#v", transport.Protocols)
|
||||
}
|
||||
})
|
||||
assertTransport("newHTTP1BridgeTransportWithTLSConfig", newHTTP1BridgeTransportWithTLSConfig(nil), func(transport *http.Transport) {
|
||||
if transport.Protocols == nil || !transport.Protocols.HTTP1() || transport.Protocols.HTTP2() || transport.Protocols.UnencryptedHTTP2() {
|
||||
t.Fatalf("unexpected protocols for bridge transport: %#v", transport.Protocols)
|
||||
}
|
||||
if transport.TLSClientConfig == nil || len(transport.TLSClientConfig.NextProtos) != 1 || transport.TLSClientConfig.NextProtos[0] != "http/1.1" {
|
||||
t.Fatalf("unexpected TLS next protos for bridge transport: %#v", transport.TLSClientConfig)
|
||||
}
|
||||
})
|
||||
assertTransport("newH2CTransport", newH2CTransport(), func(transport *http.Transport) {
|
||||
if transport.Protocols == nil || !transport.Protocols.UnencryptedHTTP2() || transport.Protocols.HTTP1() || transport.Protocols.HTTP2() {
|
||||
t.Fatalf("unexpected protocols for h2c transport: %#v", transport.Protocols)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewHTTP1BridgeTransportWithTLSConfigClonesInput(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
||||
rt := newHTTP1BridgeTransportWithTLSConfig(tlsConfig)
|
||||
transport, ok := rt.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected transport type: %T", rt)
|
||||
}
|
||||
if transport.TLSClientConfig == nil {
|
||||
t.Fatal("expected TLS client config")
|
||||
}
|
||||
if transport.TLSClientConfig == tlsConfig {
|
||||
t.Fatal("expected bridge transport to clone TLS config")
|
||||
}
|
||||
if len(tlsConfig.NextProtos) != 0 {
|
||||
t.Fatalf("input TLS config was mutated: %#v", tlsConfig.NextProtos)
|
||||
}
|
||||
if got := transport.TLSClientConfig.NextProtos; len(got) != 1 || got[0] != "http/1.1" {
|
||||
t.Fatalf("unexpected transport NextProtos: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -1363,19 +1529,29 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.ProtoMajor != 2 {
|
||||
errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto)
|
||||
if got := r.Header.Get(":protocol"); got != "" {
|
||||
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
|
||||
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") {
|
||||
errCh <- fmt.Errorf("unexpected upstream Connection header: %#v", r.Header.Values("Connection"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||
errCh <- fmt.Errorf("unexpected upstream Upgrade header: %q", r.Header.Get("Upgrade"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Sec-WebSocket-Key"); got == "" {
|
||||
errCh <- errors.New("missing upstream Sec-WebSocket-Key header")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
|
@ -1385,36 +1561,41 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("upstream response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade, X-Hop-Token\r\nX-Hop-Token: hidden\r\nSec-WebSocket-Accept: ignored\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("upstream flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
line, err := brw.ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, "echo:"+line); err != nil {
|
||||
if _, err := io.WriteString(brw, "echo:"+line); err != nil {
|
||||
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||
}
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
|
|
@ -1445,7 +1626,13 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" {
|
||||
if got := resp.Header.Get("Upgrade"); got != "" {
|
||||
t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got)
|
||||
}
|
||||
if got := resp.Header.Get("X-Hop-Token"); got != "" {
|
||||
t.Fatalf("bridged extended CONNECT response should not expose hop-by-hop token header, got %q", got)
|
||||
}
|
||||
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" {
|
||||
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
||||
}
|
||||
|
||||
|
|
@ -1470,6 +1657,224 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
closeCalls := atomic.Int32{}
|
||||
backendReadDone := make(chan struct{}, 1)
|
||||
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.Method != http.MethodGet {
|
||||
return nil, fmt.Errorf("unexpected upstream method: %s", req.Method)
|
||||
}
|
||||
var respondOnce sync.Once
|
||||
var backend *countingReadWriteCloser
|
||||
backend = &countingReadWriteCloser{
|
||||
readDataCh: make(chan []byte, 1),
|
||||
closeCalls: &closeCalls,
|
||||
closeWriteErr: nil,
|
||||
afterWrite: func() {
|
||||
respondOnce.Do(func() {
|
||||
backendReadDone <- struct{}{}
|
||||
backend.readDataCh <- []byte("echo:ping\n")
|
||||
close(backend.readDataCh)
|
||||
})
|
||||
},
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
Header: http.Header{
|
||||
"Connection": []string{"Upgrade"},
|
||||
"Upgrade": []string{"websocket"},
|
||||
"Sec-WebSocket-Accept": []string{"ignored"},
|
||||
},
|
||||
Body: backend,
|
||||
Request: req,
|
||||
}, nil
|
||||
})
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://example.com"),
|
||||
Transport: transport,
|
||||
}))
|
||||
|
||||
proxy := httptest.NewUnstartedServer(engine)
|
||||
proxy.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||
}
|
||||
proxy.StartTLS()
|
||||
defer proxy.Close()
|
||||
|
||||
clientTransport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
defer clientTransport.CloseIdleConnections()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||
if err != nil {
|
||||
t.Fatalf("new CONNECT request: %v", err)
|
||||
}
|
||||
req.Header.Set(":protocol", "websocket")
|
||||
|
||||
resp, err := clientTransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = resp.Body.Close()
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
t.Fatalf("write tunneled request body: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-backendReadDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
_ = resp.Body.Close()
|
||||
t.Fatal("backend did not receive tunneled request body")
|
||||
}
|
||||
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
_ = resp.Body.Close()
|
||||
t.Fatalf("read tunneled response body: %v", err)
|
||||
}
|
||||
if message != "echo:ping\n" {
|
||||
_ = resp.Body.Close()
|
||||
t.Fatalf("unexpected tunneled response body: %q", message)
|
||||
}
|
||||
if err := pw.Close(); err != nil {
|
||||
_ = resp.Body.Close()
|
||||
t.Fatalf("close tunneled request body: %v", err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
t.Fatalf("close response body: %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if closeCalls.Load() > 0 {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if got := closeCalls.Load(); got != 1 {
|
||||
t.Fatalf("expected backend connection to close exactly once, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ProtoMajor != 1 {
|
||||
errCh <- fmt.Errorf("expected bridged upstream protocol HTTP/1.x, got %s", r.Proto)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||
errCh <- fmt.Errorf("unexpected websocket bridge headers: Connection=%#v Upgrade=%q", r.Header.Values("Connection"), r.Header.Get("Upgrade"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("upstream response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("upstream flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
line, err := brw.ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(brw, "echo:"+line); err != nil {
|
||||
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
proxy := httptest.NewUnstartedServer(engine)
|
||||
proxy.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||
}
|
||||
proxy.StartTLS()
|
||||
defer proxy.Close()
|
||||
|
||||
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
defer transport.CloseIdleConnections()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||
if err != nil {
|
||||
t.Fatalf("new CONNECT request: %v", err)
|
||||
}
|
||||
req.Header.Set(":protocol", "websocket")
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
||||
t.Fatalf("write tunneled request body: %v", err)
|
||||
}
|
||||
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read tunneled response body: %v", err)
|
||||
}
|
||||
if message != "echo:ping\n" {
|
||||
t.Fatalf("unexpected tunneled response body: %q", message)
|
||||
}
|
||||
_ = pw.Close()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -1477,42 +1882,62 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
|||
|
||||
errCh := make(chan error, 8)
|
||||
newBackend := func(name string) *httptest.Server {
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||
if got := r.Header.Get(":protocol"); got != "" {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err)
|
||||
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream Connection header: %#v", name, r.Header.Values("Connection"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream Upgrade header: %q", name, r.Header.Get("Upgrade"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Sec-WebSocket-Key"); got == "" {
|
||||
errCh <- fmt.Errorf("%s missing upstream Sec-WebSocket-Key header", name)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
|
||||
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- fmt.Errorf("%s upstream response writer does not support hijack", name)
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s upstream hijack failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("%s upstream flush failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
|
||||
line, err := brw.ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, name+":"+line); err != nil {
|
||||
if _, err := io.WriteString(brw, name+":"+line); err != nil {
|
||||
errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
server.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil {
|
||||
t.Fatalf("configure %s HTTP/2 server: %v", name, err)
|
||||
}
|
||||
server.StartTLS()
|
||||
return server
|
||||
}
|
||||
|
||||
|
|
@ -1527,7 +1952,6 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
|||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBRoundRobin(),
|
||||
},
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
|
|
@ -1557,7 +1981,8 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
if _, err := io.WriteString(pw, payload+"\n"); err != nil {
|
||||
t.Fatalf("write tunneled request body: %v", err)
|
||||
|
|
@ -1592,54 +2017,58 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
|||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("upstream response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
reader := bufio.NewReader(r.Body)
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("upstream flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(brw)
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, "ack:"+line); err != nil {
|
||||
if _, err := io.WriteString(brw, "ack:"+line); err != nil {
|
||||
errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
_ = brw.Flush()
|
||||
|
||||
if _, err := io.Copy(io.Discard, reader); err != nil {
|
||||
errCh <- fmt.Errorf("wait for request half-close failed: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, "after-close\n"); err != nil {
|
||||
if _, err := io.WriteString(brw, "after-close\n"); err != nil {
|
||||
errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
_ = brw.Flush()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||
}
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
|
|
@ -1668,7 +2097,8 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
|
@ -1707,35 +2137,36 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi
|
|||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("enable full duplex failed: %w", err)
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("upstream response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
_ = brw.Flush()
|
||||
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
upstream.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil {
|
||||
t.Fatalf("configure upstream HTTP/2 server: %v", err)
|
||||
}
|
||||
upstream.StartTLS()
|
||||
defer upstream.Close()
|
||||
|
||||
proxyErrCh := make(chan error, 1)
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, upstream.URL),
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
select {
|
||||
|
|
@ -1772,7 +2203,8 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi
|
|||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
writeErrCh := make(chan error, 1)
|
||||
|
|
@ -1944,6 +2376,117 @@ func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error)
|
|||
return fn(req)
|
||||
}
|
||||
|
||||
type flushErrorResponseWriter struct {
|
||||
header http.Header
|
||||
body bytes.Buffer
|
||||
status int
|
||||
written bool
|
||||
flushErr error
|
||||
}
|
||||
|
||||
func (w *flushErrorResponseWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *flushErrorResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.written {
|
||||
return
|
||||
}
|
||||
w.status = statusCode
|
||||
w.written = true
|
||||
}
|
||||
|
||||
func (w *flushErrorResponseWriter) Write(p []byte) (int, error) {
|
||||
if !w.written {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.body.Write(p)
|
||||
}
|
||||
|
||||
func (w *flushErrorResponseWriter) Flush() {}
|
||||
|
||||
func (w *flushErrorResponseWriter) FlushError() error {
|
||||
if !w.written {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.flushErr
|
||||
}
|
||||
|
||||
func (w *flushErrorResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
func (w *flushErrorResponseWriter) Status() int {
|
||||
return w.status
|
||||
}
|
||||
|
||||
func (w *flushErrorResponseWriter) Size() int {
|
||||
return w.body.Len()
|
||||
}
|
||||
|
||||
func (w *flushErrorResponseWriter) Written() bool {
|
||||
return w.written
|
||||
}
|
||||
|
||||
func (w *flushErrorResponseWriter) IsHijacked() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type errorReader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (r errorReader) Read([]byte) (int, error) {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
type countingReadWriteCloser struct {
|
||||
readData []byte
|
||||
readDataCh chan []byte
|
||||
writeBuf bytes.Buffer
|
||||
closeCalls *atomic.Int32
|
||||
closeWriteErr error
|
||||
afterWrite func()
|
||||
}
|
||||
|
||||
func (r *countingReadWriteCloser) Read(p []byte) (int, error) {
|
||||
if len(r.readData) == 0 && r.readDataCh != nil {
|
||||
data, ok := <-r.readDataCh
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
r.readData = data
|
||||
}
|
||||
if len(r.readData) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, r.readData)
|
||||
r.readData = r.readData[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *countingReadWriteCloser) Write(p []byte) (int, error) {
|
||||
n, err := r.writeBuf.Write(p)
|
||||
if err == nil && r.afterWrite != nil {
|
||||
r.afterWrite()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *countingReadWriteCloser) Close() error {
|
||||
if r.closeCalls != nil {
|
||||
r.closeCalls.Add(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *countingReadWriteCloser) CloseWrite() error {
|
||||
return r.closeWriteErr
|
||||
}
|
||||
|
||||
func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||
t.Helper()
|
||||
u, err := url.Parse(raw)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue