mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
fix(reverseproxy): align forwarding and tunnel semantics
This commit is contained in:
parent
c019f24e99
commit
ed44c592d3
6 changed files with 864 additions and 26 deletions
|
|
@ -242,11 +242,20 @@ const (
|
||||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
Target: target,
|
Target: target,
|
||||||
ForwardedHeaders: touka.ForwardedBoth,
|
ForwardedHeaders: touka.ForwardedBoth,
|
||||||
ForwardedBy: "gateway-1",
|
ForwardedBy: "_gateway-1",
|
||||||
Via: "edge-1",
|
Via: "edge-1",
|
||||||
}))
|
}))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。
|
||||||
|
|
||||||
|
- IPv4:`203.0.113.43`
|
||||||
|
- IPv6 / 带端口:`[2001:db8::17]:443`
|
||||||
|
- 匿名标识:`_gateway-1`
|
||||||
|
- 未知:`unknown`
|
||||||
|
|
||||||
|
像 `gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。
|
||||||
|
|
||||||
`Via` 不是“留空即禁用”的开关。当前实现中:
|
`Via` 不是“留空即禁用”的开关。当前实现中:
|
||||||
|
|
||||||
- 如果 `Via` 非空,则使用该值追加 `Via`
|
- 如果 `Via` 非空,则使用该值追加 `Via`
|
||||||
|
|
@ -282,11 +291,13 @@ Touka 会尽量遵循代理链语义:
|
||||||
|
|
||||||
Touka 的反向代理实现支持以下能力:
|
Touka 的反向代理实现支持以下能力:
|
||||||
|
|
||||||
|
- `CONNECT` 隧道转发(HTTP/1.x)
|
||||||
- `Connection: Upgrade` / `Upgrade` 协议升级转发
|
- `Connection: Upgrade` / `Upgrade` 协议升级转发
|
||||||
- WebSocket 等 101 Switching Protocols 场景
|
- WebSocket 等 101 Switching Protocols 场景
|
||||||
- SSE(Server-Sent Events)立即刷新
|
- SSE(Server-Sent Events)立即刷新
|
||||||
- Trailer 透传
|
- Trailer 透传
|
||||||
- 1xx 响应透传
|
- 1xx 响应透传
|
||||||
|
- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理
|
||||||
|
|
||||||
例如,代理 WebSocket 服务:
|
例如,代理 WebSocket 服务:
|
||||||
|
|
||||||
|
|
@ -341,7 +352,7 @@ func main() {
|
||||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
Target: target,
|
Target: target,
|
||||||
ForwardedHeaders: touka.ForwardedBoth,
|
ForwardedHeaders: touka.ForwardedBoth,
|
||||||
ForwardedBy: "gateway-1",
|
ForwardedBy: "_gateway-1",
|
||||||
Via: "gateway-1",
|
Via: "gateway-1",
|
||||||
FlushInterval: 100 * time.Millisecond,
|
FlushInterval: 100 * time.Millisecond,
|
||||||
ModifyRequest: func(req *http.Request) {
|
ModifyRequest: func(req *http.Request) {
|
||||||
|
|
|
||||||
2
ecw.go
2
ecw.go
|
|
@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool {
|
||||||
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
hijacker, ok := ecw.w.(http.Hijacker)
|
hijacker, ok := ecw.w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface")
|
return nil, nil, http.ErrNotSupported
|
||||||
}
|
}
|
||||||
return hijacker.Hijack()
|
return hijacker.Hijack()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
52
engine.go
52
engine.go
|
|
@ -475,21 +475,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
|
||||||
func MethodNotAllowed() HandlerFunc {
|
func MethodNotAllowed() HandlerFunc {
|
||||||
return func(c *Context) {
|
return func(c *Context) {
|
||||||
httpMethod := c.Request.Method
|
httpMethod := c.Request.Method
|
||||||
requestPath := c.Request.URL.Path
|
requestPath := routeLookupPath(c.Request)
|
||||||
engine := c.engine
|
engine := c.engine
|
||||||
// 是否是OPTIONS方式
|
// 是否是OPTIONS方式
|
||||||
if httpMethod == http.MethodOptions {
|
if httpMethod == http.MethodOptions {
|
||||||
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
||||||
allowedMethods := []string{}
|
allowedMethods := engine.allowedMethodsForPath(requestPath)
|
||||||
for _, treeIter := range engine.methodTrees {
|
|
||||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
|
||||||
tempSkippedNodes := GetTempSkippedNodes()
|
|
||||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
|
||||||
PutTempSkippedNodes(tempSkippedNodes)
|
|
||||||
if value.handlers != nil {
|
|
||||||
allowedMethods = append(allowedMethods, treeIter.method)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(allowedMethods) > 0 {
|
if len(allowedMethods) > 0 {
|
||||||
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
|
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
|
||||||
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
|
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
|
||||||
|
|
@ -705,7 +696,7 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
// 这是路由查找和执行的核心逻辑
|
// 这是路由查找和执行的核心逻辑
|
||||||
func (engine *Engine) handleRequest(c *Context) {
|
func (engine *Engine) handleRequest(c *Context) {
|
||||||
httpMethod := c.Request.Method
|
httpMethod := c.Request.Method
|
||||||
requestPath := c.Request.URL.Path
|
requestPath := routeLookupPath(c.Request)
|
||||||
|
|
||||||
// 查找对应的路由树的根节点
|
// 查找对应的路由树的根节点
|
||||||
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
|
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
|
||||||
|
|
@ -725,7 +716,7 @@ func (engine *Engine) handleRequest(c *Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
|
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
|
||||||
if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向
|
if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向
|
||||||
if value.tsr && engine.RedirectTrailingSlash {
|
if value.tsr && engine.RedirectTrailingSlash {
|
||||||
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
|
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
|
||||||
redirectPath := requestPath
|
redirectPath := requestPath
|
||||||
|
|
@ -782,6 +773,41 @@ func (engine *Engine) handleRequest(c *Context) {
|
||||||
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func routeLookupPath(req *http.Request) string {
|
||||||
|
if req == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") {
|
||||||
|
return "/" + req.RequestURI
|
||||||
|
}
|
||||||
|
if isGeneralOptionsRequest(req) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if req.URL == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return req.URL.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
func isGeneralOptionsRequest(req *http.Request) bool {
|
||||||
|
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) allowedMethodsForPath(requestPath string) []string {
|
||||||
|
allowedMethods := make([]string, 0, len(engine.methodTrees))
|
||||||
|
for _, treeIter := range engine.methodTrees {
|
||||||
|
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||||
|
tempSkippedNodes := GetTempSkippedNodes()
|
||||||
|
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
||||||
|
PutTempSkippedNodes(tempSkippedNodes)
|
||||||
|
if value.handlers != nil {
|
||||||
|
allowedMethods = append(allowedMethods, treeIter.method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return allowedMethods
|
||||||
|
}
|
||||||
|
|
||||||
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
|
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
|
||||||
// 它可以用于在长连接 (如 SSE) 中监听关闭信号.
|
// 它可以用于在长连接 (如 SSE) 中监听关闭信号.
|
||||||
func (engine *Engine) Context() context.Context {
|
func (engine *Engine) Context() context.Context {
|
||||||
|
|
|
||||||
2
respw.go
2
respw.go
|
|
@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
// 尝试从底层 ResponseWriter 获取 Hijacker 接口
|
// 尝试从底层 ResponseWriter 获取 Hijacker 接口
|
||||||
hj, ok := rw.ResponseWriter.(http.Hijacker)
|
hj, ok := rw.ResponseWriter.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, errors.New("http.Hijacker interface not supported")
|
return nil, nil, http.ErrNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调用底层的 Hijack 方法
|
// 调用底层的 Hijack 方法
|
||||||
|
|
|
||||||
358
reverseproxy.go
358
reverseproxy.go
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
|
"net/http/httputil"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
@ -217,6 +218,12 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler {
|
||||||
default:
|
default:
|
||||||
proxy.config.ForwardedHeaders = ForwardedBoth
|
proxy.config.ForwardedHeaders = ForwardedBoth
|
||||||
}
|
}
|
||||||
|
proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy)
|
||||||
|
if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) {
|
||||||
|
if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil {
|
||||||
|
proxy.configError = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return proxy
|
return proxy
|
||||||
}
|
}
|
||||||
|
|
@ -234,11 +241,20 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
transport = http.DefaultTransport
|
transport = http.DefaultTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c)
|
||||||
|
if err != nil {
|
||||||
|
p.handleError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if handledLocally {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := p.requestContext(c)
|
ctx, cancel := p.requestContext(c)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
outreq := c.Request.Clone(ctx)
|
outreq := c.Request.Clone(ctx)
|
||||||
if c.Request.ContentLength == 0 {
|
if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 {
|
||||||
outreq.Body = nil
|
outreq.Body = nil
|
||||||
}
|
}
|
||||||
if outreq.Body != nil {
|
if outreq.Body != nil {
|
||||||
|
|
@ -249,12 +265,35 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
outreq.Header = make(http.Header)
|
outreq.Header = make(http.Header)
|
||||||
}
|
}
|
||||||
outreq.Close = false
|
outreq.Close = false
|
||||||
|
var connectWriter *io.PipeWriter
|
||||||
|
defer func() {
|
||||||
|
if connectWriter != nil {
|
||||||
|
_ = connectWriter.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if outreq.Method == http.MethodConnect {
|
||||||
|
pipeReader, pipeWriter := io.Pipe()
|
||||||
|
outreq.Body = pipeReader
|
||||||
|
outreq.ContentLength = -1
|
||||||
|
defer outreq.Body.Close()
|
||||||
|
connectWriter = pipeWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
if outreq.Method == http.MethodConnect {
|
||||||
|
if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil {
|
||||||
|
p.handleError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
rewriteReverseProxyURL(outreq, p.target)
|
rewriteReverseProxyURL(outreq, p.target)
|
||||||
if !p.config.PreserveHost {
|
if !p.config.PreserveHost {
|
||||||
outreq.Host = ""
|
outreq.Host = ""
|
||||||
}
|
}
|
||||||
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery)
|
||||||
|
}
|
||||||
|
if updatedMaxForwards != "" {
|
||||||
|
outreq.Header.Set("Max-Forwards", updatedMaxForwards)
|
||||||
|
}
|
||||||
|
|
||||||
reqUpType := reverseProxyUpgradeType(outreq.Header)
|
reqUpType := reverseProxyUpgradeType(outreq.Header)
|
||||||
if reqUpType != "" && !isPrintableASCII(reqUpType) {
|
if reqUpType != "" && !isPrintableASCII(reqUpType) {
|
||||||
|
|
@ -318,6 +357,23 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices {
|
||||||
|
removeHopByHopHeaders(res.Header)
|
||||||
|
res.Header.Del("Content-Length")
|
||||||
|
res.Header.Del("Transfer-Encoding")
|
||||||
|
res.ContentLength = -1
|
||||||
|
res.TransferEncoding = nil
|
||||||
|
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||||
|
if !p.modifyResponse(c, res, outreq) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil {
|
||||||
|
p.handleError(c, err)
|
||||||
|
}
|
||||||
|
connectWriter = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||||
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy)
|
||||||
if !p.modifyResponse(c, res, outreq) {
|
if !p.modifyResponse(c, res, outreq) {
|
||||||
|
|
@ -353,6 +409,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err))
|
c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err))
|
||||||
p.logf(c, "reverse proxy body copy failed: %v", err)
|
p.logf(c, "reverse proxy body copy failed: %v", err)
|
||||||
|
if reverseProxyShouldPanicOnCopyError(c.Request) {
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
|
|
@ -378,6 +437,86 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) {
|
||||||
|
if c == nil || c.Request == nil {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c.Request.Method {
|
||||||
|
case http.MethodOptions, http.MethodTrace:
|
||||||
|
default:
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards"))
|
||||||
|
if rawValue == "" {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := strconv.Atoi(rawValue)
|
||||||
|
if err != nil || value < 0 {
|
||||||
|
return "", false, &reverseProxyStatusError{
|
||||||
|
status: http.StatusBadRequest,
|
||||||
|
err: fmt.Errorf("invalid Max-Forwards value %q", rawValue),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value == 0 {
|
||||||
|
switch c.Request.Method {
|
||||||
|
case http.MethodTrace:
|
||||||
|
return "", true, p.writeLocalTraceResponse(c)
|
||||||
|
case http.MethodOptions:
|
||||||
|
p.writeLocalOptionsResponse(c)
|
||||||
|
return "", true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strconv.Itoa(value - 1), false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error {
|
||||||
|
if c == nil || c.Request == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
traceReq := c.Request.Clone(c.Request.Context())
|
||||||
|
traceReq.Body = nil
|
||||||
|
traceReq.ContentLength = 0
|
||||||
|
traceReq.TransferEncoding = nil
|
||||||
|
traceReq.RequestURI = c.Request.RequestURI
|
||||||
|
if traceReq.RequestURI == "" && traceReq.URL != nil {
|
||||||
|
traceReq.RequestURI = traceReq.URL.RequestURI()
|
||||||
|
}
|
||||||
|
traceReq.Header = traceReq.Header.Clone()
|
||||||
|
for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} {
|
||||||
|
traceReq.Header.Del(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
dump, err := httputil.DumpRequest(traceReq, false)
|
||||||
|
if err != nil {
|
||||||
|
return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "message/http")
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
_, err = c.Writer.Write(dump)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.engine != nil {
|
||||||
|
if c.Request != nil && c.Request.RequestURI != "*" {
|
||||||
|
if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 {
|
||||||
|
c.Writer.Header().Set("Allow", strings.Join(allow, ", "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) {
|
func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
if ctx.Done() != nil {
|
if ctx.Done() != nil {
|
||||||
|
|
@ -522,7 +661,11 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques
|
||||||
clientConn, brw, err := c.Writer.Hijack()
|
clientConn, brw, err := c.Writer.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
backConn.Close()
|
backConn.Close()
|
||||||
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
status := http.StatusBadGateway
|
||||||
|
if errors.Is(err, http.ErrNotSupported) {
|
||||||
|
status = http.StatusNotImplemented
|
||||||
|
}
|
||||||
|
return &reverseProxyStatusError{status: status, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
@ -561,6 +704,80 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques
|
||||||
return firstErr
|
return firstErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error {
|
||||||
|
if backWrite == nil {
|
||||||
|
res.Body.Close()
|
||||||
|
return &reverseProxyStatusError{
|
||||||
|
status: http.StatusBadGateway,
|
||||||
|
err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
backRead := res.Body
|
||||||
|
|
||||||
|
clientConn, brw, err := c.Writer.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
backRead.Close()
|
||||||
|
_ = backWrite.Close()
|
||||||
|
status := http.StatusBadGateway
|
||||||
|
if errors.Is(err, http.ErrNotSupported) {
|
||||||
|
status = http.StatusNotImplemented
|
||||||
|
}
|
||||||
|
return &reverseProxyStatusError{status: status, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
defer clientConn.Close()
|
||||||
|
defer backRead.Close()
|
||||||
|
defer backWrite.Close()
|
||||||
|
|
||||||
|
backConnClosed := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-req.Context().Done():
|
||||||
|
case <-backConnClosed:
|
||||||
|
}
|
||||||
|
backRead.Close()
|
||||||
|
_ = backWrite.Close()
|
||||||
|
}()
|
||||||
|
defer close(backConnClosed)
|
||||||
|
|
||||||
|
res.Body = nil
|
||||||
|
if err := res.Write(brw); err != nil {
|
||||||
|
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||||
|
}
|
||||||
|
if err := brw.Flush(); err != nil {
|
||||||
|
return &reverseProxyStatusError{status: http.StatusBadGateway, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
errc := make(chan error, 2)
|
||||||
|
go func() {
|
||||||
|
if _, err := io.Copy(clientConn, backRead); err != nil {
|
||||||
|
errc <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cw, ok := clientConn.(interface{ CloseWrite() error }); ok {
|
||||||
|
errc <- cw.CloseWrite()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errc <- errReverseProxyCopyDone
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
if _, err := io.Copy(backWrite, clientConn); err != nil {
|
||||||
|
errc <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errc <- backWrite.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
firstErr := <-errc
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = <-errc
|
||||||
|
}
|
||||||
|
if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
|
||||||
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration {
|
||||||
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
|
if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" {
|
||||||
return -1
|
return -1
|
||||||
|
|
@ -638,6 +855,10 @@ func reverseProxyStatusCode(err error) int {
|
||||||
if errors.As(err, &statusErr) && statusErr.status > 0 {
|
if errors.As(err, &statusErr) && statusErr.status > 0 {
|
||||||
return statusErr.status
|
return statusErr.status
|
||||||
}
|
}
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) {
|
||||||
|
return http.StatusGatewayTimeout
|
||||||
|
}
|
||||||
return http.StatusBadGateway
|
return http.StatusBadGateway
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -651,6 +872,17 @@ func validateReverseProxyTarget(target *url.URL) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateReverseProxyForwardedBy(value string) error {
|
||||||
|
trimmed := strings.TrimSpace(value)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !isValidForwardedNodeIdentifier(trimmed) {
|
||||||
|
return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeReverseProxyTarget(target *url.URL) {
|
func normalizeReverseProxyTarget(target *url.URL) {
|
||||||
switch strings.ToLower(target.Scheme) {
|
switch strings.ToLower(target.Scheme) {
|
||||||
case "ws":
|
case "ws":
|
||||||
|
|
@ -732,6 +964,83 @@ func buildForwardedHeaderValue(clientIP, by, host, scheme string) string {
|
||||||
return strings.Join(pairs, ";")
|
return strings.Join(pairs, ";")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
||||||
|
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidForwardedNodeIdentifier(value string) bool {
|
||||||
|
if value == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(value, "[") {
|
||||||
|
closing := strings.IndexByte(value, ']')
|
||||||
|
if closing <= 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
addr, err := netip.ParseAddr(value[1:closing])
|
||||||
|
if err != nil || !addr.Is6() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if closing == len(value)-1 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if value[closing+1] != ':' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isValidForwardedNodePort(value[closing+2:])
|
||||||
|
}
|
||||||
|
|
||||||
|
host, port, hasPort := strings.Cut(value, ":")
|
||||||
|
if hasPort {
|
||||||
|
switch {
|
||||||
|
case host == "unknown", isValidForwardedObfuscatedIdentifier(host):
|
||||||
|
return isValidForwardedNodePort(port)
|
||||||
|
default:
|
||||||
|
addr, err := netip.ParseAddr(host)
|
||||||
|
return err == nil && addr.Is4() && isValidForwardedNodePort(port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
addr, err := netip.ParseAddr(value)
|
||||||
|
return err == nil && addr.Is4()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidForwardedNodePort(value string) bool {
|
||||||
|
if value == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isValidForwardedObfuscatedIdentifier(value) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(value) > 5 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
port, err := strconv.Atoi(value)
|
||||||
|
return err == nil && port > 0 && port <= 65535
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidForwardedObfuscatedIdentifier(value string) bool {
|
||||||
|
if len(value) < 2 || value[0] != '_' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := 1; i < len(value); i++ {
|
||||||
|
b := value[i]
|
||||||
|
if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch b {
|
||||||
|
case '.', '_', '-':
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func formatForwardedFor(clientIP string) string {
|
func formatForwardedFor(clientIP string) string {
|
||||||
addr, err := netip.ParseAddr(clientIP)
|
addr, err := netip.ParseAddr(clientIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -817,6 +1126,47 @@ func rewriteReverseProxyURL(req *http.Request, target *url.URL) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error {
|
||||||
|
connectTarget, err := reverseProxyConnectTarget(target)
|
||||||
|
if err != nil {
|
||||||
|
return &reverseProxyStatusError{status: http.StatusBadRequest, err: err}
|
||||||
|
}
|
||||||
|
req.URL.Scheme = target.Scheme
|
||||||
|
req.URL.Host = target.Host
|
||||||
|
req.URL.Path = ""
|
||||||
|
req.URL.RawPath = ""
|
||||||
|
req.URL.RawQuery = ""
|
||||||
|
req.URL.Opaque = connectTarget
|
||||||
|
req.Host = connectTarget
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyConnectTarget(target *url.URL) (string, error) {
|
||||||
|
if target == nil {
|
||||||
|
return "", errReverseProxyNilTarget
|
||||||
|
}
|
||||||
|
host := target.Hostname()
|
||||||
|
if host == "" {
|
||||||
|
return "", errReverseProxyInvalidTarget
|
||||||
|
}
|
||||||
|
port := target.Port()
|
||||||
|
if port == "" {
|
||||||
|
switch strings.ToLower(target.Scheme) {
|
||||||
|
case "http":
|
||||||
|
port = "80"
|
||||||
|
case "https":
|
||||||
|
port = "443"
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
portNum, err := strconv.Atoi(port)
|
||||||
|
if err != nil || portNum <= 0 || portNum > 65535 {
|
||||||
|
return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port)
|
||||||
|
}
|
||||||
|
return net.JoinHostPort(host, port), nil
|
||||||
|
}
|
||||||
|
|
||||||
func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) {
|
func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) {
|
||||||
if base.RawPath == "" && incoming.RawPath == "" {
|
if base.RawPath == "" && incoming.RawPath == "" {
|
||||||
return reverseProxySingleJoiningSlash(base.Path, incoming.Path), ""
|
return reverseProxySingleJoiningSlash(base.Path, incoming.Path), ""
|
||||||
|
|
@ -919,6 +1269,10 @@ func cleanReverseProxyQueryParams(rawQuery string) string {
|
||||||
return values.Encode()
|
return values.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reverseProxyShouldPanicOnCopyError(req *http.Request) bool {
|
||||||
|
return req != nil && req.Context().Value(http.ServerContextKey) != nil
|
||||||
|
}
|
||||||
|
|
||||||
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter {
|
||||||
return UnwrapResponseWriter(writer)
|
return UnwrapResponseWriter(writer)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package touka
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
@ -70,7 +71,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
||||||
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
|
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
|
||||||
Target: target,
|
Target: target,
|
||||||
ForwardedHeaders: ForwardedBoth,
|
ForwardedHeaders: ForwardedBoth,
|
||||||
ForwardedBy: "proxy-node",
|
ForwardedBy: "_proxy-node",
|
||||||
Via: "proxy.test",
|
Via: "proxy.test",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
@ -144,7 +145,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
||||||
if !strings.Contains(got.Forwarded, "for=198.51.100.10") {
|
if !strings.Contains(got.Forwarded, "for=198.51.100.10") {
|
||||||
t.Fatalf("forwarded header missing client ip: %q", got.Forwarded)
|
t.Fatalf("forwarded header missing client ip: %q", got.Forwarded)
|
||||||
}
|
}
|
||||||
if !strings.Contains(got.Forwarded, "by=proxy-node") {
|
if !strings.Contains(got.Forwarded, "by=_proxy-node") {
|
||||||
t.Fatalf("forwarded header missing by token: %q", got.Forwarded)
|
t.Fatalf("forwarded header missing by token: %q", got.Forwarded)
|
||||||
}
|
}
|
||||||
if !strings.Contains(got.Forwarded, "host=client.example") {
|
if !strings.Contains(got.Forwarded, "host=client.example") {
|
||||||
|
|
@ -170,6 +171,61 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyRejectsInvalidForwardedBy(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, "http://example.com"),
|
||||||
|
ForwardedHeaders: ForwardedBoth,
|
||||||
|
ForwardedBy: "proxy-node",
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if rr.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyForwardedByTrimsWhitespace(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
forwardedCh := make(chan string, 1)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
forwardedCh <- r.Header.Get("Forwarded")
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, backend.URL),
|
||||||
|
ForwardedHeaders: ForwardedBoth,
|
||||||
|
ForwardedBy: " _proxy-node ",
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil)
|
||||||
|
req.RemoteAddr = "198.51.100.10:4567"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case forwarded := <-forwardedCh:
|
||||||
|
if !strings.Contains(forwarded, "by=_proxy-node") {
|
||||||
|
t.Fatalf("unexpected Forwarded header: %q", forwarded)
|
||||||
|
}
|
||||||
|
if strings.Contains(forwarded, `by=" _proxy-node "`) {
|
||||||
|
t.Fatalf("forwarded header should not preserve surrounding whitespace: %q", forwarded)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for backend Forwarded header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyDefaultViaFallback(t *testing.T) {
|
func TestReverseProxyDefaultViaFallback(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
@ -229,6 +285,23 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, "http://example.com"),
|
||||||
|
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, context.DeadlineExceeded
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil)
|
||||||
|
if rr.Code != http.StatusGatewayTimeout {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
|
func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
@ -452,6 +525,362 @@ func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyUpgradeNeedsHijacker(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("backend response writer does not support hijack")
|
||||||
|
}
|
||||||
|
conn, brw, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("backend hijack failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")
|
||||||
|
_ = brw.Flush()
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://client.example/ws", nil)
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotImplemented {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyMaxForwardsTraceHandledLocally(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
called := make(chan struct{}, 1)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called <- struct{}{}
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil)
|
||||||
|
req.RequestURI = "/trace"
|
||||||
|
req.Header.Set("Max-Forwards", "0")
|
||||||
|
req.Header.Set("Authorization", "secret")
|
||||||
|
req.Header.Set("Cookie", "a=b")
|
||||||
|
req.Header.Set("Forwarded", "for=192.0.2.1")
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
resp := rr.Result()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read body: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if got := resp.Header.Get("Content-Type"); got != "message/http" {
|
||||||
|
t.Fatalf("unexpected content type: %q", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(body), "TRACE /trace HTTP/1.1") {
|
||||||
|
t.Fatalf("trace body missing request line: %q", string(body))
|
||||||
|
}
|
||||||
|
if strings.Contains(string(body), "Authorization:") {
|
||||||
|
t.Fatalf("trace body leaked authorization header: %q", string(body))
|
||||||
|
}
|
||||||
|
if strings.Contains(string(body), "Cookie:") {
|
||||||
|
t.Fatalf("trace body leaked cookie header: %q", string(body))
|
||||||
|
}
|
||||||
|
if strings.Contains(string(body), "Forwarded:") {
|
||||||
|
t.Fatalf("trace body leaked forwarded header: %q", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-called:
|
||||||
|
t.Fatal("backend should not be called when Max-Forwards is zero")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyMaxForwardsTraceDecrementsBeforeForwarding(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
maxForwardsCh := make(chan string, 1)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
maxForwardsCh <- r.Header.Get("Max-Forwards")
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil)
|
||||||
|
req.Header.Set("Max-Forwards", "2")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case got := <-maxForwardsCh:
|
||||||
|
if got != "1" {
|
||||||
|
t.Fatalf("unexpected Max-Forwards header: %q", got)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for backend Max-Forwards")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
called := make(chan struct{}, 1)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called <- struct{}{}
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", func(c *Context) { c.Status(http.StatusNoContent) })
|
||||||
|
engine.OPTIONS("/proxy", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "http://client.example/proxy", nil)
|
||||||
|
req.Header.Set("Max-Forwards", "0")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
allow := rr.Header().Get("Allow")
|
||||||
|
if !strings.Contains(allow, http.MethodGet) || !strings.Contains(allow, http.MethodOptions) {
|
||||||
|
t.Fatalf("unexpected Allow header: %q", allow)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-called:
|
||||||
|
t.Fatal("backend should not be called when Max-Forwards is zero")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.OPTIONS("/", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "http://client.example/", nil)
|
||||||
|
req.RequestURI = "*"
|
||||||
|
req.URL.Path = ""
|
||||||
|
req.URL.RawPath = ""
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyConnectTunnel(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backendAddr := ""
|
||||||
|
errCh := make(chan error, 4)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodConnect {
|
||||||
|
errCh <- fmt.Errorf("unexpected method: %s", r.Method)
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got, want := r.RequestURI, backendAddr; got != want {
|
||||||
|
errCh <- fmt.Errorf("unexpected CONNECT target %q, want %q", got, want)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
errCh <- errors.New("backend response writer does not support hijack")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, brw, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("backend hijack failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\nVia: 1.1 upstream\r\n\r\n")
|
||||||
|
if err := brw.Flush(); err != nil {
|
||||||
|
errCh <- fmt.Errorf("backend flush failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := brw.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("backend read failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = io.WriteString(brw, strings.ToUpper(line))
|
||||||
|
if err := brw.Flush(); err != nil {
|
||||||
|
errCh <- fmt.Errorf("backend write failed: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
backendAddr = strings.TrimPrefix(backend.URL, "http://")
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodConnect, "/:authority", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, backend.URL),
|
||||||
|
Via: "proxy.test",
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
t.Fatalf("set deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = fmt.Fprintf(conn, "CONNECT origin.example:443 HTTP/1.1\r\nHost: origin.example:443\r\n\r\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("write connect request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
|
statusLine, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read status line: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(statusLine, "200") {
|
||||||
|
t.Fatalf("unexpected status line: %q", statusLine)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers, err := textproto.NewReader(reader).ReadMIMEHeader()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read headers: %v", err)
|
||||||
|
}
|
||||||
|
respHeader := http.Header(headers)
|
||||||
|
if got := respHeader.Get("Content-Length"); got != "" {
|
||||||
|
t.Fatalf("CONNECT response should not include Content-Length, got %q", got)
|
||||||
|
}
|
||||||
|
if got := respHeader.Get("Transfer-Encoding"); got != "" {
|
||||||
|
t.Fatalf("CONNECT response should not include Transfer-Encoding, got %q", got)
|
||||||
|
}
|
||||||
|
if gotVia := respHeader.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.1 upstream" || gotVia[1] != "1.1 proxy.test" {
|
||||||
|
t.Fatalf("unexpected Via response header: %#v", gotVia)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.WriteString(conn, "ping\n"); err != nil {
|
||||||
|
t.Fatalf("write tunneled payload: %v", err)
|
||||||
|
}
|
||||||
|
message, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read tunneled payload: %v", err)
|
||||||
|
}
|
||||||
|
if message != "PING\n" {
|
||||||
|
t.Fatalf("unexpected tunneled payload: %q", message)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatal(err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyConnectNeedsHijacker(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("backend response writer does not support hijack")
|
||||||
|
}
|
||||||
|
conn, brw, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("backend hijack failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\n\r\n")
|
||||||
|
_ = brw.Flush()
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodConnect, "/tunnel", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodConnect, "http://client.example/tunnel", nil)
|
||||||
|
req.URL.Path = "/tunnel"
|
||||||
|
req.RequestURI = "/tunnel"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotImplemented {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, "http://example.com"),
|
||||||
|
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"text/plain"},
|
||||||
|
},
|
||||||
|
Body: &failingReadCloser{chunks: []string{"ok"}, err: errors.New("boom")},
|
||||||
|
ContentLength: -1,
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
resp, err := proxy.Client().Get(proxy.URL + "/proxy")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("perform request: %v", err)
|
||||||
|
}
|
||||||
|
_, err = io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected body read to fail after upstream copy error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
|
func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
@ -568,3 +997,21 @@ func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||||
}
|
}
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type failingReadCloser struct {
|
||||||
|
chunks []string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *failingReadCloser) Read(p []byte) (int, error) {
|
||||||
|
if len(r.chunks) == 0 {
|
||||||
|
return 0, r.err
|
||||||
|
}
|
||||||
|
n := copy(p, r.chunks[0])
|
||||||
|
r.chunks = r.chunks[1:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *failingReadCloser) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue