touka/reverseproxy_test.go
wjqserver 1946216c0e fix: harden reverse proxy edge cases
Preserve final headers when forwarding 1xx responses, reject invalid 101 upgrade negotiations, and make the default Via token RFC-safe. Tighten the reverse proxy tests around goroutine synchronization and document the Via fallback behavior more clearly.
2026-03-29 01:15:57 +08:00

570 lines
16 KiB
Go

package touka
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"net/textproto"
"net/url"
"strings"
"testing"
"time"
)
func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
t.Helper()
type backendRequestSnapshot struct {
Path string
RawQuery string
Host string
Connection string
RemovedHeader string
Forwarded string
XForwardedFor string
XForwardedHost string
XForwardedProto string
Via []string
TE string
UserAgent string
}
gotCh := make(chan backendRequestSnapshot, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotCh <- backendRequestSnapshot{
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
Host: r.Host,
Connection: r.Header.Get("Connection"),
RemovedHeader: r.Header.Get("X-Remove-Me"),
Forwarded: r.Header.Get("Forwarded"),
XForwardedFor: r.Header.Get("X-Forwarded-For"),
XForwardedHost: r.Header.Get("X-Forwarded-Host"),
XForwardedProto: r.Header.Get("X-Forwarded-Proto"),
Via: append([]string(nil), r.Header.Values("Via")...),
TE: r.Header.Get("Te"),
UserAgent: r.Header.Get("User-Agent"),
}
w.Header().Set("Connection", "X-Backend-Secret")
w.Header().Set("X-Backend-Secret", "remove-me")
w.Header().Add("Via", "1.0 upstream")
w.Header().Add("Trailer", "X-Upstream-Trailer")
w.Header().Set("Content-Type", "text/plain")
_, _ = io.WriteString(w, "proxied")
w.Header().Set("X-Upstream-Trailer", "done")
}))
defer backend.Close()
target, err := url.Parse(backend.URL + "/base?from=target")
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
Target: target,
ForwardedHeaders: ForwardedBoth,
ForwardedBy: "proxy-node",
Via: "proxy.test",
}))
req := httptest.NewRequest(http.MethodGet, "http://client.example/api/ping?q=2", nil)
req.Host = "client.example"
req.RemoteAddr = "198.51.100.10:4567"
req.Header.Set("Connection", "X-Remove-Me")
req.Header.Set("X-Remove-Me", "client-secret")
req.Header.Set("X-Forwarded-For", "203.0.113.9")
req.Header.Set("X-Forwarded-Host", "edge.example")
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("Forwarded", "for=203.0.113.9")
req.Header.Set("Te", "trailers")
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
resp := rr.Result()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
_ = resp.Body.Close()
var got backendRequestSnapshot
select {
case got = <-gotCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend snapshot")
}
if string(body) != "proxied" {
t.Fatalf("unexpected body: %q", string(body))
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if got.Path != "/base/api/ping" {
t.Fatalf("unexpected upstream path: %q", got.Path)
}
if got.RawQuery != "from=target&q=2" {
t.Fatalf("unexpected upstream raw query: %q", got.RawQuery)
}
if got.Host != strings.TrimPrefix(backend.URL, "http://") {
t.Fatalf("unexpected upstream host: %q", got.Host)
}
if got.Connection != "" {
t.Fatalf("connection header should be stripped, got %q", got.Connection)
}
if got.RemovedHeader != "" {
t.Fatalf("connection-token header should be stripped, got %q", got.RemovedHeader)
}
if got.XForwardedFor != "203.0.113.9, 198.51.100.10" {
t.Fatalf("unexpected X-Forwarded-For: %q", got.XForwardedFor)
}
if got.XForwardedHost != "edge.example" {
t.Fatalf("unexpected X-Forwarded-Host: %q", got.XForwardedHost)
}
if got.XForwardedProto != "https" {
t.Fatalf("unexpected X-Forwarded-Proto: %q", got.XForwardedProto)
}
if got.TE != "trailers" {
t.Fatalf("unexpected TE header: %q", got.TE)
}
if got.UserAgent != "" {
t.Fatalf("expected empty user-agent suppression, got %q", got.UserAgent)
}
if !strings.Contains(got.Forwarded, "for=203.0.113.9") {
t.Fatalf("forwarded header missing prior hop: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "for=198.51.100.10") {
t.Fatalf("forwarded header missing client ip: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "by=proxy-node") {
t.Fatalf("forwarded header missing by token: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "host=client.example") {
t.Fatalf("forwarded header missing host: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "proto=http") {
t.Fatalf("forwarded header missing proto: %q", got.Forwarded)
}
if len(got.Via) != 1 || got.Via[0] != "1.1 proxy.test" {
t.Fatalf("unexpected upstream Via headers: %#v", got.Via)
}
if resp.Header.Get("Connection") != "" {
t.Fatalf("response connection header should be stripped, got %q", resp.Header.Get("Connection"))
}
if resp.Header.Get("X-Backend-Secret") != "" {
t.Fatalf("response connection-token header should be stripped, got %q", resp.Header.Get("X-Backend-Secret"))
}
if gotVia := resp.Header.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.0 upstream" || gotVia[1] != "1.1 proxy.test" {
t.Fatalf("unexpected response Via headers: %#v", gotVia)
}
if resp.Trailer.Get("X-Upstream-Trailer") != "done" {
t.Fatalf("unexpected proxied trailer: %q", resp.Trailer.Get("X-Upstream-Trailer"))
}
}
func TestReverseProxyDefaultViaFallback(t *testing.T) {
t.Helper()
viaCh := make(chan []string, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
viaCh <- append([]string(nil), r.Header.Values("Via")...)
w.WriteHeader(http.StatusNoContent)
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{Target: target}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusNoContent {
t.Fatalf("unexpected status: %d", rr.Code)
}
select {
case via := <-viaCh:
if len(via) != 1 || via[0] != "1.1 touka-engine" {
t.Fatalf("unexpected default Via header: %#v", via)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend Via header")
}
}
func TestReverseProxyCustomErrorHandler(t *testing.T) {
t.Helper()
engine := New()
target, err := url.Parse("http://127.0.0.1:1")
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: target,
ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) {
w.WriteHeader(http.StatusGatewayTimeout)
_, _ = io.WriteString(w, fmt.Sprintf("proxy failure: %v", err))
},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusGatewayTimeout {
t.Fatalf("unexpected status: %d", rr.Code)
}
if !strings.Contains(rr.Body.String(), "proxy failure:") {
t.Fatalf("unexpected body: %q", rr.Body.String())
}
}
func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "later")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "streamed")
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/trailers", ReverseProxy(ReverseProxyConfig{Target: target}))
rr := PerformRequest(engine, http.MethodGet, "/trailers", nil, nil)
resp := rr.Result()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if string(body) != "streamed" {
t.Fatalf("unexpected body: %q", string(body))
}
if got := resp.Trailer.Get("X-Unannounced-Trailer"); got != "later" {
t.Fatalf("unexpected unannounced trailer: %q", got)
}
}
func TestReverseProxyProtocolUpgrade(t *testing.T) {
t.Helper()
errCh := make(chan error, 8)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !headerValuesContainToken(r.Header["Connection"], "Upgrade") {
errCh <- fmt.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection"))
w.WriteHeader(http.StatusBadRequest)
return
}
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
errCh <- fmt.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade"))
w.WriteHeader(http.StatusBadRequest)
return
}
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("backend response writer does not support hijack")
return
}
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("backend hijack failed: %w", err)
return
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("backend flush failed: %w", err)
return
}
line, err := brw.ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("backend read failed: %w", err)
return
}
_, _ = io.WriteString(brw, "echo:"+line)
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("backend echo flush failed: %w", err)
return
}
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{
Target: target,
Via: "proxy.test",
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
t.Fatalf("set deadline: %v", err)
}
_, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
if err != nil {
t.Fatalf("write upgrade request: %v", err)
}
reader := bufio.NewReader(conn)
statusLine, err := reader.ReadString('\n')
if err != nil {
t.Fatalf("read status line: %v", err)
}
if !strings.Contains(statusLine, "101") {
t.Fatalf("unexpected status line: %q", statusLine)
}
headers, err := textproto.NewReader(reader).ReadMIMEHeader()
if err != nil {
t.Fatalf("read headers: %v", err)
}
respHeader := http.Header(headers)
if !strings.EqualFold(respHeader.Get("Upgrade"), "websocket") {
t.Fatalf("unexpected upgrade response header: %q", respHeader.Get("Upgrade"))
}
if !headerValuesContainToken(respHeader.Values("Connection"), "Upgrade") {
t.Fatalf("unexpected connection response header: %#v", respHeader.Values("Connection"))
}
if gotVia := respHeader.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" {
t.Fatalf("unexpected Via response header: %#v", gotVia)
}
if _, err := io.WriteString(conn, "ping\n"); err != nil {
t.Fatalf("write tunneled payload: %v", err)
}
message, err := reader.ReadString('\n')
if err != nil {
t.Fatalf("read tunneled payload: %v", err)
}
if message != "echo:ping\n" {
t.Fatalf("unexpected tunneled payload: %q", message)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) {
t.Helper()
errCh := make(chan error, 4)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("backend response writer does not support hijack")
return
}
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("backend hijack failed: %w", err)
return
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("backend flush failed: %w", err)
return
}
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: target}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
t.Fatalf("set deadline: %v", err)
}
_, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
if err != nil {
t.Fatalf("write upgrade request: %v", err)
}
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
if err != nil {
t.Fatalf("read response: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
t.Helper()
type oneXXInfo struct {
code int
header http.Header
}
backendTraceCh := make(chan struct{}, 1)
oneXXCh := make(chan oneXXInfo, 1)
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.Got1xxResponse == nil {
return nil, errors.New("missing Got1xxResponse trace")
}
backendTraceCh <- struct{}{}
if err := trace.Got1xxResponse(http.StatusEarlyHints, textproto.MIMEHeader{"Link": {"</style.css>; rel=preload; as=style"}}); err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/plain"},
},
Body: io.NopCloser(strings.NewReader("ok")),
ContentLength: 2,
Request: req,
}, nil
})
engine := New()
engine.Use(func(c *Context) {
c.Writer.Header().Set("X-Request-Id", "req-123")
c.Next()
})
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
Transport: transport,
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
client := proxy.Client()
req, err := http.NewRequest(http.MethodGet, proxy.URL+"/proxy", nil)
if err != nil {
t.Fatalf("new request: %v", err)
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
oneXXCh <- oneXXInfo{code: code, header: http.Header(header).Clone()}
return nil
},
}))
resp, err := client.Do(req)
if err != nil {
t.Fatalf("perform request: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
_ = resp.Body.Close()
select {
case <-backendTraceCh:
case <-time.After(2 * time.Second):
t.Fatal("expected proxy transport 1xx trace to be invoked")
}
var oneXX oneXXInfo
select {
case oneXX = <-oneXXCh:
case <-time.After(2 * time.Second):
t.Fatal("expected client to receive 1xx response")
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if string(body) != "ok" {
t.Fatalf("unexpected body: %q", string(body))
}
if got := resp.Header.Get("X-Request-Id"); got != "req-123" {
t.Fatalf("final response lost preserved header: %q", got)
}
if got := resp.Header.Get("Link"); got != "" {
t.Fatalf("interim 1xx header leaked into final response: %q", got)
}
if oneXX.code != http.StatusEarlyHints {
t.Fatalf("unexpected interim status: %d", oneXX.code)
}
if got := oneXX.header.Get("Link"); got != "</style.css>; rel=preload; as=style" {
t.Fatalf("unexpected interim Link header: %q", got)
}
if got := oneXX.header.Get("X-Request-Id"); got != "" {
t.Fatalf("final-only header leaked into interim response: %q", got)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper()
u, err := url.Parse(raw)
if err != nil {
t.Fatalf("parse url %q: %v", raw, err)
}
return u
}