diff --git a/.gitignore b/.gitignore index 30d74d2..6f301cd 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -test \ No newline at end of file +test +/bench_route_match_baseline.txt diff --git a/engine.go b/engine.go index b7cf330..ece023d 100644 --- a/engine.go +++ b/engine.go @@ -739,12 +739,13 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 return } - // 尝试不区分大小写的查找 - // 直接在 rootNode 上调用 findCaseInsensitivePath 方法 - ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) - if found && engine.RedirectFixedPath { - c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 - return + if engine.RedirectFixedPath { + // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. + ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) + if found { + c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 + return + } } } } diff --git a/route_match_benchmark_test.go b/route_match_benchmark_test.go new file mode 100644 index 0000000..e0dd2aa --- /dev/null +++ b/route_match_benchmark_test.go @@ -0,0 +1,130 @@ +package touka + +import "testing" + +var ( + benchmarkRouteHandlers HandlersChain + benchmarkRouteFullPath string + benchmarkRouteParamsLen int + benchmarkRouteCIPath []byte + benchmarkRouteCIFound bool +) + +func buildRouteMatchBenchmarkTree() *node { + tree := &node{} + routes := []string{ + "/", + "/health", + "/contact", + "/api/v1/users", + "/api/v1/users/:id", + "/api/v1/users/:id/settings", + "/assets/*filepath", + "/abc/b", + "/abc/:p1/cde", + "/abc/:p1/:p2/def/*filepath", + } + + for _, route := range routes { + tree.addRoute(route, fakeHandler(route)) + } + + return tree +} + +func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) { + b.Helper() + + params := make(Params, 0, 4) + skipped := make([]skippedNode, 0, 8) + + value := tree.getValue(path, ¶ms, &skipped, true) + if wantFullPath == "" { + if value.handlers != nil { + b.Fatalf("expected no match for %q, got %q", path, value.fullPath) + } + } else { + if value.handlers == nil { + b.Fatalf("expected match for %q, got nil handlers", path) + } + if value.fullPath != wantFullPath { + b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath) + } + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + params = params[:0] + skipped = skipped[:0] + value = tree.getValue(path, ¶ms, &skipped, true) + } + + benchmarkRouteHandlers = value.handlers + benchmarkRouteFullPath = value.fullPath + if value.params != nil { + benchmarkRouteParamsLen = len(*value.params) + } else { + benchmarkRouteParamsLen = 0 + } +} + +func BenchmarkRouteMatch(b *testing.B) { + tree := buildRouteMatchBenchmarkTree() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users") + }) + + b.Run("ParamHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id") + }) + + b.Run("BacktrackingHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath") + }) + + b.Run("Miss", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/does/not/exist", "") + }) + + b.Run("CaseInsensitiveHit", func(b *testing.B) { + path := "/API/V1/USERS/123/SETTINGS" + out, found := tree.findCaseInsensitivePath(path, true) + if !found { + b.Fatalf("expected fixed-path match for %q", path) + } + if got := string(out); got != "/api/v1/users/123/settings" { + b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) + + b.Run("CaseInsensitiveMiss", func(b *testing.B) { + path := "/DOES/NOT/EXIST" + out, found := tree.findCaseInsensitivePath(path, true) + if found || out != nil { + b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) +} diff --git a/tree.go b/tree.go index f5452f4..e9a10e6 100644 --- a/tree.go +++ b/tree.go @@ -452,12 +452,14 @@ type skippedNode struct { // 建议进行 TSR(尾部斜杠重定向). func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { var globalParamsCount int16 // 全局参数计数 + var backtrackToWildChild bool walk: // 外部循环用于遍历路由树 for { prefix := n.path // 当前节点的路径前缀 if len(path) > len(prefix) { if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 + pathAtNode := path path = path[len(prefix):] // 移除已匹配的前缀 // 在访问 path[0] 之前进行安全检查 @@ -467,30 +469,26 @@ walk: // 外部循环用于遍历路由树 // 优先尝试所有非通配符子节点, 通过匹配索引字符 idxc := path[0] // 剩余路径的第一个字符 - for i, c := range []byte(n.indices) { - if c == idxc { // 如果找到匹配的索引字符 - // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 - if n.wildChild { - index := len(*skippedNodes) - *skippedNodes = (*skippedNodes)[:index+1] - (*skippedNodes)[index] = skippedNode{ - path: prefix + path, // 记录跳过的路径 - node: &node{ // 复制当前节点的状态 - path: n.path, - wildChild: n.wildChild, - nType: n.nType, - priority: n.priority, - children: n.children, - handlers: n.handlers, - fullPath: n.fullPath, - }, - paramsCount: globalParamsCount, // 记录当前参数计数 + if !backtrackToWildChild { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 如果找到匹配的索引字符 + // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 + if n.wildChild { + index := len(*skippedNodes) + *skippedNodes = (*skippedNodes)[:index+1] + (*skippedNodes)[index] = skippedNode{ + path: pathAtNode, // 记录进入当前节点时的剩余路径 + node: n, + paramsCount: globalParamsCount, // 记录当前参数计数 + } } - } - n = n.children[i] // 移动到匹配的子节点 - continue walk // 继续外部循环 + n = n.children[i] // 移动到匹配的子节点 + continue walk // 继续外部循环 + } } + } else { + backtrackToWildChild = false } if !n.wildChild { @@ -507,7 +505,8 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 } globalParamsCount = skippedNode.paramsCount // 恢复参数计数 - continue walk // 继续外部循环 + backtrackToWildChild = true + continue walk // 继续外部循环 } } } @@ -547,7 +546,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path[:end] // 提取参数值 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(val); err == nil { val = v // 解码成功则更新值 } @@ -599,7 +598,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path // 参数值是剩余的整个路径 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(path); err == nil { val = v // 解码成功则更新值 } @@ -634,6 +633,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -658,8 +658,8 @@ walk: // 外部循环用于遍历路由树 } // 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议 - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 @@ -688,6 +688,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -758,8 +759,8 @@ walk: // 外部循环用于遍历路由树 // 未找到处理函数. // 尝试通过添加尾部斜杠来修复路径 if fixTrailingSlash { - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 @@ -781,8 +782,8 @@ walk: // 外部循环用于遍历路由树 if rb[0] != 0 { // 旧 rune 未处理完 idxc := rb[0] - for i, c := range []byte(n.indices) { - if c == idxc { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) @@ -813,9 +814,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 小写匹配 - if c == idxc { + if n.indices[i] == idxc { // 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在 if out := n.children[i].findCaseInsensitivePathRec( path, ciPath, rb, fixTrailingSlash, @@ -832,9 +833,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 大写匹配 - if c == idxc { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) diff --git a/tree_test.go b/tree_test.go index 7665afd..a35a1a8 100644 --- a/tree_test.go +++ b/tree_test.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" "testing" + "time" ) // Used as a workaround since we can't compare functions or their addresses @@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode { return &ps } +func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue { + t.Helper() + + resultCh := make(chan nodeValue, 1) + go func() { + resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape) + }() + + select { + case value := <-resultCh: + return value + case <-time.After(2 * time.Second): + t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path) + return nodeValue{} + } +} + func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { unescape := false if len(unescapes) >= 1 { @@ -1104,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) { t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams) } } + +func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) { + tests := []struct { + name string + routes []string + requestPath string + wantFullPath string + wantParams Params + }{ + { + name: "param route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/details"}, + requestPath: "/foo/bar/details", + wantFullPath: "/foo/:id/details", + wantParams: Params{{Key: "id", Value: "bar"}}, + }, + { + name: "catch-all route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/*rest"}, + requestPath: "/foo/bar/baz.txt", + wantFullPath: "/foo/:id/*rest", + wantParams: Params{ + {Key: "id", Value: "bar"}, + {Key: "rest", Value: "/baz.txt"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree := &node{} + for _, route := range tt.routes { + tree.addRoute(route, fakeHandler(route)) + } + + value := getValueWithTimeout(t, tree, tt.requestPath, false) + if value.handlers == nil { + t.Fatalf("expected handlers for %q", tt.requestPath) + } + if value.fullPath != tt.wantFullPath { + t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath) + } + if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) { + t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params) + } + }) + } +}