feat(reverseproxy): add upstream balancing and failover

This commit is contained in:
wjqserver 2026-04-02 14:40:56 +08:00
parent 59f190ce3a
commit 919236665b
4 changed files with 1394 additions and 116 deletions

View file

@ -13,7 +13,9 @@ import (
"net/http/httptrace"
"net/textproto"
"net/url"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@ -262,6 +264,507 @@ func TestReverseProxyDefaultViaFallback(t *testing.T) {
}
}
func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) {
t.Helper()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
Targets: []string{"http://example.net"},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusInternalServerError {
t.Fatalf("unexpected status: %d", rr.Code)
}
}
func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) {
t.Helper()
type snapshot struct {
Path string
RawQuery string
}
backendOneCh := make(chan snapshot, 1)
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwoCh := make(chan snapshot, 1)
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
Targets: []string{
backendOne.URL + "/one?from=one",
backendTwo.URL + "/two?from=two",
},
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
}))
first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil)
if first.Code != http.StatusOK || first.Body.String() != "one" {
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
}
second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil)
if second.Code != http.StatusOK || second.Body.String() != "two" {
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
}
select {
case got := <-backendOneCh:
if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" {
t.Fatalf("unexpected first upstream request: %#v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first upstream request")
}
select {
case got := <-backendTwoCh:
if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" {
t.Fatalf("unexpected second upstream request: %#v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for second upstream request")
}
}
func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) {
t.Helper()
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBHeader("X-Upstream", LBFirst()),
},
}))
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
}
headers := http.Header{"X-Upstream": {"tenant-a"}}
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
}
if firstSticky.Body.String() != secondSticky.Body.String() {
t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
}
}
func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) {
t.Helper()
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBQuery("tenant", LBFirst()),
},
}))
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
}
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
}
if firstSticky.Body.String() != secondSticky.Body.String() {
t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
}
}
func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) {
t.Helper()
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"})
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBClientIPHash(),
},
}))
reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
reqOne.RemoteAddr = "10.0.0.1:1234"
reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10")
rrOne := httptest.NewRecorder()
engine.ServeHTTP(rrOne, reqOne)
reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
reqTwo.RemoteAddr = "10.0.0.2:5678"
reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10")
rrTwo := httptest.NewRecorder()
engine.ServeHTTP(rrTwo, reqTwo)
if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK {
t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code)
}
if rrOne.Body.String() != rrTwo.Body.String() {
t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String())
}
}
func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "ok")
}))
defer backend.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{"http://127.0.0.1:1", backend.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
Retries: 1,
},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String())
}
}
func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, r.Header.Get("X-Attempt"))
}))
defer backend.Close()
var attempts atomic.Int64
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{"http://127.0.0.1:1", backend.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
Retries: 1,
},
ModifyRequest: func(req *http.Request) {
req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10))
},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rr.Code)
}
if rr.Body.String() != "2" {
t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String())
}
}
func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) {
t.Helper()
backendCalls := make(chan struct{}, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendCalls <- struct{}{}
_, _ = io.WriteString(w, "ok")
}))
defer backend.Close()
engine := New()
engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{"http://127.0.0.1:1", backend.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
Retries: 1,
},
}))
rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil)
if rr.Code != http.StatusBadGateway {
t.Fatalf("unexpected status: %d", rr.Code)
}
select {
case <-backendCalls:
t.Fatal("unsafe POST request should not be retried to the next upstream")
default:
}
}
func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) {
t.Helper()
backendOneStarted := make(chan struct{}, 1)
releaseBackendOne := make(chan struct{})
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendOneStarted <- struct{}{}
<-releaseBackendOne
_, _ = io.WriteString(w, "one")
}))
defer backendOne.Close()
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "two")
}))
defer backendTwo.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBLeastConn(),
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
client := proxy.Client()
client.Timeout = 5 * time.Second
firstRespCh := make(chan string, 1)
firstErrCh := make(chan error, 1)
go func() {
resp, err := client.Get(proxy.URL + "/proxy")
if err != nil {
firstErrCh <- err
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
firstErrCh <- err
return
}
firstRespCh <- string(body)
}()
select {
case <-backendOneStarted:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first backend request")
}
secondResp, err := client.Get(proxy.URL + "/proxy")
if err != nil {
close(releaseBackendOne)
t.Fatalf("second request failed: %v", err)
}
secondBody, err := io.ReadAll(secondResp.Body)
_ = secondResp.Body.Close()
if err != nil {
close(releaseBackendOne)
t.Fatalf("read second response: %v", err)
}
if string(secondBody) != "two" {
close(releaseBackendOne)
t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody))
}
close(releaseBackendOne)
select {
case err := <-firstErrCh:
t.Fatalf("first request failed: %v", err)
case body := <-firstRespCh:
if body != "one" {
t.Fatalf("unexpected first response body: %q", body)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first response body")
}
}
func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) {
t.Helper()
primaryCalls := make(chan struct{}, 4)
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
primaryCalls <- struct{}{}
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = io.WriteString(w, "primary down")
}))
defer primary.Close()
secondaryCalls := make(chan struct{}, 4)
secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
secondaryCalls <- struct{}{}
_, _ = io.WriteString(w, "secondary up")
}))
defer secondary.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{primary.URL, secondary.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
},
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
UnhealthyStatus: []int{http.StatusServiceUnavailable},
},
}))
first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" {
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
}
second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if second.Code != http.StatusOK || second.Body.String() != "secondary up" {
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
}
select {
case <-primaryCalls:
default:
t.Fatal("expected primary to receive the first request")
}
select {
case <-secondaryCalls:
default:
t.Fatal("expected secondary to receive the second request")
}
select {
case <-primaryCalls:
t.Fatal("primary should not receive the second request while unhealthy")
default:
}
}
func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) {
t.Helper()
started := make(chan struct{}, 1)
release := make(chan struct{})
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
started <- struct{}{}
<-release
_, _ = io.WriteString(w, "ok")
}))
defer backend.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{backend.URL},
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil)
if err != nil {
t.Fatalf("new request: %v", err)
}
client := proxy.Client()
respCh := make(chan error, 1)
go func() {
resp, err := client.Do(req)
if resp != nil {
_ = resp.Body.Close()
}
respCh <- err
}()
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend request")
}
cancel()
close(release)
select {
case <-respCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for canceled request to finish")
}
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String())
}
}
func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) {
t.Helper()
backendCalls := make(chan struct{}, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendCalls <- struct{}{}
_, _ = io.WriteString(w, "ok")
}))
defer backend.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Targets: []string{"http://127.0.0.1:1", backend.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBFirst(),
Retries: 3,
TryDuration: 100 * time.Millisecond,
TryInterval: 250 * time.Millisecond,
},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusBadGateway {
t.Fatalf("unexpected status: %d", rr.Code)
}
select {
case <-backendCalls:
t.Fatal("retry budget should expire before the next upstream attempt")
default:
}
}
func TestReverseProxyCustomErrorHandler(t *testing.T) {
t.Helper()
@ -967,6 +1470,122 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
}
}
func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
t.Helper()
enableHTTP2ExtendedConnectProtocol()
errCh := make(chan error, 8)
newBackend := func(name string) *httptest.Server {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if got := r.Header.Get(":protocol"); got != "websocket" {
errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got)
w.WriteHeader(http.StatusBadRequest)
return
}
controller := http.NewResponseController(w)
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err)
return
}
w.WriteHeader(http.StatusOK)
_ = controller.Flush()
line, err := bufio.NewReader(r.Body).ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err)
return
}
if _, err := io.WriteString(w, name+":"+line); err != nil {
errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err)
return
}
_ = controller.Flush()
}))
server.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil {
t.Fatalf("configure %s HTTP/2 server: %v", name, err)
}
server.StartTLS()
return server
}
backendOne := newBackend("one")
defer backendOne.Close()
backendTwo := newBackend("two")
defer backendTwo.Close()
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Targets: []string{backendOne.URL, backendTwo.URL},
LoadBalancing: ReverseProxyLoadBalancingConfig{
Policy: LBRoundRobin(),
},
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
Via: "proxy.test",
}))
proxy := httptest.NewUnstartedServer(engine)
proxy.EnableHTTP2 = true
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
t.Fatalf("configure proxy HTTP/2 server: %v", err)
}
proxy.StartTLS()
defer proxy.Close()
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
defer transport.CloseIdleConnections()
doRequest := func(payload string) string {
pr, pw := io.Pipe()
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
if err != nil {
t.Fatalf("new CONNECT request: %v", err)
}
req.Header.Set(":protocol", "websocket")
resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("round trip extended CONNECT: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if _, err := io.WriteString(pw, payload+"\n"); err != nil {
t.Fatalf("write tunneled request body: %v", err)
}
if err := pw.Close(); err != nil {
t.Fatalf("close tunneled request body: %v", err)
}
message, err := bufio.NewReader(resp.Body).ReadString('\n')
if err != nil {
t.Fatalf("read tunneled response body: %v", err)
}
return message
}
if got := doRequest("ping"); got != "one:ping\n" {
t.Fatalf("unexpected first tunneled response: %q", got)
}
if got := doRequest("pong"); got != "two:pong\n" {
t.Fatalf("unexpected second tunneled response: %q", got)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
t.Helper()