mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
fix: harden reverse proxy edge cases
Preserve final headers when forwarding 1xx responses, reject invalid 101 upgrade negotiations, and make the default Via token RFC-safe. Tighten the reverse proxy tests around goroutine synchronization and document the Via fallback behavior more clearly.
This commit is contained in:
parent
e4ca20e848
commit
1946216c0e
3 changed files with 277 additions and 15 deletions
|
|
@ -238,6 +238,24 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
}))
|
}))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`Via` 不是“留空即禁用”的开关。当前实现中:
|
||||||
|
|
||||||
|
- 如果 `Via` 非空,则使用该值追加 `Via`
|
||||||
|
- 如果 `Via` 为空,则会回退到固定值 `touka-engine`
|
||||||
|
|
||||||
|
因此,把 `Via` 留空时,发送出去的请求仍会包含 `Via` 头,只是使用默认标识 `touka-engine`。
|
||||||
|
|
||||||
|
如果您希望上游清楚区分不同入口、环境或网关实例,仍然建议显式设置一个稳定且可公开暴露的代理标识,例如:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
Via: "edge-gateway",
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
当前版本没有提供“完全禁用追加 Via”的单独配置项,因此不要把空字符串当作关闭手段。
|
||||||
|
|
||||||
### Touka 会如何处理这些头?
|
### Touka 会如何处理这些头?
|
||||||
|
|
||||||
Touka 会尽量遵循代理链语义:
|
Touka 会尽量遵循代理链语义:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
@ -299,9 +298,12 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
h := c.Writer.Header()
|
h := c.Writer.Header()
|
||||||
|
saved := h.Clone()
|
||||||
|
clear(h)
|
||||||
reverseProxyCopyHeader(h, http.Header(header))
|
reverseProxyCopyHeader(h, http.Header(header))
|
||||||
rawWriter.WriteHeader(code)
|
rawWriter.WriteHeader(code)
|
||||||
clear(h)
|
clear(h)
|
||||||
|
reverseProxyCopyHeader(h, saved)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -482,6 +484,13 @@ func (p *reverseProxyHandler) handleError(c *Context, err error) {
|
||||||
func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error {
|
func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error {
|
||||||
reqUpType := reverseProxyUpgradeType(req.Header)
|
reqUpType := reverseProxyUpgradeType(req.Header)
|
||||||
resUpType := reverseProxyUpgradeType(res.Header)
|
resUpType := reverseProxyUpgradeType(res.Header)
|
||||||
|
if reqUpType == "" || resUpType == "" {
|
||||||
|
res.Body.Close()
|
||||||
|
return &reverseProxyStatusError{
|
||||||
|
status: http.StatusBadGateway,
|
||||||
|
err: fmt.Errorf("invalid upgrade negotiation: request protocol=%q, response protocol=%q", reqUpType, resUpType),
|
||||||
|
}
|
||||||
|
}
|
||||||
if !isPrintableASCII(resUpType) {
|
if !isPrintableASCII(resUpType) {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
return &reverseProxyStatusError{
|
return &reverseProxyStatusError{
|
||||||
|
|
@ -660,11 +669,7 @@ func reverseProxyReceivedBy(configValue string) string {
|
||||||
if trimmed != "" {
|
if trimmed != "" {
|
||||||
return trimmed
|
return trimmed
|
||||||
}
|
}
|
||||||
hostname, err := os.Hostname()
|
return "touka-engine"
|
||||||
if err == nil && hostname != "" {
|
|
||||||
return hostname
|
|
||||||
}
|
|
||||||
return "touka"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func reverseProxyClientIP(remoteAddr string) string {
|
func reverseProxyClientIP(remoteAddr string) string {
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,13 @@ package touka
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/http/httptrace"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -32,9 +34,9 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
||||||
UserAgent string
|
UserAgent string
|
||||||
}
|
}
|
||||||
|
|
||||||
var got backendRequestSnapshot
|
gotCh := make(chan backendRequestSnapshot, 1)
|
||||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
got = backendRequestSnapshot{
|
gotCh <- backendRequestSnapshot{
|
||||||
Path: r.URL.Path,
|
Path: r.URL.Path,
|
||||||
RawQuery: r.URL.RawQuery,
|
RawQuery: r.URL.RawQuery,
|
||||||
Host: r.Host,
|
Host: r.Host,
|
||||||
|
|
@ -93,6 +95,13 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
var got backendRequestSnapshot
|
||||||
|
select {
|
||||||
|
case got = <-gotCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for backend snapshot")
|
||||||
|
}
|
||||||
|
|
||||||
if string(body) != "proxied" {
|
if string(body) != "proxied" {
|
||||||
t.Fatalf("unexpected body: %q", string(body))
|
t.Fatalf("unexpected body: %q", string(body))
|
||||||
}
|
}
|
||||||
|
|
@ -161,6 +170,39 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyDefaultViaFallback(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
viaCh := make(chan []string, 1)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
viaCh <- append([]string(nil), r.Header.Values("Via")...)
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{Target: target}))
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case via := <-viaCh:
|
||||||
|
if len(via) != 1 || via[0] != "1.1 touka-engine" {
|
||||||
|
t.Fatalf("unexpected default Via header: %#v", via)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for backend Via header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
@ -227,40 +269,46 @@ func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
|
||||||
func TestReverseProxyProtocolUpgrade(t *testing.T) {
|
func TestReverseProxyProtocolUpgrade(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
errCh := make(chan error, 8)
|
||||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !headerValuesContainToken(r.Header["Connection"], "Upgrade") {
|
if !headerValuesContainToken(r.Header["Connection"], "Upgrade") {
|
||||||
t.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection"))
|
errCh <- fmt.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection"))
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||||
t.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade"))
|
errCh <- fmt.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade"))
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hj, ok := w.(http.Hijacker)
|
hj, ok := w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("backend response writer does not support hijack")
|
errCh <- errors.New("backend response writer does not support hijack")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
conn, brw, err := hj.Hijack()
|
conn, brw, err := hj.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend hijack failed: %v", err)
|
errCh <- fmt.Errorf("backend hijack failed: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||||
if err := brw.Flush(); err != nil {
|
if err := brw.Flush(); err != nil {
|
||||||
t.Fatalf("backend flush failed: %v", err)
|
errCh <- fmt.Errorf("backend flush failed: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
line, err := brw.ReadString('\n')
|
line, err := brw.ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend read failed: %v", err)
|
errCh <- fmt.Errorf("backend read failed: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
_, _ = io.WriteString(brw, "echo:"+line)
|
_, _ = io.WriteString(brw, "echo:"+line)
|
||||||
if err := brw.Flush(); err != nil {
|
if err := brw.Flush(); err != nil {
|
||||||
t.Fatalf("backend echo flush failed: %v", err)
|
errCh <- fmt.Errorf("backend echo flush failed: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
defer backend.Close()
|
defer backend.Close()
|
||||||
|
|
@ -328,4 +376,195 @@ func TestReverseProxyProtocolUpgrade(t *testing.T) {
|
||||||
if message != "echo:ping\n" {
|
if message != "echo:ping\n" {
|
||||||
t.Fatalf("unexpected tunneled payload: %q", message)
|
t.Fatalf("unexpected tunneled payload: %q", message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatal(err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
errCh := make(chan error, 4)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
errCh <- errors.New("backend response writer does not support hijack")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, brw, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("backend hijack failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\n\r\n")
|
||||||
|
if err := brw.Flush(); err != nil {
|
||||||
|
errCh <- fmt.Errorf("backend flush failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: target}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
t.Fatalf("set deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("write upgrade request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read response: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadGateway {
|
||||||
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatal(err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
type oneXXInfo struct {
|
||||||
|
code int
|
||||||
|
header http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
backendTraceCh := make(chan struct{}, 1)
|
||||||
|
oneXXCh := make(chan oneXXInfo, 1)
|
||||||
|
|
||||||
|
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
trace := httptrace.ContextClientTrace(req.Context())
|
||||||
|
if trace == nil || trace.Got1xxResponse == nil {
|
||||||
|
return nil, errors.New("missing Got1xxResponse trace")
|
||||||
|
}
|
||||||
|
backendTraceCh <- struct{}{}
|
||||||
|
if err := trace.Got1xxResponse(http.StatusEarlyHints, textproto.MIMEHeader{"Link": {"</style.css>; rel=preload; as=style"}}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/plain"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
ContentLength: 2,
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Use(func(c *Context) {
|
||||||
|
c.Writer.Header().Set("X-Request-Id", "req-123")
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, "http://example.com"),
|
||||||
|
Transport: transport,
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
client := proxy.Client()
|
||||||
|
req, err := http.NewRequest(http.MethodGet, proxy.URL+"/proxy", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new request: %v", err)
|
||||||
|
}
|
||||||
|
req = req.WithContext(httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{
|
||||||
|
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||||
|
oneXXCh <- oneXXInfo{code: code, header: http.Header(header).Clone()}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("perform request: %v", err)
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read body: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-backendTraceCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("expected proxy transport 1xx trace to be invoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
var oneXX oneXXInfo
|
||||||
|
select {
|
||||||
|
case oneXX = <-oneXXCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("expected client to receive 1xx response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if string(body) != "ok" {
|
||||||
|
t.Fatalf("unexpected body: %q", string(body))
|
||||||
|
}
|
||||||
|
if got := resp.Header.Get("X-Request-Id"); got != "req-123" {
|
||||||
|
t.Fatalf("final response lost preserved header: %q", got)
|
||||||
|
}
|
||||||
|
if got := resp.Header.Get("Link"); got != "" {
|
||||||
|
t.Fatalf("interim 1xx header leaked into final response: %q", got)
|
||||||
|
}
|
||||||
|
if oneXX.code != http.StatusEarlyHints {
|
||||||
|
t.Fatalf("unexpected interim status: %d", oneXX.code)
|
||||||
|
}
|
||||||
|
if got := oneXX.header.Get("Link"); got != "</style.css>; rel=preload; as=style" {
|
||||||
|
t.Fatalf("unexpected interim Link header: %q", got)
|
||||||
|
}
|
||||||
|
if got := oneXX.header.Get("X-Request-Id"); got != "" {
|
||||||
|
t.Fatalf("final-only header leaked into interim response: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return fn(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||||
|
t.Helper()
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse url %q: %v", raw, err)
|
||||||
|
}
|
||||||
|
return u
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue