fix: cut redirect and allow-path routing overhead

Reuse fixed-path and Allow-header buffers so redirect and OPTIONS handling stop rebuilding temporary data on every request. Cache fallback chains and add regression coverage for redirect, 404, 405, and Allow behavior to keep the faster miss paths stable.
This commit is contained in:
wjqserver 2026-04-07 09:06:56 +08:00
parent 5d979e5670
commit 2d4aefc86e
6 changed files with 264 additions and 48 deletions

View file

@ -73,6 +73,12 @@ type Context struct {
// skippedNodes 用于记录跳过的节点信息,以便回溯 // skippedNodes 用于记录跳过的节点信息,以便回溯
// 通常在处理嵌套路由时使用 // 通常在处理嵌套路由时使用
SkippedNodes []skippedNode SkippedNodes []skippedNode
// fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲.
fixedPathBuf []byte
allowedMethodsBuf []string
allowHeaderBuf []byte
} }
// --- Context 相关方法实现 --- // --- Context 相关方法实现 ---
@ -111,6 +117,15 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
} else { } else {
c.SkippedNodes = make([]skippedNode, 0, 256) c.SkippedNodes = make([]skippedNode, 0, 256)
} }
if cap(c.fixedPathBuf) > 0 {
c.fixedPathBuf = c.fixedPathBuf[:0]
}
if cap(c.allowedMethodsBuf) > 0 {
c.allowedMethodsBuf = c.allowedMethodsBuf[:0]
}
if cap(c.allowHeaderBuf) > 0 {
c.allowHeaderBuf = c.allowHeaderBuf[:0]
}
} }
// Next 在处理链中执行下一个处理函数 // Next 在处理链中执行下一个处理函数

125
engine.go
View file

@ -11,6 +11,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"unicode/utf8"
"net/http" "net/http"
@ -82,6 +83,11 @@ type Engine struct {
// GlobalMaxRequestBodySize 全局请求体Body大小限制 // GlobalMaxRequestBodySize 全局请求体Body大小限制
GlobalMaxRequestBodySize int64 GlobalMaxRequestBodySize int64
notFoundChain HandlersChain
notFoundNoMethodChain HandlersChain
unmatchedFSChain HandlersChain
unmatchedFSNoMethodChain HandlersChain
} }
// HandleFunc 注册一个或多个 HTTP 方法的路由 // HandleFunc 注册一个或多个 HTTP 方法的路由
@ -127,10 +133,19 @@ var methodNotAllowedHandler HandlerFunc = func(c *Context) {
// 是否是OPTIONS方式 // 是否是OPTIONS方式
if httpMethod == http.MethodOptions { if httpMethod == http.MethodOptions {
// 如果是 OPTIONS 请求,尝试查找所有允许的方法 // 如果是 OPTIONS 请求,尝试查找所有允许的方法
allowedMethods := engine.allowedMethodsForPath(requestPath) allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0])
c.allowedMethodsBuf = allowedMethods[:0]
if len(allowedMethods) > 0 { if len(allowedMethods) > 0 {
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) allowHeader := c.allowHeaderBuf[:0]
for i, method := range allowedMethods {
if i > 0 {
allowHeader = append(allowHeader, ',', ' ')
}
allowHeader = append(allowHeader, method...)
}
c.allowHeaderBuf = allowHeader[:0]
c.Writer.Header().Set("Allow", BytesToString(allowHeader))
c.Status(http.StatusOK) c.Status(http.StatusOK)
return return
} }
@ -251,6 +266,7 @@ func New() *Engine {
TLSServerConfigurator: nil, TLSServerConfigurator: nil,
GlobalMaxRequestBodySize: -1, GlobalMaxRequestBodySize: -1,
} }
engine.rebuildFallbackChains()
engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background())
//engine.SetProtocols(GetDefaultProtocolsConfig()) //engine.SetProtocols(GetDefaultProtocolsConfig())
engine.SetDefaultProtocols() engine.SetDefaultProtocols()
@ -306,6 +322,7 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) {
// 是否开启MethodNotAllowed // 是否开启MethodNotAllowed
func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { func (engine *Engine) SetHandleMethodNotAllowed(enable bool) {
engine.HandleMethodNotAllowed = enable engine.HandleMethodNotAllowed = enable
engine.rebuildFallbackChains()
} }
// SetLogger传入实例 // SetLogger传入实例
@ -346,6 +363,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF
engine.unMatchFS.ServeUnmatchedAsFS = false engine.unMatchFS.ServeUnmatchedAsFS = false
engine.UnMatchFSRoutes = nil engine.UnMatchFSRoutes = nil
} }
engine.rebuildFallbackChains()
} }
// 获取默认Protocol配置 // 获取默认Protocol配置
@ -531,12 +549,52 @@ func NotFound() HandlerFunc {
func (Engine *Engine) NoRoute(handler HandlerFunc) { func (Engine *Engine) NoRoute(handler HandlerFunc) {
Engine.noRoute = handler Engine.noRoute = handler
Engine.noRoutes = nil Engine.noRoutes = nil
Engine.rebuildFallbackChains()
} }
// 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理)
func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) {
Engine.noRoute = nil Engine.noRoute = nil
Engine.noRoutes = handlerFuncs Engine.noRoutes = handlerFuncs
Engine.rebuildFallbackChains()
}
func (engine *Engine) rebuildFallbackChains() {
buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain {
finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound
if includeMethodNotAllowed {
finalSize++
}
if includeUnmatchedFS {
finalSize += len(engine.UnMatchFSRoutes)
}
if engine.noRoute != nil {
finalSize++
} else {
finalSize += len(engine.noRoutes)
}
chain := make(HandlersChain, 0, finalSize)
chain = append(chain, engine.globalHandlers...)
if includeMethodNotAllowed {
chain = append(chain, methodNotAllowedHandler)
}
if includeUnmatchedFS {
chain = append(chain, engine.UnMatchFSRoutes...)
}
if engine.noRoute != nil {
chain = append(chain, engine.noRoute)
} else if len(engine.noRoutes) > 0 {
chain = append(chain, engine.noRoutes...)
}
chain = append(chain, notFoundHandler)
return chain
}
engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false)
engine.notFoundNoMethodChain = buildChain(false, false)
engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS)
engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS)
} }
// combineHandlers 组合多个处理函数链为一个 // combineHandlers 组合多个处理函数链为一个
@ -553,6 +611,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
// 这些中间件将应用于所有注册的路由 // 这些中间件将应用于所有注册的路由
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { func (engine *Engine) Use(middleware ...HandlerFunc) IRouter {
engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.globalHandlers = append(engine.globalHandlers, middleware...)
engine.rebuildFallbackChains()
return engine return engine
} }
@ -746,48 +805,24 @@ func (engine *Engine) handleRequest(c *Context) {
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
return return
} }
if engine.RedirectFixedPath { if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) {
// 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历.
ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash)
if found { if found {
c.fixedPathBuf = ciPath[:0]
c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径
return return
} }
c.fixedPathBuf = ciPath[:0]
} }
} }
} }
// 构建处理链
// 组合全局中间件和路由处理函数
handlers := engine.globalHandlers
// 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由
// 则在全局中间件之后添加 MethodNotAllowed 处理器
if engine.HandleMethodNotAllowed {
handlers = append(handlers, MethodNotAllowed())
}
// 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed
// 则在处理链的最后添加 UnMatchFS 处理器
if engine.unMatchFS.ServeUnmatchedAsFS { if engine.unMatchFS.ServeUnmatchedAsFS {
/* c.handlers = engine.unmatchedFSChain
var unMatchFSHandle = c.engine.unMatchFileServer } else {
handlers = append(handlers, unMatchFSHandle) c.handlers = engine.notFoundChain
*/
handlers = append(handlers, engine.UnMatchFSRoutes...)
} }
// 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS
// 则在处理链的最后添加 NoRoute 处理器
if engine.noRoute != nil {
handlers = append(handlers, engine.noRoute)
} else if len(engine.noRoutes) > 0 {
handlers = append(handlers, engine.noRoutes...)
}
handlers = append(handlers, NotFound())
c.handlers = handlers
c.Next() // 执行处理函数链 c.Next() // 执行处理函数链
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送
} }
@ -813,8 +848,28 @@ func isGeneralOptionsRequest(req *http.Request) bool {
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
} }
func (engine *Engine) allowedMethodsForPath(requestPath string) []string { func shouldTryFixedPathLookup(path string, root *node) bool {
allowedMethods := make([]string, 0, len(engine.methodTrees)) if root != nil && root.hasCaseInsensitivePath {
return true
}
for i := 0; i < len(path); i++ {
c := path[i]
if c >= utf8.RuneSelf {
return true
}
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
}
func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string {
if cap(allowedMethods) < len(engine.methodTrees) {
allowedMethods = make([]string, 0, len(engine.methodTrees))
} else {
allowedMethods = allowedMethods[:0]
}
for _, treeIter := range engine.methodTrees { for _, treeIter := range engine.methodTrees {
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
tempSkippedNodes := GetTempSkippedNodes() tempSkippedNodes := GetTempSkippedNodes()

View file

@ -16,6 +16,9 @@ func buildServeHTTPBenchmarkEngine() *Engine {
engine.GET("/api/v1/users/:id", func(c *Context) { engine.GET("/api/v1/users/:id", func(c *Context) {
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
}) })
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.POST("/api/v1/users", func(c *Context) { engine.POST("/api/v1/users", func(c *Context) {
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
}) })
@ -61,4 +64,8 @@ func BenchmarkServeHTTP(b *testing.B) {
b.Run("OptionsAllow", func(b *testing.B) { b.Run("OptionsAllow", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users")
}) })
b.Run("FixedPathRedirect", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS")
})
} }

102
engine_test.go Normal file
View file

@ -0,0 +1,102 @@
package touka
import (
"net/http"
"testing"
)
func TestHandleRequestRedirectFixedPath(t *testing.T) {
engine := New()
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" {
t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location)
}
}
func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) {
engine := New()
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code)
}
}
func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) {
engine := New()
engine.GET("/Users/Profile", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code)
}
if location := rr.Header().Get("Location"); location != "/Users/Profile" {
t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location)
}
}
func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) {
engine := New()
engine.NoRoute(func(c *Context) {
c.Writer.Header().Set("X-NoRoute", "hit")
c.Next()
})
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code)
}
if got := rr.Header().Get("X-NoRoute"); got != "hit" {
t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got)
}
}
func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.NoRoute(func(c *Context) {
c.Writer.Header().Set("X-NoRoute", "hit")
c.Next()
})
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
if rr.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
if got := rr.Header().Get("X-NoRoute"); got != "" {
t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got)
}
}
func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.POST("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil)
if rr.Code != http.StatusOK {
t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code)
}
allow := rr.Header().Get("Allow")
if allow != "GET, POST" && allow != "POST, GET" {
t.Fatalf("expected Allow header to list matching methods, got %q", allow)
}
}

View file

@ -699,8 +699,17 @@ func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) {
if c.engine != nil { if c.engine != nil {
if c.Request != nil && c.Request.RequestURI != "*" { if c.Request != nil && c.Request.RequestURI != "*" {
if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 { if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 {
c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) c.allowedMethodsBuf = allow[:0]
allowHeader := c.allowHeaderBuf[:0]
for i, method := range allow {
if i > 0 {
allowHeader = append(allowHeader, ',', ' ')
}
allowHeader = append(allowHeader, method...)
}
c.allowHeaderBuf = allowHeader[:0]
c.Writer.Header().Set("Allow", BytesToString(allowHeader))
} }
} }
} }

34
tree.go
View file

@ -124,6 +124,7 @@ type node struct {
path string // 当前节点的路径段 path string // 当前节点的路径段
indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点
wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) wildChild bool // 是否包含通配符子节点(:param 或 *catchAll)
hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由
nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有)
priority uint32 // 节点的优先级, 用于查找时优先匹配 priority uint32 // 节点的优先级, 用于查找时优先匹配
children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾
@ -131,6 +132,19 @@ type node struct {
fullPath string // 完整路径, 用于调试和错误信息 fullPath string // 完整路径, 用于调试和错误信息
} }
func routeNeedsCaseInsensitiveLookup(path string) bool {
for i := 0; i < len(path); i++ {
c := path[i]
if c >= utf8.RuneSelf {
return true
}
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
}
// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序.
func (n *node) incrementChildPrio(pos int) int { func (n *node) incrementChildPrio(pos int) int {
cs := n.children // 获取子节点切片 cs := n.children // 获取子节点切片
@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int {
func (n *node) addRoute(path string, handlers HandlersChain) { func (n *node) addRoute(path string, handlers HandlersChain) {
fullPath := path // 记录完整的路径 fullPath := path // 记录完整的路径
n.priority++ // 增加当前节点的优先级 n.priority++ // 增加当前节点的优先级
if routeNeedsCaseInsensitiveLookup(path) {
n.hasCaseInsensitivePath = true
}
// 如果是空树(根节点) // 如果是空树(根节点)
if len(n.path) == 0 && len(n.children) == 0 { if len(n.path) == 0 && len(n.children) == 0 {
@ -702,13 +719,24 @@ walk: // 外部循环用于遍历路由树
// 它还可以选择修复尾部斜杠. // 它还可以选择修复尾部斜杠.
// 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功.
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash)
}
func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) {
const stackBufSize = 128 // 栈上缓冲区的默认大小 const stackBufSize = 128 // 栈上缓冲区的默认大小
// 在常见情况下使用栈上静态大小的缓冲区. // 在常见情况下使用栈上静态大小的缓冲区.
// 如果路径太长, 则在堆上分配缓冲区. // 如果路径太长, 则在堆上分配缓冲区.
buf := make([]byte, 0, stackBufSize) if buf != nil {
if length := len(path) + 1; length > stackBufSize { buf = buf[:0]
buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区 }
if cap(buf) < len(path)+1 {
var stackBuf [stackBufSize]byte
if len(path)+1 <= stackBufSize {
buf = stackBuf[:0]
} else {
buf = make([]byte, 0, len(path)+1) // 如果路径太长, 则分配更大的缓冲区
}
} }
ciPath := n.findCaseInsensitivePathRec( ciPath := n.findCaseInsensitivePathRec(