diff --git a/context.go b/context.go index 0620c41..07fc0fa 100644 --- a/context.go +++ b/context.go @@ -73,11 +73,11 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { c.Request = req c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 c.handlers = nil - c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 - c.Keys = make(map[string]interface{}) // 每次请求重新创建 map,避免数据污染 - c.Errors = c.Errors[:0] // 清空 Errors 切片 - c.queryCache = nil // 清空查询参数缓存 - c.formCache = nil // 清空表单数据缓存 + c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 + c.Keys = nil // 延迟初始化 Keys map 直到第一次 Set 调用 + c.Errors = c.Errors[:0] // 清空 Errors 切片 + c.queryCache = nil // 清空查询参数缓存 + c.formCache = nil // 清空表单数据缓存 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 // c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员 @@ -115,7 +115,7 @@ func (c *Context) AbortWithStatus(code int) { func (c *Context) Set(key string, value interface{}) { c.mu.Lock() // 加写锁 if c.Keys == nil { - c.Keys = make(map[string]interface{}) + c.Keys = make(map[string]interface{}) // 首次使用时分配 } c.Keys[key] = value c.mu.Unlock() // 解写锁 @@ -125,8 +125,12 @@ func (c *Context) Set(key string, value interface{}) { // 这是一个线程安全的操作 func (c *Context) Get(key string) (value interface{}, exists bool) { c.mu.RLock() // 加读锁 + // Defer unlock to ensure it's always called + defer c.mu.RUnlock() // 解读锁 + if c.Keys == nil { + return nil, false // 如果 Keys map 未初始化,则键肯定不存在 + } value, exists = c.Keys[key] - c.mu.RUnlock() // 解读锁 return } @@ -474,17 +478,29 @@ func (c *Context) RequestIP() string { for _, headerName := range c.engine.RemoteIPHeaders { if ipValue := c.Request.Header.Get(headerName); ipValue != "" { // X-Forwarded-For 可能包含多个 IP,约定第一个(最左边)是客户端 IP - // 其他头部(如 X-Real-IP)通常只有一个 - ips := strings.Split(ipValue, ",") - for _, singleIP := range ips { - trimmedIP := strings.TrimSpace(singleIP) + // Iterate through comma-separated IPs without allocating a slice from strings.Split + currentPos := 0 + for currentPos < len(ipValue) { + nextComma := strings.IndexByte(ipValue[currentPos:], ',') + var ipSegment string + if nextComma == -1 { + ipSegment = ipValue[currentPos:] + currentPos = len(ipValue) // End loop + } else { + ipSegment = ipValue[currentPos : currentPos+nextComma] + currentPos += nextComma + 1 // Move past segment and comma + } + + trimmedIP := strings.TrimSpace(ipSegment) + if trimmedIP == "" { // Skip empty segments that might result from "ip1,,ip2" + continue + } // 使用 netip.ParseAddr 进行 IP 地址的解析和格式验证 addr, err := netip.ParseAddr(trimmedIP) if err == nil { // 成功解析到合法的 IP 地址格式,立即返回 return addr.String() } - // 如果当前 singleIP 无效,继续检查列表中的下一个 } } } diff --git a/tree.go b/tree.go index 6f99223..71ecfe4 100644 --- a/tree.go +++ b/tree.go @@ -453,9 +453,9 @@ type nodeValue struct { // skippedNode 结构体用于在 getValue 查找过程中记录跳过的节点信息,以便回溯。 type skippedNode struct { - path string // 跳过时的当前路径 - node *node // 跳过的节点 - paramsCount int16 // 跳过时已收集的参数数量 + path string // 跳过时的完整请求路径段 (n.path + remaining path at that point) + node *node // 当时被跳过的节点 (n), direct pointer + paramsCount int16 // 当时已收集的参数数量 } // getValue 返回注册到给定路径(key)的处理函数。通配符的值会保存到 map 中。 @@ -477,21 +477,11 @@ walk: // 外部循环用于遍历路由树 if c == idxc { // 如果找到匹配的索引字符 // 如果当前节点有通配符子节点,则将当前节点添加到 skippedNodes,以便回溯 if n.wildChild { - index := len(*skippedNodes) - *skippedNodes = (*skippedNodes)[:index+1] - (*skippedNodes)[index] = skippedNode{ - path: prefix + path, // 记录跳过的路径 - node: &node{ // 复制当前节点的状态 - path: n.path, - wildChild: n.wildChild, - nType: n.nType, - priority: n.priority, - children: n.children, - handlers: n.handlers, - fullPath: n.fullPath, - }, + *skippedNodes = append(*skippedNodes, skippedNode{ + path: prefix + path, // Path is n.path + remaining path after matching n.path + node: n, // Store pointer to the original node `n` paramsCount: globalParamsCount, // 记录当前参数计数 - } + }) } n = n.children[i] // 移动到匹配的子节点