diff --git a/context.go b/context.go index 8c52b1f..b6fbd46 100644 --- a/context.go +++ b/context.go @@ -18,7 +18,7 @@ import ( "net/netip" "net/url" "os" - "path/filepath" + "path" "strings" "sync" "time" @@ -65,10 +65,6 @@ type Context struct { // 请求体Body大小限制 MaxRequestBodySize int64 - - // skippedNodes 用于记录跳过的节点信息,以便回溯 - // 通常在处理嵌套路由时使用 - SkippedNodes []skippedNode } // --- Context 相关方法实现 --- @@ -84,13 +80,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.Request = req - //c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 - //避免params长度为0 - if cap(c.Params) > 0 { - c.Params = c.Params[:0] - } else { - c.Params = make(Params, 0, c.engine.maxParams) - } + c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 @@ -100,12 +90,6 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize - - if cap(c.SkippedNodes) > 0 { - c.SkippedNodes = c.SkippedNodes[:0] - } else { - c.SkippedNodes = make([]skippedNode, 0, 256) - } } // Next 在处理链中执行下一个处理函数 @@ -296,113 +280,6 @@ func (c *Context) Text(code int, text string) { c.Writer.Write([]byte(text)) } -// FileText -func (c *Context) FileText(code int, filePath string) { - // 清理path - cleanPath := filepath.Clean(filePath) - if !filepath.IsAbs(cleanPath) { - c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath)) - c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed")) - return - } - // 检查文件是否存在 - if _, err := os.Stat(cleanPath); os.IsNotExist(err) { - c.AddError(fmt.Errorf("file not found: %s", cleanPath)) - c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found")) - return - } - - // 打开文件 - file, err := os.Open(cleanPath) - if err != nil { - c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err)) - c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err)) - return - } - defer file.Close() - - // 获取文件信息以获取文件大小 - fileInfo, err := file.Stat() - if err != nil { - c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err)) - c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err)) - return - } - // 判断是否是dir - if fileInfo.IsDir() { - c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath)) - c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory")) - return - } - - c.SetHeader("Content-Type", "text/plain; charset=utf-8") - - c.SetBodyStream(file, int(fileInfo.Size())) -} - -/* -// not fot work -// FileTextSafeDir -func (c *Context) FileTextSafeDir(code int, filePath string, safeDir string) { - - // 清理path - cleanPath := path.Clean(filePath) - if !filepath.IsAbs(cleanPath) { - c.AddError(fmt.Errorf("relative path not allowed: %s", cleanPath)) - c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("relative path not allowed")) - return - } - if strings.Contains(cleanPath, "..") { - c.AddError(fmt.Errorf("path traversal attempt detected: %s", cleanPath)) - c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path traversal attempt detected")) - return - } - - // 判断filePath是否包含在safeDir内, 防止路径穿越 - relPath, err := filepath.Rel(safeDir, cleanPath) - if err != nil { - c.AddError(fmt.Errorf("failed to get relative path: %w", err)) - c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("failed to get relative path: %w", err)) - return - } - cleanPath = filepath.Join(safeDir, relPath) - - // 检查文件是否存在 - if _, err := os.Stat(cleanPath); os.IsNotExist(err) { - c.AddError(fmt.Errorf("file not found: %s", cleanPath)) - c.ErrorUseHandle(http.StatusNotFound, fmt.Errorf("file not found")) - return - } - - // 打开文件 - file, err := os.Open(cleanPath) - if err != nil { - c.AddError(fmt.Errorf("failed to open file %s: %w", cleanPath, err)) - c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to open file: %w", err)) - return - } - defer file.Close() - - // 获取文件信息以获取文件大小 - fileInfo, err := file.Stat() - if err != nil { - c.AddError(fmt.Errorf("failed to get file info for %s: %w", cleanPath, err)) - c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to get file info: %w", err)) - return - } - // 判断是否是dir - if fileInfo.IsDir() { - c.AddError(fmt.Errorf("path is a directory, not a file: %s", cleanPath)) - c.ErrorUseHandle(http.StatusBadRequest, fmt.Errorf("path is a directory")) - return - } - - c.SetHeader("Content-Type", "text/plain; charset=utf-8") - - c.SetBodyStream(file, int(fileInfo.Size())) -} -*/ - // JSON 向响应写入 JSON 数据 // 设置 Content-Type 为 application/json func (c *Context) JSON(code int, obj any) { @@ -410,7 +287,6 @@ func (c *Context) JSON(code int, obj any) { c.Writer.WriteHeader(code) if err := json.MarshalWrite(c.Writer, obj); err != nil { c.AddError(fmt.Errorf("failed to marshal JSON: %w", err)) - c.Errorf("failed to marshal JSON: %s", err) c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to marshal JSON: %w", err)) return } @@ -879,7 +755,7 @@ func (c *Context) GetRequestURIPath() string { // 将文件内容作为响应body func (c *Context) SetRespBodyFile(code int, filePath string) { // 清理path - cleanPath := filepath.Clean(filePath) + cleanPath := path.Clean(filePath) // 打开文件 file, err := os.Open(cleanPath) @@ -899,7 +775,7 @@ func (c *Context) SetRespBodyFile(code int, filePath string) { } // 尝试根据文件扩展名猜测 Content-Type - contentType := mime.TypeByExtension(filepath.Ext(cleanPath)) + contentType := mime.TypeByExtension(path.Ext(cleanPath)) if contentType == "" { // 如果无法猜测,则使用默认的二进制流类型 contentType = "application/octet-stream" diff --git a/engine.go b/engine.go index 0a95765..581258c 100644 --- a/engine.go +++ b/engine.go @@ -421,41 +421,6 @@ func getHandlerName(h HandlerFunc) string { } -const MaxSkippedNodesCap = 256 - -// TempSkippedNodesPool 存储 *[]skippedNode 以复用内存 -var TempSkippedNodesPool = sync.Pool{ - New: func() any { - // 返回一个指向容量为 256 的新切片的指针 - s := make([]skippedNode, 0, MaxSkippedNodesCap) - return &s - }, -} - -// GetTempSkippedNodes 从 Pool 中获取一个 *[]skippedNode 指针 -func GetTempSkippedNodes() *[]skippedNode { - // 直接返回 Pool 中存储的指针 - return TempSkippedNodesPool.Get().(*[]skippedNode) -} - -// PutTempSkippedNodes 将用完的 *[]skippedNode 指针放回 Pool -func PutTempSkippedNodes(skippedNodes *[]skippedNode) { - if skippedNodes == nil || *skippedNodes == nil { - return - } - - // 检查容量是否符合预期。如果容量不足,则丢弃,不放回 Pool。 - if cap(*skippedNodes) < MaxSkippedNodesCap { - return // 丢弃该对象,让 Pool 在下次 Get 时通过 New 重新分配 - } - - // 长度重置为 0,保留容量,实现复用 - *skippedNodes = (*skippedNodes)[:0] - - // 将指针存回 Pool - TempSkippedNodesPool.Put(skippedNodes) -} - // 405中间件 func MethodNotAllowed() HandlerFunc { return func(c *Context) { @@ -467,10 +432,9 @@ func MethodNotAllowed() HandlerFunc { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 allowedMethods := []string{} for _, treeIter := range engine.methodTrees { + var tempSkippedNodes []skippedNode // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) - PutTempSkippedNodes(tempSkippedNodes) + value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) if value.handlers != nil { allowedMethods = append(allowedMethods, treeIter.method) } @@ -487,10 +451,9 @@ func MethodNotAllowed() HandlerFunc { if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 continue } + var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - PutTempSkippedNodes(tempSkippedNodes) + value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数 if value.handlers != nil { // 使用定义的ErrorHandle处理 engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) @@ -698,8 +661,9 @@ func (engine *Engine) handleRequest(c *Context) { // 查找匹配的节点和处理函数 // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 // skippedNodes 内部使用,因此无需从外部传入已分配的 slice + var skippedNodes []skippedNode // 用于回溯的跳过节点 // 直接在 rootNode 上调用 getValue 方法 - value := rootNode.getValue(requestPath, &c.Params, &c.SkippedNodes, true) // unescape=true 对路径参数进行 URL 解码 + value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码 if value.handlers != nil { //c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数 diff --git a/go.mod b/go.mod index f9d10a9..0b8d97b 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,12 @@ go 1.25.1 require ( github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 github.com/WJQSERVER-STUDIO/httpc v0.8.2 - github.com/WJQSERVER/wanf v0.0.3 + github.com/WJQSERVER/wanf v0.0.0-20250810023226-e51d9d0737ee github.com/fenthope/reco v0.0.4 - github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e + github.com/go-json-experiment/json v0.0.0-20250813233538-9b1f9ea2e11b ) require ( github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.49.0 // indirect + golang.org/x/net v0.43.0 // indirect ) diff --git a/go.sum b/go.sum index b75fec4..dcd4f26 100644 --- a/go.sum +++ b/go.sum @@ -2,13 +2,13 @@ github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbe github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= github.com/WJQSERVER-STUDIO/httpc v0.8.2 h1:PFPLodV0QAfGEP6915J57vIqoKu9cGuuiXG/7C9TNUk= github.com/WJQSERVER-STUDIO/httpc v0.8.2/go.mod h1:8WhHVRO+olDFBSvL5PC/bdMkb6U3vRdPJ4p4pnguV5Y= -github.com/WJQSERVER/wanf v0.0.3 h1:OqhG7ETiR5Knqr0lmbb+iUMw9O7re2vEogjVf06QevM= -github.com/WJQSERVER/wanf v0.0.3/go.mod h1:q2Pyg+G+s1acMWxrbI4CwS/Yk76/BzLREEdZ8iFwUNE= +github.com/WJQSERVER/wanf v0.0.0-20250810023226-e51d9d0737ee h1:tJ31DNBn6UhWkk8fiikAQWqULODM+yBcGAEar1tzdZc= +github.com/WJQSERVER/wanf v0.0.0-20250810023226-e51d9d0737ee/go.mod h1:q2Pyg+G+s1acMWxrbI4CwS/Yk76/BzLREEdZ8iFwUNE= github.com/fenthope/reco v0.0.4 h1:yo2g3aWwdoMpaZWZX4SdZOW7mCK82RQIU/YI8ZUQThM= github.com/fenthope/reco v0.0.4/go.mod h1:eMyS8HpdMVdJ/2WJt6Cvt8P1EH9Igzj5lSJrgc+0jeg= -github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU= -github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok= +github.com/go-json-experiment/json v0.0.0-20250813233538-9b1f9ea2e11b h1:6Q4zRHXS/YLOl9Ng1b1OOOBWMidAQZR3Gel0UKPC/KU= +github.com/go-json-experiment/json v0.0.0-20250813233538-9b1f9ea2e11b/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= diff --git a/tree.go b/tree.go index 31246a5..09711a1 100644 --- a/tree.go +++ b/tree.go @@ -5,6 +5,7 @@ package touka import ( + "bytes" "net/url" "strings" "unicode" @@ -26,6 +27,12 @@ func BytesToString(b []byte) string { return unsafe.String(unsafe.SliceData(b), len(b)) } +var ( + strColon = []byte(":") // 定义字节切片常量, 表示冒号, 用于路径参数识别 + strStar = []byte("*") // 定义字节切片常量, 表示星号, 用于捕获所有路径识别 + strSlash = []byte("/") // 定义字节切片常量, 表示斜杠, 用于路径分隔符识别 +) + // Param 是单个 URL 参数, 由键和值组成. type Param struct { Key string // 参数的键名 @@ -99,14 +106,17 @@ func (n *node) addChild(child *node) { // countParams 计算路径中参数(冒号)和捕获所有(星号)的数量. func countParams(path string) uint16 { - colons := strings.Count(path, ":") - stars := strings.Count(path, "*") - return uint16(colons + stars) + var n uint16 + s := StringToBytes(path) // 将路径字符串转换为字节切片 + n += uint16(bytes.Count(s, strColon)) // 统计冒号的数量 + n += uint16(bytes.Count(s, strStar)) // 统计星号的数量 + return n } // countSections 计算路径中斜杠('/')的数量, 即路径段的数量. func countSections(path string) uint16 { - return uint16(strings.Count(path, "/")) + s := StringToBytes(path) // 将路径字符串转换为字节切片 + return uint16(bytes.Count(s, strSlash)) // 统计斜杠的数量 } // nodeType 定义了节点的类型. @@ -408,10 +418,10 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) fullPath: fullPath, // 设置完整路径 } - n.addChild(child) // 添加子节点 - n.indices = "/" // 索引设置为 '/' - n = child // 移动到新创建的 catchAll 节点 - n.priority++ // 增加优先级 + n.addChild(child) // 添加子节点 + n.indices = string('/') // 索引设置为 '/' + n = child // 移动到新创建的 catchAll 节点 + n.priority++ // 增加优先级 // 第二个节点: 包含变量的节点 child = &node{