mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat(reverseproxy): add upstream balancing and failover
This commit is contained in:
parent
59f190ce3a
commit
919236665b
4 changed files with 1394 additions and 116 deletions
|
|
@ -13,7 +13,9 @@ import (
|
|||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -262,6 +264,507 @@ func TestReverseProxyDefaultViaFallback(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://example.com"),
|
||||
Targets: []string{"http://example.net"},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
type snapshot struct {
|
||||
Path string
|
||||
RawQuery string
|
||||
}
|
||||
|
||||
backendOneCh := make(chan snapshot, 1)
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwoCh := make(chan snapshot, 1)
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
backendOne.URL + "/one?from=one",
|
||||
backendTwo.URL + "/two?from=two",
|
||||
},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
|
||||
}))
|
||||
|
||||
first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil)
|
||||
if first.Code != http.StatusOK || first.Body.String() != "one" {
|
||||
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
|
||||
}
|
||||
second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil)
|
||||
if second.Code != http.StatusOK || second.Body.String() != "two" {
|
||||
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-backendOneCh:
|
||||
if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" {
|
||||
t.Fatalf("unexpected first upstream request: %#v", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first upstream request")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-backendTwoCh:
|
||||
if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" {
|
||||
t.Fatalf("unexpected second upstream request: %#v", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for second upstream request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBHeader("X-Upstream", LBFirst()),
|
||||
},
|
||||
}))
|
||||
|
||||
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
|
||||
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
|
||||
}
|
||||
|
||||
headers := http.Header{"X-Upstream": {"tenant-a"}}
|
||||
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
|
||||
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
|
||||
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
|
||||
}
|
||||
if firstSticky.Body.String() != secondSticky.Body.String() {
|
||||
t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBQuery("tenant", LBFirst()),
|
||||
},
|
||||
}))
|
||||
|
||||
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
|
||||
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
|
||||
}
|
||||
|
||||
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
|
||||
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
|
||||
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
|
||||
}
|
||||
if firstSticky.Body.String() != secondSticky.Body.String() {
|
||||
t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"})
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBClientIPHash(),
|
||||
},
|
||||
}))
|
||||
|
||||
reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
||||
reqOne.RemoteAddr = "10.0.0.1:1234"
|
||||
reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10")
|
||||
rrOne := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rrOne, reqOne)
|
||||
|
||||
reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
||||
reqTwo.RemoteAddr = "10.0.0.2:5678"
|
||||
reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10")
|
||||
rrTwo := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rrTwo, reqTwo)
|
||||
|
||||
if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code)
|
||||
}
|
||||
if rrOne.Body.String() != rrTwo.Body.String() {
|
||||
t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
Retries: 1,
|
||||
},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
||||
t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, r.Header.Get("X-Attempt"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
var attempts atomic.Int64
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
Retries: 1,
|
||||
},
|
||||
ModifyRequest: func(req *http.Request) {
|
||||
req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10))
|
||||
},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
if rr.Body.String() != "2" {
|
||||
t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendCalls := make(chan struct{}, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendCalls <- struct{}{}
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
Retries: 1,
|
||||
},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil)
|
||||
if rr.Code != http.StatusBadGateway {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-backendCalls:
|
||||
t.Fatal("unsafe POST request should not be retried to the next upstream")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendOneStarted := make(chan struct{}, 1)
|
||||
releaseBackendOne := make(chan struct{})
|
||||
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendOneStarted <- struct{}{}
|
||||
<-releaseBackendOne
|
||||
_, _ = io.WriteString(w, "one")
|
||||
}))
|
||||
defer backendOne.Close()
|
||||
|
||||
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "two")
|
||||
}))
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBLeastConn(),
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
client := proxy.Client()
|
||||
client.Timeout = 5 * time.Second
|
||||
|
||||
firstRespCh := make(chan string, 1)
|
||||
firstErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
resp, err := client.Get(proxy.URL + "/proxy")
|
||||
if err != nil {
|
||||
firstErrCh <- err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
firstErrCh <- err
|
||||
return
|
||||
}
|
||||
firstRespCh <- string(body)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-backendOneStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first backend request")
|
||||
}
|
||||
|
||||
secondResp, err := client.Get(proxy.URL + "/proxy")
|
||||
if err != nil {
|
||||
close(releaseBackendOne)
|
||||
t.Fatalf("second request failed: %v", err)
|
||||
}
|
||||
secondBody, err := io.ReadAll(secondResp.Body)
|
||||
_ = secondResp.Body.Close()
|
||||
if err != nil {
|
||||
close(releaseBackendOne)
|
||||
t.Fatalf("read second response: %v", err)
|
||||
}
|
||||
if string(secondBody) != "two" {
|
||||
close(releaseBackendOne)
|
||||
t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody))
|
||||
}
|
||||
|
||||
close(releaseBackendOne)
|
||||
select {
|
||||
case err := <-firstErrCh:
|
||||
t.Fatalf("first request failed: %v", err)
|
||||
case body := <-firstRespCh:
|
||||
if body != "one" {
|
||||
t.Fatalf("unexpected first response body: %q", body)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for first response body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
primaryCalls := make(chan struct{}, 4)
|
||||
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
primaryCalls <- struct{}{}
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = io.WriteString(w, "primary down")
|
||||
}))
|
||||
defer primary.Close()
|
||||
|
||||
secondaryCalls := make(chan struct{}, 4)
|
||||
secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
secondaryCalls <- struct{}{}
|
||||
_, _ = io.WriteString(w, "secondary up")
|
||||
}))
|
||||
defer secondary.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{primary.URL, secondary.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
},
|
||||
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
UnhealthyStatus: []int{http.StatusServiceUnavailable},
|
||||
},
|
||||
}))
|
||||
|
||||
first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" {
|
||||
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
|
||||
}
|
||||
second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if second.Code != http.StatusOK || second.Body.String() != "secondary up" {
|
||||
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
|
||||
}
|
||||
|
||||
select {
|
||||
case <-primaryCalls:
|
||||
default:
|
||||
t.Fatal("expected primary to receive the first request")
|
||||
}
|
||||
select {
|
||||
case <-secondaryCalls:
|
||||
default:
|
||||
t.Fatal("expected secondary to receive the second request")
|
||||
}
|
||||
select {
|
||||
case <-primaryCalls:
|
||||
t.Fatal("primary should not receive the second request while unhealthy")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
started := make(chan struct{}, 1)
|
||||
release := make(chan struct{})
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
started <- struct{}{}
|
||||
<-release
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backend.URL},
|
||||
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("new request: %v", err)
|
||||
}
|
||||
client := proxy.Client()
|
||||
respCh := make(chan error, 1)
|
||||
go func() {
|
||||
resp, err := client.Do(req)
|
||||
if resp != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
respCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for backend request")
|
||||
}
|
||||
cancel()
|
||||
close(release)
|
||||
select {
|
||||
case <-respCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for canceled request to finish")
|
||||
}
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
||||
t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendCalls := make(chan struct{}, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendCalls <- struct{}{}
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBFirst(),
|
||||
Retries: 3,
|
||||
TryDuration: 100 * time.Millisecond,
|
||||
TryInterval: 250 * time.Millisecond,
|
||||
},
|
||||
}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusBadGateway {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-backendCalls:
|
||||
t.Fatal("retry budget should expire before the next upstream attempt")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -967,6 +1470,122 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
|
||||
errCh := make(chan error, 8)
|
||||
newBackend := func(name string) *httptest.Server {
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get(":protocol"); got != "websocket" {
|
||||
errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
controller := http.NewResponseController(w)
|
||||
if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = controller.Flush()
|
||||
|
||||
line, err := bufio.NewReader(r.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
if _, err := io.WriteString(w, name+":"+line); err != nil {
|
||||
errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err)
|
||||
return
|
||||
}
|
||||
_ = controller.Flush()
|
||||
}))
|
||||
server.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil {
|
||||
t.Fatalf("configure %s HTTP/2 server: %v", name, err)
|
||||
}
|
||||
server.StartTLS()
|
||||
return server
|
||||
}
|
||||
|
||||
backendOne := newBackend("one")
|
||||
defer backendOne.Close()
|
||||
backendTwo := newBackend("two")
|
||||
defer backendTwo.Close()
|
||||
|
||||
engine := New()
|
||||
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||
Targets: []string{backendOne.URL, backendTwo.URL},
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
||||
Policy: LBRoundRobin(),
|
||||
},
|
||||
Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
|
||||
Via: "proxy.test",
|
||||
}))
|
||||
|
||||
proxy := httptest.NewUnstartedServer(engine)
|
||||
proxy.EnableHTTP2 = true
|
||||
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
||||
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
||||
}
|
||||
proxy.StartTLS()
|
||||
defer proxy.Close()
|
||||
|
||||
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
defer transport.CloseIdleConnections()
|
||||
|
||||
doRequest := func(payload string) string {
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
||||
if err != nil {
|
||||
t.Fatalf("new CONNECT request: %v", err)
|
||||
}
|
||||
req.Header.Set(":protocol", "websocket")
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("round trip extended CONNECT: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
if _, err := io.WriteString(pw, payload+"\n"); err != nil {
|
||||
t.Fatalf("write tunneled request body: %v", err)
|
||||
}
|
||||
if err := pw.Close(); err != nil {
|
||||
t.Fatalf("close tunneled request body: %v", err)
|
||||
}
|
||||
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("read tunneled response body: %v", err)
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
if got := doRequest("ping"); got != "one:ping\n" {
|
||||
t.Fatalf("unexpected first tunneled response: %q", got)
|
||||
}
|
||||
if got := doRequest("pong"); got != "two:pong\n" {
|
||||
t.Fatalf("unexpected second tunneled response: %q", got)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue