From 1946216c0edf614abeb4d6d048b99fbfe4828931 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Sun, 29 Mar 2026 01:15:57 +0800 Subject: [PATCH] fix: harden reverse proxy edge cases Preserve final headers when forwarding 1xx responses, reject invalid 101 upgrade negotiations, and make the default Via token RFC-safe. Tighten the reverse proxy tests around goroutine synchronization and document the Via fallback behavior more clearly. --- docs/reverse-proxy.md | 18 +++ reverseproxy.go | 17 ++- reverseproxy_test.go | 257 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 277 insertions(+), 15 deletions(-) diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 626a3b0..e495a19 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -238,6 +238,24 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ })) ``` +`Via` 不是“留空即禁用”的开关。当前实现中: + +- 如果 `Via` 非空,则使用该值追加 `Via` +- 如果 `Via` 为空,则会回退到固定值 `touka-engine` + +因此,把 `Via` 留空时,发送出去的请求仍会包含 `Via` 头,只是使用默认标识 `touka-engine`。 + +如果您希望上游清楚区分不同入口、环境或网关实例,仍然建议显式设置一个稳定且可公开暴露的代理标识,例如: + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + Via: "edge-gateway", +})) +``` + +当前版本没有提供“完全禁用追加 Via”的单独配置项,因此不要把空字符串当作关闭手段。 + ### Touka 会如何处理这些头? Touka 会尽量遵循代理链语义: diff --git a/reverseproxy.go b/reverseproxy.go index 6ae368d..f486364 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -17,7 +17,6 @@ import ( "net/netip" "net/textproto" "net/url" - "os" "strconv" "strings" "sync" @@ -299,9 +298,12 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { return nil } h := c.Writer.Header() + saved := h.Clone() + clear(h) reverseProxyCopyHeader(h, http.Header(header)) rawWriter.WriteHeader(code) clear(h) + reverseProxyCopyHeader(h, saved) return nil }, } @@ -482,6 +484,13 @@ func (p *reverseProxyHandler) handleError(c *Context, err error) { func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error { reqUpType := reverseProxyUpgradeType(req.Header) resUpType := reverseProxyUpgradeType(res.Header) + if reqUpType == "" || resUpType == "" { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: fmt.Errorf("invalid upgrade negotiation: request protocol=%q, response protocol=%q", reqUpType, resUpType), + } + } if !isPrintableASCII(resUpType) { res.Body.Close() return &reverseProxyStatusError{ @@ -660,11 +669,7 @@ func reverseProxyReceivedBy(configValue string) string { if trimmed != "" { return trimmed } - hostname, err := os.Hostname() - if err == nil && hostname != "" { - return hostname - } - return "touka" + return "touka-engine" } func reverseProxyClientIP(remoteAddr string) string { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 5d9148d..1b643ef 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -2,11 +2,13 @@ package touka import ( "bufio" + "errors" "fmt" "io" "net" "net/http" "net/http/httptest" + "net/http/httptrace" "net/textproto" "net/url" "strings" @@ -32,9 +34,9 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { UserAgent string } - var got backendRequestSnapshot + gotCh := make(chan backendRequestSnapshot, 1) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - got = backendRequestSnapshot{ + gotCh <- backendRequestSnapshot{ Path: r.URL.Path, RawQuery: r.URL.RawQuery, Host: r.Host, @@ -93,6 +95,13 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { } _ = resp.Body.Close() + var got backendRequestSnapshot + select { + case got = <-gotCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend snapshot") + } + if string(body) != "proxied" { t.Fatalf("unexpected body: %q", string(body)) } @@ -161,6 +170,39 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { } } +func TestReverseProxyDefaultViaFallback(t *testing.T) { + t.Helper() + + viaCh := make(chan []string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + viaCh <- append([]string(nil), r.Header.Values("Via")...) + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{Target: target})) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case via := <-viaCh: + if len(via) != 1 || via[0] != "1.1 touka-engine" { + t.Fatalf("unexpected default Via header: %#v", via) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Via header") + } +} + func TestReverseProxyCustomErrorHandler(t *testing.T) { t.Helper() @@ -227,40 +269,46 @@ func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) { func TestReverseProxyProtocolUpgrade(t *testing.T) { t.Helper() + errCh := make(chan error, 8) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !headerValuesContainToken(r.Header["Connection"], "Upgrade") { - t.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection")) + errCh <- fmt.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection")) w.WriteHeader(http.StatusBadRequest) return } if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { - t.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade")) + errCh <- fmt.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade")) w.WriteHeader(http.StatusBadRequest) return } hj, ok := w.(http.Hijacker) if !ok { - t.Fatal("backend response writer does not support hijack") + errCh <- errors.New("backend response writer does not support hijack") + return } conn, brw, err := hj.Hijack() if err != nil { - t.Fatalf("backend hijack failed: %v", err) + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return } defer conn.Close() _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") if err := brw.Flush(); err != nil { - t.Fatalf("backend flush failed: %v", err) + errCh <- fmt.Errorf("backend flush failed: %w", err) + return } line, err := brw.ReadString('\n') if err != nil { - t.Fatalf("backend read failed: %v", err) + errCh <- fmt.Errorf("backend read failed: %w", err) + return } _, _ = io.WriteString(brw, "echo:"+line) if err := brw.Flush(); err != nil { - t.Fatalf("backend echo flush failed: %v", err) + errCh <- fmt.Errorf("backend echo flush failed: %w", err) + return } })) defer backend.Close() @@ -328,4 +376,195 @@ func TestReverseProxyProtocolUpgrade(t *testing.T) { if message != "echo:ping\n" { t.Fatalf("unexpected tunneled payload: %q", message) } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) { + t.Helper() + + errCh := make(chan error, 4) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("backend response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend flush failed: %w", err) + return + } + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: target})) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + + _, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n") + if err != nil { + t.Fatalf("write upgrade request: %v", err) + } + + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) { + t.Helper() + + type oneXXInfo struct { + code int + header http.Header + } + + backendTraceCh := make(chan struct{}, 1) + oneXXCh := make(chan oneXXInfo, 1) + + transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.Got1xxResponse == nil { + return nil, errors.New("missing Got1xxResponse trace") + } + backendTraceCh <- struct{}{} + if err := trace.Got1xxResponse(http.StatusEarlyHints, textproto.MIMEHeader{"Link": {"; rel=preload; as=style"}}); err != nil { + return nil, err + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + Body: io.NopCloser(strings.NewReader("ok")), + ContentLength: 2, + Request: req, + }, nil + }) + + engine := New() + engine.Use(func(c *Context) { + c.Writer.Header().Set("X-Request-Id", "req-123") + c.Next() + }) + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: transport, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + client := proxy.Client() + req, err := http.NewRequest(http.MethodGet, proxy.URL+"/proxy", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + oneXXCh <- oneXXInfo{code: code, header: http.Header(header).Clone()} + return nil + }, + })) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("perform request: %v", err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + select { + case <-backendTraceCh: + case <-time.After(2 * time.Second): + t.Fatal("expected proxy transport 1xx trace to be invoked") + } + + var oneXX oneXXInfo + select { + case oneXX = <-oneXXCh: + case <-time.After(2 * time.Second): + t.Fatal("expected client to receive 1xx response") + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if string(body) != "ok" { + t.Fatalf("unexpected body: %q", string(body)) + } + if got := resp.Header.Get("X-Request-Id"); got != "req-123" { + t.Fatalf("final response lost preserved header: %q", got) + } + if got := resp.Header.Get("Link"); got != "" { + t.Fatalf("interim 1xx header leaked into final response: %q", got) + } + if oneXX.code != http.StatusEarlyHints { + t.Fatalf("unexpected interim status: %d", oneXX.code) + } + if got := oneXX.header.Get("Link"); got != "; rel=preload; as=style" { + t.Fatalf("unexpected interim Link header: %q", got) + } + if got := oneXX.header.Get("X-Request-Id"); got != "" { + t.Fatalf("final-only header leaked into interim response: %q", got) + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + if err != nil { + t.Fatalf("parse url %q: %v", raw, err) + } + return u }