feat(reverseproxy): add upstream balancing and failover

This commit is contained in:
wjqserver 2026-04-02 14:40:56 +08:00
parent 59f190ce3a
commit 919236665b
4 changed files with 1394 additions and 116 deletions

View file

@ -59,7 +59,11 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
```go
type ReverseProxyConfig struct {
Target *url.URL
Target *url.URL
Targets []string
LoadBalancing ReverseProxyLoadBalancingConfig
PassiveHealth ReverseProxyPassiveHealthConfig
Transport http.RoundTripper
FlushInterval time.Duration
@ -78,12 +82,115 @@ type ReverseProxyConfig struct {
### `Target`
必填。表示后端目标地址,至少需要提供 `scheme``host`
`Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme``host`
```go
target, _ := url.Parse("http://backend:9000")
```
### `Targets`
可选。用于配置多个后端目标地址。
- `Target``Targets` 互斥,只能使用其中一种
- `Targets` 的每一项都必须是完整 URL
- 每个 target 仍然可以自带 base path 和 query
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001/base?from=a",
"http://127.0.0.1:9002/base?from=b",
},
}))
```
这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。
### `LoadBalancing`
用于配置 upstream 选择策略和重试行为。
```go
type ReverseProxyLoadBalancingConfig struct {
Policy ReverseProxyLBPolicy
Retries int
TryDuration time.Duration
TryInterval time.Duration
}
```
当前内置策略:
- `touka.LBRandom()`
- `touka.LBRoundRobin()`
- `touka.LBFirst()`
- `touka.LBLeastConn()`
- `touka.LBIPHash()`
- `touka.LBClientIPHash()`
- `touka.LBURIHash()`
- `touka.LBHeader("X-Upstream", fallback)`
- `touka.LBQuery("tenant", fallback)`
其中:
- `LBFirst()` 适合主备/故障转移顺序
- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback
- 如果 `LBHeader` / `LBQuery` 没有显式 fallback则默认回退到 `LBRandom()`
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001",
"http://127.0.0.1:9002",
},
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
Policy: touka.LBHeader("X-Upstream", touka.LBFirst()),
Retries: 1,
},
}))
```
重试说明:
- 只对未开始收到上游响应的失败进行重试
- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试
- `Retries` 表示额外重试次数
- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机
- `TryInterval` 表示两次重试之间的等待间隔
### `PassiveHealth`
用于配置被动健康检查。它不会后台探测 upstream而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。
```go
type ReverseProxyPassiveHealthConfig struct {
FailDuration time.Duration
MaxFails int
UnhealthyStatus []int
}
```
- `FailDuration > 0` 时启用被动健康跟踪
- `MaxFails <= 0` 时默认按 `1` 处理
- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001",
"http://127.0.0.1:9002",
},
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
Policy: touka.LBFirst(),
},
PassiveHealth: touka.ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
UnhealthyStatus: []int{http.StatusServiceUnavailable},
},
}))
```
### `Transport`
可选。用于自定义底层转发所使用的 `http.RoundTripper`
@ -150,6 +257,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
在请求真正发往后端前,对出站请求做最后修改。
如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。
常见用途:
- 覆盖 `Host`

View file

@ -23,6 +23,8 @@ import (
"sync"
"sync/atomic"
"time"
"golang.org/x/net/http2"
)
// ForwardedHeadersPolicy controls how forwarding headers are generated.
@ -44,7 +46,11 @@ type BufferPool interface {
// ReverseProxyConfig configures the reverse proxy handler.
type ReverseProxyConfig struct {
Target *url.URL
Target *url.URL
Targets []string
LoadBalancing ReverseProxyLoadBalancingConfig
PassiveHealth ReverseProxyPassiveHealthConfig
Transport http.RoundTripper
FlushInterval time.Duration
@ -61,17 +67,18 @@ type ReverseProxyConfig struct {
}
var (
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
)
type reverseProxyHandler struct {
config ReverseProxyConfig
target *url.URL
receivedBy string
configError error
extendedConnectTransport http.RoundTripper
config ReverseProxyConfig
upstreams []*reverseProxyUpstream
receivedBy string
configError error
roundRobin atomic.Uint64
}
type reverseProxyStatusError struct {
@ -199,22 +206,16 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc {
}
func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
target := cloneReverseProxyURL(config.Target)
if target != nil {
normalizeReverseProxyTarget(target)
}
proxy := &reverseProxyHandler{
config: config,
target: target,
receivedBy: reverseProxyReceivedBy(config.Via),
}
if config.Transport == nil {
proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
}
if err := validateReverseProxyTarget(target); err != nil {
upstreams, err := buildReverseProxyUpstreams(config)
if err != nil {
proxy.configError = err
} else {
proxy.upstreams = upstreams
}
switch config.ForwardedHeaders {
@ -228,6 +229,11 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
proxy.configError = err
}
}
if proxy.configError == nil {
if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil {
proxy.configError = err
}
}
return proxy
}
@ -240,15 +246,6 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
return
}
transport := p.config.Transport
if transport == nil {
if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil {
transport = p.extendedConnectTransport
} else {
transport = http.DefaultTransport
}
}
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
if err != nil {
p.handleError(c, err)
@ -260,86 +257,64 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
ctx, cancel := p.requestContext(c)
defer cancel()
attempted := make(map[string]struct{}, len(p.upstreams))
attempts := 0
started := time.Now()
var lastErr error
outreq := c.Request.Clone(ctx)
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
outreq.Body = nil
}
if outreq.Body != nil {
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
defer outreq.Body.Close()
}
if outreq.Header == nil {
outreq.Header = make(http.Header)
}
outreq.Close = false
var connectWriter *io.PipeWriter
defer func() {
if connectWriter != nil {
_ = connectWriter.Close()
}
}()
if outreq.Method == http.MethodConnect {
pipeReader, pipeWriter := io.Pipe()
outreq.Body = pipeReader
outreq.ContentLength = -1
defer outreq.Body.Close()
connectWriter = pipeWriter
}
if outreq.Method == http.MethodConnect {
if reverseProxyIsExtendedConnectRequest(outreq) {
rewriteReverseProxyURL(outreq, p.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
} else {
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
p.handleError(c, err)
for {
upstream, err := p.selectUpstream(c, attempted)
if err != nil {
if lastErr != nil {
p.handleError(c, lastErr)
return
}
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err})
return
}
} else {
rewriteReverseProxyURL(outreq, p.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
}
if updatedMaxForwards != "" {
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
}
reqUpType := reverseProxyUpgradeType(outreq.Header)
if reqUpType != "" && !isPrintableASCII(reqUpType) {
p.handleError(c, &reverseProxyStatusError{
status: http.StatusBadRequest,
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
})
attempts++
upstream.inFlight.Add(1)
served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards)
upstream.inFlight.Add(-1)
if served {
return
}
if attemptErr != nil {
lastErr = attemptErr
}
if retriable && p.shouldRetryAttempt(c.Request, attempts, started) {
attempted[upstream.key] = struct{}{}
if !p.waitRetryInterval(ctx, started) {
if lastErr != nil {
p.handleError(c, lastErr)
}
return
}
continue
}
if attemptErr != nil {
p.handleError(c, attemptErr)
return
}
if lastErr != nil {
p.handleError(c, lastErr)
return
}
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams})
return
}
}
removeHopByHopHeaders(outreq.Header)
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
p.addForwardingHeaders(c.Request, outreq)
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
if _, ok := outreq.Header["User-Agent"]; !ok {
outreq.Header.Set("User-Agent", "")
}
if p.config.ModifyRequest != nil {
p.config.ModifyRequest(outreq)
func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) {
outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards)
if err != nil {
return false, err, false
}
defer cleanup()
transport := p.transportForUpstream(c.Request, upstream)
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
var (
roundTripMu sync.Mutex
@ -369,8 +344,13 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
roundTripDone = true
roundTripMu.Unlock()
if err != nil {
p.handleError(c, err)
return
if reverseProxyShouldCountPassiveFailure(outreq, err) {
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
}
return false, err, true
}
if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) {
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
}
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
@ -381,35 +361,34 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
res.TransferEncoding = nil
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return
return true, nil, false
}
handleConnect := p.handleConnectResponse
if reverseProxyIsExtendedConnectRequest(outreq) {
handleConnect = p.handleExtendedConnectResponse
}
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
p.handleError(c, err)
return false, err, false
}
connectWriter = nil
return
return true, nil, false
}
if res.StatusCode == http.StatusSwitchingProtocols {
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return
return true, nil, false
}
if err := p.handleUpgradeResponse(c, outreq, res); err != nil {
p.handleError(c, err)
return false, err, false
}
return
return true, nil, false
}
removeHopByHopHeaders(res.Header)
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return
return true, nil, false
}
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
@ -432,7 +411,7 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
if reverseProxyShouldPanicOnCopyError(c.Request) {
panic(http.ErrAbortHandler)
}
return
return true, nil, false
}
res.Body.Close()
@ -440,13 +419,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
c.Writer.Flush()
}
// Keep the stdlib-compatible fallback here.
// If the backend only exposes additional trailer keys after the body has been
// fully read, the trailer map can grow and those values must be written using
// the TrailerPrefix form instead of the pre-announced bare header keys.
if len(res.Trailer) == announcedTrailers {
reverseProxyCopyHeader(c.Writer.Header(), res.Trailer)
return
return true, nil, false
}
for key, values := range res.Trailer {
@ -455,6 +430,148 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
c.Writer.Header().Add(prefixedKey, value)
}
}
return true, nil, false
}
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
outreq := c.Request.Clone(ctx)
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
outreq.Body = nil
} else if c.Request.GetBody != nil {
body, err := c.Request.GetBody()
if err != nil {
return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err)
}
outreq.Body = body
} else if outreq.Body != nil {
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
}
if outreq.Header == nil {
outreq.Header = make(http.Header)
}
outreq.Close = false
var connectWriter *io.PipeWriter
if outreq.Method == http.MethodConnect {
pipeReader, pipeWriter := io.Pipe()
outreq.Body = pipeReader
outreq.ContentLength = -1
connectWriter = pipeWriter
}
cleanup := func() {
if outreq.Body != nil {
_ = outreq.Body.Close()
}
if connectWriter != nil {
_ = connectWriter.Close()
}
}
if outreq.Method == http.MethodConnect {
if reverseProxyIsExtendedConnectRequest(outreq) {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
} else {
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
cleanup()
return nil, nil, nil, err
}
}
} else {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
}
if updatedMaxForwards != "" {
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
}
reqUpType := reverseProxyUpgradeType(outreq.Header)
if reqUpType != "" && !isPrintableASCII(reqUpType) {
cleanup()
return nil, nil, nil, &reverseProxyStatusError{
status: http.StatusBadRequest,
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
}
}
removeHopByHopHeaders(outreq.Header)
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
p.addForwardingHeaders(c.Request, outreq)
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
if _, ok := outreq.Header["User-Agent"]; !ok {
outreq.Header.Set("User-Agent", "")
}
if p.config.ModifyRequest != nil {
p.config.ModifyRequest(outreq)
}
return outreq, connectWriter, cleanup, nil
}
func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper {
if p.config.Transport != nil {
return p.config.Transport
}
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
return upstream.extendedConnectTransport
}
return http.DefaultTransport
}
func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool {
if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) {
return false
}
lb := p.config.LoadBalancing
if lb.TryDuration > 0 {
return time.Since(started) < lb.TryDuration
}
return attempts <= lb.Retries
}
func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool {
interval := p.config.LoadBalancing.TryInterval
tryDuration := p.config.LoadBalancing.TryDuration
if tryDuration > 0 && interval == 0 {
interval = 250 * time.Millisecond
}
if tryDuration > 0 {
remaining := tryDuration - time.Since(started)
if remaining <= 0 {
return false
}
if interval <= 0 {
return ctx.Err() == nil
}
if interval > remaining {
return false
}
}
if interval <= 0 {
return ctx.Err() == nil
}
timer := time.NewTimer(interval)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
@ -976,6 +1093,54 @@ func validateReverseProxyTarget(target *url.URL) error {
return nil
}
func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) {
if config.Target != nil && len(config.Targets) > 0 {
return nil, errors.New("reverse proxy Target and Targets cannot be used together")
}
targets := make([]*url.URL, 0, max(1, len(config.Targets)))
if config.Target != nil {
target := cloneReverseProxyURL(config.Target)
normalizeReverseProxyTarget(target)
if err := validateReverseProxyTarget(target); err != nil {
return nil, err
}
targets = append(targets, target)
}
for i, rawTarget := range config.Targets {
trimmed := strings.TrimSpace(rawTarget)
if trimmed == "" {
return nil, fmt.Errorf("reverse proxy target at index %d is empty", i)
}
target, err := url.Parse(trimmed)
if err != nil {
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
}
normalizeReverseProxyTarget(target)
if err := validateReverseProxyTarget(target); err != nil {
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
}
targets = append(targets, target)
}
if len(targets) == 0 {
return nil, errReverseProxyNilTarget
}
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
for i, target := range targets {
upstream := &reverseProxyUpstream{
key: fmt.Sprintf("%d:%s", i, target.String()),
target: target,
index: i,
}
if config.Transport == nil {
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
}
upstreams = append(upstreams, upstream)
}
return upstreams, nil
}
func validateReverseProxyForwardedBy(value string) error {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
@ -1388,6 +1553,35 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
return req != nil && req.Context().Value(http.ServerContextKey) != nil
}
func reverseProxyCanRetryRequest(req *http.Request) bool {
if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) {
return false
}
if req.Body == nil || req.ContentLength == 0 {
return true
}
return req.GetBody != nil
}
func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool {
if err == nil || reverseProxyIsBenignTunnelError(err) {
return false
}
if req != nil && req.Context().Err() != nil {
return false
}
return !errors.Is(err, context.Canceled)
}
func reverseProxyMethodIsSafe(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func reverseProxyIsBenignTunnelError(err error) bool {
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err)
}
@ -1396,6 +1590,10 @@ func reverseProxyIsClosedBodyError(err error) bool {
if err == nil {
return false
}
var streamErr http2.StreamError
if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel {
return true
}
switch err.Error() {
case "body closed by handler", "http2: response body closed", "response body closed":
return true

352
reverseproxy_lb.go Normal file
View file

@ -0,0 +1,352 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"fmt"
"math/rand/v2"
"net/http"
"net/textproto"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
)
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
type ReverseProxyLoadBalancingConfig struct {
Policy ReverseProxyLBPolicy
Retries int
TryDuration time.Duration
TryInterval time.Duration
}
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
type ReverseProxyPassiveHealthConfig struct {
FailDuration time.Duration
MaxFails int
UnhealthyStatus []int
}
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
type ReverseProxyLBPolicy struct {
kind reverseProxyLBPolicyKind
key string
fallback *ReverseProxyLBPolicy
}
type reverseProxyLBPolicyKind uint8
const (
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
reverseProxyLBPolicyRoundRobin
reverseProxyLBPolicyFirst
reverseProxyLBPolicyLeastConn
reverseProxyLBPolicyIPHash
reverseProxyLBPolicyClientIPHash
reverseProxyLBPolicyURIHash
reverseProxyLBPolicyHeader
reverseProxyLBPolicyQuery
)
type reverseProxyUpstream struct {
key string
target *url.URL
index int
extendedConnectTransport http.RoundTripper
inFlight atomic.Int64
passiveMu sync.Mutex
failures []time.Time
}
func LBRandom() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
}
func LBRoundRobin() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
}
func LBFirst() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
}
func LBLeastConn() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
}
func LBIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
}
func LBClientIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
}
func LBURIHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
}
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
switch policy.kind {
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
reverseProxyLBPolicyURIHash:
return nil
case reverseProxyLBPolicyHeader:
if policy.key == "" {
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
}
case reverseProxyLBPolicyQuery:
if policy.key == "" {
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
}
default:
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
}
if policy.fallback != nil {
return validateReverseProxyLBPolicy(*policy.fallback)
}
return nil
}
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
now := time.Now()
policy := p.config.LoadBalancing.Policy
candidates := p.availableUpstreams(now, excluded)
if len(candidates) == 0 && len(excluded) > 0 {
candidates = p.availableUpstreams(now, nil)
}
if len(candidates) == 0 {
return nil, errReverseProxyNoAvailableUpstreams
}
return p.selectUpstreamWithPolicy(c, candidates, policy), nil
}
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
candidates := make([]*reverseProxyUpstream, 0, len(p.upstreams))
for _, upstream := range p.upstreams {
if _, skip := excluded[upstream.key]; skip {
continue
}
if !upstream.healthy(now, p.config.PassiveHealth) {
continue
}
candidates = append(candidates, upstream)
}
return candidates
}
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
switch policy.kind {
case reverseProxyLBPolicyRoundRobin:
return candidates[p.nextRoundRobinIndex(len(candidates))]
case reverseProxyLBPolicyFirst:
return candidates[0]
case reverseProxyLBPolicyLeastConn:
return p.selectLeastConnUpstream(candidates)
case reverseProxyLBPolicyIPHash:
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
case reverseProxyLBPolicyClientIPHash:
return reverseProxySelectHRW(candidates, c.RequestIP())
case reverseProxyLBPolicyURIHash:
if c.Request == nil || c.Request.URL == nil {
return reverseProxySelectRandom(candidates)
}
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
case reverseProxyLBPolicyHeader:
if c.Request != nil && c.Request.Header != nil {
if values, ok := c.Request.Header[policy.key]; ok {
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyQuery:
if c.Request != nil && c.Request.URL != nil {
if values, ok := c.Request.URL.Query()[policy.key]; ok {
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyRandom:
fallthrough
default:
return reverseProxySelectRandom(candidates)
}
}
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
if size <= 1 {
return 0
}
return int((p.roundRobin.Add(1) - 1) % uint64(size))
}
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
selected := candidates[0]
lowest := selected.inFlight.Load()
ties := []*reverseProxyUpstream{selected}
for _, upstream := range candidates[1:] {
count := upstream.inFlight.Load()
switch {
case count < lowest:
selected = upstream
lowest = count
ties = []*reverseProxyUpstream{upstream}
case count == lowest:
ties = append(ties, upstream)
}
}
if len(ties) == 1 {
return selected
}
return ties[p.nextRoundRobinIndex(len(ties))]
}
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if len(candidates) == 1 {
return candidates[0]
}
return candidates[rand.IntN(len(candidates))]
}
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if key == "" {
return reverseProxySelectRandom(candidates)
}
selected := candidates[0]
bestScore := reverseProxyHRWScore(key, selected.key)
for _, upstream := range candidates[1:] {
score := reverseProxyHRWScore(key, upstream.key)
if score > bestScore {
selected = upstream
bestScore = score
}
}
return selected
}
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
const (
offset64 = 14695981039346656037
prime64 = 1099511628211
)
h := uint64(offset64)
for i := 0; i < len(key); i++ {
h ^= uint64(key[i])
h *= prime64
}
h ^= 0xff
h *= prime64
for i := 0; i < len(upstreamKey); i++ {
h ^= uint64(upstreamKey[i])
h *= prime64
}
return h
}
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
if policy.fallback != nil {
return *policy.fallback
}
return LBRandom()
}
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return true
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
return len(u.failures) < maxFails
}
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
u.failures = append(u.failures, now)
}
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
if len(u.failures) == 0 || window <= 0 {
if window <= 0 {
u.failures = nil
}
return
}
cutoff := now.Add(-window)
keep := 0
for _, failureAt := range u.failures {
if failureAt.Before(cutoff) {
continue
}
u.failures[keep] = failureAt
keep++
}
u.failures = u.failures[:keep]
}
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
if config.FailDuration <= 0 {
return 0
}
if config.MaxFails <= 0 {
return 1
}
return config.MaxFails
}
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
if status <= 0 {
return false
}
for _, unhealthyStatus := range config.UnhealthyStatus {
if status == unhealthyStatus {
return true
}
}
return false
}

View file

@ -13,7 +13,9 @@ import (
"net/http/httptrace"
"net/textproto"
"net/url"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@ -262,6 +264,507 @@ func TestReverseProxyDefaultViaFallback(t *testing.T) {
}
}
func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) {
t.Helper()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
Targets: []string{"http://example.net"},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusInternalServerError {
t.Fatalf("unexpected status: %d", rr.Code)
}
}
func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) {
t.Helper()
type snapshot struct {
Path string
RawQuery string
}
backendOneCh := make(chan snapshot, 1)
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwoCh := make(chan snapshot, 1)
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
Targets: []string{
backendOne.URL + "/one?from=one",
backendTwo.URL + "/two?from=two",
},
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
}))
first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil)
if first.Code != http.StatusOK || first.Body.String() != "one" {
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
}
second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil)
if second.Code != http.StatusOK || second.Body.String() != "two" {
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
}
select {
case got := <-backendOneCh:
if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" {
t.Fatalf("unexpected first upstream request: %#v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first upstream request")
}
select {
case got := <-backendTwoCh:
if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" {
t.Fatalf("unexpected second upstream request: %#v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for second upstream request")
}
}
func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) {
t.Helper()
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBHeader("X-Upstream", LBFirst()),
},
}))
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
}
headers := http.Header{"X-Upstream": {"tenant-a"}}
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
}
if firstSticky.Body.String() != secondSticky.Body.String() {
t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
}
}
func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) {
t.Helper()
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBQuery("tenant", LBFirst()),
},
}))
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
}
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
}
if firstSticky.Body.String() != secondSticky.Body.String() {
t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
}
}
func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) {
t.Helper()
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"})
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBClientIPHash(),
},
}))
reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
reqOne.RemoteAddr = "10.0.0.1:1234"
reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10")
rrOne := httptest.NewRecorder()
engine.ServeHTTP(rrOne, reqOne)
reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
reqTwo.RemoteAddr = "10.0.0.2:5678"
reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10")
rrTwo := httptest.NewRecorder()
engine.ServeHTTP(rrTwo, reqTwo)
if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK {
t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code)
}
if rrOne.Body.String() != rrTwo.Body.String() {
t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String())
}
}
func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "ok")
}))
defer backend.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{"http://127.0.0.1:1", backend.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
Retries: 1,
},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String())
}
}
func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, r.Header.Get("X-Attempt"))
}))
defer backend.Close()
var attempts atomic.Int64
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{"http://127.0.0.1:1", backend.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
Retries: 1,
},
ModifyRequest: func(req *http.Request) {
req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10))
},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rr.Code)
}
if rr.Body.String() != "2" {
t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String())
}
}
func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) {
t.Helper()
backendCalls := make(chan struct{}, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendCalls <- struct{}{}
_, _ = io.WriteString(w, "ok")
}))
defer backend.Close()
engine := New()
engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{"http://127.0.0.1:1", backend.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
Retries: 1,
},
}))
rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil)
if rr.Code != http.StatusBadGateway {
t.Fatalf("unexpected status: %d", rr.Code)
}
select {
case <-backendCalls:
t.Fatal("unsafe POST request should not be retried to the next upstream")
default:
}
}
func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) {
t.Helper()
backendOneStarted := make(chan struct{}, 1)
releaseBackendOne := make(chan struct{})
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendOneStarted <- struct{}{}
<-releaseBackendOne
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBLeastConn(),
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
client := proxy.Client()
client.Timeout = 5 * time.Second
firstRespCh := make(chan string, 1)
firstErrCh := make(chan error, 1)
go func() {
resp, err := client.Get(proxy.URL + "/proxy")
if err != nil {
firstErrCh <- err
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
firstErrCh <- err
return
}
firstRespCh <- string(body)
}()
select {
case <-backendOneStarted:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first backend request")
}
secondResp, err := client.Get(proxy.URL + "/proxy")
if err != nil {
close(releaseBackendOne)
t.Fatalf("second request failed: %v", err)
}
secondBody, err := io.ReadAll(secondResp.Body)
_ = secondResp.Body.Close()
if err != nil {
close(releaseBackendOne)
t.Fatalf("read second response: %v", err)
}
if string(secondBody) != "two" {
close(releaseBackendOne)
t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody))
}
close(releaseBackendOne)
select {
case err := <-firstErrCh:
t.Fatalf("first request failed: %v", err)
case body := <-firstRespCh:
if body != "one" {
t.Fatalf("unexpected first response body: %q", body)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first response body")
}
}
func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) {
t.Helper()
primaryCalls := make(chan struct{}, 4)
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
primaryCalls <- struct{}{}
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = io.WriteString(w, "primary down")
}))
defer primary.Close()
secondaryCalls := make(chan struct{}, 4)
secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
secondaryCalls <- struct{}{}
_, _ = io.WriteString(w, "secondary up")
}))
defer secondary.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{primary.URL, secondary.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
},
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
UnhealthyStatus: []int{http.StatusServiceUnavailable},
},
}))
first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" {
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
}
second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if second.Code != http.StatusOK || second.Body.String() != "secondary up" {
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
}
select {
case <-primaryCalls:
default:
t.Fatal("expected primary to receive the first request")
}
select {
case <-secondaryCalls:
default:
t.Fatal("expected secondary to receive the second request")
}
select {
case <-primaryCalls:
t.Fatal("primary should not receive the second request while unhealthy")
default:
}
}
func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) {
t.Helper()
started := make(chan struct{}, 1)
release := make(chan struct{})
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
started <- struct{}{}
<-release
_, _ = io.WriteString(w, "ok")
}))
defer backend.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backend.URL},
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil)
if err != nil {
t.Fatalf("new request: %v", err)
}
client := proxy.Client()
respCh := make(chan error, 1)
go func() {
resp, err := client.Do(req)
if resp != nil {
_ = resp.Body.Close()
}
respCh <- err
}()
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend request")
}
cancel()
close(release)
select {
case <-respCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for canceled request to finish")
}
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String())
}
}
func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) {
t.Helper()
backendCalls := make(chan struct{}, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendCalls <- struct{}{}
_, _ = io.WriteString(w, "ok")
}))
defer backend.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{"http://127.0.0.1:1", backend.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
Retries: 3,
TryDuration: 100 * time.Millisecond,
TryInterval: 250 * time.Millisecond,
},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusBadGateway {
t.Fatalf("unexpected status: %d", rr.Code)
}
select {
case <-backendCalls:
t.Fatal("retry budget should expire before the next upstream attempt")
default:
}
}
func TestReverseProxyCustomErrorHandler(t *testing.T) {
t.Helper()
@ -967,6 +1470,122 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
}
}
func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
t.Helper()
enableHTTP2ExtendedConnectProtocol()
errCh := make(chan error, 8)
newBackend := func(name string) *httptest.Server {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if got := r.Header.Get(":protocol"); got != "websocket" {
errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got)
w.WriteHeader(http.StatusBadRequest)
return
}
controller := http.NewResponseController(w)
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err)
return
}
w.WriteHeader(http.StatusOK)
_ = controller.Flush()
line, err := bufio.NewReader(r.Body).ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err)
return
}
if _, err := io.WriteString(w, name+":"+line); err != nil {
errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err)
return
}
_ = controller.Flush()
}))
server.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil {
t.Fatalf("configure %s HTTP/2 server: %v", name, err)
}
server.StartTLS()
return server
}
backendOne := newBackend("one")
defer backendOne.Close()
backendTwo := newBackend("two")
defer backendTwo.Close()
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBRoundRobin(),
},
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
Via: "proxy.test",
}))
proxy := httptest.NewUnstartedServer(engine)
proxy.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
t.Fatalf("configure proxy HTTP/2 server: %v", err)
}
proxy.StartTLS()
defer proxy.Close()
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
defer transport.CloseIdleConnections()
doRequest := func(payload string) string {
pr, pw := io.Pipe()
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
if err != nil {
t.Fatalf("new CONNECT request: %v", err)
}
req.Header.Set(":protocol", "websocket")
resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("round trip extended CONNECT: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if _, err := io.WriteString(pw, payload+"\n"); err != nil {
t.Fatalf("write tunneled request body: %v", err)
}
if err := pw.Close(); err != nil {
t.Fatalf("close tunneled request body: %v", err)
}
message, err := bufio.NewReader(resp.Body).ReadString('\n')
if err != nil {
t.Fatalf("read tunneled response body: %v", err)
}
return message
}
if got := doRequest("ping"); got != "one:ping\n" {
t.Fatalf("unexpected first tunneled response: %q", got)
}
if got := doRequest("pong"); got != "two:pong\n" {
t.Fatalf("unexpected second tunneled response: %q", got)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
t.Helper()