From 5d979e56707a239e682f0c374cd6cc78d3bb2f3a Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 08:39:10 +0800 Subject: [PATCH] fix: reduce per-request context and fallback overhead Make Context keys lazy so requests that never call Set stop allocating on reset. Reuse stable 404 and 405 handlers and add focused benchmarks so ServeHTTP miss paths stay measurable. --- context.go | 2 +- context_benchmark_test.go | 78 +++++++++++++++++++++++++++++++++++++++ engine.go | 77 ++++++++++++++++++++------------------ engine_benchmark_test.go | 64 ++++++++++++++++++++++++++++++++ 4 files changed, 185 insertions(+), 36 deletions(-) create mode 100644 context_benchmark_test.go create mode 100644 engine_benchmark_test.go diff --git a/context.go b/context.go index 9c4ba7e..f24ceb0 100644 --- a/context.go +++ b/context.go @@ -97,7 +97,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 - c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 + c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map c.Errors = c.Errors[:0] // 清空 Errors 切片 c.queryCache = nil // 清空查询参数缓存 c.formCache = nil // 清空表单数据缓存 diff --git a/context_benchmark_test.go b/context_benchmark_test.go new file mode 100644 index 0000000..2198c59 --- /dev/null +++ b/context_benchmark_test.go @@ -0,0 +1,78 @@ +package touka + +import ( + "net/http" + "testing" +) + +func TestContextResetKeepsKeysNilUntilSet(t *testing.T) { + c, _ := CreateTestContext(nil) + if c.Keys != nil { + t.Fatalf("expected fresh test context Keys to be nil before first Set") + } + + c.Set("answer", 42) + if c.Keys == nil { + t.Fatalf("expected Set to allocate Keys map") + } + if value, exists := c.Get("answer"); !exists || value != 42 { + t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists) + } + + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + c.reset(c.Writer, req) + + if c.Keys != nil { + t.Fatalf("expected reset to clear Keys without allocating a new map") + } + if value, exists := c.Get("answer"); exists || value != nil { + t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists) + } + + ctxValue := c.Value("missing") + if ctxValue != nil { + t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue) + } + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected MustGet to panic for missing key after reset") + } + }() + _ = c.MustGet("answer") +} + +func BenchmarkContextReset(b *testing.B) { + b.Run("NoKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + } + }) + + b.Run("WithKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + c.Set("request-id", i) + } + }) +} diff --git a/engine.go b/engine.go index ece023d..b2cc952 100644 --- a/engine.go +++ b/engine.go @@ -117,6 +117,46 @@ type ErrorHandle struct { type ErrorHandler func(c *Context, code int, err error) +var errMethodNotAllowed = errors.New("method not allowed") +var errNotFound = errors.New("not found") + +var methodNotAllowedHandler HandlerFunc = func(c *Context) { + httpMethod := c.Request.Method + requestPath := routeLookupPath(c.Request) + engine := c.engine + // 是否是OPTIONS方式 + if httpMethod == http.MethodOptions { + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + allowedMethods := engine.allowedMethodsForPath(requestPath) + if len(allowedMethods) > 0 { + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + c.Status(http.StatusOK) + return + } + } + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + for _, treeIter := range engine.methodTrees { + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + continue + } + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 + PutTempSkippedNodes(tempSkippedNodes) + if value.handlers != nil { + // 使用定义的ErrorHandle处理 + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed) + return + } + } +} + +var notFoundHandler HandlerFunc = func(c *Context) { + engine := c.engine + engine.errorHandle.handler(c, http.StatusNotFound, errNotFound) +} + // defaultErrorHandle 默认错误处理 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 select { @@ -479,45 +519,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { // 405中间件 func MethodNotAllowed() HandlerFunc { - return func(c *Context) { - httpMethod := c.Request.Method - requestPath := routeLookupPath(c.Request) - engine := c.engine - // 是否是OPTIONS方式 - if httpMethod == http.MethodOptions { - // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := engine.allowedMethodsForPath(requestPath) - if len(allowedMethods) > 0 { - // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) - c.Status(http.StatusOK) - return - } - } - // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 - for _, treeIter := range engine.methodTrees { - if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 - continue - } - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - // 使用定义的ErrorHandle处理 - engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) - return - } - } - } + return methodNotAllowedHandler } // 404最后处理 func NotFound() HandlerFunc { - return func(c *Context) { - engine := c.engine - engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found")) - } + return notFoundHandler } // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go new file mode 100644 index 0000000..5780230 --- /dev/null +++ b/engine_benchmark_test.go @@ -0,0 +1,64 @@ +package touka + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +var benchmarkStatusCode int + +func buildServeHTTPBenchmarkEngine() *Engine { + engine := New() + engine.GET("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.GET("/api/v1/users/:id", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + return engine +} + +func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) { + b.Helper() + + req, err := http.NewRequest(method, path, nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr = httptest.NewRecorder() + engine.ServeHTTP(rr, req) + } + + benchmarkStatusCode = rr.Code +} + +func BenchmarkServeHTTP(b *testing.B) { + engine := buildServeHTTPBenchmarkEngine() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users") + }) + + b.Run("NotFound", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist") + }) + + b.Run("MethodNotAllowed", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users") + }) + + b.Run("OptionsAllow", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") + }) +}