diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..30d74d2 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +test \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7076bf8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,199 @@ +WJQserver Studio 开源许可证 +版本 v2.0 + +版权所有 © WJQserver Studio 2025 +版权所有 © Infinite Iroha 2025 +版权所有 © WJQserver 2025 + +定义 + +* 许可 (License): 指的是在本许可证内定义的使用、复制、分发与修改软件的条款与要求。 +* 授权方 (Licensor): 指的是拥有版权的个人或组织,亦或是拥有版权的个人或组织所指派的实体,在本许可证中特指 WJQserver Studio。 +* 贡献者 (Contributor): 指的是授权方以及根据本许可证授予贡献代码或软件的个人或实体。 +* 您 (You): 指的是行使本许可授予的权限的个人或法律实体。 +* 衍生作品 (Derivative Works): 指的是基于本软件或本软件任何部分的修改作品,无论修改程度如何。这包括但不限于基于本软件或其任何部分的修改、修订、改编、翻译或其他形式的创作,以及包含本软件或其部分的集合作品。 +* 非营利性使用 (Non-profit Use): 指的是不以直接商业盈利为主要目的的使用方式,包括但不限于: + * 个人用途: 由个人为了个人学习、研究、实验、非商业项目、个人网站搭建、毕业设计、家庭内部娱乐等非直接商业目的使用软件。 + * 教育用途: 在教育机构(如学校、大学、培训机构)内部用于教学、研究、学术交流等活动。 + * 科研用途: 在科研院所、实验室等机构内部用于科学研究、实验开发等活动。 + * 慈善与公益用途: 由慈善机构、公益组织等非营利性组织为了其公益使命或慈善事业内部运营使用,或对外提供不直接产生商业利润的公益服务。 + * 内部运营用途 (非营利组织): 非营利性组织在其内部运营中使用软件,例如用于行政管理、会员管理、内部沟通、项目管理等非直接营利性活动。 + +开源与自由软件 + +本项目为开源软件,允许用户在遵循本许可证的前提下访问和使用源代码。 +本项目旨在向用户提供尽可能广泛的非商业使用自由,同时保障社区的共同发展和良性生态,并为商业创新提供清晰的路径。 +强调版权所有,所有权利由 WJQserver Studio 及贡献者共同保留。 + +许可证条款 + +1. 使用权限 + +* 1.1 非营利性使用: 您被授予在非营利性使用场景下,为了任何目的,自由使用本软件的权限。 非营利性使用的具体场景包括但不限于定义部分所列举的各种情况。 + +* 1.2 商业使用: 您可以在商业环境中使用本软件,无需获得额外授权,但您的商业使用行为必须遵守以下条款: + + * 1.2.1 保持声明: 您在进行商业使用时,不得移除或修改软件中包含的原始版权声明、许可证声明以及来源声明。 + * 1.2.2 开源继承 (Copyleft) 与互惠共享: 如果您或您的组织希望将本软件或其衍生作品用于任何商业用途,包括但不限于: + + * 盈利性分发: 销售、出租、许可分发本软件或其衍生作品。 + * 盈利性服务: 基于本软件或其衍生作品提供商业服务,例如 SaaS 服务、咨询服务、定制开发服务、收费技术支持服务等。 + * 嵌入式商业应用: 将本软件或其衍生作品嵌入到商业产品或解决方案中进行销售。 + * 组织内部商业运营: 在营利性组织的内部运营中使用修改后的版本以直接支持其商业活动,例如定制化内部系统,通过例如但不限于在软件或相关服务中投放广告 (例如 Google Ads 等),应用内购买 (内购), 会员订阅, 增值功能收费等方式直接或间接产生商业收入。 + + 您必须选择以下两种方式之一: + + * i) 继承本许可证并开源: 您必须以本许可证或兼容的开源许可证分发您的衍生作品,并公开您的衍生作品的全部源代码,使得您的衍生作品的接收者也享有与您相同的权利,包括进一步修改和商业使用的权利。 本选项旨在促进社区的共同发展和知识共享,确保基于本软件的商业创新成果也能回馈社区。 + * ii) 获得授权方明确授权: 如果您不希望以开源方式发布您的衍生作品,或者希望使用其他许可证进行分发,或者您希望在商业运营中使用修改后的版本但不开源,您必须事先获得 WJQserver Studio 的明确书面授权。 授权的具体条款和条件将由 WJQserver Studio 另行协商确定。 + +2. 复制与分发 + +* 2.1 原始版本复制与分发: 您可以复制和分发本软件的原始版本,前提是必须满足以下条件: + + * 保留所有声明: 完整保留所有原始版权声明、许可证声明、来源声明以及其他所有权声明。 + * 附带许可证: 在分发软件时,必须同时附带本许可证的完整文本,确保接收者知悉并理解本许可证的全部条款。 + +* 2.2 衍生作品复制与分发: 您可以复制和分发基于本软件的衍生作品,您对衍生作品的分发行为将受到本许可证第 1.2.2 条(开源继承与互惠共享)的约束。 + +3. 修改权限 + +* 3.1 自由修改: 您被授予自由修改本软件的权限,无论修改目的是非营利性使用还是商业用途。 + +* 3.2 修改后使用与分发约束: 当您将修改后的版本用于商业用途或分发修改后的版本时,您需要遵守本许可证第 1.2.2 条(开源继承与互惠共享)以及第 2 条(复制与分发)的规定。 即使您不分发修改后的版本,只要您将其用于商业目的,也需要遵守开源继承条款或获得授权。 + +* 3.3 贡献接受: WJQserver Studio 鼓励社区贡献代码。如果您向本项目贡献代码,您需要同意您的贡献代码按照本许可证条款进行许可。 + +4. 专利权 + +* 4.1 无专利担保,风险自担: 本软件以“现状”提供,授权方及贡献者明确声明,不对本软件的专利侵权问题做任何形式的担保,亦不承担任何因专利侵权可能产生的责任与后果。 用户理解并同意,使用本软件的专利风险完全由用户自行承担。 + +* 4.2 专利纠纷应对: 如因用户使用本软件而引发任何专利侵权指控、诉讼或索赔,用户应自行负责处理并承担全部法律责任。 授权方及贡献者无义务参与任何相关法律程序,亦不承担任何由此产生的费用或赔偿。 + +5. 免责声明 + +* 5.1 “现状”提供,无任何保证: 本软件按“现状”提供,不提供任何明示或暗示的保证,包括但不限于适销性、特定用途适用性及非侵权性。 + +* 5.2 责任限制: 在适用法律允许的最大范围内,在任何情况下,授权方或任何贡献者均不对因使用或无法使用本软件而产生的任何直接、间接、偶然、特殊、惩罚性或后果性损害(包括但不限于采购替代商品或服务;损失使用、数据或利润;或业务中断)负责,无论其是如何造成的,也无论依据何种责任理论,即使已被告知可能发生此类损害。 + +* 5.3 用户法律责任: 用户需根据当地法律对待本项目,确保遵守所有适用法规。 + +6. 许可证期限与终止 + +* 6.1 许可证期限: 除版权所有人主动宣布放弃本软件版权外,本许可证无限期生效。 + +* 6.2 许可证终止: 如果您未能遵守本许可证的任何条款或条件,授权方有权终止本许可证。 您的许可证将在您违反本许可证条款时自动终止。 + +* 6.3 终止后的效力: 许可证终止后,您根据本许可证所享有的所有权利将立即终止,但您在许可证终止前已合法分发的软件副本,其接收者所获得的许可及权利将不受影响,继续有效。 免责声明(第 5 条)和责任限制(第 5.2 条)在本许可证终止后仍然有效。 + +7. 条款修订 + +* 7.1 修订权利保留: 授权方保留随时修改本许可证条款的权利,以便更好地适应法律、技术发展以及社区需求。 + +* 7.2 修订生效与接受: 修订后的条款将在发布时生效,除非另行声明,否则继续使用、复制、分发或修改本软件即表示您接受修订后的条款。授权方鼓励用户定期查阅本许可证的最新版本。 + +8. 其他 + +* 8.1 法定权利: 本许可证不影响您作为最终用户在适用法律下的法定权利。 + +* 8.2 条款可分割性: 若本许可证的某些条款被认定为不可执行,其余条款仍然完全有效。 + +* 8.3 版本更新: 授权方可能会发布本许可证的修订版本或新版本。您可以选择是继续使用本许可证的旧版本还是选择适用新版本。 + +WJQserver Studio Open Source License +Version v2.0 + +Copyright © WJQserver Studio 2024 + +Definitions + +* License: Refers to the terms and requirements for use, reproduction, distribution, and modification defined within this license. +* Licensor: Refers to the individual or organization that holds the copyright, or the entity designated by the copyright holder, specifically WJQserver Studio in this license. +* Contributor: Refers to the Licensor and individuals or entities who contribute code or software under this License. +* You: Refers to the individual or legal entity exercising permissions granted by this License. +* Derivative Works: Refers to works modified based on the Software or any part thereof, regardless of the extent of modification. This includes but is not limited to modifications, revisions, adaptations, translations, or other forms of creation based on the Software or any part thereof, as well as collective works containing the Software or parts thereof. +* Non-profit Use: Refers to uses not primarily intended for direct commercial profit, including but not limited to: + * Personal Use: Use by an individual for personal learning, research, experimentation, non-commercial projects, personal website development, graduation projects, home entertainment, and other non-directly commercial purposes. + * Educational Use: Use within educational institutions (such as schools, universities, training organizations) for activities such as teaching, research, and academic exchange. + * Scientific Research Use: Use within scientific research institutions, laboratories, and similar organizations for activities such as scientific research and experimental development. + * Charitable and Public Welfare Use: Use by charitable organizations, public welfare organizations, and similar non-profit entities for their public missions or internal operation of charitable activities, or to provide public services that do not directly generate commercial profit. + * Internal Operational Use (Non-profit Organizations): Use within the internal operations of non-profit organizations, such as for administrative management, membership management, internal communication, project management, and other non-directly profit-generating activities. + +Open Source and Free Software + +This project is open-source software, allowing users to access and use the source code under the premise of complying with this License. +This project aims to provide users with the broadest possible freedom for non-commercial use while ensuring the common development and healthy ecosystem of the community, and providing a clear path for commercial innovation. +Copyright is emphasized; all rights are jointly reserved by WJQserver Studio and Contributors. + +License Terms + +1. Permissions for Use + +* 1.1 Non-profit Use: You are granted permission to freely use the Software for any purpose in non-profit use scenarios. Specific non-profit use scenarios include but are not limited to the various situations listed in the Definition section. + +* 1.2 Commercial Use: You may use the Software in a commercial environment without additional authorization, but your commercial use must comply with the following terms: + + * 1.2.1 Maintain Statements: When conducting commercial use, you must not remove or modify the original copyright notices, license notices, and source statements contained in the Software. + * 1.2.2 Open Source Inheritance (Copyleft) and Reciprocal Sharing: If you or your organization wish to use the Software or its Derivative Works for any commercial purpose, including but not limited to: + + * Profit-generating Distribution: Selling, renting, licensing, or distributing the Software or its Derivative Works. + * Profit-generating Services: Providing commercial services based on the Software or its Derivative Works, such as SaaS services, consulting services, custom development services, and paid technical support services. + * Embedded Commercial Applications: Embedding the Software or its Derivative Works into commercial products or solutions for sale. + * Internal Commercial Operations: Using modified versions within the internal operations of for-profit organizations to directly support their commercial activities, such as customized internal systems, generating commercial revenue directly or indirectly through means including but not limited to placing advertisements in the software or related services (e.g., Google Ads), in-app purchases, membership subscriptions, and charging for value-added features. + + You must choose one of the following two options: + + * i) Inherit this License and Open Source: You must distribute your Derivative Works under this License or a compatible open-source license and publicly disclose the entire source code of your Derivative Works, so that recipients of your Derivative Works also enjoy the same rights as you, including the right to further modify and use commercially. This option aims to promote the common development and knowledge sharing of the community, ensuring that commercial innovation achievements based on this Software can also contribute back to the community. + * ii) Obtain Explicit Authorization from the Licensor: If you do not wish to release your Derivative Works in an open-source manner, or wish to distribute them under another license, or you wish to use a modified version in commercial operations without open-sourcing it, you must obtain explicit written authorization from WJQserver Studio in advance. The specific terms and conditions of authorization will be determined separately by WJQserver Studio through negotiation. + +2. Reproduction and Distribution + +* 2.1 Reproduction and Distribution of Original Version: You may reproduce and distribute the original version of the Software, provided that the following conditions are met: + + * Retain All Statements: Completely retain all original copyright notices, license notices, source statements, and other proprietary notices. + * Accompany with License: When distributing the Software, you must also include the full text of this License to ensure that recipients are aware of and understand all terms of this License. + +* 2.2 Reproduction and Distribution of Derivative Works: You may reproduce and distribute Derivative Works based on the Software. Your distribution of Derivative Works will be subject to the constraints of Clause 1.2.2 of this License (Open Source Inheritance and Reciprocal Sharing). + +3. Modification Permissions + +* 3.1 Free Modification: You are granted permission to freely modify the Software, regardless of whether the purpose of modification is for non-profit use or commercial use. + +* 3.2 Constraints on Use and Distribution after Modification: When you use a modified version for commercial purposes or distribute a modified version, you need to comply with the provisions of Clause 1.2.2 of this License (Open Source Inheritance and Reciprocal Sharing) and Clause 2 (Reproduction and Distribution). Even if you do not distribute the modified version, as long as you use it for commercial purposes, you also need to comply with the open-source inheritance clause or obtain authorization. + +* 3.3 Contribution Acceptance: WJQserver Studio encourages community contribution of code. If you contribute code to this project, you need to agree that your contributed code is licensed under the terms of this License. + +4. Patent Rights + +* 4.1 No Patent Warranty, Risk Self-Bearing: The software is provided “AS IS”, and the Licensor and Contributors explicitly declare that they do not provide any form of warranty regarding patent infringement issues of this software, nor do they assume any responsibility and consequences arising from patent infringement. Users understand and agree that the patent risk of using this software is entirely borne by the users themselves. + +* 4.2 Handling of Patent Disputes: If any patent infringement allegations, lawsuits, or claims arise due to the user's use of this Software, the user shall be solely responsible for handling and bear all legal liabilities. The Licensor and Contributors are under no obligation to participate in any related legal proceedings, nor do they bear any costs or compensation arising therefrom. + +5. Disclaimer of Warranty + +* 5.1 “AS IS” Provision, No Warranty: The software is provided “AS IS” without any express or implied warranties, including but not limited to warranties of merchantability, fitness for a particular purpose, and non-infringement. + +* 5.2 Limitation of Liability: To the maximum extent permitted by applicable law, in no event shall the Licensor or any Contributor be liable for any direct, indirect, incidental, special, punitive, or consequential damages (including but not limited to procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software, even if advised of the possibility of such damage. + +* 5.3 User Legal Responsibility: Users shall treat this project in accordance with local laws and regulations to ensure compliance with all applicable laws and regulations. + +6. License Term and Termination + +* 6.1 License Term: Unless the copyright holder proactively announces the abandonment of the copyright of this software, this License shall be effective indefinitely from the date of your acceptance. + +* 6.2 License Termination: If you fail to comply with any terms or conditions of this License, the Licensor has the right to terminate this License. Your License will automatically terminate upon your violation of the terms of this License. + +* 6.3 Effect after Termination: Upon termination of the License, all rights granted to you under this License will terminate immediately, but the licenses and rights obtained by recipients of software copies you have legally distributed before the termination of the License will not be affected and will remain valid. The Disclaimer of Warranty (Clause 5) and Limitation of Liability (Clause 5.2) shall remain in effect after the termination of this License. + +7. Revision of Terms + +* 7.1 Reservation of Revision Rights: The Licensor reserves the right to modify the terms of this License at any time to better adapt to legal, technological developments, and community needs. + +* 7.2 Effectiveness and Acceptance of Revisions: Revised terms will take effect upon publication, and unless otherwise stated, continued use, reproduction, distribution, or modification of the Software indicates your acceptance of the revised terms. The Licensor encourages users to periodically review the latest version of this License. + +8. Other + +* 8.1 Statutory Rights: This License does not affect your statutory rights as an end-user under applicable laws. + +* 8.2 Severability of Terms: If certain terms of this License are deemed unenforceable, the remaining terms shall remain in full force and effect. + +* 8.3 Version Updates: The Licensor may publish revised versions or new versions of this License. You may choose to continue using the old version of this License or choose to apply the new version. diff --git a/README.md b/README.md index 75b6733..4e3fa5c 100644 --- a/README.md +++ b/README.md @@ -1 +1,9 @@ -# touka \ No newline at end of file +# Touka 框架 + +## 许可证 + +本项目在v0阶段使用WJQSERVER STUDIO LICENSE许可证, 后续进行调整 + +tree部分来自[gin](https://github.com/gin-gonic/gin)与[httprouter](https://github.com/julienschmidt/httprouter) + +[WJQSERVER/httproute](https://github.com/WJQSERVER/httprouter)是本项目的前身(一个[httprouter](https://github.com/julienschmidt/httprouter)的fork版本) \ No newline at end of file diff --git a/context.go b/context.go new file mode 100644 index 0000000..4efb57d --- /dev/null +++ b/context.go @@ -0,0 +1,439 @@ +package touka + +import ( + "context" + "errors" + "fmt" + "html/template" + "io" + "math" + "net" + "net/http" + "net/netip" + "net/url" + "strings" + "sync" + + "github.com/go-json-experiment/json" + + "github.com/WJQSERVER-STUDIO/go-utils/copyb" + "github.com/WJQSERVER-STUDIO/httpc" +) + +const abortIndex int8 = math.MaxInt8 >> 1 + +// Context 是每个请求的上下文,封装了请求和响应,并提供了很多便捷方法 +// 它在中间件和最终处理函数之间传递 +type Context struct { + Writer ResponseWriter // 包装的 http.ResponseWriter + Request *http.Request + Params Params // 从 httprouter 获取的路径参数 + handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler) + index int8 // 当前执行到处理链的哪个位置 + + mu sync.RWMutex + Keys map[string]interface{} // 用于在中间件之间传递数据 + + Errors []error // 用于收集处理过程中的错误 + + // 缓存查询参数和表单数据 + queryCache url.Values + formCache url.Values + + // 携带ctx以实现关闭逻辑 + ctx context.Context + + // HTTPClient 用于在此上下文中执行出站 HTTP 请求。 + // 它由 Engine 提供。 + HTTPClient *httpc.Client + + // 引用所属的 Engine 实例,方便访问 Engine 的配置(如 HTMLRender) + engine *Engine +} + +// --- Context 相关方法实现 --- + +// reset 重置 Context 对象以供复用。 +// 每次从 sync.Pool 中获取 Context 后,都需要调用此方法进行初始化。 +func (c *Context) reset(w http.ResponseWriter, req *http.Request) { + // 每次重置时,确保 Writer 包装的是最新的 http.ResponseWriter + // 并重置其内部状态 + if rw, ok := c.Writer.(*responseWriterImpl); ok { + rw.ResponseWriter = w + rw.status = 0 + rw.size = 0 + } else { + // 如果 c.Writer 不是 responseWriterImpl,重新创建 + c.Writer = newResponseWriter(w) + } + + c.Request = req + c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 + c.handlers = nil + c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 + c.Keys = make(map[string]interface{}) // 每次请求重新创建 map,避免数据污染 + c.Errors = c.Errors[:0] // 清空 Errors 切片 + c.queryCache = nil // 清空查询参数缓存 + c.formCache = nil // 清空表单数据缓存 + c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 + // c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员 +} + +// Next 在处理链中执行下一个处理函数。 +// 这是中间件模式的核心,允许请求依次经过多个处理函数。 +func (c *Context) Next() { + c.index++ + for c.index < int8(len(c.handlers)) { + c.handlers[c.index](c) // 执行当前索引处的处理函数 + c.index++ // 移动到下一个处理函数 + } +} + +// Abort 停止处理链的后续执行。 +// 通常在中间件中,当遇到错误或需要提前终止请求时调用。 +func (c *Context) Abort() { + c.index = abortIndex // 将 index 设置为一个很大的值,使后续 Next() 调用跳过所有处理函数 +} + +// IsAborted 返回处理链是否已被中止。 +func (c *Context) IsAborted() bool { + return c.index >= abortIndex +} + +// AbortWithStatus 中止处理链并设置 HTTP 状态码。 +func (c *Context) AbortWithStatus(code int) { + c.Writer.WriteHeader(code) // 设置响应状态码 + c.Abort() // 中止处理链 +} + +// Set 将一个键值对存储到 Context 中。 +// 这是一个线程安全的操作,用于在中间件之间传递数据。 +func (c *Context) Set(key string, value interface{}) { + c.mu.Lock() // 加写锁 + if c.Keys == nil { + c.Keys = make(map[string]interface{}) + } + c.Keys[key] = value + c.mu.Unlock() // 解写锁 +} + +// Get 从 Context 中获取一个值。 +// 这是一个线程安全的操作。 +func (c *Context) Get(key string) (value interface{}, exists bool) { + c.mu.RLock() // 加读锁 + value, exists = c.Keys[key] + c.mu.RUnlock() // 解读锁 + return +} + +// MustGet 从 Context 中获取一个值,如果不存在则 panic。 +// 适用于确定值一定存在的场景。 +func (c *Context) MustGet(key string) interface{} { + if value, exists := c.Get(key); exists { + return value + } + panic("Key \"" + key + "\" does not exist in context.") +} + +// Query 从 URL 查询参数中获取值。 +// 懒加载解析查询参数,并进行缓存。 +func (c *Context) Query(key string) string { + if c.queryCache == nil { + c.queryCache = c.Request.URL.Query() // 首次访问时解析并缓存 + } + return c.queryCache.Get(key) +} + +// DefaultQuery 从 URL 查询参数中获取值,如果不存在则返回默认值。 +func (c *Context) DefaultQuery(key, defaultValue string) string { + if value := c.Query(key); value != "" { + return value + } + return defaultValue +} + +// PostForm 从 POST 请求体中获取表单值。 +// 懒加载解析表单数据,并进行缓存。 +func (c *Context) PostForm(key string) string { + if c.formCache == nil { + c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded + c.formCache = c.Request.PostForm + } + return c.formCache.Get(key) +} + +// DefaultPostForm 从 POST 请求体中获取表单值,如果不存在则返回默认值。 +func (c *Context) DefaultPostForm(key, defaultValue string) string { + if value := c.PostForm(key); value != "" { + return value + } + return defaultValue +} + +// Param 从 URL 路径参数中获取值。 +// 例如,对于路由 /users/:id,c.Param("id") 可以获取 id 的值。 +func (c *Context) Param(key string) string { + return c.Params.ByName(key) +} + +// String 向响应写入格式化的字符串。 +func (c *Context) String(code int, format string, values ...interface{}) { + c.Writer.WriteHeader(code) + c.Writer.Write([]byte(fmt.Sprintf(format, values...))) +} + +// JSON 向响应写入 JSON 数据。 +// 设置 Content-Type 为 application/json。 +func (c *Context) JSON(code int, obj interface{}) { + c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") + c.Writer.WriteHeader(code) + // 实际 JSON 编码 + jsonBytes, err := json.Marshal(obj) + if err != nil { + c.AddError(fmt.Errorf("failed to marshal JSON: %w", err)) + c.String(http.StatusInternalServerError, "Internal Server Error: Failed to marshal JSON") + return + } + c.Writer.Write(jsonBytes) +} + +// HTML 渲染 HTML 模板。 +// 如果 Engine 配置了 HTMLRender,则使用它进行渲染。 +// 否则,会进行简单的字符串输出。 +// 预留接口,可以扩展为支持多种模板引擎。 +func (c *Context) HTML(code int, name string, obj interface{}) { + c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") + c.Writer.WriteHeader(code) + + if c.engine != nil && c.engine.HTMLRender != nil { + // 假设 HTMLRender 是一个 *template.Template 实例 + if tpl, ok := c.engine.HTMLRender.(*template.Template); ok { + err := tpl.ExecuteTemplate(c.Writer, name, obj) + if err != nil { + c.AddError(fmt.Errorf("failed to render HTML template '%s': %w", name, err)) + c.String(http.StatusInternalServerError, "Internal Server Error: Failed to render HTML template") + } + return + } + // 可以扩展支持其他渲染器接口 + } + // 默认简单输出,用于未配置 HTMLRender 的情况 + c.Writer.Write([]byte(fmt.Sprintf("\n
%v
", name, obj))) +} + +// Redirect 执行 HTTP 重定向。 +// code 应为 3xx 状态码 (如 http.StatusMovedPermanently, http.StatusFound)。 +func (c *Context) Redirect(code int, location string) { + http.Redirect(c.Writer, c.Request, location, code) + c.Abort() + if fl, ok := c.Writer.(http.Flusher); ok { + fl.Flush() + } +} + +// ShouldBindJSON 尝试将请求体绑定到 JSON 对象。 +func (c *Context) ShouldBindJSON(obj interface{}) error { + if c.Request.Body == nil { + return errors.New("request body is empty") + } + /* + decoder := json.NewDecoder(c.Request.Body) + if err := decoder.Decode(obj); err != nil { + return fmt.Errorf("json binding error: %w", err) + } + */ + err := json.UnmarshalRead(c.Request.Body, obj) + if err != nil { + return fmt.Errorf("json binding error: %w", err) + } + return nil +} + +// ShouldBind 尝试将请求体绑定到各种类型(JSON, Form, XML 等)。 +// 这是一个复杂的通用绑定接口,通常根据 Content-Type 或其他头部来判断绑定方式。 +// 预留接口,可根据项目需求进行扩展。 +func (c *Context) ShouldBind(obj interface{}) error { + // TODO: 完整的通用绑定逻辑 + // 可以根据 c.Request.Header.Get("Content-Type") 来判断是 JSON, Form, XML 等 + // 例如: + // contentType := c.Request.Header.Get("Content-Type") + // if strings.HasPrefix(contentType, "application/json") { + // return c.ShouldBindJSON(obj) + // } + // if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") || strings.HasPrefix(contentType, "multipart/form-data") { + // return c.ShouldBindForm(obj) // 需要实现 ShouldBindForm + // } + return errors.New("generic binding not fully implemented yet, implement based on Content-Type") +} + +// AddError 添加一个错误到 Context。 +// 允许在处理请求过程中收集多个错误。 +func (c *Context) AddError(err error) { + c.Errors = append(c.Errors, err) +} + +// Errors 返回 Context 中收集的所有错误。 +func (c *Context) GetErrors() []error { + return c.Errors +} + +// Client 返回 Engine 提供的 HTTPClient。 +// 方便在请求处理函数中进行出站 HTTP 请求。 +func (c *Context) Client() *httpc.Client { + return c.HTTPClient +} + +// Context() 返回请求的上下文,用于取消操作。 +// 这是 Go 标准库的 `context.Context`,用于请求的取消和超时管理。 +func (c *Context) Context() context.Context { + return c.ctx +} + +// Done returns a channel that is closed when the request context is cancelled or times out. +// 继承自 `context.Context`。 +func (c *Context) Done() <-chan struct{} { + return c.ctx.Done() +} + +// Err returns the error, if any, that caused the context to be canceled or to +// time out. +// 继承自 `context.Context`。 +func (c *Context) Err() error { + return c.ctx.Err() +} + +// Value returns the value associated with this context for key, or nil if no +// value is associated with key. +// 可以用于从 Context 中获取与特定键关联的值,包括 Go 原生 Context 的值和 Touka Context 的 Keys。 +func (c *Context) Value(key interface{}) interface{} { + if keyAsString, ok := key.(string); ok { + if val, exists := c.Get(keyAsString); exists { + return val + } + } + return c.ctx.Value(key) // 尝试从 Go 原生 Context 中获取值 +} + +// GetWriter 获得一个 io.Writer 接口,可以直接向响应体写入数据。 +// 这对于需要自定义流式写入或与其他需要 io.Writer 的库集成非常有用。 +func (c *Context) GetWriter() io.Writer { + return c.Writer // ResponseWriter 接口嵌入了 http.ResponseWriter,而 http.ResponseWriter 实现了 io.Writer +} + +// WriteStream 接受一个 io.Reader 并将其内容流式传输到响应体。 +// 返回写入的字节数和可能遇到的错误。 +// 该方法在开始写入之前,会确保设置 HTTP 状态码为 200 OK。 +func (c *Context) WriteStream(reader io.Reader) (written int64, err error) { + // 确保在写入数据前设置状态码。 + // WriteHeader 会在第一次写入时被 Write 方法隐式调用,但显式调用可以确保状态码的预期。 + if !c.Writer.Written() { + c.Writer.WriteHeader(http.StatusOK) // 默认 200 OK + } + + written, err = copyb.Copy(c.Writer, reader) // 从 reader 读取并写入 ResponseWriter + if err != nil { + c.AddError(fmt.Errorf("failed to write stream: %w", err)) + } + return written, err +} + +// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体 +// 注意:请求体只能读取一次。 +func (c *Context) GetReqBody() io.ReadCloser { + return c.Request.Body +} + +// RequestIP 返回客户端的 IP 地址。 +// 它会根据 Engine 的配置 (ForwardByClientIP) 尝试从 X-Forwarded-For 或 X-Real-IP 等头部获取, +// 否则回退到 Request.RemoteAddr。 +func (c *Context) RequestIP() string { + if c.engine.ForwardByClientIP { + for _, headerName := range c.engine.RemoteIPHeaders { + if ipValue := c.Request.Header.Get(headerName); ipValue != "" { + // X-Forwarded-For 可能包含多个 IP,约定第一个(最左边)是客户端 IP + // 其他头部(如 X-Real-IP)通常只有一个 + ips := strings.Split(ipValue, ",") + for _, singleIP := range ips { + trimmedIP := strings.TrimSpace(singleIP) + // 使用 netip.ParseAddr 进行 IP 地址的解析和格式验证 + addr, err := netip.ParseAddr(trimmedIP) + if err == nil { + // 成功解析到合法的 IP 地址格式,立即返回 + return addr.String() + } + // 如果当前 singleIP 无效,继续检查列表中的下一个 + } + } + } + } + + // 如果没有启用 ForwardByClientIP 或头部中没有找到有效 IP,回退到 Request.RemoteAddr + // RemoteAddr 通常是 "host:port" 格式,但也可能直接就是 IP 地址 + remoteAddrStr := c.Request.RemoteAddr + ip, _, err := net.SplitHostPort(remoteAddrStr) // 尝试分离 host 和 port + if err != nil { + // 如果分离失败,意味着 remoteAddrStr 可能直接就是 IP 地址(或畸形) + ip = remoteAddrStr // 此时将整个 remoteAddrStr 作为候选 IP + } + + // 对从 RemoteAddr 中提取/使用的 IP 进行最终的合法性验证 + addr, parseErr := netip.ParseAddr(ip) + if parseErr == nil { + return addr.String() // 成功解析并返回合法 IP + } + + return "" +} + +// ClientIP 返回客户端的 IP 地址。 +// 这是一个别名,与 RequestIP 功能相同。 +func (c *Context) ClientIP() string { + return c.RequestIP() +} + +// ContentType 返回请求的 Content-Type 头部。 +func (c *Context) ContentType() string { + return c.GetReqHeader("Content-Type") +} + +// UserAgent 返回请求的 User-Agent 头部。 +func (c *Context) UserAgent() string { + return c.GetReqHeader("User-Agent") +} + +// Status 设置响应状态码。 +func (c *Context) Status(code int) { + c.Writer.WriteHeader(code) +} + +// File 将指定路径的文件作为响应发送。 +// 它会设置 Content-Type 和 Content-Disposition 头部。 +func (c *Context) File(filepath string) { + http.ServeFile(c.Writer, c.Request, filepath) + c.Abort() // 发送文件后中止后续处理 +} + +// SetHeader 设置响应头部。 +func (c *Context) SetHeader(key, value string) { + c.Writer.Header().Set(key, value) +} + +// AddHeader 添加响应头部。 +func (c *Context) AddHeader(key, value string) { + c.Writer.Header().Add(key, value) +} + +// DelHeader 删除响应头部。 +func (c *Context) DelHeader(key string) { + c.Writer.Header().Del(key) +} + +// GetReqHeader 获取请求头部的值。 +func (c *Context) GetReqHeader(key string) string { + return c.Request.Header.Get(key) +} + +// GetAllReqHeader 获取所有请求头部。 +func (c *Context) GetAllReqHeader() http.Header { + return c.Request.Header +} diff --git a/ecw.go b/ecw.go new file mode 100644 index 0000000..8f1417a --- /dev/null +++ b/ecw.go @@ -0,0 +1,143 @@ +package touka + +import ( + "net/http" + "sync" +) + +// errorCapturingResponseWriter 用于在 FileServer 处理时捕获错误状态码 +// 并在用户设置了自定义 ErrorHandler 时, 用该 ErrorHandler 处理此错误 +type errorCapturingResponseWriter struct { + w http.ResponseWriter // 原始的 ResponseWriter (通常是 touka.ResponseWriter 实例) + r *http.Request // 当前请求 + ctx *Context // 当前 touka.Context + errorHandlerFunc ErrorHandler // 实际要调用的错误处理函数 + statusCode int // FileServer 尝试设置的状态码 + headerSnapshot http.Header // FileServer 在调用 WriteHeader 前可能设置的头部快照 + capturedErrorSignal bool // 标记 FileServer 是否意图发送一个错误状态码 (>=400) + responseStarted bool // 标记包装器是否已经向原始 w 发送过任何数据 +} + +// errorResponseWriterPool 是用于复用 errorCapturingResponseWriter 实例的对象池 +var errorResponseWriterPool = sync.Pool{ + New: func() interface{} { + return &errorCapturingResponseWriter{ + headerSnapshot: make(http.Header), // 预先初始化 map, 减少 reset 时的分配 + } + }, +} + +// reset 重置 errorCapturingResponseWriter 的状态以供复用 +func (ecw *errorCapturingResponseWriter) reset(w http.ResponseWriter, r *http.Request, ctx *Context, eh ErrorHandler) { + ecw.w = w + ecw.r = r + ecw.ctx = ctx + ecw.errorHandlerFunc = eh + ecw.statusCode = 0 + // 清空 headerSnapshot, 但保留底层容量, 避免再次分配 + for k := range ecw.headerSnapshot { + delete(ecw.headerSnapshot, k) + } + ecw.capturedErrorSignal = false + ecw.responseStarted = false +} + +// AcquireErrorCapturingResponseWriter 从对象池获取一个 errorCapturingResponseWriter 实例 +// 必须在处理完成后调用 ReleaseErrorCapturingResponseWriter +func AcquireErrorCapturingResponseWriter(c *Context, eh ErrorHandler) *errorCapturingResponseWriter { + ecw := errorResponseWriterPool.Get().(*errorCapturingResponseWriter) + ecw.reset(c.Writer, c.Request, c, eh) // 传入 Touka Context 的 Writer + return ecw +} + +// ReleaseErrorCapturingResponseWriter 将一个 errorCapturingResponseWriter 实例返回到对象池 +func ReleaseErrorCapturingResponseWriter(ecw *errorCapturingResponseWriter) { + ecw.reset(nil, nil, nil, nil) // 清空敏感信息 + errorResponseWriterPool.Put(ecw) +} + +// Header 返回一个 http.Header +// 如果捕获到错误信号, 则操作内部的快照头部, 因为这些头部可能不会被发送, 或者会被 ErrorHandler 覆盖 +// 否则, 代理到原始 ResponseWriter 的 Header() +func (ecw *errorCapturingResponseWriter) Header() http.Header { + if ecw.capturedErrorSignal { + return ecw.headerSnapshot + } + // 返回原始 ResponseWriter 的 Header(), 确保 FileServer 设置的头部直接作用于最终响应 + return ecw.w.Header() +} + +// WriteHeader 记录状态码 +// 如果状态码表示错误 (>=400), 则激活 capturedErrorSignal 并不将状态码传递给原始 ResponseWriter +// 如果状态码表示成功, 则将快照中的头部(如果有)复制到原始 w, 然后调用原始 w.WriteHeader +func (ecw *errorCapturingResponseWriter) WriteHeader(statusCode int) { + if ecw.responseStarted { + return // 响应已开始, 忽略后续的 WriteHeader 调用 + } + ecw.statusCode = statusCode // 总是记录 FileServer 意图的状态码 + + if statusCode >= http.StatusBadRequest { + ecw.capturedErrorSignal = true + // 是一个错误状态码 (>=400), 激活错误信号 + // 不会将这个 WriteHeader 传递给原始的 w, 等待 processAfterFileServer 处理 + } else { + // 是成功状态码 + // 将 ecw.headerSnapshot 中(由 FileServer 在此之前通过 ecw.Header() 设置的) + // 任何头部直接复制到原始的 w.Header(), 确保多值头部正确传递 + for k, v := range ecw.headerSnapshot { + ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值 + } + ecw.w.WriteHeader(statusCode) // 实际写入状态码到原始 ResponseWriter + ecw.responseStarted = true // 标记成功响应已开始 + } +} + +// Write 将数据写入响应 +// 如果 capturedErrorSignal 为 true, 则丢弃数据, 因为 ErrorHandlerFunc 将负责响应体 +// 如果是成功路径, 则在必要时先发送隐式的 200 OK 头部, 然后将数据写入原始 ResponseWriter +func (ecw *errorCapturingResponseWriter) Write(data []byte) (int, error) { + if ecw.capturedErrorSignal { + return len(data), nil // 假装写入成功, 避免 FileServer 内部的错误 + } + + if !ecw.responseStarted { + if ecw.statusCode == 0 { // 如果 statusCode 仍为0 (WriteHeader 从未被显式调用) + ecw.statusCode = http.StatusOK // 隐式 200 OK + } + // 将 headerSnapshot 中的头部复制到原始 ResponseWriter 的 Header + for k, v := range ecw.headerSnapshot { + ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值 + } + ecw.w.WriteHeader(ecw.statusCode) // 发送实际的状态码 (可能是 200 或之前设置的 2xx) + ecw.responseStarted = true + } + return ecw.w.Write(data) // 写入数据到原始 ResponseWriter +} + +// Flush 尝试刷新缓冲的数据到客户端 +// 仅当未捕获错误且响应已开始, 并且原始 ResponseWriter 支持 http.Flusher 时才执行 +func (ecw *errorCapturingResponseWriter) Flush() { + if flusher, ok := ecw.w.(http.Flusher); ok { + if !ecw.capturedErrorSignal && ecw.responseStarted { + flusher.Flush() + } + } +} + +// processAfterFileServer 在 http.FileServer.ServeHTTP 调用完成后执行 +// 如果之前捕获了错误信号 (capturedErrorSignal is true) 并且响应尚未开始 +// 它将调用配置的 ErrorHandlerFunc 来处理错误 +func (ecw *errorCapturingResponseWriter) processAfterFileServer() { + if ecw.capturedErrorSignal && !ecw.responseStarted { + if ecw.ctx.engine.noRoute != nil { + ecw.ctx.Next() + } else { + // 调用用户自定义的 ErrorHandlerFunc, 由它负责完整的错误响应 + ecw.errorHandlerFunc(ecw.ctx, ecw.statusCode) + ecw.ctx.Abort() + } + } + // 如果 !ecw.capturedErrorSignal, 则成功路径已通过代理写入 ecw.w, 无需额外操作 + // 如果 ecw.capturedErrorSignal && ecw.responseStarted, 表示在捕获错误信号之前, + // 成功路径的响应已经开始, 此时无法再进行错误处理覆盖 +} diff --git a/engine.go b/engine.go new file mode 100644 index 0000000..8f5d406 --- /dev/null +++ b/engine.go @@ -0,0 +1,679 @@ +package touka + +import ( + "context" + "reflect" + "runtime" + "strings" + + "net/http" + "path" + + "sync" + + "github.com/WJQSERVER-STUDIO/httpc" +) + +// Last 返回链中的最后一个处理函数。 +// 如果链为空,则返回 nil。 +func (c HandlersChain) Last() HandlerFunc { + if len(c) > 0 { + return c[len(c)-1] + } + return nil +} + +// Engine 是 Touka 框架的核心,负责路由注册、中间件管理和请求分发。 +// 它实现了 http.Handler 接口,可以直接用于 http.ListenAndServe。 +type Engine struct { + methodTrees methodTrees // 存储所有HTTP方法的路由树 + + pool sync.Pool // Context Pool 用于复用 Context 对象,提高性能。 + + globalHandlers HandlersChain // 全局中间件,应用于所有路由。 + + maxParams uint16 // 记录所有路由中最大的参数数量,用于优化 Params 切片的分配。 + + // 可配置项,用于控制框架行为,参考 Gin + RedirectTrailingSlash bool // 是否自动重定向带尾部斜杠的路径到不带尾部斜杠的路径 (e.g. /foo/ -> /foo) + RedirectFixedPath bool // 是否自动修复路径中的大小写错误 (e.g. /Foo -> /foo) + HandleMethodNotAllowed bool // 是否启用 MethodNotAllowed 处理器 + ForwardByClientIP bool // 是否信任 X-Forwarded-For 等头部获取客户端 IP + RemoteIPHeaders []string // 用于获取客户端 IP 的头部列表,例如 {"X-Forwarded-For", "X-Real-IP"} + // TrustedProxies []string // 可信代理 IP 列表,用于判断是否使用 X-Forwarded-For 等头部 (预留接口) + + HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求。 + + HTMLRender interface{} // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 + + routesInfo []RouteInfo // 存储所有注册的路由信息 + + errorHandle ErrorHandle // 错误处理 + + noRoute HandlerFunc + + unMatchFS UnMatchFS // 未匹配下的处理 + + serverProtocols *http.Protocols //服务协议 + Protocols ProtocolsConfig //协议版本配置 + useDefaultProtocols bool //是否使用默认协议 +} + +type ErrorHandle struct { + useDefault bool + handler ErrorHandler +} + +type ErrorHandler func(c *Context, code int) + +// defaultErrorHandle 默认错误处理 +func defaultErrorHandle(c *Context, code int) { // 检查客户端是否已断开连接 + select { + case <-c.Request.Context().Done(): + + return + default: + // 输出json 状态码与状态码对应描述 + c.JSON(code, H{ + "code": code, + "message": http.StatusText(code), + }) + c.Writer.Flush() + c.Abort() + return + } +} + +type UnMatchFS struct { + FSForUnmatched http.FileSystem + ServeUnmatchedAsFS bool +} + +// ProtocolsConfig 协议版本配置结构体 +type ProtocolsConfig struct { + Http1 bool // 是否启用 HTTP/1.1 + Http2 bool // 是否启用 HTTP/2 + Http2_Cleartext bool // 是否启用 H2C +} + +// New 创建并返回一个 Engine 实例。 +func New() *Engine { + engine := &Engine{ + methodTrees: make(methodTrees, 0, 9), // 常见的HTTP方法有9个 (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS, CONNECT, TRACE) + RedirectTrailingSlash: true, + RedirectFixedPath: true, + HandleMethodNotAllowed: true, + ForwardByClientIP: true, + HTTPClient: httpc.New(), // 提供一个默认的 HTTPClient + routesInfo: make([]RouteInfo, 0), // 初始化路由信息切片 + globalHandlers: make(HandlersChain, 0), + RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, + errorHandle: ErrorHandle{ + useDefault: true, + handler: defaultErrorHandle, + }, + unMatchFS: UnMatchFS{ + ServeUnmatchedAsFS: false, + }, + } + //engine.SetProtocols(GetDefaultProtocolsConfig()) + engine.SetDefaultProtocols() + // 初始化 Context Pool,为每个新 Context 实例提供一个构造函数 + engine.pool.New = func() interface{} { + return &Context{ + Writer: newResponseWriter(nil), // 初始时可以传入nil,在ServeHTTP中会重新设置实际的 http.ResponseWriter + Params: make(Params, 0, engine.maxParams), // 预分配 Params 切片以减少内存分配 + Keys: make(map[string]interface{}), + Errors: make([]error, 0), + ctx: context.Background(), // 初始上下文,后续会被请求的 Context 覆盖 + HTTPClient: engine.HTTPClient, + engine: engine, // Context 持有 Engine 引用,方便访问 Engine 的配置 + } + } + + return engine +} + +// 生成一个携带默认中间件的Engine +func Default() *Engine { + engine := New() + engine.Use(Recovery()) + return engine +} + +// === 外部操作方法 === + +// 设置自定义错误处理 +func (engine *Engine) SetErrorHandler(handler ErrorHandler) { + engine.errorHandle.useDefault = false + engine.errorHandle.handler = handler +} + +// 获取一个默认错误处理handle +func (engine *Engine) GetDefaultErrHandler() ErrorHandler { + return defaultErrorHandle +} + +// 传入并配置unMatchFS +func (engine *Engine) SetUnMatchFS(fs http.FileSystem) { + if fs != nil { + engine.unMatchFS.FSForUnmatched = fs + engine.unMatchFS.ServeUnmatchedAsFS = true + } else { + engine.unMatchFS.ServeUnmatchedAsFS = false + } +} + +// 获取默认Protocol配置 +func GetDefaultProtocolsConfig() *ProtocolsConfig { + return &ProtocolsConfig{ + Http1: true, + Http2: false, + Http2_Cleartext: false, + } +} + +// 设置默认Protocols +func (engine *Engine) SetDefaultProtocols() { + engine.useDefaultProtocols = true + engine.SetProtocols(GetDefaultProtocolsConfig()) +} + +// 设置Protocol +func (engine *Engine) SetProtocols(config *ProtocolsConfig) { + engine.Protocols = *config + engine.serverProtocols = &http.Protocols{} // 初始化指针 + func() { + var p http.Protocols + p.SetHTTP1(config.Http1) + p.SetHTTP2(config.Http2) + p.SetUnencryptedHTTP2(config.Http2_Cleartext) + *engine.serverProtocols = p // 将值赋给指针指向的结构体 + }() + engine.useDefaultProtocols = false +} + +// 配置Req IP来源 Headers +func (engine *Engine) SetRemoteIPHeaders(headers []string) { + engine.RemoteIPHeaders = headers +} + +// SetForwardByClientIP 设置是否信任 X-Forwarded-For 等头部获取客户端 IP。 +func (engine *Engine) SetForwardByClientIP(enable bool) { + engine.ForwardByClientIP = enable +} + +// SetHTTPClient 设置 Engine 使用的 httpc.Client。 +func (engine *Engine) SetHTTPClient(client *httpc.Client) { + if client != nil { + engine.HTTPClient = client + } +} + +// registerMethodTree 内部方法,用于获取或注册对应 HTTP 方法的路由树根节点。 +// 如果该方法没有对应的树,则创建一个新的树。 +func (engine *Engine) registerMethodTree(method string) *node { + for _, tree := range engine.methodTrees { + if tree.method == method { + return tree.root + } + } + // 如果没有找到,则创建一个新的方法树并添加到列表中 + root := &node{ + nType: root, // 根节点类型 + fullPath: "/", // 根路径 + } + engine.methodTrees = append(engine.methodTrees, methodTree{method: method, root: root}) + return root +} + +// addRoute 将一个路由及处理函数链添加到路由树中。 +// 这是框架内部路由注册的核心逻辑。 +// groupPath 用于记录路由所属的分组路径。 +func (engine *Engine) addRoute(method, absolutePath, groupPath string, handlers HandlersChain) { // relativePath 更名为 absolutePath + if absolutePath == "" { + panic("absolute path must not be empty") + } + if len(handlers) == 0 { + panic("handlers must not be empty") + } + + // 检查并更新 maxParams,使用 absolutePath + if n := countParams(absolutePath); n > engine.maxParams { + engine.maxParams = n + } + + root := engine.registerMethodTree(method) + root.addRoute(absolutePath, handlers) // 调用 node 的 addRoute 方法将路由添加到树中 + + handlerName := "unknown" + if len(handlers) > 0 { + handlerName = getHandlerName(handlers.Last()) + } + + engine.routesInfo = append(engine.routesInfo, RouteInfo{ + Method: method, + Path: absolutePath, // 使用完整的绝对路径 + Handler: handlerName, + Group: groupPath, + }) +} + +// getHandlerName 辅助函数,用于获取 HandlerFunc 的名称。 +// 注意:这只是一个简单的反射实现,对于匿名函数或闭包,可能返回不可读的名称。 +func getHandlerName(h HandlerFunc) string { + //return reflect.TypeOf(h).Name() // 对于具名函数,返回函数名。对于匿名函数,可能返回空字符串或类似 func123 这样的名称。 + // 更精确的获取函数名需要 import "runtime" + // pc := reflect.ValueOf(h).Pointer() + // f := runtime.FuncForPC(pc) + // return f.Name() + + if h == nil { + return "nil_handler" + } + pc := reflect.ValueOf(h).Pointer() + f := runtime.FuncForPC(pc) + return f.Name() // 返回例如 "main.HomeHandler" 或 "touka.Logger" + +} + +// ServeHTTP 实现了 http.Handler 接口,是 Engine 处理所有 HTTP 请求的入口。 +// 每个传入的 HTTP 请求都会调用此方法。 +func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // 从 Context Pool 中获取一个 Context 对象进行复用 + c := engine.pool.Get().(*Context) + c.reset(w, req) // 重置 Context 对象的状态以适应当前请求 + + // 执行请求处理 + engine.handleRequest(c) + + // 将 Context 对象放回 Context Pool,以供下次复用 + engine.pool.Put(c) +} + +// handleRequest 负责根据请求查找路由并执行相应的处理函数链。 +// 这是路由查找和执行的核心逻辑。 +func (engine *Engine) handleRequest(c *Context) { + httpMethod := c.Request.Method + requestPath := c.Request.URL.Path + + // 查找对应的路由树的根节点 + rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 + if rootNode != nil { + // 查找匹配的节点和处理函数 + // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 + // skippedNodes 内部使用,因此无需从外部传入已分配的 slice + var skippedNodes []skippedNode // 用于回溯的跳过节点 + // 直接在 rootNode 上调用 getValue 方法 + value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码 + + if value.handlers != nil { + //c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数 + c.handlers = value.handlers + c.Next() // 执行处理函数链 + c.Writer.Flush() // 确保所有缓冲的响应数据被发送 + return + } + + // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) + if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 + if value.tsr && engine.RedirectTrailingSlash { + // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ + redirectPath := requestPath + if len(requestPath) > 0 && requestPath[len(requestPath)-1] == '/' { + redirectPath = requestPath[:len(requestPath)-1] + } else { + redirectPath = requestPath + "/" + } + c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 + return + } + // 尝试不区分大小写的查找 + // 直接在 rootNode 上调用 findCaseInsensitivePath 方法 + ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) + if found && engine.RedirectFixedPath { + c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 + return + } + } + } + /* + // 如果没有找到路由,且启用了 MethodNotAllowed 处理 + if engine.HandleMethodNotAllowed { + // 是否是OPTIONS方式 + if httpMethod == http.MethodOptions { + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + allowedMethods := []string{} + for _, treeIter := range engine.methodTrees { + var tempSkippedNodes []skippedNode + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) + if value.handlers != nil { + allowedMethods = append(allowedMethods, treeIter.method) + } + } + if len(allowedMethods) > 0 { + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + c.Status(http.StatusOK) + return + } + } + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + for _, treeIter := range engine.methodTrees { + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + continue + } + var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数 + if value.handlers != nil { + // 使用定义的ErrorHandle处理 + engine.errorHandle.handler(c, http.StatusMethodNotAllowed) + return + } + } + } + + // 是否开启了UnMatchFS + if engine.unMatchFS.ServeUnmatchedAsFS { + // 若不是GET HEAD OPTIONS则返回405 + if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { + // 使用 http.FileServer 处理未匹配的请求 + fileServer := http.FileServer(engine.unMatchFS.FSForUnmatched) + //ecw := newErrorCapturingResponseWriter(c, c.engine.errorHandle.handler) + ecw := AcquireErrorCapturingResponseWriter(c, c.engine.errorHandle.handler) + defer ReleaseErrorCapturingResponseWriter(ecw) + fileServer.ServeHTTP(ecw, c.Request) + ecw.processAfterFileServer() + return + } else { + log.Printf("Not Allowed Method: %s", c.Request.Method) + // 若为OPTIONS + if c.Request.Method == http.MethodOptions { + //返回allow get + c.Writer.Header().Set("Allow", "GET") + c.Status(http.StatusOK) + c.Abort() + return + } else { + engine.errorHandle.handler(c, http.StatusMethodNotAllowed) + return + } + } + + } else { + engine.errorHandle.handler(c, http.StatusNotFound) + return + } + */ + + // 构建处理链 + // 组合全局中间件和路由处理函数 + handlers := engine.globalHandlers + + // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 + // 则在全局中间件之后添加 MethodNotAllowed 处理器 + if engine.HandleMethodNotAllowed { + handlers = append(handlers, MethodNotAllowed()) + } + + // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed + // 则在处理链的最后添加 UnMatchFS 处理器 + if engine.unMatchFS.ServeUnmatchedAsFS { + handlers = append(handlers, unMatchFSHandle()) + } + + // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS + // 则在处理链的最后添加 NoRoute 处理器 + if engine.noRoute != nil { + handlers = append(handlers, engine.noRoute) + } + + handlers = append(handlers, NotFound()) + + c.handlers = handlers + c.Next() // 执行处理函数链 + c.Writer.Flush() // 确保所有缓冲的响应数据被发送 + +} + +// UnMatchFS HandleFunc +func unMatchFSHandle() HandlerFunc { + return func(c *Context) { + engine := c.engine + if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { + // 使用 http.FileServer 处理未匹配的请求 + fileServer := http.FileServer(engine.unMatchFS.FSForUnmatched) + //ecw := newErrorCapturingResponseWriter(c, c.engine.errorHandle.handler) + ecw := AcquireErrorCapturingResponseWriter(c, c.engine.errorHandle.handler) + defer ReleaseErrorCapturingResponseWriter(ecw) + fileServer.ServeHTTP(ecw, c.Request) + ecw.processAfterFileServer() + return + } else { + if engine.noRoute == nil { + // 若为OPTIONS + if c.Request.Method == http.MethodOptions { + //返回allow get + c.Writer.Header().Set("Allow", "GET") + c.Status(http.StatusOK) + c.Abort() + return + } else { + engine.errorHandle.handler(c, http.StatusMethodNotAllowed) + return + } + } else { + c.Next() + } + } + } +} + +// 405中间件 +func MethodNotAllowed() HandlerFunc { + return func(c *Context) { + httpMethod := c.Request.Method + requestPath := c.Request.URL.Path + engine := c.engine + // 是否是OPTIONS方式 + if httpMethod == http.MethodOptions { + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + allowedMethods := []string{} + for _, treeIter := range engine.methodTrees { + var tempSkippedNodes []skippedNode + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) + if value.handlers != nil { + allowedMethods = append(allowedMethods, treeIter.method) + } + } + if len(allowedMethods) > 0 { + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + c.Status(http.StatusOK) + return + } + } + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + for _, treeIter := range engine.methodTrees { + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + continue + } + var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数 + if value.handlers != nil { + // 使用定义的ErrorHandle处理 + engine.errorHandle.handler(c, http.StatusMethodNotAllowed) + return + } + } + } +} + +// 404最后处理 +func NotFound() HandlerFunc { + return func(c *Context) { + engine := c.engine + engine.errorHandle.handler(c, http.StatusNotFound) + return + } +} + +// 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) +func (Engine *Engine) NoRoute(handler HandlerFunc) { + Engine.noRoute = handler +} + +// combineHandlers 组合多个处理函数链为一个。 +// 这是构建完整处理链(全局中间件 + 组中间件 + 路由处理函数)的关键。 +func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) HandlersChain { + finalSize := len(h1) + len(h2) + mergedHandlers := make(HandlersChain, finalSize) + copy(mergedHandlers, h1) + copy(mergedHandlers[len(h1):], h2) + return mergedHandlers +} + +// Use 将全局中间件添加到 Engine。 +// 这些中间件将应用于所有注册的路由。 +func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { + engine.globalHandlers = append(engine.globalHandlers, middleware...) + return engine +} + +// Handle 注册通用 HTTP 方法的路由。 +// 这是所有具体 HTTP 方法注册的基础方法。 +func (engine *Engine) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) { + absolutePath := path.Join("/", relativePath) // 修正:统一使用 path.Join 进行路径拼接 + // 修正:将全局中间件与此路由的处理函数合并 + fullHandlers := engine.combineHandlers(engine.globalHandlers, handlers) + engine.addRoute(httpMethod, absolutePath, "/", fullHandlers) +} + +// GET 注册 GET 方法的路由。 +func (engine *Engine) GET(relativePath string, handlers ...HandlerFunc) { + engine.Handle(http.MethodGet, relativePath, handlers...) +} + +// POST 注册 POST 方法的路由。 +func (engine *Engine) POST(relativePath string, handlers ...HandlerFunc) { + engine.Handle(http.MethodPost, relativePath, handlers...) +} + +// PUT 注册 PUT 方法的路由。 +func (engine *Engine) PUT(relativePath string, handlers ...HandlerFunc) { + engine.Handle(http.MethodPut, relativePath, handlers...) +} + +// DELETE 注册 DELETE 方法的路由。 +func (engine *Engine) DELETE(relativePath string, handlers ...HandlerFunc) { + engine.Handle(http.MethodDelete, relativePath, handlers...) +} + +// PATCH 注册 PATCH 方法的路由。 +func (engine *Engine) PATCH(relativePath string, handlers ...HandlerFunc) { + engine.Handle(http.MethodPatch, relativePath, handlers...) +} + +// HEAD 注册 HEAD 方法的路由。 +func (engine *Engine) HEAD(relativePath string, handlers ...HandlerFunc) { + engine.Handle(http.MethodHead, relativePath, handlers...) +} + +// OPTIONS 注册 OPTIONS 方法的路由。 +func (engine *Engine) OPTIONS(relativePath string, handlers ...HandlerFunc) { + engine.Handle(http.MethodOptions, relativePath, handlers...) +} + +// ANY 注册所有常见 HTTP 方法的路由。 +func (engine *Engine) ANY(relativePath string, handlers ...HandlerFunc) { + engine.Handle(http.MethodGet, relativePath, handlers...) + engine.Handle(http.MethodPost, relativePath, handlers...) + engine.Handle(http.MethodPut, relativePath, handlers...) + engine.Handle(http.MethodDelete, relativePath, handlers...) + engine.Handle(http.MethodPatch, relativePath, handlers...) + engine.Handle(http.MethodHead, relativePath, handlers...) + engine.Handle(http.MethodOptions, relativePath, handlers...) +} + +// GetRouterInfo 返回所有已注册的路由信息。 +func (engine *Engine) GetRouterInfo() []RouteInfo { + return engine.routesInfo +} + +// Group 创建一个新的路由组。 +// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起。 +func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { + return &RouterGroup{ + Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 + basePath: path.Join("/", relativePath), + engine: engine, // 指向 Engine 实例 + } +} + +// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由。 +// 它也实现了 IRouter 接口,允许嵌套分组。 +type RouterGroup struct { + Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 + basePath string // 组路径前缀 + engine *Engine // 指向 Engine 实例,用于注册路由到全局路由树 +} + +// Use 将中间件应用于当前路由组。 +// 这些中间件将应用于当前组及其子组的所有路由。 +func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { + group.Handlers = append(group.Handlers, middleware...) + return group +} + +// Handle 注册通用 HTTP 方法的路由到当前组。 +// 路径是相对于当前组的 basePath。 +func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) { + absolutePath := path.Join(group.basePath, relativePath) + fullHandlers := group.engine.combineHandlers(group.Handlers, handlers) + group.engine.addRoute(httpMethod, absolutePath, group.basePath, fullHandlers) +} + +// GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS, ANY 方法与 Engine 类似,只是通过 Group 的 Handle 方法注册。 +func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) { + group.Handle(http.MethodGet, relativePath, handlers...) +} +func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) { + group.Handle(http.MethodPost, relativePath, handlers...) +} +func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) { + group.Handle(http.MethodPut, relativePath, handlers...) +} +func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) { + group.Handle(http.MethodDelete, relativePath, handlers...) +} +func (group *RouterGroup) PATCH(relativePath string, handlers ...HandlerFunc) { + group.Handle(http.MethodPatch, relativePath, handlers...) +} +func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) { + group.Handle(http.MethodHead, relativePath, handlers...) +} +func (group *RouterGroup) OPTIONS(relativePath string, handlers ...HandlerFunc) { + group.Handle(http.MethodOptions, relativePath, handlers...) +} +func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) { + group.Handle(http.MethodGet, relativePath, handlers...) + group.Handle(http.MethodPost, relativePath, handlers...) + group.Handle(http.MethodPut, relativePath, handlers...) + group.Handle(http.MethodDelete, relativePath, handlers...) + group.Handle(http.MethodPatch, relativePath, handlers...) + group.Handle(http.MethodHead, relativePath, handlers...) + group.Handle(http.MethodOptions, relativePath, handlers...) +} + +// Group 为当前组创建一个新的子组。 +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { + return &RouterGroup{ + Handlers: group.engine.combineHandlers(group.Handlers, handlers), + basePath: path.Join(group.basePath, relativePath), + engine: group.engine, // 指向 Engine 实例 + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..42be0b7 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/WJQSERVER/touka + +go 1.24.3 + +require ( + github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 + github.com/WJQSERVER-STUDIO/httpc v0.5.1 + github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8 +) + +require github.com/valyala/bytebufferpool v1.0.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6dfccbb --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 h1:JLtFd00AdFg/TP+dtvIzLkdHwKUGPOAijN1sMtEYoFg= +github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4/go.mod h1:FZ6XE+4TKy4MOfX1xWKe6Rwsg0ucYFCdNh1KLvyKTfc= +github.com/WJQSERVER-STUDIO/httpc v0.5.1 h1:+TKCPYBuj7PAHuiduGCGAqsHAa4QtsUfoVwRN777q64= +github.com/WJQSERVER-STUDIO/httpc v0.5.1/go.mod h1:M7KNUZjjhCkzzcg9lBPs9YfkImI+7vqjAyjdA19+joE= +github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8 h1:o8UqXPI6SVwQt04RGsqKp3qqmbOfTNMqDrWsc4O47kk= +github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= diff --git a/recovery.go b/recovery.go new file mode 100644 index 0000000..5597e53 --- /dev/null +++ b/recovery.go @@ -0,0 +1,38 @@ +package touka + +import ( + "fmt" + "log" + "net/http" + "runtime/debug" +) + +// Recovery 返回一个 Touka 的 HandlerFunc,用于捕获处理链中的 panic。 +func Recovery() HandlerFunc { + return func(c *Context) { + // 使用 defer 和 recover() 来捕获 panic + defer func() { + if r := recover(); r != nil { + // 记录 panic 信息和堆栈追踪 + err := fmt.Errorf("panic occurred: %v", r) + log.Printf("[Recovery] %s\n%s", err, debug.Stack()) // 记录错误和堆栈 + + // 检查客户端是否已断开连接,如果已断开则不再尝试写入响应 + select { + case <-c.Request.Context().Done(): + log.Printf("[Recovery] Client disconnected, skipping response for panic: %v", r) + return // 客户端已断开,直接返回 + default: + // 客户端未断开,返回 500 Internal Server Error + // 使用统一的错误处理机制 + c.engine.errorHandle.handler(c, http.StatusInternalServerError) + // Abort() 确保后续的处理函数不再执行 + c.Abort() + } + } + }() + + // 继续执行处理链中的下一个处理函数 + c.Next() + } +} diff --git a/respw.go b/respw.go new file mode 100644 index 0000000..ea8aa85 --- /dev/null +++ b/respw.go @@ -0,0 +1,85 @@ +package touka + +import ( + "bufio" + "errors" + "net" + "net/http" +) + +// --- ResponseWriter 包装 --- + +// ResponseWriter 接口扩展了 http.ResponseWriter 以提供对响应状态和大小的访问。 +type ResponseWriter interface { + http.ResponseWriter + http.Hijacker // 支持 WebSocket 等 + http.Flusher // 支持流式响应 + + Status() int // 返回写入的 HTTP 状态码,如果未写入则为 0 + Size() int // 返回已写入响应体的字节数 + Written() bool // 返回 WriteHeader 是否已被调用 +} + +// responseWriterImpl 是 ResponseWriter 的具体实现。 +type responseWriterImpl struct { + http.ResponseWriter + size int + status int // 0 表示尚未写入状态码 +} + +// NewResponseWriter 创建并返回一个 responseWriterImpl 实例。 +func newResponseWriter(w http.ResponseWriter) ResponseWriter { + rw := &responseWriterImpl{ + ResponseWriter: w, + status: 0, // 明确初始状态 + size: 0, + } + return rw +} + +func (rw *responseWriterImpl) WriteHeader(statusCode int) { + if rw.status == 0 { // 确保只设置一次 + rw.status = statusCode + rw.ResponseWriter.WriteHeader(statusCode) + } +} + +func (rw *responseWriterImpl) Write(b []byte) (int, error) { + if rw.status == 0 { + // 如果 WriteHeader 没被显式调用,Go 的 http server 会默认为 200 + // 我们在这里也将其标记为 200,因为即将写入数据。 + rw.status = http.StatusOK + // ResponseWriter.Write 会在第一次写入时自动调用 WriteHeader(http.StatusOK) + // 所以不需要在这里显式调用 rw.ResponseWriter.WriteHeader(http.StatusOK) + } + n, err := rw.ResponseWriter.Write(b) + rw.size += n + return n, err +} + +func (rw *responseWriterImpl) Status() int { + return rw.status +} + +func (rw *responseWriterImpl) Size() int { + return rw.size +} + +func (rw *responseWriterImpl) Written() bool { + return rw.status != 0 +} + +// Hijack 实现 http.Hijacker 接口。 +func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := rw.ResponseWriter.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, errors.New("http.Hijacker interface not supported") +} + +// Flush 实现 http.Flusher 接口。 +func (rw *responseWriterImpl) Flush() { + if fl, ok := rw.ResponseWriter.(http.Flusher); ok { + fl.Flush() + } +} diff --git a/serve.go b/serve.go new file mode 100644 index 0000000..ecda82b --- /dev/null +++ b/serve.go @@ -0,0 +1,231 @@ +package touka + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" +) + +const defaultShutdownTimeout = 5 * time.Second // 定义默认的优雅关闭超时时间 + +// resolveAddress 辅助函数,处理传入的地址参数。 +func resolveAddress(addr []string) string { + switch len(addr) { + case 0: + return ":8080" // 默认端口 + case 1: + return addr[0] + default: + panic("too many parameters for Run method") // 参数过多则报错 + } +} + +// Run 启动 HTTP 服务器。 +// 接受一个可选的地址参数,如果未提供则默认为 ":8080"。 +func (engine *Engine) Run(addr ...string) (err error) { + address := resolveAddress(addr) // 解析服务器地址 + log.Printf("Touka server listening on %s\n", address) + err = http.ListenAndServe(address, engine) // 启动 HTTP 服务器 + return +} + +// getShutdownTimeout 解析可选的超时参数,如果未提供或无效,则返回默认超时。 +func getShutdownTimeout(timeouts []time.Duration) time.Duration { + var timeout time.Duration + if len(timeouts) > 0 { + timeout = timeouts[0] + if timeout <= 0 { + log.Printf("Warning: Provided shutdown timeout (%v) is non-positive. Using default timeout %v.\n", timeout, defaultShutdownTimeout) + timeout = defaultShutdownTimeout + } + } else { + timeout = defaultShutdownTimeout + } + return timeout +} + +// handleGracefulShutdown 处理一个或多个 http.Server 实例的优雅关闭。 +// 它监听操作系统信号,并在指定超时时间内尝试关闭所有服务器。 +func handleGracefulShutdown(servers []*http.Server, timeout time.Duration) error { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("Shutting down Touka server(s)...") + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + var wg sync.WaitGroup + var errs []error + var errsMutex sync.Mutex // 保护 errs 切片 + + for _, srv := range servers { + srv := srv // capture loop variable + wg.Add(1) + go func() { + defer wg.Done() + if err := srv.Shutdown(ctx); err != nil { + errsMutex.Lock() + if err == context.DeadlineExceeded { + log.Printf("Server %s shutdown timed out after %v.\n", srv.Addr, timeout) + errs = append(errs, fmt.Errorf("server %s shutdown timed out", srv.Addr)) + } else { + log.Printf("Server %s forced to shutdown: %v\n", srv.Addr, err) + errs = append(errs, fmt.Errorf("server %s forced to shutdown: %w", srv.Addr, err)) + } + errsMutex.Unlock() + } + }() + } + wg.Wait() // 等待所有服务器的关闭 Goroutine 完成 + + if len(errs) > 0 { + return errors.Join(errs...) // 返回所有收集到的错误 + } + + log.Println("Touka server(s) exited gracefully.") + return nil +} + +// RunShutdown 启动 HTTP 服务器并支持优雅关闭。 +// 它监听操作系统信号 (SIGINT, SIGTERM),并在指定超时时间内优雅地关闭服务器。 +// addr: 服务器监听的地址,例如 ":8080"。 +// timeouts: 可选的超时时间,如果未提供,则默认为 5 秒。 +func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error { + timeout := getShutdownTimeout(timeouts) + + srv := &http.Server{ + Addr: addr, + Handler: engine, // Engine 实现了 http.Handler 接口 + } + + // 启动服务器在单独的 Goroutine 中运行,以便主 Goroutine 可以监听信号 + go func() { + log.Printf("Touka HTTP server listening on %s\n", addr) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Touka HTTP server listen error: %s\n", err) + } + }() + + return handleGracefulShutdown([]*http.Server{srv}, timeout) +} + +// RunWithTLS 启动 HTTPS 服务器并支持优雅关闭。 +// 用户需自行创建并传入 *tls.Config 实例,以提供完整的 TLS 配置自由度。 +// addr: 服务器监听的地址,例如 ":8443"。 +// tlsConfig: 包含 TLS 证书、密钥及其他配置的 tls.Config 实例。 +// timeouts: 可选的超时时间,如果未提供,则默认为 5 秒。 +func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil for RunWithTLS") + } + timeout := getShutdownTimeout(timeouts) + + srv := &http.Server{ + Addr: addr, + Handler: engine, + TLSConfig: tlsConfig, // 使用用户传入的 tls.Config + } + + if engine.useDefaultProtocols { + //加入HTTP2支持 + engine.SetProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, // 默认启用 HTTP/2 + Http2_Cleartext: false, + }) + } + + go func() { + log.Printf("Touka HTTPS server listening on %s\n", addr) + if err := srv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { + log.Fatalf("Touka HTTPS server listen error: %s\n", err) + } + }() + + return handleGracefulShutdown([]*http.Server{srv}, timeout) +} + +// RunWithTLSRedir 启动 HTTP 和 HTTPS 服务器,并将所有 HTTP 请求重定向到 HTTPS。 +// httpAddr: HTTP 服务器监听的地址,例如 ":80"。 +// httpsAddr: HTTPS 服务器监听的地址,例如 ":443"。 +// tlsConfig: 包含 TLS 证书、密钥及其他配置的 tls.Config 实例,用于 HTTPS 服务器。 +// timeouts: 可选的超时时间,如果未提供,则默认为 5 秒。 +func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil for RunWithTLSRedir") + } + timeout := getShutdownTimeout(timeouts) + + // HTTPS Server + httpsSrv := &http.Server{ + Addr: httpsAddr, + Handler: engine, + TLSConfig: tlsConfig, // 使用用户传入的 tls.Config + } + + if engine.useDefaultProtocols { + //加入HTTP2支持 + engine.SetProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, // 默认启用 HTTP/2 + Http2_Cleartext: false, + }) + } + + // HTTP Server for redirection + httpSrv := &http.Server{ + Addr: httpAddr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 从 r.Host 提取 hostname,例如 "localhost:8080" -> "localhost" + hostOnly, _, err := net.SplitHostPort(r.Host) + if err != nil { // r.Host 可能没有端口,例如 "example.com" + hostOnly = r.Host + } + + // 从 httpsAddr 提取目标 HTTPS 端口,例如 ":443" -> "443" + _, targetHttpsPort, err := net.SplitHostPort(httpsAddr) + if err != nil { // httpsAddr 必须包含一个有效的端口 + log.Fatalf("Error: Invalid HTTPS address '%s' for redirection. Must specify a port (e.g., ':443').", httpsAddr) + } + + var redirectHost string + if targetHttpsPort == "443" { + redirectHost = hostOnly // 如果是默认 HTTPS 端口,则无需在 URL 中显式指定端口 + } else { + redirectHost = net.JoinHostPort(hostOnly, targetHttpsPort) // 否则,显式指定端口 + } + + // 构建目标 HTTPS URL + targetURL := "https://" + redirectHost + r.URL.RequestURI() + http.Redirect(w, r, targetURL, http.StatusMovedPermanently) // 301 Permanent Redirect + }), + } + + // Start HTTPS server + go func() { + log.Printf("Touka HTTPS server listening on %s\n", httpsAddr) + if err := httpsSrv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { // 同样,传入空字符串 + log.Fatalf("Touka HTTPS server listen error: %s\n", err) + } + }() + + // Start HTTP redirect server + go func() { + log.Printf("Touka HTTP redirect server listening on %s\n", httpAddr) + if err := httpSrv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Touka HTTP redirect server listen error: %s\n", err) + } + }() + + return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, timeout) +} diff --git a/touka.go b/touka.go new file mode 100644 index 0000000..ba8400d --- /dev/null +++ b/touka.go @@ -0,0 +1,44 @@ +package touka + +import ( + "net/http" +) + +const ( + defaultMemory = 32 << 20 // 32 MB, Gin 的默认值,用于 ParseMultipartForm +) + +type H map[string]interface{} // map简写, 类似gin.H + +type Handle func(http.ResponseWriter, *http.Request, Params) + +// HandlerFunc 定义框架处理函数的类型,包括中间件和最终的路由处理函数。 +type HandlerFunc func(*Context) + +// HandlersChain 定义处理函数链(中间件栈)的类型。 +type HandlersChain []HandlerFunc + +// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 +type IRouter interface { + Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 + Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 + + Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 + GET(relativePath string, handlers ...HandlerFunc) + POST(relativePath string, handlers ...HandlerFunc) + PUT(relativePath string, handlers ...HandlerFunc) + DELETE(relativePath string, handlers ...HandlerFunc) + PATCH(relativePath string, handlers ...HandlerFunc) + HEAD(relativePath string, handlers ...HandlerFunc) + OPTIONS(relativePath string, handlers ...HandlerFunc) + ANY(relativePath string, handlers ...HandlerFunc) // 注册所有HTTP方法 +} + +// RouteInfo 包含一个已注册路由的详细信息。 +// 由 Router.GetRouters() 方法返回。 +type RouteInfo struct { + Method string // HTTP 方法 (GET, POST, PUT, DELETE 等) + Path string // 路由路径 + Handler string // 处理函数名称 + Group string // 路由分组 +} diff --git a/tree.go b/tree.go new file mode 100644 index 0000000..6f99223 --- /dev/null +++ b/tree.go @@ -0,0 +1,926 @@ +// Copyright 2013 Julien Schmidt. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be found +// at https://github.com/julienschmidt/httprouter/blob/master/LICENSE +// This tree.go is gin's fork, you can see https://github.com/gin-gonic/gin/blob/master/tree.go + +package touka // 定义包名为 touka,该包可能是一个路由或Web框架的核心组件 + +import ( + "bytes" // 导入 bytes 包,用于操作字节切片 + "net/url" // 导入 net/url 包,用于 URL 解析和转义 + "strings" // 导入 strings 包,用于字符串操作 + "unicode" // 导入 unicode 包,用于处理 Unicode 字符 + "unicode/utf8" // 导入 unicode/utf8 包,用于 UTF-8 编码和解码 + "unsafe" // 导入 unsafe 包,用于不安全的类型转换,以避免内存分配 +) + +// StringToBytes 将字符串转换为字节切片,不进行内存分配。 +// 更多详情,请参见 https://github.com/golang/go/issues/53003#issuecomment-1140276077。 +// 注意:此函数使用 unsafe 包,应谨慎使用,因为它可能导致内存不安全。 +func StringToBytes(s string) []byte { + return unsafe.Slice(unsafe.StringData(s), len(s)) +} + +// BytesToString 将字节切片转换为字符串,不进行内存分配。 +// 更多详情,请参见 https://github.com/golang/go/issues/53003#issuecomment-1140276077。 +// 注意:此函数使用 unsafe 包,应谨慎使用,因为它可能导致内存不安全。 +func BytesToString(b []byte) string { + return unsafe.String(unsafe.SliceData(b), len(b)) +} + +var ( + strColon = []byte(":") // 定义字节切片常量,表示冒号,用于路径参数识别 + strStar = []byte("*") // 定义字节切片常量,表示星号,用于捕获所有路径识别 + strSlash = []byte("/") // 定义字节切片常量,表示斜杠,用于路径分隔符识别 +) + +// Param 是单个 URL 参数,由键和值组成。 +type Param struct { + Key string // 参数的键名 + Value string // 参数的值 +} + +// Params 是 Param 类型的切片,由路由器返回。 +// 该切片是有序的,第一个 URL 参数也是切片中的第一个值。 +// 因此,按索引读取值是安全的。 +type Params []Param + +// Get 返回键名与给定名称匹配的第一个 Param 的值,并返回一个布尔值 true。 +// 如果未找到匹配的 Param,则返回空字符串和布尔值 false。 +func (ps Params) Get(name string) (string, bool) { + for _, entry := range ps { + if entry.Key == name { + return entry.Value, true + } + } + return "", false +} + +// ByName 返回键名与给定名称匹配的第一个 Param 的值。 +// 如果未找到匹配的 Param,则返回空字符串。 +func (ps Params) ByName(name string) (va string) { + va, _ = ps.Get(name) // 调用 Get 方法获取值,忽略第二个返回值 + return +} + +// methodTree 表示特定 HTTP 方法的路由树。 +type methodTree struct { + method string // HTTP 方法(例如 "GET", "POST") + root *node // 该方法的根节点 +} + +// methodTrees 是 methodTree 的切片。 +type methodTrees []methodTree + +// get 根据给定的 HTTP 方法查找并返回对应的根节点。 +// 如果找不到,则返回 nil。 +func (trees methodTrees) get(method string) *node { + for _, tree := range trees { + if tree.method == method { + return tree.root + } + } + return nil +} + +// longestCommonPrefix 计算两个字符串的最长公共前缀的长度。 +func longestCommonPrefix(a, b string) int { + i := 0 + max_ := min(len(a), len(b)) // 找出两个字符串中较短的长度 + for i < max_ && a[i] == b[i] { // 遍历直到达到较短长度或字符不匹配 + i++ + } + return i // 返回公共前缀的长度 +} + +// addChild 添加一个子节点,并将通配符子节点(如果存在)保持在数组的末尾。 +func (n *node) addChild(child *node) { + if n.wildChild && len(n.children) > 0 { + // 如果当前节点有通配符子节点,且已有子节点,则将通配符子节点移到末尾 + wildcardChild := n.children[len(n.children)-1] + n.children = append(n.children[:len(n.children)-1], child, wildcardChild) + } else { + // 否则,直接添加子节点 + n.children = append(n.children, child) + } +} + +// countParams 计算路径中参数(冒号)和捕获所有(星号)的数量。 +func countParams(path string) uint16 { + var n uint16 + s := StringToBytes(path) // 将路径字符串转换为字节切片 + n += uint16(bytes.Count(s, strColon)) // 统计冒号的数量 + n += uint16(bytes.Count(s, strStar)) // 统计星号的数量 + return n +} + +// countSections 计算路径中斜杠('/')的数量,即路径段的数量。 +func countSections(path string) uint16 { + s := StringToBytes(path) // 将路径字符串转换为字节切片 + return uint16(bytes.Count(s, strSlash)) // 统计斜杠的数量 +} + +// nodeType 定义了节点的类型。 +type nodeType uint8 + +const ( + static nodeType = iota // 静态节点,路径中不包含参数或通配符 + root // 根节点 + param // 参数节点(例如:name) + catchAll // 捕获所有节点(例如*path) +) + +// node 表示路由树中的一个节点。 +type node struct { + path string // 当前节点的路径段 + indices string // 子节点第一个字符的索引字符串,用于快速查找子节点 + wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) + nType nodeType // 节点的类型(静态、根、参数、捕获所有) + priority uint32 // 节点的优先级,用于查找时优先匹配 + children []*node // 子节点切片,最多有一个 :param 风格的节点位于数组末尾 + handlers HandlersChain // 绑定到此节点的处理函数链 + fullPath string // 完整路径,用于调试和错误信息 +} + +// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序。 +func (n *node) incrementChildPrio(pos int) int { + cs := n.children // 获取子节点切片 + cs[pos].priority++ // 增加指定位置子节点的优先级 + prio := cs[pos].priority // 获取新的优先级 + + // 调整位置(向前移动) + newPos := pos + // 从当前位置向前遍历,如果前一个子节点的优先级小于当前子节点,则交换位置 + for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- { + // 交换节点位置 + cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1] + } + + // 构建新的索引字符字符串 + if newPos != pos { + // 如果位置发生变化,则重新构建 indices 字符串 + // 前缀部分 + 移动的索引字符 + 剩余部分 + n.indices = n.indices[:newPos] + // 未改变的前缀,可能为空 + n.indices[pos:pos+1] + // 被移动的索引字符 + n.indices[newPos:pos] + n.indices[pos+1:] // 除去原位置字符的其余部分 + } + + return newPos // 返回新的位置 +} + +// addRoute 为给定路径添加一个带有处理函数的节点。 +// 非并发安全! +func (n *node) addRoute(path string, handlers HandlersChain) { + fullPath := path // 记录完整的路径 + n.priority++ // 增加当前节点的优先级 + + // 如果是空树(根节点) + if len(n.path) == 0 && len(n.children) == 0 { + n.insertChild(path, fullPath, handlers) // 直接插入子节点 + n.nType = root // 设置为根节点类型 + return + } + + parentFullPathIndex := 0 // 记录父节点的完整路径索引 + +walk: // 外部循环用于遍历和构建路由树 + for { + // 找到最长公共前缀。 + // 这也意味着公共前缀不包含 ':' 或 '*',因为现有键不能包含这些字符。 + i := longestCommonPrefix(path, n.path) + + // 分裂边 (Split edge) + // 如果公共前缀小于当前节点的路径长度,说明当前节点需要被分裂 + if i < len(n.path) { + child := node{ + path: n.path[i:], // 子节点路径是当前节点路径的剩余部分 + wildChild: n.wildChild, // 继承通配符子节点状态 + nType: static, // 分裂后的新节点是静态类型 + indices: n.indices, // 继承索引 + children: n.children, // 继承子节点 + handlers: n.handlers, // 继承处理函数 + priority: n.priority - 1, // 优先级减1,因为分裂会降低优先级 + fullPath: n.fullPath, // 继承完整路径 + } + + n.children = []*node{&child} // 当前节点现在只有一个子节点:新分裂出的子节点 + // 将当前节点的 indices 设置为新子节点路径的第一个字符 + n.indices = BytesToString([]byte{n.path[i]}) // []byte 用于正确的 Unicode 字符转换 + n.path = path[:i] // 当前节点的路径更新为公共前缀 + n.handlers = nil // 当前节点不再有处理函数(因为它被分裂了) + n.wildChild = false // 当前节点不再是通配符子节点 + n.fullPath = fullPath[:parentFullPathIndex+i] // 更新完整路径 + } + + // 将新节点作为当前节点的子节点 + // 如果路径仍然有剩余部分(即未完全匹配) + if i < len(path) { + path = path[i:] // 移除已匹配的前缀 + c := path[0] // 获取剩余路径的第一个字符 + + // '/' 在参数之后 + // 如果当前节点是参数类型,且剩余路径以 '/' 开头,并且只有一个子节点 + // 则继续遍历其唯一的子节点 + if n.nType == param && c == '/' && len(n.children) == 1 { + parentFullPathIndex += len(n.path) // 更新父节点完整路径索引 + n = n.children[0] // 移动到子节点 + n.priority++ // 增加子节点优先级 + continue walk // 继续外部循环 + } + + // 检查是否存在以下一个路径字节开头的子节点 + for i, max_ := 0, len(n.indices); i < max_; i++ { + if c == n.indices[i] { // 如果找到匹配的索引字符 + parentFullPathIndex += len(n.path) // 更新父节点完整路径索引 + i = n.incrementChildPrio(i) // 增加子节点优先级并重新排序 + n = n.children[i] // 移动到匹配的子节点 + continue walk // 继续外部循环 + } + } + + // 否则,插入新节点 + // 如果第一个字符不是 ':' 也不是 '*',且当前节点不是 catchAll 类型 + if c != ':' && c != '*' && n.nType != catchAll { + // 将新字符添加到索引字符串 + n.indices += BytesToString([]byte{c}) // []byte 用于正确的 Unicode 字符转换 + child := &node{ + fullPath: fullPath, // 设置子节点的完整路径 + } + n.addChild(child) // 添加新子节点 + n.incrementChildPrio(len(n.indices) - 1) // 增加新子节点的优先级并重新排序 + n = child // 移动到新子节点 + } else if n.wildChild { + // 正在插入一个通配符节点,需要检查是否与现有通配符冲突 + n = n.children[len(n.children)-1] // 移动到现有的通配符子节点 + n.priority++ // 增加其优先级 + + // 检查通配符是否匹配 + // 如果剩余路径长度大于等于通配符节点的路径长度,且通配符节点路径是剩余路径的前缀 + // 并且不是 catchAll 类型(不能有子路由), + // 并且通配符之后没有更多字符或紧跟着 '/' + if len(path) >= len(n.path) && n.path == path[:len(n.path)] && + // 不能向 catchAll 添加子节点 + n.nType != catchAll && + // 检查更长的通配符,例如 :name 和 :names + (len(n.path) >= len(path) || path[len(n.path)] == '/') { + continue walk // 继续外部循环 + } + + // 通配符冲突 + pathSeg := path + if n.nType != catchAll { + pathSeg, _, _ = strings.Cut(pathSeg, "/") // 如果不是 catchAll,则截取到下一个 '/' + } + prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path // 构造冲突前缀 + panic("'" + pathSeg + // 抛出 panic 表示通配符冲突 + "' in new path '" + fullPath + + "' conflicts with existing wildcard '" + n.path + + "' in existing prefix '" + prefix + + "'") + } + + n.insertChild(path, fullPath, handlers) // 插入子节点(可能包含通配符) + return // 完成添加路由 + } + + // 否则,将处理函数添加到当前节点 + if n.handlers != nil { + panic("handlers are already registered for path '" + fullPath + "'") // 如果已注册处理函数,则报错 + } + n.handlers = handlers // 设置处理函数 + n.fullPath = fullPath // 设置完整路径 + return // 完成添加路由 + } +} + +// findWildcard 搜索通配符段并检查名称是否包含无效字符。 +// 如果未找到通配符,则返回 -1 作为索引。 +func findWildcard(path string) (wildcard string, i int, valid bool) { + // 查找开始位置 + escapeColon := false // 是否正在处理转义字符 + for start, c := range []byte(path) { + if escapeColon { + escapeColon = false + if c == ':' { // 如果转义字符是 ':',则跳过 + continue + } + panic("invalid escape string in path '" + path + "'") // 无效的转义字符串 + } + if c == '\\' { // 如果是反斜杠,则设置转义标志 + escapeColon = true + continue + } + // 通配符以 ':' (参数) 或 '*' (捕获所有) 开头 + if c != ':' && c != '*' { + continue + } + + // 查找结束位置并检查无效字符 + valid = true // 默认为有效 + for end, c := range []byte(path[start+1:]) { + switch c { + case '/': // 如果遇到斜杠,说明通配符段结束 + return path[start : start+1+end], start, valid + case ':', '*': // 如果在通配符段中再次遇到 ':' 或 '*',则无效 + valid = false + } + } + return path[start:], start, valid // 返回找到的通配符、起始索引和有效性 + } + return "", -1, false // 未找到通配符 +} + +// insertChild 插入一个带有处理函数的节点。 +// 此函数处理包含通配符的路径插入逻辑。 +func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) { + for { + // 找到第一个通配符之前的前缀 + wildcard, i, valid := findWildcard(path) + if i < 0 { // 未找到通配符,结束循环 + break + } + + // 通配符名称只能包含一个 ':' 或 '*' 字符 + if !valid { + panic("only one wildcard per path segment is allowed, has: '" + + wildcard + "' in path '" + fullPath + "'") // 报错:每个路径段只允许一个通配符 + } + + // 检查通配符是否有名称 + if len(wildcard) < 2 { + panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") // 报错:通配符必须有非空名称 + } + + if wildcard[0] == ':' { // 如果是参数节点 (param) + if i > 0 { + // 在当前通配符之前插入前缀 + n.path = path[:i] // 当前节点路径更新为前缀 + path = path[i:] // 剩余路径去除前缀 + } + + child := &node{ + nType: param, // 子节点类型为参数 + path: wildcard, // 子节点路径为通配符名称 + fullPath: fullPath, // 设置子节点的完整路径 + } + n.addChild(child) // 添加子节点 + n.wildChild = true // 当前节点标记为有通配符子节点 + n = child // 移动到新创建的参数节点 + n.priority++ // 增加优先级 + + // 如果路径不以通配符结束,则会有一个以 '/' 开头的子路径 + if len(wildcard) < len(path) { + path = path[len(wildcard):] // 剩余路径去除通配符部分 + + child := &node{ + priority: 1, // 新子节点优先级 + fullPath: fullPath, // 设置子节点的完整路径 + } + n.addChild(child) // 添加子节点(通常是斜杠后的静态部分) + n = child // 移动到这个新子节点 + continue // 继续循环,查找下一个通配符或结束 + } + + // 否则,我们已经完成。将处理函数插入到新叶节点中 + n.handlers = handlers // 设置处理函数 + return // 完成 + } + + // 如果是捕获所有节点 (catchAll) + if i+len(wildcard) != len(path) { + panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") // 报错:捕获所有路由只能在路径末尾 + } + + // 检查路径段冲突 + if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { + pathSeg := "" + if len(n.children) != 0 { + pathSeg, _, _ = strings.Cut(n.children[0].path, "/") + } + panic("catch-all wildcard '" + path + // 报错:捕获所有通配符与现有路径段冲突 + "' in new path '" + fullPath + + "' conflicts with existing path segment '" + pathSeg + + "' in existing prefix '" + n.path + pathSeg + + "'") + } + + // 当前固定宽度为 1,用于 '/' + i-- + if i < 0 || path[i] != '/' { + panic("no / before catch-all in path '" + fullPath + "'") // 报错:捕获所有之前没有 '/' + } + + n.path = path[:i] // 当前节点路径更新为 catchAll 之前的部分 + + // 第一个节点:路径为空的 catchAll 节点 + child := &node{ + wildChild: true, // 标记为有通配符子节点 + nType: catchAll, // 类型为 catchAll + fullPath: fullPath, // 设置完整路径 + } + + n.addChild(child) // 添加子节点 + n.indices = string('/') // 索引设置为 '/' + n = child // 移动到新创建的 catchAll 节点 + n.priority++ // 增加优先级 + + // 第二个节点:包含变量的节点 + child = &node{ + path: path[i:], // 路径为 catchAll 的实际路径段 + nType: catchAll, // 类型为 catchAll + handlers: handlers, // 设置处理函数 + priority: 1, // 优先级 + fullPath: fullPath, // 设置完整路径 + } + n.children = []*node{child} // 将其设为当前节点的唯一子节点 + + return // 完成 + } + + // 如果没有找到通配符,简单地插入路径和处理函数 + n.path = path // 设置当前节点路径 + n.handlers = handlers // 设置处理函数 + n.fullPath = fullPath // 设置完整路径 +} + +// nodeValue 包含 (*Node).getValue 方法的返回值 +type nodeValue struct { + handlers HandlersChain // 匹配到的处理函数链 + params *Params // 提取的 URL 参数 + tsr bool // 是否建议进行尾部斜杠重定向 (Trailing Slash Redirect) + fullPath string // 匹配到的完整路径 +} + +// skippedNode 结构体用于在 getValue 查找过程中记录跳过的节点信息,以便回溯。 +type skippedNode struct { + path string // 跳过时的当前路径 + node *node // 跳过的节点 + paramsCount int16 // 跳过时已收集的参数数量 +} + +// getValue 返回注册到给定路径(key)的处理函数。通配符的值会保存到 map 中。 +// 如果找不到处理函数,则在存在一个带有额外(或不带)尾部斜杠的处理函数时, +// 建议进行 TSR(尾部斜杠重定向)。 +func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { + var globalParamsCount int16 // 全局参数计数 + +walk: // 外部循环用于遍历路由树 + for { + prefix := n.path // 当前节点的路径前缀 + if len(path) > len(prefix) { + if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 + path = path[len(prefix):] // 移除已匹配的前缀 + + // 优先尝试所有非通配符子节点,通过匹配索引字符 + idxc := path[0] // 剩余路径的第一个字符 + for i, c := range []byte(n.indices) { + if c == idxc { // 如果找到匹配的索引字符 + // 如果当前节点有通配符子节点,则将当前节点添加到 skippedNodes,以便回溯 + if n.wildChild { + index := len(*skippedNodes) + *skippedNodes = (*skippedNodes)[:index+1] + (*skippedNodes)[index] = skippedNode{ + path: prefix + path, // 记录跳过的路径 + node: &node{ // 复制当前节点的状态 + path: n.path, + wildChild: n.wildChild, + nType: n.nType, + priority: n.priority, + children: n.children, + handlers: n.handlers, + fullPath: n.fullPath, + }, + paramsCount: globalParamsCount, // 记录当前参数计数 + } + } + + n = n.children[i] // 移动到匹配的子节点 + continue walk // 继续外部循环 + } + } + + if !n.wildChild { + // 如果路径在循环结束时不等于 '/' 且当前节点没有子节点 + // 当前节点需要回溯到最后一个有效的 skippedNode + if path != "/" { + for length := len(*skippedNodes); length > 0; length-- { + skippedNode := (*skippedNodes)[length-1] + *skippedNodes = (*skippedNodes)[:length-1] // 弹出 skippedNode + if strings.HasSuffix(skippedNode.path, path) { // 如果跳过的路径包含当前路径 + path = skippedNode.path // 恢复路径 + n = skippedNode.node // 恢复节点 + if value.params != nil { + *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 + } + globalParamsCount = skippedNode.paramsCount // 恢复参数计数 + continue walk // 继续外部循环 + } + } + } + + // 未找到。 + // 如果存在一个带有额外(或不带)尾部斜杠的处理函数, + // 我们可以建议重定向到相同 URL,不带尾部斜杠。 + value.tsr = path == "/" && n.handlers != nil // 如果路径是 "/" 且当前节点有处理函数,则建议 TSR + return value + } + + // 处理通配符子节点,它总是位于数组的末尾 + n = n.children[len(n.children)-1] // 移动到通配符子节点 + globalParamsCount++ // 增加全局参数计数 + + switch n.nType { + case param: // 参数节点 + // 查找参数结束位置('/' 或路径末尾) + end := 0 + for end < len(path) && path[end] != '/' { + end++ + } + + // 保存参数值 + if params != nil { + // 如果需要,预分配容量 + if cap(*params) < int(globalParamsCount) { + newParams := make(Params, len(*params), globalParamsCount) + copy(newParams, *params) + *params = newParams + } + + if value.params == nil { + value.params = params + } + // 在预分配的容量内扩展切片 + i := len(*value.params) + *value.params = (*value.params)[:i+1] // 扩展切片 + val := path[:end] // 提取参数值 + if unescape { // 如果需要进行 URL 解码 + if v, err := url.QueryUnescape(val); err == nil { + val = v // 解码成功则更新值 + } + } + (*value.params)[i] = Param{ // 存储参数 + Key: n.path[1:], // 参数键名(去除冒号) + Value: val, // 参数值 + } + } + + // 我们需要继续深入! + if end < len(path) { + if len(n.children) > 0 { + path = path[end:] // 移除已提取的参数部分 + n = n.children[0] // 移动到下一个子节点 + continue walk // 继续外部循环 + } + + // ... 但我们无法继续 + value.tsr = len(path) == end+1 // 如果路径只剩下斜杠,则建议 TSR + return value + } + + if value.handlers = n.handlers; value.handlers != nil { + value.fullPath = n.fullPath + return value // 如果当前节点有处理函数,则返回 + } + if len(n.children) == 1 { + // 未找到处理函数。检查是否存在此路径加尾部斜杠的处理函数,以进行 TSR 建议 + n = n.children[0] + value.tsr = (n.path == "/" && n.handlers != nil) || (n.path == "" && n.indices == "/") + } + return value + + case catchAll: // 捕获所有节点 + // 保存参数值 + if params != nil { + // 如果需要,预分配容量 + if cap(*params) < int(globalParamsCount) { + newParams := make(Params, len(*params), globalParamsCount) + copy(newParams, *params) + *params = newParams + } + + if value.params == nil { + value.params = params + } + // 在预分配的容量内扩展切片 + i := len(*value.params) + *value.params = (*value.params)[:i+1] // 扩展切片 + val := path // 参数值是剩余的整个路径 + if unescape { // 如果需要进行 URL 解码 + if v, err := url.QueryUnescape(path); err == nil { + val = v // 解码成功则更新值 + } + } + (*value.params)[i] = Param{ // 存储参数 + Key: n.path[2:], // 参数键名(去除星号) + Value: val, // 参数值 + } + } + + value.handlers = n.handlers // 设置处理函数 + value.fullPath = n.fullPath + return value // 返回 + + default: + panic("invalid node type") // 无效的节点类型 + } + } + } + + if path == prefix { // 如果路径完全匹配当前节点的前缀 + // 如果当前路径不等于 '/' 且节点没有注册的处理函数,且最近匹配的节点有子节点 + // 当前节点需要回溯到最后一个有效的 skippedNode + if n.handlers == nil && path != "/" { + for length := len(*skippedNodes); length > 0; length-- { + skippedNode := (*skippedNodes)[length-1] + *skippedNodes = (*skippedNodes)[:length-1] + if strings.HasSuffix(skippedNode.path, path) { + path = skippedNode.path + n = skippedNode.node + if value.params != nil { + *value.params = (*value.params)[:skippedNode.paramsCount] + } + globalParamsCount = skippedNode.paramsCount + continue walk + } + } + } + // 我们应该已经到达包含处理函数的节点。 + // 检查此节点是否注册了处理函数。 + if value.handlers = n.handlers; value.handlers != nil { + value.fullPath = n.fullPath + return value // 如果有处理函数,则返回 + } + + // 如果此路由没有处理函数,但此路由有通配符子节点, + // 则此路径必须有一个带有额外尾部斜杠的处理函数。 + if path == "/" && n.wildChild && n.nType != root { + value.tsr = true // 建议 TSR + return value + } + + if path == "/" && n.nType == static { + value.tsr = true // 如果是静态节点且路径是根,则建议 TSR + return value + } + + // 未找到处理函数。检查此路径加尾部斜杠是否存在处理函数,以进行尾部斜杠重定向建议 + for i, c := range []byte(n.indices) { + if c == '/' { // 如果索引中包含 '/' + n = n.children[i] // 移动到对应的子节点 + value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 + (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 + return value + } + } + + return value + } + + // 未找到。我们可以建议重定向到相同 URL,添加一个额外的尾部斜杠, + // 如果该路径的叶节点存在。 + value.tsr = path == "/" || // 如果路径是根路径 + (len(prefix) == len(path)+1 && prefix[len(path)] == '/' && // 或者前缀比路径多一个斜杠 + path == prefix[:len(prefix)-1] && n.handlers != nil) // 且路径是前缀去掉最后一个斜杠,且有处理函数 + + // 回溯到最后一个有效的 skippedNode + if !value.tsr && path != "/" { + for length := len(*skippedNodes); length > 0; length-- { + skippedNode := (*skippedNodes)[length-1] + *skippedNodes = (*skippedNodes)[:length-1] + if strings.HasSuffix(skippedNode.path, path) { + path = skippedNode.path + n = skippedNode.node + if value.params != nil { + *value.params = (*value.params)[:skippedNode.paramsCount] + } + globalParamsCount = skippedNode.paramsCount + continue walk + } + } + } + + return value // 返回未找到 + } +} + +// findCaseInsensitivePath 对给定路径进行不区分大小写的查找,并尝试找到处理函数。 +// 它还可以选择修复尾部斜杠。 +// 它返回大小写校正后的路径和一个布尔值,指示查找是否成功。 +func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { + const stackBufSize = 128 // 栈上缓冲区的默认大小 + + // 在常见情况下使用栈上静态大小的缓冲区。 + // 如果路径太长,则在堆上分配缓冲区。 + buf := make([]byte, 0, stackBufSize) + if length := len(path) + 1; length > stackBufSize { + buf = make([]byte, 0, length) // 如果路径太长,则分配更大的缓冲区 + } + + ciPath := n.findCaseInsensitivePathRec( + path, + buf, // 预分配足够的内存给新路径 + [4]byte{}, // 空的 rune 缓冲区 + fixTrailingSlash, // 是否修复尾部斜杠 + ) + + return ciPath, ciPath != nil // 返回校正后的路径和是否成功找到 +} + +// shiftNRuneBytes 将字节数组中的字节向左移动 n 个字节。 +func shiftNRuneBytes(rb [4]byte, n int) [4]byte { + switch n { + case 0: + return rb + case 1: + return [4]byte{rb[1], rb[2], rb[3], 0} // 移动1位 + case 2: + return [4]byte{rb[2], rb[3]} // 移动2位 + case 3: + return [4]byte{rb[3]} // 移动3位 + default: + return [4]byte{} // 其他情况返回空 + } +} + +// findCaseInsensitivePathRec 由 n.findCaseInsensitivePath 使用的递归不区分大小写查找函数。 +func (n *node) findCaseInsensitivePathRec(path string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) []byte { + npLen := len(n.path) // 当前节点的路径长度 + +walk: // 外部循环用于遍历路由树 + // 只要剩余路径长度大于等于当前节点路径长度,且当前节点路径(除第一个字符外)不区分大小写匹配剩余路径 + for len(path) >= npLen && (npLen == 0 || strings.EqualFold(path[1:npLen], n.path[1:])) { + // 将公共前缀添加到结果中 + oldPath := path // 保存原始路径 + path = path[npLen:] // 移除已匹配的前缀 + ciPath = append(ciPath, n.path...) // 将当前节点的路径添加到不区分大小写路径中 + + if len(path) == 0 { // 如果路径已完全匹配 + // 我们应该已经到达包含处理函数的节点。 + // 检查此节点是否注册了处理函数。 + if n.handlers != nil { + return ciPath // 如果有处理函数,则返回校正后的路径 + } + + // 未找到处理函数。 + // 尝试通过添加尾部斜杠来修复路径 + if fixTrailingSlash { + for i, c := range []byte(n.indices) { + if c == '/' { // 如果索引中包含 '/' + n = n.children[i] // 移动到对应的子节点 + if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 + (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 + return append(ciPath, '/') // 返回添加斜杠后的路径 + } + return nil // 否则返回 nil + } + } + } + return nil // 未找到,返回 nil + } + + // 如果此节点没有通配符(参数或捕获所有)子节点, + // 我们可以直接查找下一个子节点并继续遍历树。 + if !n.wildChild { + // 跳过已处理的 rune 字节 + rb = shiftNRuneBytes(rb, npLen) + + if rb[0] != 0 { + // 旧 rune 未处理完 + idxc := rb[0] + for i, c := range []byte(n.indices) { + if c == idxc { + // 继续处理子节点 + n = n.children[i] + npLen = len(n.path) + continue walk // 继续外部循环 + } + } + } else { + // 处理一个新的 rune + var rv rune + + // 查找 rune 的开始位置。 + // Runes 最长为 4 字节。 + // -4 肯定会是另一个 rune。 + var off int + for max_ := min(npLen, 3); off < max_; off++ { + if i := npLen - off; utf8.RuneStart(oldPath[i]) { + // 从缓存路径读取 rune + rv, _ = utf8.DecodeRuneInString(oldPath[i:]) + break + } + } + + // 计算当前 rune 的小写字节 + lo := unicode.ToLower(rv) + utf8.EncodeRune(rb[:], lo) // 将小写 rune 编码到缓冲区 + + // 跳过已处理的字节 + rb = shiftNRuneBytes(rb, off) + + idxc := rb[0] + for i, c := range []byte(n.indices) { + // 小写匹配 + if c == idxc { + // 必须使用递归方法,因为大写字节和小写字节都可能作为索引存在 + if out := n.children[i].findCaseInsensitivePathRec( + path, ciPath, rb, fixTrailingSlash, + ); out != nil { + return out // 如果找到,则返回 + } + break + } + } + + // 如果未找到匹配项,则对大写 rune 执行相同操作(如果它不同) + if up := unicode.ToUpper(rv); up != lo { + utf8.EncodeRune(rb[:], up) // 将大写 rune 编码到缓冲区 + rb = shiftNRuneBytes(rb, off) + + idxc := rb[0] + for i, c := range []byte(n.indices) { + // 大写匹配 + if c == idxc { + // 继续处理子节点 + n = n.children[i] + npLen = len(n.path) + continue walk // 继续外部循环 + } + } + } + } + + // 未找到。我们可以建议重定向到相同 URL,不带尾部斜杠, + // 如果该路径的叶节点存在。 + if fixTrailingSlash && path == "/" && n.handlers != nil { + return ciPath // 如果可以修复尾部斜杠且有处理函数,则返回 + } + return nil // 未找到,返回 nil + } + + n = n.children[0] // 移动到通配符子节点(通常是唯一一个) + switch n.nType { + case param: // 参数节点 + // 查找参数结束位置('/' 或路径末尾) + end := 0 + for end < len(path) && path[end] != '/' { + end++ + } + + // 将参数值添加到不区分大小写路径中 + ciPath = append(ciPath, path[:end]...) + + // 我们需要继续深入! + if end < len(path) { + if len(n.children) > 0 { + // 继续处理子节点 + n = n.children[0] + npLen = len(n.path) + path = path[end:] + continue // 继续外部循环 + } + + // ... 但我们无法继续 + if fixTrailingSlash && len(path) == end+1 { + return ciPath // 如果可以修复尾部斜杠且路径只剩下斜杠,则返回 + } + return nil // 未找到,返回 nil + } + + if n.handlers != nil { + return ciPath // 如果有处理函数,则返回 + } + + if fixTrailingSlash && len(n.children) == 1 { + // 未找到处理函数。检查此路径加尾部斜杠是否存在处理函数 + n = n.children[0] + if n.path == "/" && n.handlers != nil { + return append(ciPath, '/') // 返回添加斜杠后的路径 + } + } + + return nil // 未找到,返回 nil + + case catchAll: // 捕获所有节点 + return append(ciPath, path...) // 返回添加剩余路径后的路径(捕获所有) + + default: + panic("invalid node type") // 无效的节点类型 + } + } + + // 未找到。 + // 尝试通过添加/删除尾部斜杠来修复路径 + if fixTrailingSlash { + if path == "/" { + return ciPath // 如果路径是根路径,则返回 + } + // 如果路径长度比当前节点路径少一个斜杠,且末尾是斜杠, + // 且不区分大小写匹配,且当前节点有处理函数 + if len(path)+1 == npLen && n.path[len(path)] == '/' && + strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handlers != nil { + return append(ciPath, n.path...) // 返回添加当前节点路径后的路径 + } + } + return nil // 未找到,返回 nil +}