mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
fix: harden reverse proxy edge cases
Preserve final headers when forwarding 1xx responses, reject invalid 101 upgrade negotiations, and make the default Via token RFC-safe. Tighten the reverse proxy tests around goroutine synchronization and document the Via fallback behavior more clearly.
This commit is contained in:
parent
e4ca20e848
commit
1946216c0e
3 changed files with 277 additions and 15 deletions
|
|
@ -238,6 +238,24 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
|||
}))
|
||||
```
|
||||
|
||||
`Via` 不是“留空即禁用”的开关。当前实现中:
|
||||
|
||||
- 如果 `Via` 非空,则使用该值追加 `Via`
|
||||
- 如果 `Via` 为空,则会回退到固定值 `touka-engine`
|
||||
|
||||
因此,把 `Via` 留空时,发送出去的请求仍会包含 `Via` 头,只是使用默认标识 `touka-engine`。
|
||||
|
||||
如果您希望上游清楚区分不同入口、环境或网关实例,仍然建议显式设置一个稳定且可公开暴露的代理标识,例如:
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Target: target,
|
||||
Via: "edge-gateway",
|
||||
}))
|
||||
```
|
||||
|
||||
当前版本没有提供“完全禁用追加 Via”的单独配置项,因此不要把空字符串当作关闭手段。
|
||||
|
||||
### Touka 会如何处理这些头?
|
||||
|
||||
Touka 会尽量遵循代理链语义:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ import (
|
|||
"net/netip"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -299,9 +298,12 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
|||
return nil
|
||||
}
|
||||
h := c.Writer.Header()
|
||||
saved := h.Clone()
|
||||
clear(h)
|
||||
reverseProxyCopyHeader(h, http.Header(header))
|
||||
rawWriter.WriteHeader(code)
|
||||
clear(h)
|
||||
reverseProxyCopyHeader(h, saved)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
|
@ -482,6 +484,13 @@ func (p *reverseProxyHandler) handleError(c *Context, err error) {
|
|||
func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error {
|
||||
reqUpType := reverseProxyUpgradeType(req.Header)
|
||||
resUpType := reverseProxyUpgradeType(res.Header)
|
||||
if reqUpType == "" || resUpType == "" {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{
|
||||
status: http.StatusBadGateway,
|
||||
err: fmt.Errorf("invalid upgrade negotiation: request protocol=%q, response protocol=%q", reqUpType, resUpType),
|
||||
}
|
||||
}
|
||||
if !isPrintableASCII(resUpType) {
|
||||
res.Body.Close()
|
||||
return &reverseProxyStatusError{
|
||||
|
|
@ -660,11 +669,7 @@ func reverseProxyReceivedBy(configValue string) string {
|
|||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
hostname, err := os.Hostname()
|
||||
if err == nil && hostname != "" {
|
||||
return hostname
|
||||
}
|
||||
return "touka"
|
||||
return "touka-engine"
|
||||
}
|
||||
|
||||
func reverseProxyClientIP(remoteAddr string) string {
|
||||
|
|
|
|||
|
|
@ -2,11 +2,13 @@ package touka
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
|
@ -32,9 +34,9 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
|||
UserAgent string
|
||||
}
|
||||
|
||||
var got backendRequestSnapshot
|
||||
gotCh := make(chan backendRequestSnapshot, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
got = backendRequestSnapshot{
|
||||
gotCh <- backendRequestSnapshot{
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
Host: r.Host,
|
||||
|
|
@ -93,6 +95,13 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
|||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var got backendRequestSnapshot
|
||||
select {
|
||||
case got = <-gotCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for backend snapshot")
|
||||
}
|
||||
|
||||
if string(body) != "proxied" {
|
||||
t.Fatalf("unexpected body: %q", string(body))
|
||||
}
|
||||
|
|
@ -161,6 +170,39 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyDefaultViaFallback(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
viaCh := make(chan []string, 1)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
viaCh <- append([]string(nil), r.Header.Values("Via")...)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{Target: target}))
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Fatalf("unexpected status: %d", rr.Code)
|
||||
}
|
||||
|
||||
select {
|
||||
case via := <-viaCh:
|
||||
if len(via) != 1 || via[0] != "1.1 touka-engine" {
|
||||
t.Fatalf("unexpected default Via header: %#v", via)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for backend Via header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -227,40 +269,46 @@ func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
|
|||
func TestReverseProxyProtocolUpgrade(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
errCh := make(chan error, 8)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !headerValuesContainToken(r.Header["Connection"], "Upgrade") {
|
||||
t.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection"))
|
||||
errCh <- fmt.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||
t.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade"))
|
||||
errCh <- fmt.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Fatal("backend response writer does not support hijack")
|
||||
errCh <- errors.New("backend response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Fatalf("backend hijack failed: %v", err)
|
||||
errCh <- fmt.Errorf("backend hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
t.Fatalf("backend flush failed: %v", err)
|
||||
errCh <- fmt.Errorf("backend flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
line, err := brw.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("backend read failed: %v", err)
|
||||
errCh <- fmt.Errorf("backend read failed: %w", err)
|
||||
return
|
||||
}
|
||||
_, _ = io.WriteString(brw, "echo:"+line)
|
||||
if err := brw.Flush(); err != nil {
|
||||
t.Fatalf("backend echo flush failed: %v", err)
|
||||
errCh <- fmt.Errorf("backend echo flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
|
@ -328,4 +376,195 @@ func TestReverseProxyProtocolUpgrade(t *testing.T) {
|
|||
if message != "echo:ping\n" {
|
||||
t.Fatalf("unexpected tunneled payload: %q", message)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
errCh <- errors.New("backend response writer does not support hijack")
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("backend hijack failed: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\n\r\n")
|
||||
if err := brw.Flush(); err != nil {
|
||||
errCh <- fmt.Errorf("backend flush failed: %w", err)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: target}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial proxy: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
t.Fatalf("set deadline: %v", err)
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
|
||||
if err != nil {
|
||||
t.Fatalf("write upgrade request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("read response: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
type oneXXInfo struct {
|
||||
code int
|
||||
header http.Header
|
||||
}
|
||||
|
||||
backendTraceCh := make(chan struct{}, 1)
|
||||
oneXXCh := make(chan oneXXInfo, 1)
|
||||
|
||||
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
trace := httptrace.ContextClientTrace(req.Context())
|
||||
if trace == nil || trace.Got1xxResponse == nil {
|
||||
return nil, errors.New("missing Got1xxResponse trace")
|
||||
}
|
||||
backendTraceCh <- struct{}{}
|
||||
if err := trace.Got1xxResponse(http.StatusEarlyHints, textproto.MIMEHeader{"Link": {"</style.css>; rel=preload; as=style"}}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/plain"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("ok")),
|
||||
ContentLength: 2,
|
||||
Request: req,
|
||||
}, nil
|
||||
})
|
||||
|
||||
engine := New()
|
||||
engine.Use(func(c *Context) {
|
||||
c.Writer.Header().Set("X-Request-Id", "req-123")
|
||||
c.Next()
|
||||
})
|
||||
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||
Target: mustParseURL(t, "http://example.com"),
|
||||
Transport: transport,
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
client := proxy.Client()
|
||||
req, err := http.NewRequest(http.MethodGet, proxy.URL+"/proxy", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("new request: %v", err)
|
||||
}
|
||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{
|
||||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||
oneXXCh <- oneXXInfo{code: code, header: http.Header(header).Clone()}
|
||||
return nil
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("perform request: %v", err)
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
select {
|
||||
case <-backendTraceCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected proxy transport 1xx trace to be invoked")
|
||||
}
|
||||
|
||||
var oneXX oneXXInfo
|
||||
select {
|
||||
case oneXX = <-oneXXCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected client to receive 1xx response")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
if string(body) != "ok" {
|
||||
t.Fatalf("unexpected body: %q", string(body))
|
||||
}
|
||||
if got := resp.Header.Get("X-Request-Id"); got != "req-123" {
|
||||
t.Fatalf("final response lost preserved header: %q", got)
|
||||
}
|
||||
if got := resp.Header.Get("Link"); got != "" {
|
||||
t.Fatalf("interim 1xx header leaked into final response: %q", got)
|
||||
}
|
||||
if oneXX.code != http.StatusEarlyHints {
|
||||
t.Fatalf("unexpected interim status: %d", oneXX.code)
|
||||
}
|
||||
if got := oneXX.header.Get("Link"); got != "</style.css>; rel=preload; as=style" {
|
||||
t.Fatalf("unexpected interim Link header: %q", got)
|
||||
}
|
||||
if got := oneXX.header.Get("X-Request-Id"); got != "" {
|
||||
t.Fatalf("final-only header leaked into interim response: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return fn(req)
|
||||
}
|
||||
|
||||
func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||
t.Helper()
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parse url %q: %v", raw, err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue