From 6acac9edce474de3d9abeec76d45a703247d8a2d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 08:27:00 +0800 Subject: [PATCH 1/6] fix: streamline route matcher backtracking Avoid rebuilding skipped-node state during wildcard fallback so the matcher no longer loops on the same static branch and stops allocating on the hot path. Add focused route benchmarks and regression coverage to keep the optimized path stable. --- .gitignore | 3 +- engine.go | 13 ++-- route_match_benchmark_test.go | 130 ++++++++++++++++++++++++++++++++++ tree.go | 69 +++++++++--------- tree_test.go | 66 +++++++++++++++++ 5 files changed, 240 insertions(+), 41 deletions(-) create mode 100644 route_match_benchmark_test.go diff --git a/.gitignore b/.gitignore index 30d74d2..6f301cd 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -test \ No newline at end of file +test +/bench_route_match_baseline.txt diff --git a/engine.go b/engine.go index b7cf330..ece023d 100644 --- a/engine.go +++ b/engine.go @@ -739,12 +739,13 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 return } - // 尝试不区分大小写的查找 - // 直接在 rootNode 上调用 findCaseInsensitivePath 方法 - ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) - if found && engine.RedirectFixedPath { - c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 - return + if engine.RedirectFixedPath { + // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. + ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) + if found { + c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 + return + } } } } diff --git a/route_match_benchmark_test.go b/route_match_benchmark_test.go new file mode 100644 index 0000000..e0dd2aa --- /dev/null +++ b/route_match_benchmark_test.go @@ -0,0 +1,130 @@ +package touka + +import "testing" + +var ( + benchmarkRouteHandlers HandlersChain + benchmarkRouteFullPath string + benchmarkRouteParamsLen int + benchmarkRouteCIPath []byte + benchmarkRouteCIFound bool +) + +func buildRouteMatchBenchmarkTree() *node { + tree := &node{} + routes := []string{ + "/", + "/health", + "/contact", + "/api/v1/users", + "/api/v1/users/:id", + "/api/v1/users/:id/settings", + "/assets/*filepath", + "/abc/b", + "/abc/:p1/cde", + "/abc/:p1/:p2/def/*filepath", + } + + for _, route := range routes { + tree.addRoute(route, fakeHandler(route)) + } + + return tree +} + +func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) { + b.Helper() + + params := make(Params, 0, 4) + skipped := make([]skippedNode, 0, 8) + + value := tree.getValue(path, ¶ms, &skipped, true) + if wantFullPath == "" { + if value.handlers != nil { + b.Fatalf("expected no match for %q, got %q", path, value.fullPath) + } + } else { + if value.handlers == nil { + b.Fatalf("expected match for %q, got nil handlers", path) + } + if value.fullPath != wantFullPath { + b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath) + } + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + params = params[:0] + skipped = skipped[:0] + value = tree.getValue(path, ¶ms, &skipped, true) + } + + benchmarkRouteHandlers = value.handlers + benchmarkRouteFullPath = value.fullPath + if value.params != nil { + benchmarkRouteParamsLen = len(*value.params) + } else { + benchmarkRouteParamsLen = 0 + } +} + +func BenchmarkRouteMatch(b *testing.B) { + tree := buildRouteMatchBenchmarkTree() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users") + }) + + b.Run("ParamHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id") + }) + + b.Run("BacktrackingHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath") + }) + + b.Run("Miss", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/does/not/exist", "") + }) + + b.Run("CaseInsensitiveHit", func(b *testing.B) { + path := "/API/V1/USERS/123/SETTINGS" + out, found := tree.findCaseInsensitivePath(path, true) + if !found { + b.Fatalf("expected fixed-path match for %q", path) + } + if got := string(out); got != "/api/v1/users/123/settings" { + b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) + + b.Run("CaseInsensitiveMiss", func(b *testing.B) { + path := "/DOES/NOT/EXIST" + out, found := tree.findCaseInsensitivePath(path, true) + if found || out != nil { + b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) +} diff --git a/tree.go b/tree.go index f5452f4..e9a10e6 100644 --- a/tree.go +++ b/tree.go @@ -452,12 +452,14 @@ type skippedNode struct { // 建议进行 TSR(尾部斜杠重定向). func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { var globalParamsCount int16 // 全局参数计数 + var backtrackToWildChild bool walk: // 外部循环用于遍历路由树 for { prefix := n.path // 当前节点的路径前缀 if len(path) > len(prefix) { if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 + pathAtNode := path path = path[len(prefix):] // 移除已匹配的前缀 // 在访问 path[0] 之前进行安全检查 @@ -467,30 +469,26 @@ walk: // 外部循环用于遍历路由树 // 优先尝试所有非通配符子节点, 通过匹配索引字符 idxc := path[0] // 剩余路径的第一个字符 - for i, c := range []byte(n.indices) { - if c == idxc { // 如果找到匹配的索引字符 - // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 - if n.wildChild { - index := len(*skippedNodes) - *skippedNodes = (*skippedNodes)[:index+1] - (*skippedNodes)[index] = skippedNode{ - path: prefix + path, // 记录跳过的路径 - node: &node{ // 复制当前节点的状态 - path: n.path, - wildChild: n.wildChild, - nType: n.nType, - priority: n.priority, - children: n.children, - handlers: n.handlers, - fullPath: n.fullPath, - }, - paramsCount: globalParamsCount, // 记录当前参数计数 + if !backtrackToWildChild { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 如果找到匹配的索引字符 + // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 + if n.wildChild { + index := len(*skippedNodes) + *skippedNodes = (*skippedNodes)[:index+1] + (*skippedNodes)[index] = skippedNode{ + path: pathAtNode, // 记录进入当前节点时的剩余路径 + node: n, + paramsCount: globalParamsCount, // 记录当前参数计数 + } } - } - n = n.children[i] // 移动到匹配的子节点 - continue walk // 继续外部循环 + n = n.children[i] // 移动到匹配的子节点 + continue walk // 继续外部循环 + } } + } else { + backtrackToWildChild = false } if !n.wildChild { @@ -507,7 +505,8 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 } globalParamsCount = skippedNode.paramsCount // 恢复参数计数 - continue walk // 继续外部循环 + backtrackToWildChild = true + continue walk // 继续外部循环 } } } @@ -547,7 +546,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path[:end] // 提取参数值 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(val); err == nil { val = v // 解码成功则更新值 } @@ -599,7 +598,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path // 参数值是剩余的整个路径 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(path); err == nil { val = v // 解码成功则更新值 } @@ -634,6 +633,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -658,8 +658,8 @@ walk: // 外部循环用于遍历路由树 } // 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议 - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 @@ -688,6 +688,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -758,8 +759,8 @@ walk: // 外部循环用于遍历路由树 // 未找到处理函数. // 尝试通过添加尾部斜杠来修复路径 if fixTrailingSlash { - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 @@ -781,8 +782,8 @@ walk: // 外部循环用于遍历路由树 if rb[0] != 0 { // 旧 rune 未处理完 idxc := rb[0] - for i, c := range []byte(n.indices) { - if c == idxc { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) @@ -813,9 +814,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 小写匹配 - if c == idxc { + if n.indices[i] == idxc { // 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在 if out := n.children[i].findCaseInsensitivePathRec( path, ciPath, rb, fixTrailingSlash, @@ -832,9 +833,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 大写匹配 - if c == idxc { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) diff --git a/tree_test.go b/tree_test.go index 7665afd..a35a1a8 100644 --- a/tree_test.go +++ b/tree_test.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" "testing" + "time" ) // Used as a workaround since we can't compare functions or their addresses @@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode { return &ps } +func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue { + t.Helper() + + resultCh := make(chan nodeValue, 1) + go func() { + resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape) + }() + + select { + case value := <-resultCh: + return value + case <-time.After(2 * time.Second): + t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path) + return nodeValue{} + } +} + func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { unescape := false if len(unescapes) >= 1 { @@ -1104,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) { t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams) } } + +func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) { + tests := []struct { + name string + routes []string + requestPath string + wantFullPath string + wantParams Params + }{ + { + name: "param route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/details"}, + requestPath: "/foo/bar/details", + wantFullPath: "/foo/:id/details", + wantParams: Params{{Key: "id", Value: "bar"}}, + }, + { + name: "catch-all route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/*rest"}, + requestPath: "/foo/bar/baz.txt", + wantFullPath: "/foo/:id/*rest", + wantParams: Params{ + {Key: "id", Value: "bar"}, + {Key: "rest", Value: "/baz.txt"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree := &node{} + for _, route := range tt.routes { + tree.addRoute(route, fakeHandler(route)) + } + + value := getValueWithTimeout(t, tree, tt.requestPath, false) + if value.handlers == nil { + t.Fatalf("expected handlers for %q", tt.requestPath) + } + if value.fullPath != tt.wantFullPath { + t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath) + } + if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) { + t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params) + } + }) + } +} From 5d979e56707a239e682f0c374cd6cc78d3bb2f3a Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 08:39:10 +0800 Subject: [PATCH 2/6] fix: reduce per-request context and fallback overhead Make Context keys lazy so requests that never call Set stop allocating on reset. Reuse stable 404 and 405 handlers and add focused benchmarks so ServeHTTP miss paths stay measurable. --- context.go | 2 +- context_benchmark_test.go | 78 +++++++++++++++++++++++++++++++++++++++ engine.go | 77 ++++++++++++++++++++------------------ engine_benchmark_test.go | 64 ++++++++++++++++++++++++++++++++ 4 files changed, 185 insertions(+), 36 deletions(-) create mode 100644 context_benchmark_test.go create mode 100644 engine_benchmark_test.go diff --git a/context.go b/context.go index 9c4ba7e..f24ceb0 100644 --- a/context.go +++ b/context.go @@ -97,7 +97,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 - c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 + c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map c.Errors = c.Errors[:0] // 清空 Errors 切片 c.queryCache = nil // 清空查询参数缓存 c.formCache = nil // 清空表单数据缓存 diff --git a/context_benchmark_test.go b/context_benchmark_test.go new file mode 100644 index 0000000..2198c59 --- /dev/null +++ b/context_benchmark_test.go @@ -0,0 +1,78 @@ +package touka + +import ( + "net/http" + "testing" +) + +func TestContextResetKeepsKeysNilUntilSet(t *testing.T) { + c, _ := CreateTestContext(nil) + if c.Keys != nil { + t.Fatalf("expected fresh test context Keys to be nil before first Set") + } + + c.Set("answer", 42) + if c.Keys == nil { + t.Fatalf("expected Set to allocate Keys map") + } + if value, exists := c.Get("answer"); !exists || value != 42 { + t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists) + } + + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + c.reset(c.Writer, req) + + if c.Keys != nil { + t.Fatalf("expected reset to clear Keys without allocating a new map") + } + if value, exists := c.Get("answer"); exists || value != nil { + t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists) + } + + ctxValue := c.Value("missing") + if ctxValue != nil { + t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue) + } + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected MustGet to panic for missing key after reset") + } + }() + _ = c.MustGet("answer") +} + +func BenchmarkContextReset(b *testing.B) { + b.Run("NoKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + } + }) + + b.Run("WithKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + c.Set("request-id", i) + } + }) +} diff --git a/engine.go b/engine.go index ece023d..b2cc952 100644 --- a/engine.go +++ b/engine.go @@ -117,6 +117,46 @@ type ErrorHandle struct { type ErrorHandler func(c *Context, code int, err error) +var errMethodNotAllowed = errors.New("method not allowed") +var errNotFound = errors.New("not found") + +var methodNotAllowedHandler HandlerFunc = func(c *Context) { + httpMethod := c.Request.Method + requestPath := routeLookupPath(c.Request) + engine := c.engine + // 是否是OPTIONS方式 + if httpMethod == http.MethodOptions { + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + allowedMethods := engine.allowedMethodsForPath(requestPath) + if len(allowedMethods) > 0 { + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + c.Status(http.StatusOK) + return + } + } + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + for _, treeIter := range engine.methodTrees { + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + continue + } + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 + PutTempSkippedNodes(tempSkippedNodes) + if value.handlers != nil { + // 使用定义的ErrorHandle处理 + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed) + return + } + } +} + +var notFoundHandler HandlerFunc = func(c *Context) { + engine := c.engine + engine.errorHandle.handler(c, http.StatusNotFound, errNotFound) +} + // defaultErrorHandle 默认错误处理 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 select { @@ -479,45 +519,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { // 405中间件 func MethodNotAllowed() HandlerFunc { - return func(c *Context) { - httpMethod := c.Request.Method - requestPath := routeLookupPath(c.Request) - engine := c.engine - // 是否是OPTIONS方式 - if httpMethod == http.MethodOptions { - // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := engine.allowedMethodsForPath(requestPath) - if len(allowedMethods) > 0 { - // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) - c.Status(http.StatusOK) - return - } - } - // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 - for _, treeIter := range engine.methodTrees { - if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 - continue - } - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - // 使用定义的ErrorHandle处理 - engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) - return - } - } - } + return methodNotAllowedHandler } // 404最后处理 func NotFound() HandlerFunc { - return func(c *Context) { - engine := c.engine - engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found")) - } + return notFoundHandler } // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go new file mode 100644 index 0000000..5780230 --- /dev/null +++ b/engine_benchmark_test.go @@ -0,0 +1,64 @@ +package touka + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +var benchmarkStatusCode int + +func buildServeHTTPBenchmarkEngine() *Engine { + engine := New() + engine.GET("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.GET("/api/v1/users/:id", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + return engine +} + +func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) { + b.Helper() + + req, err := http.NewRequest(method, path, nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr = httptest.NewRecorder() + engine.ServeHTTP(rr, req) + } + + benchmarkStatusCode = rr.Code +} + +func BenchmarkServeHTTP(b *testing.B) { + engine := buildServeHTTPBenchmarkEngine() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users") + }) + + b.Run("NotFound", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist") + }) + + b.Run("MethodNotAllowed", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users") + }) + + b.Run("OptionsAllow", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") + }) +} From 2d4aefc86e5d0276bb0ad7dab39eefa75b3c68c7 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:06:56 +0800 Subject: [PATCH 3/6] fix: cut redirect and allow-path routing overhead Reuse fixed-path and Allow-header buffers so redirect and OPTIONS handling stop rebuilding temporary data on every request. Cache fallback chains and add regression coverage for redirect, 404, 405, and Allow behavior to keep the faster miss paths stable. --- context.go | 15 +++++ engine.go | 125 ++++++++++++++++++++++++++++----------- engine_benchmark_test.go | 7 +++ engine_test.go | 102 ++++++++++++++++++++++++++++++++ reverseproxy.go | 13 +++- tree.go | 50 ++++++++++++---- 6 files changed, 264 insertions(+), 48 deletions(-) create mode 100644 engine_test.go diff --git a/context.go b/context.go index f24ceb0..f06d21e 100644 --- a/context.go +++ b/context.go @@ -73,6 +73,12 @@ type Context struct { // skippedNodes 用于记录跳过的节点信息,以便回溯 // 通常在处理嵌套路由时使用 SkippedNodes []skippedNode + + // fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲. + fixedPathBuf []byte + + allowedMethodsBuf []string + allowHeaderBuf []byte } // --- Context 相关方法实现 --- @@ -111,6 +117,15 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } else { c.SkippedNodes = make([]skippedNode, 0, 256) } + if cap(c.fixedPathBuf) > 0 { + c.fixedPathBuf = c.fixedPathBuf[:0] + } + if cap(c.allowedMethodsBuf) > 0 { + c.allowedMethodsBuf = c.allowedMethodsBuf[:0] + } + if cap(c.allowHeaderBuf) > 0 { + c.allowHeaderBuf = c.allowHeaderBuf[:0] + } } // Next 在处理链中执行下一个处理函数 diff --git a/engine.go b/engine.go index b2cc952..5214654 100644 --- a/engine.go +++ b/engine.go @@ -11,6 +11,7 @@ import ( "reflect" "runtime" "strings" + "unicode/utf8" "net/http" @@ -82,6 +83,11 @@ type Engine struct { // GlobalMaxRequestBodySize 全局请求体Body大小限制 GlobalMaxRequestBodySize int64 + + notFoundChain HandlersChain + notFoundNoMethodChain HandlersChain + unmatchedFSChain HandlersChain + unmatchedFSNoMethodChain HandlersChain } // HandleFunc 注册一个或多个 HTTP 方法的路由 @@ -127,10 +133,19 @@ var methodNotAllowedHandler HandlerFunc = func(c *Context) { // 是否是OPTIONS方式 if httpMethod == http.MethodOptions { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := engine.allowedMethodsForPath(requestPath) + allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0]) + c.allowedMethodsBuf = allowedMethods[:0] if len(allowedMethods) > 0 { // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allowedMethods { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", BytesToString(allowHeader)) c.Status(http.StatusOK) return } @@ -251,6 +266,7 @@ func New() *Engine { TLSServerConfigurator: nil, GlobalMaxRequestBodySize: -1, } + engine.rebuildFallbackChains() engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() @@ -306,6 +322,7 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) { // 是否开启MethodNotAllowed func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { engine.HandleMethodNotAllowed = enable + engine.rebuildFallbackChains() } // SetLogger传入实例 @@ -346,6 +363,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF engine.unMatchFS.ServeUnmatchedAsFS = false engine.UnMatchFSRoutes = nil } + engine.rebuildFallbackChains() } // 获取默认Protocol配置 @@ -531,12 +549,52 @@ func NotFound() HandlerFunc { func (Engine *Engine) NoRoute(handler HandlerFunc) { Engine.noRoute = handler Engine.noRoutes = nil + Engine.rebuildFallbackChains() } // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { Engine.noRoute = nil Engine.noRoutes = handlerFuncs + Engine.rebuildFallbackChains() +} + +func (engine *Engine) rebuildFallbackChains() { + buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain { + finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound + if includeMethodNotAllowed { + finalSize++ + } + if includeUnmatchedFS { + finalSize += len(engine.UnMatchFSRoutes) + } + if engine.noRoute != nil { + finalSize++ + } else { + finalSize += len(engine.noRoutes) + } + + chain := make(HandlersChain, 0, finalSize) + chain = append(chain, engine.globalHandlers...) + if includeMethodNotAllowed { + chain = append(chain, methodNotAllowedHandler) + } + if includeUnmatchedFS { + chain = append(chain, engine.UnMatchFSRoutes...) + } + if engine.noRoute != nil { + chain = append(chain, engine.noRoute) + } else if len(engine.noRoutes) > 0 { + chain = append(chain, engine.noRoutes...) + } + chain = append(chain, notFoundHandler) + return chain + } + + engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false) + engine.notFoundNoMethodChain = buildChain(false, false) + engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS) + engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS) } // combineHandlers 组合多个处理函数链为一个 @@ -553,6 +611,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // 这些中间件将应用于所有注册的路由 func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { engine.globalHandlers = append(engine.globalHandlers, middleware...) + engine.rebuildFallbackChains() return engine } @@ -746,48 +805,24 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 return } - if engine.RedirectFixedPath { + if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) { // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. - ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) + ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash) if found { + c.fixedPathBuf = ciPath[:0] c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 return } + c.fixedPathBuf = ciPath[:0] } } } - // 构建处理链 - // 组合全局中间件和路由处理函数 - handlers := engine.globalHandlers - - // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 - // 则在全局中间件之后添加 MethodNotAllowed 处理器 - if engine.HandleMethodNotAllowed { - handlers = append(handlers, MethodNotAllowed()) - } - - // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed - // 则在处理链的最后添加 UnMatchFS 处理器 if engine.unMatchFS.ServeUnmatchedAsFS { - /* - var unMatchFSHandle = c.engine.unMatchFileServer - handlers = append(handlers, unMatchFSHandle) - */ - handlers = append(handlers, engine.UnMatchFSRoutes...) + c.handlers = engine.unmatchedFSChain + } else { + c.handlers = engine.notFoundChain } - - // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS - // 则在处理链的最后添加 NoRoute 处理器 - if engine.noRoute != nil { - handlers = append(handlers, engine.noRoute) - } else if len(engine.noRoutes) > 0 { - handlers = append(handlers, engine.noRoutes...) - } - - handlers = append(handlers, NotFound()) - - c.handlers = handlers c.Next() // 执行处理函数链 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } @@ -813,8 +848,28 @@ func isGeneralOptionsRequest(req *http.Request) bool { return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" } -func (engine *Engine) allowedMethodsForPath(requestPath string) []string { - allowedMethods := make([]string, 0, len(engine.methodTrees)) +func shouldTryFixedPathLookup(path string, root *node) bool { + if root != nil && root.hasCaseInsensitivePath { + return true + } + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false +} + +func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string { + if cap(allowedMethods) < len(engine.methodTrees) { + allowedMethods = make([]string, 0, len(engine.methodTrees)) + } else { + allowedMethods = allowedMethods[:0] + } for _, treeIter := range engine.methodTrees { // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 tempSkippedNodes := GetTempSkippedNodes() diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go index 5780230..666e8b2 100644 --- a/engine_benchmark_test.go +++ b/engine_benchmark_test.go @@ -16,6 +16,9 @@ func buildServeHTTPBenchmarkEngine() *Engine { engine.GET("/api/v1/users/:id", func(c *Context) { c.Status(http.StatusNoContent) }) + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) engine.POST("/api/v1/users", func(c *Context) { c.Status(http.StatusNoContent) }) @@ -61,4 +64,8 @@ func BenchmarkServeHTTP(b *testing.B) { b.Run("OptionsAllow", func(b *testing.B) { benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") }) + + b.Run("FixedPathRedirect", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS") + }) } diff --git a/engine_test.go b/engine_test.go new file mode 100644 index 0000000..292d5e2 --- /dev/null +++ b/engine_test.go @@ -0,0 +1,102 @@ +package touka + +import ( + "net/http" + "testing" +) + +func TestHandleRequestRedirectFixedPath(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" { + t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location) + } +} + +func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code) + } +} + +func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) { + engine := New() + engine.GET("/Users/Profile", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code) + } + if location := rr.Header().Get("Location"); location != "/Users/Profile" { + t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location) + } +} + +func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) { + engine := New() + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "hit" { + t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got) + } +} + +func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "" { + t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got) + } +} + +func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil) + if rr.Code != http.StatusOK { + t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code) + } + allow := rr.Header().Get("Allow") + if allow != "GET, POST" && allow != "POST, GET" { + t.Fatalf("expected Allow header to list matching methods, got %q", allow) + } +} diff --git a/reverseproxy.go b/reverseproxy.go index 1b89b2a..ff49aef 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -699,8 +699,17 @@ func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { if c.engine != nil { if c.Request != nil && c.Request.RequestURI != "*" { - if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 { - c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) + if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 { + c.allowedMethodsBuf = allow[:0] + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allow { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", BytesToString(allowHeader)) } } } diff --git a/tree.go b/tree.go index e9a10e6..6595655 100644 --- a/tree.go +++ b/tree.go @@ -121,14 +121,28 @@ const ( // node 表示路由树中的一个节点. type node struct { - path string // 当前节点的路径段 - indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 - wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) - nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) - priority uint32 // 节点的优先级, 用于查找时优先匹配 - children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 - handlers HandlersChain // 绑定到此节点的处理函数链 - fullPath string // 完整路径, 用于调试和错误信息 + path string // 当前节点的路径段 + indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 + wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) + hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由 + nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) + priority uint32 // 节点的优先级, 用于查找时优先匹配 + children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 + handlers HandlersChain // 绑定到此节点的处理函数链 + fullPath string // 完整路径, 用于调试和错误信息 +} + +func routeNeedsCaseInsensitiveLookup(path string) bool { + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false } // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. @@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int { func (n *node) addRoute(path string, handlers HandlersChain) { fullPath := path // 记录完整的路径 n.priority++ // 增加当前节点的优先级 + if routeNeedsCaseInsensitiveLookup(path) { + n.hasCaseInsensitivePath = true + } // 如果是空树(根节点) if len(n.path) == 0 && len(n.children) == 0 { @@ -702,13 +719,24 @@ walk: // 外部循环用于遍历路由树 // 它还可以选择修复尾部斜杠. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { + return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash) +} + +func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) { const stackBufSize = 128 // 栈上缓冲区的默认大小 // 在常见情况下使用栈上静态大小的缓冲区. // 如果路径太长, 则在堆上分配缓冲区. - buf := make([]byte, 0, stackBufSize) - if length := len(path) + 1; length > stackBufSize { - buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区 + if buf != nil { + buf = buf[:0] + } + if cap(buf) < len(path)+1 { + var stackBuf [stackBufSize]byte + if len(path)+1 <= stackBufSize { + buf = stackBuf[:0] + } else { + buf = make([]byte, 0, len(path)+1) // 如果路径太长, 则分配更大的缓冲区 + } } ciPath := n.findCaseInsensitivePathRec( From 57847fa44647a1670f49bc22d1889b7e6203e0c8 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:32:14 +0800 Subject: [PATCH 4/6] fix: avoid unsafe header buffer reuse Use safe string copies for pooled header buffers and simplify case-insensitive lookup buffering now that the pseudo stack path was ineffective. This addresses review concerns without changing the routing semantics. --- engine.go | 4 ++-- reverseproxy.go | 2 +- tree.go | 11 +---------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/engine.go b/engine.go index 5214654..698fbd5 100644 --- a/engine.go +++ b/engine.go @@ -145,7 +145,7 @@ var methodNotAllowedHandler HandlerFunc = func(c *Context) { allowHeader = append(allowHeader, method...) } c.allowHeaderBuf = allowHeader[:0] - c.Writer.Header().Set("Allow", BytesToString(allowHeader)) + c.Writer.Header().Set("Allow", string(allowHeader)) c.Status(http.StatusOK) return } @@ -810,7 +810,7 @@ func (engine *Engine) handleRequest(c *Context) { ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash) if found { c.fixedPathBuf = ciPath[:0] - c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 + c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径 return } c.fixedPathBuf = ciPath[:0] diff --git a/reverseproxy.go b/reverseproxy.go index ff49aef..fe66e2b 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -709,7 +709,7 @@ func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { allowHeader = append(allowHeader, method...) } c.allowHeaderBuf = allowHeader[:0] - c.Writer.Header().Set("Allow", BytesToString(allowHeader)) + c.Writer.Header().Set("Allow", string(allowHeader)) } } } diff --git a/tree.go b/tree.go index 6595655..b159c8d 100644 --- a/tree.go +++ b/tree.go @@ -723,20 +723,11 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]by } func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) { - const stackBufSize = 128 // 栈上缓冲区的默认大小 - - // 在常见情况下使用栈上静态大小的缓冲区. - // 如果路径太长, 则在堆上分配缓冲区. if buf != nil { buf = buf[:0] } if cap(buf) < len(path)+1 { - var stackBuf [stackBufSize]byte - if len(path)+1 <= stackBufSize { - buf = stackBuf[:0] - } else { - buf = make([]byte, 0, len(path)+1) // 如果路径太长, 则分配更大的缓冲区 - } + buf = make([]byte, 0, len(path)+1) } ciPath := n.findCaseInsensitivePathRec( From fa027347d32012678df1dd7aafded6d8e1444c1a Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:35:39 +0800 Subject: [PATCH 5/6] fix: reduce default error response overhead Encode the built-in 404 and 405 payload with a fixed struct instead of a map so default error pages allocate less on the hot miss path. Add a regression test to keep the JSON shape stable. --- engine.go | 12 +++++++----- engine_test.go | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/engine.go b/engine.go index 698fbd5..81d3673 100644 --- a/engine.go +++ b/engine.go @@ -126,6 +126,12 @@ type ErrorHandler func(c *Context, code int, err error) var errMethodNotAllowed = errors.New("method not allowed") var errNotFound = errors.New("not found") +type defaultErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` +} + var methodNotAllowedHandler HandlerFunc = func(c *Context) { httpMethod := c.Request.Method requestPath := routeLookupPath(c.Request) @@ -187,11 +193,7 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是 if err != nil { errMsg = err.Error() } - c.JSON(code, H{ - "code": code, - "message": http.StatusText(code), - "error": errMsg, - }) + c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg}) c.Writer.Flush() c.Abort() return diff --git a/engine_test.go b/engine_test.go index 292d5e2..71f9772 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1,6 +1,7 @@ package touka import ( + "encoding/json" "net/http" "testing" ) @@ -100,3 +101,23 @@ func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) { t.Fatalf("expected Allow header to list matching methods, got %q", allow) } } + +func TestDefaultErrorHandleJSONShape(t *testing.T) { + engine := New() + rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code) + } + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) + } + if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" { + t.Fatalf("unexpected error payload: %+v", body) + } +} From 987ea81329e34d43357f200ea58a38226d4b1d3b Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:57:16 +0800 Subject: [PATCH 6/6] fix: avoid fixed-path miss panic and trim 405 fallback work --- engine.go | 14 +++++++++----- engine_test.go | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/engine.go b/engine.go index 81d3673..f9d233a 100644 --- a/engine.go +++ b/engine.go @@ -155,22 +155,25 @@ var methodNotAllowedHandler HandlerFunc = func(c *Context) { c.Status(http.StatusOK) return } + return } // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + tempSkippedNodes := GetTempSkippedNodes() for _, treeIter := range engine.methodTrees { if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 continue } // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() + *tempSkippedNodes = (*tempSkippedNodes)[:0] value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - PutTempSkippedNodes(tempSkippedNodes) if value.handlers != nil { + PutTempSkippedNodes(tempSkippedNodes) // 使用定义的ErrorHandle处理 engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed) return } } + PutTempSkippedNodes(tempSkippedNodes) } var notFoundHandler HandlerFunc = func(c *Context) { @@ -815,7 +818,7 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径 return } - c.fixedPathBuf = ciPath[:0] + c.fixedPathBuf = c.fixedPathBuf[:0] } } } @@ -872,15 +875,16 @@ func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods [ } else { allowedMethods = allowedMethods[:0] } + tempSkippedNodes := GetTempSkippedNodes() for _, treeIter := range engine.methodTrees { // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() + *tempSkippedNodes = (*tempSkippedNodes)[:0] value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) - PutTempSkippedNodes(tempSkippedNodes) if value.handlers != nil { allowedMethods = append(allowedMethods, treeIter.method) } } + PutTempSkippedNodes(tempSkippedNodes) return allowedMethods } diff --git a/engine_test.go b/engine_test.go index 71f9772..571f4b7 100644 --- a/engine_test.go +++ b/engine_test.go @@ -48,6 +48,24 @@ func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) { } } +func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) { + engine := New() + engine.GET("/Users/Profile", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic for fixed-path miss: %v", r) + } + }() + + rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code) + } +} + func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) { engine := New() engine.NoRoute(func(c *Context) {