diff --git a/context.go b/context.go index f24ceb0..f06d21e 100644 --- a/context.go +++ b/context.go @@ -73,6 +73,12 @@ type Context struct { // skippedNodes 用于记录跳过的节点信息,以便回溯 // 通常在处理嵌套路由时使用 SkippedNodes []skippedNode + + // fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲. + fixedPathBuf []byte + + allowedMethodsBuf []string + allowHeaderBuf []byte } // --- Context 相关方法实现 --- @@ -111,6 +117,15 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } else { c.SkippedNodes = make([]skippedNode, 0, 256) } + if cap(c.fixedPathBuf) > 0 { + c.fixedPathBuf = c.fixedPathBuf[:0] + } + if cap(c.allowedMethodsBuf) > 0 { + c.allowedMethodsBuf = c.allowedMethodsBuf[:0] + } + if cap(c.allowHeaderBuf) > 0 { + c.allowHeaderBuf = c.allowHeaderBuf[:0] + } } // Next 在处理链中执行下一个处理函数 diff --git a/engine.go b/engine.go index b2cc952..5214654 100644 --- a/engine.go +++ b/engine.go @@ -11,6 +11,7 @@ import ( "reflect" "runtime" "strings" + "unicode/utf8" "net/http" @@ -82,6 +83,11 @@ type Engine struct { // GlobalMaxRequestBodySize 全局请求体Body大小限制 GlobalMaxRequestBodySize int64 + + notFoundChain HandlersChain + notFoundNoMethodChain HandlersChain + unmatchedFSChain HandlersChain + unmatchedFSNoMethodChain HandlersChain } // HandleFunc 注册一个或多个 HTTP 方法的路由 @@ -127,10 +133,19 @@ var methodNotAllowedHandler HandlerFunc = func(c *Context) { // 是否是OPTIONS方式 if httpMethod == http.MethodOptions { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := engine.allowedMethodsForPath(requestPath) + allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0]) + c.allowedMethodsBuf = allowedMethods[:0] if len(allowedMethods) > 0 { // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allowedMethods { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", BytesToString(allowHeader)) c.Status(http.StatusOK) return } @@ -251,6 +266,7 @@ func New() *Engine { TLSServerConfigurator: nil, GlobalMaxRequestBodySize: -1, } + engine.rebuildFallbackChains() engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() @@ -306,6 +322,7 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) { // 是否开启MethodNotAllowed func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { engine.HandleMethodNotAllowed = enable + engine.rebuildFallbackChains() } // SetLogger传入实例 @@ -346,6 +363,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF engine.unMatchFS.ServeUnmatchedAsFS = false engine.UnMatchFSRoutes = nil } + engine.rebuildFallbackChains() } // 获取默认Protocol配置 @@ -531,12 +549,52 @@ func NotFound() HandlerFunc { func (Engine *Engine) NoRoute(handler HandlerFunc) { Engine.noRoute = handler Engine.noRoutes = nil + Engine.rebuildFallbackChains() } // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { Engine.noRoute = nil Engine.noRoutes = handlerFuncs + Engine.rebuildFallbackChains() +} + +func (engine *Engine) rebuildFallbackChains() { + buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain { + finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound + if includeMethodNotAllowed { + finalSize++ + } + if includeUnmatchedFS { + finalSize += len(engine.UnMatchFSRoutes) + } + if engine.noRoute != nil { + finalSize++ + } else { + finalSize += len(engine.noRoutes) + } + + chain := make(HandlersChain, 0, finalSize) + chain = append(chain, engine.globalHandlers...) + if includeMethodNotAllowed { + chain = append(chain, methodNotAllowedHandler) + } + if includeUnmatchedFS { + chain = append(chain, engine.UnMatchFSRoutes...) + } + if engine.noRoute != nil { + chain = append(chain, engine.noRoute) + } else if len(engine.noRoutes) > 0 { + chain = append(chain, engine.noRoutes...) + } + chain = append(chain, notFoundHandler) + return chain + } + + engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false) + engine.notFoundNoMethodChain = buildChain(false, false) + engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS) + engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS) } // combineHandlers 组合多个处理函数链为一个 @@ -553,6 +611,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // 这些中间件将应用于所有注册的路由 func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { engine.globalHandlers = append(engine.globalHandlers, middleware...) + engine.rebuildFallbackChains() return engine } @@ -746,48 +805,24 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 return } - if engine.RedirectFixedPath { + if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) { // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. - ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) + ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash) if found { + c.fixedPathBuf = ciPath[:0] c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 return } + c.fixedPathBuf = ciPath[:0] } } } - // 构建处理链 - // 组合全局中间件和路由处理函数 - handlers := engine.globalHandlers - - // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 - // 则在全局中间件之后添加 MethodNotAllowed 处理器 - if engine.HandleMethodNotAllowed { - handlers = append(handlers, MethodNotAllowed()) - } - - // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed - // 则在处理链的最后添加 UnMatchFS 处理器 if engine.unMatchFS.ServeUnmatchedAsFS { - /* - var unMatchFSHandle = c.engine.unMatchFileServer - handlers = append(handlers, unMatchFSHandle) - */ - handlers = append(handlers, engine.UnMatchFSRoutes...) + c.handlers = engine.unmatchedFSChain + } else { + c.handlers = engine.notFoundChain } - - // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS - // 则在处理链的最后添加 NoRoute 处理器 - if engine.noRoute != nil { - handlers = append(handlers, engine.noRoute) - } else if len(engine.noRoutes) > 0 { - handlers = append(handlers, engine.noRoutes...) - } - - handlers = append(handlers, NotFound()) - - c.handlers = handlers c.Next() // 执行处理函数链 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } @@ -813,8 +848,28 @@ func isGeneralOptionsRequest(req *http.Request) bool { return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" } -func (engine *Engine) allowedMethodsForPath(requestPath string) []string { - allowedMethods := make([]string, 0, len(engine.methodTrees)) +func shouldTryFixedPathLookup(path string, root *node) bool { + if root != nil && root.hasCaseInsensitivePath { + return true + } + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false +} + +func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string { + if cap(allowedMethods) < len(engine.methodTrees) { + allowedMethods = make([]string, 0, len(engine.methodTrees)) + } else { + allowedMethods = allowedMethods[:0] + } for _, treeIter := range engine.methodTrees { // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 tempSkippedNodes := GetTempSkippedNodes() diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go index 5780230..666e8b2 100644 --- a/engine_benchmark_test.go +++ b/engine_benchmark_test.go @@ -16,6 +16,9 @@ func buildServeHTTPBenchmarkEngine() *Engine { engine.GET("/api/v1/users/:id", func(c *Context) { c.Status(http.StatusNoContent) }) + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) engine.POST("/api/v1/users", func(c *Context) { c.Status(http.StatusNoContent) }) @@ -61,4 +64,8 @@ func BenchmarkServeHTTP(b *testing.B) { b.Run("OptionsAllow", func(b *testing.B) { benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") }) + + b.Run("FixedPathRedirect", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS") + }) } diff --git a/engine_test.go b/engine_test.go new file mode 100644 index 0000000..292d5e2 --- /dev/null +++ b/engine_test.go @@ -0,0 +1,102 @@ +package touka + +import ( + "net/http" + "testing" +) + +func TestHandleRequestRedirectFixedPath(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" { + t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location) + } +} + +func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code) + } +} + +func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) { + engine := New() + engine.GET("/Users/Profile", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code) + } + if location := rr.Header().Get("Location"); location != "/Users/Profile" { + t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location) + } +} + +func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) { + engine := New() + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "hit" { + t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got) + } +} + +func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "" { + t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got) + } +} + +func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil) + if rr.Code != http.StatusOK { + t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code) + } + allow := rr.Header().Get("Allow") + if allow != "GET, POST" && allow != "POST, GET" { + t.Fatalf("expected Allow header to list matching methods, got %q", allow) + } +} diff --git a/reverseproxy.go b/reverseproxy.go index 1b89b2a..ff49aef 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -699,8 +699,17 @@ func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { if c.engine != nil { if c.Request != nil && c.Request.RequestURI != "*" { - if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 { - c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) + if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 { + c.allowedMethodsBuf = allow[:0] + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allow { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", BytesToString(allowHeader)) } } } diff --git a/tree.go b/tree.go index e9a10e6..6595655 100644 --- a/tree.go +++ b/tree.go @@ -121,14 +121,28 @@ const ( // node 表示路由树中的一个节点. type node struct { - path string // 当前节点的路径段 - indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 - wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) - nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) - priority uint32 // 节点的优先级, 用于查找时优先匹配 - children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 - handlers HandlersChain // 绑定到此节点的处理函数链 - fullPath string // 完整路径, 用于调试和错误信息 + path string // 当前节点的路径段 + indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 + wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) + hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由 + nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) + priority uint32 // 节点的优先级, 用于查找时优先匹配 + children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 + handlers HandlersChain // 绑定到此节点的处理函数链 + fullPath string // 完整路径, 用于调试和错误信息 +} + +func routeNeedsCaseInsensitiveLookup(path string) bool { + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false } // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. @@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int { func (n *node) addRoute(path string, handlers HandlersChain) { fullPath := path // 记录完整的路径 n.priority++ // 增加当前节点的优先级 + if routeNeedsCaseInsensitiveLookup(path) { + n.hasCaseInsensitivePath = true + } // 如果是空树(根节点) if len(n.path) == 0 && len(n.children) == 0 { @@ -702,13 +719,24 @@ walk: // 外部循环用于遍历路由树 // 它还可以选择修复尾部斜杠. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { + return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash) +} + +func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) { const stackBufSize = 128 // 栈上缓冲区的默认大小 // 在常见情况下使用栈上静态大小的缓冲区. // 如果路径太长, 则在堆上分配缓冲区. - buf := make([]byte, 0, stackBufSize) - if length := len(path) + 1; length > stackBufSize { - buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区 + if buf != nil { + buf = buf[:0] + } + if cap(buf) < len(path)+1 { + var stackBuf [stackBufSize]byte + if len(path)+1 <= stackBufSize { + buf = stackBuf[:0] + } else { + buf = make([]byte, 0, len(path)+1) // 如果路径太长, 则分配更大的缓冲区 + } } ciPath := n.findCaseInsensitivePathRec(