touka/reverseproxy_test.go
wjqserver 6d89b8674f fix: tighten reverse proxy safety handling
Avoid HTTP error writes after hijacking upgraded connections, document ModifyResponse constraints for 101 responses, and normalize forwarded query strings consistently to reduce parsing ambiguity across proxy chains.
2026-03-29 01:39:09 +08:00

570 lines
16 KiB
Go

package touka
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"net/textproto"
"net/url"
"strings"
"testing"
"time"
)
func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
t.Helper()
type backendRequestSnapshot struct {
Path string
RawQuery string
Host string
Connection string
RemovedHeader string
Forwarded string
XForwardedFor string
XForwardedHost string
XForwardedProto string
Via []string
TE string
UserAgent string
}
gotCh := make(chan backendRequestSnapshot, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotCh <- backendRequestSnapshot{
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
Host: r.Host,
Connection: r.Header.Get("Connection"),
RemovedHeader: r.Header.Get("X-Remove-Me"),
Forwarded: r.Header.Get("Forwarded"),
XForwardedFor: r.Header.Get("X-Forwarded-For"),
XForwardedHost: r.Header.Get("X-Forwarded-Host"),
XForwardedProto: r.Header.Get("X-Forwarded-Proto"),
Via: append([]string(nil), r.Header.Values("Via")...),
TE: r.Header.Get("Te"),
UserAgent: r.Header.Get("User-Agent"),
}
w.Header().Set("Connection", "X-Backend-Secret")
w.Header().Set("X-Backend-Secret", "remove-me")
w.Header().Add("Via", "1.0 upstream")
w.Header().Add("Trailer", "X-Upstream-Trailer")
w.Header().Set("Content-Type", "text/plain")
_, _ = io.WriteString(w, "proxied")
w.Header().Set("X-Upstream-Trailer", "done")
}))
defer backend.Close()
target, err := url.Parse(backend.URL + "/base?from=target")
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
Target: target,
ForwardedHeaders: ForwardedBoth,
ForwardedBy: "proxy-node",
Via: "proxy.test",
}))
req := httptest.NewRequest(http.MethodGet, "http://client.example/api/ping?bad=1;smuggle=2&q=2", nil)
req.Host = "client.example"
req.RemoteAddr = "198.51.100.10:4567"
req.Header.Set("Connection", "X-Remove-Me")
req.Header.Set("X-Remove-Me", "client-secret")
req.Header.Set("X-Forwarded-For", "203.0.113.9")
req.Header.Set("X-Forwarded-Host", "edge.example")
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("Forwarded", "for=203.0.113.9")
req.Header.Set("Te", "trailers")
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
resp := rr.Result()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
_ = resp.Body.Close()
var got backendRequestSnapshot
select {
case got = <-gotCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend snapshot")
}
if string(body) != "proxied" {
t.Fatalf("unexpected body: %q", string(body))
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if got.Path != "/base/api/ping" {
t.Fatalf("unexpected upstream path: %q", got.Path)
}
if got.RawQuery != "from=target&q=2" {
t.Fatalf("unexpected upstream raw query: %q", got.RawQuery)
}
if got.Host != strings.TrimPrefix(backend.URL, "http://") {
t.Fatalf("unexpected upstream host: %q", got.Host)
}
if got.Connection != "" {
t.Fatalf("connection header should be stripped, got %q", got.Connection)
}
if got.RemovedHeader != "" {
t.Fatalf("connection-token header should be stripped, got %q", got.RemovedHeader)
}
if got.XForwardedFor != "203.0.113.9, 198.51.100.10" {
t.Fatalf("unexpected X-Forwarded-For: %q", got.XForwardedFor)
}
if got.XForwardedHost != "edge.example" {
t.Fatalf("unexpected X-Forwarded-Host: %q", got.XForwardedHost)
}
if got.XForwardedProto != "https" {
t.Fatalf("unexpected X-Forwarded-Proto: %q", got.XForwardedProto)
}
if got.TE != "trailers" {
t.Fatalf("unexpected TE header: %q", got.TE)
}
if got.UserAgent != "" {
t.Fatalf("expected empty user-agent suppression, got %q", got.UserAgent)
}
if !strings.Contains(got.Forwarded, "for=203.0.113.9") {
t.Fatalf("forwarded header missing prior hop: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "for=198.51.100.10") {
t.Fatalf("forwarded header missing client ip: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "by=proxy-node") {
t.Fatalf("forwarded header missing by token: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "host=client.example") {
t.Fatalf("forwarded header missing host: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "proto=http") {
t.Fatalf("forwarded header missing proto: %q", got.Forwarded)
}
if len(got.Via) != 1 || got.Via[0] != "1.1 proxy.test" {
t.Fatalf("unexpected upstream Via headers: %#v", got.Via)
}
if resp.Header.Get("Connection") != "" {
t.Fatalf("response connection header should be stripped, got %q", resp.Header.Get("Connection"))
}
if resp.Header.Get("X-Backend-Secret") != "" {
t.Fatalf("response connection-token header should be stripped, got %q", resp.Header.Get("X-Backend-Secret"))
}
if gotVia := resp.Header.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.0 upstream" || gotVia[1] != "1.1 proxy.test" {
t.Fatalf("unexpected response Via headers: %#v", gotVia)
}
if resp.Trailer.Get("X-Upstream-Trailer") != "done" {
t.Fatalf("unexpected proxied trailer: %q", resp.Trailer.Get("X-Upstream-Trailer"))
}
}
func TestReverseProxyDefaultViaFallback(t *testing.T) {
t.Helper()
viaCh := make(chan []string, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
viaCh <- append([]string(nil), r.Header.Values("Via")...)
w.WriteHeader(http.StatusNoContent)
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{Target: target}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusNoContent {
t.Fatalf("unexpected status: %d", rr.Code)
}
select {
case via := <-viaCh:
if len(via) != 1 || via[0] != "1.1 touka-engine" {
t.Fatalf("unexpected default Via header: %#v", via)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend Via header")
}
}
func TestReverseProxyCustomErrorHandler(t *testing.T) {
t.Helper()
engine := New()
target, err := url.Parse("http://127.0.0.1:1")
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: target,
ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) {
w.WriteHeader(http.StatusGatewayTimeout)
_, _ = io.WriteString(w, fmt.Sprintf("proxy failure: %v", err))
},
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusGatewayTimeout {
t.Fatalf("unexpected status: %d", rr.Code)
}
if !strings.Contains(rr.Body.String(), "proxy failure:") {
t.Fatalf("unexpected body: %q", rr.Body.String())
}
}
func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "later")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "streamed")
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/trailers", ReverseProxy(ReverseProxyConfig{Target: target}))
rr := PerformRequest(engine, http.MethodGet, "/trailers", nil, nil)
resp := rr.Result()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if string(body) != "streamed" {
t.Fatalf("unexpected body: %q", string(body))
}
if got := resp.Trailer.Get("X-Unannounced-Trailer"); got != "later" {
t.Fatalf("unexpected unannounced trailer: %q", got)
}
}
func TestReverseProxyProtocolUpgrade(t *testing.T) {
t.Helper()
errCh := make(chan error, 8)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !headerValuesContainToken(r.Header["Connection"], "Upgrade") {
errCh <- fmt.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection"))
w.WriteHeader(http.StatusBadRequest)
return
}
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
errCh <- fmt.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade"))
w.WriteHeader(http.StatusBadRequest)
return
}
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("backend response writer does not support hijack")
return
}
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("backend hijack failed: %w", err)
return
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("backend flush failed: %w", err)
return
}
line, err := brw.ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("backend read failed: %w", err)
return
}
_, _ = io.WriteString(brw, "echo:"+line)
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("backend echo flush failed: %w", err)
return
}
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{
Target: target,
Via: "proxy.test",
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
t.Fatalf("set deadline: %v", err)
}
_, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
if err != nil {
t.Fatalf("write upgrade request: %v", err)
}
reader := bufio.NewReader(conn)
statusLine, err := reader.ReadString('\n')
if err != nil {
t.Fatalf("read status line: %v", err)
}
if !strings.Contains(statusLine, "101") {
t.Fatalf("unexpected status line: %q", statusLine)
}
headers, err := textproto.NewReader(reader).ReadMIMEHeader()
if err != nil {
t.Fatalf("read headers: %v", err)
}
respHeader := http.Header(headers)
if !strings.EqualFold(respHeader.Get("Upgrade"), "websocket") {
t.Fatalf("unexpected upgrade response header: %q", respHeader.Get("Upgrade"))
}
if !headerValuesContainToken(respHeader.Values("Connection"), "Upgrade") {
t.Fatalf("unexpected connection response header: %#v", respHeader.Values("Connection"))
}
if gotVia := respHeader.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" {
t.Fatalf("unexpected Via response header: %#v", gotVia)
}
if _, err := io.WriteString(conn, "ping\n"); err != nil {
t.Fatalf("write tunneled payload: %v", err)
}
message, err := reader.ReadString('\n')
if err != nil {
t.Fatalf("read tunneled payload: %v", err)
}
if message != "echo:ping\n" {
t.Fatalf("unexpected tunneled payload: %q", message)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) {
t.Helper()
errCh := make(chan error, 4)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("backend response writer does not support hijack")
return
}
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("backend hijack failed: %w", err)
return
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("backend flush failed: %w", err)
return
}
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: target}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
t.Fatalf("set deadline: %v", err)
}
_, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
if err != nil {
t.Fatalf("write upgrade request: %v", err)
}
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
if err != nil {
t.Fatalf("read response: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
t.Helper()
type oneXXInfo struct {
code int
header http.Header
}
backendTraceCh := make(chan struct{}, 1)
oneXXCh := make(chan oneXXInfo, 1)
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.Got1xxResponse == nil {
return nil, errors.New("missing Got1xxResponse trace")
}
backendTraceCh <- struct{}{}
if err := trace.Got1xxResponse(http.StatusEarlyHints, textproto.MIMEHeader{"Link": {"</style.css>; rel=preload; as=style"}}); err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/plain"},
},
Body: io.NopCloser(strings.NewReader("ok")),
ContentLength: 2,
Request: req,
}, nil
})
engine := New()
engine.Use(func(c *Context) {
c.Writer.Header().Set("X-Request-Id", "req-123")
c.Next()
})
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
Transport: transport,
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
client := proxy.Client()
req, err := http.NewRequest(http.MethodGet, proxy.URL+"/proxy", nil)
if err != nil {
t.Fatalf("new request: %v", err)
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
oneXXCh <- oneXXInfo{code: code, header: http.Header(header).Clone()}
return nil
},
}))
resp, err := client.Do(req)
if err != nil {
t.Fatalf("perform request: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
_ = resp.Body.Close()
select {
case <-backendTraceCh:
case <-time.After(2 * time.Second):
t.Fatal("expected proxy transport 1xx trace to be invoked")
}
var oneXX oneXXInfo
select {
case oneXX = <-oneXXCh:
case <-time.After(2 * time.Second):
t.Fatal("expected client to receive 1xx response")
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if string(body) != "ok" {
t.Fatalf("unexpected body: %q", string(body))
}
if got := resp.Header.Get("X-Request-Id"); got != "req-123" {
t.Fatalf("final response lost preserved header: %q", got)
}
if got := resp.Header.Get("Link"); got != "" {
t.Fatalf("interim 1xx header leaked into final response: %q", got)
}
if oneXX.code != http.StatusEarlyHints {
t.Fatalf("unexpected interim status: %d", oneXX.code)
}
if got := oneXX.header.Get("Link"); got != "</style.css>; rel=preload; as=style" {
t.Fatalf("unexpected interim Link header: %q", got)
}
if got := oneXX.header.Get("X-Request-Id"); got != "" {
t.Fatalf("final-only header leaked into interim response: %q", got)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper()
u, err := url.Parse(raw)
if err != nil {
t.Fatalf("parse url %q: %v", raw, err)
}
return u
}