Merge pull request #77 from infinite-iroha/break/v1-enhance-reverse-proxy

feat(reverseproxy): add upstream balancing and protocol improvements
This commit is contained in:
WJQSERVER 2026-04-02 15:32:41 +08:00 committed by GitHub
commit 0d7721a24c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 2759 additions and 92 deletions

View file

@ -59,7 +59,11 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
```go ```go
type ReverseProxyConfig struct { type ReverseProxyConfig struct {
Target *url.URL Target *url.URL
Targets []string
LoadBalancing ReverseProxyLoadBalancingConfig
PassiveHealth ReverseProxyPassiveHealthConfig
Transport http.RoundTripper Transport http.RoundTripper
FlushInterval time.Duration FlushInterval time.Duration
@ -78,12 +82,115 @@ type ReverseProxyConfig struct {
### `Target` ### `Target`
必填。表示后端目标地址,至少需要提供 `scheme``host` `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme``host`
```go ```go
target, _ := url.Parse("http://backend:9000") target, _ := url.Parse("http://backend:9000")
``` ```
### `Targets`
可选。用于配置多个后端目标地址。
- `Target``Targets` 互斥,只能使用其中一种
- `Targets` 的每一项都必须是完整 URL
- 每个 target 仍然可以自带 base path 和 query
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001/base?from=a",
"http://127.0.0.1:9002/base?from=b",
},
}))
```
这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。
### `LoadBalancing`
用于配置 upstream 选择策略和重试行为。
```go
type ReverseProxyLoadBalancingConfig struct {
Policy ReverseProxyLBPolicy
Retries int
TryDuration time.Duration
TryInterval time.Duration
}
```
当前内置策略:
- `touka.LBRandom()`
- `touka.LBRoundRobin()`
- `touka.LBFirst()`
- `touka.LBLeastConn()`
- `touka.LBIPHash()`
- `touka.LBClientIPHash()`
- `touka.LBURIHash()`
- `touka.LBHeader("X-Upstream", fallback)`
- `touka.LBQuery("tenant", fallback)`
其中:
- `LBFirst()` 适合主备/故障转移顺序
- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback
- 如果 `LBHeader` / `LBQuery` 没有显式 fallback则默认回退到 `LBRandom()`
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001",
"http://127.0.0.1:9002",
},
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
Policy: touka.LBHeader("X-Upstream", touka.LBFirst()),
Retries: 1,
},
}))
```
重试说明:
- 只对未开始收到上游响应的失败进行重试
- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试
- `Retries` 表示额外重试次数
- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机
- `TryInterval` 表示两次重试之间的等待间隔
### `PassiveHealth`
用于配置被动健康检查。它不会后台探测 upstream而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。
```go
type ReverseProxyPassiveHealthConfig struct {
FailDuration time.Duration
MaxFails int
UnhealthyStatus []int
}
```
- `FailDuration > 0` 时启用被动健康跟踪
- `MaxFails <= 0` 时默认按 `1` 处理
- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001",
"http://127.0.0.1:9002",
},
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
Policy: touka.LBFirst(),
},
PassiveHealth: touka.ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
UnhealthyStatus: []int{http.StatusServiceUnavailable},
},
}))
```
### `Transport` ### `Transport`
可选。用于自定义底层转发所使用的 `http.RoundTripper` 可选。用于自定义底层转发所使用的 `http.RoundTripper`
@ -150,6 +257,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
在请求真正发往后端前,对出站请求做最后修改。 在请求真正发往后端前,对出站请求做最后修改。
如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。
常见用途: 常见用途:
- 覆盖 `Host` - 覆盖 `Host`
@ -242,11 +351,20 @@ const (
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target, Target: target,
ForwardedHeaders: touka.ForwardedBoth, ForwardedHeaders: touka.ForwardedBoth,
ForwardedBy: "gateway-1", ForwardedBy: "_gateway-1",
Via: "edge-1", Via: "edge-1",
})) }))
``` ```
如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。
- IPv4`203.0.113.43`
- IPv6 / 带端口:`[2001:db8::17]:443`
- 匿名标识:`_gateway-1`
- 未知:`unknown`
`gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。
`Via` 不是“留空即禁用”的开关。当前实现中: `Via` 不是“留空即禁用”的开关。当前实现中:
- 如果 `Via` 非空,则使用该值追加 `Via` - 如果 `Via` 非空,则使用该值追加 `Via`
@ -282,11 +400,14 @@ Touka 会尽量遵循代理链语义:
Touka 的反向代理实现支持以下能力: Touka 的反向代理实现支持以下能力:
- `CONNECT` 隧道转发HTTP/1.x
- HTTP/2 extended `CONNECT`
- `Connection: Upgrade` / `Upgrade` 协议升级转发 - `Connection: Upgrade` / `Upgrade` 协议升级转发
- WebSocket 等 101 Switching Protocols 场景 - WebSocket 等 101 Switching Protocols 场景
- SSEServer-Sent Events立即刷新 - SSEServer-Sent Events立即刷新
- Trailer 透传 - Trailer 透传
- 1xx 响应透传 - 1xx 响应透传
- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理
例如,代理 WebSocket 服务: 例如,代理 WebSocket 服务:
@ -341,7 +462,7 @@ func main() {
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target, Target: target,
ForwardedHeaders: touka.ForwardedBoth, ForwardedHeaders: touka.ForwardedBoth,
ForwardedBy: "gateway-1", ForwardedBy: "_gateway-1",
Via: "gateway-1", Via: "gateway-1",
FlushInterval: 100 * time.Millisecond, FlushInterval: 100 * time.Millisecond,
ModifyRequest: func(req *http.Request) { ModifyRequest: func(req *http.Request) {

View file

@ -22,6 +22,8 @@ r.ANY("/any", handle)
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
``` ```
服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。
## 路径参数 (Named Parameters) ## 路径参数 (Named Parameters)
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。

2
ecw.go
View file

@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool {
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := ecw.w.(http.Hijacker) hijacker, ok := ecw.w.(http.Hijacker)
if !ok { if !ok {
return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface") return nil, nil, http.ErrNotSupported
} }
return hijacker.Hijack() return hijacker.Hijack()
} }

View file

@ -7,6 +7,7 @@ package touka
import ( import (
"context" "context"
"errors" "errors"
"io"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
@ -344,6 +345,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
if engine.serverProtocols != nil { if engine.serverProtocols != nil {
srv.Protocols = engine.serverProtocols srv.Protocols = engine.serverProtocols
if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() {
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
panic(err)
}
}
} }
} }
@ -475,21 +481,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
func MethodNotAllowed() HandlerFunc { func MethodNotAllowed() HandlerFunc {
return func(c *Context) { return func(c *Context) {
httpMethod := c.Request.Method httpMethod := c.Request.Method
requestPath := c.Request.URL.Path requestPath := routeLookupPath(c.Request)
engine := c.engine engine := c.engine
// 是否是OPTIONS方式 // 是否是OPTIONS方式
if httpMethod == http.MethodOptions { if httpMethod == http.MethodOptions {
// 如果是 OPTIONS 请求,尝试查找所有允许的方法 // 如果是 OPTIONS 请求,尝试查找所有允许的方法
allowedMethods := []string{} allowedMethods := engine.allowedMethodsForPath(requestPath)
for _, treeIter := range engine.methodTrees {
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
tempSkippedNodes := GetTempSkippedNodes()
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil {
allowedMethods = append(allowedMethods, treeIter.method)
}
}
if len(allowedMethods) > 0 { if len(allowedMethods) > 0 {
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
@ -704,8 +701,13 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// handleRequest 负责根据请求查找路由并执行相应的处理函数链 // handleRequest 负责根据请求查找路由并执行相应的处理函数链
// 这是路由查找和执行的核心逻辑 // 这是路由查找和执行的核心逻辑
func (engine *Engine) handleRequest(c *Context) { func (engine *Engine) handleRequest(c *Context) {
if isGeneralOptionsRequest(c.Request) {
engine.handleGeneralOptions(c)
return
}
httpMethod := c.Request.Method httpMethod := c.Request.Method
requestPath := c.Request.URL.Path requestPath := routeLookupPath(c.Request)
// 查找对应的路由树的根节点 // 查找对应的路由树的根节点
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
@ -725,7 +727,7 @@ func (engine *Engine) handleRequest(c *Context) {
} }
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向
if value.tsr && engine.RedirectTrailingSlash { if value.tsr && engine.RedirectTrailingSlash {
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
redirectPath := requestPath redirectPath := requestPath
@ -782,6 +784,55 @@ func (engine *Engine) handleRequest(c *Context) {
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送
} }
func routeLookupPath(req *http.Request) string {
if req == nil {
return ""
}
if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") {
return "/" + req.RequestURI
}
if isGeneralOptionsRequest(req) {
return ""
}
if req.URL == nil {
return ""
}
return req.URL.Path
}
func isGeneralOptionsRequest(req *http.Request) bool {
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
}
func (engine *Engine) allowedMethodsForPath(requestPath string) []string {
allowedMethods := make([]string, 0, len(engine.methodTrees))
for _, treeIter := range engine.methodTrees {
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
tempSkippedNodes := GetTempSkippedNodes()
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil {
allowedMethods = append(allowedMethods, treeIter.method)
}
}
return allowedMethods
}
func (engine *Engine) handleGeneralOptions(c *Context) {
if c == nil || c.Request == nil {
return
}
c.Writer.Header().Set("Content-Length", "0")
if c.Request.ContentLength != 0 {
mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10)
_, _ = io.Copy(io.Discard, mb)
}
c.Writer.WriteHeader(http.StatusOK)
c.Abort()
}
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
// 它可以用于在长连接 (如 SSE) 中监听关闭信号. // 它可以用于在长连接 (如 SSE) 中监听关闭信号.
func (engine *Engine) Context() context.Context { func (engine *Engine) Context() context.Context {

3
go.mod
View file

@ -8,9 +8,10 @@ require (
github.com/WJQSERVER/wanf v0.0.8 github.com/WJQSERVER/wanf v0.0.8
github.com/fenthope/reco v0.0.5 github.com/fenthope/reco v0.0.5
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
golang.org/x/net v0.52.0
) )
require ( require (
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/net v0.52.0 // indirect golang.org/x/text v0.35.0 // indirect
) )

2
go.sum
View file

@ -12,3 +12,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=

53
http2xconnect.go Normal file
View file

@ -0,0 +1,53 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
"sync"
_ "unsafe"
"golang.org/x/net/http2"
)
var enableHTTP2ExtendedConnectOnce sync.Once
//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol
var xnetDisableHTTP2ExtendedConnectProtocol bool
func enableHTTP2ExtendedConnectProtocol() {
enableHTTP2ExtendedConnectOnce.Do(func() {
xnetDisableHTTP2ExtendedConnectProtocol = false
})
}
func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
if srv == nil {
return nil
}
enableHTTP2ExtendedConnectProtocol()
return http2.ConfigureServer(srv, nil)
}
func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper {
enableHTTP2ExtendedConnectProtocol()
transport := &http2.Transport{}
if target == nil || !strings.EqualFold(target.Scheme, "http") {
return transport
}
transport.AllowHTTP = true
transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
}
return transport
}

View file

@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// 尝试从底层 ResponseWriter 获取 Hijacker 接口 // 尝试从底层 ResponseWriter 获取 Hijacker 接口
hj, ok := rw.ResponseWriter.(http.Hijacker) hj, ok := rw.ResponseWriter.(http.Hijacker)
if !ok { if !ok {
return nil, nil, errors.New("http.Hijacker interface not supported") return nil, nil, http.ErrNotSupported
} }
// 调用底层的 Hijack 方法 // 调用底层的 Hijack 方法

View file

@ -14,6 +14,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/http/httputil"
"net/netip" "net/netip"
"net/textproto" "net/textproto"
"net/url" "net/url"
@ -22,6 +23,8 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/net/http2"
) )
// ForwardedHeadersPolicy controls how forwarding headers are generated. // ForwardedHeadersPolicy controls how forwarding headers are generated.
@ -43,7 +46,11 @@ type BufferPool interface {
// ReverseProxyConfig configures the reverse proxy handler. // ReverseProxyConfig configures the reverse proxy handler.
type ReverseProxyConfig struct { type ReverseProxyConfig struct {
Target *url.URL Target *url.URL
Targets []string
LoadBalancing ReverseProxyLoadBalancingConfig
PassiveHealth ReverseProxyPassiveHealthConfig
Transport http.RoundTripper Transport http.RoundTripper
FlushInterval time.Duration FlushInterval time.Duration
@ -60,16 +67,18 @@ type ReverseProxyConfig struct {
} }
var ( var (
errReverseProxyNilTarget = errors.New("reverse proxy target is nil") errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
) )
type reverseProxyHandler struct { type reverseProxyHandler struct {
config ReverseProxyConfig config ReverseProxyConfig
target *url.URL upstreams []*reverseProxyUpstream
receivedBy string receivedBy string
configError error configError error
roundRobin atomic.Uint64
} }
type reverseProxyStatusError struct { type reverseProxyStatusError struct {
@ -197,19 +206,16 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc {
} }
func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
target := cloneReverseProxyURL(config.Target)
if target != nil {
normalizeReverseProxyTarget(target)
}
proxy := &reverseProxyHandler{ proxy := &reverseProxyHandler{
config: config, config: config,
target: target,
receivedBy: reverseProxyReceivedBy(config.Via), receivedBy: reverseProxyReceivedBy(config.Via),
} }
if err := validateReverseProxyTarget(target); err != nil { upstreams, err := buildReverseProxyUpstreams(config)
if err != nil {
proxy.configError = err proxy.configError = err
} else {
proxy.upstreams = upstreams
} }
switch config.ForwardedHeaders { switch config.ForwardedHeaders {
@ -217,6 +223,17 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
default: default:
proxy.config.ForwardedHeaders = ForwardedBoth proxy.config.ForwardedHeaders = ForwardedBoth
} }
proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy)
if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) {
if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil {
proxy.configError = err
}
}
if proxy.configError == nil {
if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil {
proxy.configError = err
}
}
return proxy return proxy
} }
@ -229,62 +246,75 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
return return
} }
transport := p.config.Transport updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
if transport == nil { if err != nil {
transport = http.DefaultTransport p.handleError(c, err)
return
}
if handledLocally {
return
} }
ctx, cancel := p.requestContext(c) ctx, cancel := p.requestContext(c)
defer cancel() defer cancel()
attempted := make(map[string]struct{}, len(p.upstreams))
attempts := 0
started := time.Now()
var lastErr error
outreq := c.Request.Clone(ctx) for {
if c.Request.ContentLength == 0 { upstream, err := p.selectUpstream(c, attempted)
outreq.Body = nil if err != nil {
} if lastErr != nil {
if outreq.Body != nil { p.handleError(c, lastErr)
outreq.Body = &noopCloseReader{readCloser: outreq.Body} return
defer outreq.Body.Close() }
} p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err})
if outreq.Header == nil { return
outreq.Header = make(http.Header) }
}
outreq.Close = false
rewriteReverseProxyURL(outreq, p.target) attempts++
if !p.config.PreserveHost { upstream.inFlight.Add(1)
outreq.Host = "" served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards)
} upstream.inFlight.Add(-1)
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
reqUpType := reverseProxyUpgradeType(outreq.Header) if served {
if reqUpType != "" && !isPrintableASCII(reqUpType) { return
p.handleError(c, &reverseProxyStatusError{ }
status: http.StatusBadRequest, if attemptErr != nil {
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), lastErr = attemptErr
}) }
if retriable && p.shouldRetryAttempt(c.Request, attempts, started) {
attempted[upstream.key] = struct{}{}
if !p.waitRetryInterval(ctx, started) {
if lastErr != nil {
p.handleError(c, lastErr)
}
return
}
continue
}
if attemptErr != nil {
p.handleError(c, attemptErr)
return
}
if lastErr != nil {
p.handleError(c, lastErr)
return
}
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams})
return return
} }
}
removeHopByHopHeaders(outreq.Header) func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) {
if headerValuesContainToken(c.Request.Header["Te"], "trailers") { outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards)
outreq.Header.Set("Te", "trailers") if err != nil {
} return false, err, false
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
p.addForwardingHeaders(c.Request, outreq)
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
if _, ok := outreq.Header["User-Agent"]; !ok {
outreq.Header.Set("User-Agent", "")
}
if p.config.ModifyRequest != nil {
p.config.ModifyRequest(outreq)
} }
defer cleanup()
transport := p.transportForUpstream(c.Request, upstream)
rawWriter := reverseProxyBaseResponseWriter(c.Writer) rawWriter := reverseProxyBaseResponseWriter(c.Writer)
var ( var (
roundTripMu sync.Mutex roundTripMu sync.Mutex
@ -314,26 +344,51 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
roundTripDone = true roundTripDone = true
roundTripMu.Unlock() roundTripMu.Unlock()
if err != nil { if err != nil {
p.handleError(c, err) if reverseProxyShouldCountPassiveFailure(outreq, err) {
return upstream.recordFailure(time.Now(), p.config.PassiveHealth)
}
return false, err, true
}
if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) {
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
}
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
removeHopByHopHeaders(res.Header)
res.Header.Del("Content-Length")
res.Header.Del("Transfer-Encoding")
res.ContentLength = -1
res.TransferEncoding = nil
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return true, nil, false
}
handleConnect := p.handleConnectResponse
if reverseProxyIsExtendedConnectRequest(outreq) {
handleConnect = p.handleExtendedConnectResponse
}
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
return false, err, false
}
return true, nil, false
} }
if res.StatusCode == http.StatusSwitchingProtocols { if res.StatusCode == http.StatusSwitchingProtocols {
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) { if !p.modifyResponse(c, res, outreq) {
return return true, nil, false
} }
if err := p.handleUpgradeResponse(c, outreq, res); err != nil { if err := p.handleUpgradeResponse(c, outreq, res); err != nil {
p.handleError(c, err) return false, err, false
} }
return return true, nil, false
} }
removeHopByHopHeaders(res.Header) removeHopByHopHeaders(res.Header)
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) { if !p.modifyResponse(c, res, outreq) {
return return true, nil, false
} }
reverseProxyCopyHeader(c.Writer.Header(), res.Header) reverseProxyCopyHeader(c.Writer.Header(), res.Header)
@ -353,7 +408,10 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
defer res.Body.Close() defer res.Body.Close()
c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err)) c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err))
p.logf(c, "reverse proxy body copy failed: %v", err) p.logf(c, "reverse proxy body copy failed: %v", err)
return if reverseProxyShouldPanicOnCopyError(c.Request) {
panic(http.ErrAbortHandler)
}
return true, nil, false
} }
res.Body.Close() res.Body.Close()
@ -361,13 +419,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
c.Writer.Flush() c.Writer.Flush()
} }
// Keep the stdlib-compatible fallback here.
// If the backend only exposes additional trailer keys after the body has been
// fully read, the trailer map can grow and those values must be written using
// the TrailerPrefix form instead of the pre-announced bare header keys.
if len(res.Trailer) == announcedTrailers { if len(res.Trailer) == announcedTrailers {
reverseProxyCopyHeader(c.Writer.Header(), res.Trailer) reverseProxyCopyHeader(c.Writer.Header(), res.Trailer)
return return true, nil, false
} }
for key, values := range res.Trailer { for key, values := range res.Trailer {
@ -376,6 +430,228 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
c.Writer.Header().Add(prefixedKey, value) c.Writer.Header().Add(prefixedKey, value)
} }
} }
return true, nil, false
}
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
outreq := c.Request.Clone(ctx)
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
outreq.Body = nil
} else if c.Request.GetBody != nil {
body, err := c.Request.GetBody()
if err != nil {
return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err)
}
outreq.Body = body
} else if outreq.Body != nil {
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
}
if outreq.Header == nil {
outreq.Header = make(http.Header)
}
outreq.Close = false
var connectWriter *io.PipeWriter
if outreq.Method == http.MethodConnect {
pipeReader, pipeWriter := io.Pipe()
outreq.Body = pipeReader
outreq.ContentLength = -1
connectWriter = pipeWriter
}
cleanup := func() {
if outreq.Body != nil {
_ = outreq.Body.Close()
}
if connectWriter != nil {
_ = connectWriter.Close()
}
}
if outreq.Method == http.MethodConnect {
if reverseProxyIsExtendedConnectRequest(outreq) {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
} else {
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
cleanup()
return nil, nil, nil, err
}
}
} else {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
}
if updatedMaxForwards != "" {
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
}
reqUpType := reverseProxyUpgradeType(outreq.Header)
if reqUpType != "" && !isPrintableASCII(reqUpType) {
cleanup()
return nil, nil, nil, &reverseProxyStatusError{
status: http.StatusBadRequest,
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
}
}
removeHopByHopHeaders(outreq.Header)
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
p.addForwardingHeaders(c.Request, outreq)
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
if _, ok := outreq.Header["User-Agent"]; !ok {
outreq.Header.Set("User-Agent", "")
}
if p.config.ModifyRequest != nil {
p.config.ModifyRequest(outreq)
}
return outreq, connectWriter, cleanup, nil
}
func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper {
if p.config.Transport != nil {
return p.config.Transport
}
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
return upstream.extendedConnectTransport
}
return http.DefaultTransport
}
func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool {
if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) {
return false
}
lb := p.config.LoadBalancing
if lb.TryDuration > 0 {
return time.Since(started) < lb.TryDuration
}
return attempts <= lb.Retries
}
func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool {
interval := p.config.LoadBalancing.TryInterval
tryDuration := p.config.LoadBalancing.TryDuration
if tryDuration > 0 && interval == 0 {
interval = 250 * time.Millisecond
}
if tryDuration > 0 {
remaining := tryDuration - time.Since(started)
if remaining <= 0 {
return false
}
if interval <= 0 {
return ctx.Err() == nil
}
if interval > remaining {
return false
}
}
if interval <= 0 {
return ctx.Err() == nil
}
timer := time.NewTimer(interval)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
if c == nil || c.Request == nil {
return "", false, nil
}
switch c.Request.Method {
case http.MethodOptions, http.MethodTrace:
default:
return "", false, nil
}
rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards"))
if rawValue == "" {
return "", false, nil
}
value, err := strconv.Atoi(rawValue)
if err != nil || value < 0 {
return "", false, &reverseProxyStatusError{
status: http.StatusBadRequest,
err: fmt.Errorf("invalid Max-Forwards value %q", rawValue),
}
}
if value == 0 {
switch c.Request.Method {
case http.MethodTrace:
return "", true, p.writeLocalTraceResponse(c)
case http.MethodOptions:
p.writeLocalOptionsResponse(c)
return "", true, nil
}
}
return strconv.Itoa(value - 1), false, nil
}
func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error {
if c == nil || c.Request == nil {
return nil
}
traceReq := c.Request.Clone(c.Request.Context())
traceReq.Body = nil
traceReq.ContentLength = 0
traceReq.TransferEncoding = nil
traceReq.RequestURI = c.Request.RequestURI
if traceReq.RequestURI == "" && traceReq.URL != nil {
traceReq.RequestURI = traceReq.URL.RequestURI()
}
traceReq.Header = traceReq.Header.Clone()
for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} {
traceReq.Header.Del(key)
}
dump, err := httputil.DumpRequest(traceReq, false)
if err != nil {
return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err}
}
c.Writer.Header().Set("Content-Type", "message/http")
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(dump)
return err
}
func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) {
if c == nil {
return
}
if c.engine != nil {
if c.Request != nil && c.Request.RequestURI != "*" {
if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 {
c.Writer.Header().Set("Allow", strings.Join(allow, ", "))
}
}
}
c.Writer.WriteHeader(http.StatusOK)
} }
func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) { func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) {
@ -522,7 +798,11 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques
clientConn, brw, err := c.Writer.Hijack() clientConn, brw, err := c.Writer.Hijack()
if err != nil { if err != nil {
backConn.Close() backConn.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} status := http.StatusBadGateway
if errors.Is(err, http.ErrNotSupported) {
status = http.StatusNotImplemented
}
return &reverseProxyStatusError{status: status, err: err}
} }
defer clientConn.Close() defer clientConn.Close()
@ -561,6 +841,164 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques
return firstErr return firstErr
} }
func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
if backWrite == nil {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"),
}
}
backRead := res.Body
clientConn, brw, err := c.Writer.Hijack()
if err != nil {
backRead.Close()
_ = backWrite.Close()
status := http.StatusBadGateway
if errors.Is(err, http.ErrNotSupported) {
status = http.StatusNotImplemented
}
return &reverseProxyStatusError{status: status, err: err}
}
defer clientConn.Close()
defer backRead.Close()
defer backWrite.Close()
backConnClosed := make(chan struct{})
go func() {
select {
case <-req.Context().Done():
case <-backConnClosed:
}
backRead.Close()
_ = backWrite.Close()
}()
defer close(backConnClosed)
res.Body = nil
if err := res.Write(brw); err != nil {
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
if err := brw.Flush(); err != nil {
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
errc := make(chan error, 2)
go func() {
if _, err := io.Copy(clientConn, backRead); err != nil {
errc <- err
return
}
if cw, ok := clientConn.(interface{ CloseWrite() error }); ok {
errc <- cw.CloseWrite()
return
}
errc <- errReverseProxyCopyDone
}()
go func() {
if _, err := io.Copy(backWrite, clientConn); err != nil {
errc <- err
return
}
errc <- backWrite.Close()
}()
firstErr := <-errc
if firstErr == nil {
firstErr = <-errc
}
if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) {
return nil
}
return firstErr
}
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
if c == nil || c.Request == nil {
res.Body.Close()
if backWrite != nil {
_ = backWrite.Close()
}
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")}
}
if backWrite == nil {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"),
}
}
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
res.Body.Close()
_ = backWrite.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
c.Writer.WriteHeader(res.StatusCode)
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
res.Body.Close()
_ = backWrite.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
var closeOnce sync.Once
closeTunnel := func() {
closeOnce.Do(func() {
_ = c.Request.Body.Close()
_ = backWrite.Close()
_ = res.Body.Close()
})
}
go func() {
<-req.Context().Done()
closeTunnel()
}()
errc := make(chan error, 2)
go func() {
_, err := io.Copy(backWrite, c.Request.Body)
closeErr := backWrite.Close()
if err != nil && !reverseProxyIsBenignTunnelError(err) {
errc <- err
return
}
errc <- closeErr
}()
go func() {
copyErr := p.copyResponse(c.Writer, res.Body, -1)
closeErr := res.Body.Close()
if copyErr != nil {
errc <- copyErr
return
}
errc <- closeErr
}()
var firstErr error
for i := 0; i < 2; i++ {
err := <-errc
if reverseProxyIsBenignTunnelError(err) {
continue
}
if firstErr == nil {
firstErr = err
closeTunnel()
}
}
closeTunnel()
if reverseProxyIsBenignTunnelError(firstErr) {
return nil
}
return firstErr
}
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
return -1 return -1
@ -599,7 +1037,7 @@ func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byt
var written int64 var written int64
for { for {
nr, rerr := src.Read(buf) nr, rerr := src.Read(buf)
if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) { if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) {
p.logf(nil, "reverse proxy read error during body copy: %v", rerr) p.logf(nil, "reverse proxy read error during body copy: %v", rerr)
} }
if nr > 0 { if nr > 0 {
@ -638,6 +1076,10 @@ func reverseProxyStatusCode(err error) int {
if errors.As(err, &statusErr) && statusErr.status > 0 { if errors.As(err, &statusErr) && statusErr.status > 0 {
return statusErr.status return statusErr.status
} }
var netErr net.Error
if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) {
return http.StatusGatewayTimeout
}
return http.StatusBadGateway return http.StatusBadGateway
} }
@ -651,6 +1093,65 @@ func validateReverseProxyTarget(target *url.URL) error {
return nil return nil
} }
func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) {
if config.Target != nil && len(config.Targets) > 0 {
return nil, errors.New("reverse proxy Target and Targets cannot be used together")
}
targets := make([]*url.URL, 0, max(1, len(config.Targets)))
if config.Target != nil {
target := cloneReverseProxyURL(config.Target)
normalizeReverseProxyTarget(target)
if err := validateReverseProxyTarget(target); err != nil {
return nil, err
}
targets = append(targets, target)
}
for i, rawTarget := range config.Targets {
trimmed := strings.TrimSpace(rawTarget)
if trimmed == "" {
return nil, fmt.Errorf("reverse proxy target at index %d is empty", i)
}
target, err := url.Parse(trimmed)
if err != nil {
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
}
normalizeReverseProxyTarget(target)
if err := validateReverseProxyTarget(target); err != nil {
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
}
targets = append(targets, target)
}
if len(targets) == 0 {
return nil, errReverseProxyNilTarget
}
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
for i, target := range targets {
upstream := &reverseProxyUpstream{
key: fmt.Sprintf("%d:%s", i, target.String()),
target: target,
index: i,
}
if config.Transport == nil {
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
}
upstreams = append(upstreams, upstream)
}
return upstreams, nil
}
func validateReverseProxyForwardedBy(value string) error {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return nil
}
if !isValidForwardedNodeIdentifier(trimmed) {
return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value)
}
return nil
}
func normalizeReverseProxyTarget(target *url.URL) { func normalizeReverseProxyTarget(target *url.URL) {
switch strings.ToLower(target.Scheme) { switch strings.ToLower(target.Scheme) {
case "ws": case "ws":
@ -732,6 +1233,94 @@ func buildForwardedHeaderValue(clientIP, by, host, scheme string) string {
return strings.Join(pairs, ";") return strings.Join(pairs, ";")
} }
func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
}
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
return reverseProxyExtendedConnectProtocol(req) != ""
}
func reverseProxyExtendedConnectProtocol(req *http.Request) string {
if req == nil || req.Method != http.MethodConnect || req.Header == nil {
return ""
}
return textproto.TrimString(req.Header.Get(":protocol"))
}
func isValidForwardedNodeIdentifier(value string) bool {
if value == "" {
return false
}
if strings.HasPrefix(value, "[") {
closing := strings.IndexByte(value, ']')
if closing <= 1 {
return false
}
addr, err := netip.ParseAddr(value[1:closing])
if err != nil || !addr.Is6() {
return false
}
if closing == len(value)-1 {
return true
}
if value[closing+1] != ':' {
return false
}
return isValidForwardedNodePort(value[closing+2:])
}
host, port, hasPort := strings.Cut(value, ":")
if hasPort {
switch {
case host == "unknown", isValidForwardedObfuscatedIdentifier(host):
return isValidForwardedNodePort(port)
default:
addr, err := netip.ParseAddr(host)
return err == nil && addr.Is4() && isValidForwardedNodePort(port)
}
}
if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) {
return true
}
addr, err := netip.ParseAddr(value)
return err == nil && addr.Is4()
}
func isValidForwardedNodePort(value string) bool {
if value == "" {
return false
}
if isValidForwardedObfuscatedIdentifier(value) {
return true
}
if len(value) > 5 {
return false
}
port, err := strconv.Atoi(value)
return err == nil && port > 0 && port <= 65535
}
func isValidForwardedObfuscatedIdentifier(value string) bool {
if len(value) < 2 || value[0] != '_' {
return false
}
for i := 1; i < len(value); i++ {
b := value[i]
if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') {
continue
}
switch b {
case '.', '_', '-':
continue
default:
return false
}
}
return true
}
func formatForwardedFor(clientIP string) string { func formatForwardedFor(clientIP string) string {
addr, err := netip.ParseAddr(clientIP) addr, err := netip.ParseAddr(clientIP)
if err != nil { if err != nil {
@ -817,6 +1406,47 @@ func rewriteReverseProxyURL(req *http.Request, target *url.URL) {
} }
} }
func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error {
connectTarget, err := reverseProxyConnectTarget(target)
if err != nil {
return &reverseProxyStatusError{status: http.StatusBadRequest, err: err}
}
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = ""
req.URL.RawPath = ""
req.URL.RawQuery = ""
req.URL.Opaque = connectTarget
req.Host = connectTarget
return nil
}
func reverseProxyConnectTarget(target *url.URL) (string, error) {
if target == nil {
return "", errReverseProxyNilTarget
}
host := target.Hostname()
if host == "" {
return "", errReverseProxyInvalidTarget
}
port := target.Port()
if port == "" {
switch strings.ToLower(target.Scheme) {
case "http":
port = "80"
case "https":
port = "443"
default:
return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme)
}
}
portNum, err := strconv.Atoi(port)
if err != nil || portNum <= 0 || portNum > 65535 {
return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port)
}
return net.JoinHostPort(host, port), nil
}
func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) { func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) {
if base.RawPath == "" && incoming.RawPath == "" { if base.RawPath == "" && incoming.RawPath == "" {
return reverseProxySingleJoiningSlash(base.Path, incoming.Path), "" return reverseProxySingleJoiningSlash(base.Path, incoming.Path), ""
@ -919,6 +1549,59 @@ func cleanReverseProxyQueryParams(rawQuery string) string {
return values.Encode() return values.Encode()
} }
func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
return req != nil && req.Context().Value(http.ServerContextKey) != nil
}
func reverseProxyCanRetryRequest(req *http.Request) bool {
if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) {
return false
}
if req.Body == nil || req.ContentLength == 0 {
return true
}
return req.GetBody != nil
}
func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool {
if err == nil || reverseProxyIsBenignTunnelError(err) {
return false
}
if req != nil && req.Context().Err() != nil {
return false
}
return !errors.Is(err, context.Canceled)
}
func reverseProxyMethodIsSafe(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func reverseProxyIsBenignTunnelError(err error) bool {
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err)
}
func reverseProxyIsClosedBodyError(err error) bool {
if err == nil {
return false
}
var streamErr http2.StreamError
if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel {
return true
}
switch err.Error() {
case "body closed by handler", "http2: response body closed", "response body closed":
return true
default:
return false
}
}
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
return UnwrapResponseWriter(writer) return UnwrapResponseWriter(writer)
} }

352
reverseproxy_lb.go Normal file
View file

@ -0,0 +1,352 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"fmt"
"math/rand/v2"
"net/http"
"net/textproto"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
)
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
type ReverseProxyLoadBalancingConfig struct {
Policy ReverseProxyLBPolicy
Retries int
TryDuration time.Duration
TryInterval time.Duration
}
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
type ReverseProxyPassiveHealthConfig struct {
FailDuration time.Duration
MaxFails int
UnhealthyStatus []int
}
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
type ReverseProxyLBPolicy struct {
kind reverseProxyLBPolicyKind
key string
fallback *ReverseProxyLBPolicy
}
type reverseProxyLBPolicyKind uint8
const (
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
reverseProxyLBPolicyRoundRobin
reverseProxyLBPolicyFirst
reverseProxyLBPolicyLeastConn
reverseProxyLBPolicyIPHash
reverseProxyLBPolicyClientIPHash
reverseProxyLBPolicyURIHash
reverseProxyLBPolicyHeader
reverseProxyLBPolicyQuery
)
type reverseProxyUpstream struct {
key string
target *url.URL
index int
extendedConnectTransport http.RoundTripper
inFlight atomic.Int64
passiveMu sync.Mutex
failures []time.Time
}
func LBRandom() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
}
func LBRoundRobin() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
}
func LBFirst() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
}
func LBLeastConn() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
}
func LBIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
}
func LBClientIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
}
func LBURIHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
}
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
switch policy.kind {
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
reverseProxyLBPolicyURIHash:
return nil
case reverseProxyLBPolicyHeader:
if policy.key == "" {
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
}
case reverseProxyLBPolicyQuery:
if policy.key == "" {
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
}
default:
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
}
if policy.fallback != nil {
return validateReverseProxyLBPolicy(*policy.fallback)
}
return nil
}
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
now := time.Now()
policy := p.config.LoadBalancing.Policy
candidates := p.availableUpstreams(now, excluded)
if len(candidates) == 0 && len(excluded) > 0 {
candidates = p.availableUpstreams(now, nil)
}
if len(candidates) == 0 {
return nil, errReverseProxyNoAvailableUpstreams
}
return p.selectUpstreamWithPolicy(c, candidates, policy), nil
}
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
candidates := make([]*reverseProxyUpstream, 0, len(p.upstreams))
for _, upstream := range p.upstreams {
if _, skip := excluded[upstream.key]; skip {
continue
}
if !upstream.healthy(now, p.config.PassiveHealth) {
continue
}
candidates = append(candidates, upstream)
}
return candidates
}
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
switch policy.kind {
case reverseProxyLBPolicyRoundRobin:
return candidates[p.nextRoundRobinIndex(len(candidates))]
case reverseProxyLBPolicyFirst:
return candidates[0]
case reverseProxyLBPolicyLeastConn:
return p.selectLeastConnUpstream(candidates)
case reverseProxyLBPolicyIPHash:
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
case reverseProxyLBPolicyClientIPHash:
return reverseProxySelectHRW(candidates, c.RequestIP())
case reverseProxyLBPolicyURIHash:
if c.Request == nil || c.Request.URL == nil {
return reverseProxySelectRandom(candidates)
}
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
case reverseProxyLBPolicyHeader:
if c.Request != nil && c.Request.Header != nil {
if values, ok := c.Request.Header[policy.key]; ok {
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyQuery:
if c.Request != nil && c.Request.URL != nil {
if values, ok := c.Request.URL.Query()[policy.key]; ok {
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyRandom:
fallthrough
default:
return reverseProxySelectRandom(candidates)
}
}
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
if size <= 1 {
return 0
}
return int((p.roundRobin.Add(1) - 1) % uint64(size))
}
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
selected := candidates[0]
lowest := selected.inFlight.Load()
ties := []*reverseProxyUpstream{selected}
for _, upstream := range candidates[1:] {
count := upstream.inFlight.Load()
switch {
case count < lowest:
selected = upstream
lowest = count
ties = []*reverseProxyUpstream{upstream}
case count == lowest:
ties = append(ties, upstream)
}
}
if len(ties) == 1 {
return selected
}
return ties[p.nextRoundRobinIndex(len(ties))]
}
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if len(candidates) == 1 {
return candidates[0]
}
return candidates[rand.IntN(len(candidates))]
}
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if key == "" {
return reverseProxySelectRandom(candidates)
}
selected := candidates[0]
bestScore := reverseProxyHRWScore(key, selected.key)
for _, upstream := range candidates[1:] {
score := reverseProxyHRWScore(key, upstream.key)
if score > bestScore {
selected = upstream
bestScore = score
}
}
return selected
}
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
const (
offset64 = 14695981039346656037
prime64 = 1099511628211
)
h := uint64(offset64)
for i := 0; i < len(key); i++ {
h ^= uint64(key[i])
h *= prime64
}
h ^= 0xff
h *= prime64
for i := 0; i < len(upstreamKey); i++ {
h ^= uint64(upstreamKey[i])
h *= prime64
}
return h
}
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
if policy.fallback != nil {
return *policy.fallback
}
return LBRandom()
}
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return true
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
return len(u.failures) < maxFails
}
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
u.failures = append(u.failures, now)
}
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
if len(u.failures) == 0 || window <= 0 {
if window <= 0 {
u.failures = nil
}
return
}
cutoff := now.Add(-window)
keep := 0
for _, failureAt := range u.failures {
if failureAt.Before(cutoff) {
continue
}
u.failures[keep] = failureAt
keep++
}
u.failures = u.failures[:keep]
}
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
if config.FailDuration <= 0 {
return 0
}
if config.MaxFails <= 0 {
return 1
}
return config.MaxFails
}
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
if status <= 0 {
return false
}
for _, unhealthyStatus := range config.UnhealthyStatus {
if status == unhealthyStatus {
return true
}
}
return false
}

File diff suppressed because it is too large Load diff