mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat(reverseproxy): add upstream balancing and failover
This commit is contained in:
parent
59f190ce3a
commit
919236665b
4 changed files with 1394 additions and 116 deletions
|
|
@ -59,7 +59,11 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
|||
|
||||
```go
|
||||
type ReverseProxyConfig struct {
|
||||
Target *url.URL
|
||||
Target *url.URL
|
||||
Targets []string
|
||||
|
||||
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||
|
||||
Transport http.RoundTripper
|
||||
FlushInterval time.Duration
|
||||
|
|
@ -78,12 +82,115 @@ type ReverseProxyConfig struct {
|
|||
|
||||
### `Target`
|
||||
|
||||
必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。
|
||||
与 `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme` 和 `host`。
|
||||
|
||||
```go
|
||||
target, _ := url.Parse("http://backend:9000")
|
||||
```
|
||||
|
||||
### `Targets`
|
||||
|
||||
可选。用于配置多个后端目标地址。
|
||||
|
||||
- `Target` 与 `Targets` 互斥,只能使用其中一种
|
||||
- `Targets` 的每一项都必须是完整 URL
|
||||
- 每个 target 仍然可以自带 base path 和 query
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
"http://127.0.0.1:9001/base?from=a",
|
||||
"http://127.0.0.1:9002/base?from=b",
|
||||
},
|
||||
}))
|
||||
```
|
||||
|
||||
这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。
|
||||
|
||||
### `LoadBalancing`
|
||||
|
||||
用于配置 upstream 选择策略和重试行为。
|
||||
|
||||
```go
|
||||
type ReverseProxyLoadBalancingConfig struct {
|
||||
Policy ReverseProxyLBPolicy
|
||||
Retries int
|
||||
TryDuration time.Duration
|
||||
TryInterval time.Duration
|
||||
}
|
||||
```
|
||||
|
||||
当前内置策略:
|
||||
|
||||
- `touka.LBRandom()`
|
||||
- `touka.LBRoundRobin()`
|
||||
- `touka.LBFirst()`
|
||||
- `touka.LBLeastConn()`
|
||||
- `touka.LBIPHash()`
|
||||
- `touka.LBClientIPHash()`
|
||||
- `touka.LBURIHash()`
|
||||
- `touka.LBHeader("X-Upstream", fallback)`
|
||||
- `touka.LBQuery("tenant", fallback)`
|
||||
|
||||
其中:
|
||||
|
||||
- `LBFirst()` 适合主备/故障转移顺序
|
||||
- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback
|
||||
- 如果 `LBHeader` / `LBQuery` 没有显式 fallback,则默认回退到 `LBRandom()`
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
"http://127.0.0.1:9001",
|
||||
"http://127.0.0.1:9002",
|
||||
},
|
||||
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||
Policy: touka.LBHeader("X-Upstream", touka.LBFirst()),
|
||||
Retries: 1,
|
||||
},
|
||||
}))
|
||||
```
|
||||
|
||||
重试说明:
|
||||
|
||||
- 只对未开始收到上游响应的失败进行重试
|
||||
- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试
|
||||
- `Retries` 表示额外重试次数
|
||||
- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机
|
||||
- `TryInterval` 表示两次重试之间的等待间隔
|
||||
|
||||
### `PassiveHealth`
|
||||
|
||||
用于配置被动健康检查。它不会后台探测 upstream,而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。
|
||||
|
||||
```go
|
||||
type ReverseProxyPassiveHealthConfig struct {
|
||||
FailDuration time.Duration
|
||||
MaxFails int
|
||||
UnhealthyStatus []int
|
||||
}
|
||||
```
|
||||
|
||||
- `FailDuration > 0` 时启用被动健康跟踪
|
||||
- `MaxFails <= 0` 时默认按 `1` 处理
|
||||
- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
"http://127.0.0.1:9001",
|
||||
"http://127.0.0.1:9002",
|
||||
},
|
||||
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||
Policy: touka.LBFirst(),
|
||||
},
|
||||
PassiveHealth: touka.ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
UnhealthyStatus: []int{http.StatusServiceUnavailable},
|
||||
},
|
||||
}))
|
||||
```
|
||||
|
||||
### `Transport`
|
||||
|
||||
可选。用于自定义底层转发所使用的 `http.RoundTripper`。
|
||||
|
|
@ -150,6 +257,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
|||
|
||||
在请求真正发往后端前,对出站请求做最后修改。
|
||||
|
||||
如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。
|
||||
|
||||
常见用途:
|
||||
|
||||
- 覆盖 `Host`
|
||||
|
|
|
|||
426
reverseproxy.go
426
reverseproxy.go
|
|
@ -23,6 +23,8 @@ import (
|
|||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// ForwardedHeadersPolicy controls how forwarding headers are generated.
|
||||
|
|
@ -44,7 +46,11 @@ type BufferPool interface {
|
|||
|
||||
// ReverseProxyConfig configures the reverse proxy handler.
|
||||
type ReverseProxyConfig struct {
|
||||
Target *url.URL
|
||||
Target *url.URL
|
||||
Targets []string
|
||||
|
||||
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||
|
||||
Transport http.RoundTripper
|
||||
FlushInterval time.Duration
|
||||
|
|
@ -61,17 +67,18 @@ type ReverseProxyConfig struct {
|
|||
}
|
||||
|
||||
var (
|
||||
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
|
||||
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
|
||||
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
|
||||
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
|
||||
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
|
||||
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
|
||||
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
|
||||
)
|
||||
|
||||
type reverseProxyHandler struct {
|
||||
config ReverseProxyConfig
|
||||
target *url.URL
|
||||
receivedBy string
|
||||
configError error
|
||||
extendedConnectTransport http.RoundTripper
|
||||
config ReverseProxyConfig
|
||||
upstreams []*reverseProxyUpstream
|
||||
receivedBy string
|
||||
configError error
|
||||
roundRobin atomic.Uint64
|
||||
}
|
||||
|
||||
type reverseProxyStatusError struct {
|
||||
|
|
@ -199,22 +206,16 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc {
|
|||
}
|
||||
|
||||
func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
||||
target := cloneReverseProxyURL(config.Target)
|
||||
if target != nil {
|
||||
normalizeReverseProxyTarget(target)
|
||||
}
|
||||
|
||||
proxy := &reverseProxyHandler{
|
||||
config: config,
|
||||
target: target,
|
||||
receivedBy: reverseProxyReceivedBy(config.Via),
|
||||
}
|
||||
if config.Transport == nil {
|
||||
proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
||||
}
|
||||
|
||||
if err := validateReverseProxyTarget(target); err != nil {
|
||||
upstreams, err := buildReverseProxyUpstreams(config)
|
||||
if err != nil {
|
||||
proxy.configError = err
|
||||
} else {
|
||||
proxy.upstreams = upstreams
|
||||
}
|
||||
|
||||
switch config.ForwardedHeaders {
|
||||
|
|
@ -228,6 +229,11 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
|||
proxy.configError = err
|
||||
}
|
||||
}
|
||||
if proxy.configError == nil {
|
||||
if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil {
|
||||
proxy.configError = err
|
||||
}
|
||||
}
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
|
@ -240,15 +246,6 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
return
|
||||
}
|
||||
|
||||
transport := p.config.Transport
|
||||
if transport == nil {
|
||||
if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil {
|
||||
transport = p.extendedConnectTransport
|
||||
} else {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
}
|
||||
|
||||
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
||||
if err != nil {
|
||||
p.handleError(c, err)
|
||||
|
|
@ -260,86 +257,64 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
|
||||
ctx, cancel := p.requestContext(c)
|
||||
defer cancel()
|
||||
attempted := make(map[string]struct{}, len(p.upstreams))
|
||||
attempts := 0
|
||||
started := time.Now()
|
||||
var lastErr error
|
||||
|
||||
outreq := c.Request.Clone(ctx)
|
||||
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
|
||||
outreq.Body = nil
|
||||
}
|
||||
if outreq.Body != nil {
|
||||
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
|
||||
defer outreq.Body.Close()
|
||||
}
|
||||
if outreq.Header == nil {
|
||||
outreq.Header = make(http.Header)
|
||||
}
|
||||
outreq.Close = false
|
||||
var connectWriter *io.PipeWriter
|
||||
defer func() {
|
||||
if connectWriter != nil {
|
||||
_ = connectWriter.Close()
|
||||
}
|
||||
}()
|
||||
if outreq.Method == http.MethodConnect {
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
outreq.Body = pipeReader
|
||||
outreq.ContentLength = -1
|
||||
defer outreq.Body.Close()
|
||||
connectWriter = pipeWriter
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
rewriteReverseProxyURL(outreq, p.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
} else {
|
||||
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
||||
p.handleError(c, err)
|
||||
for {
|
||||
upstream, err := p.selectUpstream(c, attempted)
|
||||
if err != nil {
|
||||
if lastErr != nil {
|
||||
p.handleError(c, lastErr)
|
||||
return
|
||||
}
|
||||
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
rewriteReverseProxyURL(outreq, p.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
}
|
||||
if updatedMaxForwards != "" {
|
||||
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
|
||||
}
|
||||
|
||||
reqUpType := reverseProxyUpgradeType(outreq.Header)
|
||||
if reqUpType != "" && !isPrintableASCII(reqUpType) {
|
||||
p.handleError(c, &reverseProxyStatusError{
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
|
||||
})
|
||||
attempts++
|
||||
upstream.inFlight.Add(1)
|
||||
served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards)
|
||||
upstream.inFlight.Add(-1)
|
||||
|
||||
if served {
|
||||
return
|
||||
}
|
||||
if attemptErr != nil {
|
||||
lastErr = attemptErr
|
||||
}
|
||||
if retriable && p.shouldRetryAttempt(c.Request, attempts, started) {
|
||||
attempted[upstream.key] = struct{}{}
|
||||
if !p.waitRetryInterval(ctx, started) {
|
||||
if lastErr != nil {
|
||||
p.handleError(c, lastErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attemptErr != nil {
|
||||
p.handleError(c, attemptErr)
|
||||
return
|
||||
}
|
||||
if lastErr != nil {
|
||||
p.handleError(c, lastErr)
|
||||
return
|
||||
}
|
||||
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(outreq.Header)
|
||||
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
|
||||
outreq.Header.Set("Te", "trailers")
|
||||
}
|
||||
if reqUpType != "" {
|
||||
outreq.Header.Set("Connection", "Upgrade")
|
||||
outreq.Header.Set("Upgrade", reqUpType)
|
||||
}
|
||||
|
||||
p.addForwardingHeaders(c.Request, outreq)
|
||||
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
|
||||
|
||||
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||
outreq.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
if p.config.ModifyRequest != nil {
|
||||
p.config.ModifyRequest(outreq)
|
||||
func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) {
|
||||
outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards)
|
||||
if err != nil {
|
||||
return false, err, false
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
transport := p.transportForUpstream(c.Request, upstream)
|
||||
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
|
||||
var (
|
||||
roundTripMu sync.Mutex
|
||||
|
|
@ -369,8 +344,13 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
roundTripDone = true
|
||||
roundTripMu.Unlock()
|
||||
if err != nil {
|
||||
p.handleError(c, err)
|
||||
return
|
||||
if reverseProxyShouldCountPassiveFailure(outreq, err) {
|
||||
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
|
||||
}
|
||||
return false, err, true
|
||||
}
|
||||
if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) {
|
||||
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
|
||||
|
|
@ -381,35 +361,34 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
res.TransferEncoding = nil
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
handleConnect := p.handleConnectResponse
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
handleConnect = p.handleExtendedConnectResponse
|
||||
}
|
||||
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
|
||||
p.handleError(c, err)
|
||||
return false, err, false
|
||||
}
|
||||
connectWriter = nil
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
if err := p.handleUpgradeResponse(c, outreq, res); err != nil {
|
||||
p.handleError(c, err)
|
||||
return false, err, false
|
||||
}
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(res.Header)
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
|
||||
|
|
@ -432,7 +411,7 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
if reverseProxyShouldPanicOnCopyError(c.Request) {
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
|
|
@ -440,13 +419,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// Keep the stdlib-compatible fallback here.
|
||||
// If the backend only exposes additional trailer keys after the body has been
|
||||
// fully read, the trailer map can grow and those values must be written using
|
||||
// the TrailerPrefix form instead of the pre-announced bare header keys.
|
||||
if len(res.Trailer) == announcedTrailers {
|
||||
reverseProxyCopyHeader(c.Writer.Header(), res.Trailer)
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
for key, values := range res.Trailer {
|
||||
|
|
@ -455,6 +430,148 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
c.Writer.Header().Add(prefixedKey, value)
|
||||
}
|
||||
}
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
|
||||
outreq := c.Request.Clone(ctx)
|
||||
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
|
||||
outreq.Body = nil
|
||||
} else if c.Request.GetBody != nil {
|
||||
body, err := c.Request.GetBody()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err)
|
||||
}
|
||||
outreq.Body = body
|
||||
} else if outreq.Body != nil {
|
||||
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
|
||||
}
|
||||
if outreq.Header == nil {
|
||||
outreq.Header = make(http.Header)
|
||||
}
|
||||
outreq.Close = false
|
||||
var connectWriter *io.PipeWriter
|
||||
if outreq.Method == http.MethodConnect {
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
outreq.Body = pipeReader
|
||||
outreq.ContentLength = -1
|
||||
connectWriter = pipeWriter
|
||||
}
|
||||
cleanup := func() {
|
||||
if outreq.Body != nil {
|
||||
_ = outreq.Body.Close()
|
||||
}
|
||||
if connectWriter != nil {
|
||||
_ = connectWriter.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
rewriteReverseProxyURL(outreq, upstream.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
} else {
|
||||
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
|
||||
cleanup()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rewriteReverseProxyURL(outreq, upstream.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
}
|
||||
if updatedMaxForwards != "" {
|
||||
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
|
||||
}
|
||||
|
||||
reqUpType := reverseProxyUpgradeType(outreq.Header)
|
||||
if reqUpType != "" && !isPrintableASCII(reqUpType) {
|
||||
cleanup()
|
||||
return nil, nil, nil, &reverseProxyStatusError{
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
|
||||
}
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(outreq.Header)
|
||||
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
|
||||
outreq.Header.Set("Te", "trailers")
|
||||
}
|
||||
if reqUpType != "" {
|
||||
outreq.Header.Set("Connection", "Upgrade")
|
||||
outreq.Header.Set("Upgrade", reqUpType)
|
||||
}
|
||||
|
||||
p.addForwardingHeaders(c.Request, outreq)
|
||||
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
|
||||
|
||||
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||
outreq.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
if p.config.ModifyRequest != nil {
|
||||
p.config.ModifyRequest(outreq)
|
||||
}
|
||||
|
||||
return outreq, connectWriter, cleanup, nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper {
|
||||
if p.config.Transport != nil {
|
||||
return p.config.Transport
|
||||
}
|
||||
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
|
||||
return upstream.extendedConnectTransport
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool {
|
||||
if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) {
|
||||
return false
|
||||
}
|
||||
lb := p.config.LoadBalancing
|
||||
if lb.TryDuration > 0 {
|
||||
return time.Since(started) < lb.TryDuration
|
||||
}
|
||||
return attempts <= lb.Retries
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool {
|
||||
interval := p.config.LoadBalancing.TryInterval
|
||||
tryDuration := p.config.LoadBalancing.TryDuration
|
||||
if tryDuration > 0 && interval == 0 {
|
||||
interval = 250 * time.Millisecond
|
||||
}
|
||||
if tryDuration > 0 {
|
||||
remaining := tryDuration - time.Since(started)
|
||||
if remaining <= 0 {
|
||||
return false
|
||||
}
|
||||
if interval <= 0 {
|
||||
return ctx.Err() == nil
|
||||
}
|
||||
if interval > remaining {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if interval <= 0 {
|
||||
return ctx.Err() == nil
|
||||
}
|
||||
timer := time.NewTimer(interval)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-timer.C:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
|
||||
|
|
@ -976,6 +1093,54 @@ func validateReverseProxyTarget(target *url.URL) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) {
|
||||
if config.Target != nil && len(config.Targets) > 0 {
|
||||
return nil, errors.New("reverse proxy Target and Targets cannot be used together")
|
||||
}
|
||||
|
||||
targets := make([]*url.URL, 0, max(1, len(config.Targets)))
|
||||
if config.Target != nil {
|
||||
target := cloneReverseProxyURL(config.Target)
|
||||
normalizeReverseProxyTarget(target)
|
||||
if err := validateReverseProxyTarget(target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targets = append(targets, target)
|
||||
}
|
||||
for i, rawTarget := range config.Targets {
|
||||
trimmed := strings.TrimSpace(rawTarget)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("reverse proxy target at index %d is empty", i)
|
||||
}
|
||||
target, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
|
||||
}
|
||||
normalizeReverseProxyTarget(target)
|
||||
if err := validateReverseProxyTarget(target); err != nil {
|
||||
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
|
||||
}
|
||||
targets = append(targets, target)
|
||||
}
|
||||
if len(targets) == 0 {
|
||||
return nil, errReverseProxyNilTarget
|
||||
}
|
||||
|
||||
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
|
||||
for i, target := range targets {
|
||||
upstream := &reverseProxyUpstream{
|
||||
key: fmt.Sprintf("%d:%s", i, target.String()),
|
||||
target: target,
|
||||
index: i,
|
||||
}
|
||||
if config.Transport == nil {
|
||||
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
||||
}
|
||||
upstreams = append(upstreams, upstream)
|
||||
}
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
func validateReverseProxyForwardedBy(value string) error {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
|
|
@ -1388,6 +1553,35 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
|
|||
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
||||
}
|
||||
|
||||
func reverseProxyCanRetryRequest(req *http.Request) bool {
|
||||
if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) {
|
||||
return false
|
||||
}
|
||||
if req.Body == nil || req.ContentLength == 0 {
|
||||
return true
|
||||
}
|
||||
return req.GetBody != nil
|
||||
}
|
||||
|
||||
func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool {
|
||||
if err == nil || reverseProxyIsBenignTunnelError(err) {
|
||||
return false
|
||||
}
|
||||
if req != nil && req.Context().Err() != nil {
|
||||
return false
|
||||
}
|
||||
return !errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func reverseProxyMethodIsSafe(method string) bool {
|
||||
switch method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func reverseProxyIsBenignTunnelError(err error) bool {
|
||||
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err)
|
||||
}
|
||||
|
|
@ -1396,6 +1590,10 @@ func reverseProxyIsClosedBodyError(err error) bool {
|
|||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var streamErr http2.StreamError
|
||||
if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel {
|
||||
return true
|
||||
}
|
||||
switch err.Error() {
|
||||
case "body closed by handler", "http2: response body closed", "response body closed":
|
||||
return true
|
||||
|
|
|
|||
352
reverseproxy_lb.go
Normal file
352
reverseproxy_lb.go
Normal file
|
|
@ -0,0 +1,352 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
|
||||
type ReverseProxyLoadBalancingConfig struct {
|
||||
Policy ReverseProxyLBPolicy
|
||||
Retries int
|
||||
TryDuration time.Duration
|
||||
TryInterval time.Duration
|
||||
}
|
||||
|
||||
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
|
||||
type ReverseProxyPassiveHealthConfig struct {
|
||||
FailDuration time.Duration
|
||||
MaxFails int
|
||||
UnhealthyStatus []int
|
||||
}
|
||||
|
||||
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
|
||||
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
|
||||
type ReverseProxyLBPolicy struct {
|
||||
kind reverseProxyLBPolicyKind
|
||||
key string
|
||||
fallback *ReverseProxyLBPolicy
|
||||
}
|
||||
|
||||
type reverseProxyLBPolicyKind uint8
|
||||
|
||||
const (
|
||||
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
|
||||
reverseProxyLBPolicyRoundRobin
|
||||
reverseProxyLBPolicyFirst
|
||||
reverseProxyLBPolicyLeastConn
|
||||
reverseProxyLBPolicyIPHash
|
||||
reverseProxyLBPolicyClientIPHash
|
||||
reverseProxyLBPolicyURIHash
|
||||
reverseProxyLBPolicyHeader
|
||||
reverseProxyLBPolicyQuery
|
||||
)
|
||||
|
||||
type reverseProxyUpstream struct {
|
||||
key string
|
||||
target *url.URL
|
||||
index int
|
||||
extendedConnectTransport http.RoundTripper
|
||||
inFlight atomic.Int64
|
||||
|
||||
passiveMu sync.Mutex
|
||||
failures []time.Time
|
||||
}
|
||||
|
||||
func LBRandom() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
|
||||
}
|
||||
|
||||
func LBRoundRobin() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
|
||||
}
|
||||
|
||||
func LBFirst() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
|
||||
}
|
||||
|
||||
func LBLeastConn() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
|
||||
}
|
||||
|
||||
func LBIPHash() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
|
||||
}
|
||||
|
||||
func LBClientIPHash() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
|
||||
}
|
||||
|
||||
func LBURIHash() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
|
||||
}
|
||||
|
||||
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
|
||||
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||
policy.fallback = &fallback
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
|
||||
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||
policy.fallback = &fallback
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
|
||||
switch policy.kind {
|
||||
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
|
||||
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
|
||||
reverseProxyLBPolicyURIHash:
|
||||
return nil
|
||||
case reverseProxyLBPolicyHeader:
|
||||
if policy.key == "" {
|
||||
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
|
||||
}
|
||||
case reverseProxyLBPolicyQuery:
|
||||
if policy.key == "" {
|
||||
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
|
||||
}
|
||||
if policy.fallback != nil {
|
||||
return validateReverseProxyLBPolicy(*policy.fallback)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
|
||||
now := time.Now()
|
||||
policy := p.config.LoadBalancing.Policy
|
||||
candidates := p.availableUpstreams(now, excluded)
|
||||
if len(candidates) == 0 && len(excluded) > 0 {
|
||||
candidates = p.availableUpstreams(now, nil)
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, errReverseProxyNoAvailableUpstreams
|
||||
}
|
||||
return p.selectUpstreamWithPolicy(c, candidates, policy), nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
|
||||
candidates := make([]*reverseProxyUpstream, 0, len(p.upstreams))
|
||||
for _, upstream := range p.upstreams {
|
||||
if _, skip := excluded[upstream.key]; skip {
|
||||
continue
|
||||
}
|
||||
if !upstream.healthy(now, p.config.PassiveHealth) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, upstream)
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch policy.kind {
|
||||
case reverseProxyLBPolicyRoundRobin:
|
||||
return candidates[p.nextRoundRobinIndex(len(candidates))]
|
||||
case reverseProxyLBPolicyFirst:
|
||||
return candidates[0]
|
||||
case reverseProxyLBPolicyLeastConn:
|
||||
return p.selectLeastConnUpstream(candidates)
|
||||
case reverseProxyLBPolicyIPHash:
|
||||
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
|
||||
case reverseProxyLBPolicyClientIPHash:
|
||||
return reverseProxySelectHRW(candidates, c.RequestIP())
|
||||
case reverseProxyLBPolicyURIHash:
|
||||
if c.Request == nil || c.Request.URL == nil {
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
|
||||
case reverseProxyLBPolicyHeader:
|
||||
if c.Request != nil && c.Request.Header != nil {
|
||||
if values, ok := c.Request.Header[policy.key]; ok {
|
||||
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||
}
|
||||
}
|
||||
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||
case reverseProxyLBPolicyQuery:
|
||||
if c.Request != nil && c.Request.URL != nil {
|
||||
if values, ok := c.Request.URL.Query()[policy.key]; ok {
|
||||
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||
}
|
||||
}
|
||||
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||
case reverseProxyLBPolicyRandom:
|
||||
fallthrough
|
||||
default:
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
|
||||
if size <= 1 {
|
||||
return 0
|
||||
}
|
||||
return int((p.roundRobin.Add(1) - 1) % uint64(size))
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
selected := candidates[0]
|
||||
lowest := selected.inFlight.Load()
|
||||
ties := []*reverseProxyUpstream{selected}
|
||||
for _, upstream := range candidates[1:] {
|
||||
count := upstream.inFlight.Load()
|
||||
switch {
|
||||
case count < lowest:
|
||||
selected = upstream
|
||||
lowest = count
|
||||
ties = []*reverseProxyUpstream{upstream}
|
||||
case count == lowest:
|
||||
ties = append(ties, upstream)
|
||||
}
|
||||
}
|
||||
if len(ties) == 1 {
|
||||
return selected
|
||||
}
|
||||
return ties[p.nextRoundRobinIndex(len(ties))]
|
||||
}
|
||||
|
||||
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(candidates) == 1 {
|
||||
return candidates[0]
|
||||
}
|
||||
return candidates[rand.IntN(len(candidates))]
|
||||
}
|
||||
|
||||
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if key == "" {
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
selected := candidates[0]
|
||||
bestScore := reverseProxyHRWScore(key, selected.key)
|
||||
for _, upstream := range candidates[1:] {
|
||||
score := reverseProxyHRWScore(key, upstream.key)
|
||||
if score > bestScore {
|
||||
selected = upstream
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
|
||||
const (
|
||||
offset64 = 14695981039346656037
|
||||
prime64 = 1099511628211
|
||||
)
|
||||
h := uint64(offset64)
|
||||
for i := 0; i < len(key); i++ {
|
||||
h ^= uint64(key[i])
|
||||
h *= prime64
|
||||
}
|
||||
h ^= 0xff
|
||||
h *= prime64
|
||||
for i := 0; i < len(upstreamKey); i++ {
|
||||
h ^= uint64(upstreamKey[i])
|
||||
h *= prime64
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||
if policy.fallback != nil {
|
||||
return *policy.fallback
|
||||
}
|
||||
return LBRandom()
|
||||
}
|
||||
|
||||
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
|
||||
maxFails := reverseProxyPassiveMaxFails(config)
|
||||
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
u.passiveMu.Lock()
|
||||
defer u.passiveMu.Unlock()
|
||||
u.pruneFailuresLocked(now, config.FailDuration)
|
||||
return len(u.failures) < maxFails
|
||||
}
|
||||
|
||||
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
|
||||
maxFails := reverseProxyPassiveMaxFails(config)
|
||||
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
u.passiveMu.Lock()
|
||||
defer u.passiveMu.Unlock()
|
||||
u.pruneFailuresLocked(now, config.FailDuration)
|
||||
u.failures = append(u.failures, now)
|
||||
}
|
||||
|
||||
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
|
||||
if len(u.failures) == 0 || window <= 0 {
|
||||
if window <= 0 {
|
||||
u.failures = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
cutoff := now.Add(-window)
|
||||
keep := 0
|
||||
for _, failureAt := range u.failures {
|
||||
if failureAt.Before(cutoff) {
|
||||
continue
|
||||
}
|
||||
u.failures[keep] = failureAt
|
||||
keep++
|
||||
}
|
||||
u.failures = u.failures[:keep]
|
||||
}
|
||||
|
||||
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
|
||||
if config.FailDuration <= 0 {
|
||||
return 0
|
||||
}
|
||||
if config.MaxFails <= 0 {
|
||||
return 1
|
||||
}
|
||||
return config.MaxFails
|
||||
}
|
||||
|
||||
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
|
||||
if status <= 0 {
|
||||
return false
|
||||
}
|
||||
for _, unhealthyStatus := range config.UnhealthyStatus {
|
||||
if status == unhealthyStatus {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -13,7 +13,9 @@ import (
|
|||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -262,6 +264,507 @@ func TestReverseProxyDefaultViaFallback(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://example.com"),
|
||||
Targets: []string{"http://example.net"},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
type snapshot struct {
|
||||
Path string
|
||||
RawQuery string
|
||||
}
|
||||
|
||||
backendOneCh := make(chan snapshot, 1)
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwoCh := make(chan snapshot, 1)
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
backendOne.URL + "/one?from=one",
|
||||
backendTwo.URL + "/two?from=two",
|
||||
},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
|
||||
}))
|
||||
|
||||
first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil)
|
||||
if first.Code != http.StatusOK || first.Body.String() != "one" {
|
||||
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
|
||||
}
|
||||
second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil)
|
||||
if second.Code != http.StatusOK || second.Body.String() != "two" {
|
||||
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-backendOneCh:
|
||||
if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" {
|
||||
t.Fatalf("unexpected first upstream request: %#v", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first upstream request")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-backendTwoCh:
|
||||
if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" {
|
||||
t.Fatalf("unexpected second upstream request: %#v", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for second upstream request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBHeader("X-Upstream", LBFirst()),
|
||||
},
|
||||
}))
|
||||
|
||||
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
|
||||
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
|
||||
}
|
||||
|
||||
headers := http.Header{"X-Upstream": {"tenant-a"}}
|
||||
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
|
||||
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
|
||||
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
|
||||
}
|
||||
if firstSticky.Body.String() != secondSticky.Body.String() {
|
||||
t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBQuery("tenant", LBFirst()),
|
||||
},
|
||||
}))
|
||||
|
||||
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
|
||||
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
|
||||
}
|
||||
|
||||
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
|
||||
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
|
||||
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
|
||||
}
|
||||
if firstSticky.Body.String() != secondSticky.Body.String() {
|
||||
t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"})
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBClientIPHash(),
|
||||
},
|
||||
}))
|
||||
|
||||
reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
||||
reqOne.RemoteAddr = "10.0.0.1:1234"
|
||||
reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10")
|
||||
rrOne := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rrOne, reqOne)
|
||||
|
||||
reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
||||
reqTwo.RemoteAddr = "10.0.0.2:5678"
|
||||
reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10")
|
||||
rrTwo := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rrTwo, reqTwo)
|
||||
|
||||
if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code)
|
||||
}
|
||||
if rrOne.Body.String() != rrTwo.Body.String() {
|
||||
t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
Retries: 1,
|
||||
},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
||||
t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, r.Header.Get("X-Attempt"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
var attempts atomic.Int64
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
Retries: 1,
|
||||
},
|
||||
ModifyRequest: func(req *http.Request) {
|
||||
req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10))
|
||||
},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
if rr.Body.String() != "2" {
|
||||
t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendCalls := make(chan struct{}, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendCalls <- struct{}{}
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
Retries: 1,
|
||||
},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil)
|
||||
if rr.Code != http.StatusBadGateway {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-backendCalls:
|
||||
t.Fatal("unsafe POST request should not be retried to the next upstream")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendOneStarted := make(chan struct{}, 1)
|
||||
releaseBackendOne := make(chan struct{})
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendOneStarted <- struct{}{}
|
||||
<-releaseBackendOne
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBLeastConn(),
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
client := proxy.Client()
|
||||
client.Timeout = 5 * time.Second
|
||||
|
||||
firstRespCh := make(chan string, 1)
|
||||
firstErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
resp, err := client.Get(proxy.URL + "/proxy")
|
||||
if err != nil {
|
||||
firstErrCh <- err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
firstErrCh <- err
|
||||
return
|
||||
}
|
||||
firstRespCh <- string(body)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-backendOneStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first backend request")
|
||||
}
|
||||
|
||||
secondResp, err := client.Get(proxy.URL + "/proxy")
|
||||
if err != nil {
|
||||
close(releaseBackendOne)
|
||||
t.Fatalf("second request failed: %v", err)
|
||||
}
|
||||
secondBody, err := io.ReadAll(secondResp.Body)
|
||||
_ = secondResp.Body.Close()
|
||||
if err != nil {
|
||||
close(releaseBackendOne)
|
||||
t.Fatalf("read second response: %v", err)
|
||||
}
|
||||
if string(secondBody) != "two" {
|
||||
close(releaseBackendOne)
|
||||
t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody))
|
||||
}
|
||||
|
||||
close(releaseBackendOne)
|
||||
select {
|
||||
case err := <-firstErrCh:
|
||||
t.Fatalf("first request failed: %v", err)
|
||||
case body := <-firstRespCh:
|
||||
if body != "one" {
|
||||
t.Fatalf("unexpected first response body: %q", body)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first response body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
primaryCalls := make(chan struct{}, 4)
|
||||
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
primaryCalls <- struct{}{}
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = io.WriteString(w, "primary down")
|
||||
}))
|
||||
defer primary.Close()
|
||||
|
||||
secondaryCalls := make(chan struct{}, 4)
|
||||
secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
secondaryCalls <- struct{}{}
|
||||
_, _ = io.WriteString(w, "secondary up")
|
||||
}))
|
||||
defer secondary.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{primary.URL, secondary.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
},
|
||||
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
UnhealthyStatus: []int{http.StatusServiceUnavailable},
|
||||
},
|
||||
}))
|
||||
|
||||
first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" {
|
||||
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
|
||||
}
|
||||
second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if second.Code != http.StatusOK || second.Body.String() != "secondary up" {
|
||||
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
|
||||
}
|
||||
|
||||
select {
|
||||
case <-primaryCalls:
|
||||
default:
|
||||
t.Fatal("expected primary to receive the first request")
|
||||
}
|
||||
select {
|
||||
case <-secondaryCalls:
|
||||
default:
|
||||
t.Fatal("expected secondary to receive the second request")
|
||||
}
|
||||
select {
|
||||
case <-primaryCalls:
|
||||
t.Fatal("primary should not receive the second request while unhealthy")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
started := make(chan struct{}, 1)
|
||||
release := make(chan struct{})
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
started <- struct{}{}
|
||||
<-release
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backend.URL},
|
||||
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("new request: %v", err)
|
||||
}
|
||||
client := proxy.Client()
|
||||
respCh := make(chan error, 1)
|
||||
go func() {
|
||||
resp, err := client.Do(req)
|
||||
if resp != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
respCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for backend request")
|
||||
}
|
||||
cancel()
|
||||
close(release)
|
||||
select {
|
||||
case <-respCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for canceled request to finish")
|
||||
}
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
||||
t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendCalls := make(chan struct{}, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendCalls <- struct{}{}
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
Retries: 3,
|
||||
TryDuration: 100 * time.Millisecond,
|
||||
TryInterval: 250 * time.Millisecond,
|
||||
},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusBadGateway {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-backendCalls:
|
||||
t.Fatal("retry budget should expire before the next upstream attempt")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -967,6 +1470,122 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 8)
|
||||
newBackend := func(name string) *httptest.Server {
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
|
||||
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, name+":"+line); err != nil {
|
||||
errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
}))
|
||||
server.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil {
|
||||
t.Fatalf("configure %s HTTP/2 server: %v", name, err)
|
||||
}
|
||||
server.StartTLS()
|
||||
return server
|
||||
}
|
||||
|
||||
backendOne := newBackend("one")
|
||||
defer backendOne.Close()
|
||||
backendTwo := newBackend("two")
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBRoundRobin(),
|
||||
},
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
proxy := httptest.NewUnstartedServer(engine)
|
||||
proxy.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||
}
|
||||
proxy.StartTLS()
|
||||
defer proxy.Close()
|
||||
|
||||
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
defer transport.CloseIdleConnections()
|
||||
|
||||
doRequest := func(payload string) string {
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||
if err != nil {
|
||||
t.Fatalf("new CONNECT request: %v", err)
|
||||
}
|
||||
req.Header.Set(":protocol", "websocket")
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
if _, err := io.WriteString(pw, payload+"\n"); err != nil {
|
||||
t.Fatalf("write tunneled request body: %v", err)
|
||||
}
|
||||
if err := pw.Close(); err != nil {
|
||||
t.Fatalf("close tunneled request body: %v", err)
|
||||
}
|
||||
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read tunneled response body: %v", err)
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
if got := doRequest("ping"); got != "one:ping\n" {
|
||||
t.Fatalf("unexpected first tunneled response: %q", got)
|
||||
}
|
||||
if got := doRequest("pong"); got != "two:pong\n" {
|
||||
t.Fatalf("unexpected second tunneled response: %q", got)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue