touka/reverseproxy.go
wjqserver 2d4aefc86e fix: cut redirect and allow-path routing overhead
Reuse fixed-path and Allow-header buffers so redirect and OPTIONS handling stop rebuilding temporary data on every request. Cache fallback chains and add regression coverage for redirect, 404, 405, and Allow behavior to keep the faster miss paths stable.
2026-04-07 09:06:56 +08:00

1797 lines
46 KiB
Go

// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
"mime"
"net"
"net/http"
"net/http/httptrace"
"net/http/httputil"
"net/netip"
"net/textproto"
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/http2"
)
// ForwardedHeadersPolicy controls how forwarding headers are generated.
// The zero value uses both X-Forwarded-* and RFC 7239 Forwarded headers.
type ForwardedHeadersPolicy int
const (
ForwardedBoth ForwardedHeadersPolicy = iota
ForwardedNone
ForwardedXForwardedOnly
ForwardedRFC7239Only
)
// BufferPool provides temporary buffers for response body copying.
type BufferPool interface {
Get() []byte
Put([]byte)
}
// ReverseProxyConfig configures the reverse proxy handler.
type ReverseProxyConfig struct {
Target *url.URL
Targets []string
LoadBalancing ReverseProxyLoadBalancingConfig
PassiveHealth ReverseProxyPassiveHealthConfig
Transport http.RoundTripper
FlushInterval time.Duration
BufferPool BufferPool
AllowH2CUpstream bool
ModifyRequest func(*http.Request)
ModifyResponse func(*http.Response) error
ErrorHandler func(http.ResponseWriter, *http.Request, error)
ForwardedHeaders ForwardedHeadersPolicy
ForwardedBy string
Via string
PreserveHost bool
}
var (
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
)
type reverseProxyHandler struct {
config ReverseProxyConfig
upstreams []*reverseProxyUpstream
receivedBy string
configError error
roundRobin atomic.Uint64
}
type reverseProxyStatusError struct {
status int
err error
}
type reverseProxyExtendedConnectBridge struct {
body io.ReadCloser
}
type reverseProxyH2ReadWriteCloser struct {
io.ReadCloser
ResponseWriter
controller *http.ResponseController
}
func (rwc *reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) {
n, err := rwc.ResponseWriter.Write(p)
if err != nil {
return n, err
}
if err := rwc.controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
return n, err
}
return n, nil
}
func (rwc *reverseProxyH2ReadWriteCloser) Close() error {
if rwc.ReadCloser == nil {
return nil
}
return rwc.ReadCloser.Close()
}
func (e *reverseProxyStatusError) Error() string {
if e == nil || e.err == nil {
return ""
}
return e.err.Error()
}
func (e *reverseProxyStatusError) Unwrap() error {
if e == nil {
return nil
}
return e.err
}
type noopCloseReader struct {
readCloser io.ReadCloser
closed atomic.Bool
}
func (n *noopCloseReader) Read(p []byte) (int, error) {
if n.closed.Load() {
return 0, errors.New("reverse proxy read on closed body")
}
return n.readCloser.Read(p)
}
func (n *noopCloseReader) Close() error {
n.closed.Store(true)
return nil
}
type maxLatencyWriter struct {
dst ResponseWriter
latency time.Duration
mu sync.Mutex
t *time.Timer
flushPending bool
}
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
n, err := m.dst.Write(p)
if m.latency < 0 {
m.dst.Flush()
return n, err
}
if m.flushPending {
return n, err
}
if m.t == nil {
m.t = time.AfterFunc(m.latency, m.delayedFlush)
} else {
m.t.Reset(m.latency)
}
m.flushPending = true
return n, err
}
func (m *maxLatencyWriter) delayedFlush() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.flushPending {
return
}
m.dst.Flush()
m.flushPending = false
}
func (m *maxLatencyWriter) stop() {
m.mu.Lock()
defer m.mu.Unlock()
m.flushPending = false
if m.t != nil {
m.t.Stop()
}
}
type switchProtocolCopier struct {
user io.ReadWriter
backend io.ReadWriter
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
if _, err := io.Copy(c.user, c.backend); err != nil {
errc <- err
return
}
if cw, ok := c.user.(interface{ CloseWrite() error }); ok {
errc <- cw.CloseWrite()
return
}
errc <- errReverseProxyCopyDone
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
if _, err := io.Copy(c.backend, c.user); err != nil {
errc <- err
return
}
if cw, ok := c.backend.(interface{ CloseWrite() error }); ok {
errc <- cw.CloseWrite()
return
}
errc <- errReverseProxyCopyDone
}
// ReverseProxy returns a handler that proxies requests to the configured backend.
func ReverseProxy(config ReverseProxyConfig) HandlerFunc {
proxy := newReverseProxyHandler(config)
return func(c *Context) {
proxy.ServeHTTP(c)
}
}
func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
proxy := &reverseProxyHandler{
config: config,
receivedBy: reverseProxyReceivedBy(config.Via),
}
upstreams, err := buildReverseProxyUpstreams(config)
if err != nil {
proxy.configError = err
} else {
proxy.upstreams = upstreams
}
switch config.ForwardedHeaders {
case ForwardedBoth, ForwardedNone, ForwardedXForwardedOnly, ForwardedRFC7239Only:
default:
proxy.config.ForwardedHeaders = ForwardedBoth
}
proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy)
if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) {
if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil {
proxy.configError = err
}
}
if proxy.configError == nil {
if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil {
proxy.configError = err
}
}
return proxy
}
func (p *reverseProxyHandler) ServeHTTP(c *Context) {
defer c.Abort()
if p.configError != nil {
p.handleError(c, &reverseProxyStatusError{status: http.StatusInternalServerError, err: p.configError})
return
}
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
if err != nil {
p.handleError(c, err)
return
}
if handledLocally {
return
}
ctx, cancel := p.requestContext(c)
defer cancel()
attempted := make(map[string]struct{}, len(p.upstreams))
attempts := 0
started := time.Now()
var lastErr error
for {
upstream, err := p.selectUpstream(c, attempted)
if err != nil {
if lastErr != nil {
p.handleError(c, lastErr)
return
}
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err})
return
}
attempts++
upstream.inFlight.Add(1)
served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards)
upstream.inFlight.Add(-1)
if served {
return
}
if attemptErr != nil {
lastErr = attemptErr
}
if retriable && p.shouldRetryAttempt(c.Request, attempts, started) {
attempted[upstream.key] = struct{}{}
if !p.waitRetryInterval(ctx, started) {
if lastErr != nil {
p.handleError(c, lastErr)
}
return
}
continue
}
if attemptErr != nil {
p.handleError(c, attemptErr)
return
}
if lastErr != nil {
p.handleError(c, lastErr)
return
}
p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams})
return
}
}
func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) {
outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards)
if err != nil {
return false, err, false
}
defer cleanup()
transport := p.transportForUpstream(outreq, upstream)
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
var (
roundTripMu sync.Mutex
roundTripDone bool
)
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
roundTripMu.Lock()
defer roundTripMu.Unlock()
if roundTripDone {
return nil
}
h := c.Writer.Header()
saved := h.Clone()
clear(h)
reverseProxyCopyHeader(h, http.Header(header))
rawWriter.WriteHeader(code)
clear(h)
reverseProxyCopyHeader(h, saved)
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
roundTripMu.Lock()
roundTripDone = true
roundTripMu.Unlock()
if err != nil {
if reverseProxyShouldCountPassiveFailure(outreq, err) {
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
}
return false, err, true
}
if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) {
upstream.recordFailure(time.Now(), p.config.PassiveHealth)
}
if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil {
if res.StatusCode == http.StatusSwitchingProtocols {
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return true, nil, false
}
if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil {
return false, err, false
}
return true, nil, false
}
return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false
}
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
removeHopByHopHeaders(res.Header)
res.Header.Del("Content-Length")
res.Header.Del("Transfer-Encoding")
res.ContentLength = -1
res.TransferEncoding = nil
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return true, nil, false
}
handleConnect := p.handleConnectResponse
if reverseProxyIsExtendedConnectRequest(outreq) {
handleConnect = p.handleExtendedConnectResponse
}
if err := handleConnect(c, outreq, res, connectWriter); err != nil {
return false, err, false
}
return true, nil, false
}
if res.StatusCode == http.StatusSwitchingProtocols {
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return true, nil, false
}
if err := p.handleUpgradeResponse(c, outreq, res); err != nil {
return false, err, false
}
return true, nil, false
}
removeHopByHopHeaders(res.Header)
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
if !p.modifyResponse(c, res, outreq) {
return true, nil, false
}
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
announcedTrailers := len(res.Trailer)
if announcedTrailers > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for key := range res.Trailer {
trailerKeys = append(trailerKeys, key)
}
c.Writer.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
c.Writer.WriteHeader(res.StatusCode)
if err := p.copyResponse(c.Writer, res.Body, p.flushInterval(res)); err != nil {
defer res.Body.Close()
c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err))
p.logf(c, "reverse proxy body copy failed: %v", err)
if reverseProxyShouldPanicOnCopyError(c.Request) {
panic(http.ErrAbortHandler)
}
return true, nil, false
}
res.Body.Close()
if len(res.Trailer) > 0 {
c.Writer.Flush()
}
if len(res.Trailer) == announcedTrailers {
reverseProxyCopyHeader(c.Writer.Header(), res.Trailer)
return true, nil, false
}
for key, values := range res.Trailer {
prefixedKey := http.TrailerPrefix + key
for _, value := range values {
c.Writer.Header().Add(prefixedKey, value)
}
}
return true, nil, false
}
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
outreq := c.Request.Clone(ctx)
bridgeCtx, bridged, err := reverseProxyPrepareExtendedConnectBridge(outreq)
if err != nil {
return nil, nil, nil, err
}
if bridged {
outreq = outreq.WithContext(bridgeCtx)
}
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
outreq.Body = nil
} else if c.Request.GetBody != nil {
body, err := c.Request.GetBody()
if err != nil {
return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err)
}
outreq.Body = body
} else if outreq.Body != nil {
outreq.Body = &noopCloseReader{readCloser: outreq.Body}
}
if outreq.Header == nil {
outreq.Header = make(http.Header)
}
outreq.Close = false
var connectWriter *io.PipeWriter
if outreq.Method == http.MethodConnect && !bridged {
pipeReader, pipeWriter := io.Pipe()
outreq.Body = pipeReader
outreq.ContentLength = -1
connectWriter = pipeWriter
}
cleanup := func() {
if outreq.Body != nil {
_ = outreq.Body.Close()
}
if connectWriter != nil {
_ = connectWriter.Close()
}
}
if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) {
if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil {
cleanup()
return nil, nil, nil, err
}
} else {
rewriteReverseProxyURL(outreq, upstream.target)
if !p.config.PreserveHost {
outreq.Host = ""
}
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
}
if updatedMaxForwards != "" {
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
}
reqUpType := reverseProxyUpgradeType(outreq.Header)
if reqUpType != "" && !isPrintableASCII(reqUpType) {
cleanup()
return nil, nil, nil, &reverseProxyStatusError{
status: http.StatusBadRequest,
err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType),
}
}
removeHopByHopHeaders(outreq.Header)
if headerValuesContainToken(c.Request.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
p.addForwardingHeaders(c.Request, outreq)
appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy)
if _, ok := outreq.Header["User-Agent"]; !ok {
outreq.Header.Set("User-Agent", "")
}
if p.config.ModifyRequest != nil {
p.config.ModifyRequest(outreq)
}
return outreq, connectWriter, cleanup, nil
}
func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper {
if p.config.Transport != nil {
return p.config.Transport
}
if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil {
if upstream.bridgeTransport != nil {
return upstream.bridgeTransport
}
return http.DefaultTransport
}
if upstream.useH2C && upstream.h2cTransport != nil {
return upstream.h2cTransport
}
if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil {
return upstream.extendedConnectTransport
}
return http.DefaultTransport
}
func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool {
if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) {
return false
}
lb := p.config.LoadBalancing
if lb.TryDuration > 0 {
return time.Since(started) < lb.TryDuration
}
return attempts <= lb.Retries
}
func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool {
interval := p.config.LoadBalancing.TryInterval
tryDuration := p.config.LoadBalancing.TryDuration
if tryDuration > 0 && interval == 0 {
interval = 250 * time.Millisecond
}
if tryDuration > 0 {
remaining := tryDuration - time.Since(started)
if remaining <= 0 {
return false
}
if interval <= 0 {
return ctx.Err() == nil
}
if interval > remaining {
return false
}
}
if interval <= 0 {
return ctx.Err() == nil
}
timer := time.NewTimer(interval)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
if c == nil || c.Request == nil {
return "", false, nil
}
switch c.Request.Method {
case http.MethodOptions, http.MethodTrace:
default:
return "", false, nil
}
rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards"))
if rawValue == "" {
return "", false, nil
}
value, err := strconv.Atoi(rawValue)
if err != nil || value < 0 {
return "", false, &reverseProxyStatusError{
status: http.StatusBadRequest,
err: fmt.Errorf("invalid Max-Forwards value %q", rawValue),
}
}
if value == 0 {
switch c.Request.Method {
case http.MethodTrace:
return "", true, p.writeLocalTraceResponse(c)
case http.MethodOptions:
p.writeLocalOptionsResponse(c)
return "", true, nil
}
}
return strconv.Itoa(value - 1), false, nil
}
func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error {
if c == nil || c.Request == nil {
return nil
}
traceReq := c.Request.Clone(c.Request.Context())
traceReq.Body = nil
traceReq.ContentLength = 0
traceReq.TransferEncoding = nil
traceReq.RequestURI = c.Request.RequestURI
if traceReq.RequestURI == "" && traceReq.URL != nil {
traceReq.RequestURI = traceReq.URL.RequestURI()
}
traceReq.Header = traceReq.Header.Clone()
for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} {
traceReq.Header.Del(key)
}
dump, err := httputil.DumpRequest(traceReq, false)
if err != nil {
return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err}
}
c.Writer.Header().Set("Content-Type", "message/http")
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(dump)
return err
}
func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) {
if c == nil {
return
}
if c.engine != nil {
if c.Request != nil && c.Request.RequestURI != "*" {
if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 {
c.allowedMethodsBuf = allow[:0]
allowHeader := c.allowHeaderBuf[:0]
for i, method := range allow {
if i > 0 {
allowHeader = append(allowHeader, ',', ' ')
}
allowHeader = append(allowHeader, method...)
}
c.allowHeaderBuf = allowHeader[:0]
c.Writer.Header().Set("Allow", BytesToString(allowHeader))
}
}
}
c.Writer.WriteHeader(http.StatusOK)
}
func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) {
ctx := c.Request.Context()
if ctx.Done() != nil {
return ctx, func() {}
}
// Follow the same compatibility path as net/http/httputil.ReverseProxy:
// request contexts are normally cancelable, but middleware can still replace
// c.Request with one backed by context.Background/TODO or another context with
// a nil Done channel. In that case CloseNotifier still provides disconnect
// propagation for the upstream round trip.
rawWriter := reverseProxyBaseResponseWriter(c.Writer)
cn, ok := rawWriter.(http.CloseNotifier)
if !ok {
return ctx, func() {}
}
ctx, cancel := context.WithCancel(ctx)
notifyChan := cn.CloseNotify()
go func() {
select {
case <-notifyChan:
cancel()
case <-ctx.Done():
}
}()
return ctx, cancel
}
func (p *reverseProxyHandler) addForwardingHeaders(in *http.Request, out *http.Request) {
if p.config.ForwardedHeaders == ForwardedNone {
return
}
clientIP := reverseProxyClientIP(in.RemoteAddr)
scheme := reverseProxyRequestScheme(in)
host := in.Host
if p.config.ForwardedHeaders == ForwardedBoth || p.config.ForwardedHeaders == ForwardedXForwardedOnly {
if clientIP != "" {
appendXForwardedFor(out.Header, clientIP)
}
if host != "" {
if len(out.Header.Values("X-Forwarded-Host")) == 0 {
out.Header.Set("X-Forwarded-Host", host)
}
}
if scheme != "" {
if len(out.Header.Values("X-Forwarded-Proto")) == 0 {
out.Header.Set("X-Forwarded-Proto", scheme)
}
}
}
if p.config.ForwardedHeaders == ForwardedBoth || p.config.ForwardedHeaders == ForwardedRFC7239Only {
if forwardedValue := buildForwardedHeaderValue(clientIP, p.config.ForwardedBy, host, scheme); forwardedValue != "" {
if prior := out.Header.Values("Forwarded"); len(prior) > 0 {
forwardedValue = strings.Join(prior, ", ") + ", " + forwardedValue
out.Header.Del("Forwarded")
}
out.Header.Add("Forwarded", forwardedValue)
}
}
}
func appendXForwardedFor(header http.Header, clientIP string) {
if clientIP == "" {
return
}
prior := header.Values("X-Forwarded-For")
if len(prior) == 0 {
header.Set("X-Forwarded-For", clientIP)
return
}
header.Set("X-Forwarded-For", strings.Join(prior, ", ")+", "+clientIP)
}
func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool {
if p.config.ModifyResponse == nil {
return true
}
if err := p.config.ModifyResponse(res); err != nil {
res.Body.Close()
p.handleError(c, err)
return false
}
return true
}
func (p *reverseProxyHandler) handleError(c *Context, err error) {
if err == nil {
return
}
c.AddError(err)
if c.Writer.IsHijacked() {
p.logf(c, "reverse proxy error after hijack: %v", err)
return
}
if p.config.ErrorHandler != nil {
p.config.ErrorHandler(c.Writer, c.Request, err)
if c.Writer.Written() || c.Writer.IsHijacked() {
return
}
}
c.ErrorUseHandle(reverseProxyStatusCode(err), err)
}
func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error {
reqUpType := reverseProxyUpgradeType(req.Header)
resUpType := reverseProxyUpgradeType(res.Header)
if reqUpType == "" || resUpType == "" {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: fmt.Errorf("invalid upgrade negotiation: request protocol=%q, response protocol=%q", reqUpType, resUpType),
}
}
if !isPrintableASCII(resUpType) {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType),
}
}
if !strings.EqualFold(reqUpType, resUpType) {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType),
}
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: errors.New("backend returned 101 response without writable body"),
}
}
clientConn, brw, err := c.Writer.Hijack()
if err != nil {
backConn.Close()
status := http.StatusBadGateway
if errors.Is(err, http.ErrNotSupported) {
status = http.StatusNotImplemented
}
return &reverseProxyStatusError{status: status, err: err}
}
defer clientConn.Close()
defer backConn.Close()
backConnClosed := make(chan struct{})
go func() {
select {
case <-req.Context().Done():
case <-backConnClosed:
}
backConn.Close()
}()
defer close(backConnClosed)
res.Body = nil
if err := res.Write(brw); err != nil {
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
if err := brw.Flush(); err != nil {
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
errc := make(chan error, 2)
copyer := switchProtocolCopier{user: clientConn, backend: backConn}
go copyer.copyToBackend(errc)
go copyer.copyFromBackend(errc)
firstErr := <-errc
if firstErr == nil {
firstErr = <-errc
}
if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) {
return nil
}
return firstErr
}
func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
if backWrite == nil {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"),
}
}
backRead := res.Body
clientConn, brw, err := c.Writer.Hijack()
if err != nil {
backRead.Close()
_ = backWrite.Close()
status := http.StatusBadGateway
if errors.Is(err, http.ErrNotSupported) {
status = http.StatusNotImplemented
}
return &reverseProxyStatusError{status: status, err: err}
}
defer clientConn.Close()
defer backRead.Close()
defer backWrite.Close()
backConnClosed := make(chan struct{})
go func() {
select {
case <-req.Context().Done():
case <-backConnClosed:
}
backRead.Close()
_ = backWrite.Close()
}()
defer close(backConnClosed)
res.Body = nil
if err := res.Write(brw); err != nil {
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
if err := brw.Flush(); err != nil {
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
errc := make(chan error, 2)
go func() {
if _, err := io.Copy(clientConn, backRead); err != nil {
errc <- err
return
}
if cw, ok := clientConn.(interface{ CloseWrite() error }); ok {
errc <- cw.CloseWrite()
return
}
errc <- errReverseProxyCopyDone
}()
go func() {
if _, err := io.Copy(backWrite, clientConn); err != nil {
errc <- err
return
}
errc <- backWrite.Close()
}()
firstErr := <-errc
if firstErr == nil {
firstErr = <-errc
}
if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) {
return nil
}
return firstErr
}
func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error {
if c == nil || c.Request == nil {
res.Body.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")}
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: errors.New("backend returned bridged websocket response without writable body"),
}
}
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
backConn.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
responseHeader := c.Writer.Header()
reverseProxyCopyHeader(responseHeader, res.Header)
removeHopByHopHeaders(responseHeader)
responseHeader.Del("Sec-WebSocket-Accept")
c.Writer.WriteHeader(http.StatusOK)
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
backConn.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller}
var closeOnce sync.Once
closeTunnel := func() {
closeOnce.Do(func() {
_ = conn.Close()
_ = backConn.Close()
})
}
go func() {
<-req.Context().Done()
closeTunnel()
}()
errc := make(chan error, 2)
copyer := switchProtocolCopier{user: conn, backend: backConn}
go copyer.copyToBackend(errc)
go copyer.copyFromBackend(errc)
var firstErr error
for i := 0; i < 2; i++ {
err := <-errc
if reverseProxyIsBenignTunnelError(err) {
continue
}
if firstErr == nil {
firstErr = err
closeTunnel()
}
}
closeTunnel()
if reverseProxyIsBenignTunnelError(firstErr) {
return nil
}
return firstErr
}
func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
if c == nil || c.Request == nil {
res.Body.Close()
if backWrite != nil {
_ = backWrite.Close()
}
return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")}
}
if backWrite == nil {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"),
}
}
controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer))
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
res.Body.Close()
_ = backWrite.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
reverseProxyCopyHeader(c.Writer.Header(), res.Header)
c.Writer.WriteHeader(res.StatusCode)
if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
res.Body.Close()
_ = backWrite.Close()
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
}
var closeOnce sync.Once
closeTunnel := func() {
closeOnce.Do(func() {
_ = c.Request.Body.Close()
_ = backWrite.Close()
_ = res.Body.Close()
})
}
go func() {
<-req.Context().Done()
closeTunnel()
}()
errc := make(chan error, 2)
go func() {
_, err := io.Copy(backWrite, c.Request.Body)
closeErr := backWrite.Close()
if err != nil && !reverseProxyIsBenignTunnelError(err) {
errc <- err
return
}
errc <- closeErr
}()
go func() {
copyErr := p.copyResponse(c.Writer, res.Body, -1)
closeErr := res.Body.Close()
if copyErr != nil {
errc <- copyErr
return
}
errc <- closeErr
}()
var firstErr error
for i := 0; i < 2; i++ {
err := <-errc
if reverseProxyIsBenignTunnelError(err) {
continue
}
if firstErr == nil {
firstErr = err
closeTunnel()
}
}
closeTunnel()
if reverseProxyIsBenignTunnelError(firstErr) {
return nil
}
return firstErr
}
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
return -1
}
if res.ContentLength == -1 {
return -1
}
return p.config.FlushInterval
}
func (p *reverseProxyHandler) copyResponse(dst ResponseWriter, src io.Reader, flushInterval time.Duration) error {
var writer io.Writer = dst
if flushInterval != 0 {
mlw := &maxLatencyWriter{dst: dst, latency: flushInterval}
defer mlw.stop()
mlw.flushPending = true
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
writer = mlw
}
var buf []byte
if p.config.BufferPool != nil {
buf = p.config.BufferPool.Get()
defer p.config.BufferPool.Put(buf)
}
_, err := p.copyBuffer(writer, src, buf)
return err
}
func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
if len(buf) == 0 {
buf = make([]byte, 32*1024)
}
var written int64
for {
nr, rerr := src.Read(buf)
if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) {
p.logf(nil, "reverse proxy read error during body copy: %v", rerr)
}
if nr > 0 {
nw, werr := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
}
if werr != nil {
return written, werr
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if rerr != nil {
if errors.Is(rerr, io.EOF) {
return written, nil
}
return written, rerr
}
}
}
func (p *reverseProxyHandler) logf(c *Context, format string, args ...any) {
if c != nil {
if logger := c.GetLogger(); logger != nil {
logger.Errorf(format, args...)
return
}
}
log.Printf(format, args...)
}
func reverseProxyStatusCode(err error) int {
var statusErr *reverseProxyStatusError
if errors.As(err, &statusErr) && statusErr.status > 0 {
return statusErr.status
}
var netErr net.Error
if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) {
return http.StatusGatewayTimeout
}
return http.StatusBadGateway
}
func validateReverseProxyTarget(target *url.URL) error {
if target == nil {
return errReverseProxyNilTarget
}
if target.Scheme == "" || target.Host == "" {
return errReverseProxyInvalidTarget
}
return nil
}
func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) {
if config.Target != nil && len(config.Targets) > 0 {
return nil, errors.New("reverse proxy Target and Targets cannot be used together")
}
targets := make([]*url.URL, 0, max(1, len(config.Targets)))
if config.Target != nil {
target := cloneReverseProxyURL(config.Target)
normalizeReverseProxyTarget(target)
if err := validateReverseProxyTarget(target); err != nil {
return nil, err
}
targets = append(targets, target)
}
for i, rawTarget := range config.Targets {
trimmed := strings.TrimSpace(rawTarget)
if trimmed == "" {
return nil, fmt.Errorf("reverse proxy target at index %d is empty", i)
}
target, err := url.Parse(trimmed)
if err != nil {
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
}
normalizeReverseProxyTarget(target)
if err := validateReverseProxyTarget(target); err != nil {
return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err)
}
targets = append(targets, target)
}
if len(targets) == 0 {
return nil, errReverseProxyNilTarget
}
upstreams := make([]*reverseProxyUpstream, 0, len(targets))
for i, target := range targets {
useH2C := strings.EqualFold(target.Scheme, "h2c")
if useH2C {
target = cloneReverseProxyURL(target)
target.Scheme = "http"
}
upstream := &reverseProxyUpstream{
key: fmt.Sprintf("%d:%s", i, target.String()),
target: target,
index: i,
useH2C: useH2C || config.AllowH2CUpstream,
}
if config.Transport == nil {
upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport()
upstream.bridgeTransport = newHTTP1BridgeTransport()
if upstream.useH2C {
upstream.h2cTransport = newH2CTransport()
}
}
upstreams = append(upstreams, upstream)
}
return upstreams, nil
}
func validateReverseProxyForwardedBy(value string) error {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return nil
}
if !isValidForwardedNodeIdentifier(trimmed) {
return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value)
}
return nil
}
func normalizeReverseProxyTarget(target *url.URL) {
switch strings.ToLower(target.Scheme) {
case "ws":
target.Scheme = "http"
case "wss":
target.Scheme = "https"
}
}
func cloneReverseProxyURL(target *url.URL) *url.URL {
if target == nil {
return nil
}
clone := *target
return &clone
}
func reverseProxyReceivedBy(configValue string) string {
trimmed := strings.TrimSpace(configValue)
if trimmed != "" {
return trimmed
}
return "touka-engine"
}
func reverseProxyClientIP(remoteAddr string) string {
if remoteAddr == "" {
return ""
}
if addrPort, err := netip.ParseAddrPort(remoteAddr); err == nil {
return addrPort.Addr().String()
}
host, _, err := net.SplitHostPort(remoteAddr)
if err == nil {
if addr, parseErr := netip.ParseAddr(host); parseErr == nil {
return addr.String()
}
return host
}
if addr, err := netip.ParseAddr(remoteAddr); err == nil {
return addr.String()
}
return remoteAddr
}
func reverseProxyRequestScheme(req *http.Request) string {
if req == nil {
return ""
}
if req.TLS != nil {
return "https"
}
if req.URL != nil {
scheme := strings.ToLower(req.URL.Scheme)
if scheme != "" {
return scheme
}
}
return "http"
}
func buildForwardedHeaderValue(clientIP, by, host, scheme string) string {
pairs := make([]string, 0, 4)
if by != "" {
pairs = append(pairs, "by="+formatForwardedParameterValue(by))
}
if clientIP != "" {
pairs = append(pairs, "for="+formatForwardedFor(clientIP))
}
if host != "" {
pairs = append(pairs, "host="+formatForwardedParameterValue(host))
}
if scheme != "" {
pairs = append(pairs, "proto="+formatForwardedParameterValue(strings.ToLower(scheme)))
}
if len(pairs) == 0 {
return ""
}
return strings.Join(pairs, ";")
}
func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
}
func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) {
if req == nil {
return context.Background(), false, nil
}
protocol := reverseProxyExtendedConnectProtocol(req)
if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
return req.Context(), false, nil
}
bridge := &reverseProxyExtendedConnectBridge{body: req.Body}
ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge)
req.Header.Del(":protocol")
req.Method = http.MethodGet
req.Body = http.NoBody
req.ContentLength = 0
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Version", "13")
key, err := reverseProxyGenerateWebSocketKey()
if err != nil {
return nil, false, fmt.Errorf("reverse proxy failed to generate websocket key: %w", err)
}
req.Header.Set("Sec-WebSocket-Key", key)
return ctx, true, nil
}
func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge {
if ctx == nil {
return nil
}
bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge)
return bridge
}
func reverseProxyGenerateWebSocketKey() (string, error) {
key := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
func reverseProxyIsExtendedConnectRequest(req *http.Request) bool {
return reverseProxyExtendedConnectProtocol(req) != ""
}
func reverseProxyExtendedConnectProtocol(req *http.Request) string {
if req == nil || req.Method != http.MethodConnect || req.Header == nil {
return ""
}
return textproto.TrimString(req.Header.Get(":protocol"))
}
func isValidForwardedNodeIdentifier(value string) bool {
if value == "" {
return false
}
if strings.HasPrefix(value, "[") {
closing := strings.IndexByte(value, ']')
if closing <= 1 {
return false
}
addr, err := netip.ParseAddr(value[1:closing])
if err != nil || !addr.Is6() {
return false
}
if closing == len(value)-1 {
return true
}
if value[closing+1] != ':' {
return false
}
return isValidForwardedNodePort(value[closing+2:])
}
host, port, hasPort := strings.Cut(value, ":")
if hasPort {
switch {
case host == "unknown", isValidForwardedObfuscatedIdentifier(host):
return isValidForwardedNodePort(port)
default:
addr, err := netip.ParseAddr(host)
return err == nil && addr.Is4() && isValidForwardedNodePort(port)
}
}
if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) {
return true
}
addr, err := netip.ParseAddr(value)
return err == nil && addr.Is4()
}
func isValidForwardedNodePort(value string) bool {
if value == "" {
return false
}
if isValidForwardedObfuscatedIdentifier(value) {
return true
}
if len(value) > 5 {
return false
}
port, err := strconv.Atoi(value)
return err == nil && port > 0 && port <= 65535
}
func isValidForwardedObfuscatedIdentifier(value string) bool {
if len(value) < 2 || value[0] != '_' {
return false
}
for i := 1; i < len(value); i++ {
b := value[i]
if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') {
continue
}
switch b {
case '.', '_', '-':
continue
default:
return false
}
}
return true
}
func formatForwardedFor(clientIP string) string {
addr, err := netip.ParseAddr(clientIP)
if err != nil {
return formatForwardedParameterValue(clientIP)
}
if addr.Is6() {
return quoteForwardedString("[" + addr.String() + "]")
}
return addr.String()
}
func formatForwardedParameterValue(value string) string {
if isToken(value) {
return value
}
return quoteForwardedString(value)
}
func quoteForwardedString(value string) string {
replacer := strings.NewReplacer(`\`, `\\`, `"`, `\"`)
return `"` + replacer.Replace(value) + `"`
}
func isToken(value string) bool {
if value == "" {
return false
}
for i := 0; i < len(value); i++ {
if !isTokenChar(value[i]) {
return false
}
}
return true
}
func isTokenChar(b byte) bool {
if b >= '0' && b <= '9' {
return true
}
if b >= 'A' && b <= 'Z' {
return true
}
if b >= 'a' && b <= 'z' {
return true
}
switch b {
case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~':
return true
default:
return false
}
}
func appendViaHeader(header http.Header, protocol, receivedBy string) {
if header == nil || receivedBy == "" {
return
}
if protocol == "" {
protocol = "1.1"
}
header.Add("Via", protocol+" "+receivedBy)
}
func reverseProxyViaProtocol(major, minor int, raw string) string {
if major > 0 {
return strconv.Itoa(major) + "." + strconv.Itoa(minor)
}
if strings.HasPrefix(raw, "HTTP/") {
return strings.TrimPrefix(raw, "HTTP/")
}
return raw
}
func rewriteReverseProxyURL(req *http.Request, target *url.URL) {
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path, req.URL.RawPath = joinReverseProxyURLPath(target, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
}
func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error {
connectTarget, err := reverseProxyConnectTarget(target)
if err != nil {
return &reverseProxyStatusError{status: http.StatusBadRequest, err: err}
}
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = ""
req.URL.RawPath = ""
req.URL.RawQuery = ""
req.URL.Opaque = connectTarget
req.Host = connectTarget
return nil
}
func reverseProxyConnectTarget(target *url.URL) (string, error) {
if target == nil {
return "", errReverseProxyNilTarget
}
host := target.Hostname()
if host == "" {
return "", errReverseProxyInvalidTarget
}
port := target.Port()
if port == "" {
switch strings.ToLower(target.Scheme) {
case "http":
port = "80"
case "https":
port = "443"
default:
return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme)
}
}
portNum, err := strconv.Atoi(port)
if err != nil || portNum <= 0 || portNum > 65535 {
return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port)
}
return net.JoinHostPort(host, port), nil
}
func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) {
if base.RawPath == "" && incoming.RawPath == "" {
return reverseProxySingleJoiningSlash(base.Path, incoming.Path), ""
}
baseEscaped := base.EscapedPath()
incomingEscaped := incoming.EscapedPath()
baseSlash := strings.HasSuffix(baseEscaped, "/")
incomingSlash := strings.HasPrefix(incomingEscaped, "/")
switch {
case baseSlash && incomingSlash:
return base.Path + incoming.Path[1:], baseEscaped + incomingEscaped[1:]
case !baseSlash && !incomingSlash:
return base.Path + "/" + incoming.Path, baseEscaped + "/" + incomingEscaped
default:
return base.Path + incoming.Path, baseEscaped + incomingEscaped
}
}
func reverseProxySingleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
default:
return a + b
}
}
func reverseProxyCopyHeader(dst, src http.Header) {
for key, values := range src {
for _, value := range values {
dst.Add(key, value)
}
}
}
var reverseProxyHopHeaders = []string{
"Connection",
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
"Upgrade",
}
func removeHopByHopHeaders(header http.Header) {
for _, connectionValue := range header["Connection"] {
for _, token := range strings.Split(connectionValue, ",") {
trimmed := textproto.TrimString(token)
if trimmed != "" {
header.Del(trimmed)
}
}
}
for _, hopHeader := range reverseProxyHopHeaders {
header.Del(hopHeader)
}
}
func reverseProxyUpgradeType(header http.Header) string {
if !headerValuesContainToken(header["Connection"], "Upgrade") {
return ""
}
return header.Get("Upgrade")
}
func headerValuesContainToken(values []string, token string) bool {
if token == "" {
return false
}
for _, value := range values {
for _, part := range strings.Split(value, ",") {
if strings.EqualFold(textproto.TrimString(part), token) {
return true
}
}
}
return false
}
func cleanReverseProxyQueryParams(rawQuery string) string {
if rawQuery == "" {
return ""
}
// Normalize the outgoing query string so the proxy and upstream do not see
// different semantics for non-standard separators or malformed pairs.
// This can change the exact textual form of the original query and may drop
// parts that net/url rejects, but it keeps proxy-chain parsing behavior more
// consistent and reduces parameter-smuggling ambiguity.
values, _ := url.ParseQuery(rawQuery)
return values.Encode()
}
func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
return req != nil && req.Context().Value(http.ServerContextKey) != nil
}
func reverseProxyCanRetryRequest(req *http.Request) bool {
if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) {
return false
}
if req.Body == nil || req.ContentLength == 0 {
return true
}
return req.GetBody != nil
}
func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool {
if err == nil || reverseProxyIsBenignTunnelError(err) {
return false
}
if req != nil && req.Context().Err() != nil {
return false
}
return !errors.Is(err, context.Canceled)
}
func reverseProxyMethodIsSafe(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func reverseProxyIsBenignTunnelError(err error) bool {
return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err)
}
func reverseProxyIsClosedBodyError(err error) bool {
if err == nil {
return false
}
var streamErr http2.StreamError
if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel {
return true
}
switch err.Error() {
case "body closed by handler", "http2: response body closed", "response body closed":
return true
default:
return false
}
}
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
return UnwrapResponseWriter(writer)
}
func isPrintableASCII(value string) bool {
for i := 0; i < len(value); i++ {
if value[i] < 0x20 || value[i] > 0x7e {
return false
}
}
return true
}