touka/reverseproxy_lb.go

355 lines
10 KiB
Go

// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"fmt"
"math/rand/v2"
"net/http"
"net/textproto"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
)
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
type ReverseProxyLoadBalancingConfig struct {
Policy ReverseProxyLBPolicy
Retries int
TryDuration time.Duration
TryInterval time.Duration
}
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
type ReverseProxyPassiveHealthConfig struct {
FailDuration time.Duration
MaxFails int
UnhealthyStatus []int
}
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
type ReverseProxyLBPolicy struct {
kind reverseProxyLBPolicyKind
key string
fallback *ReverseProxyLBPolicy
}
type reverseProxyLBPolicyKind uint8
const (
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
reverseProxyLBPolicyRoundRobin
reverseProxyLBPolicyFirst
reverseProxyLBPolicyLeastConn
reverseProxyLBPolicyIPHash
reverseProxyLBPolicyClientIPHash
reverseProxyLBPolicyURIHash
reverseProxyLBPolicyHeader
reverseProxyLBPolicyQuery
)
type reverseProxyUpstream struct {
key string
target *url.URL
index int
useH2C bool
extendedConnectTransport http.RoundTripper
bridgeTransport http.RoundTripper
h2cTransport http.RoundTripper
inFlight atomic.Int64
passiveMu sync.Mutex
failures []time.Time
}
func LBRandom() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
}
func LBRoundRobin() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
}
func LBFirst() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
}
func LBLeastConn() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
}
func LBIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
}
func LBClientIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
}
func LBURIHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
}
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
switch policy.kind {
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
reverseProxyLBPolicyURIHash:
return nil
case reverseProxyLBPolicyHeader:
if policy.key == "" {
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
}
case reverseProxyLBPolicyQuery:
if policy.key == "" {
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
}
default:
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
}
if policy.fallback != nil {
return validateReverseProxyLBPolicy(*policy.fallback)
}
return nil
}
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
now := time.Now()
policy := p.config.LoadBalancing.Policy
candidates := p.availableUpstreams(now, excluded)
if len(candidates) == 0 && len(excluded) > 0 {
candidates = p.availableUpstreams(now, nil)
}
if len(candidates) == 0 {
return nil, errReverseProxyNoAvailableUpstreams
}
return p.selectUpstreamWithPolicy(c, candidates, policy), nil
}
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
candidates := make([]*reverseProxyUpstream, 0, len(p.upstreams))
for _, upstream := range p.upstreams {
if _, skip := excluded[upstream.key]; skip {
continue
}
if !upstream.healthy(now, p.config.PassiveHealth) {
continue
}
candidates = append(candidates, upstream)
}
return candidates
}
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
switch policy.kind {
case reverseProxyLBPolicyRoundRobin:
return candidates[p.nextRoundRobinIndex(len(candidates))]
case reverseProxyLBPolicyFirst:
return candidates[0]
case reverseProxyLBPolicyLeastConn:
return p.selectLeastConnUpstream(candidates)
case reverseProxyLBPolicyIPHash:
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
case reverseProxyLBPolicyClientIPHash:
return reverseProxySelectHRW(candidates, c.RequestIP())
case reverseProxyLBPolicyURIHash:
if c.Request == nil || c.Request.URL == nil {
return reverseProxySelectRandom(candidates)
}
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
case reverseProxyLBPolicyHeader:
if c.Request != nil && c.Request.Header != nil {
if values, ok := c.Request.Header[policy.key]; ok {
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyQuery:
if c.Request != nil && c.Request.URL != nil {
if values, ok := c.Request.URL.Query()[policy.key]; ok {
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyRandom:
fallthrough
default:
return reverseProxySelectRandom(candidates)
}
}
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
if size <= 1 {
return 0
}
return int((p.roundRobin.Add(1) - 1) % uint64(size))
}
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
selected := candidates[0]
lowest := selected.inFlight.Load()
ties := []*reverseProxyUpstream{selected}
for _, upstream := range candidates[1:] {
count := upstream.inFlight.Load()
switch {
case count < lowest:
selected = upstream
lowest = count
ties = []*reverseProxyUpstream{upstream}
case count == lowest:
ties = append(ties, upstream)
}
}
if len(ties) == 1 {
return selected
}
return ties[p.nextRoundRobinIndex(len(ties))]
}
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if len(candidates) == 1 {
return candidates[0]
}
return candidates[rand.IntN(len(candidates))]
}
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if key == "" {
return reverseProxySelectRandom(candidates)
}
selected := candidates[0]
bestScore := reverseProxyHRWScore(key, selected.key)
for _, upstream := range candidates[1:] {
score := reverseProxyHRWScore(key, upstream.key)
if score > bestScore {
selected = upstream
bestScore = score
}
}
return selected
}
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
const (
offset64 = 14695981039346656037
prime64 = 1099511628211
)
h := uint64(offset64)
for i := 0; i < len(key); i++ {
h ^= uint64(key[i])
h *= prime64
}
h ^= 0xff
h *= prime64
for i := 0; i < len(upstreamKey); i++ {
h ^= uint64(upstreamKey[i])
h *= prime64
}
return h
}
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
if policy.fallback != nil {
return *policy.fallback
}
return LBRandom()
}
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return true
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
return len(u.failures) < maxFails
}
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
u.failures = append(u.failures, now)
}
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
if len(u.failures) == 0 || window <= 0 {
if window <= 0 {
u.failures = nil
}
return
}
cutoff := now.Add(-window)
keep := 0
for _, failureAt := range u.failures {
if failureAt.Before(cutoff) {
continue
}
u.failures[keep] = failureAt
keep++
}
u.failures = u.failures[:keep]
}
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
if config.FailDuration <= 0 {
return 0
}
if config.MaxFails <= 0 {
return 1
}
return config.MaxFails
}
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
if status <= 0 {
return false
}
for _, unhealthyStatus := range config.UnhealthyStatus {
if status == unhealthyStatus {
return true
}
}
return false
}