From 904aea5df88280fbbf65456a9bf80ced68624618 Mon Sep 17 00:00:00 2001 From: WJQSERVER <114663932+WJQSERVER@users.noreply.github.com> Date: Sun, 14 Dec 2025 22:56:37 +0800 Subject: [PATCH] refactor: Improve engine's tree processing and context handling. --- context.go | 18 +++++++++++++++++- engine.go | 9 +++------ tree.go | 26 ++++++++------------------ 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/context.go b/context.go index c79e4cc..644bbc6 100644 --- a/context.go +++ b/context.go @@ -65,6 +65,10 @@ type Context struct { // 请求体Body大小限制 MaxRequestBodySize int64 + + // skippedNodes 用于记录跳过的节点信息,以便回溯 + // 通常在处理嵌套路由时使用 + SkippedNodes []skippedNode } // --- Context 相关方法实现 --- @@ -80,7 +84,13 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.Request = req - c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 + //c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 + //避免params长度为0 + if cap(c.Params) > 0 { + c.Params = c.Params[:0] + } else { + c.Params = make(Params, 0, 5) + } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 @@ -90,6 +100,12 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize + + if cap(c.SkippedNodes) > 0 { + c.SkippedNodes = c.SkippedNodes[:0] + } else { + c.SkippedNodes = make([]skippedNode, 0, 256) + } } // Next 在处理链中执行下一个处理函数 diff --git a/engine.go b/engine.go index 581258c..0cdd5cc 100644 --- a/engine.go +++ b/engine.go @@ -432,9 +432,8 @@ func MethodNotAllowed() HandlerFunc { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 allowedMethods := []string{} for _, treeIter := range engine.methodTrees { - var tempSkippedNodes []skippedNode // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) + value := treeIter.root.getValue(requestPath, nil, &c.SkippedNodes, false) if value.handlers != nil { allowedMethods = append(allowedMethods, treeIter.method) } @@ -451,9 +450,8 @@ func MethodNotAllowed() HandlerFunc { if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 continue } - var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数 + value := treeIter.root.getValue(requestPath, nil, &c.SkippedNodes, false) // 只查找是否存在,不需要参数 if value.handlers != nil { // 使用定义的ErrorHandle处理 engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) @@ -661,9 +659,8 @@ func (engine *Engine) handleRequest(c *Context) { // 查找匹配的节点和处理函数 // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 // skippedNodes 内部使用,因此无需从外部传入已分配的 slice - var skippedNodes []skippedNode // 用于回溯的跳过节点 // 直接在 rootNode 上调用 getValue 方法 - value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码 + value := rootNode.getValue(requestPath, &c.Params, &c.SkippedNodes, true) // unescape=true 对路径参数进行 URL 解码 if value.handlers != nil { //c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数 diff --git a/tree.go b/tree.go index 09711a1..31246a5 100644 --- a/tree.go +++ b/tree.go @@ -5,7 +5,6 @@ package touka import ( - "bytes" "net/url" "strings" "unicode" @@ -27,12 +26,6 @@ func BytesToString(b []byte) string { return unsafe.String(unsafe.SliceData(b), len(b)) } -var ( - strColon = []byte(":") // 定义字节切片常量, 表示冒号, 用于路径参数识别 - strStar = []byte("*") // 定义字节切片常量, 表示星号, 用于捕获所有路径识别 - strSlash = []byte("/") // 定义字节切片常量, 表示斜杠, 用于路径分隔符识别 -) - // Param 是单个 URL 参数, 由键和值组成. type Param struct { Key string // 参数的键名 @@ -106,17 +99,14 @@ func (n *node) addChild(child *node) { // countParams 计算路径中参数(冒号)和捕获所有(星号)的数量. func countParams(path string) uint16 { - var n uint16 - s := StringToBytes(path) // 将路径字符串转换为字节切片 - n += uint16(bytes.Count(s, strColon)) // 统计冒号的数量 - n += uint16(bytes.Count(s, strStar)) // 统计星号的数量 - return n + colons := strings.Count(path, ":") + stars := strings.Count(path, "*") + return uint16(colons + stars) } // countSections 计算路径中斜杠('/')的数量, 即路径段的数量. func countSections(path string) uint16 { - s := StringToBytes(path) // 将路径字符串转换为字节切片 - return uint16(bytes.Count(s, strSlash)) // 统计斜杠的数量 + return uint16(strings.Count(path, "/")) } // nodeType 定义了节点的类型. @@ -418,10 +408,10 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) fullPath: fullPath, // 设置完整路径 } - n.addChild(child) // 添加子节点 - n.indices = string('/') // 索引设置为 '/' - n = child // 移动到新创建的 catchAll 节点 - n.priority++ // 增加优先级 + n.addChild(child) // 添加子节点 + n.indices = "/" // 索引设置为 '/' + n = child // 移动到新创建的 catchAll 节点 + n.priority++ // 增加优先级 // 第二个节点: 包含变量的节点 child = &node{