Compare commits

..

No commits in common. "d439662adf1a60a75dc8b22d6955497d31c30bd1" and "c7a9a889e4b256fd0035f01439d6ab1372bbf7d4" have entirely different histories.

50 changed files with 833 additions and 9292 deletions

3
.gitignore vendored
View file

@ -1,2 +1 @@
test test
/bench_route_match_baseline.txt

View file

@ -59,9 +59,9 @@ func main() {
c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query) c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query)
}) })
// 启动服务器(通过 WithGracefulShutdown 启用优雅关闭) // 启动服务器 (支持优雅关闭)
log.Println("Touka Server starting on :8080...") log.Println("Touka Server starting on :8080...")
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
log.Fatalf("Touka server failed to start: %v", err) log.Fatalf("Touka server failed to start: %v", err)
} }
} }

View file

@ -70,13 +70,13 @@ func main() {
r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB
// ... 其他配置 // ... 其他配置
r.Run(touka.WithAddr(":8080")) r.Run(":8080")
} }
``` ```
#### 1.3. 服务器生命周期管理 #### 1.3. 服务器生命周期管理
Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。 Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。
```go ```go
func main() { func main() {
@ -90,11 +90,11 @@ func main() {
fmt.Println("自定义的 HTTP 服务器配置已应用") fmt.Println("自定义的 HTTP 服务器配置已应用")
}) })
// 启动服务器,并通过 Run 选项启用优雅关闭 // 启动服务器,并支持优雅关闭
// Run(...) 会阻塞当前 goroutine // RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号
// WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒 // 第二个参数是优雅关闭的超时时间
fmt.Println("服务器启动于 :8080") fmt.Println("服务器启动于 :8080")
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
log.Fatalf("服务器启动失败: %v", err) log.Fatalf("服务器启动失败: %v", err)
} }
} }
@ -187,7 +187,7 @@ func main() {
} }
} }
r.Run(touka.WithAddr(":8080")) r.Run(":8080")
} }
func AuthMiddleware() touka.HandlerFunc { func AuthMiddleware() touka.HandlerFunc {
@ -313,7 +313,7 @@ func main() {
}) })
}) })
r.Run(touka.WithAddr(":8080")) r.Run(":8080")
} }
// templates/index.html // templates/index.html
@ -400,7 +400,7 @@ func main() {
c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID}) c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID})
}) })
r.Run(touka.WithAddr(":8080")) r.Run(":8080")
} }
``` ```
@ -483,7 +483,7 @@ func main() {
// 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获 // 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获
r.StaticDir("/files", "./non-existent-dir") r.StaticDir("/files", "./non-existent-dir")
r.Run(touka.WithAddr(":8080")) r.Run(":8080")
} }
``` ```
@ -546,7 +546,7 @@ func main() {
// 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录 // 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录
r.StaticFS("/", http.FS(subFS)) r.StaticFS("/", http.FS(subFS))
r.Run(touka.WithAddr(":8080")) r.Run(":8080")
} }
``` ```

View file

@ -1,52 +0,0 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2024 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"github.com/WJQSERVER-STUDIO/httpc"
"github.com/fenthope/reco"
)
// --- reco 兼容函数 ---
// GetLogReco 返回底层的 reco.Logger 实例
// 用于需要访问 reco 特定功能的场景
// 如果当前 logger 不是 *reco.Logger 类型,返回 nil
//
//go:fix inline
func (engine *Engine) GetLogReco() *reco.Logger {
return engine.LogReco
}
// SetLogReco 设置 reco.Logger 实例
// 用于向后兼容,等价于 SetLogger(l)
//
//go:fix inline
func (engine *Engine) SetLogReco(l *reco.Logger) {
engine.LogReco = l
engine.logger = l
}
// GetLoggerReco 返回底层的 reco.Logger 实例
// 用于需要访问 reco 特定功能的场景
// 如果当前 logger 不是 *reco.Logger 类型,返回 nil
//
//go:fix inline
func (c *Context) GetLoggerReco() *reco.Logger {
if rl, ok := c.engine.logger.(*reco.Logger); ok {
return rl
}
return c.engine.LogReco
}
// --- httpc 兼容函数 ---
// GetHTTPC 返回底层的 httpc.Client 实例
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
//
//go:fix inline
func (c *Context) GetHTTPC() *httpc.Client {
return c.Client()
}

View file

@ -26,6 +26,7 @@ import (
"time" "time"
"github.com/WJQSERVER/wanf" "github.com/WJQSERVER/wanf"
"github.com/fenthope/reco"
"github.com/go-json-experiment/json" "github.com/go-json-experiment/json"
"github.com/WJQSERVER-STUDIO/go-utils/iox" "github.com/WJQSERVER-STUDIO/go-utils/iox"
@ -43,8 +44,6 @@ type Context struct {
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler) handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
index int8 // 当前执行到处理链的哪个位置 index int8 // 当前执行到处理链的哪个位置
requestBodyPrepared bool
mu sync.RWMutex mu sync.RWMutex
Keys map[string]any // 用于在中间件之间传递数据 Keys map[string]any // 用于在中间件之间传递数据
@ -72,12 +71,6 @@ type Context struct {
// skippedNodes 用于记录跳过的节点信息,以便回溯 // skippedNodes 用于记录跳过的节点信息,以便回溯
// 通常在处理嵌套路由时使用 // 通常在处理嵌套路由时使用
SkippedNodes []skippedNode SkippedNodes []skippedNode
// fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲.
fixedPathBuf []byte
allowedMethodsBuf []string
allowHeaderBuf []byte
} }
// --- Context 相关方法实现 --- // --- Context 相关方法实现 ---
@ -102,42 +95,19 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
} }
c.handlers = nil c.handlers = nil
c.index = -1 // 初始为 -1`Next()` 将其设置为 0 c.index = -1 // 初始为 -1`Next()` 将其设置为 0
c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map c.Keys = make(map[string]any) // 每次请求重新创建 map避免数据污染
c.Errors = c.Errors[:0] // 清空 Errors 切片 c.Errors = c.Errors[:0] // 清空 Errors 切片
c.queryCache = nil // 清空查询参数缓存 c.queryCache = nil // 清空查询参数缓存
c.formCache = nil // 清空表单数据缓存 c.formCache = nil // 清空表单数据缓存
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
c.requestBodyPrepared = false
if cap(c.SkippedNodes) > 0 { if cap(c.SkippedNodes) > 0 {
c.SkippedNodes = c.SkippedNodes[:0] c.SkippedNodes = c.SkippedNodes[:0]
} else { } else {
c.SkippedNodes = make([]skippedNode, 0, 256) c.SkippedNodes = make([]skippedNode, 0, 256)
} }
if cap(c.fixedPathBuf) > 0 {
c.fixedPathBuf = c.fixedPathBuf[:0]
}
if cap(c.allowedMethodsBuf) > 0 {
c.allowedMethodsBuf = c.allowedMethodsBuf[:0]
}
if cap(c.allowHeaderBuf) > 0 {
c.allowHeaderBuf = c.allowHeaderBuf[:0]
}
}
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
if len(data) == 0 {
return
}
if _, err := c.Writer.Write(data); err != nil {
wrapped := fmt.Errorf("%s: %w", contextMsg, err)
c.AddError(wrapped)
if c.engine != nil && c.engine.logger != nil {
c.engine.logger.Errorf("%s: %v", contextMsg, err)
}
}
} }
// Next 在处理链中执行下一个处理函数 // Next 在处理链中执行下一个处理函数
@ -267,18 +237,6 @@ func (c *Context) SetMaxRequestBodySize(size int64) {
c.MaxRequestBodySize = size c.MaxRequestBodySize = size
} }
func (c *Context) prepareRequestBody() io.ReadCloser {
if c.Request == nil || c.Request.Body == nil {
return nil
}
if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 {
return c.Request.Body
}
c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
c.requestBodyPrepared = true
return c.Request.Body
}
// Query 从 URL 查询参数中获取值 // Query 从 URL 查询参数中获取值
// 懒加载解析查询参数,并进行缓存 // 懒加载解析查询参数,并进行缓存
func (c *Context) Query(key string) string { func (c *Context) Query(key string) string {
@ -300,39 +258,7 @@ func (c *Context) DefaultQuery(key, defaultValue string) string {
// 懒加载解析表单数据,并进行缓存 // 懒加载解析表单数据,并进行缓存
func (c *Context) PostForm(key string) string { func (c *Context) PostForm(key string) string {
if c.formCache == nil { if c.formCache == nil {
if c.MaxRequestBodySize > 0 { c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded
c.prepareRequestBody()
}
contentType := c.Request.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
switch mediaType {
case "multipart/form-data":
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
case "application/x-www-form-urlencoded":
if err := c.Request.ParseForm(); err != nil {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
default:
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
if !errors.Is(err, http.ErrNotMultipart) {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
}
}
c.formCache = c.Request.PostForm c.formCache = c.Request.PostForm
} }
return c.formCache.Get(key) return c.formCache.Get(key)
@ -356,20 +282,20 @@ func (c *Context) Param(key string) string {
func (c *Context) Raw(code int, contentType string, data []byte) { func (c *Context) Raw(code int, contentType string, data []byte) {
c.Writer.Header().Set("Content-Type", contentType) c.Writer.Header().Set("Content-Type", contentType)
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.writeResponseBody(data, "failed to write raw response") c.Writer.Write(data)
} }
// String 向响应写入格式化的字符串 // String 向响应写入格式化的字符串
func (c *Context) String(code int, format string, values ...any) { func (c *Context) String(code int, format string, values ...any) {
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.writeResponseBody(fmt.Appendf(nil, format, values...), "failed to write string response") c.Writer.Write(fmt.Appendf(nil, format, values...))
} }
// Text 向响应写入无需格式化的string // Text 向响应写入无需格式化的string
func (c *Context) Text(code int, text string) { func (c *Context) Text(code int, text string) {
c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8") c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.writeResponseBody([]byte(text), "failed to write text response") c.Writer.Write([]byte(text))
} }
// FileText // FileText
@ -412,11 +338,8 @@ func (c *Context) FileText(code int, filePath string) {
} }
c.SetHeader("Content-Type", "text/plain; charset=utf-8") c.SetHeader("Content-Type", "text/plain; charset=utf-8")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
c.Writer.WriteHeader(code) c.SetBodyStream(file, int(fileInfo.Size()))
if _, err := iox.Copy(c.Writer, file); err != nil {
c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err))
}
} }
/* /*
@ -494,22 +417,6 @@ func (c *Context) JSON(code int, obj any) {
} }
} }
// JSONBuf 先将 JSON 编码到 buffer, 成功后再写入状态码和响应体.
// 与 JSON 相比,编码失败时可以正确返回 500 状态码,代价是多一次内存分配.
func (c *Context) JSONBuf(code int, obj any) {
var buf bytes.Buffer
if err := json.MarshalWrite(&buf, obj); err != nil {
errMsg := fmt.Errorf("failed to marshal JSON: %w", err)
c.AddError(errMsg)
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
return
}
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.Writer.WriteHeader(code)
c.writeResponseBody(buf.Bytes(), "failed to write buffered JSON response")
}
// GOB 向响应写入GOB数据 // GOB 向响应写入GOB数据
// 设置 Content-Type 为 application/octet-stream // 设置 Content-Type 为 application/octet-stream
func (c *Context) GOB(code int, obj any) { func (c *Context) GOB(code int, obj any) {
@ -524,21 +431,6 @@ func (c *Context) GOB(code int, obj any) {
} }
} }
// GOBBuf 先将 GOB 编码到 buffer, 成功后再写入状态码和响应体.
func (c *Context) GOBBuf(code int, obj any) {
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
if err := encoder.Encode(obj); err != nil {
errMsg := fmt.Errorf("failed to encode GOB: %w", err)
c.AddError(errMsg)
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
return
}
c.Writer.Header().Set("Content-Type", "application/octet-stream")
c.Writer.WriteHeader(code)
c.writeResponseBody(buf.Bytes(), "failed to write buffered GOB response")
}
// WANF向响应写入WANF数据 // WANF向响应写入WANF数据
// 设置 application/vnd.wjqserver.wanf; charset=utf-8 // 设置 application/vnd.wjqserver.wanf; charset=utf-8
func (c *Context) WANF(code int, obj any) { func (c *Context) WANF(code int, obj any) {
@ -553,21 +445,6 @@ func (c *Context) WANF(code int, obj any) {
} }
} }
// WANFBuf 先将 WANF 编码到 buffer, 成功后再写入状态码和响应体.
func (c *Context) WANFBuf(code int, obj any) {
var buf bytes.Buffer
encoder := wanf.NewStreamEncoder(&buf)
if err := encoder.Encode(obj); err != nil {
errMsg := fmt.Errorf("failed to encode WANF: %w", err)
c.AddError(errMsg)
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
return
}
c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8")
c.Writer.WriteHeader(code)
c.writeResponseBody(buf.Bytes(), "failed to write buffered WANF response")
}
// HTML 渲染 HTML 模板 // HTML 渲染 HTML 模板
// 如果 Engine 配置了 HTMLRender则使用它进行渲染 // 如果 Engine 配置了 HTMLRender则使用它进行渲染
// 否则,会进行简单的字符串输出 // 否则,会进行简单的字符串输出
@ -589,37 +466,7 @@ func (c *Context) HTML(code int, name string, obj any) {
// 可以扩展支持其他渲染器接口 // 可以扩展支持其他渲染器接口
} }
// 默认简单输出,用于未配置 HTMLRender 的情况 // 默认简单输出,用于未配置 HTMLRender 的情况
c.writeResponseBody(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj), "failed to write HTML response") c.Writer.Write(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj))
}
// HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体.
// 如果模板渲染失败,则返回 500 错误且不写入任何内容.
func (c *Context) HTMLBuf(code int, name string, obj any) {
if c.engine == nil || c.engine.HTMLRender == nil {
// 没有渲染器,回退到简单输出
c.HTML(code, name, obj)
return
}
if tpl, ok := c.engine.HTMLRender.(*template.Template); ok {
var buf bytes.Buffer
err := tpl.ExecuteTemplate(&buf, name, obj)
if err != nil {
// 渲染失败,记录错误并返回 500不写入任何内容
errMsg := fmt.Errorf("failed to render HTML template '%s': %w", name, err)
c.AddError(errMsg)
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
return
}
// 渲染成功,写入响应
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
c.Writer.WriteHeader(code)
c.writeResponseBody(buf.Bytes(), "failed to write buffered HTML response")
return
}
// 不支持的渲染器类型,回退到简单输出
c.HTML(code, name, obj)
} }
// Redirect 执行 HTTP 重定向 // Redirect 执行 HTTP 重定向
@ -634,16 +481,10 @@ func (c *Context) Redirect(code int, location string) {
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象 // ShouldBindJSON 尝试将请求体绑定到 JSON 对象
func (c *Context) ShouldBindJSON(obj any) error { func (c *Context) ShouldBindJSON(obj any) error {
var body io.ReadCloser if c.Request.Body == nil {
if c.MaxRequestBodySize > 0 {
body = c.prepareRequestBody()
} else {
body = c.Request.Body
}
if body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
err := json.UnmarshalRead(body, obj) err := json.UnmarshalRead(c.Request.Body, obj)
if err != nil { if err != nil {
return fmt.Errorf("json binding error: %w", err) return fmt.Errorf("json binding error: %w", err)
} }
@ -652,16 +493,10 @@ func (c *Context) ShouldBindJSON(obj any) error {
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象 // ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
func (c *Context) ShouldBindWANF(obj any) error { func (c *Context) ShouldBindWANF(obj any) error {
var body io.ReadCloser if c.Request.Body == nil {
if c.MaxRequestBodySize > 0 {
body = c.prepareRequestBody()
} else {
body = c.Request.Body
}
if body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
decoder, err := wanf.NewStreamDecoder(body) decoder, err := wanf.NewStreamDecoder(c.Request.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed to create WANF decoder: %w", err) return fmt.Errorf("failed to create WANF decoder: %w", err)
} }
@ -674,16 +509,10 @@ func (c *Context) ShouldBindWANF(obj any) error {
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象 // ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
func (c *Context) ShouldBindGOB(obj any) error { func (c *Context) ShouldBindGOB(obj any) error {
var body io.ReadCloser if c.Request.Body == nil {
if c.MaxRequestBodySize > 0 {
body = c.prepareRequestBody()
} else {
body = c.Request.Body
}
if body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
decoder := gob.NewDecoder(body) decoder := gob.NewDecoder(c.Request.Body)
if err := decoder.Decode(obj); err != nil { if err := decoder.Decode(obj); err != nil {
return fmt.Errorf("GOB binding error: %w", err) return fmt.Errorf("GOB binding error: %w", err)
} }
@ -800,10 +629,6 @@ func setFieldValue(field reflect.Value, values []string) error {
// ShouldBindForm 尝试将表单数据绑定到结构体 // ShouldBindForm 尝试将表单数据绑定到结构体
// 支持 application/x-www-form-urlencoded 和 multipart/form-data // 支持 application/x-www-form-urlencoded 和 multipart/form-data
func (c *Context) ShouldBindForm(obj any) error { func (c *Context) ShouldBindForm(obj any) error {
if c.MaxRequestBodySize > 0 {
c.prepareRequestBody()
}
contentType := c.Request.Header.Get("Content-Type") contentType := c.Request.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType) mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil { if err != nil {
@ -812,7 +637,7 @@ func (c *Context) ShouldBindForm(obj any) error {
switch mediaType { switch mediaType {
case "multipart/form-data": case "multipart/form-data":
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
return fmt.Errorf("parse multipart form error: %w", err) return fmt.Errorf("parse multipart form error: %w", err)
} }
case "application/x-www-form-urlencoded": case "application/x-www-form-urlencoded":
@ -826,7 +651,6 @@ func (c *Context) ShouldBindForm(obj any) error {
if err := bindForm(c.Request.Form, obj); err != nil { if err := bindForm(c.Request.Form, obj); err != nil {
return fmt.Errorf("form binding error: %w", err) return fmt.Errorf("form binding error: %w", err)
} }
c.formCache = c.Request.PostForm
return nil return nil
} }
@ -864,29 +688,10 @@ func (c *Context) GetErrors() []error {
return c.Errors return c.Errors
} }
// Client 返回当前请求的 HTTPClient // Client 返回 Engine 提供的 HTTPClient
// 如果请求处理函数或中间件设置了自定义 HTTPClient返回该实例 // 方便在请求处理函数中进行出站 HTTP 请求
// 否则返回 Engine 提供的默认实例
//
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
func (c *Context) Client() *httpc.Client { func (c *Context) Client() *httpc.Client {
if c.HTTPClient != nil { return c.HTTPClient
return c.HTTPClient
}
return c.engine.HTTPClient
}
// HTTPC 返回自动关联请求 Context 的 HTTP 客户端
// 当请求被取消时,通过此客户端发起的出站请求也会自动取消
func (c *Context) HTTPC() *contextHTTPClient {
client := c.HTTPClient
if client == nil {
client = c.engine.HTTPClient
}
return &contextHTTPClient{
client: client,
ctx: c.ctx,
}
} }
// Context() 返回请求的上下文,用于取消操作 // Context() 返回请求的上下文,用于取消操作
@ -946,30 +751,37 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) {
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体 // GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
// 注意:请求体只能读取一次 // 注意:请求体只能读取一次
func (c *Context) GetReqBody() io.ReadCloser { func (c *Context) GetReqBody() io.ReadCloser {
if c.MaxRequestBodySize > 0 {
return c.prepareRequestBody()
}
if c.Request == nil || c.Request.Body == nil {
return nil
}
return c.Request.Body return c.Request.Body
} }
// GetReqBodyFull 读取并返回请求体的所有内容 // GetReqBodyFull 读取并返回请求体的所有内容
// 注意:请求体只能读取一次 // 注意:请求体只能读取一次
func (c *Context) GetReqBodyFull() ([]byte, error) { func (c *Context) GetReqBodyFull() ([]byte, error) {
body := c.GetReqBody() if c.Request.Body == nil {
if body == nil {
return nil, nil return nil, nil
} }
defer func() {
err := body.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
data, err := io.ReadAll(body) var limitBytesReader io.ReadCloser
if c.MaxRequestBodySize > 0 {
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
defer func() {
err := limitBytesReader.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
} else {
limitBytesReader = c.Request.Body
defer func() {
err := limitBytesReader.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
}
data, err := iox.ReadAll(limitBytesReader)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) c.AddError(fmt.Errorf("failed to read request body: %w", err))
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)
@ -979,18 +791,31 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
// 类似 GetReqBodyFull, 返回 *bytes.Buffer // 类似 GetReqBodyFull, 返回 *bytes.Buffer
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
body := c.GetReqBody() if c.Request.Body == nil {
if body == nil {
return nil, nil return nil, nil
} }
defer func() {
err := body.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
data, err := io.ReadAll(body) var limitBytesReader io.ReadCloser
if c.MaxRequestBodySize > 0 {
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
defer func() {
err := limitBytesReader.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
} else {
limitBytesReader = c.Request.Body
defer func() {
err := limitBytesReader.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
}
data, err := iox.ReadAll(limitBytesReader)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) c.AddError(fmt.Errorf("failed to read request body: %w", err))
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)
@ -1149,9 +974,14 @@ func (c *Context) GetProtocol() string {
return c.Request.Proto return c.Request.Proto
} }
// GetLogger 获取engine的Logger接口 // GetHTTPC 获取框架自带传递的httpc
func (c *Context) GetLogger() Logger { func (c *Context) GetHTTPC() *httpc.Client {
return c.engine.logger return c.HTTPClient
}
// GetLogger 获取engine的Logger
func (c *Context) GetLogger() *reco.Logger {
return c.engine.LogReco
} }
// GetReqQueryString // GetReqQueryString
@ -1254,25 +1084,17 @@ func (c *Context) SetSameSite(samesite http.SameSite) {
} }
// SetCookie 设置一个 HTTP cookie // SetCookie 设置一个 HTTP cookie
// sameSite 参数是可选的,如果不提供则使用通过 SetSameSite 设置的值 func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool) {
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool, sameSite ...http.SameSite) {
if path == "" { if path == "" {
path = "/" path = "/"
} }
site := c.sameSite
if len(sameSite) > 0 {
if len(sameSite) > 1 {
c.Warnf("SetCookie: only the first SameSite value will be used, got %d values", len(sameSite))
}
site = sameSite[0]
}
http.SetCookie(c.Writer, &http.Cookie{ http.SetCookie(c.Writer, &http.Cookie{
Name: name, Name: name,
Value: url.QueryEscape(value), Value: url.QueryEscape(value),
MaxAge: maxAge, MaxAge: maxAge,
Path: path, Path: path,
Domain: domain, Domain: domain,
SameSite: site, SameSite: c.sameSite,
Secure: secure, Secure: secure,
HttpOnly: httpOnly, HttpOnly: httpOnly,
}) })
@ -1310,25 +1132,25 @@ func (c *Context) DeleteCookie(name string) {
// === 日志记录 === // === 日志记录 ===
func (c *Context) Debugf(format string, args ...any) { func (c *Context) Debugf(format string, args ...any) {
c.engine.logger.Debugf(format, args...) c.engine.LogReco.Debugf(format, args...)
} }
func (c *Context) Infof(format string, args ...any) { func (c *Context) Infof(format string, args ...any) {
c.engine.logger.Infof(format, args...) c.engine.LogReco.Infof(format, args...)
} }
func (c *Context) Warnf(format string, args ...any) { func (c *Context) Warnf(format string, args ...any) {
c.engine.logger.Warnf(format, args...) c.engine.LogReco.Warnf(format, args...)
} }
func (c *Context) Errorf(format string, args ...any) { func (c *Context) Errorf(format string, args ...any) {
c.engine.logger.Errorf(format, args...) c.engine.LogReco.Errorf(format, args...)
} }
func (c *Context) Fatalf(format string, args ...any) { func (c *Context) Fatalf(format string, args ...any) {
c.engine.logger.Fatalf(format, args...) c.engine.LogReco.Fatalf(format, args...)
} }
func (c *Context) Panicf(format string, args ...any) { func (c *Context) Panicf(format string, args ...any) {
c.engine.logger.Panicf(format, args...) c.engine.LogReco.Panicf(format, args...)
} }

View file

@ -1,81 +0,0 @@
package touka
import (
"net/http"
"testing"
)
func TestContextResetKeepsKeysNilUntilSet(t *testing.T) {
c, _ := CreateTestContext(nil)
if c.Keys != nil {
t.Fatalf("expected fresh test context Keys to be nil before first Set")
}
c.Set("answer", 42)
if c.Keys == nil {
t.Fatalf("expected Set to allocate Keys map")
}
if value, exists := c.Get("answer"); !exists || value != 42 {
t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists)
}
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("failed to build request: %v", err)
}
c.reset(UnwrapResponseWriter(c.Writer), req)
if c.Keys != nil {
t.Fatalf("expected reset to clear Keys without allocating a new map")
}
if value, exists := c.Get("answer"); exists || value != nil {
t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists)
}
ctxValue := c.Value("missing")
if ctxValue != nil {
t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue)
}
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected MustGet to panic for missing key after reset")
}
}()
_ = c.MustGet("answer")
}
func BenchmarkContextReset(b *testing.B) {
b.Run("NoKeysUse", func(b *testing.B) {
c, _ := CreateTestContext(nil)
rawWriter := UnwrapResponseWriter(c.Writer)
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.reset(rawWriter, req)
}
})
b.Run("WithKeysUse", func(b *testing.B) {
c, _ := CreateTestContext(nil)
rawWriter := UnwrapResponseWriter(c.Writer)
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.reset(rawWriter, req)
c.Set("request-id", i)
}
})
}

View file

@ -1,174 +0,0 @@
package touka
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
type zeroNilThenEOFReader struct {
readCalls int
}
func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) {
r.readCalls++
if r.readCalls == 1 {
return 0, nil
}
return 0, io.EOF
}
func (r *zeroNilThenEOFReader) Close() error {
return nil
}
func TestFileTextUsesProvidedStatusCode(t *testing.T) {
t.Helper()
dir := t.TempDir()
filePath := filepath.Join(dir, "hello.txt")
if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil {
t.Fatalf("write temp file: %v", err)
}
rr := httptest.NewRecorder()
c, _ := CreateTestContext(rr)
c.FileText(http.StatusCreated, filePath)
if rr.Code != http.StatusCreated {
t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code)
}
if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" {
t.Fatalf("unexpected content type: %q", got)
}
if body := rr.Body.String(); body != "hello touka" {
t.Fatalf("unexpected body: %q", body)
}
}
func TestMaxBytesReaderAllowsExactLimit(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4)
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("expected exact limit read to succeed, got %v", err)
}
if string(data) != "abcd" {
t.Fatalf("unexpected data: %q", string(data))
}
}
func TestMaxBytesReaderRejectsOverLimit(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4)
defer reader.Close()
_, err := io.ReadAll(reader)
if !errors.Is(err, ErrBodyTooLarge) {
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
}
}
func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1)
defer reader.Close()
buf := make([]byte, 1)
n, err := reader.Read(buf)
if n != 0 || err != nil {
t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err)
}
n, err = reader.Read(buf)
if n != 0 || !errors.Is(err, io.EOF) {
t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err)
}
}
func TestMaxBytesReaderTreatsZeroLimitAsUnlimited(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abc")), 0)
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("expected zero limit to leave body unlimited, got %v", err)
}
if string(data) != "abc" {
t.Fatalf("unexpected data: %q", string(data))
}
}
func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) {
t.Helper()
body := strings.NewReader(`{"name":"abcdef"}`)
req := httptest.NewRequest(http.MethodPost, "/json", body)
req.Header.Set("Content-Type", "application/json")
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
c.SetMaxRequestBodySize(8)
var payload struct {
Name string `json:"name"`
}
err := c.ShouldBindJSON(&payload)
if !errors.Is(err, ErrBodyTooLarge) {
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
}
}
func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) {
t.Helper()
body := strings.NewReader("name=abcdef")
req := httptest.NewRequest(http.MethodPost, "/form", body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
c.SetMaxRequestBodySize(4)
var payload struct {
Name string `form:"name"`
}
err := c.ShouldBindForm(&payload)
if !errors.Is(err, ErrBodyTooLarge) {
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
}
}
func TestPostFormHonorsMaxRequestBodySize(t *testing.T) {
t.Helper()
body := strings.NewReader("name=abcdef")
req := httptest.NewRequest(http.MethodPost, "/form", body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
c.SetMaxRequestBodySize(4)
if got := c.PostForm("name"); got != "" {
t.Fatalf("expected empty value on over-limit form body, got %q", got)
}
if len(c.Errors) == 0 {
t.Fatal("expected parse error to be recorded")
}
if !errors.Is(c.Errors[0], ErrBodyTooLarge) {
t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0])
}
}

View file

@ -1,58 +0,0 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2024 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"context"
"github.com/WJQSERVER-STUDIO/httpc"
)
// contextHTTPClient 包装 httpc.Client自动关联请求的 Context
// 当请求被取消时,出站 HTTP 请求也会自动取消
type contextHTTPClient struct {
client *httpc.Client
ctx context.Context
}
// NewRequestBuilder 创建请求构建器,自动关联请求 Context
func (c *contextHTTPClient) NewRequestBuilder(method, urlStr string) *httpc.RequestBuilder {
return c.client.NewRequestBuilder(method, urlStr).WithContext(c.ctx)
}
// GET 创建 GET 请求构建器
func (c *contextHTTPClient) GET(urlStr string) *httpc.RequestBuilder {
return c.client.GET(urlStr).WithContext(c.ctx)
}
// POST 创建 POST 请求构建器
func (c *contextHTTPClient) POST(urlStr string) *httpc.RequestBuilder {
return c.client.POST(urlStr).WithContext(c.ctx)
}
// PUT 创建 PUT 请求构建器
func (c *contextHTTPClient) PUT(urlStr string) *httpc.RequestBuilder {
return c.client.PUT(urlStr).WithContext(c.ctx)
}
// DELETE 创建 DELETE 请求构建器
func (c *contextHTTPClient) DELETE(urlStr string) *httpc.RequestBuilder {
return c.client.DELETE(urlStr).WithContext(c.ctx)
}
// PATCH 创建 PATCH 请求构建器
func (c *contextHTTPClient) PATCH(urlStr string) *httpc.RequestBuilder {
return c.client.PATCH(urlStr).WithContext(c.ctx)
}
// HEAD 创建 HEAD 请求构建器
func (c *contextHTTPClient) HEAD(urlStr string) *httpc.RequestBuilder {
return c.client.HEAD(urlStr).WithContext(c.ctx)
}
// OPTIONS 创建 OPTIONS 请求构建器
func (c *contextHTTPClient) OPTIONS(urlStr string) *httpc.RequestBuilder {
return c.client.OPTIONS(urlStr).WithContext(c.ctx)
}

View file

@ -44,9 +44,7 @@ r.SetTLSServerConfigurator(func(server *http.Server) {
Touka 支持配置 HTTP/1.1、HTTP/2 和 H2CHTTP/2 Cleartext Touka 支持配置 HTTP/1.1、HTTP/2 和 H2CHTTP/2 Cleartext
```go ```go
// 使用默认协议配置 // 使用默认协议配置(仅 HTTP/1.1
// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集,
// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。
r.SetDefaultProtocols() r.SetDefaultProtocols()
// 自定义协议配置 // 自定义协议配置
@ -59,147 +57,33 @@ r.SetProtocols(&touka.ProtocolsConfig{
### 启动方式 ### 启动方式
Touka 统一通过 `Run(opts...)` 启动服务器 Touka 提供了多种服务器启动方式
```go ```go
// 1. 简单启动(无优雅停机) // 1. 简单启动(无优雅停机)
r.Run(touka.WithAddr(":8080")) r.Run(":8080")
// 2. 带优雅停机的启动 // 2. 带优雅停机的启动
r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)) r.RunShutdown(":8080", 10*time.Second)
// 3. 带上下文的优雅停机 // 3. 带上下文的优雅停机
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() r.RunShutdownWithContext(":8080", ctx, 10*time.Second)
r.Run(
touka.WithAddr(":8080"),
touka.WithGracefulShutdown(10*time.Second),
touka.WithShutdownContext(ctx),
)
// 4. HTTPS 启动 // 4. HTTPS 启动
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
// 其他 TLS 配置... // 其他 TLS 配置...
} }
// WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。 r.RunTLS(":443", tlsConfig, 10*time.Second)
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithGracefulShutdownDefault(),
)
// 5. HTTPS + HTTP 重定向 // 5. HTTPS + HTTP 重定向
// WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。 r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second)
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(":80"),
touka.WithGracefulShutdown(10*time.Second),
)
// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(true),
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
),
)
// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(false),
touka.WithRedirectHost("example.com"),
),
)
``` ```
### HTTPS Redirect Host 策略
`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。
可用的 redirect 子选项:
- `touka.WithUseHeaderHost(true|false)`
- `touka.WithRedirectHostHeaders([]string{...})`
- `touka.WithRedirectHost("example.com")`
#### 模式一:使用请求输入侧的 host
`WithUseHeaderHost(true)` 时:
- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host`
- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header并使用第一个非空值
- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required`
示例:
```go
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(true),
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
),
)
```
#### 模式二:使用配置的固定 host
`WithUseHeaderHost(false)` 时:
- 不读取 `Request.Host`
- 不读取 `WithRedirectHostHeaders(...)`
- 必须配置 `WithRedirectHost("example.com")`
示例:
```go
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(false),
touka.WithRedirectHost("example.com"),
),
)
```
#### 严格校验规则
以下组合会直接返回配置错误:
- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)`
- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)`
- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)`
- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)`
- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)`
#### 优先级关系
1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式
2. `WithUseHeaderHost(...)` 决定 host 来源模式
3. 当 `WithUseHeaderHost(true)` 时:
- 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询
- 未配置时使用 `Request.Host`
4. 当 `WithUseHeaderHost(false)` 时:
- 只使用 `WithRedirectHost(...)`
**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。
## 优雅停机 (Graceful Shutdown) ## 优雅停机 (Graceful Shutdown)
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。
```go ```go
r := touka.Default() r := touka.Default()
@ -207,7 +91,7 @@ r := touka.Default()
// 监听 SIGINT 和 SIGTERM 信号 // 监听 SIGINT 和 SIGTERM 信号
// 如果在 10 秒内未处理完,则强制关闭 // 如果在 10 秒内未处理完,则强制关闭
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
log.Fatal("服务器退出异常:", err) log.Fatal("服务器退出异常:", err)
} }
``` ```

View file

@ -1,188 +0,0 @@
# HTTP Client (httpc)
Touka 内置了 [httpc](https://github.com/WJQSERVER-STUDIO/httpc) HTTP 客户端,方便在请求处理函数中发起出站 HTTP 请求。
## 核心特性
- **自动 Context 关联**:使用 `HTTPC()` 方法时,出站请求会自动关联当前请求的 Context
- **请求取消传播**:当客户端断开连接时,出站请求会自动取消,避免资源泄漏
- **链式调用**:保持 httpc 原有的组合式构建器风格
## 基本用法
### 简单 GET 请求
```go
r.GET("/proxy", func(c *touka.Context) {
body, err := c.HTTPC().
GET("https://api.example.com/data").
Text()
if err != nil {
c.JSON(500, touka.H{"error": err.Error()})
return
}
c.String(200, body)
})
```
### POST JSON 请求
```go
r.POST("/users", func(c *touka.Context) {
var req struct {
Name string `json:"name"`
Email string `json:"email"`
}
c.ShouldBindJSON(&req)
var result struct {
ID int `json:"id"`
Name string `json:"name"`
}
err := c.HTTPC().
POST("https://api.example.com/users").
SetHeader("Authorization", "Bearer "+token).
SetJSONBody(req).
DecodeJSON(&result)
if err != nil {
c.JSON(500, touka.H{"error": err.Error()})
return
}
c.JSON(200, result)
})
```
### 带查询参数
```go
r.GET("/search", func(c *touka.Context) {
query := c.Query("q")
var result SearchResult
err := c.HTTPC().
GET("https://api.example.com/search").
SetQueryParam("q", query).
SetQueryParam("limit", "10").
DecodeJSON(&result)
if err != nil {
c.JSON(500, touka.H{"error": err.Error()})
return
}
c.JSON(200, result)
})
```
## API 对比
### 旧方式Deprecated
```go
// 需要手动 WithContext容易忘记
resp, err := c.Client().
WithContext(c.Context()).
GET(url).
Execute()
```
### 新方式(推荐)
```go
// 自动关联请求 Context
resp, err := c.HTTPC().
GET(url).
Execute()
```
## Context 取消机制
使用 `HTTPC()` 时,当客户端断开连接(如关闭浏览器),出站请求会自动取消:
```go
r.GET("/long-task", func(c *touka.Context) {
// 这个请求会在客户端断开时自动取消
resp, err := c.HTTPC().
GET("https://slow-api.example.com/data").
Execute()
// 如果客户端已断开err 会包含 context.Canceled
if errors.Is(err, context.Canceled) {
return // 客户端已断开,无需处理
}
// ...
})
```
## 完整 API
### contextHTTPClient 方法
| 方法 | 返回类型 | 说明 |
|------|----------|------|
| `NewRequestBuilder(method, url)` | `*httpc.RequestBuilder` | 创建通用请求构建器 |
| `GET(url)` | `*httpc.RequestBuilder` | 创建 GET 请求 |
| `POST(url)` | `*httpc.RequestBuilder` | 创建 POST 请求 |
| `PUT(url)` | `*httpc.RequestBuilder` | 创建 PUT 请求 |
| `DELETE(url)` | `*httpc.RequestBuilder` | 创建 DELETE 请求 |
| `PATCH(url)` | `*httpc.RequestBuilder` | 创建 PATCH 请求 |
| `HEAD(url)` | `*httpc.RequestBuilder` | 创建 HEAD 请求 |
| `OPTIONS(url)` | `*httpc.RequestBuilder` | 创建 OPTIONS 请求 |
### httpc.RequestBuilder 链式方法
返回 `*httpc.RequestBuilder`(用于链式调用):
| 方法 | 说明 |
|------|------|
| `WithContext(ctx)` | 设置 Context通常不需要已自动关联 |
| `NoDefaultHeaders()` | 不添加默认 Header |
| `SetHeader(key, value)` | 设置 Header |
| `AddHeader(key, value)` | 添加 Header可重复 |
| `SetHeaders(map)` | 批量设置 Headers |
| `SetQueryParam(key, value)` | 设置查询参数 |
| `AddQueryParam(key, value)` | 添加查询参数(可重复) |
| `SetQueryParams(map)` | 批量设置查询参数 |
| `SetBody(io.Reader)` | 设置请求 Body |
| `SetRawBody([]byte)` | 设置字节 Body |
返回 `(*httpc.RequestBuilder, error)`(可能失败):
| 方法 | 说明 |
|------|------|
| `SetJSONBody(any)` | 设置 JSON Body |
| `SetXMLBody(any)` | 设置 XML Body |
| `SetGOBBody(any)` | 设置 GOB Body |
### 终结方法
| 方法 | 返回类型 | 说明 |
|------|----------|------|
| `Build()` | `(*http.Request, error)` | 构建请求但不执行 |
| `Execute()` | `(*http.Response, error)` | 执行并返回原始响应 |
| `DecodeJSON(v)` | `error` | 执行并解码 JSON |
| `DecodeXML(v)` | `error` | 执行并解码 XML |
| `DecodeGOB(v)` | `error` | 执行并解码 GOB |
| `Text()` | `(string, error)` | 执行并返回文本 |
| `Bytes()` | `([]byte, error)` | 执行并返回字节 |
| `SSE()` | `(*SSEStream, error)` | 建立 SSE 流连接 |
## 迁移指南
### go:fix inline 兼容
旧代码 `c.GetHTTPC()` 可通过 `go fix` 自动迁移到 `c.Client()`
```bash
go fix ./...
```
### 手动迁移
| 旧代码 | 新代码 |
|--------|--------|
| `c.GetHTTPC()` | `c.Client()``c.HTTPC()` |
| `c.Client().WithContext(ctx).GET(url)` | `c.HTTPC().GET(url)` |
## 示例
完整示例请参考 [examples/httpc](../examples/httpc)。

View file

@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其
1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。 1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。
2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。 2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。
3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求 3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理
Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。 Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。

View file

@ -1,400 +0,0 @@
# Touka Logger 接口迁移方案
## 基于 Go 1.26 `go:fix inline` 的自动化迁移设计
---
## 一、问题分析
当前架构问题:
```
Engine.LogReco → *reco.Logger (公开字段, 直接访问)
Context.GetLogger() → 返回 *reco.Logger (具体类型)
Context.Debugf/Infof... → 硬编码 c.engine.LogReco.Debugf(...)
```
这导致用户无法替换日志实现(如 zap/logrus
---
## 二、目标架构
```
Engine.logger → Logger 接口 (私有)
Engine.LogReco → *reco.Logger (公开, Deprecated - 保持向后兼容)
Engine.GetLogger() → 返回 Logger 接口
Engine.SetLogger(Logger)→ 设置日志实现
Context.GetLogger() → 返回 Logger 接口
Context.Debugf/Infof... → 调用 c.engine.logger.Debugf(...)
```
---
## 三、Logger 接口定义
```go
// logger.go
package touka
// Logger 是日志接口,支持任意日志库实现
type Logger interface {
Debugf(format string, args ...any)
Infof(format string, args ...any)
Warnf(format string, args ...any)
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
Panicf(format string, args ...any)
}
// CloserLogger 可选扩展,支持关闭操作
type CloserLogger interface {
Logger
Close() error
}
```
---
## 四、Engine 结构变更
```go
// engine.go 变更
type Engine struct {
// ... 其他字段保持不变
// logger 是新的日志接口 (私有)
logger Logger
// logReco 是保留的 reco.Logger 引用 (私有)
// 用于向后兼容,当通过 SetLoggerReco 设置时同步到 logger
logReco *reco.Logger
// 其他字段...
}
```
新增/修改方法:
```go
// GetLogger 返回日志接口
func (engine *Engine) GetLogger() Logger {
return engine.logger
}
// SetLogger 设置任意 Logger 实现
func (engine *Engine) SetLogger(l Logger) {
engine.logger = l
// 如果是 *reco.Logger 类型,同步更新 logReco
if rl, ok := l.(*reco.Logger); ok {
engine.logReco = rl
} else {
engine.logReco = nil
}
}
// SetLoggerCfg 使用 reco.Config 配置日志
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
logger := NewLogger(logcfg)
engine.logger = logger
engine.logReco = logger
}
```
---
## 五、`go:fix inline` 兼容性函数
### 5.1 旧 API 包装函数
`compat.go` 中定义:
```go
// compat.go
package touka
import "github.com/fenthope/reco"
// GetLogReco 返回 reco.Logger用于向后兼容
//
//go:fix inline
func (engine *Engine) GetLogReco() *reco.Logger {
return engine.logReco
}
// SetLogReco 设置 reco.Logger用于向后兼容
//
//go:fix inline
func (engine *Engine) SetLogReco(l *reco.Logger) {
engine.logReco = l
engine.logger = l
}
```
### 5.2 Context 日志方法的 inline 包装
```go
// context_compat.go
package touka
// Debugf 记录 Debug 级别日志
//
//go:fix inline
func (c *Context) Debugf(format string, args ...any) {
c.engine.logger.Debugf(format, args...)
}
// Infof 记录 Info 级别日志
//
//go:fix inline
func (c *Context) Infof(format string, args ...any) {
c.engine.logger.Infof(format, args...)
}
// Warnf 记录 Warn 级别日志
//
//go:fix inline
func (c *Context) Warnf(format string, args ...any) {
c.engine.logger.Warnf(format, args...)
}
// Errorf 记录 Error 级别日志
//
//go:fix inline
func (c *Context) Errorf(format string, args ...any) {
c.engine.logger.Errorf(format, args...)
}
// Fatalf 记录 Fatal 级别日志
//
//go:fix inline
func (c *Context) Fatalf(format string, args ...any) {
c.engine.logger.Fatalf(format, args...)
}
// Panicf 记录 Panic 级别日志
//
//go:fix inline
func (c *Context) Panicf(format string, args ...any) {
c.engine.logger.Panicf(format, args...)
}
```
### 5.3 GetLogger 返回类型的兼容处理
由于 `GetLogger()` 返回类型从 `*reco.Logger` 变为 `Logger`,需要提供兼容函数:
```go
// context_compat.go (续)
// GetLoggerReco 返回 *reco.Logger 类型,用于需要具体类型的场景
//
//go:fix inline
func (c *Context) GetLoggerReco() *reco.Logger {
if rl, ok := c.engine.logger.(*reco.Logger); ok {
return rl
}
return nil
}
```
---
## 六、go:fix inline 工作原理
### 迁移前用户代码:
```go
func handler(c *touka.Context) {
// 旧 API 调用
c.Debugf("request: %s", c.Request.URL.Path)
c.engine.LogReco.Infof("server started")
}
```
### go fix 执行后(自动替换):
```go
func handler(c *touka.Context) {
// Debugf 被替换为函数体
c.engine.logger.Debugf("request: %s", c.Request.URL.Path)
// LogReco 访问无法通过 inline 自动处理,需要手动迁移
// 或者通过 getter 调用
}
```
### 对于字段访问的处理策略:
`engine.LogReco` 字段访问无法直接用 `go:fix inline` 处理,采用以下策略:
1. **保留字段但标记 deprecated**:继续导出 `LogReco` 但文档标记为 deprecated
2. **提供 getter/setter**:通过 `go:fix inline` 提供 `GetLogReco/SetLogReco`
3. **渐进迁移**:用户可以在方便时手动迁移到 `GetLogger()/SetLogger()`
---
## 七、迁移前后对比
### 场景 1基本日志调用
**迁移前:**
```go
func myHandler(c *touka.Context) {
c.Debugf("processing request %s", c.Request.URL.Path)
c.Infof("user %s logged in", username)
c.Warnf("slow query: %v", duration)
c.Errorf("db error: %v", err)
}
```
**迁移后(自动替换):**
```go
func myHandler(c *touka.Context) {
c.engine.logger.Debugf("processing request %s", c.Request.URL.Path)
c.engine.logger.Infof("user %s logged in", username)
c.engine.logger.Warnf("slow query: %v", duration)
c.engine.logger.Errorf("db error: %v", err)
}
```
### 场景 2Engine 配置日志
**迁移前:**
```go
engine := touka.New()
engine.LogReco = myLogger // 直接赋值
logger := engine.LogReco // 直接读取
```
**迁移后(手动 + 自动混合):**
```go
engine := touka.New()
// 方式 1使用新 API推荐
engine.SetLogger(myLogger)
logger := engine.GetLogger()
// 方式 2通过 go:fix inline 自动替换为 getter
// engine.SetLogReco(myLogger) ← go fix 替换
// logger := engine.GetLogReco() ← go fix 替换
```
### 场景 3使用第三方日志库新功能
```go
import "go.uber.org/zap"
func main() {
zapLogger, _ := zap.NewProduction()
defer zapLogger.Sync()
engine := touka.New()
// 使用 zap 替代默认的 reco.Logger
engine.SetLogger(&ZapAdapter{logger: zapLogger})
engine.GET("/api", func(c *touka.Context) {
c.Infof("api called") // 自动使用 zap 输出
})
}
// ZapAdapter 适配 zap 到 touka.Logger 接口
type ZapAdapter struct {
logger *zap.Logger
}
func (z *ZapAdapter) Debugf(format string, args ...any) {
z.logger.Debug(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Infof(format string, args ...any) {
z.logger.Info(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Warnf(format string, args ...any) {
z.logger.Warn(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Errorf(format string, args ...any) {
z.logger.Error(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Fatalf(format string, args ...any) {
z.logger.Fatal(fmt.Sprintf(format, args...))
}
func (z *ZapAdapter) Panicf(format string, args ...any) {
z.logger.Panic(fmt.Sprintf(format, args...))
}
```
---
## 八、内部使用迁移
框架内部代码也需要迁移,将直接调用 `engine.LogReco` 改为 `engine.logger`
需要修改的文件:
- `context.go`: writeResponseBody 中的 `c.engine.LogReco.Errorf`
- `recovery.go`: 如有使用日志
- `logreco.go`: CloseLogger 方法
```go
// context.go 修改前
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
if _, err := c.Writer.Write(data); err != nil {
if c.engine.LogReco != nil {
c.engine.LogReco.Errorf("%s: %v", contextMsg, err)
}
}
}
// context.go 修改后
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
if _, err := c.Writer.Write(data); err != nil {
if c.engine.logger != nil {
c.engine.logger.Errorf("%s: %v", contextMsg, err)
}
}
}
```
---
## 九、完整文件结构
```
touka/
├── logger.go # Logger 接口定义
├── logreco.go # reco.Logger 相关工具函数
├── compat.go # go:fix inline 兼容性函数 (Engine)
├── context_compat.go # go:fix inline 兼容性函数 (Context)
├── engine.go # Engine 结构变更
├── context.go # Context 日志方法变更
└── ...
```
---
## 十、版本策略
| 版本 | 变更内容 |
|------|---------|
| v1.x | 引入 Logger 接口LogReco 标记 deprecated |
| v2.x | 移除 LogReco 公开字段,仅通过 getter/setter 访问 |
| v3.x | 移除 go:fix inline 兼容函数 |
---
## 十一、go:fix inline 限制说明
1. **字段访问无法自动迁移**`engine.LogReco` 字段访问需要用户手动修改
2. **返回类型变更需谨慎**`GetLogger()` 返回类型变更会导致依赖具体类型的代码失败
3. **inline 函数有大小限制**:函数体过大会影响内联效果
4. **跨包迁移**`go:fix inline` 支持跨包,但用户必须运行 `go fix`
---
## 十二、推荐迁移步骤
1. **框架侧**:添加 Logger 接口,添加 go:fix inline 函数
2. **用户侧**:运行 `go fix ./...` 自动迁移可处理的部分
3. **用户侧**:手动将 `engine.LogReco` 字段访问改为 `engine.SetLogger()/GetLogger()`
4. **用户侧**:如需使用第三方日志,实现 Logger 接口并通过 SetLogger 设置

View file

@ -26,41 +26,6 @@ api.Use(AuthMiddleware())
} }
``` ```
也可以在创建组时直接传入中间件:
```go
api := r.Group("/api", AuthMiddleware(), RateLimitMiddleware())
{
api.GET("/user", handleUser)
api.POST("/data", handleData)
}
```
### 路由级中间件
为单个路由注册中间件,仅对该路由生效。
```go
// 单个路由中间件
r.GET("/protected", AuthMiddleware(), func(c *touka.Context) {
c.String(http.StatusOK, "Protected content")
})
// 多个路由中间件(按顺序执行)
r.POST("/upload",
RateLimitMiddleware(),
AuthMiddleware(),
PermissionCheckMiddleware(),
func(c *touka.Context) {
// 处理上传
},
)
// 路由组中的单个路由也可以使用路由级中间件
api := r.Group("/api")
api.GET("/admin", AdminAuthMiddleware(), adminHandler)
```
## 编写自定义中间件 ## 编写自定义中间件
中间件的函数签名是 `touka.HandlerFunc` 中间件的函数签名是 `touka.HandlerFunc`
@ -102,36 +67,6 @@ func APIKeyAuth() touka.HandlerFunc {
} }
``` ```
## 中间件执行顺序
理解中间件的执行顺序对于构建正确的处理流程至关重要。**注意:注册顺序决定了执行逻辑**,中间件必须在注册路由之前调用(全局中间件应在创建组或定义路由前注册)。中间件按照以下顺序执行:
```go
// 全局中间件
r.Use(GlobalMiddleware1())
r.Use(GlobalMiddleware2())
// 组中间件
api := r.Group("/api", GroupMiddleware1())
api.Use(GroupMiddleware2())
// 路由级中间件
api.GET("/users", RouteMiddleware1(), RouteMiddleware2(), userHandler)
```
对于 `/api/users` 请求,执行顺序为:
1. `GlobalMiddleware1()` - 全局中间件
2. `GlobalMiddleware2()` - 全局中间件
3. `GroupMiddleware1()` - 路由组中间件
4. `GroupMiddleware2()` - 路由组中间件
5. `RouteMiddleware1()` - 路由级中间件
6. `RouteMiddleware2()` - 路由级中间件
7. `userHandler` - 最终处理函数
```
请求进入 → 全局中间件 → 路由组中间件 → 路由级中间件 → 最终处理函数 → 路由级中间件后置逻辑 → 路由组中间件后置逻辑 → 全局中间件后置逻辑 → 响应
```
## 内置中间件 ## 内置中间件
- **Recovery**: 捕获任何发生的 panic恢复运行并返回 500 错误。它还负责调用全局错误处理器。 - **Recovery**: 捕获任何发生的 panic恢复运行并返回 500 错误。它还负责调用全局错误处理器。

View file

@ -46,7 +46,7 @@ func main() {
// 4. 启动服务器并监听 8080 端口 // 4. 启动服务器并监听 8080 端口
log.Println("Touka server is running on :8080") log.Println("Touka server is running on :8080")
if err := r.Run(touka.WithAddr(":8080")); err != nil { if err := r.Run(":8080"); err != nil {
log.Fatalf("Server failed: %v", err) log.Fatalf("Server failed: %v", err)
} }
} }
@ -66,11 +66,11 @@ go run main.go
## 优雅停机 ## 优雅停机
在生产环境中,我们推荐`Run` 追加优雅关闭选项。启用后Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)` 在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成
```go ```go
// 等待 10 秒以处理剩余请求 // 等待 10 秒以处理剩余请求
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
log.Fatalf("Server forced to shutdown: %v", err) log.Fatalf("Server forced to shutdown: %v", err)
} }
``` ```

View file

@ -28,7 +28,7 @@ func main() {
Target: target, Target: target,
})) }))
_ = r.Run(touka.WithAddr(":8080")) _ = r.Run(":8080")
} }
``` ```
@ -59,16 +59,11 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
```go ```go
type ReverseProxyConfig struct { type ReverseProxyConfig struct {
Target *url.URL Target *url.URL
Targets []string
LoadBalancing ReverseProxyLoadBalancingConfig
PassiveHealth ReverseProxyPassiveHealthConfig
Transport http.RoundTripper Transport http.RoundTripper
FlushInterval time.Duration FlushInterval time.Duration
BufferPool BufferPool BufferPool BufferPool
AllowH2CUpstream bool
ModifyRequest func(*http.Request) ModifyRequest func(*http.Request)
ModifyResponse func(*http.Response) error ModifyResponse func(*http.Response) error
@ -83,133 +78,12 @@ type ReverseProxyConfig struct {
### `Target` ### `Target`
`Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme``host` 必填。表示后端目标地址,至少需要提供 `scheme``host`
```go ```go
target, _ := url.Parse("http://backend:9000") target, _ := url.Parse("http://backend:9000")
``` ```
### `Targets`
可选。用于配置多个后端目标地址。
- `Target``Targets` 互斥,只能使用其中一种
- `Targets` 的每一项都必须是完整 URL
- 每个 target 仍然可以自带 base path 和 query
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001/base?from=a",
"http://127.0.0.1:9002/base?from=b",
},
}))
```
这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。
### `LoadBalancing`
用于配置 upstream 选择策略和重试行为。
```go
type ReverseProxyLoadBalancingConfig struct {
Policy ReverseProxyLBPolicy
Retries int
TryDuration time.Duration
TryInterval time.Duration
}
```
当前内置策略:
- `touka.LBRandom()`
- `touka.LBRoundRobin()`
- `touka.LBFirst()`
- `touka.LBLeastConn()`
- `touka.LBIPHash()`
- `touka.LBClientIPHash()`
- `touka.LBURIHash()`
- `touka.LBHeader("X-Upstream", fallback)`
- `touka.LBQuery("tenant", fallback)`
其中:
- `LBFirst()` 适合主备/故障转移顺序
- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback
- 如果 `LBHeader` / `LBQuery` 没有显式 fallback则默认回退到 `LBRandom()`
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001",
"http://127.0.0.1:9002",
},
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
Policy: touka.LBHeader("X-Upstream", touka.LBFirst()),
Retries: 1,
},
}))
```
重试说明:
- 只对未开始收到上游响应的失败进行重试
- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试
- `Retries` 表示额外重试次数
- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机
- `TryInterval` 表示两次重试之间的等待间隔
### `PassiveHealth`
用于配置被动健康检查。它不会后台探测 upstream而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。
```go
type ReverseProxyPassiveHealthConfig struct {
FailDuration time.Duration
MaxFails int
UnhealthyStatus []int
}
```
- `FailDuration > 0` 时启用被动健康跟踪
- `MaxFails <= 0` 时默认按 `1` 处理
- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream
```go
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Targets: []string{
"http://127.0.0.1:9001",
"http://127.0.0.1:9002",
},
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
Policy: touka.LBFirst(),
},
PassiveHealth: touka.ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
UnhealthyStatus: []int{http.StatusServiceUnavailable},
},
}))
```
### `AllowH2CUpstream`
允许代理使用未加密 HTTP/2h2c`http://` upstream 通信。
- 默认关闭
- 这是一个显式配置项
- 启用后Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游
- 这意味着上游本身也必须显式支持 h2c它不是“先试 h2c失败再自动回退到 h1”的协商模式
```go
r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target,
AllowH2CUpstream: true,
}))
```
对于下游 HTTP/2 extended `CONNECT` websocket 场景Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。
### `Transport` ### `Transport`
可选。用于自定义底层转发所使用的 `http.RoundTripper` 可选。用于自定义底层转发所使用的 `http.RoundTripper`
@ -276,8 +150,6 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
在请求真正发往后端前,对出站请求做最后修改。 在请求真正发往后端前,对出站请求做最后修改。
如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。
常见用途: 常见用途:
- 覆盖 `Host` - 覆盖 `Host`
@ -370,20 +242,11 @@ const (
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target, Target: target,
ForwardedHeaders: touka.ForwardedBoth, ForwardedHeaders: touka.ForwardedBoth,
ForwardedBy: "_gateway-1", ForwardedBy: "gateway-1",
Via: "edge-1", Via: "edge-1",
})) }))
``` ```
如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。
- IPv4`203.0.113.43`
- IPv6 / 带端口:`[2001:db8::17]:443`
- 匿名标识:`_gateway-1`
- 未知:`unknown`
`gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。
`Via` 不是“留空即禁用”的开关。当前实现中: `Via` 不是“留空即禁用”的开关。当前实现中:
- 如果 `Via` 非空,则使用该值追加 `Via` - 如果 `Via` 非空,则使用该值追加 `Via`
@ -419,14 +282,11 @@ Touka 会尽量遵循代理链语义:
Touka 的反向代理实现支持以下能力: Touka 的反向代理实现支持以下能力:
- `CONNECT` 隧道转发HTTP/1.x
- HTTP/2 extended `CONNECT`
- `Connection: Upgrade` / `Upgrade` 协议升级转发 - `Connection: Upgrade` / `Upgrade` 协议升级转发
- WebSocket 等 101 Switching Protocols 场景 - WebSocket 等 101 Switching Protocols 场景
- SSEServer-Sent Events立即刷新 - SSEServer-Sent Events立即刷新
- Trailer 透传 - Trailer 透传
- 1xx 响应透传 - 1xx 响应透传
- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理
例如,代理 WebSocket 服务: 例如,代理 WebSocket 服务:
@ -481,7 +341,7 @@ func main() {
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
Target: target, Target: target,
ForwardedHeaders: touka.ForwardedBoth, ForwardedHeaders: touka.ForwardedBoth,
ForwardedBy: "_gateway-1", ForwardedBy: "gateway-1",
Via: "gateway-1", Via: "gateway-1",
FlushInterval: 100 * time.Millisecond, FlushInterval: 100 * time.Millisecond,
ModifyRequest: func(req *http.Request) { ModifyRequest: func(req *http.Request) {
@ -497,7 +357,7 @@ func main() {
}, },
})) }))
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }

View file

@ -22,8 +22,6 @@ r.ANY("/any", handle)
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
``` ```
服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。
## 路径参数 (Named Parameters) ## 路径参数 (Named Parameters)
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。
@ -142,7 +140,7 @@ func main() {
r := touka.Default() r := touka.Default()
fsroot, _ := fs.Sub(content, "dist") fsroot, _ := fs.Sub(content, "dist")
r.StaticFS("/", http.FS(fsroot)) r.StaticFS("/", http.FS(fsroot))
r.Run(touka.WithAddr(":8080")) r.Run(":8080")
} }
``` ```

View file

@ -40,40 +40,43 @@ r.GET("/events", func(c *touka.Context) {
## 模式二:通道模式 (EventStreamChan) ## 模式二:通道模式 (EventStreamChan)
如果您需要更高级的并发控制(例如从多个异步源接收数据),可以使用通道模式。与回调模式类似,此方法是**阻塞的**handler 会在此方法中停留,直到事件 channel 被关闭或客户端断开连接。 如果您需要更高级的并发控制(例如从多个异步源接收数据),可以使用通道模式。
```go ```go
r.GET("/events-chan", func(c *touka.Context) { r.GET("/events-chan", func(c *touka.Context) {
eventChan := make(chan touka.Event) eventChan, errChan := c.EventStreamChan()
ctx := c.Request.Context()
// 在独立的 goroutine 中发送事件. // 监听错误/断开连接
go func() { go func() {
defer close(eventChan) // 务必在结束时关闭以结束事件流. if err := <-errChan; err != nil {
log.Printf("SSE 错误: %v", err)
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return // 客户端已断开, 退出 goroutine.
case eventChan <- touka.Event{
Data: fmt.Sprintf("消息 #%d", i),
}:
}
time.Sleep(1 * time.Second)
} }
}() }()
// EventStreamChan 会阻塞直到流结束. // 发送数据
c.EventStreamChan(eventChan) go func() {
defer close(eventChan) // 务必在结束时关闭
for i := 0; i < 10; i++ {
select {
case <-c.Request.Context().Done():
return
default:
eventChan <- touka.Event{
Data: fmt.Sprintf("消息 #%d", i),
}
time.Sleep(1 * time.Second)
}
}
}()
}) })
``` ```
## 最佳实践 ## 最佳实践
1. **资源回收**: `EventStreamChan` 是阻塞的handler 在事件流结束前不会返回。将 `c.Request.Context().Done()``eventChan <- ...` 作为同一个 `select` 的两个分支,确保发送操作本身能够响应客户端断开。 1. **资源回收**: 确保在 `EventStreamChan` 模式下正确监听 `c.Request.Context().Done()` 以避免 Goroutine 泄漏。
2. **关闭 Channel**: 生产者完成发送后必须 `close(eventChan)`,否则 handler 会永远阻塞。 2. **数据格式**: SSE 协议要求数据为 UTF-8。Touka 的 `Render` 方法会自动处理多行数据并加上必要的 `data:` 前缀。
3. **数据格式**: SSE 协议要求数据为 UTF-8。Touka 的 `Render` 方法会自动处理多行数据并加上必要的 `data:` 前缀。 3. **超时管理**: SSE 连接通常是长连接,请确保您的反向代理(如 Nginx配置了足够大的写超时时间。
4. **超时管理**: SSE 连接通常是长连接,请确保您的反向代理(如 Nginx配置了足够大的写超时时间。
## 优雅关闭与资源清理 ## 优雅关闭与资源清理
@ -125,4 +128,4 @@ r.GET("/events-graceful", func(c *touka.Context) {
2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。 2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。
3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。 3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。
**注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)``touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文 **注意:** 请务必使用 `RunShutdown``RunTLS``RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号

View file

@ -39,7 +39,7 @@ func main() {
// 您也可以使用 StaticFS 服务根路径 // 您也可以使用 StaticFS 服务根路径
// r.StaticFS("/", http.FS(fsroot)) // r.StaticFS("/", http.FS(fsroot))
r.Run(touka.WithAddr(":8080")) r.Run(":8080")
} }
``` ```

2
ecw.go
View file

@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool {
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := ecw.w.(http.Hijacker) hijacker, ok := ecw.w.(http.Hijacker)
if !ok { if !ok {
return nil, nil, http.ErrNotSupported return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface")
} }
return hijacker.Hijack() return hijacker.Hijack()
} }

View file

@ -1,59 +0,0 @@
package touka
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestErrorCapturingResponseWriterResetClearsHeaderSnapshot(t *testing.T) {
c, _ := CreateTestContext(nil)
ecw := AcquireErrorCapturingResponseWriter(c)
defer ReleaseErrorCapturingResponseWriter(ecw)
ecw.capturedErrorSignal = true
ecw.Header().Set("Content-Type", "text/plain")
ecw.Header().Add("X-Test", "one")
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatalf("failed to build request: %v", err)
}
ecw.reset(httptest.NewRecorder(), req, c, c.engine.errorHandle.handler)
if len(ecw.headerSnapshot) != 0 {
t.Fatalf("expected header snapshot to be empty after reset, got %#v", ecw.headerSnapshot)
}
}
func BenchmarkErrorCapturingResponseWriterReset(b *testing.B) {
c, _ := CreateTestContext(nil)
ecw := AcquireErrorCapturingResponseWriter(c)
defer ReleaseErrorCapturingResponseWriter(ecw)
rawWriter := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
keys := make([]string, 16)
for i := range keys {
keys[i] = http.CanonicalHeaderKey("X-Test-" + string(rune('A'+i)))
}
values := []string{"one", "two", "three"}
for _, key := range keys {
ecw.headerSnapshot[key] = values
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ecw.reset(rawWriter, req, c, c.engine.errorHandle.handler)
for _, key := range keys {
ecw.headerSnapshot[key] = values
}
}
}

392
engine.go
View file

@ -7,11 +7,9 @@ package touka
import ( import (
"context" "context"
"errors" "errors"
"io"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"unicode/utf8"
"net/http" "net/http"
@ -19,7 +17,6 @@ import (
"github.com/WJQSERVER-STUDIO/httpc" "github.com/WJQSERVER-STUDIO/httpc"
"github.com/fenthope/reco" "github.com/fenthope/reco"
"github.com/go-json-experiment/json"
) )
// Last 返回链中的最后一个处理函数 // Last 返回链中的最后一个处理函数
@ -52,14 +49,8 @@ type Engine struct {
HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求 HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求
// LogReco 保留的 reco.Logger 字段
// Deprecated: 使用 SetLogger/GetLogger 替代
LogReco *reco.Logger LogReco *reco.Logger
// logger 是新的日志接口,支持任意 Logger 实现
// 优先级: logger > LogReco
logger Logger
HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口
routesInfo []RouteInfo // 存储所有注册的路由信息 routesInfo []RouteInfo // 存储所有注册的路由信息
@ -90,11 +81,6 @@ type Engine struct {
// GlobalMaxRequestBodySize 全局请求体Body大小限制 // GlobalMaxRequestBodySize 全局请求体Body大小限制
GlobalMaxRequestBodySize int64 GlobalMaxRequestBodySize int64
notFoundChain HandlersChain
notFoundNoMethodChain HandlersChain
unmatchedFSChain HandlersChain
unmatchedFSNoMethodChain HandlersChain
} }
// HandleFunc 注册一个或多个 HTTP 方法的路由 // HandleFunc 注册一个或多个 HTTP 方法的路由
@ -130,90 +116,6 @@ type ErrorHandle struct {
type ErrorHandler func(c *Context, code int, err error) type ErrorHandler func(c *Context, code int, err error)
var errMethodNotAllowed = errors.New("method not allowed")
var errNotFound = errors.New("not found")
type defaultErrorResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Error string `json:"error"`
}
var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error())
var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error())
func mustMarshalDefaultErrorBody(code int, errMsg string) []byte {
body, err := json.Marshal(defaultErrorResponse{
Code: code,
Message: http.StatusText(code),
Error: errMsg,
})
if err != nil {
panic(err)
}
return body
}
func writeDefaultErrorJSON(c *Context, code int, body []byte) {
if c == nil || c.Writer == nil {
return
}
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.Writer.WriteHeader(code)
c.writeResponseBody(body, "failed to write default error response")
c.Writer.Flush()
c.Abort()
}
var methodNotAllowedHandler HandlerFunc = func(c *Context) {
httpMethod := c.Request.Method
requestPath := routeLookupPath(c.Request)
engine := c.engine
// 是否是OPTIONS方式
if httpMethod == http.MethodOptions {
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0])
c.allowedMethodsBuf = allowedMethods[:0]
if len(allowedMethods) > 0 {
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
allowHeader := c.allowHeaderBuf[:0]
for i, method := range allowedMethods {
if i > 0 {
allowHeader = append(allowHeader, ',', ' ')
}
allowHeader = append(allowHeader, method...)
}
c.allowHeaderBuf = allowHeader[:0]
c.Writer.Header().Set("Allow", string(allowHeader))
c.Status(http.StatusOK)
return
}
return
}
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
tempSkippedNodes := GetTempSkippedNodes()
for _, treeIter := range engine.methodTrees {
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
continue
}
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
*tempSkippedNodes = (*tempSkippedNodes)[:0]
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
if value.handlers != nil {
PutTempSkippedNodes(tempSkippedNodes)
// 使用定义的ErrorHandle处理
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed)
return
}
}
PutTempSkippedNodes(tempSkippedNodes)
}
var notFoundHandler HandlerFunc = func(c *Context) {
engine := c.engine
engine.errorHandle.handler(c, http.StatusNotFound, errNotFound)
}
// defaultErrorHandle 默认错误处理 // defaultErrorHandle 默认错误处理
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
select { select {
@ -224,22 +126,16 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是
if c.Writer.Written() { if c.Writer.Written() {
return return
} }
if len(c.Errors) == 0 {
switch {
case code == http.StatusNotFound && errors.Is(err, errNotFound):
writeDefaultErrorJSON(c, code, defaultNotFoundBody)
return
case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed):
writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody)
return
}
}
// 输出json 状态码与状态码对应描述 // 输出json 状态码与状态码对应描述
var errMsg string var errMsg string
if err != nil { if err != nil {
errMsg = err.Error() errMsg = err.Error()
} }
c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg}) c.JSON(code, H{
"code": code,
"message": http.StatusText(code),
"error": errMsg,
})
c.Writer.Flush() c.Writer.Flush()
c.Abort() c.Abort()
return return
@ -314,7 +210,6 @@ func New() *Engine {
TLSServerConfigurator: nil, TLSServerConfigurator: nil,
GlobalMaxRequestBodySize: -1, GlobalMaxRequestBodySize: -1,
} }
engine.rebuildFallbackChains()
engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background())
//engine.SetProtocols(GetDefaultProtocolsConfig()) //engine.SetProtocols(GetDefaultProtocolsConfig())
engine.SetDefaultProtocols() engine.SetDefaultProtocols()
@ -370,32 +265,18 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) {
// 是否开启MethodNotAllowed // 是否开启MethodNotAllowed
func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { func (engine *Engine) SetHandleMethodNotAllowed(enable bool) {
engine.HandleMethodNotAllowed = enable engine.HandleMethodNotAllowed = enable
engine.rebuildFallbackChains()
} }
// SetLogger 传入 Logger 接口实例 // SetLogger传入实例
func (engine *Engine) SetLogger(logger Logger) { func (engine *Engine) SetLogger(logger *reco.Logger) {
engine.logger = logger
// 同步更新 LogReco 以保持向后兼容
if rl, ok := logger.(*reco.Logger); ok {
engine.LogReco = rl
} else {
engine.LogReco = nil
}
}
// GetLogger 返回 Logger 接口实例
func (engine *Engine) GetLogger() Logger {
return engine.logger
}
// SetLoggerCfg 使用 reco.Config 配置日志
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
logger := NewLogger(logcfg)
engine.logger = logger
engine.LogReco = logger engine.LogReco = logger
} }
// 配置日志LoggerCfg
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
engine.LogReco = NewLogger(logcfg)
}
// 设置自定义错误处理 // 设置自定义错误处理
func (engine *Engine) SetErrorHandler(handler ErrorHandler) { func (engine *Engine) SetErrorHandler(handler ErrorHandler) {
engine.errorHandle.useDefault = false engine.errorHandle.useDefault = false
@ -424,7 +305,6 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF
engine.unMatchFS.ServeUnmatchedAsFS = false engine.unMatchFS.ServeUnmatchedAsFS = false
engine.UnMatchFSRoutes = nil engine.UnMatchFSRoutes = nil
} }
engine.rebuildFallbackChains()
} }
// 获取默认Protocol配置 // 获取默认Protocol配置
@ -460,28 +340,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
}() }()
} }
func cloneServerProtocols(protocols *http.Protocols) *http.Protocols {
if protocols == nil {
return nil
}
cloned := *protocols
return &cloned
}
func applyServerProtocols(srv *http.Server, protocols *http.Protocols) {
if protocols != nil {
srv.Protocols = cloneServerProtocols(protocols)
if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() {
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
panic(err)
}
}
}
}
// applyDefaultServerConfig 应用框架的默认配置到 http.Server // applyDefaultServerConfig 应用框架的默认配置到 http.Server
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
applyServerProtocols(srv, engine.serverProtocols) if engine.serverProtocols != nil {
srv.Protocols = engine.serverProtocols
}
} }
// 配置全局Req Body大小限制 // 配置全局Req Body大小限制
@ -610,64 +473,66 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
// 405中间件 // 405中间件
func MethodNotAllowed() HandlerFunc { func MethodNotAllowed() HandlerFunc {
return methodNotAllowedHandler return func(c *Context) {
httpMethod := c.Request.Method
requestPath := c.Request.URL.Path
engine := c.engine
// 是否是OPTIONS方式
if httpMethod == http.MethodOptions {
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
allowedMethods := []string{}
for _, treeIter := range engine.methodTrees {
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
tempSkippedNodes := GetTempSkippedNodes()
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil {
allowedMethods = append(allowedMethods, treeIter.method)
}
}
if len(allowedMethods) > 0 {
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
c.Status(http.StatusOK)
return
}
}
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
for _, treeIter := range engine.methodTrees {
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
continue
}
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
tempSkippedNodes := GetTempSkippedNodes()
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
PutTempSkippedNodes(tempSkippedNodes)
if value.handlers != nil {
// 使用定义的ErrorHandle处理
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
return
}
}
}
} }
// 404最后处理 // 404最后处理
func NotFound() HandlerFunc { func NotFound() HandlerFunc {
return notFoundHandler return func(c *Context) {
engine := c.engine
engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found"))
}
} }
// 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理)
func (Engine *Engine) NoRoute(handler HandlerFunc) { func (Engine *Engine) NoRoute(handler HandlerFunc) {
Engine.noRoute = handler Engine.noRoute = handler
Engine.noRoutes = nil Engine.noRoutes = nil
Engine.rebuildFallbackChains()
} }
// 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理)
func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) {
Engine.noRoute = nil Engine.noRoute = nil
Engine.noRoutes = handlerFuncs Engine.noRoutes = handlerFuncs
Engine.rebuildFallbackChains()
}
func (engine *Engine) rebuildFallbackChains() {
buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain {
finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound
if includeMethodNotAllowed {
finalSize++
}
if includeUnmatchedFS {
finalSize += len(engine.UnMatchFSRoutes)
}
if engine.noRoute != nil {
finalSize++
} else {
finalSize += len(engine.noRoutes)
}
chain := make(HandlersChain, 0, finalSize)
chain = append(chain, engine.globalHandlers...)
if includeMethodNotAllowed {
chain = append(chain, methodNotAllowedHandler)
}
if includeUnmatchedFS {
chain = append(chain, engine.UnMatchFSRoutes...)
}
if engine.noRoute != nil {
chain = append(chain, engine.noRoute)
} else if len(engine.noRoutes) > 0 {
chain = append(chain, engine.noRoutes...)
}
chain = append(chain, notFoundHandler)
return chain
}
engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false)
engine.notFoundNoMethodChain = buildChain(false, false)
engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS)
engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS)
} }
// combineHandlers 组合多个处理函数链为一个 // combineHandlers 组合多个处理函数链为一个
@ -682,9 +547,8 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
// Use 将全局中间件添加到 Engine // Use 将全局中间件添加到 Engine
// 这些中间件将应用于所有注册的路由 // 这些中间件将应用于所有注册的路由
func (engine *Engine) Use(middleware ...HandlerFunc) Router { func (engine *Engine) Use(middleware ...HandlerFunc) IRouter {
engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.globalHandlers = append(engine.globalHandlers, middleware...)
engine.rebuildFallbackChains()
return engine return engine
} }
@ -751,7 +615,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo {
// Group 创建一个新的路由组 // Group 创建一个新的路由组
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router { func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter {
return &RouterGroup{ return &RouterGroup{
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
basePath: resolveRoutePath("/", relativePath), basePath: resolveRoutePath("/", relativePath),
@ -760,7 +624,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router
} }
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
// 它也实现了 Router 接口,允许嵌套分组 // 它也实现了 IRouter 接口,允许嵌套分组
type RouterGroup struct { type RouterGroup struct {
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
basePath string // 组路径前缀 basePath string // 组路径前缀
@ -769,7 +633,7 @@ type RouterGroup struct {
// Use 将中间件应用于当前路由组 // Use 将中间件应用于当前路由组
// 这些中间件将应用于当前组及其子组的所有路由 // 这些中间件将应用于当前组及其子组的所有路由
func (group *RouterGroup) Use(middleware ...HandlerFunc) Router { func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter {
group.Handlers = append(group.Handlers, middleware...) group.Handlers = append(group.Handlers, middleware...)
return group return group
} }
@ -815,7 +679,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) {
} }
// Group 为当前组创建一个新的子组 // Group 为当前组创建一个新的子组
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router { func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter {
return &RouterGroup{ return &RouterGroup{
Handlers: group.engine.combineHandlers(group.Handlers, handlers), Handlers: group.engine.combineHandlers(group.Handlers, handlers),
basePath: resolveRoutePath(group.basePath, relativePath), basePath: resolveRoutePath(group.basePath, relativePath),
@ -840,13 +704,8 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// handleRequest 负责根据请求查找路由并执行相应的处理函数链 // handleRequest 负责根据请求查找路由并执行相应的处理函数链
// 这是路由查找和执行的核心逻辑 // 这是路由查找和执行的核心逻辑
func (engine *Engine) handleRequest(c *Context) { func (engine *Engine) handleRequest(c *Context) {
if isGeneralOptionsRequest(c.Request) {
engine.handleGeneralOptions(c)
return
}
httpMethod := c.Request.Method httpMethod := c.Request.Method
requestPath := routeLookupPath(c.Request) requestPath := c.Request.URL.Path
// 查找对应的路由树的根节点 // 查找对应的路由树的根节点
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
@ -866,7 +725,7 @@ func (engine *Engine) handleRequest(c *Context) {
} }
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向 if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向
if value.tsr && engine.RedirectTrailingSlash { if value.tsr && engine.RedirectTrailingSlash {
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
redirectPath := requestPath redirectPath := requestPath
@ -878,98 +737,51 @@ func (engine *Engine) handleRequest(c *Context) {
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
return return
} }
if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) { // 尝试不区分大小写的查找
// 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. // 直接在 rootNode 上调用 findCaseInsensitivePath 方法
ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash) ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash)
if found { if found && engine.RedirectFixedPath {
c.fixedPathBuf = ciPath[:0] c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径
c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径 return
return
}
c.fixedPathBuf = c.fixedPathBuf[:0]
} }
} }
} }
if engine.unMatchFS.ServeUnmatchedAsFS { // 构建处理链
c.handlers = engine.unmatchedFSChain // 组合全局中间件和路由处理函数
} else { handlers := engine.globalHandlers
c.handlers = engine.notFoundChain
// 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由
// 则在全局中间件之后添加 MethodNotAllowed 处理器
if engine.HandleMethodNotAllowed {
handlers = append(handlers, MethodNotAllowed())
} }
// 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed
// 则在处理链的最后添加 UnMatchFS 处理器
if engine.unMatchFS.ServeUnmatchedAsFS {
/*
var unMatchFSHandle = c.engine.unMatchFileServer
handlers = append(handlers, unMatchFSHandle)
*/
handlers = append(handlers, engine.UnMatchFSRoutes...)
}
// 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS
// 则在处理链的最后添加 NoRoute 处理器
if engine.noRoute != nil {
handlers = append(handlers, engine.noRoute)
} else if len(engine.noRoutes) > 0 {
handlers = append(handlers, engine.noRoutes...)
}
handlers = append(handlers, NotFound())
c.handlers = handlers
c.Next() // 执行处理函数链 c.Next() // 执行处理函数链
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送
} }
func routeLookupPath(req *http.Request) string {
if req == nil {
return ""
}
if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") {
return "/" + req.RequestURI
}
if isGeneralOptionsRequest(req) {
return ""
}
if req.URL == nil {
return ""
}
return req.URL.Path
}
func isGeneralOptionsRequest(req *http.Request) bool {
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
}
func shouldTryFixedPathLookup(path string, root *node) bool {
if root != nil && root.hasCaseInsensitivePath {
return true
}
for i := 0; i < len(path); i++ {
c := path[i]
if c >= utf8.RuneSelf {
return true
}
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
}
func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string {
if cap(allowedMethods) < len(engine.methodTrees) {
allowedMethods = make([]string, 0, len(engine.methodTrees))
} else {
allowedMethods = allowedMethods[:0]
}
tempSkippedNodes := GetTempSkippedNodes()
for _, treeIter := range engine.methodTrees {
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
*tempSkippedNodes = (*tempSkippedNodes)[:0]
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
if value.handlers != nil {
allowedMethods = append(allowedMethods, treeIter.method)
}
}
PutTempSkippedNodes(tempSkippedNodes)
return allowedMethods
}
func (engine *Engine) handleGeneralOptions(c *Context) {
if c == nil || c.Request == nil {
return
}
c.Writer.Header().Set("Content-Length", "0")
if c.Request.ContentLength != 0 {
mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10)
_, _ = io.Copy(io.Discard, mb)
}
c.Writer.WriteHeader(http.StatusOK)
c.Abort()
}
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
// 它可以用于在长连接 (如 SSE) 中监听关闭信号. // 它可以用于在长连接 (如 SSE) 中监听关闭信号.
func (engine *Engine) Context() context.Context { func (engine *Engine) Context() context.Context {

View file

@ -1,71 +0,0 @@
package touka
import (
"net/http"
"net/http/httptest"
"testing"
)
var benchmarkStatusCode int
func buildServeHTTPBenchmarkEngine() *Engine {
engine := New()
engine.GET("/api/v1/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.GET("/api/v1/users/:id", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.POST("/api/v1/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
return engine
}
func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) {
b.Helper()
req, err := http.NewRequest(method, path, nil)
if err != nil {
b.Fatalf("failed to build request: %v", err)
}
rr := httptest.NewRecorder()
engine.ServeHTTP(rr, req)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr = httptest.NewRecorder()
engine.ServeHTTP(rr, req)
}
benchmarkStatusCode = rr.Code
}
func BenchmarkServeHTTP(b *testing.B) {
engine := buildServeHTTPBenchmarkEngine()
b.Run("StaticHit", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users")
})
b.Run("NotFound", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist")
})
b.Run("MethodNotAllowed", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users")
})
b.Run("OptionsAllow", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users")
})
b.Run("FixedPathRedirect", func(b *testing.B) {
benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS")
})
}

View file

@ -1,306 +0,0 @@
package touka
import (
"bufio"
"encoding/json"
"errors"
"html/template"
"net"
"net/http"
"testing"
)
type failingResponseWriter struct {
header http.Header
status int
err error
}
func (w *failingResponseWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *failingResponseWriter) WriteHeader(statusCode int) {
if w.status == 0 {
w.status = statusCode
}
}
func (w *failingResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
if w.err != nil {
return 0, w.err
}
return len(p), nil
}
func (w *failingResponseWriter) Flush() {}
func (w *failingResponseWriter) Status() int {
return w.status
}
func (w *failingResponseWriter) Size() int {
return 0
}
func (w *failingResponseWriter) Written() bool {
return w.status != 0
}
func (w *failingResponseWriter) IsHijacked() bool {
return false
}
func (w *failingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, http.ErrNotSupported
}
func TestHandleRequestRedirectFixedPath(t *testing.T) {
engine := New()
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" {
t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location)
}
}
func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) {
engine := New()
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code)
}
}
func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) {
engine := New()
engine.GET("/Users/Profile", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code)
}
if location := rr.Header().Get("Location"); location != "/Users/Profile" {
t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location)
}
}
func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) {
engine := New()
engine.GET("/Users/Profile", func(c *Context) {
c.Status(http.StatusNoContent)
})
defer func() {
if r := recover(); r != nil {
t.Fatalf("unexpected panic for fixed-path miss: %v", r)
}
}()
rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code)
}
}
func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) {
engine := New()
engine.NoRoute(func(c *Context) {
c.Writer.Header().Set("X-NoRoute", "hit")
c.Next()
})
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code)
}
if got := rr.Header().Get("X-NoRoute"); got != "hit" {
t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got)
}
}
func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.NoRoute(func(c *Context) {
c.Writer.Header().Set("X-NoRoute", "hit")
c.Next()
})
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
if rr.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
if got := rr.Header().Get("X-NoRoute"); got != "" {
t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got)
}
}
func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
engine.POST("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil)
if rr.Code != http.StatusOK {
t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code)
}
allow := rr.Header().Get("Allow")
if allow != "GET, POST" && allow != "POST, GET" {
t.Fatalf("expected Allow header to list matching methods, got %q", allow)
}
}
func TestDefaultErrorHandleJSONShape(t *testing.T) {
engine := New()
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
if rr.Code != http.StatusNotFound {
t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code)
}
var body struct {
Code int `json:"code"`
Message string `json:"message"`
Error string `json:"error"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
}
if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" {
t.Fatalf("unexpected error payload: %+v", body)
}
}
func TestDefaultMethodNotAllowedJSONShape(t *testing.T) {
engine := New()
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
if rr.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
var body struct {
Code int `json:"code"`
Message string `json:"message"`
Error string `json:"error"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
}
if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" {
t.Fatalf("unexpected error payload: %+v", body)
}
}
func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) {
engine := New()
engine.SetErrorHandler(func(c *Context, code int, err error) {
c.Writer.Header().Set("X-Custom-Error", "1")
c.String(code, "custom:%v", err)
})
engine.GET("/users", func(c *Context) {
c.Status(http.StatusNoContent)
})
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
if rr.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
}
if got := rr.Header().Get("X-Custom-Error"); got != "1" {
t.Fatalf("expected custom error header, got %q", got)
}
if rr.Body.String() != "custom:method not allowed" {
t.Fatalf("expected custom error body, got %q", rr.Body.String())
}
}
func TestResponseHelpersCaptureWriteErrors(t *testing.T) {
testCases := []struct {
name string
run func(*Context)
}{
{name: "Raw", run: func(c *Context) { c.Raw(http.StatusOK, "application/octet-stream", []byte("payload")) }},
{name: "String", run: func(c *Context) { c.String(http.StatusOK, "value=%d", 1) }},
{name: "Text", run: func(c *Context) { c.Text(http.StatusOK, "payload") }},
{name: "JSONBuf", run: func(c *Context) { c.JSONBuf(http.StatusOK, map[string]string{"a": "b"}) }},
{name: "GOBBuf", run: func(c *Context) { c.GOBBuf(http.StatusOK, struct{ A string }{A: "b"}) }},
{name: "WANFBuf", run: func(c *Context) { c.WANFBuf(http.StatusOK, map[string]string{"a": "b"}) }},
{name: "HTMLFallback", run: func(c *Context) { c.HTML(http.StatusOK, "page", map[string]string{"a": "b"}) }},
{name: "HTMLBuf", run: func(c *Context) {
c.engine.HTMLRender = template.Must(template.New("page").Parse(`{{.a}}`))
c.HTMLBuf(http.StatusOK, "page", map[string]string{"a": "b"})
}},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
writerErr := errors.New("write failed")
w := &failingResponseWriter{err: writerErr}
c, _ := CreateTestContext(w)
tc.run(c)
if got := len(c.Errors); got != 1 {
t.Fatalf("expected exactly one captured error, got %d", got)
}
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
}
})
}
}
func TestDefaultErrorFastPathCapturesWriteErrors(t *testing.T) {
writerErr := errors.New("write failed")
w := &failingResponseWriter{err: writerErr}
engine := New()
c, _ := CreateTestContext(w)
c.engine = engine
req, err := http.NewRequest(http.MethodGet, "/missing", nil)
if err != nil {
t.Fatalf("failed to build request: %v", err)
}
c.reset(w, req)
defaultErrorHandle(c, http.StatusNotFound, errNotFound)
if len(c.Errors) == 0 {
t.Fatal("expected write error to be captured")
}
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
}
if c.Writer.Status() != http.StatusNotFound {
t.Fatalf("expected status %d, got %d", http.StatusNotFound, c.Writer.Status())
}
if !c.IsAborted() {
t.Fatal("expected fast path to abort context")
}
}

View file

@ -1,103 +0,0 @@
package main
import (
"fmt"
"net/http"
"github.com/infinite-iroha/touka"
)
func main() {
r := touka.Default()
// 示例 1简单 GET 请求(自动关联请求 Context
r.GET("/proxy", func(c *touka.Context) {
// 使用 HTTPC() 方法,自动关联请求 Context
// 当客户端断开连接时,出站请求也会自动取消
body, err := c.HTTPC().
GET("https://httpbin.org/get").
Text()
if err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
c.String(http.StatusOK, "%s", body)
})
// 示例 2带 Header 的 POST 请求
r.POST("/users", func(c *touka.Context) {
var req struct {
Name string `json:"name"`
Email string `json:"email"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()})
return
}
var result struct {
ID int `json:"id"`
Name string `json:"name"`
}
// 链式调用,保持 httpc 风格
// 注意SetJSONBody 返回 (*RequestBuilder, error)
rb, err := c.HTTPC().
POST("https://httpbin.org/post").
SetHeader("X-API-Key", "secret").
SetJSONBody(req)
if err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
if err := rb.DecodeJSON(&result); err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
})
// 示例 3带查询参数的请求
r.GET("/search", func(c *touka.Context) {
query := c.DefaultQuery("q", "")
page := c.DefaultQuery("page", "1")
var result struct {
Items []string `json:"items"`
Total int `json:"total"`
}
err := c.HTTPC().
GET("https://httpbin.org/get").
SetQueryParam("q", query).
SetQueryParam("page", page).
DecodeJSON(&result)
if err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
})
// 示例 4使用底层 httpc.Client旧方式仍可用但不推荐
r.GET("/legacy", func(c *touka.Context) {
// 旧方式:需要手动 WithContext
body, err := c.Client().
GET("https://httpbin.org/get").
WithContext(c.Context()).
Text()
if err != nil {
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
return
}
c.String(http.StatusOK, "%s", body)
})
fmt.Println("Server running on :8080")
fmt.Println("Try:")
fmt.Println(" curl http://localhost:8080/proxy")
fmt.Println(" curl -X POST -d '{\"name\":\"test\",\"email\":\"test@example.com\"}' http://localhost:8080/users")
fmt.Println(" curl 'http://localhost:8080/search?q=golang&page=1'")
// r.Run(touka.WithAddr(":8080"))
}

View file

@ -1,71 +0,0 @@
package main
import (
"fmt"
"log/slog"
"net/http"
"os"
"github.com/infinite-iroha/touka"
)
// SlogAdapter 将 slog.Logger 适配到 touka.Logger 接口
type SlogAdapter struct {
logger *slog.Logger
}
func NewSlogAdapter(handler slog.Handler) *SlogAdapter {
return &SlogAdapter{
logger: slog.New(handler),
}
}
func (s *SlogAdapter) Debugf(format string, args ...any) {
s.logger.Debug(fmt.Sprintf(format, args...))
}
func (s *SlogAdapter) Infof(format string, args ...any) {
s.logger.Info(fmt.Sprintf(format, args...))
}
func (s *SlogAdapter) Warnf(format string, args ...any) {
s.logger.Warn(fmt.Sprintf(format, args...))
}
func (s *SlogAdapter) Errorf(format string, args ...any) {
s.logger.Error(fmt.Sprintf(format, args...))
}
func (s *SlogAdapter) Fatalf(format string, args ...any) {
s.logger.Error(fmt.Sprintf(format, args...))
os.Exit(1)
}
func (s *SlogAdapter) Panicf(format string, args ...any) {
s.logger.Error(fmt.Sprintf(format, args...))
panic(fmt.Sprintf(format, args...))
}
func main() {
engine := touka.New()
// 使用 slog 替换默认的 reco.Logger
handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelDebug,
})
slogAdapter := NewSlogAdapter(handler)
engine.SetLogger(slogAdapter)
engine.GET("/", func(c *touka.Context) {
c.Infof("request received: %s", c.Request.URL.Path)
c.JSON(http.StatusOK, map[string]string{"message": "hello"})
})
// 也可以获取 Logger 接口
logger := engine.GetLogger()
logger.Debugf("engine started")
// 也可以直接使用 slog
slog.Info("Server running", "addr", ":8080")
// engine.Run(":8080")
}

7
go.mod
View file

@ -3,15 +3,14 @@ module github.com/infinite-iroha/touka
go 1.26 go 1.26
require ( require (
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2
github.com/WJQSERVER-STUDIO/httpc v0.9.3 github.com/WJQSERVER-STUDIO/httpc v0.9.0
github.com/WJQSERVER/wanf v0.0.8 github.com/WJQSERVER/wanf v0.0.8
github.com/fenthope/reco v0.0.5 github.com/fenthope/reco v0.0.5
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
golang.org/x/net v0.53.0
) )
require ( require (
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/text v0.36.0 // indirect golang.org/x/net v0.52.0 // indirect
) )

14
go.sum
View file

@ -1,7 +1,7 @@
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 h1:Hc1O6D50U3URkdSzfQ/SgeUU750wUBCYhefdvAbE2Ck= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM=
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3/go.mod h1:nFQzepAwwdj5Hp5U+X19l4FVvsaOSBTW41BzfI/CkMA= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok=
github.com/WJQSERVER-STUDIO/httpc v0.9.3 h1:wYZkz9f/+2WuDuzPlExebvnn0q6QeArM15Y51HJ5UUI= github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4=
github.com/WJQSERVER-STUDIO/httpc v0.9.3/go.mod h1:vtaDmN/8gN8Es1DJsGvvrFr8kErysJndu87i+KOWUHY= github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg=
github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8= github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8=
github.com/WJQSERVER/wanf v0.0.8/go.mod h1:R0Zw/1skEMVlQ9m5atbkmanlW+9h2bkdq7+wbPY+F/8= github.com/WJQSERVER/wanf v0.0.8/go.mod h1:R0Zw/1skEMVlQ9m5atbkmanlW+9h2bkdq7+wbPY+F/8=
github.com/fenthope/reco v0.0.5 h1:Z/bOunFf4LSgYP/IxG9fe2pTrIq7bPsDflflbNR5Agw= github.com/fenthope/reco v0.0.5 h1:Z/bOunFf4LSgYP/IxG9fe2pTrIq7bPsDflflbNR5Agw=
@ -10,7 +10,5 @@ github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVw
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg= github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=

View file

@ -1,88 +0,0 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"crypto/tls"
"net"
"net/http"
"sync"
"time"
_ "unsafe"
"golang.org/x/net/http2"
)
var enableHTTP2ExtendedConnectOnce sync.Once
//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol
var xnetDisableHTTP2ExtendedConnectProtocol bool
func enableHTTP2ExtendedConnectProtocol() {
enableHTTP2ExtendedConnectOnce.Do(func() {
xnetDisableHTTP2ExtendedConnectProtocol = false
})
}
func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
if srv == nil {
return nil
}
enableHTTP2ExtendedConnectProtocol()
return http2.ConfigureServer(srv, nil)
}
func newHTTP2ExtendedConnectTransport() http.RoundTripper {
enableHTTP2ExtendedConnectProtocol()
transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols)
transport.Protocols.SetHTTP1(true)
transport.Protocols.SetHTTP2(true)
return transport
}
func newHTTP1BridgeTransport() http.RoundTripper {
return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}})
}
func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper {
transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols)
transport.Protocols.SetHTTP1(true)
if tlsConfig == nil {
transport.TLSClientConfig = &tls.Config{}
} else {
transport.TLSClientConfig = tlsConfig.Clone()
}
if len(transport.TLSClientConfig.NextProtos) == 0 {
transport.TLSClientConfig.NextProtos = []string{"http/1.1"}
}
return transport
}
func newH2CTransport() http.RoundTripper {
transport := cloneDefaultTransport()
transport.Protocols = new(http.Protocols)
transport.Protocols.SetUnencryptedHTTP2(true)
return transport
}
func cloneDefaultTransport() *http.Transport {
if transport, ok := http.DefaultTransport.(*http.Transport); ok {
return transport.Clone()
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

View file

@ -1,150 +0,0 @@
package touka
import (
"bytes"
"io"
"testing"
"github.com/WJQSERVER-STUDIO/go-utils/iox"
)
type benchmarkResetReader struct {
data []byte
off int
}
func (r *benchmarkResetReader) Read(p []byte) (int, error) {
if r.off >= len(r.data) {
return 0, io.EOF
}
n := copy(p, r.data[r.off:])
r.off += n
return n, nil
}
func (r *benchmarkResetReader) Reset() {
r.off = 0
}
type benchmarkDiscardWriter struct{}
func (benchmarkDiscardWriter) Write(p []byte) (int, error) {
return len(p), nil
}
var benchmarkIOXResult int64
var benchmarkIOXBytes []byte
func BenchmarkIOXCopyComparison(b *testing.B) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
b.Run("io.Copy", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := io.Copy(w, r)
if err != nil {
b.Fatalf("io.Copy failed: %v", err)
}
benchmarkIOXResult = n
}
})
b.Run("iox.Copy", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := iox.Copy(w, r)
if err != nil {
b.Fatalf("iox.Copy failed: %v", err)
}
benchmarkIOXResult = n
}
})
}
func BenchmarkIOXCopyBufferComparison(b *testing.B) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
b.Run("io.CopyBuffer", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
buf := make([]byte, 32*1024)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := io.CopyBuffer(w, r, buf)
if err != nil {
b.Fatalf("io.CopyBuffer failed: %v", err)
}
benchmarkIOXResult = n
}
})
b.Run("iox.CopyBuffer", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
buf := make([]byte, 32*1024)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := iox.CopyBuffer(w, r, buf)
if err != nil {
b.Fatalf("iox.CopyBuffer failed: %v", err)
}
benchmarkIOXResult = n
}
})
}
func BenchmarkIOXReadAllComparison(b *testing.B) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
b.Run("io.ReadAll", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
data, err := io.ReadAll(r)
if err != nil {
b.Fatalf("io.ReadAll failed: %v", err)
}
benchmarkIOXBytes = data
}
})
b.Run("iox.ReadAll", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
data, err := io.ReadAll(r)
if err != nil {
b.Fatalf("iox.ReadAll failed: %v", err)
}
benchmarkIOXBytes = data
}
})
}

View file

@ -1,23 +0,0 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2024 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
// Logger 是日志接口支持多种日志库实现reco、zap、logrus 等)
// 用户可以通过实现此接口来替换默认的日志实现
type Logger interface {
Debugf(format string, args ...any)
Infof(format string, args ...any)
Warnf(format string, args ...any)
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
Panicf(format string, args ...any)
}
// CloserLogger 可选扩展接口,支持关闭操作
// 如果 Logger 实现了此接口Engine 在关闭时会调用 Close()
type CloserLogger interface {
Logger
Close() error
}

View file

@ -39,16 +39,7 @@ func CloseLogger(logger *reco.Logger) {
} }
} }
// CloseLogger 关闭 Engine 的日志实现
// 如果 logger 实现了 CloserLogger 接口,会调用其 Close 方法
func (engine *Engine) CloseLogger() { func (engine *Engine) CloseLogger() {
if cl, ok := engine.logger.(CloserLogger); ok {
if err := cl.Close(); err != nil {
log.Printf("Close Logger Error: %s", err)
}
return
}
// 兼容旧代码
if engine.LogReco != nil { if engine.LogReco != nil {
CloseLogger(engine.LogReco) CloseLogger(engine.LogReco)
} }

View file

@ -23,21 +23,19 @@ type maxBytesReader struct {
n int64 n int64
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数. // read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
read atomic.Int64 read atomic.Int64
// emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读.
emptyAtLimit atomic.Bool
} }
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据, // NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误. // 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误.
// //
// 如果 r 为 nil, 会 panic. // 如果 r 为 nil, 会 panic.
// 如果 n 小于等于 0, 则读取不受限制, 直接返回原始的 r. // 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r.
func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
if r == nil { if r == nil {
panic("NewMaxBytesReader called with a nil reader") panic("NewMaxBytesReader called with a nil reader")
} }
// 如果限制为非正数, 意味着不限制, 直接返回原始的 ReadCloser. // 如果限制为数, 意味着不限制, 直接返回原始的 ReadCloser.
if n <= 0 { if n < 0 {
return r return r
} }
return &maxBytesReader{ return &maxBytesReader{
@ -48,53 +46,48 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制. // Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
func (mbr *maxBytesReader) Read(p []byte) (int, error) { func (mbr *maxBytesReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
readSoFar := mbr.read.Load() readSoFar := mbr.read.Load()
remaining := mbr.n - readSoFar
if remaining < 0 { // 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误.
if readSoFar >= mbr.n {
return 0, ErrBodyTooLarge return 0, ErrBodyTooLarge
} }
if remaining == 0 {
var probe [1]byte
n, err := mbr.r.Read(probe[:])
if n > 0 {
mbr.read.Add(int64(n))
return 0, ErrBodyTooLarge
}
if err != nil {
return 0, err
}
if mbr.emptyAtLimit.Swap(true) {
return 0, ErrBodyTooLarge
}
return 0, nil
}
mbr.emptyAtLimit.Store(false)
// 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。 // 计算当前还可以读取多少字节.
if int64(len(p))-1 > remaining { remaining := mbr.n - readSoFar
p = p[:remaining+1]
// 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度.
// 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数.
if int64(len(p)) > remaining {
p = p[:remaining]
} }
// 从底层 Reader 读取数据. // 从底层 Reader 读取数据.
n, err := mbr.r.Read(p) n, err := mbr.r.Read(p)
if int64(n) <= remaining { // 如果实际读取到了数据, 更新原子计数器.
if n > 0 { if n > 0 {
mbr.read.Add(int64(n)) readSoFar = mbr.read.Add(int64(n))
} }
// 如果底层 Read 返回错误 (例如 io.EOF).
if err != nil {
// 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束.
// 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了.
return n, err return n, err
} }
// 读取结果跨过了限制,只向上层暴露允许的部分。 // 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误.
if remaining > 0 { // 这是处理"跨越"限制情况的关键.
mbr.read.Add(remaining) if readSoFar > mbr.n {
// 返回实际读取的字节数 n, 并附上超限错误.
// 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭.
return n, ErrBodyTooLarge
} }
return int(remaining), ErrBodyTooLarge
// 一切正常, 返回读取的字节数和 nil 错误.
return n, nil
} }
// Close 方法关闭底层的 ReadCloser, 保证资源释放. // Close 方法关闭底层的 ReadCloser, 保证资源释放.

View file

@ -11,16 +11,18 @@ import (
) )
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. // mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
// 嵌入 cancelCtx 作为基础 context, 支持 cause 传播.
// deadlineCtx 作为 cancelCtx 的子 context, 确保 deadline 到期时 cancelCtx 也被取消.
type mergedContext struct { type mergedContext struct {
// 嵌入一个基础 context, 它持有最早的 deadline 和取消信号.
context.Context context.Context
// 保存了所有的父 context, 用于 Value() 方法的查找.
parents []context.Context parents []context.Context
// 用于手动取消此 mergedContext 的函数.
cancel context.CancelFunc
} }
// MergeCtx 创建并返回一个新的 context.Context. // MergeCtx 创建并返回一个新的 context.Context.
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时, // 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context. // 自动被取消 (逻辑或关系).
// //
// 新的 context 会继承: // 新的 context 会继承:
// - Deadline: 所有父 context 中最早的截止时间. // - Deadline: 所有父 context 中最早的截止时间.
@ -30,8 +32,7 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
return context.WithCancel(context.Background()) return context.WithCancel(context.Background())
} }
if len(parents) == 1 { if len(parents) == 1 {
ctx, cancel := context.WithCancelCause(parents[0]) return context.WithCancel(parents[0])
return ctx, func() { cancel(nil) }
} }
var earliestDeadline time.Time var earliestDeadline time.Time
@ -43,93 +44,79 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
} }
} }
// cancelCtx 作为基础 context, 提供 CancelCauseFunc 以支持 cause 传播. var baseCtx context.Context
cancelCtx, cancelCause := context.WithCancelCause(context.Background()) var baseCancel context.CancelFunc
// deadlineCtx 作为 cancelCtx 的子 context (如果有 deadline).
// 当 cancelCtx 被取消时, deadlineCtx 也会被取消;
// 当 deadline 到期时, deadlineCtx 自行取消, watcher 负责关闭 cancelCtx.
var deadlineCtx context.Context
var deadlineCancel context.CancelFunc
if !earliestDeadline.IsZero() { if !earliestDeadline.IsZero() {
deadlineCtx, deadlineCancel = context.WithDeadlineCause(cancelCtx, earliestDeadline, context.DeadlineExceeded) baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline)
} } else {
baseCtx, baseCancel = context.WithCancel(context.Background())
// 嵌入的 context: 有 deadline 时用 deadlineCtx (以返回正确的 Deadline),
// 否则用 cancelCtx.
embedCtx := cancelCtx
if deadlineCtx != nil {
embedCtx = deadlineCtx
} }
mc := &mergedContext{ mc := &mergedContext{
Context: embedCtx, Context: baseCtx,
parents: parents, parents: parents,
cancel: baseCancel,
} }
// 启动监控 goroutine, 监听 parent 取消或 deadline 到期. // 启动一个监控 goroutine.
go func() { go func() {
// 将 cancelCtx 加入 orDone, 确保手动 cancel() 时 orDone goroutine 能退出, 防止泄漏. defer mc.cancel()
parentDone := orDone(append(mc.parents, cancelCtx)...)
if deadlineCtx != nil { // orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭.
defer deadlineCancel() // 同时监听 baseCtx.Done() 以便支持手动取消.
select { select {
case <-parentDone: case <-orDone(mc.parents...):
// parent 取消或手动 cancel() case <-mc.Context.Done():
for _, p := range mc.parents {
if p.Err() != nil {
cancelCause(context.Cause(p))
return
}
}
// 手动 cancel(), cause 已由 cancelCause() 设置
case <-deadlineCtx.Done():
// deadline 到期, 需要关闭 cancelCtx 并设置 cause
cancelCause(context.DeadlineExceeded)
}
} else {
<-parentDone
for _, p := range mc.parents {
if p.Err() != nil {
cancelCause(context.Cause(p))
return
}
}
} }
}() }()
return mc, func() { cancelCause(nil) } return mc, mc.cancel
} }
// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause), // Value 返回当前Ctx Value
// 再按传入顺序从 parents 中查找.
func (mc *mergedContext) Value(key any) any { func (mc *mergedContext) Value(key any) any {
if v := mc.Context.Value(key); v != nil { return mc.Context.Value(key)
return v
}
for _, p := range mc.parents {
if val := p.Value(key); val != nil {
return val
}
}
return nil
} }
// Deadline, Done, Err 均由嵌入的 context.Context 提供. // Deadline 实现了 context.Context 的 Deadline 方法.
func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) {
return mc.Context.Deadline()
}
// orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭. // Done 实现了 context.Context 的 Done 方法.
func (mc *mergedContext) Done() <-chan struct{} {
return mc.Context.Done()
}
// Err 实现了 context.Context 的 Err 方法.
func (mc *mergedContext) Err() error {
return mc.Context.Err()
}
// orDone 是一个辅助函数, 返回一个 channel.
// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭.
// 这是一个非阻塞的、不会泄漏 goroutine 的实现.
func orDone(contexts ...context.Context) <-chan struct{} { func orDone(contexts ...context.Context) <-chan struct{} {
done := make(chan struct{}) done := make(chan struct{})
var once sync.Once var once sync.Once
closeDone := func() {
once.Do(func() {
close(done)
})
}
// 为每个父 context 启动一个 goroutine.
for _, ctx := range contexts { for _, ctx := range contexts {
go func(c context.Context) { go func(c context.Context) {
select { select {
case <-c.Done(): case <-c.Done():
once.Do(func() { close(done) }) closeDone()
case <-done: case <-done:
// orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出.
} }
}(ctx) }(ctx)
} }
return done return done
} }

View file

@ -1,256 +0,0 @@
package touka
import (
"context"
"errors"
"testing"
"time"
)
func TestMergeCtx_NoParents(t *testing.T) {
ctx, cancel := MergeCtx()
defer cancel()
if ctx.Err() != nil {
t.Fatal("expected no error before cancel")
}
cancel()
if ctx.Err() == nil {
t.Fatal("expected error after cancel")
}
}
func TestMergeCtx_SingleParent(t *testing.T) {
parent, parentCancel := context.WithCancel(context.Background())
ctx, cancel := MergeCtx(parent)
defer cancel()
if ctx.Err() != nil {
t.Fatal("expected no error before parent cancel")
}
parentCancel()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after parent cancel")
}
}
func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
cancel1()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after p1 cancel")
}
// p2 should still be fine
if p2.Err() != nil {
t.Fatal("expected p2 to be unaffected")
}
}
func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel1()
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
cancel2()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after p2 cancel")
}
}
func TestMergeCtx_ExternalCancel(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel1()
defer cancel2()
ctx, cancel := MergeCtx(p1, p2)
cancel()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after external cancel")
}
}
func TestMergeCtx_CausePropagation(t *testing.T) {
testErr := errors.New("test cause")
p1, cancel1 := context.WithCancelCause(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
cancel1(testErr)
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after p1 cancel")
}
cause := context.Cause(ctx)
if cause != testErr {
t.Fatalf("expected cause %v, got %v", testErr, cause)
}
cancel1(nil) // cleanup (already cancelled, no-op)
}
func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) {
testErr := errors.New("second parent cause")
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancelCause(context.Background())
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
cancel2(testErr)
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after p2 cancel")
}
cause := context.Cause(ctx)
if cause != testErr {
t.Fatalf("expected cause %v, got %v", testErr, cause)
}
cancel1()
}
func TestMergeCtx_Deadline_Earliest(t *testing.T) {
now := time.Now()
early := now.Add(100 * time.Millisecond)
late := now.Add(1 * time.Hour)
p1, cancel1 := context.WithDeadline(context.Background(), late)
p2, cancel2 := context.WithDeadline(context.Background(), early)
defer cancel1()
defer cancel2()
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
dl, ok := ctx.Deadline()
if !ok {
t.Fatal("expected deadline to be set")
}
if !dl.Equal(early) {
t.Fatalf("expected deadline %v, got %v", early, dl)
}
}
func TestMergeCtx_Deadline_Expires(t *testing.T) {
p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancelP()
ctx, cancel := MergeCtx(p)
defer cancel()
<-ctx.Done()
if ctx.Err() == nil {
t.Fatal("expected error after deadline expires")
}
}
func TestMergeCtx_ValueLookup(t *testing.T) {
type key struct{}
p1 := context.WithValue(context.Background(), key{}, "from_p1")
p2 := context.WithValue(context.Background(), key{}, "from_p2")
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
val := ctx.Value(key{})
if val != "from_p1" {
t.Fatalf("expected 'from_p1', got %v", val)
}
}
func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) {
type key1 struct{}
type key2 struct{}
p1 := context.WithValue(context.Background(), key1{}, "val1")
p2 := context.WithValue(context.Background(), key2{}, "val2")
ctx, cancel := MergeCtx(p1, p2)
defer cancel()
if v := ctx.Value(key1{}); v != "val1" {
t.Fatalf("expected 'val1', got %v", v)
}
if v := ctx.Value(key2{}); v != "val2" {
t.Fatalf("expected 'val2', got %v", v)
}
if v := ctx.Value("missing"); v != nil {
t.Fatalf("expected nil, got %v", v)
}
}
func TestMergeCtx_ContextInterface(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel1()
defer cancel2()
var ctx context.Context
ctx, _ = MergeCtx(p1, p2)
// Verify all Context interface methods work
_ = ctx.Done()
_ = ctx.Err()
_, _ = ctx.Deadline()
_ = ctx.Value("any")
}
func TestOrDone_SingleContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
done := orDone(ctx)
cancel()
<-done // should not block
}
func TestOrDone_MultipleContexts(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
done := orDone(p1, p2)
cancel1()
<-done // should not block
}
func TestOrDone_SecondContextCancels(t *testing.T) {
p1, cancel1 := context.WithCancel(context.Background())
p2, cancel2 := context.WithCancel(context.Background())
defer cancel1()
done := orDone(p1, p2)
cancel2()
<-done // should not block
}

View file

@ -70,25 +70,42 @@ func TestApplyDefaultServerConfig(t *testing.T) {
} }
} }
func TestTLSRunDefaultsProtocolInheritance(t *testing.T) { func TestRunTLSProtocolInheritance(t *testing.T) {
engine := New() engine := New()
srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) // 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2
if engine.useDefaultProtocols {
if !srv.Protocols.HTTP2() { engine.setProtocols(&ProtocolsConfig{
t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config") Http1: true,
Http2: true,
})
} }
// 模拟用户设置了自定义协议后进入 TLS 运行模式 srv := &http.Server{TLSConfig: &tls.Config{}}
engine.applyDefaultServerConfig(srv)
if !srv.Protocols.HTTP2() {
t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config")
}
// 模拟用户设置了自定义协议后调用 RunTLS
engine = New() engine = New()
engine.SetProtocols(&ProtocolsConfig{ engine.SetProtocols(&ProtocolsConfig{
Http1: true, Http1: true,
Http2: false, // 用户明确不想要 HTTP/2 Http2: false, // 用户明确不想要 HTTP/2
}) })
srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) if engine.useDefaultProtocols {
engine.setProtocols(&ProtocolsConfig{
Http1: true,
Http2: true,
})
}
srv2 := &http.Server{TLSConfig: &tls.Config{}}
engine.applyDefaultServerConfig(srv2)
if srv2.Protocols.HTTP2() { if srv2.Protocols.HTTP2() {
t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols") t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously")
} }
} }

View file

@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// 尝试从底层 ResponseWriter 获取 Hijacker 接口 // 尝试从底层 ResponseWriter 获取 Hijacker 接口
hj, ok := rw.ResponseWriter.(http.Hijacker) hj, ok := rw.ResponseWriter.(http.Hijacker)
if !ok { if !ok {
return nil, nil, http.ErrNotSupported return nil, nil, errors.New("http.Hijacker interface not supported")
} }
// 调用底层的 Hijack 方法 // 调用底层的 Hijack 方法

File diff suppressed because it is too large Load diff

View file

@ -1,355 +0,0 @@
package touka
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
)
type benchmarkReadSeeker struct {
data []byte
off int
}
func (r *benchmarkReadSeeker) Read(p []byte) (int, error) {
if r.off >= len(r.data) {
return 0, io.EOF
}
n := copy(p, r.data[r.off:])
r.off += n
return n, nil
}
func (r *benchmarkReadSeeker) Reset() {
r.off = 0
}
type benchmarkResponseWriter struct {
header http.Header
status int
size int
}
func newBenchmarkResponseWriter() *benchmarkResponseWriter {
return &benchmarkResponseWriter{header: make(http.Header)}
}
func (w *benchmarkResponseWriter) Header() http.Header {
return w.header
}
func (w *benchmarkResponseWriter) WriteHeader(statusCode int) {
if w.status == 0 {
w.status = statusCode
}
}
func (w *benchmarkResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
w.size += len(p)
return len(p), nil
}
func (w *benchmarkResponseWriter) Flush() {}
func (w *benchmarkResponseWriter) Status() int {
return w.status
}
func (w *benchmarkResponseWriter) Size() int {
return w.size
}
func (w *benchmarkResponseWriter) Written() bool {
return w.status != 0
}
func (w *benchmarkResponseWriter) IsHijacked() bool {
return false
}
func (w *benchmarkResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, http.ErrNotSupported
}
func (w *benchmarkResponseWriter) reset() {
clear(w.header)
w.status = 0
w.size = 0
}
var benchmarkReverseProxySink int
func BenchmarkReverseProxyCopyResponse(b *testing.B) {
body := bytes.Repeat([]byte("0123456789abcdef"), 4096)
proxy := newReverseProxyHandler(ReverseProxyConfig{})
dst := newBenchmarkResponseWriter()
src := &benchmarkReadSeeker{data: body}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
dst.reset()
src.Reset()
if err := proxy.copyResponse(dst, src, 0); err != nil {
b.Fatalf("copyResponse failed: %v", err)
}
}
benchmarkReverseProxySink = dst.Size()
}
func BenchmarkReverseProxyAvailableUpstreams(b *testing.B) {
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
{key: "d", index: 3},
},
config: ReverseProxyConfig{
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
MaxFails: 3,
},
},
}
now := time.Now()
proxy.upstreams[0].failures = []time.Time{now.Add(-30 * time.Second)}
proxy.upstreams[1].failures = []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkReverseProxySink = len(proxy.availableUpstreams(now, nil))
}
}
func BenchmarkReverseProxySelectUpstream(b *testing.B) {
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
{key: "d", index: 3},
},
config: ReverseProxyConfig{
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
MaxFails: 3,
},
},
}
proxy.upstreams[0].failures = []time.Time{time.Now().Add(-30 * time.Second)}
c, _ := CreateTestContext(nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
selected, err := proxy.selectUpstream(c, nil)
if err != nil {
b.Fatalf("selectUpstream failed: %v", err)
}
benchmarkReverseProxySink = selected.index
}
}
func BenchmarkReverseProxySelectUpstreamHeaderPolicy(b *testing.B) {
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
{key: "d", index: 3},
},
config: ReverseProxyConfig{
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())},
},
}
c, _ := CreateTestContext(nil)
c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b", "tenant-c"}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
selected, err := proxy.selectUpstream(c, nil)
if err != nil {
b.Fatalf("selectUpstream failed: %v", err)
}
benchmarkReverseProxySink = selected.index
}
}
func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) {
proxy := newReverseProxyHandler(ReverseProxyConfig{})
dst := newBenchmarkResponseWriter()
src := bytes.NewBufferString("hello, reverse proxy")
if err := proxy.copyResponse(dst, src, 0); err != nil {
t.Fatalf("copyResponse failed: %v", err)
}
if got, want := dst.Size(), len("hello, reverse proxy"); got != want {
t.Fatalf("expected %d bytes copied, got %d", want, got)
}
}
type fixedLenBufferPool struct {
buf []byte
}
func (p *fixedLenBufferPool) Get() []byte {
return p.buf
}
func (p *fixedLenBufferPool) Put(buf []byte) {
p.buf = buf
}
type recordingReader struct {
chunk int
reads []int
left int
}
func (r *recordingReader) Read(p []byte) (int, error) {
if r.left == 0 {
return 0, io.EOF
}
n := min(r.chunk, len(p), r.left)
if n == 0 {
return 0, errors.New("reader received zero-length buffer")
}
for i := range n {
p[i] = 'x'
}
r.left -= n
r.reads = append(r.reads, len(p))
return n, nil
}
func TestReverseProxyCopyResponseRespectsCustomBufferLength(t *testing.T) {
pool := &fixedLenBufferPool{buf: make([]byte, 8, 32*1024)}
proxy := newReverseProxyHandler(ReverseProxyConfig{BufferPool: pool})
dst := newBenchmarkResponseWriter()
src := &recordingReader{chunk: 8, left: 24}
if err := proxy.copyResponse(dst, src, 0); err != nil {
t.Fatalf("copyResponse failed: %v", err)
}
if len(src.reads) == 0 {
t.Fatal("expected reader to be used")
}
for _, size := range src.reads {
if size != 8 {
t.Fatalf("expected custom buffer length 8 to be preserved, got read size %d", size)
}
}
}
func TestReverseProxyAvailableUpstreamsFiltersExcludedAndUnhealthy(t *testing.T) {
now := time.Now()
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a"},
{key: "b", failures: []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}},
{key: "c"},
},
config: ReverseProxyConfig{
PassiveHealth: ReverseProxyPassiveHealthConfig{
FailDuration: time.Minute,
MaxFails: 2,
},
},
}
available := proxy.availableUpstreams(now, map[string]struct{}{"c": {}})
if len(available) != 1 {
t.Fatalf("expected only one available upstream, got %d", len(available))
}
if available[0].key != "a" {
t.Fatalf("expected upstream 'a', got %q", available[0].key)
}
}
func TestReverseProxyHeaderPolicyUsesAllHeaderValues(t *testing.T) {
proxy := &reverseProxyHandler{
upstreams: []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
},
config: ReverseProxyConfig{
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())},
},
}
c, _ := CreateTestContext(nil)
c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b"}
selectedA, err := proxy.selectUpstream(c, nil)
if err != nil {
t.Fatalf("selectUpstream failed: %v", err)
}
selectedB, err := proxy.selectUpstream(c, nil)
if err != nil {
t.Fatalf("selectUpstream failed: %v", err)
}
if selectedA.key != selectedB.key {
t.Fatalf("expected stable selection for identical multi-value header, got %q and %q", selectedA.key, selectedB.key)
}
c.Request.Header["X-Tenant"] = []string{"tenant-b", "tenant-a"}
selectedC, err := proxy.selectUpstream(c, nil)
if err != nil {
t.Fatalf("selectUpstream failed: %v", err)
}
if selectedC == nil {
t.Fatal("expected upstream for reordered multi-value header")
}
}
func TestReverseProxyHeaderPolicyMatchesJoinCompatibility(t *testing.T) {
candidates := []*reverseProxyUpstream{
{key: "a", index: 0},
{key: "b", index: 1},
{key: "c", index: 2},
}
testCases := [][]string{
{"tenant-a"},
{"tenant-a", "tenant-b"},
{"", "tenant-b"},
{"tenant-a", ""},
{"", ""},
}
for _, values := range testCases {
got := reverseProxySelectHRWValues(candidates, values)
want := reverseProxySelectHRW(candidates, strings.Join(values, ","))
if got == nil || want == nil {
t.Fatalf("expected non-nil upstreams for values %v", values)
}
if got.key != want.key {
t.Fatalf("expected joined compatibility for values %v, got %q want %q", values, got.key, want.key)
}
}
}
var _ io.Writer = (*benchmarkResponseWriter)(nil)

View file

@ -1,530 +0,0 @@
package touka
import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"testing"
)
func TestReverseProxyHeaderOpsReplaceSubstring(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Server"); got != "Caddy" {
t.Errorf("expected X-Server=Caddy, got %q", got)
}
if got := r.Header.Get("X-Location"); got != "/api/v2/resource" {
t.Errorf("expected X-Location=/api/v2/resource, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Replace: map[string][]Replacement{
"X-Server": {{Search: "NGINX", Replace: "Caddy"}},
"X-Location": {{Search: "v1", Replace: "v2"}},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Server", "NGINX")
req.Header.Set("X-Location", "/api/v1/resource")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsReplaceRegexp(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Route"); got != "/proxy-upstream" {
t.Errorf("expected X-Route=/proxy-upstream, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Replace: map[string][]Replacement{
"X-Route": {{SearchRegexp: `^/([^/]+)/(.+)$`, Replace: "/proxy-$2"}},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Route", "/original/upstream")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsReplaceWildcard(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Host-A"); got != "new.example.com" {
t.Errorf("expected X-Host-A=new.example.com, got %q", got)
}
if got := r.Header.Get("X-Host-B"); got != "new.example.com" {
t.Errorf("expected X-Host-B=new.example.com, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Replace: map[string][]Replacement{
"*": {{Search: "old.example.com", Replace: "new.example.com"}},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Host-A", "old.example.com")
req.Header.Set("X-Host-B", "old.example.com")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "backend-internal:8080")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
ResponseHeaders: &RespHeaderOps{
HeaderOps: &HeaderOps{
Replace: map[string][]Replacement{
"X-Backend": {{Search: "backend-internal:8080", Replace: "public.example.com"}},
},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if got := resp.Header.Get("X-Backend"); got != "public.example.com" {
t.Errorf("expected X-Backend=public.example.com, got %q", got)
}
}
func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Replace: map[string][]Replacement{
"X-Test": {{SearchRegexp: "[invalid"}},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("expected status 500, got %d", resp.StatusCode)
}
}
func TestReplacementApply(t *testing.T) {
tests := []struct {
name string
r *Replacement
s string
want string
}{
{name: "nil replacement", r: nil, s: "hello", want: "hello"},
{name: "empty string", r: &Replacement{Search: "x", Replace: "y"}, s: "", want: ""},
{name: "substring match", r: &Replacement{Search: "world", Replace: "go"}, s: "hello world", want: "hello go"},
{name: "substring no match", r: &Replacement{Search: "foo", Replace: "bar"}, s: "hello world", want: "hello world"},
{name: "substring multiple", r: &Replacement{Search: "a", Replace: "b"}, s: "aaa", want: "bbb"},
{name: "regexp match", r: &Replacement{SearchRegexp: `\d+`, Replace: "N", re: regexp.MustCompile(`\d+`)}, s: "abc123def", want: "abcNdef"},
{name: "regexp no match", r: &Replacement{SearchRegexp: `z+`, Replace: "Z", re: regexp.MustCompile(`z+`)}, s: "abc", want: "abc"},
{name: "empty search and regexp", r: &Replacement{}, s: "unchanged", want: "unchanged"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.r.apply(tt.s); got != tt.want {
t.Errorf("Replacement.apply() = %q, want %q", got, tt.want)
}
})
}
}
func BenchmarkHeaderOpsAdd(b *testing.B) {
ops := &HeaderOps{
Add: map[string][]string{
"X-Custom-1": {"value-1"},
"X-Custom-2": {"value-2"},
"X-Custom-3": {"value-3"},
},
}
hdr := make(http.Header)
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr = make(http.Header)
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsSet(b *testing.B) {
ops := &HeaderOps{
Set: map[string][]string{
"X-Frame-Options": {"DENY"},
"X-Content-Type-Options": {"nosniff"},
"X-XSS-Protection": {"1; mode=block"},
},
}
hdr := make(http.Header)
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr = make(http.Header)
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsDeleteSingle(b *testing.B) {
ops := &HeaderOps{
Delete: []string{"X-Powered-By"},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("X-Powered-By", "Express")
hdr.Set("X-Keep", "value")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsDeleteWildcard(b *testing.B) {
ops := &HeaderOps{
Delete: []string{"X-Debug-*"},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("X-Debug-1", "v1")
hdr.Set("X-Debug-2", "v2")
hdr.Set("X-Keep", "value")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsReplaceSubstring(b *testing.B) {
ops := &HeaderOps{
Replace: map[string][]Replacement{
"Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}},
},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("Location", "http://internal:8080/api/v1/users")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsReplaceRegexp(b *testing.B) {
re := regexp.MustCompile(`^http://([^/]+)(/.*)$`)
ops := &HeaderOps{
Replace: map[string][]Replacement{
"Location": {{SearchRegexp: `^http://([^/]+)(/.*)$`, Replace: "https://public.example.com$2", re: re}},
},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("Location", "http://internal:8080/api/v1/users")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsReplaceWildcard(b *testing.B) {
ops := &HeaderOps{
Replace: map[string][]Replacement{
"*": {{Search: "internal.example.com", Replace: "public.example.com"}},
},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("X-Host", "internal.example.com")
hdr.Set("X-Origin", "internal.example.com")
ops.applyTo(hdr, repl)
}
}
func BenchmarkHeaderOpsMixed(b *testing.B) {
ops := &HeaderOps{
Add: map[string][]string{
"X-Request-ID": {"req-123"},
},
Set: map[string][]string{
"X-Frame-Options": {"DENY"},
},
Delete: []string{"X-Powered-By"},
Replace: map[string][]Replacement{
"Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}},
},
}
repl := &reverseProxyReplacer{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
hdr := make(http.Header)
hdr.Set("X-Powered-By", "Express")
hdr.Set("Location", "http://internal:8080/api")
ops.applyTo(hdr, repl)
}
}
func BenchmarkReplacementApplySubstring(b *testing.B) {
r := &Replacement{Search: "old.example.com", Replace: "new.example.com"}
s := "https://old.example.com/api/v1/resource"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = r.apply(s)
}
}
func BenchmarkReplacementApplyRegexp(b *testing.B) {
r := &Replacement{SearchRegexp: `^https?://[^/]+`, Replace: "https://new.example.com", re: regexp.MustCompile(`^https?://[^/]+`)}
s := "https://old.example.com/api/v1/resource"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = r.apply(s)
}
}
func TestReverseProxyReplacerDynamicVars(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "http://example.com/api/v1/users?sort=name&limit=10", nil)
req.Host = "example.com"
repl := newReverseProxyReplacer(req)
tests := []struct {
name string
input string
want string
}{
{"method", "{method}", "GET"},
{"host", "{host}", "example.com"},
{"path", "{path}", "/api/v1/users"},
{"query", "{query}", "sort=name&limit=10"},
{"scheme", "{scheme}", "http"},
{"proto", "{proto}", "HTTP/1.1"},
{"combined", "X-{method}-{path}", "X-GET-/api/v1/users"},
{"no vars", "static-value", "static-value"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := repl.Replace(tt.input); got != tt.want {
t.Errorf("Replace(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestReverseProxyReplacerNilRequest(t *testing.T) {
repl := newReverseProxyReplacer(nil)
if got := repl.Replace("{method}"); got != "{method}" {
t.Errorf("expected unchanged string with nil request, got %q", got)
}
}
func TestReverseProxyReplacerNilReplacer(t *testing.T) {
var repl *reverseProxyReplacer
if got := repl.Replace("{method}"); got != "{method}" {
t.Errorf("expected unchanged string with nil replacer, got %q", got)
}
}
func TestReverseProxyReplacerFromHeader(t *testing.T) {
hdr := make(http.Header)
repl := newReverseProxyReplacerFromHeader(hdr)
if got := repl.Replace("{method}"); got != "{method}" {
t.Errorf("expected unchanged string from header replacer, got %q", got)
}
}
func TestReverseProxyHeaderOpsWithDynamicVars(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Forwarded-Path"); got != "/dynamic/path" {
t.Errorf("expected X-Forwarded-Path=/dynamic/path, got %q", got)
}
if got := r.Header.Get("X-Forwarded-Method"); got != "GET" {
t.Errorf("expected X-Forwarded-Method=GET, got %q", got)
}
if got := r.Header.Get("X-Forwarded-Host"); got != "client.example" {
t.Errorf("expected X-Forwarded-Host=client.example, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/dynamic/path", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Add: map[string][]string{
"X-Forwarded-Path": {"{path}"},
"X-Forwarded-Method": {"{method}"},
"X-Forwarded-Host": {"{host}"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/dynamic/path", nil)
req.Host = "client.example"
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}

View file

@ -1,220 +0,0 @@
package touka
import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestReverseProxyHeaderOpsAdd(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Custom-Header"); got != "test-value" {
t.Errorf("expected X-Custom-Header=test-value, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Add: map[string][]string{
"X-Custom-Header": {"test-value"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsDelete(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Sensitive") != "" {
t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive"))
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Delete: []string{"X-Sensitive"},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Sensitive", "should-be-removed")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsSet(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("X-Replace")
if got != "new-value" {
t.Errorf("expected X-Replace=new-value, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Set: map[string][]string{
"X-Replace": {"new-value"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Replace", "old-value")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyResponseHeaderOps(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "backend-server")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
ResponseHeaders: &RespHeaderOps{
HeaderOps: &HeaderOps{
Set: map[string][]string{
"X-Custom": {"custom-value"},
},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if got := resp.Header.Get("X-Custom"); got != "custom-value" {
t.Errorf("expected X-Custom=custom-value, got %q", got)
}
}
func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Powered-By", "Express")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
ResponseHeaders: &RespHeaderOps{
HeaderOps: &HeaderOps{
Delete: []string{"X-Powered-By"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if got := resp.Header.Get("X-Powered-By"); got != "" {
t.Errorf("expected X-Powered-By to be deleted, got %q", got)
}
}

View file

@ -1,409 +0,0 @@
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
// Copyright 2026 WJQSERVER. All rights reserved.
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
package touka
import (
"fmt"
"math/rand/v2"
"net/http"
"net/textproto"
"net/url"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
)
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
type ReverseProxyLoadBalancingConfig struct {
Policy ReverseProxyLBPolicy
Retries int
TryDuration time.Duration
TryInterval time.Duration
}
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
type ReverseProxyPassiveHealthConfig struct {
FailDuration time.Duration
MaxFails int
UnhealthyStatus []int
}
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
type ReverseProxyLBPolicy struct {
kind reverseProxyLBPolicyKind
key string
fallback *ReverseProxyLBPolicy
}
type reverseProxyLBPolicyKind uint8
const (
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
reverseProxyLBPolicyRoundRobin
reverseProxyLBPolicyFirst
reverseProxyLBPolicyLeastConn
reverseProxyLBPolicyIPHash
reverseProxyLBPolicyClientIPHash
reverseProxyLBPolicyURIHash
reverseProxyLBPolicyHeader
reverseProxyLBPolicyQuery
)
type reverseProxyUpstream struct {
key string
target *url.URL
index int
useH2C bool
extendedConnectTransport http.RoundTripper
bridgeTransport http.RoundTripper
h2cTransport http.RoundTripper
inFlight atomic.Int64
passiveMu sync.Mutex
failures []time.Time
}
func LBRandom() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
}
func LBRoundRobin() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
}
func LBFirst() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
}
func LBLeastConn() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
}
func LBIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
}
func LBClientIPHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
}
func LBURIHash() ReverseProxyLBPolicy {
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
}
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
policy.fallback = &fallback
}
return policy
}
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
switch policy.kind {
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
reverseProxyLBPolicyURIHash:
return nil
case reverseProxyLBPolicyHeader:
if policy.key == "" {
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
}
case reverseProxyLBPolicyQuery:
if policy.key == "" {
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
}
default:
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
}
if policy.fallback != nil {
return validateReverseProxyLBPolicy(*policy.fallback)
}
return nil
}
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
now := time.Now()
policy := p.config.LoadBalancing.Policy
candidateBuf := reverseProxyCandidatePool.Get().(*[]*reverseProxyUpstream)
candidates := p.availableUpstreamsInto(now, excluded, *candidateBuf)
if len(candidates) == 0 && len(excluded) > 0 {
candidates = p.availableUpstreamsInto(now, nil, candidates[:0])
}
if len(candidates) == 0 {
*candidateBuf = candidates[:0]
reverseProxyCandidatePool.Put(candidateBuf)
return nil, errReverseProxyNoAvailableUpstreams
}
selected := p.selectUpstreamWithPolicy(c, candidates, policy)
*candidateBuf = candidates[:0]
reverseProxyCandidatePool.Put(candidateBuf)
return selected, nil
}
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
return p.availableUpstreamsInto(now, excluded, nil)
}
func (p *reverseProxyHandler) availableUpstreamsInto(now time.Time, excluded map[string]struct{}, candidates []*reverseProxyUpstream) []*reverseProxyUpstream {
if cap(candidates) < len(p.upstreams) {
candidates = make([]*reverseProxyUpstream, 0, len(p.upstreams))
} else {
candidates = candidates[:0]
}
for _, upstream := range p.upstreams {
if _, skip := excluded[upstream.key]; skip {
continue
}
if !upstream.healthy(now, p.config.PassiveHealth) {
continue
}
candidates = append(candidates, upstream)
}
return candidates
}
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
switch policy.kind {
case reverseProxyLBPolicyRoundRobin:
return candidates[p.nextRoundRobinIndex(len(candidates))]
case reverseProxyLBPolicyFirst:
return candidates[0]
case reverseProxyLBPolicyLeastConn:
return p.selectLeastConnUpstream(candidates)
case reverseProxyLBPolicyIPHash:
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
case reverseProxyLBPolicyClientIPHash:
return reverseProxySelectHRW(candidates, c.RequestIP())
case reverseProxyLBPolicyURIHash:
if c.Request == nil || c.Request.URL == nil {
return reverseProxySelectRandom(candidates)
}
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
case reverseProxyLBPolicyHeader:
if c.Request != nil && c.Request.Header != nil {
if values, ok := c.Request.Header[policy.key]; ok {
return reverseProxySelectHRWValues(candidates, values)
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyQuery:
if c.Request != nil && c.Request.URL != nil {
if values, ok := c.Request.URL.Query()[policy.key]; ok {
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
}
}
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
case reverseProxyLBPolicyRandom:
fallthrough
default:
return reverseProxySelectRandom(candidates)
}
}
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
if size <= 1 {
return 0
}
return int((p.roundRobin.Add(1) - 1) % uint64(size))
}
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
selected := candidates[0]
lowest := selected.inFlight.Load()
ties := []*reverseProxyUpstream{selected}
for _, upstream := range candidates[1:] {
count := upstream.inFlight.Load()
switch {
case count < lowest:
selected = upstream
lowest = count
ties = []*reverseProxyUpstream{upstream}
case count == lowest:
ties = append(ties, upstream)
}
}
if len(ties) == 1 {
return selected
}
return ties[p.nextRoundRobinIndex(len(ties))]
}
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if len(candidates) == 1 {
return candidates[0]
}
return candidates[rand.IntN(len(candidates))]
}
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if key == "" {
return reverseProxySelectRandom(candidates)
}
selected := candidates[0]
bestScore := reverseProxyHRWScore(key, selected.key)
for _, upstream := range candidates[1:] {
score := reverseProxyHRWScore(key, upstream.key)
if score > bestScore {
selected = upstream
bestScore = score
}
}
return selected
}
func reverseProxySelectHRWValues(candidates []*reverseProxyUpstream, values []string) *reverseProxyUpstream {
if len(candidates) == 0 {
return nil
}
if len(values) == 0 {
return reverseProxySelectRandom(candidates)
}
selected := candidates[0]
bestScore := reverseProxyHRWValuesScore(values, selected.key)
for _, upstream := range candidates[1:] {
score := reverseProxyHRWValuesScore(values, upstream.key)
if score > bestScore {
selected = upstream
bestScore = score
}
}
return selected
}
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
const (
offset64 = 14695981039346656037
prime64 = 1099511628211
)
h := uint64(offset64)
for i := 0; i < len(key); i++ {
h ^= uint64(key[i])
h *= prime64
}
h ^= 0xff
h *= prime64
for i := 0; i < len(upstreamKey); i++ {
h ^= uint64(upstreamKey[i])
h *= prime64
}
return h
}
func reverseProxyHRWValuesScore(values []string, upstreamKey string) uint64 {
const (
offset64 = 14695981039346656037
prime64 = 1099511628211
)
h := uint64(offset64)
for valueIndex, value := range values {
for i := 0; i < len(value); i++ {
h ^= uint64(value[i])
h *= prime64
}
if valueIndex+1 < len(values) {
h ^= ','
h *= prime64
}
}
h ^= 0xff
h *= prime64
for i := 0; i < len(upstreamKey); i++ {
h ^= uint64(upstreamKey[i])
h *= prime64
}
return h
}
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
if policy.fallback != nil {
return *policy.fallback
}
return LBRandom()
}
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return true
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
return len(u.failures) < maxFails
}
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
maxFails := reverseProxyPassiveMaxFails(config)
if config.FailDuration <= 0 || maxFails <= 0 {
return
}
u.passiveMu.Lock()
defer u.passiveMu.Unlock()
u.pruneFailuresLocked(now, config.FailDuration)
u.failures = append(u.failures, now)
}
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
if len(u.failures) == 0 || window <= 0 {
if window <= 0 {
u.failures = nil
}
return
}
cutoff := now.Add(-window)
keep := 0
for _, failureAt := range u.failures {
if failureAt.Before(cutoff) {
continue
}
u.failures[keep] = failureAt
keep++
}
u.failures = u.failures[:keep]
}
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
if config.FailDuration <= 0 {
return 0
}
if config.MaxFails <= 0 {
return 1
}
return config.MaxFails
}
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
if status <= 0 {
return false
}
return slices.Contains(config.UnhealthyStatus, status)
}

File diff suppressed because it is too large Load diff

View file

@ -1,130 +0,0 @@
package touka
import "testing"
var (
benchmarkRouteHandlers HandlersChain
benchmarkRouteFullPath string
benchmarkRouteParamsLen int
benchmarkRouteCIPath []byte
benchmarkRouteCIFound bool
)
func buildRouteMatchBenchmarkTree() *node {
tree := &node{}
routes := []string{
"/",
"/health",
"/contact",
"/api/v1/users",
"/api/v1/users/:id",
"/api/v1/users/:id/settings",
"/assets/*filepath",
"/abc/b",
"/abc/:p1/cde",
"/abc/:p1/:p2/def/*filepath",
}
for _, route := range routes {
tree.addRoute(route, fakeHandler(route))
}
return tree
}
func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) {
b.Helper()
params := make(Params, 0, 4)
skipped := make([]skippedNode, 0, 8)
value := tree.getValue(path, &params, &skipped, true)
if wantFullPath == "" {
if value.handlers != nil {
b.Fatalf("expected no match for %q, got %q", path, value.fullPath)
}
} else {
if value.handlers == nil {
b.Fatalf("expected match for %q, got nil handlers", path)
}
if value.fullPath != wantFullPath {
b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath)
}
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
params = params[:0]
skipped = skipped[:0]
value = tree.getValue(path, &params, &skipped, true)
}
benchmarkRouteHandlers = value.handlers
benchmarkRouteFullPath = value.fullPath
if value.params != nil {
benchmarkRouteParamsLen = len(*value.params)
} else {
benchmarkRouteParamsLen = 0
}
}
func BenchmarkRouteMatch(b *testing.B) {
tree := buildRouteMatchBenchmarkTree()
b.Run("StaticHit", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users")
})
b.Run("ParamHit", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id")
})
b.Run("BacktrackingHit", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath")
})
b.Run("Miss", func(b *testing.B) {
benchmarkRouteLookup(b, tree, "/does/not/exist", "")
})
b.Run("CaseInsensitiveHit", func(b *testing.B) {
path := "/API/V1/USERS/123/SETTINGS"
out, found := tree.findCaseInsensitivePath(path, true)
if !found {
b.Fatalf("expected fixed-path match for %q", path)
}
if got := string(out); got != "/api/v1/users/123/settings" {
b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
out, found = tree.findCaseInsensitivePath(path, true)
}
benchmarkRouteCIPath = out
benchmarkRouteCIFound = found
})
b.Run("CaseInsensitiveMiss", func(b *testing.B) {
path := "/DOES/NOT/EXIST"
out, found := tree.findCaseInsensitivePath(path, true)
if found || out != nil {
b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
out, found = tree.findCaseInsensitivePath(path, true)
}
benchmarkRouteCIPath = out
benchmarkRouteCIFound = found
})
}

791
serve.go
View file

@ -14,7 +14,6 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -22,322 +21,329 @@ import (
"github.com/fenthope/reco" "github.com/fenthope/reco"
) )
// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间
const defaultShutdownTimeout = 5 * time.Second const defaultShutdownTimeout = 5 * time.Second
type runMode uint8 // --- 内部辅助函数 ---
const ( // resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080"
runModeHTTP runMode = iota func resolveAddress(addr []string) string {
runModeHTTPS switch len(addr) {
runModeHTTPSRedirect case 0:
) return ":8080"
case 1:
type runConfig struct { return addr[0]
addr string default:
httpRedirectAddr string panic("too many parameters provided for server address")
tlsConfig *tls.Config
redirectHost string
redirectHostHeaders []string
useHeaderHost bool
useHeaderHostSet bool
graceful bool
shutdownTimeout time.Duration
gracefulCtx context.Context
mode runMode
shutdownDefaultSet bool
shutdownTimeoutSet bool
}
type RunOption interface {
apply(*runConfig) error
}
type runOptionFunc func(*runConfig) error
func (f runOptionFunc) apply(cfg *runConfig) error {
return f(cfg)
}
func defaultRunConfig() runConfig {
return runConfig{
addr: ":8080",
shutdownTimeout: defaultShutdownTimeout,
mode: runModeHTTP,
useHeaderHost: true,
} }
} }
type HTTPRedirectOption interface { // getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值
applyRedirect(*runConfig) error func getShutdownTimeout(timeouts []time.Duration) time.Duration {
} if len(timeouts) > 0 && timeouts[0] > 0 {
return timeouts[0]
type redirectOptionFunc func(*runConfig) error
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
return f(cfg)
}
func WithAddr(addr string) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if addr == "" {
return errors.New("run address must not be empty")
}
cfg.addr = addr
return nil
})
}
func WithTLS(tlsConfig *tls.Config) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if tlsConfig == nil {
return errors.New("tls.Config must not be nil")
}
cfg.tlsConfig = tlsConfig
if cfg.mode == runModeHTTP {
cfg.mode = runModeHTTPS
}
return nil
})
}
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if addr == "" {
return errors.New("http redirect address must not be empty")
}
cfg.httpRedirectAddr = addr
cfg.mode = runModeHTTPSRedirect
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt.applyRedirect(cfg); err != nil {
return err
}
}
return nil
})
}
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.useHeaderHost = enabled
cfg.useHeaderHostSet = true
return nil
})
}
func WithRedirectHost(host string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
if host == "" {
return errors.New("redirect host must not be empty")
}
cfg.redirectHost = host
return nil
})
}
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
for _, header := range headers {
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
if trimmed != "" {
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
}
}
return nil
})
}
func WithGracefulShutdown(timeout time.Duration) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
cfg.graceful = true
cfg.shutdownTimeoutSet = true
if timeout > 0 {
cfg.shutdownTimeout = timeout
} else {
cfg.shutdownTimeout = defaultShutdownTimeout
}
return nil
})
}
func WithGracefulShutdownDefault() RunOption {
return runOptionFunc(func(cfg *runConfig) error {
cfg.graceful = true
cfg.shutdownDefaultSet = true
cfg.shutdownTimeout = defaultShutdownTimeout
return nil
})
}
func WithShutdownContext(ctx context.Context) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if ctx == nil {
return errors.New("shutdown context must not be nil")
}
cfg.gracefulCtx = ctx
return nil
})
}
func serveServer(srv *http.Server, serveTLS bool) error {
if serveTLS {
return srv.ListenAndServeTLS("", "")
} }
return srv.ListenAndServe() return defaultShutdownTimeout
} }
func runServer(serverType string, srv *http.Server, serveTLS bool) { // runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server,
// 并处理其启动失败的致命错误
// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS")
func runServer(serverType string, srv *http.Server) {
go func() { go func() {
var err error
protocol := "http" protocol := "http"
if serveTLS { if srv.TLSConfig != nil {
protocol = "https" protocol = "https"
} }
log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr)
err := serveServer(srv, serveTLS) if srv.TLSConfig != nil {
// 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置,
// ListenAndServeTLS 的前两个参数可以为空字符串
err = srv.ListenAndServeTLS("", "")
} else {
err = srv.ListenAndServe()
}
// 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed),
// 则认为是一个严重错误,并终止程序
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Touka %s server failed: %v", serverType, err) log.Fatalf("Touka %s server failed: %v", serverType, err)
} }
}() }()
} }
func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config { // handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器
if tlsConfig == nil { // 这是所有支持优雅关闭的 RunXXX 方法的最终归宿
return nil func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error {
} // 创建一个 channel 来接收操作系统信号
return tlsConfig.Clone() quit := make(chan os.Signal, 1)
} signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
<-quit // 阻塞,直到接收到上述信号之一
log.Println("Shutting down Touka server(s)...")
func parseHTTPSPort(addr string) (string, error) { // 关闭日志记录器
_, port, err := net.SplitHostPort(addr) if logger != nil {
if err != nil { go func() {
return "", fmt.Errorf("https address %q must include a port: %w", addr, err) log.Println("Closing Touka logger...")
} CloseLogger(logger)
return port, nil }()
}
func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) {
if serveTLS {
if engine.TLSServerConfigurator != nil {
engine.TLSServerConfigurator(srv)
return
}
}
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
}
func applyRedirectServerConfig(engine *Engine, srv *http.Server) {
applyServerProtocols(srv, engine.serverProtocols)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
}
func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols {
if engine == nil {
return nil
}
if serveTLS && engine.useDefaultProtocols {
protocols := &http.Protocols{}
protocols.SetHTTP1(true)
protocols.SetHTTP2(true)
return protocols
}
return cloneServerProtocols(engine.serverProtocols)
}
func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
serveTLS := cfg.mode != runModeHTTP
server := &http.Server{
Addr: cfg.addr,
Handler: engine,
TLSConfig: cloneTLSConfig(cfg.tlsConfig),
}
if cfg.graceful {
server.BaseContext = func(net.Listener) context.Context {
return engine.shutdownCtx
}
server.RegisterOnShutdown(engine.shutdownCancel)
}
applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS))
applyMainServerConfig(engine, server, serveTLS)
return server
}
func firstRedirectHeaderHost(r *http.Request, headers []string) string {
if r == nil {
return ""
}
for _, header := range headers {
value := strings.TrimSpace(r.Header.Get(header))
if value == "" {
continue
}
if comma := strings.IndexByte(value, ','); comma >= 0 {
value = strings.TrimSpace(value[:comma])
}
if value != "" {
return value
}
}
return ""
}
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if cfg.redirectHost == "" {
return "", http.StatusInternalServerError, false
}
return cfg.redirectHost, 0, true
} }
if len(cfg.redirectHostHeaders) > 0 { // 创建一个带超时的上下文,用于 Shutdown
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders) ctx, cancel := context.WithTimeout(context.Background(), timeout)
if host == "" { defer cancel()
return "", http.StatusUpgradeRequired, false
}
return host, 0, true
}
if r == nil { var wg sync.WaitGroup
return "", http.StatusUpgradeRequired, false errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
}
host := strings.TrimSpace(r.Host)
if host == "" {
return "", http.StatusUpgradeRequired, false
}
return host, 0, true
}
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { // 并发地关闭所有服务器
httpsAddr := cfg.addr for _, srv := range servers {
httpAddr := cfg.httpRedirectAddr wg.Add(1)
httpsPort, err := parseHTTPSPort(httpsAddr) go func(s *http.Server) {
if err != nil { defer wg.Done()
return nil, err if err := s.Shutdown(ctx); err != nil {
} // 将错误发送到 channel
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, statusCode, ok := redirectTargetHost(r, cfg)
if !ok {
http.Error(w, http.StatusText(statusCode), statusCode)
return
}
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
host = parsedHost
if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
host = "[" + host + "]"
} }
}(srv)
}
wg.Wait() // 等待所有服务器的关闭 goroutine 完成
close(errChan) // 关闭 channel,以便可以安全地遍历它
// 收集所有关闭过程中发生的错误
var shutdownErrors []error
for err := range errChan {
shutdownErrors = append(shutdownErrors, err)
log.Printf("Shutdown error: %v", err)
}
if len(shutdownErrors) > 0 {
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
}
log.Println("Touka server(s) exited gracefully.")
return nil
}
func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error {
// 创建一个 channel 来接收操作系统信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
// 启动服务器
serverStopped := make(chan error, 1)
for _, srv := range servers {
go func(s *http.Server) {
serverStopped <- s.ListenAndServe()
}(srv)
}
select {
case <-ctx.Done():
// Context 被取消 (例如,通过外部取消函数)
log.Println("Context cancelled, shutting down Touka server(s)...")
case err := <-serverStopped:
// 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("Touka HTTP server failed: %w", err)
}
log.Println("Touka HTTP server stopped gracefully.")
return nil // 服务器已自行优雅关闭,无需进一步处理
case <-quit:
// 接收到操作系统信号
log.Println("Shutting down Touka server(s) due to OS signal...")
}
// 关闭日志记录器
if logger != nil {
go func() {
log.Println("Closing Touka logger...")
CloseLogger(logger)
}()
}
// 创建一个带超时的上下文,用于 Shutdown
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
// 并发地关闭所有服务器
for _, srv := range servers {
wg.Add(1)
go func(s *http.Server) {
defer wg.Done()
if err := s.Shutdown(shutdownCtx); err != nil {
// 将错误发送到 channel
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
}
}(srv)
}
wg.Wait()
close(errChan) // 关闭 channel,以便可以安全地遍历它
// 收集所有关闭过程中发生的错误
var shutdownErrors []error
for err := range errChan {
shutdownErrors = append(shutdownErrors, err)
log.Printf("Shutdown error: %v", err)
}
if len(shutdownErrors) > 0 {
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
}
log.Println("Touka server(s) exited gracefully.")
return nil
}
// --- 公共 Run 方法 ---
// Run 启动一个不支持优雅关闭的 HTTP 服务器
// 这是一个阻塞调用,主要用于简单的场景或快速测试
// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法
func (engine *Engine) Run(addr ...string) error {
address := resolveAddress(addr)
srv := &http.Server{Addr: address, Handler: engine}
// 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性
engine.applyDefaultServerConfig(srv)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address)
return srv.ListenAndServe()
}
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error {
srv := &http.Server{
Addr: addr,
Handler: engine,
BaseContext: func(l net.Listener) context.Context {
return engine.shutdownCtx
},
}
srv.RegisterOnShutdown(engine.shutdownCancel)
// 应用框架的默认配置和用户提供的自定义配置
engine.applyDefaultServerConfig(srv)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
runServer("HTTP", srv)
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
}
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error {
srv := &http.Server{
Addr: addr,
Handler: engine,
BaseContext: func(l net.Listener) context.Context {
return engine.shutdownCtx
},
}
srv.RegisterOnShutdown(engine.shutdownCancel)
// 应用框架的默认配置和用户提供的自定义配置
engine.applyDefaultServerConfig(srv)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco)
}
// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器
func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
if tlsConfig == nil {
return errors.New("tls.Config must not be nil for RunTLS")
}
// 配置 HTTP/2 支持 (如果使用默认配置)
if engine.useDefaultProtocols {
engine.setProtocols(&ProtocolsConfig{
Http1: true,
Http2: true, // 默认在 TLS 上启用 HTTP/2
})
}
srv := &http.Server{
Addr: addr,
Handler: engine,
TLSConfig: tlsConfig,
BaseContext: func(l net.Listener) context.Context {
return engine.shutdownCtx
},
}
srv.RegisterOnShutdown(engine.shutdownCancel)
// 应用框架的默认配置和用户提供的自定义配置
// 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator
engine.applyDefaultServerConfig(srv)
if engine.TLSServerConfigurator != nil {
engine.TLSServerConfigurator(srv)
} else if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
runServer("HTTPS", srv)
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
}
// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名
func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
return engine.RunTLS(addr, tlsConfig, timeouts...)
}
// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭
func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
if tlsConfig == nil {
return errors.New("tls.Config must not be nil for RunTLSRedir")
}
// --- HTTPS 服务器 ---
if engine.useDefaultProtocols {
engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true})
}
httpsSrv := &http.Server{
Addr: httpsAddr,
Handler: engine,
TLSConfig: tlsConfig,
BaseContext: func(l net.Listener) context.Context {
return engine.shutdownCtx
},
}
httpsSrv.RegisterOnShutdown(engine.shutdownCancel)
engine.applyDefaultServerConfig(httpsSrv)
if engine.TLSServerConfigurator != nil {
engine.TLSServerConfigurator(httpsSrv)
} else if engine.ServerConfigurator != nil {
engine.ServerConfigurator(httpsSrv)
}
// --- HTTP 重定向服务器 ---
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
_, httpsPort, err := net.SplitHostPort(httpsAddr)
if err != nil {
// 如果 httpsAddr 没有端口,这是一个配置错误
log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr)
} }
targetURL := "https://" + host targetURL := "https://" + host
// 只有在非标准 HTTPS 端口 (443) 时才附加端口号
if httpsPort != "443" { if httpsPort != "443" {
targetURL = "https://" + net.JoinHostPort(host, httpsPort) targetURL = "https://" + net.JoinHostPort(host, httpsPort)
} }
@ -345,205 +351,22 @@ func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
http.Redirect(w, r, targetURL, http.StatusMovedPermanently) http.Redirect(w, r, targetURL, http.StatusMovedPermanently)
}) })
httpSrv := &http.Server{
Addr: httpAddr,
Handler: redirectHandler,
}
engine.applyDefaultServerConfig(httpSrv)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(httpSrv)
}
server := &http.Server{Addr: httpAddr, Handler: redirectHandler} // --- 启动服务器和优雅关闭 ---
applyRedirectServerConfig(engine, server) runServer("HTTPS", httpsSrv)
return server, nil runServer("HTTP Redirect", httpSrv)
return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco)
} }
func validateRunConfig(cfg runConfig) error { // RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性
if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil { func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
return errors.New("WithHTTPRedirect requires WithTLS") return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...)
}
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
return errors.New("https mode requires WithTLS")
}
if cfg.gracefulCtx != nil && !cfg.graceful {
return errors.New("WithShutdownContext requires graceful shutdown")
}
if len(cfg.redirectHostHeaders) > 0 {
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
}
}
if cfg.useHeaderHostSet && cfg.useHeaderHost {
if cfg.redirectHost != "" {
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
}
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if cfg.redirectHost == "" {
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
}
if len(cfg.redirectHostHeaders) > 0 {
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
}
}
return nil
}
func effectiveShutdownTimeout(cfg runConfig) time.Duration {
if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet {
if cfg.shutdownTimeout > 0 {
return cfg.shutdownTimeout
}
}
return defaultShutdownTimeout
}
func closeLoggerAsync(logger *reco.Logger) {
if logger == nil {
return
}
go func() {
log.Println("Closing Touka logger...")
CloseLogger(logger)
}()
}
func shutdownServers(servers []*http.Server, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, len(servers))
for _, srv := range servers {
wg.Add(1)
go func(s *http.Server) {
defer wg.Done()
if err := s.Shutdown(ctx); err != nil {
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
}
}(srv)
}
wg.Wait()
close(errChan)
var shutdownErrors []error
for err := range errChan {
shutdownErrors = append(shutdownErrors, err)
log.Printf("Shutdown error: %v", err)
}
if len(shutdownErrors) > 0 {
return errors.Join(shutdownErrors...)
}
return nil
}
func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(quit)
serverStopped := make(chan error, len(servers))
for i, srv := range servers {
serveTLSFlag := serveTLS[i]
go func(server *http.Server, useTLS bool) {
serverStopped <- serveServer(server, useTLS)
}(srv, serveTLSFlag)
}
select {
case err := <-serverStopped:
if err != nil && !errors.Is(err, http.ErrServerClosed) {
if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil {
return errors.Join(err, shutdownErr)
}
return err
}
log.Println("Touka server stopped gracefully.")
return nil
case <-quit:
log.Println("Shutting down Touka server(s) due to OS signal...")
case <-shutdownCtx.Done():
log.Println("Context cancelled, shutting down Touka server(s)...")
}
closeLoggerAsync(logger)
if err := shutdownServers(servers, timeout); err != nil {
return err
}
log.Println("Touka server(s) exited gracefully.")
return nil
}
// Run starts the engine with the provided startup options.
//
// Default behavior with no options:
// - HTTP only
// - listens on :8080
// - no graceful shutdown orchestration
//
// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable
// signal-aware graceful shutdown and request-context cancellation semantics.
// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown.
func (engine *Engine) Run(opts ...RunOption) error {
cfg := defaultRunConfig()
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt.apply(&cfg); err != nil {
return err
}
}
if cfg.httpRedirectAddr != "" {
cfg.mode = runModeHTTPSRedirect
} else if cfg.tlsConfig != nil {
cfg.mode = runModeHTTPS
}
if err := validateRunConfig(cfg); err != nil {
return err
}
serveTLS := cfg.mode != runModeHTTP
mainServer := buildMainServer(engine, cfg)
servers := []*http.Server{mainServer}
serveTLSFlags := []bool{serveTLS}
if cfg.mode == runModeHTTPSRedirect {
redirectServer, err := buildRedirectServer(engine, cfg)
if err != nil {
return err
}
servers = append(servers, redirectServer)
serveTLSFlags = append(serveTLSFlags, false)
}
if !cfg.graceful {
if len(servers) > 1 {
serverStopped := make(chan error, len(servers))
for i, srv := range servers {
serveTLSFlag := serveTLSFlags[i]
go func(server *http.Server, useTLS bool) {
serverStopped <- serveServer(server, useTLS)
}(srv, serveTLSFlag)
}
err := <-serverStopped
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return errors.Join(err, shutdownErr)
}
return shutdownErr
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
protocolLabel := "HTTP"
if serveTLS {
protocolLabel = "HTTPS"
}
log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr)
return serveServer(mainServer, serveTLS)
}
shutdownCtx := context.Background()
if cfg.gracefulCtx != nil {
shutdownCtx = cfg.gracefulCtx
}
return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx)
} }

View file

@ -1,492 +0,0 @@
package touka
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func generateSelfSignedCert(t *testing.T) tls.Certificate {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate private key: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "127.0.0.1"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("create self-signed cert: %v", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("parse self-signed cert: %v", err)
}
return cert
}
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on ephemeral port: %v", err)
}
addr := listener.Addr().String()
if err := listener.Close(); err != nil {
t.Fatalf("close temporary listener: %v", err)
}
srv := &http.Server{
Addr: addr,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("ok"))
}),
// RunShutdown uses the HTTP startup path and must not let a shared
// ServerConfigurator accidentally turn it into HTTPS.
TLSConfig: &tls.Config{},
}
errCh := make(chan error, 1)
go func() {
errCh <- serveServer(srv, false)
}()
client := &http.Client{Timeout: 200 * time.Millisecond}
var resp *http.Response
requestURL := "http://" + addr
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
resp, err = client.Get(requestURL)
if err == nil {
break
}
time.Sleep(20 * time.Millisecond)
}
if err != nil {
select {
case serveErr := <-errCh:
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: request error=%v, serve error=%v", err, serveErr)
default:
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err)
}
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: got %d want %d", resp.StatusCode, http.StatusOK)
}
if string(body) != "ok" {
t.Fatalf("unexpected body: got %q want %q", string(body), "ok")
}
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
t.Fatalf("shutdown server: %v", err)
}
if err := <-errCh; !errors.Is(err, http.ErrServerClosed) {
t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err)
}
}
func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
engine := New()
err := engine.Run(WithHTTPRedirect(":80"))
if err == nil {
t.Fatal("expected redirect mode without TLS to fail")
}
}
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
engine := New()
err := engine.Run(
WithAddr(":443"),
WithTLS(&tls.Config{}),
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
)
if err == nil {
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
}
}
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
cfg := defaultRunConfig()
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
t.Fatalf("apply graceful default option: %v", err)
}
if !cfg.graceful {
t.Fatal("expected graceful shutdown to be enabled")
}
if cfg.shutdownTimeout != defaultShutdownTimeout {
t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout)
}
}
func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
cfg := defaultRunConfig()
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
if err := WithTLS(tlsConfig).apply(&cfg); err != nil {
t.Fatalf("apply TLS option: %v", err)
}
if cfg.mode != runModeHTTPS {
t.Fatalf("expected HTTPS mode, got %v", cfg.mode)
}
if cfg.graceful {
t.Fatal("expected TLS option to remain independent from graceful shutdown")
}
if cfg.tlsConfig != tlsConfig {
t.Fatal("expected TLS config to be preserved in run config")
}
}
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
engine := New()
if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
t.Fatal("expected redirect server builder to reject https address without port")
}
}
func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
cfg := defaultRunConfig()
ctx := t.Context()
if err := WithShutdownContext(ctx).apply(&cfg); err != nil {
t.Fatalf("apply shutdown context option: %v", err)
}
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected shutdown context without graceful shutdown to fail validation")
}
}
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
cfg := defaultRunConfig()
cfg.httpRedirectAddr = ":80"
if err := validateRunConfig(cfg); err != nil {
t.Fatalf("validate run config: %v", err)
}
if cfg.mode != runModeHTTP {
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
}
}
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = false
cfg.useHeaderHostSet = true
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected configured host mode without redirect host to fail validation")
}
}
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = true
cfg.useHeaderHostSet = true
cfg.redirectHost = "configured.example"
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
}
}
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
engine := New()
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
if server.BaseContext == nil {
t.Fatal("expected graceful main server to set BaseContext")
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for base context check: %v", err)
}
defer listener.Close()
if got := server.BaseContext(listener); got != engine.shutdownCtx {
t.Fatal("expected graceful main server to use engine shutdown context")
}
}
func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) {
engine := New()
serverConfigured := false
tlsConfigured := false
engine.SetServerConfigurator(func(s *http.Server) {
serverConfigured = true
s.ReadTimeout = time.Second
})
engine.SetTLSServerConfigurator(func(s *http.Server) {
tlsConfigured = true
s.IdleTimeout = time.Second
})
server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if !tlsConfigured {
t.Fatal("expected TLS configurator to run for HTTPS main server")
}
if serverConfigured {
t.Fatal("expected generic server configurator to be skipped when TLS configurator is set")
}
if server.IdleTimeout != time.Second {
t.Fatal("expected TLS configurator changes to be applied to HTTPS main server")
}
}
func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
engine := New()
configured := false
engine.SetServerConfigurator(func(s *http.Server) {
configured = true
s.ReadTimeout = time.Second
})
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
if !configured {
t.Fatal("expected redirect server to use generic server configurator")
}
if server.ReadTimeout != time.Second {
t.Fatal("expected redirect server configurator changes to be applied")
}
}
func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
engine := New()
httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if !httpsServer.Protocols.HTTP2() {
t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings")
}
httpServer := buildMainServer(engine, defaultRunConfig())
if httpServer.Protocols.HTTP2() {
t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled")
}
}
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
req.Header.Set("X-First-Host", "first.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUpgradeRequired {
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
}
}
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: false,
useHeaderHostSet: true,
redirectHost: "configured.example",
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil)
req.Host = "[::1]:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" {
t.Fatalf("unexpected IPv6 redirect location: %q", location)
}
}
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on occupied addr: %v", err)
}
occupiedAddr := occupied.Addr().String()
defer occupied.Close()
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for redirect addr: %v", err)
}
redirectAddr := redirectListener.Addr().String()
if err := redirectListener.Close(); err != nil {
t.Fatalf("close redirect addr probe: %v", err)
}
engine := New()
redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
mainServer := &http.Server{Addr: occupiedAddr, Handler: engine}
err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background())
if err == nil {
t.Fatal("expected gracefulServe to fail when one server cannot bind")
}
if !strings.Contains(err.Error(), occupiedAddr) {
t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err)
}
conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond)
if dialErr == nil {
conn.Close()
t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr)
}
if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") {
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
}
}
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on occupied addr: %v", err)
}
occupiedAddr := occupied.Addr().String()
defer occupied.Close()
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for redirect addr: %v", err)
}
redirectAddr := redirectListener.Addr().String()
if err := redirectListener.Close(); err != nil {
t.Fatalf("close redirect addr probe: %v", err)
}
engine := New()
err = engine.Run(
WithAddr(occupiedAddr),
WithTLS(&tls.Config{}),
WithHTTPRedirect(redirectAddr),
)
if err == nil {
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
}
if !strings.Contains(err.Error(), occupiedAddr) {
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
}
}

72
sse.go
View file

@ -111,40 +111,46 @@ func (c *Context) EventStream(streamer func(w io.Writer) bool) {
// EventStreamChan 返回用于 SSE 事件流的 channel. // EventStreamChan 返回用于 SSE 事件流的 channel.
// 这是为高级并发场景设计的、更灵活的API. // 这是为高级并发场景设计的、更灵活的API.
// //
// 与 EventStream 回调模式类似, 此方法是阻塞的: handler 会在此方法中停留, // 重要:
// 直到事件 channel 被关闭 (close eventChan) 或客户端断开连接. // - 调用者必须 close(eventChan) 来结束事件流.
// 这保证了 Context 不会在 SSE 流期间被 pool 回收. // - 调用者必须在独立的 goroutine 中消费 errChan 来处理错误和连接断开.
// // - 为防止 goroutine 泄漏, 建议发送方在 select 中同时监听 c.Request.Context().Done().
// eventChan 必须在调用此方法之前创建, 以便调用者可以在独立的 goroutine 中发送事件.
// 调用者必须在完成后 close(eventChan) 来结束流.
// 生产者 goroutine 必须在 select 中监听 c.Request.Context().Done(), 否则在客户端断开时会产生 goroutine 泄漏.
// //
// 详细用法: // 详细用法:
// //
// r.GET("/sse/channel", func(c *touka.Context) { // r.GET("/sse/channel", func(c *touka.Context) {
// eventChan := make(chan touka.Event) // eventChan, errChan := c.EventStreamChan()
// //
// // 在独立的 goroutine 中异步发送事件. // // 必须在独立的goroutine中处理错误和连接断开.
// go func() { // go func() {
// defer close(eventChan) // 完成后关闭 channel 以结束事件流. // if err := <-errChan; err != nil {
// c.Errorf("SSE channel error: %v", err)
// }
// }()
//
// // 在另一个goroutine中异步发送事件.
// go func() {
// // 重要: 必须在逻辑结束时关闭channel, 以通知框架.
// defer close(eventChan)
// //
// for i := 1; i <= 5; i++ { // for i := 1; i <= 5; i++ {
// select { // select {
// case <-c.Request.Context().Done(): // case <-c.Request.Context().Done():
// return // 客户端已断开, 退出 goroutine. // return // 客户端已断开, 退出 goroutine.
// case eventChan <- touka.Event{ // default:
// Id: fmt.Sprintf("%d", i), // eventChan <- touka.Event{
// Data: "hello from channel", // Id: fmt.Sprintf("%d", i),
// }: // Data: "hello from channel",
// }
// time.Sleep(2 * time.Second)
// } // }
// time.Sleep(2 * time.Second)
// } // }
// }() // }()
//
// // 阻塞直到事件流结束.
// c.EventStreamChan(eventChan)
// }) // })
func (c *Context) EventStreamChan(eventChan <-chan Event) { func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
eventChan := make(chan Event)
errChan := make(chan error, 1)
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8") c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
c.Writer.Header().Set("Cache-Control", "no-cache, no-transform") c.Writer.Header().Set("Cache-Control", "no-cache, no-transform")
c.Writer.Header().Del("Connection") c.Writer.Header().Del("Connection")
@ -153,16 +159,8 @@ func (c *Context) EventStreamChan(eventChan <-chan Event) {
c.Writer.WriteHeader(http.StatusOK) c.Writer.WriteHeader(http.StatusOK)
c.Writer.Flush() c.Writer.Flush()
// 捕获稳定的引用, 不持有 *Context 指针, 以免 Context 被 pool 回收后出现竞态.
w := c.Writer
fl, _ := w.(http.Flusher)
reqCtx := c.Request.Context()
goroutineExited := make(chan struct{})
// 写入 goroutine: 从 eventChan 消费事件并写入响应.
go func() { go func() {
defer close(goroutineExited) defer close(errChan)
for { for {
select { select {
@ -170,23 +168,17 @@ func (c *Context) EventStreamChan(eventChan <-chan Event) {
if !ok { if !ok {
return return
} }
if err := event.Render(w); err != nil { if err := event.Render(c.Writer); err != nil {
errChan <- err
return return
} }
if fl != nil { c.Writer.Flush()
fl.Flush() case <-c.Request.Context().Done():
} errChan <- c.Request.Context().Err()
case <-reqCtx.Done():
return return
} }
} }
}() }()
// 阻塞直到: return eventChan, errChan
// 1. 写入 goroutine 退出 (eventChan 关闭或写入失败)
// 2. 客户端断开连接 (reqCtx 取消)
select {
case <-goroutineExited:
case <-reqCtx.Done():
}
} }

View file

@ -1,142 +0,0 @@
package touka
import (
"context"
"net/http/httptest"
"strings"
"testing"
"time"
)
// TestEventStreamChanBlocksHandler verifies that EventStreamChan blocks until
// the event channel is closed.
func TestEventStreamChanBlocksHandler(t *testing.T) {
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/sse", nil)
c, _ := CreateTestContextWithRequest(rr, req)
handlerReturned := make(chan struct{})
eventChan := make(chan Event)
// Start producer goroutine before EventStreamChan blocks
go func() {
defer close(eventChan)
time.Sleep(30 * time.Millisecond)
eventChan <- Event{Data: "hello"}
time.Sleep(30 * time.Millisecond)
}()
go func() {
c.EventStreamChan(eventChan)
close(handlerReturned)
}()
// Wait for goroutine to start
time.Sleep(10 * time.Millisecond)
// Handler should NOT have returned (eventChan not closed)
select {
case <-handlerReturned:
t.Fatal("Handler returned before eventChan was closed - EventStreamChan is not blocking")
case <-time.After(40 * time.Millisecond):
// good, still blocking
}
// Wait for producer to finish (30+30ms + margin)
select {
case <-handlerReturned:
// good, handler returned
case <-time.After(200 * time.Millisecond):
t.Fatal("Handler did not return after eventChan was closed")
}
}
// TestEventStreamChanUnblocksOnClientDisconnect verifies the handler returns
// when the request context is cancelled, even if eventChan is never closed.
func TestEventStreamChanUnblocksOnClientDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/sse", nil).WithContext(ctx)
c, _ := CreateTestContextWithRequest(rr, req)
eventChan := make(chan Event)
handlerReturned := make(chan struct{})
// Producer never closes eventChan
go func() {
for {
select {
case <-ctx.Done():
return
case eventChan <- Event{Data: "tick"}:
time.Sleep(10 * time.Millisecond)
}
}
}()
go func() {
c.EventStreamChan(eventChan)
close(handlerReturned)
}()
// Handler should NOT have returned
select {
case <-handlerReturned:
t.Fatal("Handler returned before stream ended")
case <-time.After(60 * time.Millisecond):
// good, still blocked
}
// Cancel context to simulate client disconnect
cancel()
select {
case <-handlerReturned:
// good
case <-time.After(200 * time.Millisecond):
t.Fatal("Handler did not return after client disconnect")
}
}
// TestEventStreamChanWritesEvents verifies the SSE event format is correct.
func TestEventStreamChanWritesEvents(t *testing.T) {
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/sse", nil)
c, _ := CreateTestContextWithRequest(rr, req)
eventChan := make(chan Event)
go func() {
defer close(eventChan)
eventChan <- Event{Id: "1", Event: "tick", Data: "hello\nworld"}
eventChan <- Event{Id: "2", Data: "second"}
}()
c.EventStreamChan(eventChan)
body := rr.Body.String()
ct := rr.Header().Get("Content-Type")
if !strings.Contains(ct, "text/event-stream") {
t.Fatalf("expected text/event-stream content type, got %q", ct)
}
if !strings.Contains(body, "id: 1") {
t.Fatal("missing id field in first event")
}
if !strings.Contains(body, "event: tick") {
t.Fatal("missing event field in first event")
}
if !strings.Contains(body, "data: hello") {
t.Fatal("missing data line 1 in first event")
}
if !strings.Contains(body, "data: world") {
t.Fatal("missing data line 2 in first event")
}
if !strings.Contains(body, "id: 2") {
t.Fatal("missing id field in second event")
}
if !strings.Contains(body, "data: second") {
t.Fatal("missing data in second event")
}
}

View file

@ -22,10 +22,10 @@ type HandlerFunc func(*Context)
// HandlersChain 定义处理函数链(中间件栈)的类型。 // HandlersChain 定义处理函数链(中间件栈)的类型。
type HandlersChain []HandlerFunc type HandlersChain []HandlerFunc
// Router 定义了路由注册的接口提供路由分组和HTTP方法注册的能力。 // IRouter 定义了路由注册的接口提供路由分组和HTTP方法注册的能力。
type Router interface { type IRouter interface {
Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组 Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组
Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组 Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
GET(relativePath string, handlers ...HandlerFunc) GET(relativePath string, handlers ...HandlerFunc)

118
tree.go
View file

@ -121,28 +121,14 @@ const (
// node 表示路由树中的一个节点. // node 表示路由树中的一个节点.
type node struct { type node struct {
path string // 当前节点的路径段 path string // 当前节点的路径段
indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点
wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) wildChild bool // 是否包含通配符子节点(:param 或 *catchAll)
hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由 nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有)
nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) priority uint32 // 节点的优先级, 用于查找时优先匹配
priority uint32 // 节点的优先级, 用于查找时优先匹配 children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾
children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 handlers HandlersChain // 绑定到此节点的处理函数链
handlers HandlersChain // 绑定到此节点的处理函数链 fullPath string // 完整路径, 用于调试和错误信息
fullPath string // 完整路径, 用于调试和错误信息
}
func routeNeedsCaseInsensitiveLookup(path string) bool {
for i := 0; i < len(path); i++ {
c := path[i]
if c >= utf8.RuneSelf {
return true
}
if c >= 'A' && c <= 'Z' {
return true
}
}
return false
} }
// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序.
@ -176,9 +162,6 @@ func (n *node) incrementChildPrio(pos int) int {
func (n *node) addRoute(path string, handlers HandlersChain) { func (n *node) addRoute(path string, handlers HandlersChain) {
fullPath := path // 记录完整的路径 fullPath := path // 记录完整的路径
n.priority++ // 增加当前节点的优先级 n.priority++ // 增加当前节点的优先级
if routeNeedsCaseInsensitiveLookup(path) {
n.hasCaseInsensitivePath = true
}
// 如果是空树(根节点) // 如果是空树(根节点)
if len(n.path) == 0 && len(n.children) == 0 { if len(n.path) == 0 && len(n.children) == 0 {
@ -469,14 +452,12 @@ type skippedNode struct {
// 建议进行 TSR(尾部斜杠重定向). // 建议进行 TSR(尾部斜杠重定向).
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
var globalParamsCount int16 // 全局参数计数 var globalParamsCount int16 // 全局参数计数
var backtrackToWildChild bool
walk: // 外部循环用于遍历路由树 walk: // 外部循环用于遍历路由树
for { for {
prefix := n.path // 当前节点的路径前缀 prefix := n.path // 当前节点的路径前缀
if len(path) > len(prefix) { if len(path) > len(prefix) {
if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头
pathAtNode := path
path = path[len(prefix):] // 移除已匹配的前缀 path = path[len(prefix):] // 移除已匹配的前缀
// 在访问 path[0] 之前进行安全检查 // 在访问 path[0] 之前进行安全检查
@ -486,26 +467,30 @@ walk: // 外部循环用于遍历路由树
// 优先尝试所有非通配符子节点, 通过匹配索引字符 // 优先尝试所有非通配符子节点, 通过匹配索引字符
idxc := path[0] // 剩余路径的第一个字符 idxc := path[0] // 剩余路径的第一个字符
if !backtrackToWildChild { for i, c := range []byte(n.indices) {
for i := 0; i < len(n.indices); i++ { if c == idxc { // 如果找到匹配的索引字符
if n.indices[i] == idxc { // 如果找到匹配的索引字符 // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯
// 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 if n.wildChild {
if n.wildChild { index := len(*skippedNodes)
index := len(*skippedNodes) *skippedNodes = (*skippedNodes)[:index+1]
*skippedNodes = (*skippedNodes)[:index+1] (*skippedNodes)[index] = skippedNode{
(*skippedNodes)[index] = skippedNode{ path: prefix + path, // 记录跳过的路径
path: pathAtNode, // 记录进入当前节点时的剩余路径 node: &node{ // 复制当前节点的状态
node: n, path: n.path,
paramsCount: globalParamsCount, // 记录当前参数计数 wildChild: n.wildChild,
} nType: n.nType,
priority: n.priority,
children: n.children,
handlers: n.handlers,
fullPath: n.fullPath,
},
paramsCount: globalParamsCount, // 记录当前参数计数
} }
n = n.children[i] // 移动到匹配的子节点
continue walk // 继续外部循环
} }
n = n.children[i] // 移动到匹配的子节点
continue walk // 继续外部循环
} }
} else {
backtrackToWildChild = false
} }
if !n.wildChild { if !n.wildChild {
@ -522,8 +507,7 @@ walk: // 外部循环用于遍历路由树
*value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片
} }
globalParamsCount = skippedNode.paramsCount // 恢复参数计数 globalParamsCount = skippedNode.paramsCount // 恢复参数计数
backtrackToWildChild = true continue walk // 继续外部循环
continue walk // 继续外部循环
} }
} }
} }
@ -563,7 +547,7 @@ walk: // 外部循环用于遍历路由树
i := len(*value.params) i := len(*value.params)
*value.params = (*value.params)[:i+1] // 扩展切片 *value.params = (*value.params)[:i+1] // 扩展切片
val := path[:end] // 提取参数值 val := path[:end] // 提取参数值
if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if unescape { // 如果需要进行 URL 解码
if v, err := url.QueryUnescape(val); err == nil { if v, err := url.QueryUnescape(val); err == nil {
val = v // 解码成功则更新值 val = v // 解码成功则更新值
} }
@ -615,7 +599,7 @@ walk: // 外部循环用于遍历路由树
i := len(*value.params) i := len(*value.params)
*value.params = (*value.params)[:i+1] // 扩展切片 *value.params = (*value.params)[:i+1] // 扩展切片
val := path // 参数值是剩余的整个路径 val := path // 参数值是剩余的整个路径
if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if unescape { // 如果需要进行 URL 解码
if v, err := url.QueryUnescape(path); err == nil { if v, err := url.QueryUnescape(path); err == nil {
val = v // 解码成功则更新值 val = v // 解码成功则更新值
} }
@ -650,7 +634,6 @@ walk: // 外部循环用于遍历路由树
*value.params = (*value.params)[:skippedNode.paramsCount] *value.params = (*value.params)[:skippedNode.paramsCount]
} }
globalParamsCount = skippedNode.paramsCount globalParamsCount = skippedNode.paramsCount
backtrackToWildChild = true
continue walk continue walk
} }
} }
@ -675,8 +658,8 @@ walk: // 外部循环用于遍历路由树
} }
// 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议 // 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议
for i := 0; i < len(n.indices); i++ { for i, c := range []byte(n.indices) {
if n.indices[i] == '/' { // 如果索引中包含 '/' if c == '/' { // 如果索引中包含 '/'
n = n.children[i] // 移动到对应的子节点 n = n.children[i] // 移动到对应的子节点
value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
(n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数
@ -705,7 +688,6 @@ walk: // 外部循环用于遍历路由树
*value.params = (*value.params)[:skippedNode.paramsCount] *value.params = (*value.params)[:skippedNode.paramsCount]
} }
globalParamsCount = skippedNode.paramsCount globalParamsCount = skippedNode.paramsCount
backtrackToWildChild = true
continue walk continue walk
} }
} }
@ -719,15 +701,13 @@ walk: // 外部循环用于遍历路由树
// 它还可以选择修复尾部斜杠. // 它还可以选择修复尾部斜杠.
// 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功.
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash) const stackBufSize = 128 // 栈上缓冲区的默认大小
}
func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) { // 在常见情况下使用栈上静态大小的缓冲区.
if buf != nil { // 如果路径太长, 则在堆上分配缓冲区.
buf = buf[:0] buf := make([]byte, 0, stackBufSize)
} if length := len(path) + 1; length > stackBufSize {
if cap(buf) < len(path)+1 { buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区
buf = make([]byte, 0, len(path)+1)
} }
ciPath := n.findCaseInsensitivePathRec( ciPath := n.findCaseInsensitivePathRec(
@ -778,8 +758,8 @@ walk: // 外部循环用于遍历路由树
// 未找到处理函数. // 未找到处理函数.
// 尝试通过添加尾部斜杠来修复路径 // 尝试通过添加尾部斜杠来修复路径
if fixTrailingSlash { if fixTrailingSlash {
for i := 0; i < len(n.indices); i++ { for i, c := range []byte(n.indices) {
if n.indices[i] == '/' { // 如果索引中包含 '/' if c == '/' { // 如果索引中包含 '/'
n = n.children[i] // 移动到对应的子节点 n = n.children[i] // 移动到对应的子节点
if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
(n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数
@ -801,8 +781,8 @@ walk: // 外部循环用于遍历路由树
if rb[0] != 0 { if rb[0] != 0 {
// 旧 rune 未处理完 // 旧 rune 未处理完
idxc := rb[0] idxc := rb[0]
for i := 0; i < len(n.indices); i++ { for i, c := range []byte(n.indices) {
if n.indices[i] == idxc { if c == idxc {
// 继续处理子节点 // 继续处理子节点
n = n.children[i] n = n.children[i]
npLen = len(n.path) npLen = len(n.path)
@ -833,9 +813,9 @@ walk: // 外部循环用于遍历路由树
rb = shiftNRuneBytes(rb, off) rb = shiftNRuneBytes(rb, off)
idxc := rb[0] idxc := rb[0]
for i := 0; i < len(n.indices); i++ { for i, c := range []byte(n.indices) {
// 小写匹配 // 小写匹配
if n.indices[i] == idxc { if c == idxc {
// 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在 // 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在
if out := n.children[i].findCaseInsensitivePathRec( if out := n.children[i].findCaseInsensitivePathRec(
path, ciPath, rb, fixTrailingSlash, path, ciPath, rb, fixTrailingSlash,
@ -852,9 +832,9 @@ walk: // 外部循环用于遍历路由树
rb = shiftNRuneBytes(rb, off) rb = shiftNRuneBytes(rb, off)
idxc := rb[0] idxc := rb[0]
for i := 0; i < len(n.indices); i++ { for i, c := range []byte(n.indices) {
// 大写匹配 // 大写匹配
if n.indices[i] == idxc { if c == idxc {
// 继续处理子节点 // 继续处理子节点
n = n.children[i] n = n.children[i]
npLen = len(n.path) npLen = len(n.path)
@ -872,7 +852,7 @@ walk: // 外部循环用于遍历路由树
return nil // 未找到, 返回 nil return nil // 未找到, 返回 nil
} }
n = n.children[len(n.children)-1] // 通配符子节点约定始终位于末尾 n = n.children[0] // 移动到通配符子节点(通常是唯一一个)
switch n.nType { switch n.nType {
case param: // 参数节点 case param: // 参数节点
// 查找参数结束位置('/' 或路径末尾) // 查找参数结束位置('/' 或路径末尾)

View file

@ -11,7 +11,6 @@ import (
"regexp" "regexp"
"strings" "strings"
"testing" "testing"
"time"
) )
// Used as a workaround since we can't compare functions or their addresses // Used as a workaround since we can't compare functions or their addresses
@ -40,23 +39,6 @@ func getSkippedNodes() *[]skippedNode {
return &ps return &ps
} }
func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue {
t.Helper()
resultCh := make(chan nodeValue, 1)
go func() {
resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape)
}()
select {
case value := <-resultCh:
return value
case <-time.After(2 * time.Second):
t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path)
return nodeValue{}
}
}
func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) {
unescape := false unescape := false
if len(unescapes) >= 1 { if len(unescapes) >= 1 {
@ -919,34 +901,6 @@ func TestTreeInvalidNodeType(t *testing.T) {
} }
} }
func TestFindCaseInsensitivePathWithStaticAndParamRoutesDoesNotPanicOnMiss(t *testing.T) {
tree := &node{}
routes := [...]string{
"/:user/:repo/info/refs",
"/healthz",
"/api/db/data",
"/api/db/sum",
}
for _, route := range routes {
tree.addRoute(route, fakeHandler(route))
}
defer func() {
if r := recover(); r != nil {
t.Fatalf("unexpected panic while looking up missing path: %v", r)
}
}()
if out, found := tree.findCaseInsensitivePath("/does-not-exist", true); found || out != nil {
t.Fatalf("expected missing path lookup to return no match, got %q, %t", string(out), found)
}
if out, found := tree.findCaseInsensitivePath("/does-not-exist", false); found || out != nil {
t.Fatalf("expected missing path lookup without trailing slash fix to return no match, got %q, %t", string(out), found)
}
}
func TestTreeInvalidParamsType(t *testing.T) { func TestTreeInvalidParamsType(t *testing.T) {
tree := &node{} tree := &node{}
// add a child with wildcard // add a child with wildcard
@ -1122,51 +1076,3 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) {
t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams) t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams)
} }
} }
func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) {
tests := []struct {
name string
routes []string
requestPath string
wantFullPath string
wantParams Params
}{
{
name: "param route after static dead end",
routes: []string{"/foo/bar", "/foo/:id/details"},
requestPath: "/foo/bar/details",
wantFullPath: "/foo/:id/details",
wantParams: Params{{Key: "id", Value: "bar"}},
},
{
name: "catch-all route after static dead end",
routes: []string{"/foo/bar", "/foo/:id/*rest"},
requestPath: "/foo/bar/baz.txt",
wantFullPath: "/foo/:id/*rest",
wantParams: Params{
{Key: "id", Value: "bar"},
{Key: "rest", Value: "/baz.txt"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tree := &node{}
for _, route := range tt.routes {
tree.addRoute(route, fakeHandler(route))
}
value := getValueWithTimeout(t, tree, tt.requestPath, false)
if value.handlers == nil {
t.Fatalf("expected handlers for %q", tt.requestPath)
}
if value.fullPath != tt.wantFullPath {
t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath)
}
if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) {
t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params)
}
})
}
}