mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Merge pull request #77 from infinite-iroha/break/v1-enhance-reverse-proxy
feat(reverseproxy): add upstream balancing and protocol improvements
This commit is contained in:
commit
0d7721a24c
11 changed files with 2759 additions and 92 deletions
|
|
@ -60,6 +60,10 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
|||
```go
|
||||
type ReverseProxyConfig struct {
|
||||
Target *url.URL
|
||||
Targets []string
|
||||
|
||||
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||
|
||||
Transport http.RoundTripper
|
||||
FlushInterval time.Duration
|
||||
|
|
@ -78,12 +82,115 @@ type ReverseProxyConfig struct {
|
|||
|
||||
### `Target`
|
||||
|
||||
必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。
|
||||
与 `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme` 和 `host`。
|
||||
|
||||
```go
|
||||
target, _ := url.Parse("http://backend:9000")
|
||||
```
|
||||
|
||||
### `Targets`
|
||||
|
||||
可选。用于配置多个后端目标地址。
|
||||
|
||||
- `Target` 与 `Targets` 互斥,只能使用其中一种
|
||||
- `Targets` 的每一项都必须是完整 URL
|
||||
- 每个 target 仍然可以自带 base path 和 query
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
"http://127.0.0.1:9001/base?from=a",
|
||||
"http://127.0.0.1:9002/base?from=b",
|
||||
},
|
||||
}))
|
||||
```
|
||||
|
||||
这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。
|
||||
|
||||
### `LoadBalancing`
|
||||
|
||||
用于配置 upstream 选择策略和重试行为。
|
||||
|
||||
```go
|
||||
type ReverseProxyLoadBalancingConfig struct {
|
||||
Policy ReverseProxyLBPolicy
|
||||
Retries int
|
||||
TryDuration time.Duration
|
||||
TryInterval time.Duration
|
||||
}
|
||||
```
|
||||
|
||||
当前内置策略:
|
||||
|
||||
- `touka.LBRandom()`
|
||||
- `touka.LBRoundRobin()`
|
||||
- `touka.LBFirst()`
|
||||
- `touka.LBLeastConn()`
|
||||
- `touka.LBIPHash()`
|
||||
- `touka.LBClientIPHash()`
|
||||
- `touka.LBURIHash()`
|
||||
- `touka.LBHeader("X-Upstream", fallback)`
|
||||
- `touka.LBQuery("tenant", fallback)`
|
||||
|
||||
其中:
|
||||
|
||||
- `LBFirst()` 适合主备/故障转移顺序
|
||||
- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback
|
||||
- 如果 `LBHeader` / `LBQuery` 没有显式 fallback,则默认回退到 `LBRandom()`
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
"http://127.0.0.1:9001",
|
||||
"http://127.0.0.1:9002",
|
||||
},
|
||||
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||
Policy: touka.LBHeader("X-Upstream", touka.LBFirst()),
|
||||
Retries: 1,
|
||||
},
|
||||
}))
|
||||
```
|
||||
|
||||
重试说明:
|
||||
|
||||
- 只对未开始收到上游响应的失败进行重试
|
||||
- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试
|
||||
- `Retries` 表示额外重试次数
|
||||
- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机
|
||||
- `TryInterval` 表示两次重试之间的等待间隔
|
||||
|
||||
### `PassiveHealth`
|
||||
|
||||
用于配置被动健康检查。它不会后台探测 upstream,而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。
|
||||
|
||||
```go
|
||||
type ReverseProxyPassiveHealthConfig struct {
|
||||
FailDuration time.Duration
|
||||
MaxFails int
|
||||
UnhealthyStatus []int
|
||||
}
|
||||
```
|
||||
|
||||
- `FailDuration > 0` 时启用被动健康跟踪
|
||||
- `MaxFails <= 0` 时默认按 `1` 处理
|
||||
- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
"http://127.0.0.1:9001",
|
||||
"http://127.0.0.1:9002",
|
||||
},
|
||||
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||
Policy: touka.LBFirst(),
|
||||
},
|
||||
PassiveHealth: touka.ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
UnhealthyStatus: []int{http.StatusServiceUnavailable},
|
||||
},
|
||||
}))
|
||||
```
|
||||
|
||||
### `Transport`
|
||||
|
||||
可选。用于自定义底层转发所使用的 `http.RoundTripper`。
|
||||
|
|
@ -150,6 +257,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
|||
|
||||
在请求真正发往后端前,对出站请求做最后修改。
|
||||
|
||||
如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。
|
||||
|
||||
常见用途:
|
||||
|
||||
- 覆盖 `Host`
|
||||
|
|
@ -242,11 +351,20 @@ const (
|
|||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Target: target,
|
||||
ForwardedHeaders: touka.ForwardedBoth,
|
||||
ForwardedBy: "gateway-1",
|
||||
ForwardedBy: "_gateway-1",
|
||||
Via: "edge-1",
|
||||
}))
|
||||
```
|
||||
|
||||
如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。
|
||||
|
||||
- IPv4:`203.0.113.43`
|
||||
- IPv6 / 带端口:`[2001:db8::17]:443`
|
||||
- 匿名标识:`_gateway-1`
|
||||
- 未知:`unknown`
|
||||
|
||||
像 `gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。
|
||||
|
||||
`Via` 不是“留空即禁用”的开关。当前实现中:
|
||||
|
||||
- 如果 `Via` 非空,则使用该值追加 `Via`
|
||||
|
|
@ -282,11 +400,14 @@ Touka 会尽量遵循代理链语义:
|
|||
|
||||
Touka 的反向代理实现支持以下能力:
|
||||
|
||||
- `CONNECT` 隧道转发(HTTP/1.x)
|
||||
- HTTP/2 extended `CONNECT`
|
||||
- `Connection: Upgrade` / `Upgrade` 协议升级转发
|
||||
- WebSocket 等 101 Switching Protocols 场景
|
||||
- SSE(Server-Sent Events)立即刷新
|
||||
- Trailer 透传
|
||||
- 1xx 响应透传
|
||||
- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理
|
||||
|
||||
例如,代理 WebSocket 服务:
|
||||
|
||||
|
|
@ -341,7 +462,7 @@ func main() {
|
|||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Target: target,
|
||||
ForwardedHeaders: touka.ForwardedBoth,
|
||||
ForwardedBy: "gateway-1",
|
||||
ForwardedBy: "_gateway-1",
|
||||
Via: "gateway-1",
|
||||
FlushInterval: 100 * time.Millisecond,
|
||||
ModifyRequest: func(req *http.Request) {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ r.ANY("/any", handle)
|
|||
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
|
||||
```
|
||||
|
||||
服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。
|
||||
|
||||
## 路径参数 (Named Parameters)
|
||||
|
||||
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。
|
||||
|
|
|
|||
2
ecw.go
2
ecw.go
|
|
@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool {
|
|||
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := ecw.w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface")
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
|
|
|||
77
engine.go
77
engine.go
|
|
@ -7,6 +7,7 @@ package touka
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
|
@ -344,6 +345,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
|
|||
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
|
||||
if engine.serverProtocols != nil {
|
||||
srv.Protocols = engine.serverProtocols
|
||||
if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() {
|
||||
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -475,21 +481,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
|
|||
func MethodNotAllowed() HandlerFunc {
|
||||
return func(c *Context) {
|
||||
httpMethod := c.Request.Method
|
||||
requestPath := c.Request.URL.Path
|
||||
requestPath := routeLookupPath(c.Request)
|
||||
engine := c.engine
|
||||
// 是否是OPTIONS方式
|
||||
if httpMethod == http.MethodOptions {
|
||||
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
||||
allowedMethods := []string{}
|
||||
for _, treeIter := range engine.methodTrees {
|
||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||
tempSkippedNodes := GetTempSkippedNodes()
|
||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
||||
PutTempSkippedNodes(tempSkippedNodes)
|
||||
if value.handlers != nil {
|
||||
allowedMethods = append(allowedMethods, treeIter.method)
|
||||
}
|
||||
}
|
||||
allowedMethods := engine.allowedMethodsForPath(requestPath)
|
||||
if len(allowedMethods) > 0 {
|
||||
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
|
||||
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
|
||||
|
|
@ -704,8 +701,13 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||
// handleRequest 负责根据请求查找路由并执行相应的处理函数链
|
||||
// 这是路由查找和执行的核心逻辑
|
||||
func (engine *Engine) handleRequest(c *Context) {
|
||||
if isGeneralOptionsRequest(c.Request) {
|
||||
engine.handleGeneralOptions(c)
|
||||
return
|
||||
}
|
||||
|
||||
httpMethod := c.Request.Method
|
||||
requestPath := c.Request.URL.Path
|
||||
requestPath := routeLookupPath(c.Request)
|
||||
|
||||
// 查找对应的路由树的根节点
|
||||
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
|
||||
|
|
@ -725,7 +727,7 @@ func (engine *Engine) handleRequest(c *Context) {
|
|||
}
|
||||
|
||||
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
|
||||
if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向
|
||||
if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向
|
||||
if value.tsr && engine.RedirectTrailingSlash {
|
||||
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
|
||||
redirectPath := requestPath
|
||||
|
|
@ -782,6 +784,55 @@ func (engine *Engine) handleRequest(c *Context) {
|
|||
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
||||
}
|
||||
|
||||
func routeLookupPath(req *http.Request) string {
|
||||
if req == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") {
|
||||
return "/" + req.RequestURI
|
||||
}
|
||||
if isGeneralOptionsRequest(req) {
|
||||
return ""
|
||||
}
|
||||
if req.URL == nil {
|
||||
return ""
|
||||
}
|
||||
return req.URL.Path
|
||||
}
|
||||
|
||||
func isGeneralOptionsRequest(req *http.Request) bool {
|
||||
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
|
||||
}
|
||||
|
||||
func (engine *Engine) allowedMethodsForPath(requestPath string) []string {
|
||||
allowedMethods := make([]string, 0, len(engine.methodTrees))
|
||||
for _, treeIter := range engine.methodTrees {
|
||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||
tempSkippedNodes := GetTempSkippedNodes()
|
||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
||||
PutTempSkippedNodes(tempSkippedNodes)
|
||||
if value.handlers != nil {
|
||||
allowedMethods = append(allowedMethods, treeIter.method)
|
||||
}
|
||||
}
|
||||
return allowedMethods
|
||||
}
|
||||
|
||||
func (engine *Engine) handleGeneralOptions(c *Context) {
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Length", "0")
|
||||
if c.Request.ContentLength != 0 {
|
||||
mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10)
|
||||
_, _ = io.Copy(io.Discard, mb)
|
||||
}
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
|
||||
// 它可以用于在长连接 (如 SSE) 中监听关闭信号.
|
||||
func (engine *Engine) Context() context.Context {
|
||||
|
|
|
|||
3
go.mod
3
go.mod
|
|
@ -8,9 +8,10 @@ require (
|
|||
github.com/WJQSERVER/wanf v0.0.8
|
||||
github.com/fenthope/reco v0.0.5
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
|
||||
golang.org/x/net v0.52.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
)
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -12,3 +12,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
|
|||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
|
|
|
|||
53
http2xconnect.go
Normal file
53
http2xconnect.go
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
_ "unsafe"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var enableHTTP2ExtendedConnectOnce sync.Once
|
||||
|
||||
//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol
|
||||
var xnetDisableHTTP2ExtendedConnectProtocol bool
|
||||
|
||||
func enableHTTP2ExtendedConnectProtocol() {
|
||||
enableHTTP2ExtendedConnectOnce.Do(func() {
|
||||
xnetDisableHTTP2ExtendedConnectProtocol = false
|
||||
})
|
||||
}
|
||||
|
||||
func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
|
||||
if srv == nil {
|
||||
return nil
|
||||
}
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
return http2.ConfigureServer(srv, nil)
|
||||
}
|
||||
|
||||
func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper {
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
transport := &http2.Transport{}
|
||||
if target == nil || !strings.EqualFold(target.Scheme, "http") {
|
||||
return transport
|
||||
}
|
||||
|
||||
transport.AllowHTTP = true
|
||||
transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
return transport
|
||||
}
|
||||
2
respw.go
2
respw.go
|
|
@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|||
// 尝试从底层 ResponseWriter 获取 Hijacker 接口
|
||||
hj, ok := rw.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("http.Hijacker interface not supported")
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
// 调用底层的 Hijack 方法
|
||||
|
|
|
|||
813
reverseproxy.go
813
reverseproxy.go
|
|
@ -14,6 +14,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
|
|
@ -22,6 +23,8 @@ import (
|
|||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// ForwardedHeadersPolicy controls how forwarding headers are generated.
|
||||
|
|
@ -44,6 +47,10 @@ type BufferPool interface {
|
|||
// ReverseProxyConfig configures the reverse proxy handler.
|
||||
type ReverseProxyConfig struct {
|
||||
Target *url.URL
|
||||
Targets []string
|
||||
|
||||
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||
|
||||
Transport http.RoundTripper
|
||||
FlushInterval time.Duration
|
||||
|
|
@ -63,13 +70,15 @@ var (
|
|||
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
|
||||
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
|
||||
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
|
||||
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
|
||||
)
|
||||
|
||||
type reverseProxyHandler struct {
|
||||
config ReverseProxyConfig
|
||||
target *url.URL
|
||||
upstreams []*reverseProxyUpstream
|
||||
receivedBy string
|
||||
configError error
|
||||
roundRobin atomic.Uint64
|
||||
}
|
||||
|
||||
type reverseProxyStatusError struct {
|
||||
|
|
@ -197,19 +206,16 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc {
|
|||
}
|
||||
|
||||
func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
||||
target := cloneReverseProxyURL(config.Target)
|
||||
if target != nil {
|
||||
normalizeReverseProxyTarget(target)
|
||||
}
|
||||
|
||||
proxy := &reverseProxyHandler{
|
||||
config: config,
|
||||
target: target,
|
||||
receivedBy: reverseProxyReceivedBy(config.Via),
|
||||
}
|
||||
|
||||
if err := validateReverseProxyTarget(target); err != nil {
|
||||
upstreams, err := buildReverseProxyUpstreams(config)
|
||||
if err != nil {
|
||||
proxy.configError = err
|
||||
} else {
|
||||
proxy.upstreams = upstreams
|
||||
}
|
||||
|
||||
switch config.ForwardedHeaders {
|
||||
|
|
@ -217,6 +223,17 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
|||
default:
|
||||
proxy.config.ForwardedHeaders = ForwardedBoth
|
||||
}
|
||||
proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy)
|
||||
if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) {
|
||||
if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil {
|
||||
proxy.configError = err
|
||||
}
|
||||
}
|
||||
if proxy.configError == nil {
|
||||
if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil {
|
||||
proxy.configError = err
|
||||
}
|
||||
}
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
|
@ -229,62 +246,75 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
return
|
||||
}
|
||||
|
||||
transport := p.config.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
||||
if err != nil {
|
||||
p.handleError(c, err)
|
||||
return
|
||||
}
|
||||
if handledLocally {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := p.requestContext(c)
|
||||
defer cancel()
|
||||
attempted := make(map[string]struct{}, len(p.upstreams))
|
||||
attempts := 0
|
||||
started := time.Now()
|
||||
var lastErr error
|
||||
|
||||
outreq := c.Request.Clone(ctx)
|
||||
if c.Request.ContentLength == 0 {
|
||||
outreq.Body = nil
|
||||
for {
|
||||
upstream, err := p.selectUpstream(c, attempted)
|
||||
if err != nil {
|
||||
if lastErr != nil {
|
||||
p.handleError(c, lastErr)
|
||||
return
|
||||
}
|
||||
if outreq.Body != nil {
|
||||
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
|
||||
defer outreq.Body.Close()
|
||||
}
|
||||
if outreq.Header == nil {
|
||||
outreq.Header = make(http.Header)
|
||||
}
|
||||
outreq.Close = false
|
||||
|
||||
rewriteReverseProxyURL(outreq, p.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
|
||||
reqUpType := reverseProxyUpgradeType(outreq.Header)
|
||||
if reqUpType != "" && !isPrintableASCII(reqUpType) {
|
||||
p.handleError(c, &reverseProxyStatusError{
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
|
||||
})
|
||||
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err})
|
||||
return
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(outreq.Header)
|
||||
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
|
||||
outreq.Header.Set("Te", "trailers")
|
||||
}
|
||||
if reqUpType != "" {
|
||||
outreq.Header.Set("Connection", "Upgrade")
|
||||
outreq.Header.Set("Upgrade", reqUpType)
|
||||
}
|
||||
attempts++
|
||||
upstream.inFlight.Add(1)
|
||||
served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards)
|
||||
upstream.inFlight.Add(-1)
|
||||
|
||||
p.addForwardingHeaders(c.Request, outreq)
|
||||
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
|
||||
|
||||
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||
outreq.Header.Set("User-Agent", "")
|
||||
if served {
|
||||
return
|
||||
}
|
||||
|
||||
if p.config.ModifyRequest != nil {
|
||||
p.config.ModifyRequest(outreq)
|
||||
if attemptErr != nil {
|
||||
lastErr = attemptErr
|
||||
}
|
||||
if retriable && p.shouldRetryAttempt(c.Request, attempts, started) {
|
||||
attempted[upstream.key] = struct{}{}
|
||||
if !p.waitRetryInterval(ctx, started) {
|
||||
if lastErr != nil {
|
||||
p.handleError(c, lastErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attemptErr != nil {
|
||||
p.handleError(c, attemptErr)
|
||||
return
|
||||
}
|
||||
if lastErr != nil {
|
||||
p.handleError(c, lastErr)
|
||||
return
|
||||
}
|
||||
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) {
|
||||
outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards)
|
||||
if err != nil {
|
||||
return false, err, false
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
transport := p.transportForUpstream(c.Request, upstream)
|
||||
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
|
||||
var (
|
||||
roundTripMu sync.Mutex
|
||||
|
|
@ -314,26 +344,51 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
roundTripDone = true
|
||||
roundTripMu.Unlock()
|
||||
if err != nil {
|
||||
p.handleError(c, err)
|
||||
return
|
||||
if reverseProxyShouldCountPassiveFailure(outreq, err) {
|
||||
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
|
||||
}
|
||||
return false, err, true
|
||||
}
|
||||
if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) {
|
||||
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
|
||||
removeHopByHopHeaders(res.Header)
|
||||
res.Header.Del("Content-Length")
|
||||
res.Header.Del("Transfer-Encoding")
|
||||
res.ContentLength = -1
|
||||
res.TransferEncoding = nil
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
return true, nil, false
|
||||
}
|
||||
handleConnect := p.handleConnectResponse
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
handleConnect = p.handleExtendedConnectResponse
|
||||
}
|
||||
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
|
||||
return false, err, false
|
||||
}
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
if err := p.handleUpgradeResponse(c, outreq, res); err != nil {
|
||||
p.handleError(c, err)
|
||||
return false, err, false
|
||||
}
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(res.Header)
|
||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||
|
||||
if !p.modifyResponse(c, res, outreq) {
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
|
||||
|
|
@ -353,7 +408,10 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
defer res.Body.Close()
|
||||
c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err))
|
||||
p.logf(c, "reverse proxy body copy failed: %v", err)
|
||||
return
|
||||
if reverseProxyShouldPanicOnCopyError(c.Request) {
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
return true, nil, false
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
|
|
@ -361,13 +419,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// Keep the stdlib-compatible fallback here.
|
||||
// If the backend only exposes additional trailer keys after the body has been
|
||||
// fully read, the trailer map can grow and those values must be written using
|
||||
// the TrailerPrefix form instead of the pre-announced bare header keys.
|
||||
if len(res.Trailer) == announcedTrailers {
|
||||
reverseProxyCopyHeader(c.Writer.Header(), res.Trailer)
|
||||
return
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
for key, values := range res.Trailer {
|
||||
|
|
@ -376,6 +430,228 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
c.Writer.Header().Add(prefixedKey, value)
|
||||
}
|
||||
}
|
||||
return true, nil, false
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
|
||||
outreq := c.Request.Clone(ctx)
|
||||
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
|
||||
outreq.Body = nil
|
||||
} else if c.Request.GetBody != nil {
|
||||
body, err := c.Request.GetBody()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err)
|
||||
}
|
||||
outreq.Body = body
|
||||
} else if outreq.Body != nil {
|
||||
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
|
||||
}
|
||||
if outreq.Header == nil {
|
||||
outreq.Header = make(http.Header)
|
||||
}
|
||||
outreq.Close = false
|
||||
var connectWriter *io.PipeWriter
|
||||
if outreq.Method == http.MethodConnect {
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
outreq.Body = pipeReader
|
||||
outreq.ContentLength = -1
|
||||
connectWriter = pipeWriter
|
||||
}
|
||||
cleanup := func() {
|
||||
if outreq.Body != nil {
|
||||
_ = outreq.Body.Close()
|
||||
}
|
||||
if connectWriter != nil {
|
||||
_ = connectWriter.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if outreq.Method == http.MethodConnect {
|
||||
if reverseProxyIsExtendedConnectRequest(outreq) {
|
||||
rewriteReverseProxyURL(outreq, upstream.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
} else {
|
||||
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
|
||||
cleanup()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rewriteReverseProxyURL(outreq, upstream.target)
|
||||
if !p.config.PreserveHost {
|
||||
outreq.Host = ""
|
||||
}
|
||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||
}
|
||||
if updatedMaxForwards != "" {
|
||||
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
|
||||
}
|
||||
|
||||
reqUpType := reverseProxyUpgradeType(outreq.Header)
|
||||
if reqUpType != "" && !isPrintableASCII(reqUpType) {
|
||||
cleanup()
|
||||
return nil, nil, nil, &reverseProxyStatusError{
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
|
||||
}
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(outreq.Header)
|
||||
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
|
||||
outreq.Header.Set("Te", "trailers")
|
||||
}
|
||||
if reqUpType != "" {
|
||||
outreq.Header.Set("Connection", "Upgrade")
|
||||
outreq.Header.Set("Upgrade", reqUpType)
|
||||
}
|
||||
|
||||
p.addForwardingHeaders(c.Request, outreq)
|
||||
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
|
||||
|
||||
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||
outreq.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
if p.config.ModifyRequest != nil {
|
||||
p.config.ModifyRequest(outreq)
|
||||
}
|
||||
|
||||
return outreq, connectWriter, cleanup, nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper {
|
||||
if p.config.Transport != nil {
|
||||
return p.config.Transport
|
||||
}
|
||||
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
|
||||
return upstream.extendedConnectTransport
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool {
|
||||
if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) {
|
||||
return false
|
||||
}
|
||||
lb := p.config.LoadBalancing
|
||||
if lb.TryDuration > 0 {
|
||||
return time.Since(started) < lb.TryDuration
|
||||
}
|
||||
return attempts <= lb.Retries
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool {
|
||||
interval := p.config.LoadBalancing.TryInterval
|
||||
tryDuration := p.config.LoadBalancing.TryDuration
|
||||
if tryDuration > 0 && interval == 0 {
|
||||
interval = 250 * time.Millisecond
|
||||
}
|
||||
if tryDuration > 0 {
|
||||
remaining := tryDuration - time.Since(started)
|
||||
if remaining <= 0 {
|
||||
return false
|
||||
}
|
||||
if interval <= 0 {
|
||||
return ctx.Err() == nil
|
||||
}
|
||||
if interval > remaining {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if interval <= 0 {
|
||||
return ctx.Err() == nil
|
||||
}
|
||||
timer := time.NewTimer(interval)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-timer.C:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
|
||||
if c == nil || c.Request == nil {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
switch c.Request.Method {
|
||||
case http.MethodOptions, http.MethodTrace:
|
||||
default:
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards"))
|
||||
if rawValue == "" {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
value, err := strconv.Atoi(rawValue)
|
||||
if err != nil || value < 0 {
|
||||
return "", false, &reverseProxyStatusError{
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("invalid Max-Forwards value %q", rawValue),
|
||||
}
|
||||
}
|
||||
if value == 0 {
|
||||
switch c.Request.Method {
|
||||
case http.MethodTrace:
|
||||
return "", true, p.writeLocalTraceResponse(c)
|
||||
case http.MethodOptions:
|
||||
p.writeLocalOptionsResponse(c)
|
||||
return "", true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return strconv.Itoa(value - 1), false, nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error {
|
||||
if c == nil || c.Request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
traceReq := c.Request.Clone(c.Request.Context())
|
||||
traceReq.Body = nil
|
||||
traceReq.ContentLength = 0
|
||||
traceReq.TransferEncoding = nil
|
||||
traceReq.RequestURI = c.Request.RequestURI
|
||||
if traceReq.RequestURI == "" && traceReq.URL != nil {
|
||||
traceReq.RequestURI = traceReq.URL.RequestURI()
|
||||
}
|
||||
traceReq.Header = traceReq.Header.Clone()
|
||||
for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} {
|
||||
traceReq.Header.Del(key)
|
||||
}
|
||||
|
||||
dump, err := httputil.DumpRequest(traceReq, false)
|
||||
if err != nil {
|
||||
return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err}
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "message/http")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, err = c.Writer.Write(dump)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.engine != nil {
|
||||
if c.Request != nil && c.Request.RequestURI != "*" {
|
||||
if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 {
|
||||
c.Writer.Header().Set("Allow", strings.Join(allow, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) {
|
||||
|
|
@ -522,7 +798,11 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques
|
|||
clientConn, brw, err := c.Writer.Hijack()
|
||||
if err != nil {
|
||||
backConn.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
status := http.StatusBadGateway
|
||||
if errors.Is(err, http.ErrNotSupported) {
|
||||
status = http.StatusNotImplemented
|
||||
}
|
||||
return &reverseProxyStatusError{status: status, err: err}
|
||||
}
|
||||
|
||||
defer clientConn.Close()
|
||||
|
|
@ -561,6 +841,164 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques
|
|||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
|
||||
if backWrite == nil {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{
|
||||
status: http.StatusBadGateway,
|
||||
err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"),
|
||||
}
|
||||
}
|
||||
backRead := res.Body
|
||||
|
||||
clientConn, brw, err := c.Writer.Hijack()
|
||||
if err != nil {
|
||||
backRead.Close()
|
||||
_ = backWrite.Close()
|
||||
status := http.StatusBadGateway
|
||||
if errors.Is(err, http.ErrNotSupported) {
|
||||
status = http.StatusNotImplemented
|
||||
}
|
||||
return &reverseProxyStatusError{status: status, err: err}
|
||||
}
|
||||
|
||||
defer clientConn.Close()
|
||||
defer backRead.Close()
|
||||
defer backWrite.Close()
|
||||
|
||||
backConnClosed := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnClosed:
|
||||
}
|
||||
backRead.Close()
|
||||
_ = backWrite.Close()
|
||||
}()
|
||||
defer close(backConnClosed)
|
||||
|
||||
res.Body = nil
|
||||
if err := res.Write(brw); err != nil {
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
if err := brw.Flush(); err != nil {
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() {
|
||||
if _, err := io.Copy(clientConn, backRead); err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
if cw, ok := clientConn.(interface{ CloseWrite() error }); ok {
|
||||
errc <- cw.CloseWrite()
|
||||
return
|
||||
}
|
||||
errc <- errReverseProxyCopyDone
|
||||
}()
|
||||
go func() {
|
||||
if _, err := io.Copy(backWrite, clientConn); err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
errc <- backWrite.Close()
|
||||
}()
|
||||
|
||||
firstErr := <-errc
|
||||
if firstErr == nil {
|
||||
firstErr = <-errc
|
||||
}
|
||||
if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
|
||||
if c == nil || c.Request == nil {
|
||||
res.Body.Close()
|
||||
if backWrite != nil {
|
||||
_ = backWrite.Close()
|
||||
}
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")}
|
||||
}
|
||||
if backWrite == nil {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{
|
||||
status: http.StatusBadGateway,
|
||||
err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"),
|
||||
}
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
res.Body.Close()
|
||||
_ = backWrite.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
|
||||
c.Writer.WriteHeader(res.StatusCode)
|
||||
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
res.Body.Close()
|
||||
_ = backWrite.Close()
|
||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||
}
|
||||
|
||||
var closeOnce sync.Once
|
||||
closeTunnel := func() {
|
||||
closeOnce.Do(func() {
|
||||
_ = c.Request.Body.Close()
|
||||
_ = backWrite.Close()
|
||||
_ = res.Body.Close()
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
<-req.Context().Done()
|
||||
closeTunnel()
|
||||
}()
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := io.Copy(backWrite, c.Request.Body)
|
||||
closeErr := backWrite.Close()
|
||||
if err != nil && !reverseProxyIsBenignTunnelError(err) {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
errc <- closeErr
|
||||
}()
|
||||
go func() {
|
||||
copyErr := p.copyResponse(c.Writer, res.Body, -1)
|
||||
closeErr := res.Body.Close()
|
||||
if copyErr != nil {
|
||||
errc <- copyErr
|
||||
return
|
||||
}
|
||||
errc <- closeErr
|
||||
}()
|
||||
|
||||
var firstErr error
|
||||
for i := 0; i < 2; i++ {
|
||||
err := <-errc
|
||||
if reverseProxyIsBenignTunnelError(err) {
|
||||
continue
|
||||
}
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
closeTunnel()
|
||||
}
|
||||
}
|
||||
closeTunnel()
|
||||
if reverseProxyIsBenignTunnelError(firstErr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return firstErr
|
||||
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
||||
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
|
||||
return -1
|
||||
|
|
@ -599,7 +1037,7 @@ func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byt
|
|||
var written int64
|
||||
for {
|
||||
nr, rerr := src.Read(buf)
|
||||
if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) {
|
||||
if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) {
|
||||
p.logf(nil, "reverse proxy read error during body copy: %v", rerr)
|
||||
}
|
||||
if nr > 0 {
|
||||
|
|
@ -638,6 +1076,10 @@ func reverseProxyStatusCode(err error) int {
|
|||
if errors.As(err, &statusErr) && statusErr.status > 0 {
|
||||
return statusErr.status
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) {
|
||||
return http.StatusGatewayTimeout
|
||||
}
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
|
||||
|
|
@ -651,6 +1093,65 @@ func validateReverseProxyTarget(target *url.URL) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) {
|
||||
if config.Target != nil && len(config.Targets) > 0 {
|
||||
return nil, errors.New("reverse proxy Target and Targets cannot be used together")
|
||||
}
|
||||
|
||||
targets := make([]*url.URL, 0, max(1, len(config.Targets)))
|
||||
if config.Target != nil {
|
||||
target := cloneReverseProxyURL(config.Target)
|
||||
normalizeReverseProxyTarget(target)
|
||||
if err := validateReverseProxyTarget(target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targets = append(targets, target)
|
||||
}
|
||||
for i, rawTarget := range config.Targets {
|
||||
trimmed := strings.TrimSpace(rawTarget)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("reverse proxy target at index %d is empty", i)
|
||||
}
|
||||
target, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
|
||||
}
|
||||
normalizeReverseProxyTarget(target)
|
||||
if err := validateReverseProxyTarget(target); err != nil {
|
||||
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
|
||||
}
|
||||
targets = append(targets, target)
|
||||
}
|
||||
if len(targets) == 0 {
|
||||
return nil, errReverseProxyNilTarget
|
||||
}
|
||||
|
||||
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
|
||||
for i, target := range targets {
|
||||
upstream := &reverseProxyUpstream{
|
||||
key: fmt.Sprintf("%d:%s", i, target.String()),
|
||||
target: target,
|
||||
index: i,
|
||||
}
|
||||
if config.Transport == nil {
|
||||
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
|
||||
}
|
||||
upstreams = append(upstreams, upstream)
|
||||
}
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
func validateReverseProxyForwardedBy(value string) error {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
if !isValidForwardedNodeIdentifier(trimmed) {
|
||||
return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeReverseProxyTarget(target *url.URL) {
|
||||
switch strings.ToLower(target.Scheme) {
|
||||
case "ws":
|
||||
|
|
@ -732,6 +1233,94 @@ func buildForwardedHeaderValue(clientIP, by, host, scheme string) string {
|
|||
return strings.Join(pairs, ";")
|
||||
}
|
||||
|
||||
func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
||||
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
||||
}
|
||||
|
||||
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
|
||||
return reverseProxyExtendedConnectProtocol(req) != ""
|
||||
}
|
||||
|
||||
func reverseProxyExtendedConnectProtocol(req *http.Request) string {
|
||||
if req == nil || req.Method != http.MethodConnect || req.Header == nil {
|
||||
return ""
|
||||
}
|
||||
return textproto.TrimString(req.Header.Get(":protocol"))
|
||||
}
|
||||
|
||||
func isValidForwardedNodeIdentifier(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(value, "[") {
|
||||
closing := strings.IndexByte(value, ']')
|
||||
if closing <= 1 {
|
||||
return false
|
||||
}
|
||||
addr, err := netip.ParseAddr(value[1:closing])
|
||||
if err != nil || !addr.Is6() {
|
||||
return false
|
||||
}
|
||||
if closing == len(value)-1 {
|
||||
return true
|
||||
}
|
||||
if value[closing+1] != ':' {
|
||||
return false
|
||||
}
|
||||
return isValidForwardedNodePort(value[closing+2:])
|
||||
}
|
||||
|
||||
host, port, hasPort := strings.Cut(value, ":")
|
||||
if hasPort {
|
||||
switch {
|
||||
case host == "unknown", isValidForwardedObfuscatedIdentifier(host):
|
||||
return isValidForwardedNodePort(port)
|
||||
default:
|
||||
addr, err := netip.ParseAddr(host)
|
||||
return err == nil && addr.Is4() && isValidForwardedNodePort(port)
|
||||
}
|
||||
}
|
||||
|
||||
if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) {
|
||||
return true
|
||||
}
|
||||
addr, err := netip.ParseAddr(value)
|
||||
return err == nil && addr.Is4()
|
||||
}
|
||||
|
||||
func isValidForwardedNodePort(value string) bool {
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
if isValidForwardedObfuscatedIdentifier(value) {
|
||||
return true
|
||||
}
|
||||
if len(value) > 5 {
|
||||
return false
|
||||
}
|
||||
port, err := strconv.Atoi(value)
|
||||
return err == nil && port > 0 && port <= 65535
|
||||
}
|
||||
|
||||
func isValidForwardedObfuscatedIdentifier(value string) bool {
|
||||
if len(value) < 2 || value[0] != '_' {
|
||||
return false
|
||||
}
|
||||
for i := 1; i < len(value); i++ {
|
||||
b := value[i]
|
||||
if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') {
|
||||
continue
|
||||
}
|
||||
switch b {
|
||||
case '.', '_', '-':
|
||||
continue
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatForwardedFor(clientIP string) string {
|
||||
addr, err := netip.ParseAddr(clientIP)
|
||||
if err != nil {
|
||||
|
|
@ -817,6 +1406,47 @@ func rewriteReverseProxyURL(req *http.Request, target *url.URL) {
|
|||
}
|
||||
}
|
||||
|
||||
func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error {
|
||||
connectTarget, err := reverseProxyConnectTarget(target)
|
||||
if err != nil {
|
||||
return &reverseProxyStatusError{status: http.StatusBadRequest, err: err}
|
||||
}
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = ""
|
||||
req.URL.RawPath = ""
|
||||
req.URL.RawQuery = ""
|
||||
req.URL.Opaque = connectTarget
|
||||
req.Host = connectTarget
|
||||
return nil
|
||||
}
|
||||
|
||||
func reverseProxyConnectTarget(target *url.URL) (string, error) {
|
||||
if target == nil {
|
||||
return "", errReverseProxyNilTarget
|
||||
}
|
||||
host := target.Hostname()
|
||||
if host == "" {
|
||||
return "", errReverseProxyInvalidTarget
|
||||
}
|
||||
port := target.Port()
|
||||
if port == "" {
|
||||
switch strings.ToLower(target.Scheme) {
|
||||
case "http":
|
||||
port = "80"
|
||||
case "https":
|
||||
port = "443"
|
||||
default:
|
||||
return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme)
|
||||
}
|
||||
}
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil || portNum <= 0 || portNum > 65535 {
|
||||
return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port)
|
||||
}
|
||||
return net.JoinHostPort(host, port), nil
|
||||
}
|
||||
|
||||
func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) {
|
||||
if base.RawPath == "" && incoming.RawPath == "" {
|
||||
return reverseProxySingleJoiningSlash(base.Path, incoming.Path), ""
|
||||
|
|
@ -919,6 +1549,59 @@ func cleanReverseProxyQueryParams(rawQuery string) string {
|
|||
return values.Encode()
|
||||
}
|
||||
|
||||
func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
|
||||
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
||||
}
|
||||
|
||||
func reverseProxyCanRetryRequest(req *http.Request) bool {
|
||||
if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) {
|
||||
return false
|
||||
}
|
||||
if req.Body == nil || req.ContentLength == 0 {
|
||||
return true
|
||||
}
|
||||
return req.GetBody != nil
|
||||
}
|
||||
|
||||
func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool {
|
||||
if err == nil || reverseProxyIsBenignTunnelError(err) {
|
||||
return false
|
||||
}
|
||||
if req != nil && req.Context().Err() != nil {
|
||||
return false
|
||||
}
|
||||
return !errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
func reverseProxyMethodIsSafe(method string) bool {
|
||||
switch method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func reverseProxyIsBenignTunnelError(err error) bool {
|
||||
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err)
|
||||
}
|
||||
|
||||
func reverseProxyIsClosedBodyError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var streamErr http2.StreamError
|
||||
if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel {
|
||||
return true
|
||||
}
|
||||
switch err.Error() {
|
||||
case "body closed by handler", "http2: response body closed", "response body closed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
||||
return UnwrapResponseWriter(writer)
|
||||
}
|
||||
|
|
|
|||
352
reverseproxy_lb.go
Normal file
352
reverseproxy_lb.go
Normal file
|
|
@ -0,0 +1,352 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
|
||||
type ReverseProxyLoadBalancingConfig struct {
|
||||
Policy ReverseProxyLBPolicy
|
||||
Retries int
|
||||
TryDuration time.Duration
|
||||
TryInterval time.Duration
|
||||
}
|
||||
|
||||
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
|
||||
type ReverseProxyPassiveHealthConfig struct {
|
||||
FailDuration time.Duration
|
||||
MaxFails int
|
||||
UnhealthyStatus []int
|
||||
}
|
||||
|
||||
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
|
||||
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
|
||||
type ReverseProxyLBPolicy struct {
|
||||
kind reverseProxyLBPolicyKind
|
||||
key string
|
||||
fallback *ReverseProxyLBPolicy
|
||||
}
|
||||
|
||||
type reverseProxyLBPolicyKind uint8
|
||||
|
||||
const (
|
||||
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
|
||||
reverseProxyLBPolicyRoundRobin
|
||||
reverseProxyLBPolicyFirst
|
||||
reverseProxyLBPolicyLeastConn
|
||||
reverseProxyLBPolicyIPHash
|
||||
reverseProxyLBPolicyClientIPHash
|
||||
reverseProxyLBPolicyURIHash
|
||||
reverseProxyLBPolicyHeader
|
||||
reverseProxyLBPolicyQuery
|
||||
)
|
||||
|
||||
type reverseProxyUpstream struct {
|
||||
key string
|
||||
target *url.URL
|
||||
index int
|
||||
extendedConnectTransport http.RoundTripper
|
||||
inFlight atomic.Int64
|
||||
|
||||
passiveMu sync.Mutex
|
||||
failures []time.Time
|
||||
}
|
||||
|
||||
func LBRandom() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
|
||||
}
|
||||
|
||||
func LBRoundRobin() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
|
||||
}
|
||||
|
||||
func LBFirst() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
|
||||
}
|
||||
|
||||
func LBLeastConn() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
|
||||
}
|
||||
|
||||
func LBIPHash() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
|
||||
}
|
||||
|
||||
func LBClientIPHash() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
|
||||
}
|
||||
|
||||
func LBURIHash() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
|
||||
}
|
||||
|
||||
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
|
||||
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||
policy.fallback = &fallback
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
|
||||
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||
policy.fallback = &fallback
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
|
||||
switch policy.kind {
|
||||
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
|
||||
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
|
||||
reverseProxyLBPolicyURIHash:
|
||||
return nil
|
||||
case reverseProxyLBPolicyHeader:
|
||||
if policy.key == "" {
|
||||
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
|
||||
}
|
||||
case reverseProxyLBPolicyQuery:
|
||||
if policy.key == "" {
|
||||
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
|
||||
}
|
||||
if policy.fallback != nil {
|
||||
return validateReverseProxyLBPolicy(*policy.fallback)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
|
||||
now := time.Now()
|
||||
policy := p.config.LoadBalancing.Policy
|
||||
candidates := p.availableUpstreams(now, excluded)
|
||||
if len(candidates) == 0 && len(excluded) > 0 {
|
||||
candidates = p.availableUpstreams(now, nil)
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, errReverseProxyNoAvailableUpstreams
|
||||
}
|
||||
return p.selectUpstreamWithPolicy(c, candidates, policy), nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
|
||||
candidates := make([]*reverseProxyUpstream, 0, len(p.upstreams))
|
||||
for _, upstream := range p.upstreams {
|
||||
if _, skip := excluded[upstream.key]; skip {
|
||||
continue
|
||||
}
|
||||
if !upstream.healthy(now, p.config.PassiveHealth) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, upstream)
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch policy.kind {
|
||||
case reverseProxyLBPolicyRoundRobin:
|
||||
return candidates[p.nextRoundRobinIndex(len(candidates))]
|
||||
case reverseProxyLBPolicyFirst:
|
||||
return candidates[0]
|
||||
case reverseProxyLBPolicyLeastConn:
|
||||
return p.selectLeastConnUpstream(candidates)
|
||||
case reverseProxyLBPolicyIPHash:
|
||||
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
|
||||
case reverseProxyLBPolicyClientIPHash:
|
||||
return reverseProxySelectHRW(candidates, c.RequestIP())
|
||||
case reverseProxyLBPolicyURIHash:
|
||||
if c.Request == nil || c.Request.URL == nil {
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
|
||||
case reverseProxyLBPolicyHeader:
|
||||
if c.Request != nil && c.Request.Header != nil {
|
||||
if values, ok := c.Request.Header[policy.key]; ok {
|
||||
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||
}
|
||||
}
|
||||
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||
case reverseProxyLBPolicyQuery:
|
||||
if c.Request != nil && c.Request.URL != nil {
|
||||
if values, ok := c.Request.URL.Query()[policy.key]; ok {
|
||||
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||
}
|
||||
}
|
||||
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||
case reverseProxyLBPolicyRandom:
|
||||
fallthrough
|
||||
default:
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
|
||||
if size <= 1 {
|
||||
return 0
|
||||
}
|
||||
return int((p.roundRobin.Add(1) - 1) % uint64(size))
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
selected := candidates[0]
|
||||
lowest := selected.inFlight.Load()
|
||||
ties := []*reverseProxyUpstream{selected}
|
||||
for _, upstream := range candidates[1:] {
|
||||
count := upstream.inFlight.Load()
|
||||
switch {
|
||||
case count < lowest:
|
||||
selected = upstream
|
||||
lowest = count
|
||||
ties = []*reverseProxyUpstream{upstream}
|
||||
case count == lowest:
|
||||
ties = append(ties, upstream)
|
||||
}
|
||||
}
|
||||
if len(ties) == 1 {
|
||||
return selected
|
||||
}
|
||||
return ties[p.nextRoundRobinIndex(len(ties))]
|
||||
}
|
||||
|
||||
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(candidates) == 1 {
|
||||
return candidates[0]
|
||||
}
|
||||
return candidates[rand.IntN(len(candidates))]
|
||||
}
|
||||
|
||||
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if key == "" {
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
selected := candidates[0]
|
||||
bestScore := reverseProxyHRWScore(key, selected.key)
|
||||
for _, upstream := range candidates[1:] {
|
||||
score := reverseProxyHRWScore(key, upstream.key)
|
||||
if score > bestScore {
|
||||
selected = upstream
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
|
||||
const (
|
||||
offset64 = 14695981039346656037
|
||||
prime64 = 1099511628211
|
||||
)
|
||||
h := uint64(offset64)
|
||||
for i := 0; i < len(key); i++ {
|
||||
h ^= uint64(key[i])
|
||||
h *= prime64
|
||||
}
|
||||
h ^= 0xff
|
||||
h *= prime64
|
||||
for i := 0; i < len(upstreamKey); i++ {
|
||||
h ^= uint64(upstreamKey[i])
|
||||
h *= prime64
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||
if policy.fallback != nil {
|
||||
return *policy.fallback
|
||||
}
|
||||
return LBRandom()
|
||||
}
|
||||
|
||||
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
|
||||
maxFails := reverseProxyPassiveMaxFails(config)
|
||||
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
u.passiveMu.Lock()
|
||||
defer u.passiveMu.Unlock()
|
||||
u.pruneFailuresLocked(now, config.FailDuration)
|
||||
return len(u.failures) < maxFails
|
||||
}
|
||||
|
||||
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
|
||||
maxFails := reverseProxyPassiveMaxFails(config)
|
||||
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
u.passiveMu.Lock()
|
||||
defer u.passiveMu.Unlock()
|
||||
u.pruneFailuresLocked(now, config.FailDuration)
|
||||
u.failures = append(u.failures, now)
|
||||
}
|
||||
|
||||
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
|
||||
if len(u.failures) == 0 || window <= 0 {
|
||||
if window <= 0 {
|
||||
u.failures = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
cutoff := now.Add(-window)
|
||||
keep := 0
|
||||
for _, failureAt := range u.failures {
|
||||
if failureAt.Before(cutoff) {
|
||||
continue
|
||||
}
|
||||
u.failures[keep] = failureAt
|
||||
keep++
|
||||
}
|
||||
u.failures = u.failures[:keep]
|
||||
}
|
||||
|
||||
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
|
||||
if config.FailDuration <= 0 {
|
||||
return 0
|
||||
}
|
||||
if config.MaxFails <= 0 {
|
||||
return 1
|
||||
}
|
||||
return config.MaxFails
|
||||
}
|
||||
|
||||
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
|
||||
if status <= 0 {
|
||||
return false
|
||||
}
|
||||
for _, unhealthyStatus := range config.UnhealthyStatus {
|
||||
if status == unhealthyStatus {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
1406
reverseproxy_test.go
1406
reverseproxy_test.go
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue