mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Merge pull request #81 from infinite-iroha/feat/optimize-route-match-hotpath
Feat/optimize route match hotpath
This commit is contained in:
commit
fca9bbd3ef
10 changed files with 731 additions and 131 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -1 +1,2 @@
|
||||||
test
|
test
|
||||||
|
/bench_route_match_baseline.txt
|
||||||
|
|
|
||||||
17
context.go
17
context.go
|
|
@ -73,6 +73,12 @@ type Context struct {
|
||||||
// skippedNodes 用于记录跳过的节点信息,以便回溯
|
// skippedNodes 用于记录跳过的节点信息,以便回溯
|
||||||
// 通常在处理嵌套路由时使用
|
// 通常在处理嵌套路由时使用
|
||||||
SkippedNodes []skippedNode
|
SkippedNodes []skippedNode
|
||||||
|
|
||||||
|
// fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲.
|
||||||
|
fixedPathBuf []byte
|
||||||
|
|
||||||
|
allowedMethodsBuf []string
|
||||||
|
allowHeaderBuf []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Context 相关方法实现 ---
|
// --- Context 相关方法实现 ---
|
||||||
|
|
@ -97,7 +103,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
c.handlers = nil
|
c.handlers = nil
|
||||||
c.index = -1 // 初始为 -1,`Next()` 将其设置为 0
|
c.index = -1 // 初始为 -1,`Next()` 将其设置为 0
|
||||||
c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染
|
c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map
|
||||||
c.Errors = c.Errors[:0] // 清空 Errors 切片
|
c.Errors = c.Errors[:0] // 清空 Errors 切片
|
||||||
c.queryCache = nil // 清空查询参数缓存
|
c.queryCache = nil // 清空查询参数缓存
|
||||||
c.formCache = nil // 清空表单数据缓存
|
c.formCache = nil // 清空表单数据缓存
|
||||||
|
|
@ -111,6 +117,15 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
||||||
} else {
|
} else {
|
||||||
c.SkippedNodes = make([]skippedNode, 0, 256)
|
c.SkippedNodes = make([]skippedNode, 0, 256)
|
||||||
}
|
}
|
||||||
|
if cap(c.fixedPathBuf) > 0 {
|
||||||
|
c.fixedPathBuf = c.fixedPathBuf[:0]
|
||||||
|
}
|
||||||
|
if cap(c.allowedMethodsBuf) > 0 {
|
||||||
|
c.allowedMethodsBuf = c.allowedMethodsBuf[:0]
|
||||||
|
}
|
||||||
|
if cap(c.allowHeaderBuf) > 0 {
|
||||||
|
c.allowHeaderBuf = c.allowHeaderBuf[:0]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next 在处理链中执行下一个处理函数
|
// Next 在处理链中执行下一个处理函数
|
||||||
|
|
|
||||||
78
context_benchmark_test.go
Normal file
78
context_benchmark_test.go
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextResetKeepsKeysNilUntilSet(t *testing.T) {
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
if c.Keys != nil {
|
||||||
|
t.Fatalf("expected fresh test context Keys to be nil before first Set")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("answer", 42)
|
||||||
|
if c.Keys == nil {
|
||||||
|
t.Fatalf("expected Set to allocate Keys map")
|
||||||
|
}
|
||||||
|
if value, exists := c.Get("answer"); !exists || value != 42 {
|
||||||
|
t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
c.reset(c.Writer, req)
|
||||||
|
|
||||||
|
if c.Keys != nil {
|
||||||
|
t.Fatalf("expected reset to clear Keys without allocating a new map")
|
||||||
|
}
|
||||||
|
if value, exists := c.Get("answer"); exists || value != nil {
|
||||||
|
t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxValue := c.Value("missing")
|
||||||
|
if ctxValue != nil {
|
||||||
|
t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatalf("expected MustGet to panic for missing key after reset")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
_ = c.MustGet("answer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContextReset(b *testing.B) {
|
||||||
|
b.Run("NoKeysUse", func(b *testing.B) {
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
c.reset(c.Writer, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WithKeysUse", func(b *testing.B) {
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
c.reset(c.Writer, req)
|
||||||
|
c.Set("request-id", i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
227
engine.go
227
engine.go
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
|
@ -82,6 +83,11 @@ type Engine struct {
|
||||||
|
|
||||||
// GlobalMaxRequestBodySize 全局请求体Body大小限制
|
// GlobalMaxRequestBodySize 全局请求体Body大小限制
|
||||||
GlobalMaxRequestBodySize int64
|
GlobalMaxRequestBodySize int64
|
||||||
|
|
||||||
|
notFoundChain HandlersChain
|
||||||
|
notFoundNoMethodChain HandlersChain
|
||||||
|
unmatchedFSChain HandlersChain
|
||||||
|
unmatchedFSNoMethodChain HandlersChain
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
||||||
|
|
@ -117,6 +123,64 @@ type ErrorHandle struct {
|
||||||
|
|
||||||
type ErrorHandler func(c *Context, code int, err error)
|
type ErrorHandler func(c *Context, code int, err error)
|
||||||
|
|
||||||
|
var errMethodNotAllowed = errors.New("method not allowed")
|
||||||
|
var errNotFound = errors.New("not found")
|
||||||
|
|
||||||
|
type defaultErrorResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var methodNotAllowedHandler HandlerFunc = func(c *Context) {
|
||||||
|
httpMethod := c.Request.Method
|
||||||
|
requestPath := routeLookupPath(c.Request)
|
||||||
|
engine := c.engine
|
||||||
|
// 是否是OPTIONS方式
|
||||||
|
if httpMethod == http.MethodOptions {
|
||||||
|
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
||||||
|
allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0])
|
||||||
|
c.allowedMethodsBuf = allowedMethods[:0]
|
||||||
|
if len(allowedMethods) > 0 {
|
||||||
|
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
|
||||||
|
allowHeader := c.allowHeaderBuf[:0]
|
||||||
|
for i, method := range allowedMethods {
|
||||||
|
if i > 0 {
|
||||||
|
allowHeader = append(allowHeader, ',', ' ')
|
||||||
|
}
|
||||||
|
allowHeader = append(allowHeader, method...)
|
||||||
|
}
|
||||||
|
c.allowHeaderBuf = allowHeader[:0]
|
||||||
|
c.Writer.Header().Set("Allow", string(allowHeader))
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
|
||||||
|
tempSkippedNodes := GetTempSkippedNodes()
|
||||||
|
for _, treeIter := range engine.methodTrees {
|
||||||
|
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||||
|
*tempSkippedNodes = (*tempSkippedNodes)[:0]
|
||||||
|
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
|
||||||
|
if value.handlers != nil {
|
||||||
|
PutTempSkippedNodes(tempSkippedNodes)
|
||||||
|
// 使用定义的ErrorHandle处理
|
||||||
|
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PutTempSkippedNodes(tempSkippedNodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
var notFoundHandler HandlerFunc = func(c *Context) {
|
||||||
|
engine := c.engine
|
||||||
|
engine.errorHandle.handler(c, http.StatusNotFound, errNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
// defaultErrorHandle 默认错误处理
|
// defaultErrorHandle 默认错误处理
|
||||||
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
|
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
|
||||||
select {
|
select {
|
||||||
|
|
@ -132,11 +196,7 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg = err.Error()
|
errMsg = err.Error()
|
||||||
}
|
}
|
||||||
c.JSON(code, H{
|
c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg})
|
||||||
"code": code,
|
|
||||||
"message": http.StatusText(code),
|
|
||||||
"error": errMsg,
|
|
||||||
})
|
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
|
|
@ -211,6 +271,7 @@ func New() *Engine {
|
||||||
TLSServerConfigurator: nil,
|
TLSServerConfigurator: nil,
|
||||||
GlobalMaxRequestBodySize: -1,
|
GlobalMaxRequestBodySize: -1,
|
||||||
}
|
}
|
||||||
|
engine.rebuildFallbackChains()
|
||||||
engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background())
|
engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background())
|
||||||
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
||||||
engine.SetDefaultProtocols()
|
engine.SetDefaultProtocols()
|
||||||
|
|
@ -266,6 +327,7 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) {
|
||||||
// 是否开启MethodNotAllowed
|
// 是否开启MethodNotAllowed
|
||||||
func (engine *Engine) SetHandleMethodNotAllowed(enable bool) {
|
func (engine *Engine) SetHandleMethodNotAllowed(enable bool) {
|
||||||
engine.HandleMethodNotAllowed = enable
|
engine.HandleMethodNotAllowed = enable
|
||||||
|
engine.rebuildFallbackChains()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLogger传入实例
|
// SetLogger传入实例
|
||||||
|
|
@ -306,6 +368,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF
|
||||||
engine.unMatchFS.ServeUnmatchedAsFS = false
|
engine.unMatchFS.ServeUnmatchedAsFS = false
|
||||||
engine.UnMatchFSRoutes = nil
|
engine.UnMatchFSRoutes = nil
|
||||||
}
|
}
|
||||||
|
engine.rebuildFallbackChains()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取默认Protocol配置
|
// 获取默认Protocol配置
|
||||||
|
|
@ -479,57 +542,64 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
|
||||||
|
|
||||||
// 405中间件
|
// 405中间件
|
||||||
func MethodNotAllowed() HandlerFunc {
|
func MethodNotAllowed() HandlerFunc {
|
||||||
return func(c *Context) {
|
return methodNotAllowedHandler
|
||||||
httpMethod := c.Request.Method
|
|
||||||
requestPath := routeLookupPath(c.Request)
|
|
||||||
engine := c.engine
|
|
||||||
// 是否是OPTIONS方式
|
|
||||||
if httpMethod == http.MethodOptions {
|
|
||||||
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
|
||||||
allowedMethods := engine.allowedMethodsForPath(requestPath)
|
|
||||||
if len(allowedMethods) > 0 {
|
|
||||||
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
|
|
||||||
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
|
|
||||||
c.Status(http.StatusOK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
|
|
||||||
for _, treeIter := range engine.methodTrees {
|
|
||||||
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
|
||||||
tempSkippedNodes := GetTempSkippedNodes()
|
|
||||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
|
|
||||||
PutTempSkippedNodes(tempSkippedNodes)
|
|
||||||
if value.handlers != nil {
|
|
||||||
// 使用定义的ErrorHandle处理
|
|
||||||
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 404最后处理
|
// 404最后处理
|
||||||
func NotFound() HandlerFunc {
|
func NotFound() HandlerFunc {
|
||||||
return func(c *Context) {
|
return notFoundHandler
|
||||||
engine := c.engine
|
|
||||||
engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理)
|
// 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理)
|
||||||
func (Engine *Engine) NoRoute(handler HandlerFunc) {
|
func (Engine *Engine) NoRoute(handler HandlerFunc) {
|
||||||
Engine.noRoute = handler
|
Engine.noRoute = handler
|
||||||
Engine.noRoutes = nil
|
Engine.noRoutes = nil
|
||||||
|
Engine.rebuildFallbackChains()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理)
|
// 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理)
|
||||||
func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) {
|
func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) {
|
||||||
Engine.noRoute = nil
|
Engine.noRoute = nil
|
||||||
Engine.noRoutes = handlerFuncs
|
Engine.noRoutes = handlerFuncs
|
||||||
|
Engine.rebuildFallbackChains()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) rebuildFallbackChains() {
|
||||||
|
buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain {
|
||||||
|
finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound
|
||||||
|
if includeMethodNotAllowed {
|
||||||
|
finalSize++
|
||||||
|
}
|
||||||
|
if includeUnmatchedFS {
|
||||||
|
finalSize += len(engine.UnMatchFSRoutes)
|
||||||
|
}
|
||||||
|
if engine.noRoute != nil {
|
||||||
|
finalSize++
|
||||||
|
} else {
|
||||||
|
finalSize += len(engine.noRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
chain := make(HandlersChain, 0, finalSize)
|
||||||
|
chain = append(chain, engine.globalHandlers...)
|
||||||
|
if includeMethodNotAllowed {
|
||||||
|
chain = append(chain, methodNotAllowedHandler)
|
||||||
|
}
|
||||||
|
if includeUnmatchedFS {
|
||||||
|
chain = append(chain, engine.UnMatchFSRoutes...)
|
||||||
|
}
|
||||||
|
if engine.noRoute != nil {
|
||||||
|
chain = append(chain, engine.noRoute)
|
||||||
|
} else if len(engine.noRoutes) > 0 {
|
||||||
|
chain = append(chain, engine.noRoutes...)
|
||||||
|
}
|
||||||
|
chain = append(chain, notFoundHandler)
|
||||||
|
return chain
|
||||||
|
}
|
||||||
|
|
||||||
|
engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false)
|
||||||
|
engine.notFoundNoMethodChain = buildChain(false, false)
|
||||||
|
engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS)
|
||||||
|
engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// combineHandlers 组合多个处理函数链为一个
|
// combineHandlers 组合多个处理函数链为一个
|
||||||
|
|
@ -546,6 +616,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
|
||||||
// 这些中间件将应用于所有注册的路由
|
// 这些中间件将应用于所有注册的路由
|
||||||
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter {
|
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter {
|
||||||
engine.globalHandlers = append(engine.globalHandlers, middleware...)
|
engine.globalHandlers = append(engine.globalHandlers, middleware...)
|
||||||
|
engine.rebuildFallbackChains()
|
||||||
return engine
|
return engine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -739,47 +810,24 @@ func (engine *Engine) handleRequest(c *Context) {
|
||||||
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
|
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 尝试不区分大小写的查找
|
if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) {
|
||||||
// 直接在 rootNode 上调用 findCaseInsensitivePath 方法
|
// 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历.
|
||||||
ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash)
|
ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash)
|
||||||
if found && engine.RedirectFixedPath {
|
if found {
|
||||||
c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径
|
c.fixedPathBuf = ciPath[:0]
|
||||||
return
|
c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.fixedPathBuf = c.fixedPathBuf[:0]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建处理链
|
|
||||||
// 组合全局中间件和路由处理函数
|
|
||||||
handlers := engine.globalHandlers
|
|
||||||
|
|
||||||
// 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由
|
|
||||||
// 则在全局中间件之后添加 MethodNotAllowed 处理器
|
|
||||||
if engine.HandleMethodNotAllowed {
|
|
||||||
handlers = append(handlers, MethodNotAllowed())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed
|
|
||||||
// 则在处理链的最后添加 UnMatchFS 处理器
|
|
||||||
if engine.unMatchFS.ServeUnmatchedAsFS {
|
if engine.unMatchFS.ServeUnmatchedAsFS {
|
||||||
/*
|
c.handlers = engine.unmatchedFSChain
|
||||||
var unMatchFSHandle = c.engine.unMatchFileServer
|
} else {
|
||||||
handlers = append(handlers, unMatchFSHandle)
|
c.handlers = engine.notFoundChain
|
||||||
*/
|
|
||||||
handlers = append(handlers, engine.UnMatchFSRoutes...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS
|
|
||||||
// 则在处理链的最后添加 NoRoute 处理器
|
|
||||||
if engine.noRoute != nil {
|
|
||||||
handlers = append(handlers, engine.noRoute)
|
|
||||||
} else if len(engine.noRoutes) > 0 {
|
|
||||||
handlers = append(handlers, engine.noRoutes...)
|
|
||||||
}
|
|
||||||
|
|
||||||
handlers = append(handlers, NotFound())
|
|
||||||
|
|
||||||
c.handlers = handlers
|
|
||||||
c.Next() // 执行处理函数链
|
c.Next() // 执行处理函数链
|
||||||
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
||||||
}
|
}
|
||||||
|
|
@ -805,17 +853,38 @@ func isGeneralOptionsRequest(req *http.Request) bool {
|
||||||
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
|
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (engine *Engine) allowedMethodsForPath(requestPath string) []string {
|
func shouldTryFixedPathLookup(path string, root *node) bool {
|
||||||
allowedMethods := make([]string, 0, len(engine.methodTrees))
|
if root != nil && root.hasCaseInsensitivePath {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for i := 0; i < len(path); i++ {
|
||||||
|
c := path[i]
|
||||||
|
if c >= utf8.RuneSelf {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if c >= 'A' && c <= 'Z' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string {
|
||||||
|
if cap(allowedMethods) < len(engine.methodTrees) {
|
||||||
|
allowedMethods = make([]string, 0, len(engine.methodTrees))
|
||||||
|
} else {
|
||||||
|
allowedMethods = allowedMethods[:0]
|
||||||
|
}
|
||||||
|
tempSkippedNodes := GetTempSkippedNodes()
|
||||||
for _, treeIter := range engine.methodTrees {
|
for _, treeIter := range engine.methodTrees {
|
||||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||||
tempSkippedNodes := GetTempSkippedNodes()
|
*tempSkippedNodes = (*tempSkippedNodes)[:0]
|
||||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
||||||
PutTempSkippedNodes(tempSkippedNodes)
|
|
||||||
if value.handlers != nil {
|
if value.handlers != nil {
|
||||||
allowedMethods = append(allowedMethods, treeIter.method)
|
allowedMethods = append(allowedMethods, treeIter.method)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
PutTempSkippedNodes(tempSkippedNodes)
|
||||||
return allowedMethods
|
return allowedMethods
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
71
engine_benchmark_test.go
Normal file
71
engine_benchmark_test.go
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var benchmarkStatusCode int
|
||||||
|
|
||||||
|
func buildServeHTTPBenchmarkEngine() *Engine {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/api/v1/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.GET("/api/v1/users/:id", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.POST("/api/v1/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
return engine
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
req, err := http.NewRequest(method, path, nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkStatusCode = rr.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkServeHTTP(b *testing.B) {
|
||||||
|
engine := buildServeHTTPBenchmarkEngine()
|
||||||
|
|
||||||
|
b.Run("StaticHit", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("NotFound", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("MethodNotAllowed", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("OptionsAllow", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("FixedPathRedirect", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS")
|
||||||
|
})
|
||||||
|
}
|
||||||
141
engine_test.go
Normal file
141
engine_test.go
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandleRequestRedirectFixedPath(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil)
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" {
|
||||||
|
t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil)
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/Users/Profile", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil)
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "/Users/Profile" {
|
||||||
|
t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/Users/Profile", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Fatalf("unexpected panic for fixed-path miss: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil)
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.NoRoute(func(c *Context) {
|
||||||
|
c.Writer.Header().Set("X-NoRoute", "hit")
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code)
|
||||||
|
}
|
||||||
|
if got := rr.Header().Get("X-NoRoute"); got != "hit" {
|
||||||
|
t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.NoRoute(func(c *Context) {
|
||||||
|
c.Writer.Header().Set("X-NoRoute", "hit")
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
|
||||||
|
if rr.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
|
||||||
|
}
|
||||||
|
if got := rr.Header().Get("X-NoRoute"); got != "" {
|
||||||
|
t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.POST("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code)
|
||||||
|
}
|
||||||
|
allow := rr.Header().Get("Allow")
|
||||||
|
if allow != "GET, POST" && allow != "POST, GET" {
|
||||||
|
t.Fatalf("expected Allow header to list matching methods, got %q", allow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultErrorHandleJSONShape(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
|
||||||
|
}
|
||||||
|
if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" {
|
||||||
|
t.Fatalf("unexpected error payload: %+v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -699,8 +699,17 @@ func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) {
|
||||||
|
|
||||||
if c.engine != nil {
|
if c.engine != nil {
|
||||||
if c.Request != nil && c.Request.RequestURI != "*" {
|
if c.Request != nil && c.Request.RequestURI != "*" {
|
||||||
if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 {
|
if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 {
|
||||||
c.Writer.Header().Set("Allow", strings.Join(allow, ", "))
|
c.allowedMethodsBuf = allow[:0]
|
||||||
|
allowHeader := c.allowHeaderBuf[:0]
|
||||||
|
for i, method := range allow {
|
||||||
|
if i > 0 {
|
||||||
|
allowHeader = append(allowHeader, ',', ' ')
|
||||||
|
}
|
||||||
|
allowHeader = append(allowHeader, method...)
|
||||||
|
}
|
||||||
|
c.allowHeaderBuf = allowHeader[:0]
|
||||||
|
c.Writer.Header().Set("Allow", string(allowHeader))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
130
route_match_benchmark_test.go
Normal file
130
route_match_benchmark_test.go
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
var (
|
||||||
|
benchmarkRouteHandlers HandlersChain
|
||||||
|
benchmarkRouteFullPath string
|
||||||
|
benchmarkRouteParamsLen int
|
||||||
|
benchmarkRouteCIPath []byte
|
||||||
|
benchmarkRouteCIFound bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildRouteMatchBenchmarkTree() *node {
|
||||||
|
tree := &node{}
|
||||||
|
routes := []string{
|
||||||
|
"/",
|
||||||
|
"/health",
|
||||||
|
"/contact",
|
||||||
|
"/api/v1/users",
|
||||||
|
"/api/v1/users/:id",
|
||||||
|
"/api/v1/users/:id/settings",
|
||||||
|
"/assets/*filepath",
|
||||||
|
"/abc/b",
|
||||||
|
"/abc/:p1/cde",
|
||||||
|
"/abc/:p1/:p2/def/*filepath",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
tree.addRoute(route, fakeHandler(route))
|
||||||
|
}
|
||||||
|
|
||||||
|
return tree
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
params := make(Params, 0, 4)
|
||||||
|
skipped := make([]skippedNode, 0, 8)
|
||||||
|
|
||||||
|
value := tree.getValue(path, ¶ms, &skipped, true)
|
||||||
|
if wantFullPath == "" {
|
||||||
|
if value.handlers != nil {
|
||||||
|
b.Fatalf("expected no match for %q, got %q", path, value.fullPath)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if value.handlers == nil {
|
||||||
|
b.Fatalf("expected match for %q, got nil handlers", path)
|
||||||
|
}
|
||||||
|
if value.fullPath != wantFullPath {
|
||||||
|
b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
params = params[:0]
|
||||||
|
skipped = skipped[:0]
|
||||||
|
value = tree.getValue(path, ¶ms, &skipped, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkRouteHandlers = value.handlers
|
||||||
|
benchmarkRouteFullPath = value.fullPath
|
||||||
|
if value.params != nil {
|
||||||
|
benchmarkRouteParamsLen = len(*value.params)
|
||||||
|
} else {
|
||||||
|
benchmarkRouteParamsLen = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRouteMatch(b *testing.B) {
|
||||||
|
tree := buildRouteMatchBenchmarkTree()
|
||||||
|
|
||||||
|
b.Run("StaticHit", func(b *testing.B) {
|
||||||
|
benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ParamHit", func(b *testing.B) {
|
||||||
|
benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("BacktrackingHit", func(b *testing.B) {
|
||||||
|
benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Miss", func(b *testing.B) {
|
||||||
|
benchmarkRouteLookup(b, tree, "/does/not/exist", "")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("CaseInsensitiveHit", func(b *testing.B) {
|
||||||
|
path := "/API/V1/USERS/123/SETTINGS"
|
||||||
|
out, found := tree.findCaseInsensitivePath(path, true)
|
||||||
|
if !found {
|
||||||
|
b.Fatalf("expected fixed-path match for %q", path)
|
||||||
|
}
|
||||||
|
if got := string(out); got != "/api/v1/users/123/settings" {
|
||||||
|
b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
out, found = tree.findCaseInsensitivePath(path, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkRouteCIPath = out
|
||||||
|
benchmarkRouteCIFound = found
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("CaseInsensitiveMiss", func(b *testing.B) {
|
||||||
|
path := "/DOES/NOT/EXIST"
|
||||||
|
out, found := tree.findCaseInsensitivePath(path, true)
|
||||||
|
if found || out != nil {
|
||||||
|
b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
out, found = tree.findCaseInsensitivePath(path, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkRouteCIPath = out
|
||||||
|
benchmarkRouteCIFound = found
|
||||||
|
})
|
||||||
|
}
|
||||||
116
tree.go
116
tree.go
|
|
@ -121,14 +121,28 @@ const (
|
||||||
|
|
||||||
// node 表示路由树中的一个节点.
|
// node 表示路由树中的一个节点.
|
||||||
type node struct {
|
type node struct {
|
||||||
path string // 当前节点的路径段
|
path string // 当前节点的路径段
|
||||||
indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点
|
indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点
|
||||||
wildChild bool // 是否包含通配符子节点(:param 或 *catchAll)
|
wildChild bool // 是否包含通配符子节点(:param 或 *catchAll)
|
||||||
nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有)
|
hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由
|
||||||
priority uint32 // 节点的优先级, 用于查找时优先匹配
|
nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有)
|
||||||
children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾
|
priority uint32 // 节点的优先级, 用于查找时优先匹配
|
||||||
handlers HandlersChain // 绑定到此节点的处理函数链
|
children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾
|
||||||
fullPath string // 完整路径, 用于调试和错误信息
|
handlers HandlersChain // 绑定到此节点的处理函数链
|
||||||
|
fullPath string // 完整路径, 用于调试和错误信息
|
||||||
|
}
|
||||||
|
|
||||||
|
func routeNeedsCaseInsensitiveLookup(path string) bool {
|
||||||
|
for i := 0; i < len(path); i++ {
|
||||||
|
c := path[i]
|
||||||
|
if c >= utf8.RuneSelf {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if c >= 'A' && c <= 'Z' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序.
|
// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序.
|
||||||
|
|
@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int {
|
||||||
func (n *node) addRoute(path string, handlers HandlersChain) {
|
func (n *node) addRoute(path string, handlers HandlersChain) {
|
||||||
fullPath := path // 记录完整的路径
|
fullPath := path // 记录完整的路径
|
||||||
n.priority++ // 增加当前节点的优先级
|
n.priority++ // 增加当前节点的优先级
|
||||||
|
if routeNeedsCaseInsensitiveLookup(path) {
|
||||||
|
n.hasCaseInsensitivePath = true
|
||||||
|
}
|
||||||
|
|
||||||
// 如果是空树(根节点)
|
// 如果是空树(根节点)
|
||||||
if len(n.path) == 0 && len(n.children) == 0 {
|
if len(n.path) == 0 && len(n.children) == 0 {
|
||||||
|
|
@ -452,12 +469,14 @@ type skippedNode struct {
|
||||||
// 建议进行 TSR(尾部斜杠重定向).
|
// 建议进行 TSR(尾部斜杠重定向).
|
||||||
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
|
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
|
||||||
var globalParamsCount int16 // 全局参数计数
|
var globalParamsCount int16 // 全局参数计数
|
||||||
|
var backtrackToWildChild bool
|
||||||
|
|
||||||
walk: // 外部循环用于遍历路由树
|
walk: // 外部循环用于遍历路由树
|
||||||
for {
|
for {
|
||||||
prefix := n.path // 当前节点的路径前缀
|
prefix := n.path // 当前节点的路径前缀
|
||||||
if len(path) > len(prefix) {
|
if len(path) > len(prefix) {
|
||||||
if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头
|
if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头
|
||||||
|
pathAtNode := path
|
||||||
path = path[len(prefix):] // 移除已匹配的前缀
|
path = path[len(prefix):] // 移除已匹配的前缀
|
||||||
|
|
||||||
// 在访问 path[0] 之前进行安全检查
|
// 在访问 path[0] 之前进行安全检查
|
||||||
|
|
@ -467,30 +486,26 @@ walk: // 外部循环用于遍历路由树
|
||||||
|
|
||||||
// 优先尝试所有非通配符子节点, 通过匹配索引字符
|
// 优先尝试所有非通配符子节点, 通过匹配索引字符
|
||||||
idxc := path[0] // 剩余路径的第一个字符
|
idxc := path[0] // 剩余路径的第一个字符
|
||||||
for i, c := range []byte(n.indices) {
|
if !backtrackToWildChild {
|
||||||
if c == idxc { // 如果找到匹配的索引字符
|
for i := 0; i < len(n.indices); i++ {
|
||||||
// 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯
|
if n.indices[i] == idxc { // 如果找到匹配的索引字符
|
||||||
if n.wildChild {
|
// 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯
|
||||||
index := len(*skippedNodes)
|
if n.wildChild {
|
||||||
*skippedNodes = (*skippedNodes)[:index+1]
|
index := len(*skippedNodes)
|
||||||
(*skippedNodes)[index] = skippedNode{
|
*skippedNodes = (*skippedNodes)[:index+1]
|
||||||
path: prefix + path, // 记录跳过的路径
|
(*skippedNodes)[index] = skippedNode{
|
||||||
node: &node{ // 复制当前节点的状态
|
path: pathAtNode, // 记录进入当前节点时的剩余路径
|
||||||
path: n.path,
|
node: n,
|
||||||
wildChild: n.wildChild,
|
paramsCount: globalParamsCount, // 记录当前参数计数
|
||||||
nType: n.nType,
|
}
|
||||||
priority: n.priority,
|
|
||||||
children: n.children,
|
|
||||||
handlers: n.handlers,
|
|
||||||
fullPath: n.fullPath,
|
|
||||||
},
|
|
||||||
paramsCount: globalParamsCount, // 记录当前参数计数
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
n = n.children[i] // 移动到匹配的子节点
|
n = n.children[i] // 移动到匹配的子节点
|
||||||
continue walk // 继续外部循环
|
continue walk // 继续外部循环
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
backtrackToWildChild = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !n.wildChild {
|
if !n.wildChild {
|
||||||
|
|
@ -507,7 +522,8 @@ walk: // 外部循环用于遍历路由树
|
||||||
*value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片
|
*value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片
|
||||||
}
|
}
|
||||||
globalParamsCount = skippedNode.paramsCount // 恢复参数计数
|
globalParamsCount = skippedNode.paramsCount // 恢复参数计数
|
||||||
continue walk // 继续外部循环
|
backtrackToWildChild = true
|
||||||
|
continue walk // 继续外部循环
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -547,7 +563,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
i := len(*value.params)
|
i := len(*value.params)
|
||||||
*value.params = (*value.params)[:i+1] // 扩展切片
|
*value.params = (*value.params)[:i+1] // 扩展切片
|
||||||
val := path[:end] // 提取参数值
|
val := path[:end] // 提取参数值
|
||||||
if unescape { // 如果需要进行 URL 解码
|
if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
|
||||||
if v, err := url.QueryUnescape(val); err == nil {
|
if v, err := url.QueryUnescape(val); err == nil {
|
||||||
val = v // 解码成功则更新值
|
val = v // 解码成功则更新值
|
||||||
}
|
}
|
||||||
|
|
@ -599,7 +615,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
i := len(*value.params)
|
i := len(*value.params)
|
||||||
*value.params = (*value.params)[:i+1] // 扩展切片
|
*value.params = (*value.params)[:i+1] // 扩展切片
|
||||||
val := path // 参数值是剩余的整个路径
|
val := path // 参数值是剩余的整个路径
|
||||||
if unescape { // 如果需要进行 URL 解码
|
if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
|
||||||
if v, err := url.QueryUnescape(path); err == nil {
|
if v, err := url.QueryUnescape(path); err == nil {
|
||||||
val = v // 解码成功则更新值
|
val = v // 解码成功则更新值
|
||||||
}
|
}
|
||||||
|
|
@ -634,6 +650,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
*value.params = (*value.params)[:skippedNode.paramsCount]
|
*value.params = (*value.params)[:skippedNode.paramsCount]
|
||||||
}
|
}
|
||||||
globalParamsCount = skippedNode.paramsCount
|
globalParamsCount = skippedNode.paramsCount
|
||||||
|
backtrackToWildChild = true
|
||||||
continue walk
|
continue walk
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -658,8 +675,8 @@ walk: // 外部循环用于遍历路由树
|
||||||
}
|
}
|
||||||
|
|
||||||
// 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议
|
// 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
if c == '/' { // 如果索引中包含 '/'
|
if n.indices[i] == '/' { // 如果索引中包含 '/'
|
||||||
n = n.children[i] // 移动到对应的子节点
|
n = n.children[i] // 移动到对应的子节点
|
||||||
value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
||||||
(n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数
|
(n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数
|
||||||
|
|
@ -688,6 +705,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
*value.params = (*value.params)[:skippedNode.paramsCount]
|
*value.params = (*value.params)[:skippedNode.paramsCount]
|
||||||
}
|
}
|
||||||
globalParamsCount = skippedNode.paramsCount
|
globalParamsCount = skippedNode.paramsCount
|
||||||
|
backtrackToWildChild = true
|
||||||
continue walk
|
continue walk
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -701,13 +719,15 @@ walk: // 外部循环用于遍历路由树
|
||||||
// 它还可以选择修复尾部斜杠.
|
// 它还可以选择修复尾部斜杠.
|
||||||
// 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功.
|
// 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功.
|
||||||
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
|
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
|
||||||
const stackBufSize = 128 // 栈上缓冲区的默认大小
|
return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash)
|
||||||
|
}
|
||||||
|
|
||||||
// 在常见情况下使用栈上静态大小的缓冲区.
|
func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) {
|
||||||
// 如果路径太长, 则在堆上分配缓冲区.
|
if buf != nil {
|
||||||
buf := make([]byte, 0, stackBufSize)
|
buf = buf[:0]
|
||||||
if length := len(path) + 1; length > stackBufSize {
|
}
|
||||||
buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区
|
if cap(buf) < len(path)+1 {
|
||||||
|
buf = make([]byte, 0, len(path)+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
ciPath := n.findCaseInsensitivePathRec(
|
ciPath := n.findCaseInsensitivePathRec(
|
||||||
|
|
@ -758,8 +778,8 @@ walk: // 外部循环用于遍历路由树
|
||||||
// 未找到处理函数.
|
// 未找到处理函数.
|
||||||
// 尝试通过添加尾部斜杠来修复路径
|
// 尝试通过添加尾部斜杠来修复路径
|
||||||
if fixTrailingSlash {
|
if fixTrailingSlash {
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
if c == '/' { // 如果索引中包含 '/'
|
if n.indices[i] == '/' { // 如果索引中包含 '/'
|
||||||
n = n.children[i] // 移动到对应的子节点
|
n = n.children[i] // 移动到对应的子节点
|
||||||
if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
||||||
(n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数
|
(n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数
|
||||||
|
|
@ -781,8 +801,8 @@ walk: // 外部循环用于遍历路由树
|
||||||
if rb[0] != 0 {
|
if rb[0] != 0 {
|
||||||
// 旧 rune 未处理完
|
// 旧 rune 未处理完
|
||||||
idxc := rb[0]
|
idxc := rb[0]
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
if c == idxc {
|
if n.indices[i] == idxc {
|
||||||
// 继续处理子节点
|
// 继续处理子节点
|
||||||
n = n.children[i]
|
n = n.children[i]
|
||||||
npLen = len(n.path)
|
npLen = len(n.path)
|
||||||
|
|
@ -813,9 +833,9 @@ walk: // 外部循环用于遍历路由树
|
||||||
rb = shiftNRuneBytes(rb, off)
|
rb = shiftNRuneBytes(rb, off)
|
||||||
|
|
||||||
idxc := rb[0]
|
idxc := rb[0]
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
// 小写匹配
|
// 小写匹配
|
||||||
if c == idxc {
|
if n.indices[i] == idxc {
|
||||||
// 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在
|
// 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在
|
||||||
if out := n.children[i].findCaseInsensitivePathRec(
|
if out := n.children[i].findCaseInsensitivePathRec(
|
||||||
path, ciPath, rb, fixTrailingSlash,
|
path, ciPath, rb, fixTrailingSlash,
|
||||||
|
|
@ -832,9 +852,9 @@ walk: // 外部循环用于遍历路由树
|
||||||
rb = shiftNRuneBytes(rb, off)
|
rb = shiftNRuneBytes(rb, off)
|
||||||
|
|
||||||
idxc := rb[0]
|
idxc := rb[0]
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
// 大写匹配
|
// 大写匹配
|
||||||
if c == idxc {
|
if n.indices[i] == idxc {
|
||||||
// 继续处理子节点
|
// 继续处理子节点
|
||||||
n = n.children[i]
|
n = n.children[i]
|
||||||
npLen = len(n.path)
|
npLen = len(n.path)
|
||||||
|
|
|
||||||
66
tree_test.go
66
tree_test.go
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Used as a workaround since we can't compare functions or their addresses
|
// Used as a workaround since we can't compare functions or their addresses
|
||||||
|
|
@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode {
|
||||||
return &ps
|
return &ps
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
resultCh := make(chan nodeValue, 1)
|
||||||
|
go func() {
|
||||||
|
resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case value := <-resultCh:
|
||||||
|
return value
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path)
|
||||||
|
return nodeValue{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) {
|
func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) {
|
||||||
unescape := false
|
unescape := false
|
||||||
if len(unescapes) >= 1 {
|
if len(unescapes) >= 1 {
|
||||||
|
|
@ -1104,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) {
|
||||||
t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams)
|
t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
routes []string
|
||||||
|
requestPath string
|
||||||
|
wantFullPath string
|
||||||
|
wantParams Params
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "param route after static dead end",
|
||||||
|
routes: []string{"/foo/bar", "/foo/:id/details"},
|
||||||
|
requestPath: "/foo/bar/details",
|
||||||
|
wantFullPath: "/foo/:id/details",
|
||||||
|
wantParams: Params{{Key: "id", Value: "bar"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "catch-all route after static dead end",
|
||||||
|
routes: []string{"/foo/bar", "/foo/:id/*rest"},
|
||||||
|
requestPath: "/foo/bar/baz.txt",
|
||||||
|
wantFullPath: "/foo/:id/*rest",
|
||||||
|
wantParams: Params{
|
||||||
|
{Key: "id", Value: "bar"},
|
||||||
|
{Key: "rest", Value: "/baz.txt"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tree := &node{}
|
||||||
|
for _, route := range tt.routes {
|
||||||
|
tree.addRoute(route, fakeHandler(route))
|
||||||
|
}
|
||||||
|
|
||||||
|
value := getValueWithTimeout(t, tree, tt.requestPath, false)
|
||||||
|
if value.handlers == nil {
|
||||||
|
t.Fatalf("expected handlers for %q", tt.requestPath)
|
||||||
|
}
|
||||||
|
if value.fullPath != tt.wantFullPath {
|
||||||
|
t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath)
|
||||||
|
}
|
||||||
|
if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) {
|
||||||
|
t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue