fix(reverseproxy): align forwarding and tunnel semantics

This commit is contained in:
wjqserver 2026-04-02 03:18:49 +08:00
parent c019f24e99
commit ed44c592d3
6 changed files with 864 additions and 26 deletions

View file

@ -2,6 +2,7 @@ package touka
import (
"bufio"
"context"
"errors"
"fmt"
"io"
@ -70,7 +71,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
Target: target,
ForwardedHeaders: ForwardedBoth,
ForwardedBy: "proxy-node",
ForwardedBy: "_proxy-node",
Via: "proxy.test",
}))
@ -144,7 +145,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
if !strings.Contains(got.Forwarded, "for=198.51.100.10") {
t.Fatalf("forwarded header missing client ip: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "by=proxy-node") {
if !strings.Contains(got.Forwarded, "by=_proxy-node") {
t.Fatalf("forwarded header missing by token: %q", got.Forwarded)
}
if !strings.Contains(got.Forwarded, "host=client.example") {
@ -170,6 +171,61 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
}
}
func TestReverseProxyRejectsInvalidForwardedBy(t *testing.T) {
t.Helper()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
ForwardedHeaders: ForwardedBoth,
ForwardedBy: "proxy-node",
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusInternalServerError {
t.Fatalf("unexpected status: %d", rr.Code)
}
}
func TestReverseProxyForwardedByTrimsWhitespace(t *testing.T) {
t.Helper()
forwardedCh := make(chan string, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
forwardedCh <- r.Header.Get("Forwarded")
w.WriteHeader(http.StatusNoContent)
}))
defer backend.Close()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, backend.URL),
ForwardedHeaders: ForwardedBoth,
ForwardedBy: " _proxy-node ",
}))
req := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
req.RemoteAddr = "198.51.100.10:4567"
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
if rr.Code != http.StatusNoContent {
t.Fatalf("unexpected status: %d", rr.Code)
}
select {
case forwarded := <-forwardedCh:
if !strings.Contains(forwarded, "by=_proxy-node") {
t.Fatalf("unexpected Forwarded header: %q", forwarded)
}
if strings.Contains(forwarded, `by=" _proxy-node "`) {
t.Fatalf("forwarded header should not preserve surrounding whitespace: %q", forwarded)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend Forwarded header")
}
}
func TestReverseProxyDefaultViaFallback(t *testing.T) {
t.Helper()
@ -229,6 +285,23 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) {
}
}
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
t.Helper()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
return nil, context.DeadlineExceeded
}),
}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusGatewayTimeout {
t.Fatalf("unexpected status: %d", rr.Code)
}
}
func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
t.Helper()
@ -452,6 +525,362 @@ func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) {
}
}
func TestReverseProxyUpgradeNeedsHijacker(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
t.Fatal("backend response writer does not support hijack")
}
conn, brw, err := hj.Hijack()
if err != nil {
t.Fatalf("backend hijack failed: %v", err)
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
_ = brw.Flush()
}))
defer backend.Close()
engine := New()
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
req := httptest.NewRequest(http.MethodGet, "http://client.example/ws", nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
if rr.Code != http.StatusNotImplemented {
t.Fatalf("unexpected status: %d", rr.Code)
}
}
func TestReverseProxyMaxForwardsTraceHandledLocally(t *testing.T) {
t.Helper()
called := make(chan struct{}, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called <- struct{}{}
w.WriteHeader(http.StatusNoContent)
}))
defer backend.Close()
engine := New()
engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil)
req.RequestURI = "/trace"
req.Header.Set("Max-Forwards", "0")
req.Header.Set("Authorization", "secret")
req.Header.Set("Cookie", "a=b")
req.Header.Set("Forwarded", "for=192.0.2.1")
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
resp := rr.Result()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if got := resp.Header.Get("Content-Type"); got != "message/http" {
t.Fatalf("unexpected content type: %q", got)
}
if !strings.Contains(string(body), "TRACE /trace HTTP/1.1") {
t.Fatalf("trace body missing request line: %q", string(body))
}
if strings.Contains(string(body), "Authorization:") {
t.Fatalf("trace body leaked authorization header: %q", string(body))
}
if strings.Contains(string(body), "Cookie:") {
t.Fatalf("trace body leaked cookie header: %q", string(body))
}
if strings.Contains(string(body), "Forwarded:") {
t.Fatalf("trace body leaked forwarded header: %q", string(body))
}
select {
case <-called:
t.Fatal("backend should not be called when Max-Forwards is zero")
default:
}
}
func TestReverseProxyMaxForwardsTraceDecrementsBeforeForwarding(t *testing.T) {
t.Helper()
maxForwardsCh := make(chan string, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
maxForwardsCh <- r.Header.Get("Max-Forwards")
w.WriteHeader(http.StatusNoContent)
}))
defer backend.Close()
engine := New()
engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil)
req.Header.Set("Max-Forwards", "2")
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
if rr.Code != http.StatusNoContent {
t.Fatalf("unexpected status: %d", rr.Code)
}
select {
case got := <-maxForwardsCh:
if got != "1" {
t.Fatalf("unexpected Max-Forwards header: %q", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend Max-Forwards")
}
}
func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) {
t.Helper()
called := make(chan struct{}, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called <- struct{}{}
w.WriteHeader(http.StatusNoContent)
}))
defer backend.Close()
engine := New()
engine.GET("/proxy", func(c *Context) { c.Status(http.StatusNoContent) })
engine.OPTIONS("/proxy", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
req := httptest.NewRequest(http.MethodOptions, "http://client.example/proxy", nil)
req.Header.Set("Max-Forwards", "0")
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rr.Code)
}
allow := rr.Header().Get("Allow")
if !strings.Contains(allow, http.MethodGet) || !strings.Contains(allow, http.MethodOptions) {
t.Fatalf("unexpected Allow header: %q", allow)
}
select {
case <-called:
t.Fatal("backend should not be called when Max-Forwards is zero")
default:
}
}
func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
t.Helper()
engine := New()
engine.OPTIONS("/", func(c *Context) {
c.Status(http.StatusNoContent)
})
req := httptest.NewRequest(http.MethodOptions, "http://client.example/", nil)
req.RequestURI = "*"
req.URL.Path = ""
req.URL.RawPath = ""
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
if rr.Code != http.StatusNotFound {
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
}
}
func TestReverseProxyConnectTunnel(t *testing.T) {
t.Helper()
backendAddr := ""
errCh := make(chan error, 4)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
errCh <- fmt.Errorf("unexpected method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if got, want := r.RequestURI, backendAddr; got != want {
errCh <- fmt.Errorf("unexpected CONNECT target %q, want %q", got, want)
w.WriteHeader(http.StatusBadRequest)
return
}
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("backend response writer does not support hijack")
return
}
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("backend hijack failed: %w", err)
return
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\nVia: 1.1 upstream\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("backend flush failed: %w", err)
return
}
line, err := brw.ReadString('\n')
if err != nil {
errCh <- fmt.Errorf("backend read failed: %w", err)
return
}
_, _ = io.WriteString(brw, strings.ToUpper(line))
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("backend write failed: %w", err)
return
}
}))
defer backend.Close()
backendAddr = strings.TrimPrefix(backend.URL, "http://")
engine := New()
engine.Handle(http.MethodConnect, "/:authority", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, backend.URL),
Via: "proxy.test",
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
t.Fatalf("set deadline: %v", err)
}
_, err = fmt.Fprintf(conn, "CONNECT origin.example:443 HTTP/1.1\r\nHost: origin.example:443\r\n\r\n")
if err != nil {
t.Fatalf("write connect request: %v", err)
}
reader := bufio.NewReader(conn)
statusLine, err := reader.ReadString('\n')
if err != nil {
t.Fatalf("read status line: %v", err)
}
if !strings.Contains(statusLine, "200") {
t.Fatalf("unexpected status line: %q", statusLine)
}
headers, err := textproto.NewReader(reader).ReadMIMEHeader()
if err != nil {
t.Fatalf("read headers: %v", err)
}
respHeader := http.Header(headers)
if got := respHeader.Get("Content-Length"); got != "" {
t.Fatalf("CONNECT response should not include Content-Length, got %q", got)
}
if got := respHeader.Get("Transfer-Encoding"); got != "" {
t.Fatalf("CONNECT response should not include Transfer-Encoding, got %q", got)
}
if gotVia := respHeader.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.1 upstream" || gotVia[1] != "1.1 proxy.test" {
t.Fatalf("unexpected Via response header: %#v", gotVia)
}
if _, err := io.WriteString(conn, "ping\n"); err != nil {
t.Fatalf("write tunneled payload: %v", err)
}
message, err := reader.ReadString('\n')
if err != nil {
t.Fatalf("read tunneled payload: %v", err)
}
if message != "PING\n" {
t.Fatalf("unexpected tunneled payload: %q", message)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyConnectNeedsHijacker(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
t.Fatal("backend response writer does not support hijack")
}
conn, brw, err := hj.Hijack()
if err != nil {
t.Fatalf("backend hijack failed: %v", err)
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\n\r\n")
_ = brw.Flush()
}))
defer backend.Close()
engine := New()
engine.Handle(http.MethodConnect, "/tunnel", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
req := httptest.NewRequest(http.MethodConnect, "http://client.example/tunnel", nil)
req.URL.Path = "/tunnel"
req.RequestURI = "/tunnel"
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
if rr.Code != http.StatusNotImplemented {
t.Fatalf("unexpected status: %d", rr.Code)
}
}
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
t.Helper()
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/plain"},
},
Body: &failingReadCloser{chunks: []string{"ok"}, err: errors.New("boom")},
ContentLength: -1,
Request: req,
}, nil
}),
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := proxy.Client().Get(proxy.URL + "/proxy")
if err != nil {
t.Fatalf("perform request: %v", err)
}
_, err = io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err == nil {
t.Fatal("expected body read to fail after upstream copy error")
}
}
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
t.Helper()
@ -568,3 +997,21 @@ func mustParseURL(t *testing.T, raw string) *url.URL {
}
return u
}
type failingReadCloser struct {
chunks []string
err error
}
func (r *failingReadCloser) Read(p []byte) (int, error) {
if len(r.chunks) == 0 {
return 0, r.err
}
n := copy(p, r.chunks[0])
r.chunks = r.chunks[1:]
return n, nil
}
func (r *failingReadCloser) Close() error {
return nil
}