mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
2485 lines
72 KiB
Go
2485 lines
72 KiB
Go
package touka
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
crand "crypto/rand"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httptrace"
|
|
"net/textproto"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
|
t.Helper()
|
|
|
|
type backendRequestSnapshot struct {
|
|
Path string
|
|
RawQuery string
|
|
Host string
|
|
Connection string
|
|
RemovedHeader string
|
|
Forwarded string
|
|
XForwardedFor string
|
|
XForwardedHost string
|
|
XForwardedProto string
|
|
Via []string
|
|
TE string
|
|
UserAgent string
|
|
}
|
|
|
|
gotCh := make(chan backendRequestSnapshot, 1)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotCh <- backendRequestSnapshot{
|
|
Path: r.URL.Path,
|
|
RawQuery: r.URL.RawQuery,
|
|
Host: r.Host,
|
|
Connection: r.Header.Get("Connection"),
|
|
RemovedHeader: r.Header.Get("X-Remove-Me"),
|
|
Forwarded: r.Header.Get("Forwarded"),
|
|
XForwardedFor: r.Header.Get("X-Forwarded-For"),
|
|
XForwardedHost: r.Header.Get("X-Forwarded-Host"),
|
|
XForwardedProto: r.Header.Get("X-Forwarded-Proto"),
|
|
Via: append([]string(nil), r.Header.Values("Via")...),
|
|
TE: r.Header.Get("Te"),
|
|
UserAgent: r.Header.Get("User-Agent"),
|
|
}
|
|
|
|
w.Header().Set("Connection", "X-Backend-Secret")
|
|
w.Header().Set("X-Backend-Secret", "remove-me")
|
|
w.Header().Add("Via", "1.0 upstream")
|
|
w.Header().Add("Trailer", "X-Upstream-Trailer")
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
_, _ = io.WriteString(w, "proxied")
|
|
w.Header().Set("X-Upstream-Trailer", "done")
|
|
}))
|
|
defer backend.Close()
|
|
|
|
target, err := url.Parse(backend.URL + "/base?from=target")
|
|
if err != nil {
|
|
t.Fatalf("parse target: %v", err)
|
|
}
|
|
|
|
engine := New()
|
|
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
|
|
Target: target,
|
|
ForwardedHeaders: ForwardedBoth,
|
|
ForwardedBy: "_proxy-node",
|
|
Via: "proxy.test",
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://client.example/api/ping?bad=1;smuggle=2&q=2", nil)
|
|
req.Host = "client.example"
|
|
req.RemoteAddr = "198.51.100.10:4567"
|
|
req.Header.Set("Connection", "X-Remove-Me")
|
|
req.Header.Set("X-Remove-Me", "client-secret")
|
|
req.Header.Set("X-Forwarded-For", "203.0.113.9")
|
|
req.Header.Set("X-Forwarded-Host", "edge.example")
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("Forwarded", "for=203.0.113.9")
|
|
req.Header.Set("Te", "trailers")
|
|
|
|
rr := httptest.NewRecorder()
|
|
engine.ServeHTTP(rr, req)
|
|
|
|
resp := rr.Result()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("read body: %v", err)
|
|
}
|
|
_ = resp.Body.Close()
|
|
|
|
var got backendRequestSnapshot
|
|
select {
|
|
case got = <-gotCh:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for backend snapshot")
|
|
}
|
|
|
|
if string(body) != "proxied" {
|
|
t.Fatalf("unexpected body: %q", string(body))
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
|
}
|
|
if got.Path != "/base/api/ping" {
|
|
t.Fatalf("unexpected upstream path: %q", got.Path)
|
|
}
|
|
if got.RawQuery != "from=target&q=2" {
|
|
t.Fatalf("unexpected upstream raw query: %q", got.RawQuery)
|
|
}
|
|
if got.Host != strings.TrimPrefix(backend.URL, "http://") {
|
|
t.Fatalf("unexpected upstream host: %q", got.Host)
|
|
}
|
|
if got.Connection != "" {
|
|
t.Fatalf("connection header should be stripped, got %q", got.Connection)
|
|
}
|
|
if got.RemovedHeader != "" {
|
|
t.Fatalf("connection-token header should be stripped, got %q", got.RemovedHeader)
|
|
}
|
|
if got.XForwardedFor != "203.0.113.9, 198.51.100.10" {
|
|
t.Fatalf("unexpected X-Forwarded-For: %q", got.XForwardedFor)
|
|
}
|
|
if got.XForwardedHost != "edge.example" {
|
|
t.Fatalf("unexpected X-Forwarded-Host: %q", got.XForwardedHost)
|
|
}
|
|
if got.XForwardedProto != "https" {
|
|
t.Fatalf("unexpected X-Forwarded-Proto: %q", got.XForwardedProto)
|
|
}
|
|
if got.TE != "trailers" {
|
|
t.Fatalf("unexpected TE header: %q", got.TE)
|
|
}
|
|
if got.UserAgent != "" {
|
|
t.Fatalf("expected empty user-agent suppression, got %q", got.UserAgent)
|
|
}
|
|
if !strings.Contains(got.Forwarded, "for=203.0.113.9") {
|
|
t.Fatalf("forwarded header missing prior hop: %q", got.Forwarded)
|
|
}
|
|
if !strings.Contains(got.Forwarded, "for=198.51.100.10") {
|
|
t.Fatalf("forwarded header missing client ip: %q", got.Forwarded)
|
|
}
|
|
if !strings.Contains(got.Forwarded, "by=_proxy-node") {
|
|
t.Fatalf("forwarded header missing by token: %q", got.Forwarded)
|
|
}
|
|
if !strings.Contains(got.Forwarded, "host=client.example") {
|
|
t.Fatalf("forwarded header missing host: %q", got.Forwarded)
|
|
}
|
|
if !strings.Contains(got.Forwarded, "proto=http") {
|
|
t.Fatalf("forwarded header missing proto: %q", got.Forwarded)
|
|
}
|
|
if len(got.Via) != 1 || got.Via[0] != "1.1 proxy.test" {
|
|
t.Fatalf("unexpected upstream Via headers: %#v", got.Via)
|
|
}
|
|
if resp.Header.Get("Connection") != "" {
|
|
t.Fatalf("response connection header should be stripped, got %q", resp.Header.Get("Connection"))
|
|
}
|
|
if resp.Header.Get("X-Backend-Secret") != "" {
|
|
t.Fatalf("response connection-token header should be stripped, got %q", resp.Header.Get("X-Backend-Secret"))
|
|
}
|
|
if gotVia := resp.Header.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.0 upstream" || gotVia[1] != "1.1 proxy.test" {
|
|
t.Fatalf("unexpected response Via headers: %#v", gotVia)
|
|
}
|
|
if resp.Trailer.Get("X-Upstream-Trailer") != "done" {
|
|
t.Fatalf("unexpected proxied trailer: %q", resp.Trailer.Get("X-Upstream-Trailer"))
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyRejectsInvalidForwardedBy(t *testing.T) {
|
|
t.Helper()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, "http://example.com"),
|
|
ForwardedHeaders: ForwardedBoth,
|
|
ForwardedBy: "proxy-node",
|
|
}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusInternalServerError {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyForwardedByTrimsWhitespace(t *testing.T) {
|
|
t.Helper()
|
|
|
|
forwardedCh := make(chan string, 1)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
forwardedCh <- r.Header.Get("Forwarded")
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, backend.URL),
|
|
ForwardedHeaders: ForwardedBoth,
|
|
ForwardedBy: " _proxy-node ",
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
|
req.RemoteAddr = "198.51.100.10:4567"
|
|
rr := httptest.NewRecorder()
|
|
engine.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusNoContent {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
|
|
select {
|
|
case forwarded := <-forwardedCh:
|
|
if !strings.Contains(forwarded, "by=_proxy-node") {
|
|
t.Fatalf("unexpected Forwarded header: %q", forwarded)
|
|
}
|
|
if strings.Contains(forwarded, `by=" _proxy-node "`) {
|
|
t.Fatalf("forwarded header should not preserve surrounding whitespace: %q", forwarded)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for backend Forwarded header")
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyDefaultViaFallback(t *testing.T) {
|
|
t.Helper()
|
|
|
|
viaCh := make(chan []string, 1)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
viaCh <- append([]string(nil), r.Header.Values("Via")...)
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
defer backend.Close()
|
|
|
|
target, err := url.Parse(backend.URL)
|
|
if err != nil {
|
|
t.Fatalf("parse target: %v", err)
|
|
}
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{Target: target}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusNoContent {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
|
|
select {
|
|
case via := <-viaCh:
|
|
if len(via) != 1 || via[0] != "1.1 touka-engine" {
|
|
t.Fatalf("unexpected default Via header: %#v", via)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for backend Via header")
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) {
|
|
t.Helper()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, "http://example.com"),
|
|
Targets: []string{"http://example.net"},
|
|
}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusInternalServerError {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) {
|
|
t.Helper()
|
|
|
|
type snapshot struct {
|
|
Path string
|
|
RawQuery string
|
|
}
|
|
|
|
backendOneCh := make(chan snapshot, 1)
|
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
|
|
_, _ = io.WriteString(w, "one")
|
|
}))
|
|
defer backendOne.Close()
|
|
|
|
backendTwoCh := make(chan snapshot, 1)
|
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery}
|
|
_, _ = io.WriteString(w, "two")
|
|
}))
|
|
defer backendTwo.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{
|
|
backendOne.URL + "/one?from=one",
|
|
backendTwo.URL + "/two?from=two",
|
|
},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
|
|
}))
|
|
|
|
first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil)
|
|
if first.Code != http.StatusOK || first.Body.String() != "one" {
|
|
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
|
|
}
|
|
second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil)
|
|
if second.Code != http.StatusOK || second.Body.String() != "two" {
|
|
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
|
|
}
|
|
|
|
select {
|
|
case got := <-backendOneCh:
|
|
if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" {
|
|
t.Fatalf("unexpected first upstream request: %#v", got)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for first upstream request")
|
|
}
|
|
|
|
select {
|
|
case got := <-backendTwoCh:
|
|
if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" {
|
|
t.Fatalf("unexpected second upstream request: %#v", got)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for second upstream request")
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "one")
|
|
}))
|
|
defer backendOne.Close()
|
|
|
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "two")
|
|
}))
|
|
defer backendTwo.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBHeader("X-Upstream", LBFirst()),
|
|
},
|
|
}))
|
|
|
|
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
|
|
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
|
|
}
|
|
|
|
headers := http.Header{"X-Upstream": {"tenant-a"}}
|
|
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
|
|
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers)
|
|
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
|
|
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
|
|
}
|
|
if firstSticky.Body.String() != secondSticky.Body.String() {
|
|
t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "one")
|
|
}))
|
|
defer backendOne.Close()
|
|
|
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "two")
|
|
}))
|
|
defer backendTwo.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBQuery("tenant", LBFirst()),
|
|
},
|
|
}))
|
|
|
|
fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" {
|
|
t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String())
|
|
}
|
|
|
|
firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
|
|
secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil)
|
|
if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK {
|
|
t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code)
|
|
}
|
|
if firstSticky.Body.String() != secondSticky.Body.String() {
|
|
t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "one")
|
|
}))
|
|
defer backendOne.Close()
|
|
|
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "two")
|
|
}))
|
|
defer backendTwo.Close()
|
|
|
|
engine := New()
|
|
engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"})
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBClientIPHash(),
|
|
},
|
|
}))
|
|
|
|
reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
|
reqOne.RemoteAddr = "10.0.0.1:1234"
|
|
reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10")
|
|
rrOne := httptest.NewRecorder()
|
|
engine.ServeHTTP(rrOne, reqOne)
|
|
|
|
reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
|
reqTwo.RemoteAddr = "10.0.0.2:5678"
|
|
reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10")
|
|
rrTwo := httptest.NewRecorder()
|
|
engine.ServeHTTP(rrTwo, reqTwo)
|
|
|
|
if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK {
|
|
t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code)
|
|
}
|
|
if rrOne.Body.String() != rrTwo.Body.String() {
|
|
t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "ok")
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBFirst(),
|
|
Retries: 1,
|
|
},
|
|
}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
|
t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, r.Header.Get("X-Attempt"))
|
|
}))
|
|
defer backend.Close()
|
|
|
|
var attempts atomic.Int64
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBFirst(),
|
|
Retries: 1,
|
|
},
|
|
ModifyRequest: func(req *http.Request) {
|
|
req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10))
|
|
},
|
|
}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
if rr.Body.String() != "2" {
|
|
t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backendCalls := make(chan struct{}, 1)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
backendCalls <- struct{}{}
|
|
_, _ = io.WriteString(w, "ok")
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBFirst(),
|
|
Retries: 1,
|
|
},
|
|
}))
|
|
|
|
rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil)
|
|
if rr.Code != http.StatusBadGateway {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
|
|
select {
|
|
case <-backendCalls:
|
|
t.Fatal("unsafe POST request should not be retried to the next upstream")
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backendOneStarted := make(chan struct{}, 1)
|
|
releaseBackendOne := make(chan struct{})
|
|
backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
backendOneStarted <- struct{}{}
|
|
<-releaseBackendOne
|
|
_, _ = io.WriteString(w, "one")
|
|
}))
|
|
defer backendOne.Close()
|
|
|
|
backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.WriteString(w, "two")
|
|
}))
|
|
defer backendTwo.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBLeastConn(),
|
|
},
|
|
}))
|
|
|
|
proxy := httptest.NewServer(engine)
|
|
defer proxy.Close()
|
|
client := proxy.Client()
|
|
client.Timeout = 5 * time.Second
|
|
|
|
firstRespCh := make(chan string, 1)
|
|
firstErrCh := make(chan error, 1)
|
|
go func() {
|
|
resp, err := client.Get(proxy.URL + "/proxy")
|
|
if err != nil {
|
|
firstErrCh <- err
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
firstErrCh <- err
|
|
return
|
|
}
|
|
firstRespCh <- string(body)
|
|
}()
|
|
|
|
select {
|
|
case <-backendOneStarted:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for first backend request")
|
|
}
|
|
|
|
secondResp, err := client.Get(proxy.URL + "/proxy")
|
|
if err != nil {
|
|
close(releaseBackendOne)
|
|
t.Fatalf("second request failed: %v", err)
|
|
}
|
|
secondBody, err := io.ReadAll(secondResp.Body)
|
|
_ = secondResp.Body.Close()
|
|
if err != nil {
|
|
close(releaseBackendOne)
|
|
t.Fatalf("read second response: %v", err)
|
|
}
|
|
if string(secondBody) != "two" {
|
|
close(releaseBackendOne)
|
|
t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody))
|
|
}
|
|
|
|
close(releaseBackendOne)
|
|
select {
|
|
case err := <-firstErrCh:
|
|
t.Fatalf("first request failed: %v", err)
|
|
case body := <-firstRespCh:
|
|
if body != "one" {
|
|
t.Fatalf("unexpected first response body: %q", body)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for first response body")
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) {
|
|
t.Helper()
|
|
|
|
primaryCalls := make(chan struct{}, 4)
|
|
primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
primaryCalls <- struct{}{}
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
_, _ = io.WriteString(w, "primary down")
|
|
}))
|
|
defer primary.Close()
|
|
|
|
secondaryCalls := make(chan struct{}, 4)
|
|
secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
secondaryCalls <- struct{}{}
|
|
_, _ = io.WriteString(w, "secondary up")
|
|
}))
|
|
defer secondary.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{primary.URL, secondary.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBFirst(),
|
|
},
|
|
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
|
FailDuration: time.Minute,
|
|
UnhealthyStatus: []int{http.StatusServiceUnavailable},
|
|
},
|
|
}))
|
|
|
|
first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" {
|
|
t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String())
|
|
}
|
|
second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if second.Code != http.StatusOK || second.Body.String() != "secondary up" {
|
|
t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String())
|
|
}
|
|
|
|
select {
|
|
case <-primaryCalls:
|
|
default:
|
|
t.Fatal("expected primary to receive the first request")
|
|
}
|
|
select {
|
|
case <-secondaryCalls:
|
|
default:
|
|
t.Fatal("expected secondary to receive the second request")
|
|
}
|
|
select {
|
|
case <-primaryCalls:
|
|
t.Fatal("primary should not receive the second request while unhealthy")
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) {
|
|
t.Helper()
|
|
|
|
started := make(chan struct{}, 1)
|
|
release := make(chan struct{})
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
started <- struct{}{}
|
|
<-release
|
|
_, _ = io.WriteString(w, "ok")
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{backend.URL},
|
|
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
|
FailDuration: time.Minute,
|
|
},
|
|
}))
|
|
|
|
proxy := httptest.NewServer(engine)
|
|
defer proxy.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil)
|
|
if err != nil {
|
|
t.Fatalf("new request: %v", err)
|
|
}
|
|
client := proxy.Client()
|
|
respCh := make(chan error, 1)
|
|
go func() {
|
|
resp, err := client.Do(req)
|
|
if resp != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
respCh <- err
|
|
}()
|
|
|
|
select {
|
|
case <-started:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for backend request")
|
|
}
|
|
cancel()
|
|
close(release)
|
|
select {
|
|
case <-respCh:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for canceled request to finish")
|
|
}
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
|
t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backendCalls := make(chan struct{}, 1)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
backendCalls <- struct{}{}
|
|
_, _ = io.WriteString(w, "ok")
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{"http://127.0.0.1:1", backend.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBFirst(),
|
|
Retries: 3,
|
|
TryDuration: 100 * time.Millisecond,
|
|
TryInterval: 250 * time.Millisecond,
|
|
},
|
|
}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusBadGateway {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
|
|
select {
|
|
case <-backendCalls:
|
|
t.Fatal("retry budget should expire before the next upstream attempt")
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyAllowH2CUpstream(t *testing.T) {
|
|
t.Helper()
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen h2c upstream: %v", err)
|
|
}
|
|
server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Upstream-Proto", r.Proto)
|
|
_, _ = io.WriteString(w, "ok")
|
|
})}
|
|
server.Protocols = new(http.Protocols)
|
|
server.Protocols.SetUnencryptedHTTP2(true)
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- server.Serve(listener)
|
|
}()
|
|
defer func() {
|
|
_ = server.Close()
|
|
<-errCh
|
|
}()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, "http://"+listener.Addr().String()),
|
|
AllowH2CUpstream: true,
|
|
}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusOK || rr.Body.String() != "ok" {
|
|
t.Fatalf("unexpected response: code=%d body=%q", rr.Code, rr.Body.String())
|
|
}
|
|
if got := rr.Header().Get("X-Upstream-Proto"); got != "HTTP/2.0" {
|
|
t.Fatalf("expected h2c upstream proto, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
|
t.Helper()
|
|
|
|
engine := New()
|
|
target, err := url.Parse("http://127.0.0.1:1")
|
|
if err != nil {
|
|
t.Fatalf("parse target: %v", err)
|
|
}
|
|
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Target: target,
|
|
ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) {
|
|
w.WriteHeader(http.StatusGatewayTimeout)
|
|
_, _ = io.WriteString(w, fmt.Sprintf("proxy failure: %v", err))
|
|
},
|
|
}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusGatewayTimeout {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
if !strings.Contains(rr.Body.String(), "proxy failure:") {
|
|
t.Fatalf("unexpected body: %q", rr.Body.String())
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *testing.T) {
|
|
t.Helper()
|
|
|
|
flushErr := errors.New("flush failed")
|
|
writer := &flushErrorResponseWriter{flushErr: flushErr}
|
|
conn := &reverseProxyH2ReadWriteCloser{
|
|
ReadCloser: io.NopCloser(strings.NewReader("")),
|
|
ResponseWriter: writer,
|
|
controller: http.NewResponseController(reverseProxyBaseResponseWriter(writer)),
|
|
}
|
|
|
|
n, err := conn.Write([]byte("ping"))
|
|
if n != len("ping") {
|
|
t.Fatalf("unexpected bytes written: %d", n)
|
|
}
|
|
if !errors.Is(err, flushErr) {
|
|
t.Fatalf("unexpected write error: %v", err)
|
|
}
|
|
if got := writer.body.String(); got != "ping" {
|
|
t.Fatalf("unexpected buffered body: %q", got)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *testing.T) {
|
|
t.Helper()
|
|
|
|
transportCalled := atomic.Bool{}
|
|
entropyErr := errors.New("entropy source unavailable")
|
|
originalReader := crand.Reader
|
|
crand.Reader = errorReader{err: entropyErr}
|
|
t.Cleanup(func() {
|
|
crand.Reader = originalReader
|
|
})
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, "http://example.com"),
|
|
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
|
transportCalled.Store(true)
|
|
return nil, errors.New("unexpected round trip")
|
|
}),
|
|
ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) {
|
|
w.WriteHeader(reverseProxyStatusCode(err))
|
|
_, _ = io.WriteString(w, err.Error())
|
|
},
|
|
}))
|
|
|
|
headers := make(http.Header)
|
|
headers.Set(":protocol", "websocket")
|
|
rr := PerformRequest(engine, http.MethodConnect, "/ws", nil, headers)
|
|
|
|
if transportCalled.Load() {
|
|
t.Fatal("transport should not be called when websocket key generation fails")
|
|
}
|
|
if rr.Code != http.StatusBadGateway {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
if body := rr.Body.String(); !strings.Contains(body, "reverse proxy failed to generate websocket key") || !strings.Contains(body, entropyErr.Error()) {
|
|
t.Fatalf("unexpected error body: %q", body)
|
|
}
|
|
}
|
|
|
|
func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing.T) {
|
|
t.Helper()
|
|
|
|
originalDefaultTransport := http.DefaultTransport
|
|
http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
|
return nil, errors.New("unexpected round trip")
|
|
})
|
|
t.Cleanup(func() {
|
|
http.DefaultTransport = originalDefaultTransport
|
|
})
|
|
|
|
assertTransport := func(name string, rt http.RoundTripper, check func(*http.Transport)) {
|
|
t.Helper()
|
|
transport, ok := rt.(*http.Transport)
|
|
if !ok {
|
|
t.Fatalf("%s returned %T, want *http.Transport", name, rt)
|
|
}
|
|
check(transport)
|
|
}
|
|
|
|
assertTransport("newHTTP2ExtendedConnectTransport", newHTTP2ExtendedConnectTransport(), func(transport *http.Transport) {
|
|
if transport.Protocols == nil || !transport.Protocols.HTTP1() || !transport.Protocols.HTTP2() {
|
|
t.Fatalf("unexpected protocols for extended connect transport: %#v", transport.Protocols)
|
|
}
|
|
})
|
|
assertTransport("newHTTP1BridgeTransportWithTLSConfig", newHTTP1BridgeTransportWithTLSConfig(nil), func(transport *http.Transport) {
|
|
if transport.Protocols == nil || !transport.Protocols.HTTP1() || transport.Protocols.HTTP2() || transport.Protocols.UnencryptedHTTP2() {
|
|
t.Fatalf("unexpected protocols for bridge transport: %#v", transport.Protocols)
|
|
}
|
|
if transport.TLSClientConfig == nil || len(transport.TLSClientConfig.NextProtos) != 1 || transport.TLSClientConfig.NextProtos[0] != "http/1.1" {
|
|
t.Fatalf("unexpected TLS next protos for bridge transport: %#v", transport.TLSClientConfig)
|
|
}
|
|
})
|
|
assertTransport("newH2CTransport", newH2CTransport(), func(transport *http.Transport) {
|
|
if transport.Protocols == nil || !transport.Protocols.UnencryptedHTTP2() || transport.Protocols.HTTP1() || transport.Protocols.HTTP2() {
|
|
t.Fatalf("unexpected protocols for h2c transport: %#v", transport.Protocols)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestNewHTTP1BridgeTransportWithTLSConfigClonesInput(t *testing.T) {
|
|
t.Helper()
|
|
|
|
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
|
rt := newHTTP1BridgeTransportWithTLSConfig(tlsConfig)
|
|
transport, ok := rt.(*http.Transport)
|
|
if !ok {
|
|
t.Fatalf("unexpected transport type: %T", rt)
|
|
}
|
|
if transport.TLSClientConfig == nil {
|
|
t.Fatal("expected TLS client config")
|
|
}
|
|
if transport.TLSClientConfig == tlsConfig {
|
|
t.Fatal("expected bridge transport to clone TLS config")
|
|
}
|
|
if len(tlsConfig.NextProtos) != 0 {
|
|
t.Fatalf("input TLS config was mutated: %#v", tlsConfig.NextProtos)
|
|
}
|
|
if got := transport.TLSClientConfig.NextProtos; len(got) != 1 || got[0] != "http/1.1" {
|
|
t.Fatalf("unexpected transport NextProtos: %#v", got)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
|
|
t.Helper()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, "http://example.com"),
|
|
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
|
return nil, context.DeadlineExceeded
|
|
}),
|
|
}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
|
if rr.Code != http.StatusGatewayTimeout {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "later")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, "streamed")
|
|
}))
|
|
defer backend.Close()
|
|
|
|
target, err := url.Parse(backend.URL)
|
|
if err != nil {
|
|
t.Fatalf("parse target: %v", err)
|
|
}
|
|
|
|
engine := New()
|
|
engine.GET("/trailers", ReverseProxy(ReverseProxyConfig{Target: target}))
|
|
|
|
rr := PerformRequest(engine, http.MethodGet, "/trailers", nil, nil)
|
|
resp := rr.Result()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("read body: %v", err)
|
|
}
|
|
_ = resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
|
}
|
|
if string(body) != "streamed" {
|
|
t.Fatalf("unexpected body: %q", string(body))
|
|
}
|
|
if got := resp.Trailer.Get("X-Unannounced-Trailer"); got != "later" {
|
|
t.Fatalf("unexpected unannounced trailer: %q", got)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyProtocolUpgrade(t *testing.T) {
|
|
t.Helper()
|
|
|
|
errCh := make(chan error, 8)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !headerValuesContainToken(r.Header["Connection"], "Upgrade") {
|
|
errCh <- fmt.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection"))
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
|
errCh <- fmt.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade"))
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
errCh <- errors.New("backend response writer does not support hijack")
|
|
return
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("backend hijack failed: %w", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
|
if err := brw.Flush(); err != nil {
|
|
errCh <- fmt.Errorf("backend flush failed: %w", err)
|
|
return
|
|
}
|
|
|
|
line, err := brw.ReadString('\n')
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("backend read failed: %w", err)
|
|
return
|
|
}
|
|
_, _ = io.WriteString(brw, "echo:"+line)
|
|
if err := brw.Flush(); err != nil {
|
|
errCh <- fmt.Errorf("backend echo flush failed: %w", err)
|
|
return
|
|
}
|
|
}))
|
|
defer backend.Close()
|
|
|
|
target, err := url.Parse(backend.URL)
|
|
if err != nil {
|
|
t.Fatalf("parse target: %v", err)
|
|
}
|
|
|
|
engine := New()
|
|
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{
|
|
Target: target,
|
|
Via: "proxy.test",
|
|
}))
|
|
|
|
proxy := httptest.NewServer(engine)
|
|
defer proxy.Close()
|
|
|
|
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("dial proxy: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
|
t.Fatalf("set deadline: %v", err)
|
|
}
|
|
|
|
_, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
|
|
if err != nil {
|
|
t.Fatalf("write upgrade request: %v", err)
|
|
}
|
|
|
|
reader := bufio.NewReader(conn)
|
|
statusLine, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
t.Fatalf("read status line: %v", err)
|
|
}
|
|
if !strings.Contains(statusLine, "101") {
|
|
t.Fatalf("unexpected status line: %q", statusLine)
|
|
}
|
|
|
|
headers, err := textproto.NewReader(reader).ReadMIMEHeader()
|
|
if err != nil {
|
|
t.Fatalf("read headers: %v", err)
|
|
}
|
|
respHeader := http.Header(headers)
|
|
if !strings.EqualFold(respHeader.Get("Upgrade"), "websocket") {
|
|
t.Fatalf("unexpected upgrade response header: %q", respHeader.Get("Upgrade"))
|
|
}
|
|
if !headerValuesContainToken(respHeader.Values("Connection"), "Upgrade") {
|
|
t.Fatalf("unexpected connection response header: %#v", respHeader.Values("Connection"))
|
|
}
|
|
if gotVia := respHeader.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" {
|
|
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
|
}
|
|
|
|
if _, err := io.WriteString(conn, "ping\n"); err != nil {
|
|
t.Fatalf("write tunneled payload: %v", err)
|
|
}
|
|
message, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
t.Fatalf("read tunneled payload: %v", err)
|
|
}
|
|
if message != "echo:ping\n" {
|
|
t.Fatalf("unexpected tunneled payload: %q", message)
|
|
}
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) {
|
|
t.Helper()
|
|
|
|
errCh := make(chan error, 4)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
errCh <- errors.New("backend response writer does not support hijack")
|
|
return
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("backend hijack failed: %w", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\n\r\n")
|
|
if err := brw.Flush(); err != nil {
|
|
errCh <- fmt.Errorf("backend flush failed: %w", err)
|
|
return
|
|
}
|
|
}))
|
|
defer backend.Close()
|
|
|
|
target, err := url.Parse(backend.URL)
|
|
if err != nil {
|
|
t.Fatalf("parse target: %v", err)
|
|
}
|
|
|
|
engine := New()
|
|
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: target}))
|
|
|
|
proxy := httptest.NewServer(engine)
|
|
defer proxy.Close()
|
|
|
|
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("dial proxy: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
|
t.Fatalf("set deadline: %v", err)
|
|
}
|
|
|
|
_, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
|
|
if err != nil {
|
|
t.Fatalf("write upgrade request: %v", err)
|
|
}
|
|
|
|
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
|
|
if err != nil {
|
|
t.Fatalf("read response: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusBadGateway {
|
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
|
}
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyUpgradeNeedsHijacker(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
t.Fatal("backend response writer does not support hijack")
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
t.Fatalf("backend hijack failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
|
_ = brw.Flush()
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://client.example/ws", nil)
|
|
req.Header.Set("Connection", "Upgrade")
|
|
req.Header.Set("Upgrade", "websocket")
|
|
rr := httptest.NewRecorder()
|
|
engine.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusNotImplemented {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyMaxForwardsTraceHandledLocally(t *testing.T) {
|
|
t.Helper()
|
|
|
|
called := make(chan struct{}, 1)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called <- struct{}{}
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
|
|
|
req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil)
|
|
req.RequestURI = "/trace"
|
|
req.Header.Set("Max-Forwards", "0")
|
|
req.Header.Set("Authorization", "secret")
|
|
req.Header.Set("Cookie", "a=b")
|
|
req.Header.Set("Forwarded", "for=192.0.2.1")
|
|
|
|
rr := httptest.NewRecorder()
|
|
engine.ServeHTTP(rr, req)
|
|
|
|
resp := rr.Result()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("read body: %v", err)
|
|
}
|
|
_ = resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
|
}
|
|
if got := resp.Header.Get("Content-Type"); got != "message/http" {
|
|
t.Fatalf("unexpected content type: %q", got)
|
|
}
|
|
if !strings.Contains(string(body), "TRACE /trace HTTP/1.1") {
|
|
t.Fatalf("trace body missing request line: %q", string(body))
|
|
}
|
|
if strings.Contains(string(body), "Authorization:") {
|
|
t.Fatalf("trace body leaked authorization header: %q", string(body))
|
|
}
|
|
if strings.Contains(string(body), "Cookie:") {
|
|
t.Fatalf("trace body leaked cookie header: %q", string(body))
|
|
}
|
|
if strings.Contains(string(body), "Forwarded:") {
|
|
t.Fatalf("trace body leaked forwarded header: %q", string(body))
|
|
}
|
|
|
|
select {
|
|
case <-called:
|
|
t.Fatal("backend should not be called when Max-Forwards is zero")
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyMaxForwardsTraceDecrementsBeforeForwarding(t *testing.T) {
|
|
t.Helper()
|
|
|
|
maxForwardsCh := make(chan string, 1)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
maxForwardsCh <- r.Header.Get("Max-Forwards")
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
|
|
|
req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil)
|
|
req.Header.Set("Max-Forwards", "2")
|
|
rr := httptest.NewRecorder()
|
|
engine.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusNoContent {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
|
|
select {
|
|
case got := <-maxForwardsCh:
|
|
if got != "1" {
|
|
t.Fatalf("unexpected Max-Forwards header: %q", got)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for backend Max-Forwards")
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) {
|
|
t.Helper()
|
|
|
|
called := make(chan struct{}, 1)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called <- struct{}{}
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", func(c *Context) { c.Status(http.StatusNoContent) })
|
|
engine.OPTIONS("/proxy", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://client.example/proxy", nil)
|
|
req.Header.Set("Max-Forwards", "0")
|
|
rr := httptest.NewRecorder()
|
|
engine.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
allow := rr.Header().Get("Allow")
|
|
if !strings.Contains(allow, http.MethodGet) || !strings.Contains(allow, http.MethodOptions) {
|
|
t.Fatalf("unexpected Allow header: %q", allow)
|
|
}
|
|
|
|
select {
|
|
case <-called:
|
|
t.Fatal("backend should not be called when Max-Forwards is zero")
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) {
|
|
t.Helper()
|
|
|
|
engine := New()
|
|
engine.OPTIONS("/", func(c *Context) {
|
|
c.Status(http.StatusNoContent)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "http://client.example/", nil)
|
|
req.RequestURI = "*"
|
|
req.URL.Path = ""
|
|
req.URL.RawPath = ""
|
|
rr := httptest.NewRecorder()
|
|
engine.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
|
|
}
|
|
if got := rr.Header().Get("Content-Length"); got != "0" {
|
|
t.Fatalf("unexpected Content-Length header: %q", got)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyConnectTunnel(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backendAddr := ""
|
|
errCh := make(chan error, 4)
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodConnect {
|
|
errCh <- fmt.Errorf("unexpected method: %s", r.Method)
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
if got, want := r.RequestURI, backendAddr; got != want {
|
|
errCh <- fmt.Errorf("unexpected CONNECT target %q, want %q", got, want)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
errCh <- errors.New("backend response writer does not support hijack")
|
|
return
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("backend hijack failed: %w", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\nVia: 1.1 upstream\r\n\r\n")
|
|
if err := brw.Flush(); err != nil {
|
|
errCh <- fmt.Errorf("backend flush failed: %w", err)
|
|
return
|
|
}
|
|
|
|
line, err := brw.ReadString('\n')
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("backend read failed: %w", err)
|
|
return
|
|
}
|
|
_, _ = io.WriteString(brw, strings.ToUpper(line))
|
|
if err := brw.Flush(); err != nil {
|
|
errCh <- fmt.Errorf("backend write failed: %w", err)
|
|
return
|
|
}
|
|
}))
|
|
defer backend.Close()
|
|
backendAddr = strings.TrimPrefix(backend.URL, "http://")
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodConnect, "/:authority", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, backend.URL),
|
|
Via: "proxy.test",
|
|
}))
|
|
|
|
proxy := httptest.NewServer(engine)
|
|
defer proxy.Close()
|
|
|
|
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("dial proxy: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
|
t.Fatalf("set deadline: %v", err)
|
|
}
|
|
|
|
_, err = fmt.Fprintf(conn, "CONNECT origin.example:443 HTTP/1.1\r\nHost: origin.example:443\r\n\r\n")
|
|
if err != nil {
|
|
t.Fatalf("write connect request: %v", err)
|
|
}
|
|
|
|
reader := bufio.NewReader(conn)
|
|
statusLine, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
t.Fatalf("read status line: %v", err)
|
|
}
|
|
if !strings.Contains(statusLine, "200") {
|
|
t.Fatalf("unexpected status line: %q", statusLine)
|
|
}
|
|
|
|
headers, err := textproto.NewReader(reader).ReadMIMEHeader()
|
|
if err != nil {
|
|
t.Fatalf("read headers: %v", err)
|
|
}
|
|
respHeader := http.Header(headers)
|
|
if got := respHeader.Get("Content-Length"); got != "" {
|
|
t.Fatalf("CONNECT response should not include Content-Length, got %q", got)
|
|
}
|
|
if got := respHeader.Get("Transfer-Encoding"); got != "" {
|
|
t.Fatalf("CONNECT response should not include Transfer-Encoding, got %q", got)
|
|
}
|
|
if gotVia := respHeader.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.1 upstream" || gotVia[1] != "1.1 proxy.test" {
|
|
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
|
}
|
|
|
|
if _, err := io.WriteString(conn, "ping\n"); err != nil {
|
|
t.Fatalf("write tunneled payload: %v", err)
|
|
}
|
|
message, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
t.Fatalf("read tunneled payload: %v", err)
|
|
}
|
|
if message != "PING\n" {
|
|
t.Fatalf("unexpected tunneled payload: %q", message)
|
|
}
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyConnectNeedsHijacker(t *testing.T) {
|
|
t.Helper()
|
|
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
t.Fatal("backend response writer does not support hijack")
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
t.Fatalf("backend hijack failed: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\n\r\n")
|
|
_ = brw.Flush()
|
|
}))
|
|
defer backend.Close()
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodConnect, "/tunnel", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
|
|
|
req := httptest.NewRequest(http.MethodConnect, "http://client.example/tunnel", nil)
|
|
req.URL.Path = "/tunnel"
|
|
req.RequestURI = "/tunnel"
|
|
rr := httptest.NewRecorder()
|
|
engine.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusNotImplemented {
|
|
t.Fatalf("unexpected status: %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) {
|
|
t.Helper()
|
|
|
|
enableHTTP2ExtendedConnectProtocol()
|
|
|
|
errCh := make(chan error, 4)
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
if got := r.Header.Get(":protocol"); got != "" {
|
|
errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") {
|
|
errCh <- fmt.Errorf("unexpected upstream Connection header: %#v", r.Header.Values("Connection"))
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
|
errCh <- fmt.Errorf("unexpected upstream Upgrade header: %q", r.Header.Get("Upgrade"))
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.Header.Get("Sec-WebSocket-Key"); got == "" {
|
|
errCh <- errors.New("missing upstream Sec-WebSocket-Key header")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.URL.Path; got != "/ws" {
|
|
errCh <- fmt.Errorf("unexpected upstream path: %q", got)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
errCh <- errors.New("upstream response writer does not support hijack")
|
|
return
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade, X-Hop-Token\r\nX-Hop-Token: hidden\r\nSec-WebSocket-Accept: ignored\r\n\r\n")
|
|
if err := brw.Flush(); err != nil {
|
|
errCh <- fmt.Errorf("upstream flush failed: %w", err)
|
|
return
|
|
}
|
|
|
|
line, err := brw.ReadString('\n')
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
|
return
|
|
}
|
|
if _, err := io.WriteString(brw, "echo:"+line); err != nil {
|
|
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
|
|
return
|
|
}
|
|
_ = brw.Flush()
|
|
}))
|
|
defer upstream.Close()
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, upstream.URL),
|
|
Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
|
|
Via: "proxy.test",
|
|
}))
|
|
|
|
proxy := httptest.NewUnstartedServer(engine)
|
|
proxy.EnableHTTP2 = true
|
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
|
}
|
|
proxy.StartTLS()
|
|
defer proxy.Close()
|
|
|
|
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
|
defer transport.CloseIdleConnections()
|
|
|
|
pr, pw := io.Pipe()
|
|
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
|
if err != nil {
|
|
t.Fatalf("new CONNECT request: %v", err)
|
|
}
|
|
req.Header.Set(":protocol", "websocket")
|
|
|
|
resp, err := transport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
|
}
|
|
if got := resp.Header.Get("Upgrade"); got != "" {
|
|
t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got)
|
|
}
|
|
if got := resp.Header.Get("X-Hop-Token"); got != "" {
|
|
t.Fatalf("bridged extended CONNECT response should not expose hop-by-hop token header, got %q", got)
|
|
}
|
|
if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" {
|
|
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
|
}
|
|
|
|
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
|
t.Fatalf("write tunneled request body: %v", err)
|
|
}
|
|
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
|
if err != nil {
|
|
t.Fatalf("read tunneled response body: %v", err)
|
|
}
|
|
if message != "echo:ping\n" {
|
|
t.Fatalf("unexpected tunneled response body: %q", message)
|
|
}
|
|
if err := pw.Close(); err != nil {
|
|
t.Fatalf("close tunneled request body: %v", err)
|
|
}
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) {
|
|
t.Helper()
|
|
|
|
enableHTTP2ExtendedConnectProtocol()
|
|
|
|
closeCalls := atomic.Int32{}
|
|
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
|
if req.Method != http.MethodGet {
|
|
return nil, fmt.Errorf("unexpected upstream method: %s", req.Method)
|
|
}
|
|
backend := &countingReadWriteCloser{
|
|
readData: []byte("echo:ping\n"),
|
|
closeCalls: &closeCalls,
|
|
closeWriteErr: http.ErrNotSupported,
|
|
}
|
|
return &http.Response{
|
|
StatusCode: http.StatusSwitchingProtocols,
|
|
Header: http.Header{
|
|
"Connection": []string{"Upgrade"},
|
|
"Upgrade": []string{"websocket"},
|
|
"Sec-WebSocket-Accept": []string{"ignored"},
|
|
},
|
|
Body: backend,
|
|
Request: req,
|
|
}, nil
|
|
})
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, "http://example.com"),
|
|
Transport: transport,
|
|
}))
|
|
|
|
proxy := httptest.NewUnstartedServer(engine)
|
|
proxy.EnableHTTP2 = true
|
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
|
}
|
|
proxy.StartTLS()
|
|
defer proxy.Close()
|
|
|
|
clientTransport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
|
defer clientTransport.CloseIdleConnections()
|
|
|
|
pr, pw := io.Pipe()
|
|
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
|
if err != nil {
|
|
t.Fatalf("new CONNECT request: %v", err)
|
|
}
|
|
req.Header.Set(":protocol", "websocket")
|
|
|
|
resp, err := clientTransport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
_ = resp.Body.Close()
|
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
|
}
|
|
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
|
_ = resp.Body.Close()
|
|
t.Fatalf("write tunneled request body: %v", err)
|
|
}
|
|
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
|
if err != nil {
|
|
_ = resp.Body.Close()
|
|
t.Fatalf("read tunneled response body: %v", err)
|
|
}
|
|
if message != "echo:ping\n" {
|
|
_ = resp.Body.Close()
|
|
t.Fatalf("unexpected tunneled response body: %q", message)
|
|
}
|
|
if err := pw.Close(); err != nil {
|
|
_ = resp.Body.Close()
|
|
t.Fatalf("close tunneled request body: %v", err)
|
|
}
|
|
if err := resp.Body.Close(); err != nil {
|
|
t.Fatalf("close response body: %v", err)
|
|
}
|
|
|
|
deadline := time.Now().Add(2 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
if closeCalls.Load() > 0 {
|
|
break
|
|
}
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
if got := closeCalls.Load(); got != 1 {
|
|
t.Fatalf("expected backend connection to close exactly once, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) {
|
|
t.Helper()
|
|
|
|
enableHTTP2ExtendedConnectProtocol()
|
|
|
|
errCh := make(chan error, 4)
|
|
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.ProtoMajor != 1 {
|
|
errCh <- fmt.Errorf("expected bridged upstream protocol HTTP/1.x, got %s", r.Proto)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if r.Method != http.MethodGet {
|
|
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
|
errCh <- fmt.Errorf("unexpected websocket bridge headers: Connection=%#v Upgrade=%q", r.Header.Values("Connection"), r.Header.Get("Upgrade"))
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
errCh <- errors.New("upstream response writer does not support hijack")
|
|
return
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
|
if err := brw.Flush(); err != nil {
|
|
errCh <- fmt.Errorf("upstream flush failed: %w", err)
|
|
return
|
|
}
|
|
|
|
line, err := brw.ReadString('\n')
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
|
return
|
|
}
|
|
if _, err := io.WriteString(brw, "echo:"+line); err != nil {
|
|
errCh <- fmt.Errorf("write tunneled response body failed: %w", err)
|
|
return
|
|
}
|
|
_ = brw.Flush()
|
|
}))
|
|
upstream.EnableHTTP2 = true
|
|
upstream.StartTLS()
|
|
defer upstream.Close()
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, upstream.URL),
|
|
Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
|
|
Via: "proxy.test",
|
|
}))
|
|
|
|
proxy := httptest.NewUnstartedServer(engine)
|
|
proxy.EnableHTTP2 = true
|
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
|
}
|
|
proxy.StartTLS()
|
|
defer proxy.Close()
|
|
|
|
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
|
defer transport.CloseIdleConnections()
|
|
|
|
pr, pw := io.Pipe()
|
|
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
|
if err != nil {
|
|
t.Fatalf("new CONNECT request: %v", err)
|
|
}
|
|
req.Header.Set(":protocol", "websocket")
|
|
|
|
resp, err := transport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
|
}
|
|
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
|
t.Fatalf("write tunneled request body: %v", err)
|
|
}
|
|
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
|
if err != nil {
|
|
t.Fatalf("read tunneled response body: %v", err)
|
|
}
|
|
if message != "echo:ping\n" {
|
|
t.Fatalf("unexpected tunneled response body: %q", message)
|
|
}
|
|
_ = pw.Close()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) {
|
|
t.Helper()
|
|
|
|
enableHTTP2ExtendedConnectProtocol()
|
|
|
|
errCh := make(chan error, 8)
|
|
newBackend := func(name string) *httptest.Server {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method)
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
if got := r.Header.Get(":protocol"); got != "" {
|
|
errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") {
|
|
errCh <- fmt.Errorf("%s unexpected upstream Connection header: %#v", name, r.Header.Values("Connection"))
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
|
errCh <- fmt.Errorf("%s unexpected upstream Upgrade header: %q", name, r.Header.Get("Upgrade"))
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
if got := r.Header.Get("Sec-WebSocket-Key"); got == "" {
|
|
errCh <- fmt.Errorf("%s missing upstream Sec-WebSocket-Key header", name)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
errCh <- fmt.Errorf("%s upstream response writer does not support hijack", name)
|
|
return
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("%s upstream hijack failed: %w", name, err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
|
if err := brw.Flush(); err != nil {
|
|
errCh <- fmt.Errorf("%s upstream flush failed: %w", name, err)
|
|
return
|
|
}
|
|
|
|
line, err := brw.ReadString('\n')
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err)
|
|
return
|
|
}
|
|
if _, err := io.WriteString(brw, name+":"+line); err != nil {
|
|
errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err)
|
|
return
|
|
}
|
|
_ = brw.Flush()
|
|
}))
|
|
return server
|
|
}
|
|
|
|
backendOne := newBackend("one")
|
|
defer backendOne.Close()
|
|
backendTwo := newBackend("two")
|
|
defer backendTwo.Close()
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
|
Targets: []string{backendOne.URL, backendTwo.URL},
|
|
LoadBalancing: ReverseProxyLoadBalancingConfig{
|
|
Policy: LBRoundRobin(),
|
|
},
|
|
Via: "proxy.test",
|
|
}))
|
|
|
|
proxy := httptest.NewUnstartedServer(engine)
|
|
proxy.EnableHTTP2 = true
|
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
|
}
|
|
proxy.StartTLS()
|
|
defer proxy.Close()
|
|
|
|
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
|
defer transport.CloseIdleConnections()
|
|
|
|
doRequest := func(payload string) string {
|
|
pr, pw := io.Pipe()
|
|
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
|
if err != nil {
|
|
t.Fatalf("new CONNECT request: %v", err)
|
|
}
|
|
req.Header.Set(":protocol", "websocket")
|
|
|
|
resp, err := transport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
|
}
|
|
if _, err := io.WriteString(pw, payload+"\n"); err != nil {
|
|
t.Fatalf("write tunneled request body: %v", err)
|
|
}
|
|
if err := pw.Close(); err != nil {
|
|
t.Fatalf("close tunneled request body: %v", err)
|
|
}
|
|
message, err := bufio.NewReader(resp.Body).ReadString('\n')
|
|
if err != nil {
|
|
t.Fatalf("read tunneled response body: %v", err)
|
|
}
|
|
return message
|
|
}
|
|
|
|
if got := doRequest("ping"); got != "one:ping\n" {
|
|
t.Fatalf("unexpected first tunneled response: %q", got)
|
|
}
|
|
if got := doRequest("pong"); got != "two:pong\n" {
|
|
t.Fatalf("unexpected second tunneled response: %q", got)
|
|
}
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) {
|
|
t.Helper()
|
|
|
|
enableHTTP2ExtendedConnectProtocol()
|
|
|
|
errCh := make(chan error, 4)
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
errCh <- errors.New("upstream response writer does not support hijack")
|
|
return
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
|
if err := brw.Flush(); err != nil {
|
|
errCh <- fmt.Errorf("upstream flush failed: %w", err)
|
|
return
|
|
}
|
|
|
|
reader := bufio.NewReader(brw)
|
|
line, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("read tunneled request body failed: %w", err)
|
|
return
|
|
}
|
|
if _, err := io.WriteString(brw, "ack:"+line); err != nil {
|
|
errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err)
|
|
return
|
|
}
|
|
_ = brw.Flush()
|
|
|
|
if _, err := io.Copy(io.Discard, reader); err != nil {
|
|
errCh <- fmt.Errorf("wait for request half-close failed: %w", err)
|
|
return
|
|
}
|
|
if _, err := io.WriteString(brw, "after-close\n"); err != nil {
|
|
errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err)
|
|
return
|
|
}
|
|
_ = brw.Flush()
|
|
}))
|
|
defer upstream.Close()
|
|
|
|
engine := New()
|
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, upstream.URL),
|
|
Via: "proxy.test",
|
|
}))
|
|
|
|
proxy := httptest.NewUnstartedServer(engine)
|
|
proxy.EnableHTTP2 = true
|
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
|
}
|
|
proxy.StartTLS()
|
|
defer proxy.Close()
|
|
|
|
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
|
defer transport.CloseIdleConnections()
|
|
|
|
pr, pw := io.Pipe()
|
|
req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr)
|
|
if err != nil {
|
|
t.Fatalf("new CONNECT request: %v", err)
|
|
}
|
|
req.Header.Set(":protocol", "websocket")
|
|
|
|
resp, err := transport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
|
}
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
if _, err := io.WriteString(pw, "ping\n"); err != nil {
|
|
t.Fatalf("write tunneled request body: %v", err)
|
|
}
|
|
message, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
t.Fatalf("read immediate tunneled response: %v", err)
|
|
}
|
|
if message != "ack:ping\n" {
|
|
t.Fatalf("unexpected immediate tunneled response: %q", message)
|
|
}
|
|
if err := pw.Close(); err != nil {
|
|
t.Fatalf("close tunneled request body: %v", err)
|
|
}
|
|
|
|
message, err = reader.ReadString('\n')
|
|
if err != nil {
|
|
t.Fatalf("read post-close tunneled response: %v", err)
|
|
}
|
|
if message != "after-close\n" {
|
|
t.Fatalf("unexpected post-close tunneled response: %q", message)
|
|
}
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testing.T) {
|
|
t.Helper()
|
|
|
|
enableHTTP2ExtendedConnectProtocol()
|
|
|
|
errCh := make(chan error, 4)
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method)
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
hj, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
errCh <- errors.New("upstream response writer does not support hijack")
|
|
return
|
|
}
|
|
conn, brw, err := hj.Hijack()
|
|
if err != nil {
|
|
errCh <- fmt.Errorf("upstream hijack failed: %w", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
|
_ = brw.Flush()
|
|
|
|
<-r.Context().Done()
|
|
}))
|
|
defer upstream.Close()
|
|
|
|
proxyErrCh := make(chan error, 1)
|
|
engine := New()
|
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, upstream.URL),
|
|
Via: "proxy.test",
|
|
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
|
select {
|
|
case proxyErrCh <- err:
|
|
default:
|
|
}
|
|
},
|
|
}))
|
|
|
|
proxy := httptest.NewUnstartedServer(engine)
|
|
proxy.EnableHTTP2 = true
|
|
if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil {
|
|
t.Fatalf("configure proxy HTTP/2 server: %v", err)
|
|
}
|
|
proxy.StartTLS()
|
|
defer proxy.Close()
|
|
|
|
transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
|
defer transport.CloseIdleConnections()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
pr, pw := io.Pipe()
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodConnect, proxy.URL+"/ws", pr)
|
|
if err != nil {
|
|
t.Fatalf("new CONNECT request: %v", err)
|
|
}
|
|
req.Header.Set(":protocol", "websocket")
|
|
|
|
resp, err := transport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("round trip extended CONNECT: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body))
|
|
}
|
|
|
|
writeErrCh := make(chan error, 1)
|
|
go func() {
|
|
_, err := io.WriteString(pw, strings.Repeat("x", 1<<20))
|
|
writeErrCh <- err
|
|
}()
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
cancel()
|
|
_ = pw.CloseWithError(context.Canceled)
|
|
select {
|
|
case <-writeErrCh:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for request body writer to unblock")
|
|
}
|
|
|
|
select {
|
|
case err := <-proxyErrCh:
|
|
t.Fatalf("proxy error handler should not be called on cancellation, got: %v", err)
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
|
|
t.Helper()
|
|
|
|
engine := New()
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, "http://example.com"),
|
|
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{
|
|
"Content-Type": []string{"text/plain"},
|
|
},
|
|
Body: &failingReadCloser{chunks: []string{"ok"}, err: errors.New("boom")},
|
|
ContentLength: -1,
|
|
Request: req,
|
|
}, nil
|
|
}),
|
|
}))
|
|
|
|
proxy := httptest.NewServer(engine)
|
|
defer proxy.Close()
|
|
|
|
resp, err := proxy.Client().Get(proxy.URL + "/proxy")
|
|
if err != nil {
|
|
t.Fatalf("perform request: %v", err)
|
|
}
|
|
_, err = io.ReadAll(resp.Body)
|
|
_ = resp.Body.Close()
|
|
if err == nil {
|
|
t.Fatal("expected body read to fail after upstream copy error")
|
|
}
|
|
}
|
|
|
|
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
|
|
t.Helper()
|
|
|
|
type oneXXInfo struct {
|
|
code int
|
|
header http.Header
|
|
}
|
|
|
|
backendTraceCh := make(chan struct{}, 1)
|
|
oneXXCh := make(chan oneXXInfo, 1)
|
|
|
|
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
|
trace := httptrace.ContextClientTrace(req.Context())
|
|
if trace == nil || trace.Got1xxResponse == nil {
|
|
return nil, errors.New("missing Got1xxResponse trace")
|
|
}
|
|
backendTraceCh <- struct{}{}
|
|
if err := trace.Got1xxResponse(http.StatusEarlyHints, textproto.MIMEHeader{"Link": {"</style.css>; rel=preload; as=style"}}); err != nil {
|
|
return nil, err
|
|
}
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{
|
|
"Content-Type": []string{"text/plain"},
|
|
},
|
|
Body: io.NopCloser(strings.NewReader("ok")),
|
|
ContentLength: 2,
|
|
Request: req,
|
|
}, nil
|
|
})
|
|
|
|
engine := New()
|
|
engine.Use(func(c *Context) {
|
|
c.Writer.Header().Set("X-Request-Id", "req-123")
|
|
c.Next()
|
|
})
|
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
|
Target: mustParseURL(t, "http://example.com"),
|
|
Transport: transport,
|
|
}))
|
|
|
|
proxy := httptest.NewServer(engine)
|
|
defer proxy.Close()
|
|
|
|
client := proxy.Client()
|
|
req, err := http.NewRequest(http.MethodGet, proxy.URL+"/proxy", nil)
|
|
if err != nil {
|
|
t.Fatalf("new request: %v", err)
|
|
}
|
|
req = req.WithContext(httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{
|
|
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
|
oneXXCh <- oneXXInfo{code: code, header: http.Header(header).Clone()}
|
|
return nil
|
|
},
|
|
}))
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("perform request: %v", err)
|
|
}
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("read body: %v", err)
|
|
}
|
|
_ = resp.Body.Close()
|
|
|
|
select {
|
|
case <-backendTraceCh:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("expected proxy transport 1xx trace to be invoked")
|
|
}
|
|
|
|
var oneXX oneXXInfo
|
|
select {
|
|
case oneXX = <-oneXXCh:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("expected client to receive 1xx response")
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
|
}
|
|
if string(body) != "ok" {
|
|
t.Fatalf("unexpected body: %q", string(body))
|
|
}
|
|
if got := resp.Header.Get("X-Request-Id"); got != "req-123" {
|
|
t.Fatalf("final response lost preserved header: %q", got)
|
|
}
|
|
if got := resp.Header.Get("Link"); got != "" {
|
|
t.Fatalf("interim 1xx header leaked into final response: %q", got)
|
|
}
|
|
if oneXX.code != http.StatusEarlyHints {
|
|
t.Fatalf("unexpected interim status: %d", oneXX.code)
|
|
}
|
|
if got := oneXX.header.Get("Link"); got != "</style.css>; rel=preload; as=style" {
|
|
t.Fatalf("unexpected interim Link header: %q", got)
|
|
}
|
|
if got := oneXX.header.Get("X-Request-Id"); got != "" {
|
|
t.Fatalf("final-only header leaked into interim response: %q", got)
|
|
}
|
|
}
|
|
|
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return fn(req)
|
|
}
|
|
|
|
type flushErrorResponseWriter struct {
|
|
header http.Header
|
|
body bytes.Buffer
|
|
status int
|
|
written bool
|
|
flushErr error
|
|
}
|
|
|
|
func (w *flushErrorResponseWriter) Header() http.Header {
|
|
if w.header == nil {
|
|
w.header = make(http.Header)
|
|
}
|
|
return w.header
|
|
}
|
|
|
|
func (w *flushErrorResponseWriter) WriteHeader(statusCode int) {
|
|
if w.written {
|
|
return
|
|
}
|
|
w.status = statusCode
|
|
w.written = true
|
|
}
|
|
|
|
func (w *flushErrorResponseWriter) Write(p []byte) (int, error) {
|
|
if !w.written {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
return w.body.Write(p)
|
|
}
|
|
|
|
func (w *flushErrorResponseWriter) Flush() {}
|
|
|
|
func (w *flushErrorResponseWriter) FlushError() error {
|
|
if !w.written {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
return w.flushErr
|
|
}
|
|
|
|
func (w *flushErrorResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
return nil, nil, http.ErrNotSupported
|
|
}
|
|
|
|
func (w *flushErrorResponseWriter) Status() int {
|
|
return w.status
|
|
}
|
|
|
|
func (w *flushErrorResponseWriter) Size() int {
|
|
return w.body.Len()
|
|
}
|
|
|
|
func (w *flushErrorResponseWriter) Written() bool {
|
|
return w.written
|
|
}
|
|
|
|
func (w *flushErrorResponseWriter) IsHijacked() bool {
|
|
return false
|
|
}
|
|
|
|
type errorReader struct {
|
|
err error
|
|
}
|
|
|
|
func (r errorReader) Read([]byte) (int, error) {
|
|
return 0, r.err
|
|
}
|
|
|
|
type countingReadWriteCloser struct {
|
|
readData []byte
|
|
writeBuf bytes.Buffer
|
|
closeCalls *atomic.Int32
|
|
closeWriteErr error
|
|
}
|
|
|
|
func (r *countingReadWriteCloser) Read(p []byte) (int, error) {
|
|
if len(r.readData) == 0 {
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, r.readData)
|
|
r.readData = r.readData[n:]
|
|
return n, nil
|
|
}
|
|
|
|
func (r *countingReadWriteCloser) Write(p []byte) (int, error) {
|
|
return r.writeBuf.Write(p)
|
|
}
|
|
|
|
func (r *countingReadWriteCloser) Close() error {
|
|
if r.closeCalls != nil {
|
|
r.closeCalls.Add(1)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *countingReadWriteCloser) CloseWrite() error {
|
|
return r.closeWriteErr
|
|
}
|
|
|
|
func mustParseURL(t *testing.T, raw string) *url.URL {
|
|
t.Helper()
|
|
u, err := url.Parse(raw)
|
|
if err != nil {
|
|
t.Fatalf("parse url %q: %v", raw, err)
|
|
}
|
|
return u
|
|
}
|
|
|
|
type failingReadCloser struct {
|
|
chunks []string
|
|
err error
|
|
}
|
|
|
|
func (r *failingReadCloser) Read(p []byte) (int, error) {
|
|
if len(r.chunks) == 0 {
|
|
return 0, r.err
|
|
}
|
|
n := copy(p, r.chunks[0])
|
|
r.chunks = r.chunks[1:]
|
|
return n, nil
|
|
}
|
|
|
|
func (r *failingReadCloser) Close() error {
|
|
return nil
|
|
}
|