Merge pull request #81 from infinite-iroha/feat/optimize-route-match-hotpath

Feat/optimize route match hotpath
This commit is contained in:
WJQSERVER 2026-04-07 09:58:10 +08:00 committed by GitHub
commit fca9bbd3ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 731 additions and 131 deletions

3
.gitignore vendored
View file

@ -1 +1,2 @@
test test
/bench_route_match_baseline.txt

View file

@ -73,6 +73,12 @@ type Context struct {
// skippedNodes 用于记录跳过的节点信息,以便回溯 // skippedNodes 用于记录跳过的节点信息,以便回溯
// 通常在处理嵌套路由时使用 // 通常在处理嵌套路由时使用
SkippedNodes []skippedNode SkippedNodes []skippedNode
// fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲.
fixedPathBuf []byte
allowedMethodsBuf []string
allowHeaderBuf []byte
} }
// --- Context 相关方法实现 --- // --- Context 相关方法实现 ---
@ -97,7 +103,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
} }
c.handlers = nil c.handlers = nil
c.index = -1 // 初始为 -1`Next()` 将其设置为 0 c.index = -1 // 初始为 -1`Next()` 将其设置为 0
c.Keys = make(map[string]any) // 每次请求重新创建 map避免数据污染 c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map
c.Errors = c.Errors[:0] // 清空 Errors 切片 c.Errors = c.Errors[:0] // 清空 Errors 切片
c.queryCache = nil // 清空查询参数缓存 c.queryCache = nil // 清空查询参数缓存
c.formCache = nil // 清空表单数据缓存 c.formCache = nil // 清空表单数据缓存
@ -111,6 +117,15 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
} else { } else {
c.SkippedNodes = make([]skippedNode, 0, 256) c.SkippedNodes = make([]skippedNode, 0, 256)
} }
if cap(c.fixedPathBuf) > 0 {
c.fixedPathBuf = c.fixedPathBuf[:0]
}
if cap(c.allowedMethodsBuf) > 0 {
c.allowedMethodsBuf = c.allowedMethodsBuf[:0]
}
if cap(c.allowHeaderBuf) > 0 {
c.allowHeaderBuf = c.allowHeaderBuf[:0]
}
} }
// Next 在处理链中执行下一个处理函数 // Next 在处理链中执行下一个处理函数

78
context_benchmark_test.go Normal file
View file

@ -0,0 +1,78 @@
package touka
import (
"net/http"
"testing"
)
func TestContextResetKeepsKeysNilUntilSet(t *testing.T) {
c, _ := CreateTestContext(nil)
if c.Keys != nil {
t.Fatalf("expected fresh test context Keys to be nil before first Set")
}
c.Set("answer", 42)
if c.Keys == nil {
t.Fatalf("expected Set to allocate Keys map")
}
if value, exists := c.Get("answer"); !exists || value != 42 {
t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists)
}
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("failed to build request: %v", err)
}
c.reset(c.Writer, req)
if c.Keys != nil {
t.Fatalf("expected reset to clear Keys without allocating a new map")
}
if value, exists := c.Get("answer"); exists || value != nil {
t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists)
}
ctxValue := c.Value("missing")
if ctxValue != nil {
t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue)
}
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected MustGet to panic for missing key after reset")
}
}()
_ = c.MustGet("answer")
}
func BenchmarkContextReset(b *testing.B) {
b.Run("NoKeysUse", func(b *testing.B) {
c, _ := CreateTestContext(nil)
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.reset(c.Writer, req)
}
})
b.Run("WithKeysUse", func(b *testing.B) {
c, _ := CreateTestContext(nil)
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.reset(c.Writer, req)
c.Set("request-id", i)
}
})
}

227
engine.go
View file

@ -11,6 +11,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"unicode/utf8"
"net/http" "net/http"
@ -82,6 +83,11 @@ type Engine struct {
// GlobalMaxRequestBodySize 全局请求体Body大小限制 // GlobalMaxRequestBodySize 全局请求体Body大小限制
GlobalMaxRequestBodySize int64 GlobalMaxRequestBodySize int64
notFoundChain HandlersChain
notFoundNoMethodChain HandlersChain
unmatchedFSChain HandlersChain
unmatchedFSNoMethodChain HandlersChain
} }
// HandleFunc 注册一个或多个 HTTP 方法的路由 // HandleFunc 注册一个或多个 HTTP 方法的路由
@ -117,6 +123,64 @@ type ErrorHandle struct {
type ErrorHandler func(c *Context, code int, err error) type ErrorHandler func(c *Context, code int, err error)
var errMethodNotAllowed = errors.New("method not allowed")
var errNotFound = errors.New("not found")
type defaultErrorResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Error string `json:"error"`
}
var methodNotAllowedHandler HandlerFunc = func(c *Context) {
httpMethod := c.Request.Method
requestPath := routeLookupPath(c.Request)
engine := c.engine
// 是否是OPTIONS方式
if httpMethod == http.MethodOptions {
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0])
c.allowedMethodsBuf = allowedMethods[:0]
if len(allowedMethods) > 0 {
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
allowHeader := c.allowHeaderBuf[:0]
for i, method := range allowedMethods {
if i > 0 {
allowHeader = append(allowHeader, ',', ' ')
}
allowHeader = append(allowHeader, method...)
}
c.allowHeaderBuf = allowHeader[:0]
c.Writer.Header().Set("Allow", string(allowHeader))
c.Status(http.StatusOK)
return
}
return
}
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
tempSkippedNodes := GetTempSkippedNodes()
for _, treeIter := range engine.methodTrees {
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
continue
}
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
*tempSkippedNodes = (*tempSkippedNodes)[:0]
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
if value.handlers != nil {
PutTempSkippedNodes(tempSkippedNodes)
// 使用定义的ErrorHandle处理
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed)
return
}
}
PutTempSkippedNodes(tempSkippedNodes)
}
var notFoundHandler HandlerFunc = func(c *Context) {
engine := c.engine
engine.errorHandle.handler(c, http.StatusNotFound, errNotFound)
}
// defaultErrorHandle 默认错误处理 // defaultErrorHandle 默认错误处理
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
select { select {
@ -132,11 +196,7 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是
if err != nil { if err != nil {
errMsg = err.Error() errMsg = err.Error()
} }
c.JSON(code, H{ c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg})
"code": code,
"message": http.StatusText(code),
"error": errMsg,
})
c.Writer.Flush() c.Writer.Flush()
c.Abort() c.Abort()
return return
@ -211,6 +271,7 @@ func New() *Engine {
TLSServerConfigurator: nil, TLSServerConfigurator: nil,
GlobalMaxRequestBodySize: -1, GlobalMaxRequestBodySize: -1,
} }
engine.rebuildFallbackChains()
engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background())
//engine.SetProtocols(GetDefaultProtocolsConfig()) //engine.SetProtocols(GetDefaultProtocolsConfig())
engine.SetDefaultProtocols() engine.SetDefaultProtocols()
@ -266,6 +327,7 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) {
// 是否开启MethodNotAllowed // 是否开启MethodNotAllowed
func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { func (engine *Engine) SetHandleMethodNotAllowed(enable bool) {
engine.HandleMethodNotAllowed = enable engine.HandleMethodNotAllowed = enable
engine.rebuildFallbackChains()
} }
// SetLogger传入实例 // SetLogger传入实例
@ -306,6 +368,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF
engine.unMatchFS.ServeUnmatchedAsFS = false engine.unMatchFS.ServeUnmatchedAsFS = false
engine.UnMatchFSRoutes = nil engine.UnMatchFSRoutes = nil
} }
engine.rebuildFallbackChains()
} }
// 获取默认Protocol配置 // 获取默认Protocol配置
@ -479,57 +542,64 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
// 405中间件 // 405中间件
func MethodNotAllowed() HandlerFunc { func MethodNotAllowed() HandlerFunc {
return func(c *Context) { return methodNotAllowedHandler
httpMethod := c.Request.Method
requestPath := routeLookupPath(c.Request)
engine := c.engine
// 是否是OPTIONS方式
if httpMethod == http.MethodOptions {
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
allowedMethods := engine.allowedMethodsForPath(requestPath)
if len(allowedMethods) > 0 {
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
c.Status(http.StatusOK)
return
}
}
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
for _, treeIter := range engine.methodTrees {
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
continue
}
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
tempSkippedNodes := GetTempSkippedNodes()
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil {
// 使用定义的ErrorHandle处理
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
return
}
}
}
} }
// 404最后处理 // 404最后处理
func NotFound() HandlerFunc { func NotFound() HandlerFunc {
return func(c *Context) { return notFoundHandler
engine := c.engine
engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found"))
}
} }
// 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理)
func (Engine *Engine) NoRoute(handler HandlerFunc) { func (Engine *Engine) NoRoute(handler HandlerFunc) {
Engine.noRoute = handler Engine.noRoute = handler
Engine.noRoutes = nil Engine.noRoutes = nil
Engine.rebuildFallbackChains()
} }
// 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理)
func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) {
Engine.noRoute = nil Engine.noRoute = nil
Engine.noRoutes = handlerFuncs Engine.noRoutes = handlerFuncs
Engine.rebuildFallbackChains()
}
func (engine *Engine) rebuildFallbackChains() {
buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain {
finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound
if includeMethodNotAllowed {
finalSize++
}
if includeUnmatchedFS {
finalSize += len(engine.UnMatchFSRoutes)
}
if engine.noRoute != nil {
finalSize++
} else {
finalSize += len(engine.noRoutes)
}
chain := make(HandlersChain, 0, finalSize)
chain = append(chain, engine.globalHandlers...)
if includeMethodNotAllowed {
chain = append(chain, methodNotAllowedHandler)
}
if includeUnmatchedFS {
chain = append(chain, engine.UnMatchFSRoutes...)
}
if engine.noRoute != nil {
chain = append(chain, engine.noRoute)
} else if len(engine.noRoutes) > 0 {
chain = append(chain, engine.noRoutes...)
}
chain = append(chain, notFoundHandler)
return chain
}
engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false)
engine.notFoundNoMethodChain = buildChain(false, false)
engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS)
engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS)
} }
// combineHandlers 组合多个处理函数链为一个 // combineHandlers 组合多个处理函数链为一个
@ -546,6 +616,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
// 这些中间件将应用于所有注册的路由 // 这些中间件将应用于所有注册的路由
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { func (engine *Engine) Use(middleware ...HandlerFunc) IRouter {
engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.globalHandlers = append(engine.globalHandlers, middleware...)
engine.rebuildFallbackChains()
return engine return engine
} }
@ -739,47 +810,24 @@ func (engine *Engine) handleRequest(c *Context) {
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
return return
} }
// 尝试不区分大小写的查找 if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) {
// 直接在 rootNode 上调用 findCaseInsensitivePath 方法 // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历.
ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash)
if found && engine.RedirectFixedPath { if found {
c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 c.fixedPathBuf = ciPath[:0]
return c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径
return
}
c.fixedPathBuf = c.fixedPathBuf[:0]
} }
} }
} }
// 构建处理链
// 组合全局中间件和路由处理函数
handlers := engine.globalHandlers
// 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由
// 则在全局中间件之后添加 MethodNotAllowed 处理器
if engine.HandleMethodNotAllowed {
handlers = append(handlers, MethodNotAllowed())
}
// 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed
// 则在处理链的最后添加 UnMatchFS 处理器
if engine.unMatchFS.ServeUnmatchedAsFS { if engine.unMatchFS.ServeUnmatchedAsFS {
/* c.handlers = engine.unmatchedFSChain
var unMatchFSHandle = c.engine.unMatchFileServer } else {
handlers = append(handlers, unMatchFSHandle) c.handlers = engine.notFoundChain
*/
handlers = append(handlers, engine.UnMatchFSRoutes...)
} }
// 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS
// 则在处理链的最后添加 NoRoute 处理器
if engine.noRoute != nil {
handlers = append(handlers, engine.noRoute)
} else if len(engine.noRoutes) > 0 {
handlers = append(handlers, engine.noRoutes...)
}
handlers = append(handlers, NotFound())
c.handlers = handlers
c.Next() // 执行处理函数链 c.Next() // 执行处理函数链
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送
} }
@ -805,17 +853,38 @@ func isGeneralOptionsRequest(req *http.Request) bool {
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
} }
func (engine *Engine) allowedMethodsForPath(requestPath string) []string { func shouldTryFixedPathLookup(path string, root *node) bool {
allowedMethods := make([]string, 0, len(engine.methodTrees)) if root != nil && root.hasCaseInsensitivePath {
return true
}
for i := 0; i < len(path); i++ {
c := path[i]
if c >= utf8.RuneSelf {
return true
}
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
}
func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string {
if cap(allowedMethods) < len(engine.methodTrees) {
allowedMethods = make([]string, 0, len(engine.methodTrees))
} else {
allowedMethods = allowedMethods[:0]
}
tempSkippedNodes := GetTempSkippedNodes()
for _, treeIter := range engine.methodTrees { for _, treeIter := range engine.methodTrees {
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
tempSkippedNodes := GetTempSkippedNodes() *tempSkippedNodes = (*tempSkippedNodes)[:0]
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil { if value.handlers != nil {
allowedMethods = append(allowedMethods, treeIter.method) allowedMethods = append(allowedMethods, treeIter.method)
} }
} }
PutTempSkippedNodes(tempSkippedNodes)
return allowedMethods return allowedMethods
} }

71
engine_benchmark_test.go Normal file
View file

@ -0,0 +1,71 @@
package touka
import (
"net/http"
"net/http/httptest"
"testing"
)
var benchmarkStatusCode int
func buildServeHTTPBenchmarkEngine() *Engine {
engine := New()
engine.GET("/api/v1/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.GET("/api/v1/users/:id", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.POST("/api/v1/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
return engine
}
func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) {
b.Helper()
req, err := http.NewRequest(method, path, nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr = httptest.NewRecorder()
engine.ServeHTTP(rr, req)
}
benchmarkStatusCode = rr.Code
}
func BenchmarkServeHTTP(b *testing.B) {
engine := buildServeHTTPBenchmarkEngine()
b.Run("StaticHit", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users")
})
b.Run("NotFound", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist")
})
b.Run("MethodNotAllowed", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users")
})
b.Run("OptionsAllow", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users")
})
b.Run("FixedPathRedirect", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS")
})
}

141
engine_test.go Normal file
View file

@ -0,0 +1,141 @@
package touka
import (
"encoding/json"
"net/http"
"testing"
)
func TestHandleRequestRedirectFixedPath(t *testing.T) {
engine := New()
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" {
t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location)
}
}
func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) {
engine := New()
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code)
}
}
func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) {
engine := New()
engine.GET("/Users/Profile", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code)
}
if location := rr.Header().Get("Location"); location != "/Users/Profile" {
t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location)
}
}
func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) {
engine := New()
engine.GET("/Users/Profile", func(c *Context) {
c.Status(http.StatusNoContent)
})
defer func() {
if r := recover(); r != nil {
t.Fatalf("unexpected panic for fixed-path miss: %v", r)
}
}()
rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code)
}
}
func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) {
engine := New()
engine.NoRoute(func(c *Context) {
c.Writer.Header().Set("X-NoRoute", "hit")
c.Next()
})
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code)
}
if got := rr.Header().Get("X-NoRoute"); got != "hit" {
t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got)
}
}
func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.NoRoute(func(c *Context) {
c.Writer.Header().Set("X-NoRoute", "hit")
c.Next()
})
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
if rr.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
if got := rr.Header().Get("X-NoRoute"); got != "" {
t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got)
}
}
func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.POST("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil)
if rr.Code != http.StatusOK {
t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code)
}
allow := rr.Header().Get("Allow")
if allow != "GET, POST" && allow != "POST, GET" {
t.Fatalf("expected Allow header to list matching methods, got %q", allow)
}
}
func TestDefaultErrorHandleJSONShape(t *testing.T) {
engine := New()
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code)
}
var body struct {
Code int `json:"code"`
Message string `json:"message"`
Error string `json:"error"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
}
if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" {
t.Fatalf("unexpected error payload: %+v", body)
}
}

View file

@ -699,8 +699,17 @@ func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) {
if c.engine != nil { if c.engine != nil {
if c.Request != nil && c.Request.RequestURI != "*" { if c.Request != nil && c.Request.RequestURI != "*" {
if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 { if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 {
c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) c.allowedMethodsBuf = allow[:0]
allowHeader := c.allowHeaderBuf[:0]
for i, method := range allow {
if i > 0 {
allowHeader = append(allowHeader, ',', ' ')
}
allowHeader = append(allowHeader, method...)
}
c.allowHeaderBuf = allowHeader[:0]
c.Writer.Header().Set("Allow", string(allowHeader))
} }
} }
} }

View file

@ -0,0 +1,130 @@
package touka
import "testing"
var (
benchmarkRouteHandlers HandlersChain
benchmarkRouteFullPath string
benchmarkRouteParamsLen int
benchmarkRouteCIPath []byte
benchmarkRouteCIFound bool
)
func buildRouteMatchBenchmarkTree() *node {
tree := &node{}
routes := []string{
"/",
"/health",
"/contact",
"/api/v1/users",
"/api/v1/users/:id",
"/api/v1/users/:id/settings",
"/assets/*filepath",
"/abc/b",
"/abc/:p1/cde",
"/abc/:p1/:p2/def/*filepath",
}
for _, route := range routes {
tree.addRoute(route, fakeHandler(route))
}
return tree
}
func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) {
b.Helper()
params := make(Params, 0, 4)
skipped := make([]skippedNode, 0, 8)
value := tree.getValue(path, &params, &skipped, true)
if wantFullPath == "" {
if value.handlers != nil {
b.Fatalf("expected no match for %q, got %q", path, value.fullPath)
}
} else {
if value.handlers == nil {
b.Fatalf("expected match for %q, got nil handlers", path)
}
if value.fullPath != wantFullPath {
b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath)
}
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
params = params[:0]
skipped = skipped[:0]
value = tree.getValue(path, &params, &skipped, true)
}
benchmarkRouteHandlers = value.handlers
benchmarkRouteFullPath = value.fullPath
if value.params != nil {
benchmarkRouteParamsLen = len(*value.params)
} else {
benchmarkRouteParamsLen = 0
}
}
func BenchmarkRouteMatch(b *testing.B) {
tree := buildRouteMatchBenchmarkTree()
b.Run("StaticHit", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users")
})
b.Run("ParamHit", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id")
})
b.Run("BacktrackingHit", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath")
})
b.Run("Miss", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/does/not/exist", "")
})
b.Run("CaseInsensitiveHit", func(b *testing.B) {
path := "/API/V1/USERS/123/SETTINGS"
out, found := tree.findCaseInsensitivePath(path, true)
if !found {
b.Fatalf("expected fixed-path match for %q", path)
}
if got := string(out); got != "/api/v1/users/123/settings" {
b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
out, found = tree.findCaseInsensitivePath(path, true)
}
benchmarkRouteCIPath = out
benchmarkRouteCIFound = found
})
b.Run("CaseInsensitiveMiss", func(b *testing.B) {
path := "/DOES/NOT/EXIST"
out, found := tree.findCaseInsensitivePath(path, true)
if found || out != nil {
b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
out, found = tree.findCaseInsensitivePath(path, true)
}
benchmarkRouteCIPath = out
benchmarkRouteCIFound = found
})
}

116
tree.go
View file

@ -121,14 +121,28 @@ const (
// node 表示路由树中的一个节点. // node 表示路由树中的一个节点.
type node struct { type node struct {
path string // 当前节点的路径段 path string // 当前节点的路径段
indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点
wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) wildChild bool // 是否包含通配符子节点(:param 或 *catchAll)
nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由
priority uint32 // 节点的优先级, 用于查找时优先匹配 nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有)
children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 priority uint32 // 节点的优先级, 用于查找时优先匹配
handlers HandlersChain // 绑定到此节点的处理函数链 children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾
fullPath string // 完整路径, 用于调试和错误信息 handlers HandlersChain // 绑定到此节点的处理函数链
fullPath string // 完整路径, 用于调试和错误信息
}
func routeNeedsCaseInsensitiveLookup(path string) bool {
for i := 0; i < len(path); i++ {
c := path[i]
if c >= utf8.RuneSelf {
return true
}
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
} }
// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序.
@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int {
func (n *node) addRoute(path string, handlers HandlersChain) { func (n *node) addRoute(path string, handlers HandlersChain) {
fullPath := path // 记录完整的路径 fullPath := path // 记录完整的路径
n.priority++ // 增加当前节点的优先级 n.priority++ // 增加当前节点的优先级
if routeNeedsCaseInsensitiveLookup(path) {
n.hasCaseInsensitivePath = true
}
// 如果是空树(根节点) // 如果是空树(根节点)
if len(n.path) == 0 && len(n.children) == 0 { if len(n.path) == 0 && len(n.children) == 0 {
@ -452,12 +469,14 @@ type skippedNode struct {
// 建议进行 TSR(尾部斜杠重定向). // 建议进行 TSR(尾部斜杠重定向).
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
var globalParamsCount int16 // 全局参数计数 var globalParamsCount int16 // 全局参数计数
var backtrackToWildChild bool
walk: // 外部循环用于遍历路由树 walk: // 外部循环用于遍历路由树
for { for {
prefix := n.path // 当前节点的路径前缀 prefix := n.path // 当前节点的路径前缀
if len(path) > len(prefix) { if len(path) > len(prefix) {
if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头
pathAtNode := path
path = path[len(prefix):] // 移除已匹配的前缀 path = path[len(prefix):] // 移除已匹配的前缀
// 在访问 path[0] 之前进行安全检查 // 在访问 path[0] 之前进行安全检查
@ -467,30 +486,26 @@ walk: // 外部循环用于遍历路由树
// 优先尝试所有非通配符子节点, 通过匹配索引字符 // 优先尝试所有非通配符子节点, 通过匹配索引字符
idxc := path[0] // 剩余路径的第一个字符 idxc := path[0] // 剩余路径的第一个字符
for i, c := range []byte(n.indices) { if !backtrackToWildChild {
if c == idxc { // 如果找到匹配的索引字符 for i := 0; i < len(n.indices); i++ {
// 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 if n.indices[i] == idxc { // 如果找到匹配的索引字符
if n.wildChild { // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯
index := len(*skippedNodes) if n.wildChild {
*skippedNodes = (*skippedNodes)[:index+1] index := len(*skippedNodes)
(*skippedNodes)[index] = skippedNode{ *skippedNodes = (*skippedNodes)[:index+1]
path: prefix + path, // 记录跳过的路径 (*skippedNodes)[index] = skippedNode{
node: &node{ // 复制当前节点的状态 path: pathAtNode, // 记录进入当前节点时的剩余路径
path: n.path, node: n,
wildChild: n.wildChild, paramsCount: globalParamsCount, // 记录当前参数计数
nType: n.nType, }
priority: n.priority,
children: n.children,
handlers: n.handlers,
fullPath: n.fullPath,
},
paramsCount: globalParamsCount, // 记录当前参数计数
} }
}
n = n.children[i] // 移动到匹配的子节点 n = n.children[i] // 移动到匹配的子节点
continue walk // 继续外部循环 continue walk // 继续外部循环
}
} }
} else {
backtrackToWildChild = false
} }
if !n.wildChild { if !n.wildChild {
@ -507,7 +522,8 @@ walk: // 外部循环用于遍历路由树
*value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片
} }
globalParamsCount = skippedNode.paramsCount // 恢复参数计数 globalParamsCount = skippedNode.paramsCount // 恢复参数计数
continue walk // 继续外部循环 backtrackToWildChild = true
continue walk // 继续外部循环
} }
} }
} }
@ -547,7 +563,7 @@ walk: // 外部循环用于遍历路由树
i := len(*value.params) i := len(*value.params)
*value.params = (*value.params)[:i+1] // 扩展切片 *value.params = (*value.params)[:i+1] // 扩展切片
val := path[:end] // 提取参数值 val := path[:end] // 提取参数值
if unescape { // 如果需要进行 URL 解码 if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
if v, err := url.QueryUnescape(val); err == nil { if v, err := url.QueryUnescape(val); err == nil {
val = v // 解码成功则更新值 val = v // 解码成功则更新值
} }
@ -599,7 +615,7 @@ walk: // 外部循环用于遍历路由树
i := len(*value.params) i := len(*value.params)
*value.params = (*value.params)[:i+1] // 扩展切片 *value.params = (*value.params)[:i+1] // 扩展切片
val := path // 参数值是剩余的整个路径 val := path // 参数值是剩余的整个路径
if unescape { // 如果需要进行 URL 解码 if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
if v, err := url.QueryUnescape(path); err == nil { if v, err := url.QueryUnescape(path); err == nil {
val = v // 解码成功则更新值 val = v // 解码成功则更新值
} }
@ -634,6 +650,7 @@ walk: // 外部循环用于遍历路由树
*value.params = (*value.params)[:skippedNode.paramsCount] *value.params = (*value.params)[:skippedNode.paramsCount]
} }
globalParamsCount = skippedNode.paramsCount globalParamsCount = skippedNode.paramsCount
backtrackToWildChild = true
continue walk continue walk
} }
} }
@ -658,8 +675,8 @@ walk: // 外部循环用于遍历路由树
} }
// 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议 // 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
if c == '/' { // 如果索引中包含 '/' if n.indices[i] == '/' { // 如果索引中包含 '/'
n = n.children[i] // 移动到对应的子节点 n = n.children[i] // 移动到对应的子节点
value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
(n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数
@ -688,6 +705,7 @@ walk: // 外部循环用于遍历路由树
*value.params = (*value.params)[:skippedNode.paramsCount] *value.params = (*value.params)[:skippedNode.paramsCount]
} }
globalParamsCount = skippedNode.paramsCount globalParamsCount = skippedNode.paramsCount
backtrackToWildChild = true
continue walk continue walk
} }
} }
@ -701,13 +719,15 @@ walk: // 外部循环用于遍历路由树
// 它还可以选择修复尾部斜杠. // 它还可以选择修复尾部斜杠.
// 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功.
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
const stackBufSize = 128 // 栈上缓冲区的默认大小 return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash)
}
// 在常见情况下使用栈上静态大小的缓冲区. func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) {
// 如果路径太长, 则在堆上分配缓冲区. if buf != nil {
buf := make([]byte, 0, stackBufSize) buf = buf[:0]
if length := len(path) + 1; length > stackBufSize { }
buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区 if cap(buf) < len(path)+1 {
buf = make([]byte, 0, len(path)+1)
} }
ciPath := n.findCaseInsensitivePathRec( ciPath := n.findCaseInsensitivePathRec(
@ -758,8 +778,8 @@ walk: // 外部循环用于遍历路由树
// 未找到处理函数. // 未找到处理函数.
// 尝试通过添加尾部斜杠来修复路径 // 尝试通过添加尾部斜杠来修复路径
if fixTrailingSlash { if fixTrailingSlash {
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
if c == '/' { // 如果索引中包含 '/' if n.indices[i] == '/' { // 如果索引中包含 '/'
n = n.children[i] // 移动到对应的子节点 n = n.children[i] // 移动到对应的子节点
if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
(n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数
@ -781,8 +801,8 @@ walk: // 外部循环用于遍历路由树
if rb[0] != 0 { if rb[0] != 0 {
// 旧 rune 未处理完 // 旧 rune 未处理完
idxc := rb[0] idxc := rb[0]
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
if c == idxc { if n.indices[i] == idxc {
// 继续处理子节点 // 继续处理子节点
n = n.children[i] n = n.children[i]
npLen = len(n.path) npLen = len(n.path)
@ -813,9 +833,9 @@ walk: // 外部循环用于遍历路由树
rb = shiftNRuneBytes(rb, off) rb = shiftNRuneBytes(rb, off)
idxc := rb[0] idxc := rb[0]
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
// 小写匹配 // 小写匹配
if c == idxc { if n.indices[i] == idxc {
// 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在 // 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在
if out := n.children[i].findCaseInsensitivePathRec( if out := n.children[i].findCaseInsensitivePathRec(
path, ciPath, rb, fixTrailingSlash, path, ciPath, rb, fixTrailingSlash,
@ -832,9 +852,9 @@ walk: // 外部循环用于遍历路由树
rb = shiftNRuneBytes(rb, off) rb = shiftNRuneBytes(rb, off)
idxc := rb[0] idxc := rb[0]
for i, c := range []byte(n.indices) { for i := 0; i < len(n.indices); i++ {
// 大写匹配 // 大写匹配
if c == idxc { if n.indices[i] == idxc {
// 继续处理子节点 // 继续处理子节点
n = n.children[i] n = n.children[i]
npLen = len(n.path) npLen = len(n.path)

View file

@ -11,6 +11,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"testing" "testing"
"time"
) )
// Used as a workaround since we can't compare functions or their addresses // Used as a workaround since we can't compare functions or their addresses
@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode {
return &ps return &ps
} }
func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue {
t.Helper()
resultCh := make(chan nodeValue, 1)
go func() {
resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape)
}()
select {
case value := <-resultCh:
return value
case <-time.After(2 * time.Second):
t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path)
return nodeValue{}
}
}
func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) {
unescape := false unescape := false
if len(unescapes) >= 1 { if len(unescapes) >= 1 {
@ -1104,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) {
t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams) t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams)
} }
} }
func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) {
tests := []struct {
name string
routes []string
requestPath string
wantFullPath string
wantParams Params
}{
{
name: "param route after static dead end",
routes: []string{"/foo/bar", "/foo/:id/details"},
requestPath: "/foo/bar/details",
wantFullPath: "/foo/:id/details",
wantParams: Params{{Key: "id", Value: "bar"}},
},
{
name: "catch-all route after static dead end",
routes: []string{"/foo/bar", "/foo/:id/*rest"},
requestPath: "/foo/bar/baz.txt",
wantFullPath: "/foo/:id/*rest",
wantParams: Params{
{Key: "id", Value: "bar"},
{Key: "rest", Value: "/baz.txt"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tree := &node{}
for _, route := range tt.routes {
tree.addRoute(route, fakeHandler(route))
}
value := getValueWithTimeout(t, tree, tt.requestPath, false)
if value.handlers == nil {
t.Fatalf("expected handlers for %q", tt.requestPath)
}
if value.fullPath != tt.wantFullPath {
t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath)
}
if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) {
t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params)
}
})
}
}