diff --git a/.gitignore b/.gitignore index 30d74d2..6f301cd 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -test \ No newline at end of file +test +/bench_route_match_baseline.txt diff --git a/context.go b/context.go index 9c4ba7e..f06d21e 100644 --- a/context.go +++ b/context.go @@ -73,6 +73,12 @@ type Context struct { // skippedNodes 用于记录跳过的节点信息,以便回溯 // 通常在处理嵌套路由时使用 SkippedNodes []skippedNode + + // fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲. + fixedPathBuf []byte + + allowedMethodsBuf []string + allowHeaderBuf []byte } // --- Context 相关方法实现 --- @@ -97,7 +103,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 - c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 + c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map c.Errors = c.Errors[:0] // 清空 Errors 切片 c.queryCache = nil // 清空查询参数缓存 c.formCache = nil // 清空表单数据缓存 @@ -111,6 +117,15 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } else { c.SkippedNodes = make([]skippedNode, 0, 256) } + if cap(c.fixedPathBuf) > 0 { + c.fixedPathBuf = c.fixedPathBuf[:0] + } + if cap(c.allowedMethodsBuf) > 0 { + c.allowedMethodsBuf = c.allowedMethodsBuf[:0] + } + if cap(c.allowHeaderBuf) > 0 { + c.allowHeaderBuf = c.allowHeaderBuf[:0] + } } // Next 在处理链中执行下一个处理函数 diff --git a/context_benchmark_test.go b/context_benchmark_test.go new file mode 100644 index 0000000..2198c59 --- /dev/null +++ b/context_benchmark_test.go @@ -0,0 +1,78 @@ +package touka + +import ( + "net/http" + "testing" +) + +func TestContextResetKeepsKeysNilUntilSet(t *testing.T) { + c, _ := CreateTestContext(nil) + if c.Keys != nil { + t.Fatalf("expected fresh test context Keys to be nil before first Set") + } + + c.Set("answer", 42) + if c.Keys == nil { + t.Fatalf("expected Set to allocate Keys map") + } + if value, exists := c.Get("answer"); !exists || value != 42 { + t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists) + } + + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + c.reset(c.Writer, req) + + if c.Keys != nil { + t.Fatalf("expected reset to clear Keys without allocating a new map") + } + if value, exists := c.Get("answer"); exists || value != nil { + t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists) + } + + ctxValue := c.Value("missing") + if ctxValue != nil { + t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue) + } + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected MustGet to panic for missing key after reset") + } + }() + _ = c.MustGet("answer") +} + +func BenchmarkContextReset(b *testing.B) { + b.Run("NoKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + } + }) + + b.Run("WithKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + c.Set("request-id", i) + } + }) +} diff --git a/engine.go b/engine.go index b7cf330..f9d233a 100644 --- a/engine.go +++ b/engine.go @@ -11,6 +11,7 @@ import ( "reflect" "runtime" "strings" + "unicode/utf8" "net/http" @@ -82,6 +83,11 @@ type Engine struct { // GlobalMaxRequestBodySize 全局请求体Body大小限制 GlobalMaxRequestBodySize int64 + + notFoundChain HandlersChain + notFoundNoMethodChain HandlersChain + unmatchedFSChain HandlersChain + unmatchedFSNoMethodChain HandlersChain } // HandleFunc 注册一个或多个 HTTP 方法的路由 @@ -117,6 +123,64 @@ type ErrorHandle struct { type ErrorHandler func(c *Context, code int, err error) +var errMethodNotAllowed = errors.New("method not allowed") +var errNotFound = errors.New("not found") + +type defaultErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` +} + +var methodNotAllowedHandler HandlerFunc = func(c *Context) { + httpMethod := c.Request.Method + requestPath := routeLookupPath(c.Request) + engine := c.engine + // 是否是OPTIONS方式 + if httpMethod == http.MethodOptions { + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0]) + c.allowedMethodsBuf = allowedMethods[:0] + if len(allowedMethods) > 0 { + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allowedMethods { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", string(allowHeader)) + c.Status(http.StatusOK) + return + } + return + } + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + tempSkippedNodes := GetTempSkippedNodes() + for _, treeIter := range engine.methodTrees { + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + continue + } + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + *tempSkippedNodes = (*tempSkippedNodes)[:0] + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 + if value.handlers != nil { + PutTempSkippedNodes(tempSkippedNodes) + // 使用定义的ErrorHandle处理 + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed) + return + } + } + PutTempSkippedNodes(tempSkippedNodes) +} + +var notFoundHandler HandlerFunc = func(c *Context) { + engine := c.engine + engine.errorHandle.handler(c, http.StatusNotFound, errNotFound) +} + // defaultErrorHandle 默认错误处理 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 select { @@ -132,11 +196,7 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是 if err != nil { errMsg = err.Error() } - c.JSON(code, H{ - "code": code, - "message": http.StatusText(code), - "error": errMsg, - }) + c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg}) c.Writer.Flush() c.Abort() return @@ -211,6 +271,7 @@ func New() *Engine { TLSServerConfigurator: nil, GlobalMaxRequestBodySize: -1, } + engine.rebuildFallbackChains() engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() @@ -266,6 +327,7 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) { // 是否开启MethodNotAllowed func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { engine.HandleMethodNotAllowed = enable + engine.rebuildFallbackChains() } // SetLogger传入实例 @@ -306,6 +368,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF engine.unMatchFS.ServeUnmatchedAsFS = false engine.UnMatchFSRoutes = nil } + engine.rebuildFallbackChains() } // 获取默认Protocol配置 @@ -479,57 +542,64 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { // 405中间件 func MethodNotAllowed() HandlerFunc { - return func(c *Context) { - httpMethod := c.Request.Method - requestPath := routeLookupPath(c.Request) - engine := c.engine - // 是否是OPTIONS方式 - if httpMethod == http.MethodOptions { - // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := engine.allowedMethodsForPath(requestPath) - if len(allowedMethods) > 0 { - // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) - c.Status(http.StatusOK) - return - } - } - // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 - for _, treeIter := range engine.methodTrees { - if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 - continue - } - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - // 使用定义的ErrorHandle处理 - engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) - return - } - } - } + return methodNotAllowedHandler } // 404最后处理 func NotFound() HandlerFunc { - return func(c *Context) { - engine := c.engine - engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found")) - } + return notFoundHandler } // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) func (Engine *Engine) NoRoute(handler HandlerFunc) { Engine.noRoute = handler Engine.noRoutes = nil + Engine.rebuildFallbackChains() } // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { Engine.noRoute = nil Engine.noRoutes = handlerFuncs + Engine.rebuildFallbackChains() +} + +func (engine *Engine) rebuildFallbackChains() { + buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain { + finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound + if includeMethodNotAllowed { + finalSize++ + } + if includeUnmatchedFS { + finalSize += len(engine.UnMatchFSRoutes) + } + if engine.noRoute != nil { + finalSize++ + } else { + finalSize += len(engine.noRoutes) + } + + chain := make(HandlersChain, 0, finalSize) + chain = append(chain, engine.globalHandlers...) + if includeMethodNotAllowed { + chain = append(chain, methodNotAllowedHandler) + } + if includeUnmatchedFS { + chain = append(chain, engine.UnMatchFSRoutes...) + } + if engine.noRoute != nil { + chain = append(chain, engine.noRoute) + } else if len(engine.noRoutes) > 0 { + chain = append(chain, engine.noRoutes...) + } + chain = append(chain, notFoundHandler) + return chain + } + + engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false) + engine.notFoundNoMethodChain = buildChain(false, false) + engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS) + engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS) } // combineHandlers 组合多个处理函数链为一个 @@ -546,6 +616,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // 这些中间件将应用于所有注册的路由 func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { engine.globalHandlers = append(engine.globalHandlers, middleware...) + engine.rebuildFallbackChains() return engine } @@ -739,47 +810,24 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 return } - // 尝试不区分大小写的查找 - // 直接在 rootNode 上调用 findCaseInsensitivePath 方法 - ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) - if found && engine.RedirectFixedPath { - c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 - return + if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) { + // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. + ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash) + if found { + c.fixedPathBuf = ciPath[:0] + c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径 + return + } + c.fixedPathBuf = c.fixedPathBuf[:0] } } } - // 构建处理链 - // 组合全局中间件和路由处理函数 - handlers := engine.globalHandlers - - // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 - // 则在全局中间件之后添加 MethodNotAllowed 处理器 - if engine.HandleMethodNotAllowed { - handlers = append(handlers, MethodNotAllowed()) - } - - // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed - // 则在处理链的最后添加 UnMatchFS 处理器 if engine.unMatchFS.ServeUnmatchedAsFS { - /* - var unMatchFSHandle = c.engine.unMatchFileServer - handlers = append(handlers, unMatchFSHandle) - */ - handlers = append(handlers, engine.UnMatchFSRoutes...) + c.handlers = engine.unmatchedFSChain + } else { + c.handlers = engine.notFoundChain } - - // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS - // 则在处理链的最后添加 NoRoute 处理器 - if engine.noRoute != nil { - handlers = append(handlers, engine.noRoute) - } else if len(engine.noRoutes) > 0 { - handlers = append(handlers, engine.noRoutes...) - } - - handlers = append(handlers, NotFound()) - - c.handlers = handlers c.Next() // 执行处理函数链 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } @@ -805,17 +853,38 @@ func isGeneralOptionsRequest(req *http.Request) bool { return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" } -func (engine *Engine) allowedMethodsForPath(requestPath string) []string { - allowedMethods := make([]string, 0, len(engine.methodTrees)) +func shouldTryFixedPathLookup(path string, root *node) bool { + if root != nil && root.hasCaseInsensitivePath { + return true + } + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false +} + +func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string { + if cap(allowedMethods) < len(engine.methodTrees) { + allowedMethods = make([]string, 0, len(engine.methodTrees)) + } else { + allowedMethods = allowedMethods[:0] + } + tempSkippedNodes := GetTempSkippedNodes() for _, treeIter := range engine.methodTrees { // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() + *tempSkippedNodes = (*tempSkippedNodes)[:0] value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) - PutTempSkippedNodes(tempSkippedNodes) if value.handlers != nil { allowedMethods = append(allowedMethods, treeIter.method) } } + PutTempSkippedNodes(tempSkippedNodes) return allowedMethods } diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go new file mode 100644 index 0000000..666e8b2 --- /dev/null +++ b/engine_benchmark_test.go @@ -0,0 +1,71 @@ +package touka + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +var benchmarkStatusCode int + +func buildServeHTTPBenchmarkEngine() *Engine { + engine := New() + engine.GET("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.GET("/api/v1/users/:id", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + return engine +} + +func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) { + b.Helper() + + req, err := http.NewRequest(method, path, nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr = httptest.NewRecorder() + engine.ServeHTTP(rr, req) + } + + benchmarkStatusCode = rr.Code +} + +func BenchmarkServeHTTP(b *testing.B) { + engine := buildServeHTTPBenchmarkEngine() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users") + }) + + b.Run("NotFound", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist") + }) + + b.Run("MethodNotAllowed", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users") + }) + + b.Run("OptionsAllow", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") + }) + + b.Run("FixedPathRedirect", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS") + }) +} diff --git a/engine_test.go b/engine_test.go new file mode 100644 index 0000000..571f4b7 --- /dev/null +++ b/engine_test.go @@ -0,0 +1,141 @@ +package touka + +import ( + "encoding/json" + "net/http" + "testing" +) + +func TestHandleRequestRedirectFixedPath(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" { + t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location) + } +} + +func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code) + } +} + +func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) { + engine := New() + engine.GET("/Users/Profile", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code) + } + if location := rr.Header().Get("Location"); location != "/Users/Profile" { + t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location) + } +} + +func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) { + engine := New() + engine.GET("/Users/Profile", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic for fixed-path miss: %v", r) + } + }() + + rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code) + } +} + +func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) { + engine := New() + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "hit" { + t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got) + } +} + +func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "" { + t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got) + } +} + +func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil) + if rr.Code != http.StatusOK { + t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code) + } + allow := rr.Header().Get("Allow") + if allow != "GET, POST" && allow != "POST, GET" { + t.Fatalf("expected Allow header to list matching methods, got %q", allow) + } +} + +func TestDefaultErrorHandleJSONShape(t *testing.T) { + engine := New() + rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code) + } + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) + } + if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" { + t.Fatalf("unexpected error payload: %+v", body) + } +} diff --git a/reverseproxy.go b/reverseproxy.go index 1b89b2a..fe66e2b 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -699,8 +699,17 @@ func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { if c.engine != nil { if c.Request != nil && c.Request.RequestURI != "*" { - if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 { - c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) + if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 { + c.allowedMethodsBuf = allow[:0] + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allow { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", string(allowHeader)) } } } diff --git a/route_match_benchmark_test.go b/route_match_benchmark_test.go new file mode 100644 index 0000000..e0dd2aa --- /dev/null +++ b/route_match_benchmark_test.go @@ -0,0 +1,130 @@ +package touka + +import "testing" + +var ( + benchmarkRouteHandlers HandlersChain + benchmarkRouteFullPath string + benchmarkRouteParamsLen int + benchmarkRouteCIPath []byte + benchmarkRouteCIFound bool +) + +func buildRouteMatchBenchmarkTree() *node { + tree := &node{} + routes := []string{ + "/", + "/health", + "/contact", + "/api/v1/users", + "/api/v1/users/:id", + "/api/v1/users/:id/settings", + "/assets/*filepath", + "/abc/b", + "/abc/:p1/cde", + "/abc/:p1/:p2/def/*filepath", + } + + for _, route := range routes { + tree.addRoute(route, fakeHandler(route)) + } + + return tree +} + +func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) { + b.Helper() + + params := make(Params, 0, 4) + skipped := make([]skippedNode, 0, 8) + + value := tree.getValue(path, ¶ms, &skipped, true) + if wantFullPath == "" { + if value.handlers != nil { + b.Fatalf("expected no match for %q, got %q", path, value.fullPath) + } + } else { + if value.handlers == nil { + b.Fatalf("expected match for %q, got nil handlers", path) + } + if value.fullPath != wantFullPath { + b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath) + } + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + params = params[:0] + skipped = skipped[:0] + value = tree.getValue(path, ¶ms, &skipped, true) + } + + benchmarkRouteHandlers = value.handlers + benchmarkRouteFullPath = value.fullPath + if value.params != nil { + benchmarkRouteParamsLen = len(*value.params) + } else { + benchmarkRouteParamsLen = 0 + } +} + +func BenchmarkRouteMatch(b *testing.B) { + tree := buildRouteMatchBenchmarkTree() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users") + }) + + b.Run("ParamHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id") + }) + + b.Run("BacktrackingHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath") + }) + + b.Run("Miss", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/does/not/exist", "") + }) + + b.Run("CaseInsensitiveHit", func(b *testing.B) { + path := "/API/V1/USERS/123/SETTINGS" + out, found := tree.findCaseInsensitivePath(path, true) + if !found { + b.Fatalf("expected fixed-path match for %q", path) + } + if got := string(out); got != "/api/v1/users/123/settings" { + b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) + + b.Run("CaseInsensitiveMiss", func(b *testing.B) { + path := "/DOES/NOT/EXIST" + out, found := tree.findCaseInsensitivePath(path, true) + if found || out != nil { + b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) +} diff --git a/tree.go b/tree.go index f5452f4..b159c8d 100644 --- a/tree.go +++ b/tree.go @@ -121,14 +121,28 @@ const ( // node 表示路由树中的一个节点. type node struct { - path string // 当前节点的路径段 - indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 - wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) - nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) - priority uint32 // 节点的优先级, 用于查找时优先匹配 - children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 - handlers HandlersChain // 绑定到此节点的处理函数链 - fullPath string // 完整路径, 用于调试和错误信息 + path string // 当前节点的路径段 + indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 + wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) + hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由 + nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) + priority uint32 // 节点的优先级, 用于查找时优先匹配 + children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 + handlers HandlersChain // 绑定到此节点的处理函数链 + fullPath string // 完整路径, 用于调试和错误信息 +} + +func routeNeedsCaseInsensitiveLookup(path string) bool { + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false } // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. @@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int { func (n *node) addRoute(path string, handlers HandlersChain) { fullPath := path // 记录完整的路径 n.priority++ // 增加当前节点的优先级 + if routeNeedsCaseInsensitiveLookup(path) { + n.hasCaseInsensitivePath = true + } // 如果是空树(根节点) if len(n.path) == 0 && len(n.children) == 0 { @@ -452,12 +469,14 @@ type skippedNode struct { // 建议进行 TSR(尾部斜杠重定向). func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { var globalParamsCount int16 // 全局参数计数 + var backtrackToWildChild bool walk: // 外部循环用于遍历路由树 for { prefix := n.path // 当前节点的路径前缀 if len(path) > len(prefix) { if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 + pathAtNode := path path = path[len(prefix):] // 移除已匹配的前缀 // 在访问 path[0] 之前进行安全检查 @@ -467,30 +486,26 @@ walk: // 外部循环用于遍历路由树 // 优先尝试所有非通配符子节点, 通过匹配索引字符 idxc := path[0] // 剩余路径的第一个字符 - for i, c := range []byte(n.indices) { - if c == idxc { // 如果找到匹配的索引字符 - // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 - if n.wildChild { - index := len(*skippedNodes) - *skippedNodes = (*skippedNodes)[:index+1] - (*skippedNodes)[index] = skippedNode{ - path: prefix + path, // 记录跳过的路径 - node: &node{ // 复制当前节点的状态 - path: n.path, - wildChild: n.wildChild, - nType: n.nType, - priority: n.priority, - children: n.children, - handlers: n.handlers, - fullPath: n.fullPath, - }, - paramsCount: globalParamsCount, // 记录当前参数计数 + if !backtrackToWildChild { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 如果找到匹配的索引字符 + // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 + if n.wildChild { + index := len(*skippedNodes) + *skippedNodes = (*skippedNodes)[:index+1] + (*skippedNodes)[index] = skippedNode{ + path: pathAtNode, // 记录进入当前节点时的剩余路径 + node: n, + paramsCount: globalParamsCount, // 记录当前参数计数 + } } - } - n = n.children[i] // 移动到匹配的子节点 - continue walk // 继续外部循环 + n = n.children[i] // 移动到匹配的子节点 + continue walk // 继续外部循环 + } } + } else { + backtrackToWildChild = false } if !n.wildChild { @@ -507,7 +522,8 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 } globalParamsCount = skippedNode.paramsCount // 恢复参数计数 - continue walk // 继续外部循环 + backtrackToWildChild = true + continue walk // 继续外部循环 } } } @@ -547,7 +563,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path[:end] // 提取参数值 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(val); err == nil { val = v // 解码成功则更新值 } @@ -599,7 +615,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path // 参数值是剩余的整个路径 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(path); err == nil { val = v // 解码成功则更新值 } @@ -634,6 +650,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -658,8 +675,8 @@ walk: // 外部循环用于遍历路由树 } // 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议 - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 @@ -688,6 +705,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -701,13 +719,15 @@ walk: // 外部循环用于遍历路由树 // 它还可以选择修复尾部斜杠. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { - const stackBufSize = 128 // 栈上缓冲区的默认大小 + return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash) +} - // 在常见情况下使用栈上静态大小的缓冲区. - // 如果路径太长, 则在堆上分配缓冲区. - buf := make([]byte, 0, stackBufSize) - if length := len(path) + 1; length > stackBufSize { - buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区 +func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) { + if buf != nil { + buf = buf[:0] + } + if cap(buf) < len(path)+1 { + buf = make([]byte, 0, len(path)+1) } ciPath := n.findCaseInsensitivePathRec( @@ -758,8 +778,8 @@ walk: // 外部循环用于遍历路由树 // 未找到处理函数. // 尝试通过添加尾部斜杠来修复路径 if fixTrailingSlash { - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 @@ -781,8 +801,8 @@ walk: // 外部循环用于遍历路由树 if rb[0] != 0 { // 旧 rune 未处理完 idxc := rb[0] - for i, c := range []byte(n.indices) { - if c == idxc { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) @@ -813,9 +833,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 小写匹配 - if c == idxc { + if n.indices[i] == idxc { // 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在 if out := n.children[i].findCaseInsensitivePathRec( path, ciPath, rb, fixTrailingSlash, @@ -832,9 +852,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 大写匹配 - if c == idxc { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) diff --git a/tree_test.go b/tree_test.go index 7665afd..a35a1a8 100644 --- a/tree_test.go +++ b/tree_test.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" "testing" + "time" ) // Used as a workaround since we can't compare functions or their addresses @@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode { return &ps } +func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue { + t.Helper() + + resultCh := make(chan nodeValue, 1) + go func() { + resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape) + }() + + select { + case value := <-resultCh: + return value + case <-time.After(2 * time.Second): + t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path) + return nodeValue{} + } +} + func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { unescape := false if len(unescapes) >= 1 { @@ -1104,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) { t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams) } } + +func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) { + tests := []struct { + name string + routes []string + requestPath string + wantFullPath string + wantParams Params + }{ + { + name: "param route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/details"}, + requestPath: "/foo/bar/details", + wantFullPath: "/foo/:id/details", + wantParams: Params{{Key: "id", Value: "bar"}}, + }, + { + name: "catch-all route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/*rest"}, + requestPath: "/foo/bar/baz.txt", + wantFullPath: "/foo/:id/*rest", + wantParams: Params{ + {Key: "id", Value: "bar"}, + {Key: "rest", Value: "/baz.txt"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree := &node{} + for _, route := range tt.routes { + tree.addRoute(route, fakeHandler(route)) + } + + value := getValueWithTimeout(t, tree, tt.requestPath, false) + if value.handlers == nil { + t.Fatalf("expected handlers for %q", tt.requestPath) + } + if value.fullPath != tt.wantFullPath { + t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath) + } + if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) { + t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params) + } + }) + } +}