feat(reverseproxy): add upstream balancing and failover

This commit is contained in:
wjqserver 2026-04-02 14:40:56 +08:00
parent 59f190ce3a
commit 919236665b
4 changed files with 1394 additions and 116 deletions

View file

@ -23,6 +23,8 @@ import (
"sync"
"sync/atomic"
"time"
"golang.org/x/net/http2"
)
// ForwardedHeadersPolicy controls how forwarding headers are generated.
@ -44,7 +46,11 @@ type BufferPool interface {
// ReverseProxyConfig configures the reverse proxy handler.
type ReverseProxyConfig struct {
Target *url.URL
Target *url.URL
Targets []string
LoadBalancing ReverseProxyLoadBalancingConfig
PassiveHealth ReverseProxyPassiveHealthConfig
Transport http.RoundTripper
FlushInterval time.Duration
@ -61,17 +67,18 @@ type ReverseProxyConfig struct {
}
var (
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
)
type reverseProxyHandler struct {
config ReverseProxyConfig
target *url.URL
receivedBy string
configError error
extendedConnectTransport http.RoundTripper
config ReverseProxyConfig
upstreams []*reverseProxyUpstream
receivedBy string
configError error
roundRobin atomic.Uint64
}
type reverseProxyStatusError struct {
@ -199,22 +206,16 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc {
}
func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
target := cloneReverseProxyURL(config.Target)
if target != nil {
normalizeReverseProxyTarget(target)
}
proxy := &reverseProxyHandler{
config: config,
target: target,
receivedBy: reverseProxyReceivedBy(config.Via),
}
if config.Transport == nil {
proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
}
if err := validateReverseProxyTarget(target); err != nil {
upstreams, err := buildReverseProxyUpstreams(config)
if err != nil {
proxy.configError = err
} else {
proxy.upstreams = upstreams
}
switch config.ForwardedHeaders {
@ -228,6 +229,11 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
proxy.configError = err
}
}
if proxy.configError == nil {
if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil {
proxy.configError = err
}
}
return proxy
}
@ -240,15 +246,6 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
return
}
transport := p.config.Transport
if transport == nil {
if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil {
transport = p.extendedConnectTransport
} else {
transport = http.DefaultTransport
}
}
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
if err != nil {
p.handleError(c, err)
@ -260,86 +257,64 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
ctx, cancel := p.requestContext(c)
defer cancel()
attempted := make(map[string]struct{}, len(p.upstreams))
attempts := 0
started := time.Now()
var lastErr error
outreq := c.Request.Clone(ctx)
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
outreq.Body = nil
}
if outreq.Body != nil {
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
defer outreq.Body.Close()
}
if outreq.Header == nil {
outreq.Header = make(http.Header)
}
outreq.Close = false
var connectWriter *io.PipeWriter
defer func() {
if connectWriter != nil {
_ = connectWriter.Close()
}
}()
if outreq.Method == http.MethodConnect {
pipeReader, pipeWriter := io.Pipe()
outreq.Body = pipeReader
outreq.ContentLength = -1
defer outreq.Body.Close()
connectWriter = pipeWriter
}
if outreq.Method == http.MethodConnect {
if reverseProxyIsExtendedConnectRequest(outreq) {
rewriteReverseProxyURL(outreq, p.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
} else {
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
p.handleError(c, err)
for {
upstream, err := p.selectUpstream(c, attempted)
if err != nil {
if lastErr != nil {
p.handleError(c, lastErr)
return
}
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err})
return
}
} else {
rewriteReverseProxyURL(outreq, p.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
}
if updatedMaxForwards != "" {
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
}
reqUpType := reverseProxyUpgradeType(outreq.Header)
if reqUpType != "" && !isPrintableASCII(reqUpType) {
p.handleError(c, &reverseProxyStatusError{
status: http.StatusBadRequest,
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
})
attempts++
upstream.inFlight.Add(1)
served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards)
upstream.inFlight.Add(-1)
if served {
return
}
if attemptErr != nil {
lastErr = attemptErr
}
if retriable && p.shouldRetryAttempt(c.Request, attempts, started) {
attempted[upstream.key] = struct{}{}
if !p.waitRetryInterval(ctx, started) {
if lastErr != nil {
p.handleError(c, lastErr)
}
return
}
continue
}
if attemptErr != nil {
p.handleError(c, attemptErr)
return
}
if lastErr != nil {
p.handleError(c, lastErr)
return
}
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams})
return
}
}
removeHopByHopHeaders(outreq.Header)
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
p.addForwardingHeaders(c.Request, outreq)
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
if _, ok := outreq.Header["User-Agent"]; !ok {
outreq.Header.Set("User-Agent", "")
}
if p.config.ModifyRequest != nil {
p.config.ModifyRequest(outreq)
func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) {
outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards)
if err != nil {
return false, err, false
}
defer cleanup()
transport := p.transportForUpstream(c.Request, upstream)
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
var (
roundTripMu sync.Mutex
@ -369,8 +344,13 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
roundTripDone = true
roundTripMu.Unlock()
if err != nil {
p.handleError(c, err)
return
if reverseProxyShouldCountPassiveFailure(outreq, err) {
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
}
return false, err, true
}
if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) {
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
}
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
@ -381,35 +361,34 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
res.TransferEncoding = nil
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return
return true, nil, false
}
handleConnect := p.handleConnectResponse
if reverseProxyIsExtendedConnectRequest(outreq) {
handleConnect = p.handleExtendedConnectResponse
}
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
p.handleError(c, err)
return false, err, false
}
connectWriter = nil
return
return true, nil, false
}
if res.StatusCode == http.StatusSwitchingProtocols {
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return
return true, nil, false
}
if err := p.handleUpgradeResponse(c, outreq, res); err != nil {
p.handleError(c, err)
return false, err, false
}
return
return true, nil, false
}
removeHopByHopHeaders(res.Header)
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return
return true, nil, false
}
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
@ -432,7 +411,7 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
if reverseProxyShouldPanicOnCopyError(c.Request) {
panic(http.ErrAbortHandler)
}
return
return true, nil, false
}
res.Body.Close()
@ -440,13 +419,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
c.Writer.Flush()
}
// Keep the stdlib-compatible fallback here.
// If the backend only exposes additional trailer keys after the body has been
// fully read, the trailer map can grow and those values must be written using
// the TrailerPrefix form instead of the pre-announced bare header keys.
if len(res.Trailer) == announcedTrailers {
reverseProxyCopyHeader(c.Writer.Header(), res.Trailer)
return
return true, nil, false
}
for key, values := range res.Trailer {
@ -455,6 +430,148 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
c.Writer.Header().Add(prefixedKey, value)
}
}
return true, nil, false
}
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
outreq := c.Request.Clone(ctx)
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
outreq.Body = nil
} else if c.Request.GetBody != nil {
body, err := c.Request.GetBody()
if err != nil {
return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err)
}
outreq.Body = body
} else if outreq.Body != nil {
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
}
if outreq.Header == nil {
outreq.Header = make(http.Header)
}
outreq.Close = false
var connectWriter *io.PipeWriter
if outreq.Method == http.MethodConnect {
pipeReader, pipeWriter := io.Pipe()
outreq.Body = pipeReader
outreq.ContentLength = -1
connectWriter = pipeWriter
}
cleanup := func() {
if outreq.Body != nil {
_ = outreq.Body.Close()
}
if connectWriter != nil {
_ = connectWriter.Close()
}
}
if outreq.Method == http.MethodConnect {
if reverseProxyIsExtendedConnectRequest(outreq) {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
} else {
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
cleanup()
return nil, nil, nil, err
}
}
} else {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
}
if updatedMaxForwards != "" {
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
}
reqUpType := reverseProxyUpgradeType(outreq.Header)
if reqUpType != "" && !isPrintableASCII(reqUpType) {
cleanup()
return nil, nil, nil, &reverseProxyStatusError{
status: http.StatusBadRequest,
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
}
}
removeHopByHopHeaders(outreq.Header)
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
p.addForwardingHeaders(c.Request, outreq)
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
if _, ok := outreq.Header["User-Agent"]; !ok {
outreq.Header.Set("User-Agent", "")
}
if p.config.ModifyRequest != nil {
p.config.ModifyRequest(outreq)
}
return outreq, connectWriter, cleanup, nil
}
func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper {
if p.config.Transport != nil {
return p.config.Transport
}
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
return upstream.extendedConnectTransport
}
return http.DefaultTransport
}
func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool {
if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) {
return false
}
lb := p.config.LoadBalancing
if lb.TryDuration > 0 {
return time.Since(started) < lb.TryDuration
}
return attempts <= lb.Retries
}
func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool {
interval := p.config.LoadBalancing.TryInterval
tryDuration := p.config.LoadBalancing.TryDuration
if tryDuration > 0 && interval == 0 {
interval = 250 * time.Millisecond
}
if tryDuration > 0 {
remaining := tryDuration - time.Since(started)
if remaining <= 0 {
return false
}
if interval <= 0 {
return ctx.Err() == nil
}
if interval > remaining {
return false
}
}
if interval <= 0 {
return ctx.Err() == nil
}
timer := time.NewTimer(interval)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
@ -976,6 +1093,54 @@ func validateReverseProxyTarget(target *url.URL) error {
return nil
}
func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) {
if config.Target != nil && len(config.Targets) > 0 {
return nil, errors.New("reverse proxy Target and Targets cannot be used together")
}
targets := make([]*url.URL, 0, max(1, len(config.Targets)))
if config.Target != nil {
target := cloneReverseProxyURL(config.Target)
normalizeReverseProxyTarget(target)
if err := validateReverseProxyTarget(target); err != nil {
return nil, err
}
targets = append(targets, target)
}
for i, rawTarget := range config.Targets {
trimmed := strings.TrimSpace(rawTarget)
if trimmed == "" {
return nil, fmt.Errorf("reverse proxy target at index %d is empty", i)
}
target, err := url.Parse(trimmed)
if err != nil {
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
}
normalizeReverseProxyTarget(target)
if err := validateReverseProxyTarget(target); err != nil {
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
}
targets = append(targets, target)
}
if len(targets) == 0 {
return nil, errReverseProxyNilTarget
}
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
for i, target := range targets {
upstream := &reverseProxyUpstream{
key: fmt.Sprintf("%d:%s", i, target.String()),
target: target,
index: i,
}
if config.Transport == nil {
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target)
}
upstreams = append(upstreams, upstream)
}
return upstreams, nil
}
func validateReverseProxyForwardedBy(value string) error {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
@ -1388,6 +1553,35 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
return req != nil && req.Context().Value(http.ServerContextKey) != nil
}
func reverseProxyCanRetryRequest(req *http.Request) bool {
if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) {
return false
}
if req.Body == nil || req.ContentLength == 0 {
return true
}
return req.GetBody != nil
}
func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool {
if err == nil || reverseProxyIsBenignTunnelError(err) {
return false
}
if req != nil && req.Context().Err() != nil {
return false
}
return !errors.Is(err, context.Canceled)
}
func reverseProxyMethodIsSafe(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func reverseProxyIsBenignTunnelError(err error) bool {
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err)
}
@ -1396,6 +1590,10 @@ func reverseProxyIsClosedBodyError(err error) bool {
if err == nil {
return false
}
var streamErr http2.StreamError
if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel {
return true
}
switch err.Error() {
case "body closed by handler", "http2: response body closed", "response body closed":
return true