mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat(reverseproxy): add upstream balancing and failover
This commit is contained in:
parent
59f190ce3a
commit
919236665b
4 changed files with 1394 additions and 116 deletions
|
|
@ -59,7 +59,11 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
|
||||||
```go
|
```go
|
||||||
type ReverseProxyConfig struct {
|
type ReverseProxyConfig struct {
|
||||||
Target *url.URL
|
Target *url.URL
|
||||||
|
Targets []string
|
||||||
|
|
||||||
|
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||||
|
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||||
|
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
FlushInterval time.Duration
|
FlushInterval time.Duration
|
||||||
|
|
@ -78,12 +82,115 @@ type ReverseProxyConfig struct {
|
||||||
|
|
||||||
### `Target`
|
### `Target`
|
||||||
|
|
||||||
必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。
|
与 `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme` 和 `host`。
|
||||||
|
|
||||||
```go
|
```go
|
||||||
target, _ := url.Parse("http://backend:9000")
|
target, _ := url.Parse("http://backend:9000")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `Targets`
|
||||||
|
|
||||||
|
可选。用于配置多个后端目标地址。
|
||||||
|
|
||||||
|
- `Target` 与 `Targets` 互斥,只能使用其中一种
|
||||||
|
- `Targets` 的每一项都必须是完整 URL
|
||||||
|
- 每个 target 仍然可以自带 base path 和 query
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Targets: []string{
|
||||||
|
"http://127.0.0.1:9001/base?from=a",
|
||||||
|
"http://127.0.0.1:9002/base?from=b",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。
|
||||||
|
|
||||||
|
### `LoadBalancing`
|
||||||
|
|
||||||
|
用于配置 upstream 选择策略和重试行为。
|
||||||
|
|
||||||
|
```go
|
||||||
|
type ReverseProxyLoadBalancingConfig struct {
|
||||||
|
Policy ReverseProxyLBPolicy
|
||||||
|
Retries int
|
||||||
|
TryDuration time.Duration
|
||||||
|
TryInterval time.Duration
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
当前内置策略:
|
||||||
|
|
||||||
|
- `touka.LBRandom()`
|
||||||
|
- `touka.LBRoundRobin()`
|
||||||
|
- `touka.LBFirst()`
|
||||||
|
- `touka.LBLeastConn()`
|
||||||
|
- `touka.LBIPHash()`
|
||||||
|
- `touka.LBClientIPHash()`
|
||||||
|
- `touka.LBURIHash()`
|
||||||
|
- `touka.LBHeader("X-Upstream", fallback)`
|
||||||
|
- `touka.LBQuery("tenant", fallback)`
|
||||||
|
|
||||||
|
其中:
|
||||||
|
|
||||||
|
- `LBFirst()` 适合主备/故障转移顺序
|
||||||
|
- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback
|
||||||
|
- 如果 `LBHeader` / `LBQuery` 没有显式 fallback,则默认回退到 `LBRandom()`
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Targets: []string{
|
||||||
|
"http://127.0.0.1:9001",
|
||||||
|
"http://127.0.0.1:9002",
|
||||||
|
},
|
||||||
|
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: touka.LBHeader("X-Upstream", touka.LBFirst()),
|
||||||
|
Retries: 1,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
重试说明:
|
||||||
|
|
||||||
|
- 只对未开始收到上游响应的失败进行重试
|
||||||
|
- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试
|
||||||
|
- `Retries` 表示额外重试次数
|
||||||
|
- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机
|
||||||
|
- `TryInterval` 表示两次重试之间的等待间隔
|
||||||
|
|
||||||
|
### `PassiveHealth`
|
||||||
|
|
||||||
|
用于配置被动健康检查。它不会后台探测 upstream,而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。
|
||||||
|
|
||||||
|
```go
|
||||||
|
type ReverseProxyPassiveHealthConfig struct {
|
||||||
|
FailDuration time.Duration
|
||||||
|
MaxFails int
|
||||||
|
UnhealthyStatus []int
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `FailDuration > 0` 时启用被动健康跟踪
|
||||||
|
- `MaxFails <= 0` 时默认按 `1` 处理
|
||||||
|
- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Targets: []string{
|
||||||
|
"http://127.0.0.1:9001",
|
||||||
|
"http://127.0.0.1:9002",
|
||||||
|
},
|
||||||
|
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: touka.LBFirst(),
|
||||||
|
},
|
||||||
|
PassiveHealth: touka.ReverseProxyPassiveHealthConfig{
|
||||||
|
FailDuration: time.Minute,
|
||||||
|
UnhealthyStatus: []int{http.StatusServiceUnavailable},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
### `Transport`
|
### `Transport`
|
||||||
|
|
||||||
可选。用于自定义底层转发所使用的 `http.RoundTripper`。
|
可选。用于自定义底层转发所使用的 `http.RoundTripper`。
|
||||||
|
|
@ -150,6 +257,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
|
||||||
在请求真正发往后端前,对出站请求做最后修改。
|
在请求真正发往后端前,对出站请求做最后修改。
|
||||||
|
|
||||||
|
如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。
|
||||||
|
|
||||||
常见用途:
|
常见用途:
|
||||||
|
|
||||||
- 覆盖 `Host`
|
- 覆盖 `Host`
|
||||||
|
|
|
||||||
426
reverseproxy.go
426
reverseproxy.go
|
|
@ -23,6 +23,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ForwardedHeadersPolicy controls how forwarding headers are generated.
|
// ForwardedHeadersPolicy controls how forwarding headers are generated.
|
||||||
|
|
@ -44,7 +46,11 @@ type BufferPool interface {
|
||||||
|
|
||||||
// ReverseProxyConfig configures the reverse proxy handler.
|
// ReverseProxyConfig configures the reverse proxy handler.
|
||||||
type ReverseProxyConfig struct {
|
type ReverseProxyConfig struct {
|
||||||
Target *url.URL
|
Target *url.URL
|
||||||
|
Targets []string
|
||||||
|
|
||||||
|
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||||
|
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||||
|
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
FlushInterval time.Duration
|
FlushInterval time.Duration
|
||||||
|
|
@ -61,17 +67,18 @@ type ReverseProxyConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
|
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
|
||||||
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
|
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
|
||||||
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
|
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
|
||||||
|
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
|
||||||
)
|
)
|
||||||
|
|
||||||
type reverseProxyHandler struct {
|
type reverseProxyHandler struct {
|
||||||
config ReverseProxyConfig
|
config ReverseProxyConfig
|
||||||
target *url.URL
|
upstreams []*reverseProxyUpstream
|
||||||
receivedBy string
|
receivedBy string
|
||||||
configError error
|
configError error
|
||||||
extendedConnectTransport http.RoundTripper
|
roundRobin atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
type reverseProxyStatusError struct {
|
type reverseProxyStatusError struct {
|
||||||
|
|
@ -199,22 +206,16 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
||||||
target := cloneReverseProxyURL(config.Target)
|
|
||||||
if target != nil {
|
|
||||||
normalizeReverseProxyTarget(target)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy := &reverseProxyHandler{
|
proxy := &reverseProxyHandler{
|
||||||
config: config,
|
config: config,
|
||||||
target: target,
|
|
||||||
receivedBy: reverseProxyReceivedBy(config.Via),
|
receivedBy: reverseProxyReceivedBy(config.Via),
|
||||||
}
|
}
|
||||||
if config.Transport == nil {
|
|
||||||
proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validateReverseProxyTarget(target); err != nil {
|
upstreams, err := buildReverseProxyUpstreams(config)
|
||||||
|
if err != nil {
|
||||||
proxy.configError = err
|
proxy.configError = err
|
||||||
|
} else {
|
||||||
|
proxy.upstreams = upstreams
|
||||||
}
|
}
|
||||||
|
|
||||||
switch config.ForwardedHeaders {
|
switch config.ForwardedHeaders {
|
||||||
|
|
@ -228,6 +229,11 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
||||||
proxy.configError = err
|
proxy.configError = err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if proxy.configError == nil {
|
||||||
|
if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil {
|
||||||
|
proxy.configError = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return proxy
|
return proxy
|
||||||
}
|
}
|
||||||
|
|
@ -240,15 +246,6 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
transport := p.config.Transport
|
|
||||||
if transport == nil {
|
|
||||||
if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil {
|
|
||||||
transport = p.extendedConnectTransport
|
|
||||||
} else {
|
|
||||||
transport = http.DefaultTransport
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.handleError(c, err)
|
p.handleError(c, err)
|
||||||
|
|
@ -260,86 +257,64 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
|
|
||||||
ctx, cancel := p.requestContext(c)
|
ctx, cancel := p.requestContext(c)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
attempted := make(map[string]struct{}, len(p.upstreams))
|
||||||
|
attempts := 0
|
||||||
|
started := time.Now()
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
outreq := c.Request.Clone(ctx)
|
for {
|
||||||
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
|
upstream, err := p.selectUpstream(c, attempted)
|
||||||
outreq.Body = nil
|
if err != nil {
|
||||||
}
|
if lastErr != nil {
|
||||||
if outreq.Body != nil {
|
p.handleError(c, lastErr)
|
||||||
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
|
|
||||||
defer outreq.Body.Close()
|
|
||||||
}
|
|
||||||
if outreq.Header == nil {
|
|
||||||
outreq.Header = make(http.Header)
|
|
||||||
}
|
|
||||||
outreq.Close = false
|
|
||||||
var connectWriter *io.PipeWriter
|
|
||||||
defer func() {
|
|
||||||
if connectWriter != nil {
|
|
||||||
_ = connectWriter.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if outreq.Method == http.MethodConnect {
|
|
||||||
pipeReader, pipeWriter := io.Pipe()
|
|
||||||
outreq.Body = pipeReader
|
|
||||||
outreq.ContentLength = -1
|
|
||||||
defer outreq.Body.Close()
|
|
||||||
connectWriter = pipeWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
if outreq.Method == http.MethodConnect {
|
|
||||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
|
||||||
rewriteReverseProxyURL(outreq, p.target)
|
|
||||||
if !p.config.PreserveHost {
|
|
||||||
outreq.Host = ""
|
|
||||||
}
|
|
||||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
|
||||||
} else {
|
|
||||||
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
|
||||||
p.handleError(c, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
rewriteReverseProxyURL(outreq, p.target)
|
|
||||||
if !p.config.PreserveHost {
|
|
||||||
outreq.Host = ""
|
|
||||||
}
|
|
||||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
|
||||||
}
|
|
||||||
if updatedMaxForwards != "" {
|
|
||||||
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
|
|
||||||
}
|
|
||||||
|
|
||||||
reqUpType := reverseProxyUpgradeType(outreq.Header)
|
attempts++
|
||||||
if reqUpType != "" && !isPrintableASCII(reqUpType) {
|
upstream.inFlight.Add(1)
|
||||||
p.handleError(c, &reverseProxyStatusError{
|
served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards)
|
||||||
status: http.StatusBadRequest,
|
upstream.inFlight.Add(-1)
|
||||||
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
|
|
||||||
})
|
if served {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if attemptErr != nil {
|
||||||
|
lastErr = attemptErr
|
||||||
|
}
|
||||||
|
if retriable && p.shouldRetryAttempt(c.Request, attempts, started) {
|
||||||
|
attempted[upstream.key] = struct{}{}
|
||||||
|
if !p.waitRetryInterval(ctx, started) {
|
||||||
|
if lastErr != nil {
|
||||||
|
p.handleError(c, lastErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if attemptErr != nil {
|
||||||
|
p.handleError(c, attemptErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
p.handleError(c, lastErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
removeHopByHopHeaders(outreq.Header)
|
func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) {
|
||||||
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
|
outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards)
|
||||||
outreq.Header.Set("Te", "trailers")
|
if err != nil {
|
||||||
}
|
return false, err, false
|
||||||
if reqUpType != "" {
|
|
||||||
outreq.Header.Set("Connection", "Upgrade")
|
|
||||||
outreq.Header.Set("Upgrade", reqUpType)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.addForwardingHeaders(c.Request, outreq)
|
|
||||||
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
|
|
||||||
|
|
||||||
if _, ok := outreq.Header["User-Agent"]; !ok {
|
|
||||||
outreq.Header.Set("User-Agent", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.config.ModifyRequest != nil {
|
|
||||||
p.config.ModifyRequest(outreq)
|
|
||||||
}
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
transport := p.transportForUpstream(c.Request, upstream)
|
||||||
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
|
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
|
||||||
var (
|
var (
|
||||||
roundTripMu sync.Mutex
|
roundTripMu sync.Mutex
|
||||||
|
|
@ -369,8 +344,13 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
roundTripDone = true
|
roundTripDone = true
|
||||||
roundTripMu.Unlock()
|
roundTripMu.Unlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.handleError(c, err)
|
if reverseProxyShouldCountPassiveFailure(outreq, err) {
|
||||||
return
|
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
|
||||||
|
}
|
||||||
|
return false, err, true
|
||||||
|
}
|
||||||
|
if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) {
|
||||||
|
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
|
||||||
}
|
}
|
||||||
|
|
||||||
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
|
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
|
||||||
|
|
@ -381,35 +361,34 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
res.TransferEncoding = nil
|
res.TransferEncoding = nil
|
||||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||||
if !p.modifyResponse(c, res, outreq) {
|
if !p.modifyResponse(c, res, outreq) {
|
||||||
return
|
return true, nil, false
|
||||||
}
|
}
|
||||||
handleConnect := p.handleConnectResponse
|
handleConnect := p.handleConnectResponse
|
||||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||||
handleConnect = p.handleExtendedConnectResponse
|
handleConnect = p.handleExtendedConnectResponse
|
||||||
}
|
}
|
||||||
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
|
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
|
||||||
p.handleError(c, err)
|
return false, err, false
|
||||||
}
|
}
|
||||||
connectWriter = nil
|
return true, nil, false
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||||
if !p.modifyResponse(c, res, outreq) {
|
if !p.modifyResponse(c, res, outreq) {
|
||||||
return
|
return true, nil, false
|
||||||
}
|
}
|
||||||
if err := p.handleUpgradeResponse(c, outreq, res); err != nil {
|
if err := p.handleUpgradeResponse(c, outreq, res); err != nil {
|
||||||
p.handleError(c, err)
|
return false, err, false
|
||||||
}
|
}
|
||||||
return
|
return true, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
removeHopByHopHeaders(res.Header)
|
removeHopByHopHeaders(res.Header)
|
||||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||||
|
|
||||||
if !p.modifyResponse(c, res, outreq) {
|
if !p.modifyResponse(c, res, outreq) {
|
||||||
return
|
return true, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
|
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
|
||||||
|
|
@ -432,7 +411,7 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
if reverseProxyShouldPanicOnCopyError(c.Request) {
|
if reverseProxyShouldPanicOnCopyError(c.Request) {
|
||||||
panic(http.ErrAbortHandler)
|
panic(http.ErrAbortHandler)
|
||||||
}
|
}
|
||||||
return
|
return true, nil, false
|
||||||
}
|
}
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
|
|
||||||
|
|
@ -440,13 +419,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep the stdlib-compatible fallback here.
|
|
||||||
// If the backend only exposes additional trailer keys after the body has been
|
|
||||||
// fully read, the trailer map can grow and those values must be written using
|
|
||||||
// the TrailerPrefix form instead of the pre-announced bare header keys.
|
|
||||||
if len(res.Trailer) == announcedTrailers {
|
if len(res.Trailer) == announcedTrailers {
|
||||||
reverseProxyCopyHeader(c.Writer.Header(), res.Trailer)
|
reverseProxyCopyHeader(c.Writer.Header(), res.Trailer)
|
||||||
return
|
return true, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, values := range res.Trailer {
|
for key, values := range res.Trailer {
|
||||||
|
|
@ -455,6 +430,148 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
c.Writer.Header().Add(prefixedKey, value)
|
c.Writer.Header().Add(prefixedKey, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
|
||||||
|
outreq := c.Request.Clone(ctx)
|
||||||
|
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
|
||||||
|
outreq.Body = nil
|
||||||
|
} else if c.Request.GetBody != nil {
|
||||||
|
body, err := c.Request.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err)
|
||||||
|
}
|
||||||
|
outreq.Body = body
|
||||||
|
} else if outreq.Body != nil {
|
||||||
|
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
|
||||||
|
}
|
||||||
|
if outreq.Header == nil {
|
||||||
|
outreq.Header = make(http.Header)
|
||||||
|
}
|
||||||
|
outreq.Close = false
|
||||||
|
var connectWriter *io.PipeWriter
|
||||||
|
if outreq.Method == http.MethodConnect {
|
||||||
|
pipeReader, pipeWriter := io.Pipe()
|
||||||
|
outreq.Body = pipeReader
|
||||||
|
outreq.ContentLength = -1
|
||||||
|
connectWriter = pipeWriter
|
||||||
|
}
|
||||||
|
cleanup := func() {
|
||||||
|
if outreq.Body != nil {
|
||||||
|
_ = outreq.Body.Close()
|
||||||
|
}
|
||||||
|
if connectWriter != nil {
|
||||||
|
_ = connectWriter.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if outreq.Method == http.MethodConnect {
|
||||||
|
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||||
|
rewriteReverseProxyURL(outreq, upstream.target)
|
||||||
|
if !p.config.PreserveHost {
|
||||||
|
outreq.Host = ""
|
||||||
|
}
|
||||||
|
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||||
|
} else {
|
||||||
|
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
|
||||||
|
cleanup()
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rewriteReverseProxyURL(outreq, upstream.target)
|
||||||
|
if !p.config.PreserveHost {
|
||||||
|
outreq.Host = ""
|
||||||
|
}
|
||||||
|
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||||
|
}
|
||||||
|
if updatedMaxForwards != "" {
|
||||||
|
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
|
||||||
|
}
|
||||||
|
|
||||||
|
reqUpType := reverseProxyUpgradeType(outreq.Header)
|
||||||
|
if reqUpType != "" && !isPrintableASCII(reqUpType) {
|
||||||
|
cleanup()
|
||||||
|
return nil, nil, nil, &reverseProxyStatusError{
|
||||||
|
status: http.StatusBadRequest,
|
||||||
|
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
removeHopByHopHeaders(outreq.Header)
|
||||||
|
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
|
||||||
|
outreq.Header.Set("Te", "trailers")
|
||||||
|
}
|
||||||
|
if reqUpType != "" {
|
||||||
|
outreq.Header.Set("Connection", "Upgrade")
|
||||||
|
outreq.Header.Set("Upgrade", reqUpType)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.addForwardingHeaders(c.Request, outreq)
|
||||||
|
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
|
||||||
|
|
||||||
|
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||||
|
outreq.Header.Set("User-Agent", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.ModifyRequest != nil {
|
||||||
|
p.config.ModifyRequest(outreq)
|
||||||
|
}
|
||||||
|
|
||||||
|
return outreq, connectWriter, cleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper {
|
||||||
|
if p.config.Transport != nil {
|
||||||
|
return p.config.Transport
|
||||||
|
}
|
||||||
|
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
|
||||||
|
return upstream.extendedConnectTransport
|
||||||
|
}
|
||||||
|
return http.DefaultTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool {
|
||||||
|
if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
lb := p.config.LoadBalancing
|
||||||
|
if lb.TryDuration > 0 {
|
||||||
|
return time.Since(started) < lb.TryDuration
|
||||||
|
}
|
||||||
|
return attempts <= lb.Retries
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool {
|
||||||
|
interval := p.config.LoadBalancing.TryInterval
|
||||||
|
tryDuration := p.config.LoadBalancing.TryDuration
|
||||||
|
if tryDuration > 0 && interval == 0 {
|
||||||
|
interval = 250 * time.Millisecond
|
||||||
|
}
|
||||||
|
if tryDuration > 0 {
|
||||||
|
remaining := tryDuration - time.Since(started)
|
||||||
|
if remaining <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
return ctx.Err() == nil
|
||||||
|
}
|
||||||
|
if interval > remaining {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
return ctx.Err() == nil
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(interval)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-timer.C:
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
|
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
|
||||||
|
|
@ -976,6 +1093,54 @@ func validateReverseProxyTarget(target *url.URL) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) {
|
||||||
|
if config.Target != nil && len(config.Targets) > 0 {
|
||||||
|
return nil, errors.New("reverse proxy Target and Targets cannot be used together")
|
||||||
|
}
|
||||||
|
|
||||||
|
targets := make([]*url.URL, 0, max(1, len(config.Targets)))
|
||||||
|
if config.Target != nil {
|
||||||
|
target := cloneReverseProxyURL(config.Target)
|
||||||
|
normalizeReverseProxyTarget(target)
|
||||||
|
if err := validateReverseProxyTarget(target); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
targets = append(targets, target)
|
||||||
|
}
|
||||||
|
for i, rawTarget := range config.Targets {
|
||||||
|
trimmed := strings.TrimSpace(rawTarget)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil, fmt.Errorf("reverse proxy target at index %d is empty", i)
|
||||||
|
}
|
||||||
|
target, err := url.Parse(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
|
||||||
|
}
|
||||||
|
normalizeReverseProxyTarget(target)
|
||||||
|
if err := validateReverseProxyTarget(target); err != nil {
|
||||||
|
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
|
||||||
|
}
|
||||||
|
targets = append(targets, target)
|
||||||
|
}
|
||||||
|
if len(targets) == 0 {
|
||||||
|
return nil, errReverseProxyNilTarget
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
|
||||||
|
for i, target := range targets {
|
||||||
|
upstream := &reverseProxyUpstream{
|
||||||
|
key: fmt.Sprintf("%d:%s", i, target.String()),
|
||||||
|
target: target,
|
||||||
|
index: i,
|
||||||
|
}
|
||||||
|
if config.Transport == nil {
|
||||||
|
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
||||||
|
}
|
||||||
|
upstreams = append(upstreams, upstream)
|
||||||
|
}
|
||||||
|
return upstreams, nil
|
||||||
|
}
|
||||||
|
|
||||||
func validateReverseProxyForwardedBy(value string) error {
|
func validateReverseProxyForwardedBy(value string) error {
|
||||||
trimmed := strings.TrimSpace(value)
|
trimmed := strings.TrimSpace(value)
|
||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
|
|
@ -1388,6 +1553,35 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
|
||||||
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reverseProxyCanRetryRequest(req *http.Request) bool {
|
||||||
|
if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if req.Body == nil || req.ContentLength == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return req.GetBody != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool {
|
||||||
|
if err == nil || reverseProxyIsBenignTunnelError(err) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if req != nil && req.Context().Err() != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !errors.Is(err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyMethodIsSafe(method string) bool {
|
||||||
|
switch method {
|
||||||
|
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func reverseProxyIsBenignTunnelError(err error) bool {
|
func reverseProxyIsBenignTunnelError(err error) bool {
|
||||||
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err)
|
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err)
|
||||||
}
|
}
|
||||||
|
|
@ -1396,6 +1590,10 @@ func reverseProxyIsClosedBodyError(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
var streamErr http2.StreamError
|
||||||
|
if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel {
|
||||||
|
return true
|
||||||
|
}
|
||||||
switch err.Error() {
|
switch err.Error() {
|
||||||
case "body closed by handler", "http2: response body closed", "response body closed":
|
case "body closed by handler", "http2: response body closed", "response body closed":
|
||||||
return true
|
return true
|
||||||
|
|
|
||||||
352
reverseproxy_lb.go
Normal file
352
reverseproxy_lb.go
Normal file
|
|
@ -0,0 +1,352 @@
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||||
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand/v2"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
|
||||||
|
type ReverseProxyLoadBalancingConfig struct {
|
||||||
|
Policy ReverseProxyLBPolicy
|
||||||
|
Retries int
|
||||||
|
TryDuration time.Duration
|
||||||
|
TryInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
|
||||||
|
type ReverseProxyPassiveHealthConfig struct {
|
||||||
|
FailDuration time.Duration
|
||||||
|
MaxFails int
|
||||||
|
UnhealthyStatus []int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
|
||||||
|
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
|
||||||
|
type ReverseProxyLBPolicy struct {
|
||||||
|
kind reverseProxyLBPolicyKind
|
||||||
|
key string
|
||||||
|
fallback *ReverseProxyLBPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
type reverseProxyLBPolicyKind uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
|
||||||
|
reverseProxyLBPolicyRoundRobin
|
||||||
|
reverseProxyLBPolicyFirst
|
||||||
|
reverseProxyLBPolicyLeastConn
|
||||||
|
reverseProxyLBPolicyIPHash
|
||||||
|
reverseProxyLBPolicyClientIPHash
|
||||||
|
reverseProxyLBPolicyURIHash
|
||||||
|
reverseProxyLBPolicyHeader
|
||||||
|
reverseProxyLBPolicyQuery
|
||||||
|
)
|
||||||
|
|
||||||
|
type reverseProxyUpstream struct {
|
||||||
|
key string
|
||||||
|
target *url.URL
|
||||||
|
index int
|
||||||
|
extendedConnectTransport http.RoundTripper
|
||||||
|
inFlight atomic.Int64
|
||||||
|
|
||||||
|
passiveMu sync.Mutex
|
||||||
|
failures []time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBRandom() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBRoundRobin() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBFirst() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBLeastConn() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBIPHash() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBClientIPHash() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBURIHash() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||||
|
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
|
||||||
|
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||||
|
policy.fallback = &fallback
|
||||||
|
}
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||||
|
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
|
||||||
|
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||||
|
policy.fallback = &fallback
|
||||||
|
}
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
|
||||||
|
switch policy.kind {
|
||||||
|
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
|
||||||
|
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
|
||||||
|
reverseProxyLBPolicyURIHash:
|
||||||
|
return nil
|
||||||
|
case reverseProxyLBPolicyHeader:
|
||||||
|
if policy.key == "" {
|
||||||
|
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
|
||||||
|
}
|
||||||
|
case reverseProxyLBPolicyQuery:
|
||||||
|
if policy.key == "" {
|
||||||
|
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
|
||||||
|
}
|
||||||
|
if policy.fallback != nil {
|
||||||
|
return validateReverseProxyLBPolicy(*policy.fallback)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
|
||||||
|
now := time.Now()
|
||||||
|
policy := p.config.LoadBalancing.Policy
|
||||||
|
candidates := p.availableUpstreams(now, excluded)
|
||||||
|
if len(candidates) == 0 && len(excluded) > 0 {
|
||||||
|
candidates = p.availableUpstreams(now, nil)
|
||||||
|
}
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, errReverseProxyNoAvailableUpstreams
|
||||||
|
}
|
||||||
|
return p.selectUpstreamWithPolicy(c, candidates, policy), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
|
||||||
|
candidates := make([]*reverseProxyUpstream, 0, len(p.upstreams))
|
||||||
|
for _, upstream := range p.upstreams {
|
||||||
|
if _, skip := excluded[upstream.key]; skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !upstream.healthy(now, p.config.PassiveHealth) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidates = append(candidates, upstream)
|
||||||
|
}
|
||||||
|
return candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch policy.kind {
|
||||||
|
case reverseProxyLBPolicyRoundRobin:
|
||||||
|
return candidates[p.nextRoundRobinIndex(len(candidates))]
|
||||||
|
case reverseProxyLBPolicyFirst:
|
||||||
|
return candidates[0]
|
||||||
|
case reverseProxyLBPolicyLeastConn:
|
||||||
|
return p.selectLeastConnUpstream(candidates)
|
||||||
|
case reverseProxyLBPolicyIPHash:
|
||||||
|
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
|
||||||
|
case reverseProxyLBPolicyClientIPHash:
|
||||||
|
return reverseProxySelectHRW(candidates, c.RequestIP())
|
||||||
|
case reverseProxyLBPolicyURIHash:
|
||||||
|
if c.Request == nil || c.Request.URL == nil {
|
||||||
|
return reverseProxySelectRandom(candidates)
|
||||||
|
}
|
||||||
|
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
|
||||||
|
case reverseProxyLBPolicyHeader:
|
||||||
|
if c.Request != nil && c.Request.Header != nil {
|
||||||
|
if values, ok := c.Request.Header[policy.key]; ok {
|
||||||
|
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||||
|
case reverseProxyLBPolicyQuery:
|
||||||
|
if c.Request != nil && c.Request.URL != nil {
|
||||||
|
if values, ok := c.Request.URL.Query()[policy.key]; ok {
|
||||||
|
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||||
|
case reverseProxyLBPolicyRandom:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
return reverseProxySelectRandom(candidates)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
|
||||||
|
if size <= 1 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int((p.roundRobin.Add(1) - 1) % uint64(size))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
selected := candidates[0]
|
||||||
|
lowest := selected.inFlight.Load()
|
||||||
|
ties := []*reverseProxyUpstream{selected}
|
||||||
|
for _, upstream := range candidates[1:] {
|
||||||
|
count := upstream.inFlight.Load()
|
||||||
|
switch {
|
||||||
|
case count < lowest:
|
||||||
|
selected = upstream
|
||||||
|
lowest = count
|
||||||
|
ties = []*reverseProxyUpstream{upstream}
|
||||||
|
case count == lowest:
|
||||||
|
ties = append(ties, upstream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(ties) == 1 {
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
return ties[p.nextRoundRobinIndex(len(ties))]
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(candidates) == 1 {
|
||||||
|
return candidates[0]
|
||||||
|
}
|
||||||
|
return candidates[rand.IntN(len(candidates))]
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
return reverseProxySelectRandom(candidates)
|
||||||
|
}
|
||||||
|
selected := candidates[0]
|
||||||
|
bestScore := reverseProxyHRWScore(key, selected.key)
|
||||||
|
for _, upstream := range candidates[1:] {
|
||||||
|
score := reverseProxyHRWScore(key, upstream.key)
|
||||||
|
if score > bestScore {
|
||||||
|
selected = upstream
|
||||||
|
bestScore = score
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
|
||||||
|
const (
|
||||||
|
offset64 = 14695981039346656037
|
||||||
|
prime64 = 1099511628211
|
||||||
|
)
|
||||||
|
h := uint64(offset64)
|
||||||
|
for i := 0; i < len(key); i++ {
|
||||||
|
h ^= uint64(key[i])
|
||||||
|
h *= prime64
|
||||||
|
}
|
||||||
|
h ^= 0xff
|
||||||
|
h *= prime64
|
||||||
|
for i := 0; i < len(upstreamKey); i++ {
|
||||||
|
h ^= uint64(upstreamKey[i])
|
||||||
|
h *= prime64
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||||
|
if policy.fallback != nil {
|
||||||
|
return *policy.fallback
|
||||||
|
}
|
||||||
|
return LBRandom()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
|
||||||
|
maxFails := reverseProxyPassiveMaxFails(config)
|
||||||
|
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
u.passiveMu.Lock()
|
||||||
|
defer u.passiveMu.Unlock()
|
||||||
|
u.pruneFailuresLocked(now, config.FailDuration)
|
||||||
|
return len(u.failures) < maxFails
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
|
||||||
|
maxFails := reverseProxyPassiveMaxFails(config)
|
||||||
|
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u.passiveMu.Lock()
|
||||||
|
defer u.passiveMu.Unlock()
|
||||||
|
u.pruneFailuresLocked(now, config.FailDuration)
|
||||||
|
u.failures = append(u.failures, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
|
||||||
|
if len(u.failures) == 0 || window <= 0 {
|
||||||
|
if window <= 0 {
|
||||||
|
u.failures = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cutoff := now.Add(-window)
|
||||||
|
keep := 0
|
||||||
|
for _, failureAt := range u.failures {
|
||||||
|
if failureAt.Before(cutoff) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
u.failures[keep] = failureAt
|
||||||
|
keep++
|
||||||
|
}
|
||||||
|
u.failures = u.failures[:keep]
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
|
||||||
|
if config.FailDuration <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if config.MaxFails <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return config.MaxFails
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
|
||||||
|
if status <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, unhealthyStatus := range config.UnhealthyStatus {
|
||||||
|
if status == unhealthyStatus {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
@ -13,7 +13,9 @@ import (
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -262,6 +264,507 @@ func TestReverseProxyDefaultViaFallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, "http://example.com"),
|
||||||
|
Targets: []string{"http://example.net"},
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if rr.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
type snapshot struct {
|
||||||
|
Path string
|
||||||
|
RawQuery string
|
||||||
|
}
|
||||||
|
|
||||||
|
backendOneCh := make(chan snapshot, 1)
|
||||||
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
|
||||||
|
_, _ = io.WriteString(w, "one")
|
||||||
|
}))
|
||||||
|
defer backendOne.Close()
|
||||||
|
|
||||||
|
backendTwoCh := make(chan snapshot, 1)
|
||||||
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
|
||||||
|
_, _ = io.WriteString(w, "two")
|
||||||
|
}))
|
||||||
|
defer backendTwo.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{
|
||||||
|
backendOne.URL + "/one?from=one",
|
||||||
|
backendTwo.URL + "/two?from=two",
|
||||||
|
},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
|
||||||
|
}))
|
||||||
|
|
||||||
|
first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil)
|
||||||
|
if first.Code != http.StatusOK || first.Body.String() != "one" {
|
||||||
|
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
|
||||||
|
}
|
||||||
|
second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil)
|
||||||
|
if second.Code != http.StatusOK || second.Body.String() != "two" {
|
||||||
|
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-backendOneCh:
|
||||||
|
if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" {
|
||||||
|
t.Fatalf("unexpected first upstream request: %#v", got)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for first upstream request")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-backendTwoCh:
|
||||||
|
if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" {
|
||||||
|
t.Fatalf("unexpected second upstream request: %#v", got)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for second upstream request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "one")
|
||||||
|
}))
|
||||||
|
defer backendOne.Close()
|
||||||
|
|
||||||
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "two")
|
||||||
|
}))
|
||||||
|
defer backendTwo.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBHeader("X-Upstream", LBFirst()),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
|
||||||
|
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := http.Header{"X-Upstream": {"tenant-a"}}
|
||||||
|
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
|
||||||
|
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
|
||||||
|
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
|
||||||
|
}
|
||||||
|
if firstSticky.Body.String() != secondSticky.Body.String() {
|
||||||
|
t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "one")
|
||||||
|
}))
|
||||||
|
defer backendOne.Close()
|
||||||
|
|
||||||
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "two")
|
||||||
|
}))
|
||||||
|
defer backendTwo.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBQuery("tenant", LBFirst()),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
|
||||||
|
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
|
||||||
|
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
|
||||||
|
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
|
||||||
|
}
|
||||||
|
if firstSticky.Body.String() != secondSticky.Body.String() {
|
||||||
|
t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "one")
|
||||||
|
}))
|
||||||
|
defer backendOne.Close()
|
||||||
|
|
||||||
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "two")
|
||||||
|
}))
|
||||||
|
defer backendTwo.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"})
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBClientIPHash(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
||||||
|
reqOne.RemoteAddr = "10.0.0.1:1234"
|
||||||
|
reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10")
|
||||||
|
rrOne := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rrOne, reqOne)
|
||||||
|
|
||||||
|
reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
||||||
|
reqTwo.RemoteAddr = "10.0.0.2:5678"
|
||||||
|
reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10")
|
||||||
|
rrTwo := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rrTwo, reqTwo)
|
||||||
|
|
||||||
|
if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code)
|
||||||
|
}
|
||||||
|
if rrOne.Body.String() != rrTwo.Body.String() {
|
||||||
|
t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "ok")
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBFirst(),
|
||||||
|
Retries: 1,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
||||||
|
t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, r.Header.Get("X-Attempt"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
var attempts atomic.Int64
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBFirst(),
|
||||||
|
Retries: 1,
|
||||||
|
},
|
||||||
|
ModifyRequest: func(req *http.Request) {
|
||||||
|
req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10))
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
if rr.Body.String() != "2" {
|
||||||
|
t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backendCalls := make(chan struct{}, 1)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
backendCalls <- struct{}{}
|
||||||
|
_, _ = io.WriteString(w, "ok")
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBFirst(),
|
||||||
|
Retries: 1,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil)
|
||||||
|
if rr.Code != http.StatusBadGateway {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-backendCalls:
|
||||||
|
t.Fatal("unsafe POST request should not be retried to the next upstream")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backendOneStarted := make(chan struct{}, 1)
|
||||||
|
releaseBackendOne := make(chan struct{})
|
||||||
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
backendOneStarted <- struct{}{}
|
||||||
|
<-releaseBackendOne
|
||||||
|
_, _ = io.WriteString(w, "one")
|
||||||
|
}))
|
||||||
|
defer backendOne.Close()
|
||||||
|
|
||||||
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "two")
|
||||||
|
}))
|
||||||
|
defer backendTwo.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBLeastConn(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
client := proxy.Client()
|
||||||
|
client.Timeout = 5 * time.Second
|
||||||
|
|
||||||
|
firstRespCh := make(chan string, 1)
|
||||||
|
firstErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
resp, err := client.Get(proxy.URL + "/proxy")
|
||||||
|
if err != nil {
|
||||||
|
firstErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
firstErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
firstRespCh <- string(body)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-backendOneStarted:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for first backend request")
|
||||||
|
}
|
||||||
|
|
||||||
|
secondResp, err := client.Get(proxy.URL + "/proxy")
|
||||||
|
if err != nil {
|
||||||
|
close(releaseBackendOne)
|
||||||
|
t.Fatalf("second request failed: %v", err)
|
||||||
|
}
|
||||||
|
secondBody, err := io.ReadAll(secondResp.Body)
|
||||||
|
_ = secondResp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
close(releaseBackendOne)
|
||||||
|
t.Fatalf("read second response: %v", err)
|
||||||
|
}
|
||||||
|
if string(secondBody) != "two" {
|
||||||
|
close(releaseBackendOne)
|
||||||
|
t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
close(releaseBackendOne)
|
||||||
|
select {
|
||||||
|
case err := <-firstErrCh:
|
||||||
|
t.Fatalf("first request failed: %v", err)
|
||||||
|
case body := <-firstRespCh:
|
||||||
|
if body != "one" {
|
||||||
|
t.Fatalf("unexpected first response body: %q", body)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for first response body")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
primaryCalls := make(chan struct{}, 4)
|
||||||
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
primaryCalls <- struct{}{}
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
_, _ = io.WriteString(w, "primary down")
|
||||||
|
}))
|
||||||
|
defer primary.Close()
|
||||||
|
|
||||||
|
secondaryCalls := make(chan struct{}, 4)
|
||||||
|
secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
secondaryCalls <- struct{}{}
|
||||||
|
_, _ = io.WriteString(w, "secondary up")
|
||||||
|
}))
|
||||||
|
defer secondary.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{primary.URL, secondary.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBFirst(),
|
||||||
|
},
|
||||||
|
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||||
|
FailDuration: time.Minute,
|
||||||
|
UnhealthyStatus: []int{http.StatusServiceUnavailable},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" {
|
||||||
|
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
|
||||||
|
}
|
||||||
|
second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if second.Code != http.StatusOK || second.Body.String() != "secondary up" {
|
||||||
|
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-primaryCalls:
|
||||||
|
default:
|
||||||
|
t.Fatal("expected primary to receive the first request")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-secondaryCalls:
|
||||||
|
default:
|
||||||
|
t.Fatal("expected secondary to receive the second request")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-primaryCalls:
|
||||||
|
t.Fatal("primary should not receive the second request while unhealthy")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
started := make(chan struct{}, 1)
|
||||||
|
release := make(chan struct{})
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
started <- struct{}{}
|
||||||
|
<-release
|
||||||
|
_, _ = io.WriteString(w, "ok")
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{backend.URL},
|
||||||
|
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||||
|
FailDuration: time.Minute,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new request: %v", err)
|
||||||
|
}
|
||||||
|
client := proxy.Client()
|
||||||
|
respCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if resp != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}
|
||||||
|
respCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for backend request")
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
close(release)
|
||||||
|
select {
|
||||||
|
case <-respCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for canceled request to finish")
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
||||||
|
t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backendCalls := make(chan struct{}, 1)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
backendCalls <- struct{}{}
|
||||||
|
_, _ = io.WriteString(w, "ok")
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBFirst(),
|
||||||
|
Retries: 3,
|
||||||
|
TryDuration: 100 * time.Millisecond,
|
||||||
|
TryInterval: 250 * time.Millisecond,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if rr.Code != http.StatusBadGateway {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-backendCalls:
|
||||||
|
t.Fatal("retry budget should expire before the next upstream attempt")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
@ -967,6 +1470,122 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
|
||||||
|
errCh := make(chan error, 8)
|
||||||
|
newBackend := func(name string) *httptest.Server {
|
||||||
|
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodConnect {
|
||||||
|
errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method)
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||||
|
errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
controller := http.NewResponseController(w)
|
||||||
|
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||||
|
errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = controller.Flush()
|
||||||
|
|
||||||
|
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(w, name+":"+line); err != nil {
|
||||||
|
errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = controller.Flush()
|
||||||
|
}))
|
||||||
|
server.EnableHTTP2 = true
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil {
|
||||||
|
t.Fatalf("configure %s HTTP/2 server: %v", name, err)
|
||||||
|
}
|
||||||
|
server.StartTLS()
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
backendOne := newBackend("one")
|
||||||
|
defer backendOne.Close()
|
||||||
|
backendTwo := newBackend("two")
|
||||||
|
defer backendTwo.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: LBRoundRobin(),
|
||||||
|
},
|
||||||
|
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||||
|
Via: "proxy.test",
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewUnstartedServer(engine)
|
||||||
|
proxy.EnableHTTP2 = true
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||||
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||||
|
}
|
||||||
|
proxy.StartTLS()
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||||
|
defer transport.CloseIdleConnections()
|
||||||
|
|
||||||
|
doRequest := func(payload string) string {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new CONNECT request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set(":protocol", "websocket")
|
||||||
|
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(pw, payload+"\n"); err != nil {
|
||||||
|
t.Fatalf("write tunneled request body: %v", err)
|
||||||
|
}
|
||||||
|
if err := pw.Close(); err != nil {
|
||||||
|
t.Fatalf("close tunneled request body: %v", err)
|
||||||
|
}
|
||||||
|
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read tunneled response body: %v", err)
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := doRequest("ping"); got != "one:ping\n" {
|
||||||
|
t.Fatalf("unexpected first tunneled response: %q", got)
|
||||||
|
}
|
||||||
|
if got := doRequest("pong"); got != "two:pong\n" {
|
||||||
|
t.Fatalf("unexpected second tunneled response: %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatal(err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue