fix: harden reverse proxy edge cases

Preserve final headers when forwarding 1xx responses, reject invalid 101 upgrade negotiations, and make the default Via token RFC-safe. Tighten the reverse proxy tests around goroutine synchronization and document the Via fallback behavior more clearly.
This commit is contained in:
wjqserver 2026-03-29 01:15:57 +08:00
parent e4ca20e848
commit 1946216c0e
3 changed files with 277 additions and 15 deletions

View file

@ -238,6 +238,24 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
})) }))
``` ```
`Via` 不是“留空即禁用”的开关。当前实现中:
- 如果 `Via` 非空,则使用该值追加 `Via`
- 如果 `Via` 为空,则会回退到固定值 `touka-engine`
因此,把 `Via` 留空时,发送出去的请求仍会包含 `Via` 头,只是使用默认标识 `touka-engine`
如果您希望上游清楚区分不同入口、环境或网关实例,仍然建议显式设置一个稳定且可公开暴露的代理标识,例如:
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target,
Via: "edge-gateway",
}))
```
当前版本没有提供“完全禁用追加 Via”的单独配置项因此不要把空字符串当作关闭手段。
### Touka 会如何处理这些头? ### Touka 会如何处理这些头?
Touka 会尽量遵循代理链语义: Touka 会尽量遵循代理链语义:

View file

@ -17,7 +17,6 @@ import (
"net/netip" "net/netip"
"net/textproto" "net/textproto"
"net/url" "net/url"
"os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -299,9 +298,12 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
return nil return nil
} }
h := c.Writer.Header() h := c.Writer.Header()
saved := h.Clone()
clear(h)
reverseProxyCopyHeader(h, http.Header(header)) reverseProxyCopyHeader(h, http.Header(header))
rawWriter.WriteHeader(code) rawWriter.WriteHeader(code)
clear(h) clear(h)
reverseProxyCopyHeader(h, saved)
return nil return nil
}, },
} }
@ -482,6 +484,13 @@ func (p *reverseProxyHandler) handleError(c *Context, err error) {
func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error { func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error {
reqUpType := reverseProxyUpgradeType(req.Header) reqUpType := reverseProxyUpgradeType(req.Header)
resUpType := reverseProxyUpgradeType(res.Header) resUpType := reverseProxyUpgradeType(res.Header)
if reqUpType == "" || resUpType == "" {
res.Body.Close()
return &reverseProxyStatusError{
status: http.StatusBadGateway,
err: fmt.Errorf("invalid upgrade negotiation: request protocol=%q, response protocol=%q", reqUpType, resUpType),
}
}
if !isPrintableASCII(resUpType) { if !isPrintableASCII(resUpType) {
res.Body.Close() res.Body.Close()
return &reverseProxyStatusError{ return &reverseProxyStatusError{
@ -660,11 +669,7 @@ func reverseProxyReceivedBy(configValue string) string {
if trimmed != "" { if trimmed != "" {
return trimmed return trimmed
} }
hostname, err := os.Hostname() return "touka-engine"
if err == nil && hostname != "" {
return hostname
}
return "touka"
} }
func reverseProxyClientIP(remoteAddr string) string { func reverseProxyClientIP(remoteAddr string) string {

View file

@ -2,11 +2,13 @@ package touka
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httptrace"
"net/textproto" "net/textproto"
"net/url" "net/url"
"strings" "strings"
@ -32,9 +34,9 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
UserAgent string UserAgent string
} }
var got backendRequestSnapshot gotCh := make(chan backendRequestSnapshot, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got = backendRequestSnapshot{ gotCh <- backendRequestSnapshot{
Path: r.URL.Path, Path: r.URL.Path,
RawQuery: r.URL.RawQuery, RawQuery: r.URL.RawQuery,
Host: r.Host, Host: r.Host,
@ -93,6 +95,13 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
} }
_ = resp.Body.Close() _ = resp.Body.Close()
var got backendRequestSnapshot
select {
case got = <-gotCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend snapshot")
}
if string(body) != "proxied" { if string(body) != "proxied" {
t.Fatalf("unexpected body: %q", string(body)) t.Fatalf("unexpected body: %q", string(body))
} }
@ -161,6 +170,39 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
} }
} }
func TestReverseProxyDefaultViaFallback(t *testing.T) {
t.Helper()
viaCh := make(chan []string, 1)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
viaCh <- append([]string(nil), r.Header.Values("Via")...)
w.WriteHeader(http.StatusNoContent)
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{Target: target}))
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
if rr.Code != http.StatusNoContent {
t.Fatalf("unexpected status: %d", rr.Code)
}
select {
case via := <-viaCh:
if len(via) != 1 || via[0] != "1.1 touka-engine" {
t.Fatalf("unexpected default Via header: %#v", via)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for backend Via header")
}
}
func TestReverseProxyCustomErrorHandler(t *testing.T) { func TestReverseProxyCustomErrorHandler(t *testing.T) {
t.Helper() t.Helper()
@ -227,40 +269,46 @@ func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
func TestReverseProxyProtocolUpgrade(t *testing.T) { func TestReverseProxyProtocolUpgrade(t *testing.T) {
t.Helper() t.Helper()
errCh := make(chan error, 8)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !headerValuesContainToken(r.Header["Connection"], "Upgrade") { if !headerValuesContainToken(r.Header["Connection"], "Upgrade") {
t.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection")) errCh <- fmt.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection"))
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
} }
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
t.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade")) errCh <- fmt.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade"))
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
} }
hj, ok := w.(http.Hijacker) hj, ok := w.(http.Hijacker)
if !ok { if !ok {
t.Fatal("backend response writer does not support hijack") errCh <- errors.New("backend response writer does not support hijack")
return
} }
conn, brw, err := hj.Hijack() conn, brw, err := hj.Hijack()
if err != nil { if err != nil {
t.Fatalf("backend hijack failed: %v", err) errCh <- fmt.Errorf("backend hijack failed: %w", err)
return
} }
defer conn.Close() defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
if err := brw.Flush(); err != nil { if err := brw.Flush(); err != nil {
t.Fatalf("backend flush failed: %v", err) errCh <- fmt.Errorf("backend flush failed: %w", err)
return
} }
line, err := brw.ReadString('\n') line, err := brw.ReadString('\n')
if err != nil { if err != nil {
t.Fatalf("backend read failed: %v", err) errCh <- fmt.Errorf("backend read failed: %w", err)
return
} }
_, _ = io.WriteString(brw, "echo:"+line) _, _ = io.WriteString(brw, "echo:"+line)
if err := brw.Flush(); err != nil { if err := brw.Flush(); err != nil {
t.Fatalf("backend echo flush failed: %v", err) errCh <- fmt.Errorf("backend echo flush failed: %w", err)
return
} }
})) }))
defer backend.Close() defer backend.Close()
@ -328,4 +376,195 @@ func TestReverseProxyProtocolUpgrade(t *testing.T) {
if message != "echo:ping\n" { if message != "echo:ping\n" {
t.Fatalf("unexpected tunneled payload: %q", message) t.Fatalf("unexpected tunneled payload: %q", message)
} }
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) {
t.Helper()
errCh := make(chan error, 4)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
errCh <- errors.New("backend response writer does not support hijack")
return
}
conn, brw, err := hj.Hijack()
if err != nil {
errCh <- fmt.Errorf("backend hijack failed: %w", err)
return
}
defer conn.Close()
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\n\r\n")
if err := brw.Flush(); err != nil {
errCh <- fmt.Errorf("backend flush failed: %w", err)
return
}
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: target}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
t.Fatalf("set deadline: %v", err)
}
_, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
if err != nil {
t.Fatalf("write upgrade request: %v", err)
}
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
if err != nil {
t.Fatalf("read response: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
t.Helper()
type oneXXInfo struct {
code int
header http.Header
}
backendTraceCh := make(chan struct{}, 1)
oneXXCh := make(chan oneXXInfo, 1)
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.Got1xxResponse == nil {
return nil, errors.New("missing Got1xxResponse trace")
}
backendTraceCh <- struct{}{}
if err := trace.Got1xxResponse(http.StatusEarlyHints, textproto.MIMEHeader{"Link": {"</style.css>; rel=preload; as=style"}}); err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/plain"},
},
Body: io.NopCloser(strings.NewReader("ok")),
ContentLength: 2,
Request: req,
}, nil
})
engine := New()
engine.Use(func(c *Context) {
c.Writer.Header().Set("X-Request-Id", "req-123")
c.Next()
})
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
Transport: transport,
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
client := proxy.Client()
req, err := http.NewRequest(http.MethodGet, proxy.URL+"/proxy", nil)
if err != nil {
t.Fatalf("new request: %v", err)
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
oneXXCh <- oneXXInfo{code: code, header: http.Header(header).Clone()}
return nil
},
}))
resp, err := client.Do(req)
if err != nil {
t.Fatalf("perform request: %v", err)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
_ = resp.Body.Close()
select {
case <-backendTraceCh:
case <-time.After(2 * time.Second):
t.Fatal("expected proxy transport 1xx trace to be invoked")
}
var oneXX oneXXInfo
select {
case oneXX = <-oneXXCh:
case <-time.After(2 * time.Second):
t.Fatal("expected client to receive 1xx response")
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %d", resp.StatusCode)
}
if string(body) != "ok" {
t.Fatalf("unexpected body: %q", string(body))
}
if got := resp.Header.Get("X-Request-Id"); got != "req-123" {
t.Fatalf("final response lost preserved header: %q", got)
}
if got := resp.Header.Get("Link"); got != "" {
t.Fatalf("interim 1xx header leaked into final response: %q", got)
}
if oneXX.code != http.StatusEarlyHints {
t.Fatalf("unexpected interim status: %d", oneXX.code)
}
if got := oneXX.header.Get("Link"); got != "</style.css>; rel=preload; as=style" {
t.Fatalf("unexpected interim Link header: %q", got)
}
if got := oneXX.header.Get("X-Request-Id"); got != "" {
t.Fatalf("final-only header leaked into interim response: %q", got)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper()
u, err := url.Parse(raw)
if err != nil {
t.Fatalf("parse url %q: %v", raw, err)
}
return u
} }