diff --git a/context.go b/context.go index 9c4ba7e..f24ceb0 100644 --- a/context.go +++ b/context.go @@ -97,7 +97,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 - c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 + c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map c.Errors = c.Errors[:0] // 清空 Errors 切片 c.queryCache = nil // 清空查询参数缓存 c.formCache = nil // 清空表单数据缓存 diff --git a/context_benchmark_test.go b/context_benchmark_test.go new file mode 100644 index 0000000..2198c59 --- /dev/null +++ b/context_benchmark_test.go @@ -0,0 +1,78 @@ +package touka + +import ( + "net/http" + "testing" +) + +func TestContextResetKeepsKeysNilUntilSet(t *testing.T) { + c, _ := CreateTestContext(nil) + if c.Keys != nil { + t.Fatalf("expected fresh test context Keys to be nil before first Set") + } + + c.Set("answer", 42) + if c.Keys == nil { + t.Fatalf("expected Set to allocate Keys map") + } + if value, exists := c.Get("answer"); !exists || value != 42 { + t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists) + } + + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + c.reset(c.Writer, req) + + if c.Keys != nil { + t.Fatalf("expected reset to clear Keys without allocating a new map") + } + if value, exists := c.Get("answer"); exists || value != nil { + t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists) + } + + ctxValue := c.Value("missing") + if ctxValue != nil { + t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue) + } + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected MustGet to panic for missing key after reset") + } + }() + _ = c.MustGet("answer") +} + +func BenchmarkContextReset(b *testing.B) { + b.Run("NoKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + } + }) + + b.Run("WithKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + c.Set("request-id", i) + } + }) +} diff --git a/engine.go b/engine.go index ece023d..b2cc952 100644 --- a/engine.go +++ b/engine.go @@ -117,6 +117,46 @@ type ErrorHandle struct { type ErrorHandler func(c *Context, code int, err error) +var errMethodNotAllowed = errors.New("method not allowed") +var errNotFound = errors.New("not found") + +var methodNotAllowedHandler HandlerFunc = func(c *Context) { + httpMethod := c.Request.Method + requestPath := routeLookupPath(c.Request) + engine := c.engine + // 是否是OPTIONS方式 + if httpMethod == http.MethodOptions { + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + allowedMethods := engine.allowedMethodsForPath(requestPath) + if len(allowedMethods) > 0 { + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + c.Status(http.StatusOK) + return + } + } + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + for _, treeIter := range engine.methodTrees { + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + continue + } + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 + PutTempSkippedNodes(tempSkippedNodes) + if value.handlers != nil { + // 使用定义的ErrorHandle处理 + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed) + return + } + } +} + +var notFoundHandler HandlerFunc = func(c *Context) { + engine := c.engine + engine.errorHandle.handler(c, http.StatusNotFound, errNotFound) +} + // defaultErrorHandle 默认错误处理 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 select { @@ -479,45 +519,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { // 405中间件 func MethodNotAllowed() HandlerFunc { - return func(c *Context) { - httpMethod := c.Request.Method - requestPath := routeLookupPath(c.Request) - engine := c.engine - // 是否是OPTIONS方式 - if httpMethod == http.MethodOptions { - // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := engine.allowedMethodsForPath(requestPath) - if len(allowedMethods) > 0 { - // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) - c.Status(http.StatusOK) - return - } - } - // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 - for _, treeIter := range engine.methodTrees { - if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 - continue - } - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - // 使用定义的ErrorHandle处理 - engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) - return - } - } - } + return methodNotAllowedHandler } // 404最后处理 func NotFound() HandlerFunc { - return func(c *Context) { - engine := c.engine - engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found")) - } + return notFoundHandler } // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go new file mode 100644 index 0000000..5780230 --- /dev/null +++ b/engine_benchmark_test.go @@ -0,0 +1,64 @@ +package touka + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +var benchmarkStatusCode int + +func buildServeHTTPBenchmarkEngine() *Engine { + engine := New() + engine.GET("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.GET("/api/v1/users/:id", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + return engine +} + +func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) { + b.Helper() + + req, err := http.NewRequest(method, path, nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr = httptest.NewRecorder() + engine.ServeHTTP(rr, req) + } + + benchmarkStatusCode = rr.Code +} + +func BenchmarkServeHTTP(b *testing.B) { + engine := buildServeHTTPBenchmarkEngine() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users") + }) + + b.Run("NotFound", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist") + }) + + b.Run("MethodNotAllowed", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users") + }) + + b.Run("OptionsAllow", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") + }) +}