diff --git a/README.md b/README.md index a449962..a7b99fd 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Touka(灯花) 是一个基于 Go 语言构建的多层次、高性能 Web 框架 - **[中间件 (middleware.md)](docs/middleware.md)** - **[统一错误处理 (error-handling.md)](docs/error-handling.md)** - **[静态文件与资源 (static-files.md)](docs/static-files.md)** +- **[反向代理 (reverse-proxy.md)](docs/reverse-proxy.md)** - **[Server-Sent Events (sse.md)](docs/sse.md)** - **[高级特性与优化 (advanced.md)](docs/advanced.md)** diff --git a/docs/advanced.md b/docs/advanced.md index 4b68f93..a7cb9a2 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -184,6 +184,8 @@ r.SetRemoteIPHeaders([]string{ }) ``` +如果您同时使用 Touka 的 `ReverseProxy` 把请求继续转发给其他后端,请再参考 `docs/reverse-proxy.md` 中关于 `Forwarded`、`X-Forwarded-*` 与 `Via` 的说明。前者解决“当前请求的客户端 IP 如何被 Touka 正确解析”,后者解决“代理后的请求如何把链路信息继续传给下一跳”。 + ## 请求体大小限制 为了防止恶意的大数据包攻击(如慢速 HTTP 攻击或内存溢出),Touka 内置了请求体大小限制机制。 diff --git a/docs/introduction.md b/docs/introduction.md index d1aec3e..94a7310 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -14,6 +14,7 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其 - **最小化内存分配**: 在热点路径上尽可能减少临时对象的产生。 - **统一错误处理**: 独创的 `errorCapturingResponseWriter` 机制,能够捕获包括标准库 `http.FileServer` 在内的所有组件产生的错误状态码,并交由全局处理器统一处理。 - **无缝集成 SSE**: 内置对 Server-Sent Events 的支持,提供简单易用的回调式 API 和高度灵活的通道式 API。 +- **内置反向代理**: 支持请求转发、协议升级、转发头维护、Trailer 与流式响应透传。 - **静态资源增强**: 针对本地文件、目录以及 Go 嵌入式文件系统(embed.FS)提供了开箱即用的支持。 - **标准库兼容**: 提供了适配器,可以轻松将现有的 `http.Handler` 或 `http.HandlerFunc` 集成到 Touka 中。 diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md new file mode 100644 index 0000000..5dfcbd1 --- /dev/null +++ b/docs/reverse-proxy.md @@ -0,0 +1,377 @@ +# 反向代理 + +Touka 内置了反向代理能力,可以直接把某一组请求转发到后端服务,同时保留 Touka 的路由、中间件与统一错误处理风格。 + +`touka.ReverseProxy` 返回一个 `HandlerFunc`,因此它可以像普通路由处理器一样直接挂到 `GET`、`ANY`、路由组等位置。 + +## 最简单的用法 + +```go +package main + +import ( + "log" + "net/url" + + "github.com/infinite-iroha/touka" +) + +func main() { + r := touka.Default() + + target, err := url.Parse("http://127.0.0.1:9000") + if err != nil { + log.Fatal(err) + } + + r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + })) + + _ = r.Run(":8080") +} +``` + +当客户端访问 `http://127.0.0.1:8080/api/users` 时,请求会被转发到 `http://127.0.0.1:9000/api/users`。 + +## 带基础路径的代理 + +如果目标服务部署在一个子路径下,可以直接把目标地址写成带路径的 URL: + +```go +target, _ := url.Parse("http://127.0.0.1:9000/backend") + +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, +})) +``` + +此时: + +- `/api/users` 会转发到 `/backend/api/users` +- `/api/orders?id=10` 会转发到 `/backend/api/orders?id=10` + +目标 URL 自身携带的查询参数也会被保留并与原请求查询参数合并。 +合并后的出站查询串会再经过一次规范化处理,因此某些非标准分隔符(例如 `;`)或非法参数片段可能被重编码、折叠或直接丢弃。 +这是为了尽量让代理链各跳对查询参数的解析结果保持一致,并减少参数走私这类解析歧义风险。 + +## 配置项说明 + +```go +type ReverseProxyConfig struct { + Target *url.URL + + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool + + ModifyRequest func(*http.Request) + ModifyResponse func(*http.Response) error + ErrorHandler func(http.ResponseWriter, *http.Request, error) + + ForwardedHeaders ForwardedHeadersPolicy + ForwardedBy string + Via string + PreserveHost bool +} +``` + +### `Target` + +必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。 + +```go +target, _ := url.Parse("http://backend:9000") +``` + +### `Transport` + +可选。用于自定义底层转发所使用的 `http.RoundTripper`。 + +如果留空,则默认使用 `http.DefaultTransport`。 + +```go +proxyTransport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, +} + +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + Transport: proxyTransport, +})) +``` + +### `FlushInterval` + +控制代理在复制响应体时的主动刷新间隔: + +- `0`:不额外定时刷新 +- `> 0`:按指定间隔刷新 +- `< 0`:每次写入后立即刷新 + +对于 SSE 和无 `Content-Length` 的流式响应,Touka 会自动立即刷新,不依赖该配置。 + +### `BufferPool` + +可选。用于为响应体复制过程提供可复用的字节缓冲区,以减少大响应或高并发代理场景下的临时内存分配。 + +如果留空,Touka 会在复制响应体时按需分配默认缓冲区。 + +```go +type bytePool struct { + pool sync.Pool +} + +func (p *bytePool) Get() []byte { + if buf, ok := p.pool.Get().([]byte); ok { + return buf + } + return make([]byte, 32*1024) +} + +func (p *bytePool) Put(buf []byte) { + if cap(buf) >= 32*1024 { + p.pool.Put(buf[:32*1024]) + } +} + +proxyPool := &bytePool{} + +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + BufferPool: proxyPool, +})) +``` + +通常只有在您已经观察到明显的分配压力,或代理的响应体较大、吞吐较高时,才需要专门配置它。 + +### `ModifyRequest` + +在请求真正发往后端前,对出站请求做最后修改。 + +常见用途: + +- 覆盖 `Host` +- 增加鉴权头 +- 重写路径 +- 注入内部追踪头 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ModifyRequest: func(req *http.Request) { + req.Header.Set("X-Internal-Token", "gateway-token") + }, +})) +``` + +### `ModifyResponse` + +在后端返回响应后、写回客户端前,对响应做额外处理。 + +注意:`ModifyResponse` 也会作用于 `101 Switching Protocols` 响应。 +如果该代理路由需要转发 WebSocket 或其他 Upgrade 流量,请不要在这里消费、完全缓冲,或替换 `resp.Body` 为只读对象;后续升级流程仍然要求它保留 `io.ReadWriteCloser` 能力。 +更稳妥的做法是对 `101` 响应直接跳过这类处理。 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ModifyResponse: func(resp *http.Response) error { + if resp.StatusCode == http.StatusSwitchingProtocols { + return nil + } + resp.Header.Set("X-Proxy", "touka") + return nil + }, +})) +``` + +如果该函数返回错误,会转入 `ErrorHandler` 或默认的 `502 Bad Gateway` 处理流程。 + +### `ErrorHandler` + +用于处理连接后端失败、协议升级失败、`ModifyResponse` 返回错误等情况。 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte("upstream unavailable")) + }, +})) +``` + +### `PreserveHost` + +默认情况下,代理请求的 `Host` 会跟随后端目标地址。 + +如果设置为 `true`,则会保留客户端原始 `Host`。 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + PreserveHost: true, +})) +``` + +这在某些依赖原始域名进行路由或租户识别的后端服务中会比较有用。 + +## 转发头策略 + +Touka 支持两类常见的代理转发头: + +- 兼容性更好的 `X-Forwarded-*` +- 标准化的 `Forwarded`(RFC 7239) + +可选值: + +```go +const ( + ForwardedBoth ForwardedHeadersPolicy = iota + ForwardedNone + ForwardedXForwardedOnly + ForwardedRFC7239Only +) +``` + +推荐默认使用 `ForwardedBoth`。 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ForwardedHeaders: touka.ForwardedBoth, + ForwardedBy: "gateway-1", + Via: "edge-1", +})) +``` + +`Via` 不是“留空即禁用”的开关。当前实现中: + +- 如果 `Via` 非空,则使用该值追加 `Via` +- 如果 `Via` 为空,则会回退到固定值 `touka-engine` + +因此,把 `Via` 留空时,发送出去的请求仍会包含 `Via` 头,只是使用默认标识 `touka-engine`。 + +如果您希望上游清楚区分不同入口、环境或网关实例,仍然建议显式设置一个稳定且可公开暴露的代理标识,例如: + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + Via: "edge-gateway", +})) +``` + +当前版本没有提供“完全禁用追加 Via”的单独配置项,因此不要把空字符串当作关闭手段。 + +### Touka 会如何处理这些头? + +Touka 会尽量遵循代理链语义: + +- 已有的 `X-Forwarded-For` 会保留,并在末尾追加当前 hop 的客户端 IP +- 已有的 `Forwarded` 会保留,并在末尾追加当前 hop 的条目 +- 已有的 `X-Forwarded-Host` 与 `X-Forwarded-Proto` 会优先保留;如果缺失,则由当前请求补齐 +- `Via` 会追加当前代理标识 + +这意味着在 Touka 前面还有一层可信代理(如 Nginx、Traefik、Cloudflare、网关)时,上游服务仍然可以看到完整的代理链。 + +如果您**不信任**客户端传入的这些头,请在进入 `ReverseProxy` 之前自行清理,或在 `ModifyRequest` 中显式重写。 + +## 协议升级与流式响应 + +Touka 的反向代理实现支持以下能力: + +- `Connection: Upgrade` / `Upgrade` 协议升级转发 +- WebSocket 等 101 Switching Protocols 场景 +- SSE(Server-Sent Events)立即刷新 +- Trailer 透传 +- 1xx 响应透传 + +例如,代理 WebSocket 服务: + +```go +target, _ := url.Parse("http://127.0.0.1:9001") + +r.ANY("/ws/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, +})) +``` + +## Hop-by-hop 头处理 + +根据 HTTP 代理语义,Touka 在转发时会移除连接级别的 hop-by-hop 头,避免把只应作用于单跳连接的头继续传给下游。 + +典型包括: + +- `Connection` +- `Proxy-Connection` +- `Keep-Alive` +- `Proxy-Authenticate` +- `Proxy-Authorization` +- `TE` +- `Trailer` +- `Transfer-Encoding` +- `Upgrade` + +同时,若请求本身是合法的协议升级请求,Touka 会在剥离后重新补回必要的 `Connection: Upgrade` 与 `Upgrade` 头。 + +## 一个更完整的例子 + +```go +package main + +import ( + "log" + "net/http" + "net/url" + "time" + + "github.com/infinite-iroha/touka" +) + +func main() { + r := touka.Default() + + target, err := url.Parse("http://127.0.0.1:9000") + if err != nil { + log.Fatal(err) + } + + r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ForwardedHeaders: touka.ForwardedBoth, + ForwardedBy: "gateway-1", + Via: "gateway-1", + FlushInterval: 100 * time.Millisecond, + ModifyRequest: func(req *http.Request) { + req.Header.Set("X-Gateway", "touka") + }, + ModifyResponse: func(resp *http.Response) error { + resp.Header.Set("X-Proxy", "touka") + return nil + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte("bad gateway")) + }, + })) + + if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + log.Fatal(err) + } +} +``` + +## 与 `SetForwardByClientIP` 的关系 + +`ReverseProxy` 负责把请求转发给后端,并维护代理链头。 + +而 `SetForwardByClientIP` / `SetRemoteIPHeaders` 是 Touka 在**接收请求**时,用于解析当前请求客户端 IP 的逻辑。 + +两者通常会一起出现,但解决的是两个不同方向的问题: + +- `ReverseProxy`:出站转发 +- `SetForwardByClientIP`:入站解析 + +如果您的 Touka 本身就部署在其他代理之后,建议同时正确配置这两部分。 diff --git a/respw.go b/respw.go index 2cf6700..dd94db3 100644 --- a/respw.go +++ b/respw.go @@ -45,6 +45,15 @@ func newResponseWriter(w http.ResponseWriter) ResponseWriter { } } +// UnwrapResponseWriter returns the underlying stdlib response writer when the +// provided writer is Touka's internal wrapper. +func UnwrapResponseWriter(w ResponseWriter) http.ResponseWriter { + if wrapped, ok := w.(*responseWriterImpl); ok && wrapped.ResponseWriter != nil { + return wrapped.ResponseWriter + } + return w +} + func (rw *responseWriterImpl) reset(w http.ResponseWriter) { rw.ResponseWriter = w rw.status = 0 @@ -56,6 +65,10 @@ func (rw *responseWriterImpl) WriteHeader(statusCode int) { if rw.hijacked { return } + if statusCode >= 100 && statusCode < 200 && statusCode != http.StatusSwitchingProtocols { + rw.ResponseWriter.WriteHeader(statusCode) + return + } if rw.status == 0 { // 确保只设置一次 rw.status = statusCode rw.ResponseWriter.WriteHeader(statusCode) diff --git a/reverseproxy.go b/reverseproxy.go new file mode 100644 index 0000000..1730b1e --- /dev/null +++ b/reverseproxy.go @@ -0,0 +1,933 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2026 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "mime" + "net" + "net/http" + "net/http/httptrace" + "net/netip" + "net/textproto" + "net/url" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// ForwardedHeadersPolicy controls how forwarding headers are generated. +// The zero value uses both X-Forwarded-* and RFC 7239 Forwarded headers. +type ForwardedHeadersPolicy int + +const ( + ForwardedBoth ForwardedHeadersPolicy = iota + ForwardedNone + ForwardedXForwardedOnly + ForwardedRFC7239Only +) + +// BufferPool provides temporary buffers for response body copying. +type BufferPool interface { + Get() []byte + Put([]byte) +} + +// ReverseProxyConfig configures the reverse proxy handler. +type ReverseProxyConfig struct { + Target *url.URL + + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool + + ModifyRequest func(*http.Request) + ModifyResponse func(*http.Response) error + ErrorHandler func(http.ResponseWriter, *http.Request, error) + + ForwardedHeaders ForwardedHeadersPolicy + ForwardedBy string + Via string + PreserveHost bool +} + +var ( + errReverseProxyNilTarget = errors.New("reverse proxy target is nil") + errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") + errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") +) + +type reverseProxyHandler struct { + config ReverseProxyConfig + target *url.URL + receivedBy string + configError error +} + +type reverseProxyStatusError struct { + status int + err error +} + +func (e *reverseProxyStatusError) Error() string { + if e == nil || e.err == nil { + return "" + } + return e.err.Error() +} + +func (e *reverseProxyStatusError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +type noopCloseReader struct { + readCloser io.ReadCloser + closed atomic.Bool +} + +func (n *noopCloseReader) Read(p []byte) (int, error) { + if n.closed.Load() { + return 0, errors.New("reverse proxy read on closed body") + } + return n.readCloser.Read(p) +} + +func (n *noopCloseReader) Close() error { + n.closed.Store(true) + return nil +} + +type maxLatencyWriter struct { + dst ResponseWriter + latency time.Duration + + mu sync.Mutex + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + n, err := m.dst.Write(p) + if m.latency < 0 { + m.dst.Flush() + return n, err + } + if m.flushPending { + return n, err + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return n, err +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.flushPending { + return + } + m.dst.Flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +type switchProtocolCopier struct { + user io.ReadWriter + backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + if _, err := io.Copy(c.user, c.backend); err != nil { + errc <- err + return + } + if cw, ok := c.user.(interface{ CloseWrite() error }); ok { + errc <- cw.CloseWrite() + return + } + errc <- errReverseProxyCopyDone +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + if _, err := io.Copy(c.backend, c.user); err != nil { + errc <- err + return + } + if cw, ok := c.backend.(interface{ CloseWrite() error }); ok { + errc <- cw.CloseWrite() + return + } + errc <- errReverseProxyCopyDone +} + +// ReverseProxy returns a handler that proxies requests to the configured backend. +func ReverseProxy(config ReverseProxyConfig) HandlerFunc { + proxy := newReverseProxyHandler(config) + return func(c *Context) { + proxy.ServeHTTP(c) + } +} + +func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { + target := cloneReverseProxyURL(config.Target) + if target != nil { + normalizeReverseProxyTarget(target) + } + + proxy := &reverseProxyHandler{ + config: config, + target: target, + receivedBy: reverseProxyReceivedBy(config.Via), + } + + if err := validateReverseProxyTarget(target); err != nil { + proxy.configError = err + } + + switch config.ForwardedHeaders { + case ForwardedBoth, ForwardedNone, ForwardedXForwardedOnly, ForwardedRFC7239Only: + default: + proxy.config.ForwardedHeaders = ForwardedBoth + } + + return proxy +} + +func (p *reverseProxyHandler) ServeHTTP(c *Context) { + defer c.Abort() + + if p.configError != nil { + p.handleError(c, &reverseProxyStatusError{status: http.StatusInternalServerError, err: p.configError}) + return + } + + transport := p.config.Transport + if transport == nil { + transport = http.DefaultTransport + } + + ctx, cancel := p.requestContext(c) + defer cancel() + + outreq := c.Request.Clone(ctx) + if c.Request.ContentLength == 0 { + outreq.Body = nil + } + if outreq.Body != nil { + outreq.Body = &noopCloseReader{readCloser: outreq.Body} + defer outreq.Body.Close() + } + if outreq.Header == nil { + outreq.Header = make(http.Header) + } + outreq.Close = false + + rewriteReverseProxyURL(outreq, p.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + + reqUpType := reverseProxyUpgradeType(outreq.Header) + if reqUpType != "" && !isPrintableASCII(reqUpType) { + p.handleError(c, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), + }) + return + } + + removeHopByHopHeaders(outreq.Header) + if headerValuesContainToken(c.Request.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + p.addForwardingHeaders(c.Request, outreq) + appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) + + if _, ok := outreq.Header["User-Agent"]; !ok { + outreq.Header.Set("User-Agent", "") + } + + if p.config.ModifyRequest != nil { + p.config.ModifyRequest(outreq) + } + + rawWriter := reverseProxyBaseResponseWriter(c.Writer) + var ( + roundTripMu sync.Mutex + roundTripDone bool + ) + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + roundTripMu.Lock() + defer roundTripMu.Unlock() + if roundTripDone { + return nil + } + h := c.Writer.Header() + saved := h.Clone() + clear(h) + reverseProxyCopyHeader(h, http.Header(header)) + rawWriter.WriteHeader(code) + clear(h) + reverseProxyCopyHeader(h, saved) + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + + res, err := transport.RoundTrip(outreq) + roundTripMu.Lock() + roundTripDone = true + roundTripMu.Unlock() + if err != nil { + p.handleError(c, err) + return + } + + if res.StatusCode == http.StatusSwitchingProtocols { + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return + } + if err := p.handleUpgradeResponse(c, outreq, res); err != nil { + p.handleError(c, err) + } + return + } + + removeHopByHopHeaders(res.Header) + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + + if !p.modifyResponse(c, res, outreq) { + return + } + + reverseProxyCopyHeader(c.Writer.Header(), res.Header) + + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for key := range res.Trailer { + trailerKeys = append(trailerKeys, key) + } + c.Writer.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + c.Writer.WriteHeader(res.StatusCode) + + if err := p.copyResponse(c.Writer, res.Body, p.flushInterval(res)); err != nil { + defer res.Body.Close() + c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err)) + p.logf(c, "reverse proxy body copy failed: %v", err) + return + } + res.Body.Close() + + if len(res.Trailer) > 0 { + c.Writer.Flush() + } + + // Keep the stdlib-compatible fallback here. + // If the backend only exposes additional trailer keys after the body has been + // fully read, the trailer map can grow and those values must be written using + // the TrailerPrefix form instead of the pre-announced bare header keys. + if len(res.Trailer) == announcedTrailers { + reverseProxyCopyHeader(c.Writer.Header(), res.Trailer) + return + } + + for key, values := range res.Trailer { + prefixedKey := http.TrailerPrefix + key + for _, value := range values { + c.Writer.Header().Add(prefixedKey, value) + } + } +} + +func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) { + ctx := c.Request.Context() + if ctx.Done() != nil { + return ctx, func() {} + } + + // Follow the same compatibility path as net/http/httputil.ReverseProxy: + // request contexts are normally cancelable, but middleware can still replace + // c.Request with one backed by context.Background/TODO or another context with + // a nil Done channel. In that case CloseNotifier still provides disconnect + // propagation for the upstream round trip. + rawWriter := reverseProxyBaseResponseWriter(c.Writer) + cn, ok := rawWriter.(http.CloseNotifier) + if !ok { + return ctx, func() {} + } + + ctx, cancel := context.WithCancel(ctx) + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + return ctx, cancel +} + +func (p *reverseProxyHandler) addForwardingHeaders(in *http.Request, out *http.Request) { + if p.config.ForwardedHeaders == ForwardedNone { + return + } + + clientIP := reverseProxyClientIP(in.RemoteAddr) + scheme := reverseProxyRequestScheme(in) + host := in.Host + + if p.config.ForwardedHeaders == ForwardedBoth || p.config.ForwardedHeaders == ForwardedXForwardedOnly { + if clientIP != "" { + appendXForwardedFor(out.Header, clientIP) + } + if host != "" { + if len(out.Header.Values("X-Forwarded-Host")) == 0 { + out.Header.Set("X-Forwarded-Host", host) + } + } + if scheme != "" { + if len(out.Header.Values("X-Forwarded-Proto")) == 0 { + out.Header.Set("X-Forwarded-Proto", scheme) + } + } + } + + if p.config.ForwardedHeaders == ForwardedBoth || p.config.ForwardedHeaders == ForwardedRFC7239Only { + if forwardedValue := buildForwardedHeaderValue(clientIP, p.config.ForwardedBy, host, scheme); forwardedValue != "" { + if prior := out.Header.Values("Forwarded"); len(prior) > 0 { + forwardedValue = strings.Join(prior, ", ") + ", " + forwardedValue + out.Header.Del("Forwarded") + } + out.Header.Add("Forwarded", forwardedValue) + } + } +} + +func appendXForwardedFor(header http.Header, clientIP string) { + if clientIP == "" { + return + } + prior := header.Values("X-Forwarded-For") + if len(prior) == 0 { + header.Set("X-Forwarded-For", clientIP) + return + } + header.Set("X-Forwarded-For", strings.Join(prior, ", ")+", "+clientIP) +} + +func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool { + if p.config.ModifyResponse == nil { + return true + } + if err := p.config.ModifyResponse(res); err != nil { + res.Body.Close() + p.handleError(c, err) + return false + } + return true +} + +func (p *reverseProxyHandler) handleError(c *Context, err error) { + if err == nil { + return + } + c.AddError(err) + if c.Writer.IsHijacked() { + p.logf(c, "reverse proxy error after hijack: %v", err) + return + } + if p.config.ErrorHandler != nil { + p.config.ErrorHandler(c.Writer, c.Request, err) + if c.Writer.Written() || c.Writer.IsHijacked() { + return + } + } + c.ErrorUseHandle(reverseProxyStatusCode(err), err) +} + +func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error { + reqUpType := reverseProxyUpgradeType(req.Header) + resUpType := reverseProxyUpgradeType(res.Header) + if reqUpType == "" || resUpType == "" { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: fmt.Errorf("invalid upgrade negotiation: request protocol=%q, response protocol=%q", reqUpType, resUpType), + } + } + if !isPrintableASCII(resUpType) { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), + } + } + if !strings.EqualFold(reqUpType, resUpType) { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), + } + } + + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("backend returned 101 response without writable body"), + } + } + + clientConn, brw, err := c.Writer.Hijack() + if err != nil { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + defer clientConn.Close() + defer backConn.Close() + + backConnClosed := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + case <-backConnClosed: + } + backConn.Close() + }() + defer close(backConnClosed) + + res.Body = nil + if err := res.Write(brw); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + if err := brw.Flush(); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + copyer := switchProtocolCopier{user: clientConn, backend: backConn} + go copyer.copyToBackend(errc) + go copyer.copyFromBackend(errc) + + firstErr := <-errc + if firstErr == nil { + firstErr = <-errc + } + if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) { + return nil + } + return firstErr +} + +func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { + if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { + return -1 + } + if res.ContentLength == -1 { + return -1 + } + return p.config.FlushInterval +} + +func (p *reverseProxyHandler) copyResponse(dst ResponseWriter, src io.Reader, flushInterval time.Duration) error { + var writer io.Writer = dst + + if flushInterval != 0 { + mlw := &maxLatencyWriter{dst: dst, latency: flushInterval} + defer mlw.stop() + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + writer = mlw + } + + var buf []byte + if p.config.BufferPool != nil { + buf = p.config.BufferPool.Get() + defer p.config.BufferPool.Put(buf) + } + _, err := p.copyBuffer(writer, src, buf) + return err +} + +func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) { + p.logf(nil, "reverse proxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if errors.Is(rerr, io.EOF) { + return written, nil + } + return written, rerr + } + } +} + +func (p *reverseProxyHandler) logf(c *Context, format string, args ...any) { + if c != nil { + if logger := c.GetLogger(); logger != nil { + logger.Errorf(format, args...) + return + } + } + log.Printf(format, args...) +} + +func reverseProxyStatusCode(err error) int { + var statusErr *reverseProxyStatusError + if errors.As(err, &statusErr) && statusErr.status > 0 { + return statusErr.status + } + return http.StatusBadGateway +} + +func validateReverseProxyTarget(target *url.URL) error { + if target == nil { + return errReverseProxyNilTarget + } + if target.Scheme == "" || target.Host == "" { + return errReverseProxyInvalidTarget + } + return nil +} + +func normalizeReverseProxyTarget(target *url.URL) { + switch strings.ToLower(target.Scheme) { + case "ws": + target.Scheme = "http" + case "wss": + target.Scheme = "https" + } +} + +func cloneReverseProxyURL(target *url.URL) *url.URL { + if target == nil { + return nil + } + clone := *target + return &clone +} + +func reverseProxyReceivedBy(configValue string) string { + trimmed := strings.TrimSpace(configValue) + if trimmed != "" { + return trimmed + } + return "touka-engine" +} + +func reverseProxyClientIP(remoteAddr string) string { + if remoteAddr == "" { + return "" + } + if addrPort, err := netip.ParseAddrPort(remoteAddr); err == nil { + return addrPort.Addr().String() + } + host, _, err := net.SplitHostPort(remoteAddr) + if err == nil { + if addr, parseErr := netip.ParseAddr(host); parseErr == nil { + return addr.String() + } + return host + } + if addr, err := netip.ParseAddr(remoteAddr); err == nil { + return addr.String() + } + return remoteAddr +} + +func reverseProxyRequestScheme(req *http.Request) string { + if req == nil { + return "" + } + if req.TLS != nil { + return "https" + } + if req.URL != nil { + scheme := strings.ToLower(req.URL.Scheme) + if scheme != "" { + return scheme + } + } + return "http" +} + +func buildForwardedHeaderValue(clientIP, by, host, scheme string) string { + pairs := make([]string, 0, 4) + if by != "" { + pairs = append(pairs, "by="+formatForwardedParameterValue(by)) + } + if clientIP != "" { + pairs = append(pairs, "for="+formatForwardedFor(clientIP)) + } + if host != "" { + pairs = append(pairs, "host="+formatForwardedParameterValue(host)) + } + if scheme != "" { + pairs = append(pairs, "proto="+formatForwardedParameterValue(strings.ToLower(scheme))) + } + if len(pairs) == 0 { + return "" + } + return strings.Join(pairs, ";") +} + +func formatForwardedFor(clientIP string) string { + addr, err := netip.ParseAddr(clientIP) + if err != nil { + return formatForwardedParameterValue(clientIP) + } + if addr.Is6() { + return quoteForwardedString("[" + addr.String() + "]") + } + return addr.String() +} + +func formatForwardedParameterValue(value string) string { + if isToken(value) { + return value + } + return quoteForwardedString(value) +} + +func quoteForwardedString(value string) string { + replacer := strings.NewReplacer(`\`, `\\`, `"`, `\"`) + return `"` + replacer.Replace(value) + `"` +} + +func isToken(value string) bool { + if value == "" { + return false + } + for i := 0; i < len(value); i++ { + if !isTokenChar(value[i]) { + return false + } + } + return true +} + +func isTokenChar(b byte) bool { + if b >= '0' && b <= '9' { + return true + } + if b >= 'A' && b <= 'Z' { + return true + } + if b >= 'a' && b <= 'z' { + return true + } + switch b { + case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~': + return true + default: + return false + } +} + +func appendViaHeader(header http.Header, protocol, receivedBy string) { + if header == nil || receivedBy == "" { + return + } + if protocol == "" { + protocol = "1.1" + } + header.Add("Via", protocol+" "+receivedBy) +} + +func reverseProxyViaProtocol(major, minor int, raw string) string { + if major > 0 { + return strconv.Itoa(major) + "." + strconv.Itoa(minor) + } + if strings.HasPrefix(raw, "HTTP/") { + return strings.TrimPrefix(raw, "HTTP/") + } + return raw +} + +func rewriteReverseProxyURL(req *http.Request, target *url.URL) { + targetQuery := target.RawQuery + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path, req.URL.RawPath = joinReverseProxyURLPath(target, req.URL) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } +} + +func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) { + if base.RawPath == "" && incoming.RawPath == "" { + return reverseProxySingleJoiningSlash(base.Path, incoming.Path), "" + } + + baseEscaped := base.EscapedPath() + incomingEscaped := incoming.EscapedPath() + + baseSlash := strings.HasSuffix(baseEscaped, "/") + incomingSlash := strings.HasPrefix(incomingEscaped, "/") + + switch { + case baseSlash && incomingSlash: + return base.Path + incoming.Path[1:], baseEscaped + incomingEscaped[1:] + case !baseSlash && !incomingSlash: + return base.Path + "/" + incoming.Path, baseEscaped + "/" + incomingEscaped + default: + return base.Path + incoming.Path, baseEscaped + incomingEscaped + } +} + +func reverseProxySingleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + default: + return a + b + } +} + +func reverseProxyCopyHeader(dst, src http.Header) { + for key, values := range src { + for _, value := range values { + dst.Add(key, value) + } + } +} + +var reverseProxyHopHeaders = []string{ + "Connection", + "Proxy-Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +func removeHopByHopHeaders(header http.Header) { + for _, connectionValue := range header["Connection"] { + for _, token := range strings.Split(connectionValue, ",") { + trimmed := textproto.TrimString(token) + if trimmed != "" { + header.Del(trimmed) + } + } + } + for _, hopHeader := range reverseProxyHopHeaders { + header.Del(hopHeader) + } +} + +func reverseProxyUpgradeType(header http.Header) string { + if !headerValuesContainToken(header["Connection"], "Upgrade") { + return "" + } + return header.Get("Upgrade") +} + +func headerValuesContainToken(values []string, token string) bool { + if token == "" { + return false + } + for _, value := range values { + for _, part := range strings.Split(value, ",") { + if strings.EqualFold(textproto.TrimString(part), token) { + return true + } + } + } + return false +} + +func cleanReverseProxyQueryParams(rawQuery string) string { + if rawQuery == "" { + return "" + } + // Normalize the outgoing query string so the proxy and upstream do not see + // different semantics for non-standard separators or malformed pairs. + // This can change the exact textual form of the original query and may drop + // parts that net/url rejects, but it keeps proxy-chain parsing behavior more + // consistent and reduces parameter-smuggling ambiguity. + values, _ := url.ParseQuery(rawQuery) + return values.Encode() +} + +func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { + return UnwrapResponseWriter(writer) +} + +func isPrintableASCII(value string) bool { + for i := 0; i < len(value); i++ { + if value[i] < 0x20 || value[i] > 0x7e { + return false + } + } + return true +} diff --git a/reverseproxy_test.go b/reverseproxy_test.go new file mode 100644 index 0000000..f82aff9 --- /dev/null +++ b/reverseproxy_test.go @@ -0,0 +1,570 @@ +package touka + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/http/httptrace" + "net/textproto" + "net/url" + "strings" + "testing" + "time" +) + +func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { + t.Helper() + + type backendRequestSnapshot struct { + Path string + RawQuery string + Host string + Connection string + RemovedHeader string + Forwarded string + XForwardedFor string + XForwardedHost string + XForwardedProto string + Via []string + TE string + UserAgent string + } + + gotCh := make(chan backendRequestSnapshot, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotCh <- backendRequestSnapshot{ + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + Host: r.Host, + Connection: r.Header.Get("Connection"), + RemovedHeader: r.Header.Get("X-Remove-Me"), + Forwarded: r.Header.Get("Forwarded"), + XForwardedFor: r.Header.Get("X-Forwarded-For"), + XForwardedHost: r.Header.Get("X-Forwarded-Host"), + XForwardedProto: r.Header.Get("X-Forwarded-Proto"), + Via: append([]string(nil), r.Header.Values("Via")...), + TE: r.Header.Get("Te"), + UserAgent: r.Header.Get("User-Agent"), + } + + w.Header().Set("Connection", "X-Backend-Secret") + w.Header().Set("X-Backend-Secret", "remove-me") + w.Header().Add("Via", "1.0 upstream") + w.Header().Add("Trailer", "X-Upstream-Trailer") + w.Header().Set("Content-Type", "text/plain") + _, _ = io.WriteString(w, "proxied") + w.Header().Set("X-Upstream-Trailer", "done") + })) + defer backend.Close() + + target, err := url.Parse(backend.URL + "/base?from=target") + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ + Target: target, + ForwardedHeaders: ForwardedBoth, + ForwardedBy: "proxy-node", + Via: "proxy.test", + })) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/api/ping?bad=1;smuggle=2&q=2", nil) + req.Host = "client.example" + req.RemoteAddr = "198.51.100.10:4567" + req.Header.Set("Connection", "X-Remove-Me") + req.Header.Set("X-Remove-Me", "client-secret") + req.Header.Set("X-Forwarded-For", "203.0.113.9") + req.Header.Set("X-Forwarded-Host", "edge.example") + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("Forwarded", "for=203.0.113.9") + req.Header.Set("Te", "trailers") + + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + resp := rr.Result() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + var got backendRequestSnapshot + select { + case got = <-gotCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend snapshot") + } + + if string(body) != "proxied" { + t.Fatalf("unexpected body: %q", string(body)) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if got.Path != "/base/api/ping" { + t.Fatalf("unexpected upstream path: %q", got.Path) + } + if got.RawQuery != "from=target&q=2" { + t.Fatalf("unexpected upstream raw query: %q", got.RawQuery) + } + if got.Host != strings.TrimPrefix(backend.URL, "http://") { + t.Fatalf("unexpected upstream host: %q", got.Host) + } + if got.Connection != "" { + t.Fatalf("connection header should be stripped, got %q", got.Connection) + } + if got.RemovedHeader != "" { + t.Fatalf("connection-token header should be stripped, got %q", got.RemovedHeader) + } + if got.XForwardedFor != "203.0.113.9, 198.51.100.10" { + t.Fatalf("unexpected X-Forwarded-For: %q", got.XForwardedFor) + } + if got.XForwardedHost != "edge.example" { + t.Fatalf("unexpected X-Forwarded-Host: %q", got.XForwardedHost) + } + if got.XForwardedProto != "https" { + t.Fatalf("unexpected X-Forwarded-Proto: %q", got.XForwardedProto) + } + if got.TE != "trailers" { + t.Fatalf("unexpected TE header: %q", got.TE) + } + if got.UserAgent != "" { + t.Fatalf("expected empty user-agent suppression, got %q", got.UserAgent) + } + if !strings.Contains(got.Forwarded, "for=203.0.113.9") { + t.Fatalf("forwarded header missing prior hop: %q", got.Forwarded) + } + if !strings.Contains(got.Forwarded, "for=198.51.100.10") { + t.Fatalf("forwarded header missing client ip: %q", got.Forwarded) + } + if !strings.Contains(got.Forwarded, "by=proxy-node") { + t.Fatalf("forwarded header missing by token: %q", got.Forwarded) + } + if !strings.Contains(got.Forwarded, "host=client.example") { + t.Fatalf("forwarded header missing host: %q", got.Forwarded) + } + if !strings.Contains(got.Forwarded, "proto=http") { + t.Fatalf("forwarded header missing proto: %q", got.Forwarded) + } + if len(got.Via) != 1 || got.Via[0] != "1.1 proxy.test" { + t.Fatalf("unexpected upstream Via headers: %#v", got.Via) + } + if resp.Header.Get("Connection") != "" { + t.Fatalf("response connection header should be stripped, got %q", resp.Header.Get("Connection")) + } + if resp.Header.Get("X-Backend-Secret") != "" { + t.Fatalf("response connection-token header should be stripped, got %q", resp.Header.Get("X-Backend-Secret")) + } + if gotVia := resp.Header.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.0 upstream" || gotVia[1] != "1.1 proxy.test" { + t.Fatalf("unexpected response Via headers: %#v", gotVia) + } + if resp.Trailer.Get("X-Upstream-Trailer") != "done" { + t.Fatalf("unexpected proxied trailer: %q", resp.Trailer.Get("X-Upstream-Trailer")) + } +} + +func TestReverseProxyDefaultViaFallback(t *testing.T) { + t.Helper() + + viaCh := make(chan []string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + viaCh <- append([]string(nil), r.Header.Values("Via")...) + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{Target: target})) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case via := <-viaCh: + if len(via) != 1 || via[0] != "1.1 touka-engine" { + t.Fatalf("unexpected default Via header: %#v", via) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Via header") + } +} + +func TestReverseProxyCustomErrorHandler(t *testing.T) { + t.Helper() + + engine := New() + target, err := url.Parse("http://127.0.0.1:1") + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: target, + ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) { + w.WriteHeader(http.StatusGatewayTimeout) + _, _ = io.WriteString(w, fmt.Sprintf("proxy failure: %v", err)) + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusGatewayTimeout { + t.Fatalf("unexpected status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "proxy failure:") { + t.Fatalf("unexpected body: %q", rr.Body.String()) + } +} + +func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "later") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "streamed") + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/trailers", ReverseProxy(ReverseProxyConfig{Target: target})) + + rr := PerformRequest(engine, http.MethodGet, "/trailers", nil, nil) + resp := rr.Result() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if string(body) != "streamed" { + t.Fatalf("unexpected body: %q", string(body)) + } + if got := resp.Trailer.Get("X-Unannounced-Trailer"); got != "later" { + t.Fatalf("unexpected unannounced trailer: %q", got) + } +} + +func TestReverseProxyProtocolUpgrade(t *testing.T) { + t.Helper() + + errCh := make(chan error, 8) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !headerValuesContainToken(r.Header["Connection"], "Upgrade") { + errCh <- fmt.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("backend response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("backend read failed: %w", err) + return + } + _, _ = io.WriteString(brw, "echo:"+line) + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend echo flush failed: %w", err) + return + } + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/ws", ReverseProxy(ReverseProxyConfig{ + Target: target, + Via: "proxy.test", + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + + _, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n") + if err != nil { + t.Fatalf("write upgrade request: %v", err) + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read status line: %v", err) + } + if !strings.Contains(statusLine, "101") { + t.Fatalf("unexpected status line: %q", statusLine) + } + + headers, err := textproto.NewReader(reader).ReadMIMEHeader() + if err != nil { + t.Fatalf("read headers: %v", err) + } + respHeader := http.Header(headers) + if !strings.EqualFold(respHeader.Get("Upgrade"), "websocket") { + t.Fatalf("unexpected upgrade response header: %q", respHeader.Get("Upgrade")) + } + if !headerValuesContainToken(respHeader.Values("Connection"), "Upgrade") { + t.Fatalf("unexpected connection response header: %#v", respHeader.Values("Connection")) + } + if gotVia := respHeader.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(conn, "ping\n"); err != nil { + t.Fatalf("write tunneled payload: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read tunneled payload: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled payload: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) { + t.Helper() + + errCh := make(chan error, 4) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("backend response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend flush failed: %w", err) + return + } + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: target})) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + + _, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n") + if err != nil { + t.Fatalf("write upgrade request: %v", err) + } + + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) { + t.Helper() + + type oneXXInfo struct { + code int + header http.Header + } + + backendTraceCh := make(chan struct{}, 1) + oneXXCh := make(chan oneXXInfo, 1) + + transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.Got1xxResponse == nil { + return nil, errors.New("missing Got1xxResponse trace") + } + backendTraceCh <- struct{}{} + if err := trace.Got1xxResponse(http.StatusEarlyHints, textproto.MIMEHeader{"Link": {"; rel=preload; as=style"}}); err != nil { + return nil, err + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + Body: io.NopCloser(strings.NewReader("ok")), + ContentLength: 2, + Request: req, + }, nil + }) + + engine := New() + engine.Use(func(c *Context) { + c.Writer.Header().Set("X-Request-Id", "req-123") + c.Next() + }) + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: transport, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + client := proxy.Client() + req, err := http.NewRequest(http.MethodGet, proxy.URL+"/proxy", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + oneXXCh <- oneXXInfo{code: code, header: http.Header(header).Clone()} + return nil + }, + })) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("perform request: %v", err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + select { + case <-backendTraceCh: + case <-time.After(2 * time.Second): + t.Fatal("expected proxy transport 1xx trace to be invoked") + } + + var oneXX oneXXInfo + select { + case oneXX = <-oneXXCh: + case <-time.After(2 * time.Second): + t.Fatal("expected client to receive 1xx response") + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if string(body) != "ok" { + t.Fatalf("unexpected body: %q", string(body)) + } + if got := resp.Header.Get("X-Request-Id"); got != "req-123" { + t.Fatalf("final response lost preserved header: %q", got) + } + if got := resp.Header.Get("Link"); got != "" { + t.Fatalf("interim 1xx header leaked into final response: %q", got) + } + if oneXX.code != http.StatusEarlyHints { + t.Fatalf("unexpected interim status: %d", oneXX.code) + } + if got := oneXX.header.Get("Link"); got != "; rel=preload; as=style" { + t.Fatalf("unexpected interim Link header: %q", got) + } + if got := oneXX.header.Get("X-Request-Id"); got != "" { + t.Fatalf("final-only header leaked into interim response: %q", got) + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + if err != nil { + t.Fatalf("parse url %q: %v", raw, err) + } + return u +}