diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index e495a19..1f8a353 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -168,10 +168,17 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ 在后端返回响应后、写回客户端前,对响应做额外处理。 +注意:`ModifyResponse` 也会作用于 `101 Switching Protocols` 响应。 +如果该代理路由需要转发 WebSocket 或其他 Upgrade 流量,请不要在这里消费、完全缓冲,或替换 `resp.Body` 为只读对象;后续升级流程仍然要求它保留 `io.ReadWriteCloser` 能力。 +更稳妥的做法是对 `101` 响应直接跳过这类处理。 + ```go r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ModifyResponse: func(resp *http.Response) error { + if resp.StatusCode == http.StatusSwitchingProtocols { + return nil + } resp.Header.Set("X-Proxy", "touka") return nil }, diff --git a/reverseproxy.go b/reverseproxy.go index f486364..c635a1f 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -472,6 +472,10 @@ func (p *reverseProxyHandler) handleError(c *Context, err error) { return } c.AddError(err) + if c.Writer.IsHijacked() { + p.logf(c, "reverse proxy error after hijack: %v", err) + return + } if p.config.ErrorHandler != nil { p.config.ErrorHandler(c.Writer, c.Request, err) if c.Writer.Written() || c.Writer.IsHijacked() { @@ -906,10 +910,7 @@ func cleanReverseProxyQueryParams(rawQuery string) string { if rawQuery == "" { return "" } - values, err := url.ParseQuery(rawQuery) - if err == nil { - return rawQuery - } + values, _ := url.ParseQuery(rawQuery) return values.Encode() } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 1b643ef..f82aff9 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -74,7 +74,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { Via: "proxy.test", })) - req := httptest.NewRequest(http.MethodGet, "http://client.example/api/ping?q=2", nil) + req := httptest.NewRequest(http.MethodGet, "http://client.example/api/ping?bad=1;smuggle=2&q=2", nil) req.Host = "client.example" req.RemoteAddr = "198.51.100.10:4567" req.Header.Set("Connection", "X-Remove-Me")