mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
commit
43fede96d5
50 changed files with 9252 additions and 791 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -1 +1,2 @@
|
|||
test
|
||||
/bench_route_match_baseline.txt
|
||||
|
|
|
|||
|
|
@ -59,9 +59,9 @@ func main() {
|
|||
c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query)
|
||||
})
|
||||
|
||||
// 启动服务器 (支持优雅关闭)
|
||||
// 启动服务器(通过 WithGracefulShutdown 启用优雅关闭)
|
||||
log.Println("Touka Server starting on :8080...")
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatalf("Touka server failed to start: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,13 +70,13 @@ func main() {
|
|||
r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB
|
||||
|
||||
// ... 其他配置
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
#### 1.3. 服务器生命周期管理
|
||||
|
||||
Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。
|
||||
Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。
|
||||
|
||||
```go
|
||||
func main() {
|
||||
|
|
@ -90,11 +90,11 @@ func main() {
|
|||
fmt.Println("自定义的 HTTP 服务器配置已应用")
|
||||
})
|
||||
|
||||
// 启动服务器,并支持优雅关闭
|
||||
// RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号
|
||||
// 第二个参数是优雅关闭的超时时间
|
||||
// 启动服务器,并通过 Run 选项启用优雅关闭
|
||||
// Run(...) 会阻塞当前 goroutine
|
||||
// WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒
|
||||
fmt.Println("服务器启动于 :8080")
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatalf("服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -187,7 +187,7 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
|
||||
func AuthMiddleware() touka.HandlerFunc {
|
||||
|
|
@ -313,7 +313,7 @@ func main() {
|
|||
})
|
||||
})
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
|
||||
// templates/index.html
|
||||
|
|
@ -400,7 +400,7 @@ func main() {
|
|||
c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID})
|
||||
})
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -483,7 +483,7 @@ func main() {
|
|||
// 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获
|
||||
r.StaticDir("/files", "./non-existent-dir")
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -546,7 +546,7 @@ func main() {
|
|||
// 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录
|
||||
r.StaticFS("/", http.FS(subFS))
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
|||
52
compat.go
Normal file
52
compat.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2024 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
import (
|
||||
"github.com/WJQSERVER-STUDIO/httpc"
|
||||
"github.com/fenthope/reco"
|
||||
)
|
||||
|
||||
// --- reco 兼容函数 ---
|
||||
|
||||
// GetLogReco 返回底层的 reco.Logger 实例
|
||||
// 用于需要访问 reco 特定功能的场景
|
||||
// 如果当前 logger 不是 *reco.Logger 类型,返回 nil
|
||||
//
|
||||
//go:fix inline
|
||||
func (engine *Engine) GetLogReco() *reco.Logger {
|
||||
return engine.LogReco
|
||||
}
|
||||
|
||||
// SetLogReco 设置 reco.Logger 实例
|
||||
// 用于向后兼容,等价于 SetLogger(l)
|
||||
//
|
||||
//go:fix inline
|
||||
func (engine *Engine) SetLogReco(l *reco.Logger) {
|
||||
engine.LogReco = l
|
||||
engine.logger = l
|
||||
}
|
||||
|
||||
// GetLoggerReco 返回底层的 reco.Logger 实例
|
||||
// 用于需要访问 reco 特定功能的场景
|
||||
// 如果当前 logger 不是 *reco.Logger 类型,返回 nil
|
||||
//
|
||||
//go:fix inline
|
||||
func (c *Context) GetLoggerReco() *reco.Logger {
|
||||
if rl, ok := c.engine.logger.(*reco.Logger); ok {
|
||||
return rl
|
||||
}
|
||||
return c.engine.LogReco
|
||||
}
|
||||
|
||||
// --- httpc 兼容函数 ---
|
||||
|
||||
// GetHTTPC 返回底层的 httpc.Client 实例
|
||||
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
|
||||
//
|
||||
//go:fix inline
|
||||
func (c *Context) GetHTTPC() *httpc.Client {
|
||||
return c.Client()
|
||||
}
|
||||
314
context.go
314
context.go
|
|
@ -26,7 +26,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/WJQSERVER/wanf"
|
||||
"github.com/fenthope/reco"
|
||||
"github.com/go-json-experiment/json"
|
||||
|
||||
"github.com/WJQSERVER-STUDIO/go-utils/iox"
|
||||
|
|
@ -44,6 +43,8 @@ type Context struct {
|
|||
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
|
||||
index int8 // 当前执行到处理链的哪个位置
|
||||
|
||||
requestBodyPrepared bool
|
||||
|
||||
mu sync.RWMutex
|
||||
Keys map[string]any // 用于在中间件之间传递数据
|
||||
|
||||
|
|
@ -71,6 +72,12 @@ type Context struct {
|
|||
// skippedNodes 用于记录跳过的节点信息,以便回溯
|
||||
// 通常在处理嵌套路由时使用
|
||||
SkippedNodes []skippedNode
|
||||
|
||||
// fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲.
|
||||
fixedPathBuf []byte
|
||||
|
||||
allowedMethodsBuf []string
|
||||
allowHeaderBuf []byte
|
||||
}
|
||||
|
||||
// --- Context 相关方法实现 ---
|
||||
|
|
@ -95,19 +102,42 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
c.handlers = nil
|
||||
c.index = -1 // 初始为 -1,`Next()` 将其设置为 0
|
||||
c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染
|
||||
c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map
|
||||
c.Errors = c.Errors[:0] // 清空 Errors 切片
|
||||
c.queryCache = nil // 清空查询参数缓存
|
||||
c.formCache = nil // 清空表单数据缓存
|
||||
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
||||
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
||||
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
||||
c.requestBodyPrepared = false
|
||||
|
||||
if cap(c.SkippedNodes) > 0 {
|
||||
c.SkippedNodes = c.SkippedNodes[:0]
|
||||
} else {
|
||||
c.SkippedNodes = make([]skippedNode, 0, 256)
|
||||
}
|
||||
if cap(c.fixedPathBuf) > 0 {
|
||||
c.fixedPathBuf = c.fixedPathBuf[:0]
|
||||
}
|
||||
if cap(c.allowedMethodsBuf) > 0 {
|
||||
c.allowedMethodsBuf = c.allowedMethodsBuf[:0]
|
||||
}
|
||||
if cap(c.allowHeaderBuf) > 0 {
|
||||
c.allowHeaderBuf = c.allowHeaderBuf[:0]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
if _, err := c.Writer.Write(data); err != nil {
|
||||
wrapped := fmt.Errorf("%s: %w", contextMsg, err)
|
||||
c.AddError(wrapped)
|
||||
if c.engine != nil && c.engine.logger != nil {
|
||||
c.engine.logger.Errorf("%s: %v", contextMsg, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Next 在处理链中执行下一个处理函数
|
||||
|
|
@ -237,6 +267,18 @@ func (c *Context) SetMaxRequestBodySize(size int64) {
|
|||
c.MaxRequestBodySize = size
|
||||
}
|
||||
|
||||
func (c *Context) prepareRequestBody() io.ReadCloser {
|
||||
if c.Request == nil || c.Request.Body == nil {
|
||||
return nil
|
||||
}
|
||||
if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 {
|
||||
return c.Request.Body
|
||||
}
|
||||
c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
c.requestBodyPrepared = true
|
||||
return c.Request.Body
|
||||
}
|
||||
|
||||
// Query 从 URL 查询参数中获取值
|
||||
// 懒加载解析查询参数,并进行缓存
|
||||
func (c *Context) Query(key string) string {
|
||||
|
|
@ -258,7 +300,39 @@ func (c *Context) DefaultQuery(key, defaultValue string) string {
|
|||
// 懒加载解析表单数据,并进行缓存
|
||||
func (c *Context) PostForm(key string) string {
|
||||
if c.formCache == nil {
|
||||
c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
c.prepareRequestBody()
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||
c.formCache = make(url.Values)
|
||||
return ""
|
||||
}
|
||||
|
||||
switch mediaType {
|
||||
case "multipart/form-data":
|
||||
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||
c.formCache = make(url.Values)
|
||||
return ""
|
||||
}
|
||||
case "application/x-www-form-urlencoded":
|
||||
if err := c.Request.ParseForm(); err != nil {
|
||||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||
c.formCache = make(url.Values)
|
||||
return ""
|
||||
}
|
||||
default:
|
||||
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||
if !errors.Is(err, http.ErrNotMultipart) {
|
||||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||
c.formCache = make(url.Values)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
c.formCache = c.Request.PostForm
|
||||
}
|
||||
return c.formCache.Get(key)
|
||||
|
|
@ -282,20 +356,20 @@ func (c *Context) Param(key string) string {
|
|||
func (c *Context) Raw(code int, contentType string, data []byte) {
|
||||
c.Writer.Header().Set("Content-Type", contentType)
|
||||
c.Writer.WriteHeader(code)
|
||||
c.Writer.Write(data)
|
||||
c.writeResponseBody(data, "failed to write raw response")
|
||||
}
|
||||
|
||||
// String 向响应写入格式化的字符串
|
||||
func (c *Context) String(code int, format string, values ...any) {
|
||||
c.Writer.WriteHeader(code)
|
||||
c.Writer.Write(fmt.Appendf(nil, format, values...))
|
||||
c.writeResponseBody(fmt.Appendf(nil, format, values...), "failed to write string response")
|
||||
}
|
||||
|
||||
// Text 向响应写入无需格式化的string
|
||||
func (c *Context) Text(code int, text string) {
|
||||
c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
c.Writer.WriteHeader(code)
|
||||
c.Writer.Write([]byte(text))
|
||||
c.writeResponseBody([]byte(text), "failed to write text response")
|
||||
}
|
||||
|
||||
// FileText
|
||||
|
|
@ -338,8 +412,11 @@ func (c *Context) FileText(code int, filePath string) {
|
|||
}
|
||||
|
||||
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
||||
|
||||
c.SetBodyStream(file, int(fileInfo.Size()))
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
|
||||
c.Writer.WriteHeader(code)
|
||||
if _, err := iox.Copy(c.Writer, file); err != nil {
|
||||
c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err))
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -417,6 +494,22 @@ func (c *Context) JSON(code int, obj any) {
|
|||
}
|
||||
}
|
||||
|
||||
// JSONBuf 先将 JSON 编码到 buffer, 成功后再写入状态码和响应体.
|
||||
// 与 JSON 相比,编码失败时可以正确返回 500 状态码,代价是多一次内存分配.
|
||||
func (c *Context) JSONBuf(code int, obj any) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.MarshalWrite(&buf, obj); err != nil {
|
||||
errMsg := fmt.Errorf("failed to marshal JSON: %w", err)
|
||||
c.AddError(errMsg)
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
c.Writer.WriteHeader(code)
|
||||
c.writeResponseBody(buf.Bytes(), "failed to write buffered JSON response")
|
||||
}
|
||||
|
||||
// GOB 向响应写入GOB数据
|
||||
// 设置 Content-Type 为 application/octet-stream
|
||||
func (c *Context) GOB(code int, obj any) {
|
||||
|
|
@ -431,6 +524,21 @@ func (c *Context) GOB(code int, obj any) {
|
|||
}
|
||||
}
|
||||
|
||||
// GOBBuf 先将 GOB 编码到 buffer, 成功后再写入状态码和响应体.
|
||||
func (c *Context) GOBBuf(code int, obj any) {
|
||||
var buf bytes.Buffer
|
||||
encoder := gob.NewEncoder(&buf)
|
||||
if err := encoder.Encode(obj); err != nil {
|
||||
errMsg := fmt.Errorf("failed to encode GOB: %w", err)
|
||||
c.AddError(errMsg)
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||||
return
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/octet-stream")
|
||||
c.Writer.WriteHeader(code)
|
||||
c.writeResponseBody(buf.Bytes(), "failed to write buffered GOB response")
|
||||
}
|
||||
|
||||
// WANF向响应写入WANF数据
|
||||
// 设置 application/vnd.wjqserver.wanf; charset=utf-8
|
||||
func (c *Context) WANF(code int, obj any) {
|
||||
|
|
@ -445,6 +553,21 @@ func (c *Context) WANF(code int, obj any) {
|
|||
}
|
||||
}
|
||||
|
||||
// WANFBuf 先将 WANF 编码到 buffer, 成功后再写入状态码和响应体.
|
||||
func (c *Context) WANFBuf(code int, obj any) {
|
||||
var buf bytes.Buffer
|
||||
encoder := wanf.NewStreamEncoder(&buf)
|
||||
if err := encoder.Encode(obj); err != nil {
|
||||
errMsg := fmt.Errorf("failed to encode WANF: %w", err)
|
||||
c.AddError(errMsg)
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||||
return
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8")
|
||||
c.Writer.WriteHeader(code)
|
||||
c.writeResponseBody(buf.Bytes(), "failed to write buffered WANF response")
|
||||
}
|
||||
|
||||
// HTML 渲染 HTML 模板
|
||||
// 如果 Engine 配置了 HTMLRender,则使用它进行渲染
|
||||
// 否则,会进行简单的字符串输出
|
||||
|
|
@ -466,7 +589,37 @@ func (c *Context) HTML(code int, name string, obj any) {
|
|||
// 可以扩展支持其他渲染器接口
|
||||
}
|
||||
// 默认简单输出,用于未配置 HTMLRender 的情况
|
||||
c.Writer.Write(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj))
|
||||
c.writeResponseBody(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj), "failed to write HTML response")
|
||||
}
|
||||
|
||||
// HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体.
|
||||
// 如果模板渲染失败,则返回 500 错误且不写入任何内容.
|
||||
func (c *Context) HTMLBuf(code int, name string, obj any) {
|
||||
if c.engine == nil || c.engine.HTMLRender == nil {
|
||||
// 没有渲染器,回退到简单输出
|
||||
c.HTML(code, name, obj)
|
||||
return
|
||||
}
|
||||
|
||||
if tpl, ok := c.engine.HTMLRender.(*template.Template); ok {
|
||||
var buf bytes.Buffer
|
||||
err := tpl.ExecuteTemplate(&buf, name, obj)
|
||||
if err != nil {
|
||||
// 渲染失败,记录错误并返回 500,不写入任何内容
|
||||
errMsg := fmt.Errorf("failed to render HTML template '%s': %w", name, err)
|
||||
c.AddError(errMsg)
|
||||
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||||
return
|
||||
}
|
||||
// 渲染成功,写入响应
|
||||
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
c.Writer.WriteHeader(code)
|
||||
c.writeResponseBody(buf.Bytes(), "failed to write buffered HTML response")
|
||||
return
|
||||
}
|
||||
|
||||
// 不支持的渲染器类型,回退到简单输出
|
||||
c.HTML(code, name, obj)
|
||||
}
|
||||
|
||||
// Redirect 执行 HTTP 重定向
|
||||
|
|
@ -481,10 +634,16 @@ func (c *Context) Redirect(code int, location string) {
|
|||
|
||||
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象
|
||||
func (c *Context) ShouldBindJSON(obj any) error {
|
||||
if c.Request.Body == nil {
|
||||
var body io.ReadCloser
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
body = c.prepareRequestBody()
|
||||
} else {
|
||||
body = c.Request.Body
|
||||
}
|
||||
if body == nil {
|
||||
return errors.New("request body is empty")
|
||||
}
|
||||
err := json.UnmarshalRead(c.Request.Body, obj)
|
||||
err := json.UnmarshalRead(body, obj)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json binding error: %w", err)
|
||||
}
|
||||
|
|
@ -493,10 +652,16 @@ func (c *Context) ShouldBindJSON(obj any) error {
|
|||
|
||||
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
|
||||
func (c *Context) ShouldBindWANF(obj any) error {
|
||||
if c.Request.Body == nil {
|
||||
var body io.ReadCloser
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
body = c.prepareRequestBody()
|
||||
} else {
|
||||
body = c.Request.Body
|
||||
}
|
||||
if body == nil {
|
||||
return errors.New("request body is empty")
|
||||
}
|
||||
decoder, err := wanf.NewStreamDecoder(c.Request.Body)
|
||||
decoder, err := wanf.NewStreamDecoder(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create WANF decoder: %w", err)
|
||||
}
|
||||
|
|
@ -509,10 +674,16 @@ func (c *Context) ShouldBindWANF(obj any) error {
|
|||
|
||||
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
|
||||
func (c *Context) ShouldBindGOB(obj any) error {
|
||||
if c.Request.Body == nil {
|
||||
var body io.ReadCloser
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
body = c.prepareRequestBody()
|
||||
} else {
|
||||
body = c.Request.Body
|
||||
}
|
||||
if body == nil {
|
||||
return errors.New("request body is empty")
|
||||
}
|
||||
decoder := gob.NewDecoder(c.Request.Body)
|
||||
decoder := gob.NewDecoder(body)
|
||||
if err := decoder.Decode(obj); err != nil {
|
||||
return fmt.Errorf("GOB binding error: %w", err)
|
||||
}
|
||||
|
|
@ -629,6 +800,10 @@ func setFieldValue(field reflect.Value, values []string) error {
|
|||
// ShouldBindForm 尝试将表单数据绑定到结构体
|
||||
// 支持 application/x-www-form-urlencoded 和 multipart/form-data
|
||||
func (c *Context) ShouldBindForm(obj any) error {
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
c.prepareRequestBody()
|
||||
}
|
||||
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
|
|
@ -637,7 +812,7 @@ func (c *Context) ShouldBindForm(obj any) error {
|
|||
|
||||
switch mediaType {
|
||||
case "multipart/form-data":
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
||||
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||
return fmt.Errorf("parse multipart form error: %w", err)
|
||||
}
|
||||
case "application/x-www-form-urlencoded":
|
||||
|
|
@ -651,6 +826,7 @@ func (c *Context) ShouldBindForm(obj any) error {
|
|||
if err := bindForm(c.Request.Form, obj); err != nil {
|
||||
return fmt.Errorf("form binding error: %w", err)
|
||||
}
|
||||
c.formCache = c.Request.PostForm
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -688,10 +864,29 @@ func (c *Context) GetErrors() []error {
|
|||
return c.Errors
|
||||
}
|
||||
|
||||
// Client 返回 Engine 提供的 HTTPClient
|
||||
// 方便在请求处理函数中进行出站 HTTP 请求
|
||||
// Client 返回当前请求的 HTTPClient
|
||||
// 如果请求处理函数或中间件设置了自定义 HTTPClient,返回该实例;
|
||||
// 否则返回 Engine 提供的默认实例
|
||||
//
|
||||
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
|
||||
func (c *Context) Client() *httpc.Client {
|
||||
if c.HTTPClient != nil {
|
||||
return c.HTTPClient
|
||||
}
|
||||
return c.engine.HTTPClient
|
||||
}
|
||||
|
||||
// HTTPC 返回自动关联请求 Context 的 HTTP 客户端
|
||||
// 当请求被取消时,通过此客户端发起的出站请求也会自动取消
|
||||
func (c *Context) HTTPC() *contextHTTPClient {
|
||||
client := c.HTTPClient
|
||||
if client == nil {
|
||||
client = c.engine.HTTPClient
|
||||
}
|
||||
return &contextHTTPClient{
|
||||
client: client,
|
||||
ctx: c.ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// Context() 返回请求的上下文,用于取消操作
|
||||
|
|
@ -751,37 +946,30 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) {
|
|||
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
|
||||
// 注意:请求体只能读取一次
|
||||
func (c *Context) GetReqBody() io.ReadCloser {
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
return c.prepareRequestBody()
|
||||
}
|
||||
if c.Request == nil || c.Request.Body == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Request.Body
|
||||
}
|
||||
|
||||
// GetReqBodyFull 读取并返回请求体的所有内容
|
||||
// 注意:请求体只能读取一次
|
||||
func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||||
if c.Request.Body == nil {
|
||||
body := c.GetReqBody()
|
||||
if body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var limitBytesReader io.ReadCloser
|
||||
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
err := body.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
limitBytesReader = c.Request.Body
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
data, err := iox.ReadAll(limitBytesReader)
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
|
|
@ -791,31 +979,18 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
|
|||
|
||||
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
|
||||
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
||||
if c.Request.Body == nil {
|
||||
body := c.GetReqBody()
|
||||
if body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var limitBytesReader io.ReadCloser
|
||||
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
err := body.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
limitBytesReader = c.Request.Body
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
data, err := iox.ReadAll(limitBytesReader)
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
|
|
@ -974,14 +1149,9 @@ func (c *Context) GetProtocol() string {
|
|||
return c.Request.Proto
|
||||
}
|
||||
|
||||
// GetHTTPC 获取框架自带传递的httpc
|
||||
func (c *Context) GetHTTPC() *httpc.Client {
|
||||
return c.HTTPClient
|
||||
}
|
||||
|
||||
// GetLogger 获取engine的Logger
|
||||
func (c *Context) GetLogger() *reco.Logger {
|
||||
return c.engine.LogReco
|
||||
// GetLogger 获取engine的Logger接口
|
||||
func (c *Context) GetLogger() Logger {
|
||||
return c.engine.logger
|
||||
}
|
||||
|
||||
// GetReqQueryString
|
||||
|
|
@ -1084,17 +1254,25 @@ func (c *Context) SetSameSite(samesite http.SameSite) {
|
|||
}
|
||||
|
||||
// SetCookie 设置一个 HTTP cookie
|
||||
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool) {
|
||||
// sameSite 参数是可选的,如果不提供则使用通过 SetSameSite 设置的值
|
||||
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool, sameSite ...http.SameSite) {
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
site := c.sameSite
|
||||
if len(sameSite) > 0 {
|
||||
if len(sameSite) > 1 {
|
||||
c.Warnf("SetCookie: only the first SameSite value will be used, got %d values", len(sameSite))
|
||||
}
|
||||
site = sameSite[0]
|
||||
}
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: url.QueryEscape(value),
|
||||
MaxAge: maxAge,
|
||||
Path: path,
|
||||
Domain: domain,
|
||||
SameSite: c.sameSite,
|
||||
SameSite: site,
|
||||
Secure: secure,
|
||||
HttpOnly: httpOnly,
|
||||
})
|
||||
|
|
@ -1132,25 +1310,25 @@ func (c *Context) DeleteCookie(name string) {
|
|||
|
||||
// === 日志记录 ===
|
||||
func (c *Context) Debugf(format string, args ...any) {
|
||||
c.engine.LogReco.Debugf(format, args...)
|
||||
c.engine.logger.Debugf(format, args...)
|
||||
}
|
||||
|
||||
func (c *Context) Infof(format string, args ...any) {
|
||||
c.engine.LogReco.Infof(format, args...)
|
||||
c.engine.logger.Infof(format, args...)
|
||||
}
|
||||
|
||||
func (c *Context) Warnf(format string, args ...any) {
|
||||
c.engine.LogReco.Warnf(format, args...)
|
||||
c.engine.logger.Warnf(format, args...)
|
||||
}
|
||||
|
||||
func (c *Context) Errorf(format string, args ...any) {
|
||||
c.engine.LogReco.Errorf(format, args...)
|
||||
c.engine.logger.Errorf(format, args...)
|
||||
}
|
||||
|
||||
func (c *Context) Fatalf(format string, args ...any) {
|
||||
c.engine.LogReco.Fatalf(format, args...)
|
||||
c.engine.logger.Fatalf(format, args...)
|
||||
}
|
||||
|
||||
func (c *Context) Panicf(format string, args ...any) {
|
||||
c.engine.LogReco.Panicf(format, args...)
|
||||
c.engine.logger.Panicf(format, args...)
|
||||
}
|
||||
|
|
|
|||
81
context_benchmark_test.go
Normal file
81
context_benchmark_test.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestContextResetKeepsKeysNilUntilSet(t *testing.T) {
|
||||
c, _ := CreateTestContext(nil)
|
||||
if c.Keys != nil {
|
||||
t.Fatalf("expected fresh test context Keys to be nil before first Set")
|
||||
}
|
||||
|
||||
c.Set("answer", 42)
|
||||
if c.Keys == nil {
|
||||
t.Fatalf("expected Set to allocate Keys map")
|
||||
}
|
||||
if value, exists := c.Get("answer"); !exists || value != 42 {
|
||||
t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
c.reset(UnwrapResponseWriter(c.Writer), req)
|
||||
|
||||
if c.Keys != nil {
|
||||
t.Fatalf("expected reset to clear Keys without allocating a new map")
|
||||
}
|
||||
if value, exists := c.Get("answer"); exists || value != nil {
|
||||
t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists)
|
||||
}
|
||||
|
||||
ctxValue := c.Value("missing")
|
||||
if ctxValue != nil {
|
||||
t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue)
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatalf("expected MustGet to panic for missing key after reset")
|
||||
}
|
||||
}()
|
||||
_ = c.MustGet("answer")
|
||||
}
|
||||
|
||||
func BenchmarkContextReset(b *testing.B) {
|
||||
b.Run("NoKeysUse", func(b *testing.B) {
|
||||
c, _ := CreateTestContext(nil)
|
||||
rawWriter := UnwrapResponseWriter(c.Writer)
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
c.reset(rawWriter, req)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WithKeysUse", func(b *testing.B) {
|
||||
c, _ := CreateTestContext(nil)
|
||||
rawWriter := UnwrapResponseWriter(c.Writer)
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
c.reset(rawWriter, req)
|
||||
c.Set("request-id", i)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
174
context_bodylimit_test.go
Normal file
174
context_bodylimit_test.go
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type zeroNilThenEOFReader struct {
|
||||
readCalls int
|
||||
}
|
||||
|
||||
func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) {
|
||||
r.readCalls++
|
||||
if r.readCalls == 1 {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (r *zeroNilThenEOFReader) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestFileTextUsesProvidedStatusCode(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
filePath := filepath.Join(dir, "hello.txt")
|
||||
if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil {
|
||||
t.Fatalf("write temp file: %v", err)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
c, _ := CreateTestContext(rr)
|
||||
|
||||
c.FileText(http.StatusCreated, filePath)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code)
|
||||
}
|
||||
if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" {
|
||||
t.Fatalf("unexpected content type: %q", got)
|
||||
}
|
||||
if body := rr.Body.String(); body != "hello touka" {
|
||||
t.Fatalf("unexpected body: %q", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderAllowsExactLimit(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4)
|
||||
defer reader.Close()
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("expected exact limit read to succeed, got %v", err)
|
||||
}
|
||||
if string(data) != "abcd" {
|
||||
t.Fatalf("unexpected data: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderRejectsOverLimit(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4)
|
||||
defer reader.Close()
|
||||
|
||||
_, err := io.ReadAll(reader)
|
||||
if !errors.Is(err, ErrBodyTooLarge) {
|
||||
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1)
|
||||
defer reader.Close()
|
||||
|
||||
buf := make([]byte, 1)
|
||||
n, err := reader.Read(buf)
|
||||
if n != 0 || err != nil {
|
||||
t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err)
|
||||
}
|
||||
|
||||
n, err = reader.Read(buf)
|
||||
if n != 0 || !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderTreatsZeroLimitAsUnlimited(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abc")), 0)
|
||||
defer reader.Close()
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("expected zero limit to leave body unlimited, got %v", err)
|
||||
}
|
||||
if string(data) != "abc" {
|
||||
t.Fatalf("unexpected data: %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
body := strings.NewReader(`{"name":"abcdef"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/json", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
c.SetMaxRequestBodySize(8)
|
||||
|
||||
var payload struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
err := c.ShouldBindJSON(&payload)
|
||||
if !errors.Is(err, ErrBodyTooLarge) {
|
||||
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
body := strings.NewReader("name=abcdef")
|
||||
req := httptest.NewRequest(http.MethodPost, "/form", body)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
c.SetMaxRequestBodySize(4)
|
||||
|
||||
var payload struct {
|
||||
Name string `form:"name"`
|
||||
}
|
||||
|
||||
err := c.ShouldBindForm(&payload)
|
||||
if !errors.Is(err, ErrBodyTooLarge) {
|
||||
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostFormHonorsMaxRequestBodySize(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
body := strings.NewReader("name=abcdef")
|
||||
req := httptest.NewRequest(http.MethodPost, "/form", body)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
c.SetMaxRequestBodySize(4)
|
||||
|
||||
if got := c.PostForm("name"); got != "" {
|
||||
t.Fatalf("expected empty value on over-limit form body, got %q", got)
|
||||
}
|
||||
if len(c.Errors) == 0 {
|
||||
t.Fatal("expected parse error to be recorded")
|
||||
}
|
||||
if !errors.Is(c.Errors[0], ErrBodyTooLarge) {
|
||||
t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0])
|
||||
}
|
||||
}
|
||||
58
context_httpc.go
Normal file
58
context_httpc.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2024 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/WJQSERVER-STUDIO/httpc"
|
||||
)
|
||||
|
||||
// contextHTTPClient 包装 httpc.Client,自动关联请求的 Context
|
||||
// 当请求被取消时,出站 HTTP 请求也会自动取消
|
||||
type contextHTTPClient struct {
|
||||
client *httpc.Client
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewRequestBuilder 创建请求构建器,自动关联请求 Context
|
||||
func (c *contextHTTPClient) NewRequestBuilder(method, urlStr string) *httpc.RequestBuilder {
|
||||
return c.client.NewRequestBuilder(method, urlStr).WithContext(c.ctx)
|
||||
}
|
||||
|
||||
// GET 创建 GET 请求构建器
|
||||
func (c *contextHTTPClient) GET(urlStr string) *httpc.RequestBuilder {
|
||||
return c.client.GET(urlStr).WithContext(c.ctx)
|
||||
}
|
||||
|
||||
// POST 创建 POST 请求构建器
|
||||
func (c *contextHTTPClient) POST(urlStr string) *httpc.RequestBuilder {
|
||||
return c.client.POST(urlStr).WithContext(c.ctx)
|
||||
}
|
||||
|
||||
// PUT 创建 PUT 请求构建器
|
||||
func (c *contextHTTPClient) PUT(urlStr string) *httpc.RequestBuilder {
|
||||
return c.client.PUT(urlStr).WithContext(c.ctx)
|
||||
}
|
||||
|
||||
// DELETE 创建 DELETE 请求构建器
|
||||
func (c *contextHTTPClient) DELETE(urlStr string) *httpc.RequestBuilder {
|
||||
return c.client.DELETE(urlStr).WithContext(c.ctx)
|
||||
}
|
||||
|
||||
// PATCH 创建 PATCH 请求构建器
|
||||
func (c *contextHTTPClient) PATCH(urlStr string) *httpc.RequestBuilder {
|
||||
return c.client.PATCH(urlStr).WithContext(c.ctx)
|
||||
}
|
||||
|
||||
// HEAD 创建 HEAD 请求构建器
|
||||
func (c *contextHTTPClient) HEAD(urlStr string) *httpc.RequestBuilder {
|
||||
return c.client.HEAD(urlStr).WithContext(c.ctx)
|
||||
}
|
||||
|
||||
// OPTIONS 创建 OPTIONS 请求构建器
|
||||
func (c *contextHTTPClient) OPTIONS(urlStr string) *httpc.RequestBuilder {
|
||||
return c.client.OPTIONS(urlStr).WithContext(c.ctx)
|
||||
}
|
||||
134
docs/advanced.md
134
docs/advanced.md
|
|
@ -44,7 +44,9 @@ r.SetTLSServerConfigurator(func(server *http.Server) {
|
|||
Touka 支持配置 HTTP/1.1、HTTP/2 和 H2C(HTTP/2 Cleartext):
|
||||
|
||||
```go
|
||||
// 使用默认协议配置(仅 HTTP/1.1)
|
||||
// 使用默认协议配置
|
||||
// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集,
|
||||
// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。
|
||||
r.SetDefaultProtocols()
|
||||
|
||||
// 自定义协议配置
|
||||
|
|
@ -57,33 +59,147 @@ r.SetProtocols(&touka.ProtocolsConfig{
|
|||
|
||||
### 启动方式
|
||||
|
||||
Touka 提供了多种服务器启动方式:
|
||||
Touka 统一通过 `Run(opts...)` 启动服务器:
|
||||
|
||||
```go
|
||||
// 1. 简单启动(无优雅停机)
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
|
||||
// 2. 带优雅停机的启动
|
||||
r.RunShutdown(":8080", 10*time.Second)
|
||||
r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second))
|
||||
|
||||
// 3. 带上下文的优雅停机
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
r.RunShutdownWithContext(":8080", ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
r.Run(
|
||||
touka.WithAddr(":8080"),
|
||||
touka.WithGracefulShutdown(10*time.Second),
|
||||
touka.WithShutdownContext(ctx),
|
||||
)
|
||||
|
||||
// 4. HTTPS 启动
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
// 其他 TLS 配置...
|
||||
}
|
||||
r.RunTLS(":443", tlsConfig, 10*time.Second)
|
||||
// WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithGracefulShutdownDefault(),
|
||||
)
|
||||
|
||||
// 5. HTTPS + HTTP 重定向
|
||||
r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second)
|
||||
// WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(":80"),
|
||||
touka.WithGracefulShutdown(10*time.Second),
|
||||
)
|
||||
|
||||
// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host)
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(
|
||||
":80",
|
||||
touka.WithUseHeaderHost(true),
|
||||
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
|
||||
),
|
||||
)
|
||||
|
||||
// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host)
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(
|
||||
":80",
|
||||
touka.WithUseHeaderHost(false),
|
||||
touka.WithRedirectHost("example.com"),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### HTTPS Redirect Host 策略
|
||||
|
||||
`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。
|
||||
|
||||
可用的 redirect 子选项:
|
||||
|
||||
- `touka.WithUseHeaderHost(true|false)`
|
||||
- `touka.WithRedirectHostHeaders([]string{...})`
|
||||
- `touka.WithRedirectHost("example.com")`
|
||||
|
||||
#### 模式一:使用请求输入侧的 host
|
||||
|
||||
当 `WithUseHeaderHost(true)` 时:
|
||||
|
||||
- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host`
|
||||
- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值
|
||||
- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required`
|
||||
|
||||
示例:
|
||||
|
||||
```go
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(
|
||||
":80",
|
||||
touka.WithUseHeaderHost(true),
|
||||
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
#### 模式二:使用配置的固定 host
|
||||
|
||||
当 `WithUseHeaderHost(false)` 时:
|
||||
|
||||
- 不读取 `Request.Host`
|
||||
- 不读取 `WithRedirectHostHeaders(...)`
|
||||
- 必须配置 `WithRedirectHost("example.com")`
|
||||
|
||||
示例:
|
||||
|
||||
```go
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(
|
||||
":80",
|
||||
touka.WithUseHeaderHost(false),
|
||||
touka.WithRedirectHost("example.com"),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
#### 严格校验规则
|
||||
|
||||
以下组合会直接返回配置错误:
|
||||
|
||||
- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)`
|
||||
- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)`
|
||||
- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)`
|
||||
- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)`
|
||||
- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)`
|
||||
|
||||
#### 优先级关系
|
||||
|
||||
1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式
|
||||
2. `WithUseHeaderHost(...)` 决定 host 来源模式
|
||||
3. 当 `WithUseHeaderHost(true)` 时:
|
||||
- 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询
|
||||
- 未配置时使用 `Request.Host`
|
||||
4. 当 `WithUseHeaderHost(false)` 时:
|
||||
- 只使用 `WithRedirectHost(...)`
|
||||
|
||||
**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。
|
||||
|
||||
## 优雅停机 (Graceful Shutdown)
|
||||
|
||||
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。
|
||||
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。
|
||||
|
||||
```go
|
||||
r := touka.Default()
|
||||
|
|
@ -91,7 +207,7 @@ r := touka.Default()
|
|||
|
||||
// 监听 SIGINT 和 SIGTERM 信号
|
||||
// 如果在 10 秒内未处理完,则强制关闭
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatal("服务器退出异常:", err)
|
||||
}
|
||||
```
|
||||
|
|
|
|||
188
docs/httpc.md
Normal file
188
docs/httpc.md
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
# HTTP Client (httpc)
|
||||
|
||||
Touka 内置了 [httpc](https://github.com/WJQSERVER-STUDIO/httpc) HTTP 客户端,方便在请求处理函数中发起出站 HTTP 请求。
|
||||
|
||||
## 核心特性
|
||||
|
||||
- **自动 Context 关联**:使用 `HTTPC()` 方法时,出站请求会自动关联当前请求的 Context
|
||||
- **请求取消传播**:当客户端断开连接时,出站请求会自动取消,避免资源泄漏
|
||||
- **链式调用**:保持 httpc 原有的组合式构建器风格
|
||||
|
||||
## 基本用法
|
||||
|
||||
### 简单 GET 请求
|
||||
|
||||
```go
|
||||
r.GET("/proxy", func(c *touka.Context) {
|
||||
body, err := c.HTTPC().
|
||||
GET("https://api.example.com/data").
|
||||
Text()
|
||||
if err != nil {
|
||||
c.JSON(500, touka.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.String(200, body)
|
||||
})
|
||||
```
|
||||
|
||||
### POST JSON 请求
|
||||
|
||||
```go
|
||||
r.POST("/users", func(c *touka.Context) {
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
c.ShouldBindJSON(&req)
|
||||
|
||||
var result struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
err := c.HTTPC().
|
||||
POST("https://api.example.com/users").
|
||||
SetHeader("Authorization", "Bearer "+token).
|
||||
SetJSONBody(req).
|
||||
DecodeJSON(&result)
|
||||
if err != nil {
|
||||
c.JSON(500, touka.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(200, result)
|
||||
})
|
||||
```
|
||||
|
||||
### 带查询参数
|
||||
|
||||
```go
|
||||
r.GET("/search", func(c *touka.Context) {
|
||||
query := c.Query("q")
|
||||
|
||||
var result SearchResult
|
||||
err := c.HTTPC().
|
||||
GET("https://api.example.com/search").
|
||||
SetQueryParam("q", query).
|
||||
SetQueryParam("limit", "10").
|
||||
DecodeJSON(&result)
|
||||
if err != nil {
|
||||
c.JSON(500, touka.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(200, result)
|
||||
})
|
||||
```
|
||||
|
||||
## API 对比
|
||||
|
||||
### 旧方式(Deprecated)
|
||||
|
||||
```go
|
||||
// 需要手动 WithContext,容易忘记
|
||||
resp, err := c.Client().
|
||||
WithContext(c.Context()).
|
||||
GET(url).
|
||||
Execute()
|
||||
```
|
||||
|
||||
### 新方式(推荐)
|
||||
|
||||
```go
|
||||
// 自动关联请求 Context
|
||||
resp, err := c.HTTPC().
|
||||
GET(url).
|
||||
Execute()
|
||||
```
|
||||
|
||||
## Context 取消机制
|
||||
|
||||
使用 `HTTPC()` 时,当客户端断开连接(如关闭浏览器),出站请求会自动取消:
|
||||
|
||||
```go
|
||||
r.GET("/long-task", func(c *touka.Context) {
|
||||
// 这个请求会在客户端断开时自动取消
|
||||
resp, err := c.HTTPC().
|
||||
GET("https://slow-api.example.com/data").
|
||||
Execute()
|
||||
|
||||
// 如果客户端已断开,err 会包含 context.Canceled
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return // 客户端已断开,无需处理
|
||||
}
|
||||
// ...
|
||||
})
|
||||
```
|
||||
|
||||
## 完整 API
|
||||
|
||||
### contextHTTPClient 方法
|
||||
|
||||
| 方法 | 返回类型 | 说明 |
|
||||
|------|----------|------|
|
||||
| `NewRequestBuilder(method, url)` | `*httpc.RequestBuilder` | 创建通用请求构建器 |
|
||||
| `GET(url)` | `*httpc.RequestBuilder` | 创建 GET 请求 |
|
||||
| `POST(url)` | `*httpc.RequestBuilder` | 创建 POST 请求 |
|
||||
| `PUT(url)` | `*httpc.RequestBuilder` | 创建 PUT 请求 |
|
||||
| `DELETE(url)` | `*httpc.RequestBuilder` | 创建 DELETE 请求 |
|
||||
| `PATCH(url)` | `*httpc.RequestBuilder` | 创建 PATCH 请求 |
|
||||
| `HEAD(url)` | `*httpc.RequestBuilder` | 创建 HEAD 请求 |
|
||||
| `OPTIONS(url)` | `*httpc.RequestBuilder` | 创建 OPTIONS 请求 |
|
||||
|
||||
### httpc.RequestBuilder 链式方法
|
||||
|
||||
返回 `*httpc.RequestBuilder`(用于链式调用):
|
||||
|
||||
| 方法 | 说明 |
|
||||
|------|------|
|
||||
| `WithContext(ctx)` | 设置 Context(通常不需要,已自动关联) |
|
||||
| `NoDefaultHeaders()` | 不添加默认 Header |
|
||||
| `SetHeader(key, value)` | 设置 Header |
|
||||
| `AddHeader(key, value)` | 添加 Header(可重复) |
|
||||
| `SetHeaders(map)` | 批量设置 Headers |
|
||||
| `SetQueryParam(key, value)` | 设置查询参数 |
|
||||
| `AddQueryParam(key, value)` | 添加查询参数(可重复) |
|
||||
| `SetQueryParams(map)` | 批量设置查询参数 |
|
||||
| `SetBody(io.Reader)` | 设置请求 Body |
|
||||
| `SetRawBody([]byte)` | 设置字节 Body |
|
||||
|
||||
返回 `(*httpc.RequestBuilder, error)`(可能失败):
|
||||
|
||||
| 方法 | 说明 |
|
||||
|------|------|
|
||||
| `SetJSONBody(any)` | 设置 JSON Body |
|
||||
| `SetXMLBody(any)` | 设置 XML Body |
|
||||
| `SetGOBBody(any)` | 设置 GOB Body |
|
||||
|
||||
### 终结方法
|
||||
|
||||
| 方法 | 返回类型 | 说明 |
|
||||
|------|----------|------|
|
||||
| `Build()` | `(*http.Request, error)` | 构建请求但不执行 |
|
||||
| `Execute()` | `(*http.Response, error)` | 执行并返回原始响应 |
|
||||
| `DecodeJSON(v)` | `error` | 执行并解码 JSON |
|
||||
| `DecodeXML(v)` | `error` | 执行并解码 XML |
|
||||
| `DecodeGOB(v)` | `error` | 执行并解码 GOB |
|
||||
| `Text()` | `(string, error)` | 执行并返回文本 |
|
||||
| `Bytes()` | `([]byte, error)` | 执行并返回字节 |
|
||||
| `SSE()` | `(*SSEStream, error)` | 建立 SSE 流连接 |
|
||||
|
||||
## 迁移指南
|
||||
|
||||
### go:fix inline 兼容
|
||||
|
||||
旧代码 `c.GetHTTPC()` 可通过 `go fix` 自动迁移到 `c.Client()`:
|
||||
|
||||
```bash
|
||||
go fix ./...
|
||||
```
|
||||
|
||||
### 手动迁移
|
||||
|
||||
| 旧代码 | 新代码 |
|
||||
|--------|--------|
|
||||
| `c.GetHTTPC()` | `c.Client()` 或 `c.HTTPC()` |
|
||||
| `c.Client().WithContext(ctx).GET(url)` | `c.HTTPC().GET(url)` |
|
||||
|
||||
## 示例
|
||||
|
||||
完整示例请参考 [examples/httpc](../examples/httpc)。
|
||||
|
|
@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其
|
|||
|
||||
1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。
|
||||
2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。
|
||||
3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理。
|
||||
3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求。
|
||||
|
||||
Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。
|
||||
|
|
|
|||
400
docs/logger-migration-design.md
Normal file
400
docs/logger-migration-design.md
Normal file
|
|
@ -0,0 +1,400 @@
|
|||
# Touka Logger 接口迁移方案
|
||||
|
||||
## 基于 Go 1.26 `go:fix inline` 的自动化迁移设计
|
||||
|
||||
---
|
||||
|
||||
## 一、问题分析
|
||||
|
||||
当前架构问题:
|
||||
```
|
||||
Engine.LogReco → *reco.Logger (公开字段, 直接访问)
|
||||
Context.GetLogger() → 返回 *reco.Logger (具体类型)
|
||||
Context.Debugf/Infof... → 硬编码 c.engine.LogReco.Debugf(...)
|
||||
```
|
||||
|
||||
这导致用户无法替换日志实现(如 zap/logrus)。
|
||||
|
||||
---
|
||||
|
||||
## 二、目标架构
|
||||
|
||||
```
|
||||
Engine.logger → Logger 接口 (私有)
|
||||
Engine.LogReco → *reco.Logger (公开, Deprecated - 保持向后兼容)
|
||||
Engine.GetLogger() → 返回 Logger 接口
|
||||
Engine.SetLogger(Logger)→ 设置日志实现
|
||||
Context.GetLogger() → 返回 Logger 接口
|
||||
Context.Debugf/Infof... → 调用 c.engine.logger.Debugf(...)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 三、Logger 接口定义
|
||||
|
||||
```go
|
||||
// logger.go
|
||||
package touka
|
||||
|
||||
// Logger 是日志接口,支持任意日志库实现
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...any)
|
||||
Infof(format string, args ...any)
|
||||
Warnf(format string, args ...any)
|
||||
Errorf(format string, args ...any)
|
||||
Fatalf(format string, args ...any)
|
||||
Panicf(format string, args ...any)
|
||||
}
|
||||
|
||||
// CloserLogger 可选扩展,支持关闭操作
|
||||
type CloserLogger interface {
|
||||
Logger
|
||||
Close() error
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 四、Engine 结构变更
|
||||
|
||||
```go
|
||||
// engine.go 变更
|
||||
type Engine struct {
|
||||
// ... 其他字段保持不变
|
||||
|
||||
// logger 是新的日志接口 (私有)
|
||||
logger Logger
|
||||
|
||||
// logReco 是保留的 reco.Logger 引用 (私有)
|
||||
// 用于向后兼容,当通过 SetLoggerReco 设置时同步到 logger
|
||||
logReco *reco.Logger
|
||||
|
||||
// 其他字段...
|
||||
}
|
||||
```
|
||||
|
||||
新增/修改方法:
|
||||
|
||||
```go
|
||||
// GetLogger 返回日志接口
|
||||
func (engine *Engine) GetLogger() Logger {
|
||||
return engine.logger
|
||||
}
|
||||
|
||||
// SetLogger 设置任意 Logger 实现
|
||||
func (engine *Engine) SetLogger(l Logger) {
|
||||
engine.logger = l
|
||||
// 如果是 *reco.Logger 类型,同步更新 logReco
|
||||
if rl, ok := l.(*reco.Logger); ok {
|
||||
engine.logReco = rl
|
||||
} else {
|
||||
engine.logReco = nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetLoggerCfg 使用 reco.Config 配置日志
|
||||
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
|
||||
logger := NewLogger(logcfg)
|
||||
engine.logger = logger
|
||||
engine.logReco = logger
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 五、`go:fix inline` 兼容性函数
|
||||
|
||||
### 5.1 旧 API 包装函数
|
||||
|
||||
在 `compat.go` 中定义:
|
||||
|
||||
```go
|
||||
// compat.go
|
||||
package touka
|
||||
|
||||
import "github.com/fenthope/reco"
|
||||
|
||||
// GetLogReco 返回 reco.Logger,用于向后兼容
|
||||
//
|
||||
//go:fix inline
|
||||
func (engine *Engine) GetLogReco() *reco.Logger {
|
||||
return engine.logReco
|
||||
}
|
||||
|
||||
// SetLogReco 设置 reco.Logger,用于向后兼容
|
||||
//
|
||||
//go:fix inline
|
||||
func (engine *Engine) SetLogReco(l *reco.Logger) {
|
||||
engine.logReco = l
|
||||
engine.logger = l
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 Context 日志方法的 inline 包装
|
||||
|
||||
```go
|
||||
// context_compat.go
|
||||
package touka
|
||||
|
||||
// Debugf 记录 Debug 级别日志
|
||||
//
|
||||
//go:fix inline
|
||||
func (c *Context) Debugf(format string, args ...any) {
|
||||
c.engine.logger.Debugf(format, args...)
|
||||
}
|
||||
|
||||
// Infof 记录 Info 级别日志
|
||||
//
|
||||
//go:fix inline
|
||||
func (c *Context) Infof(format string, args ...any) {
|
||||
c.engine.logger.Infof(format, args...)
|
||||
}
|
||||
|
||||
// Warnf 记录 Warn 级别日志
|
||||
//
|
||||
//go:fix inline
|
||||
func (c *Context) Warnf(format string, args ...any) {
|
||||
c.engine.logger.Warnf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf 记录 Error 级别日志
|
||||
//
|
||||
//go:fix inline
|
||||
func (c *Context) Errorf(format string, args ...any) {
|
||||
c.engine.logger.Errorf(format, args...)
|
||||
}
|
||||
|
||||
// Fatalf 记录 Fatal 级别日志
|
||||
//
|
||||
//go:fix inline
|
||||
func (c *Context) Fatalf(format string, args ...any) {
|
||||
c.engine.logger.Fatalf(format, args...)
|
||||
}
|
||||
|
||||
// Panicf 记录 Panic 级别日志
|
||||
//
|
||||
//go:fix inline
|
||||
func (c *Context) Panicf(format string, args ...any) {
|
||||
c.engine.logger.Panicf(format, args...)
|
||||
}
|
||||
```
|
||||
|
||||
### 5.3 GetLogger 返回类型的兼容处理
|
||||
|
||||
由于 `GetLogger()` 返回类型从 `*reco.Logger` 变为 `Logger`,需要提供兼容函数:
|
||||
|
||||
```go
|
||||
// context_compat.go (续)
|
||||
|
||||
// GetLoggerReco 返回 *reco.Logger 类型,用于需要具体类型的场景
|
||||
//
|
||||
//go:fix inline
|
||||
func (c *Context) GetLoggerReco() *reco.Logger {
|
||||
if rl, ok := c.engine.logger.(*reco.Logger); ok {
|
||||
return rl
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 六、go:fix inline 工作原理
|
||||
|
||||
### 迁移前用户代码:
|
||||
```go
|
||||
func handler(c *touka.Context) {
|
||||
// 旧 API 调用
|
||||
c.Debugf("request: %s", c.Request.URL.Path)
|
||||
c.engine.LogReco.Infof("server started")
|
||||
}
|
||||
```
|
||||
|
||||
### go fix 执行后(自动替换):
|
||||
```go
|
||||
func handler(c *touka.Context) {
|
||||
// Debugf 被替换为函数体
|
||||
c.engine.logger.Debugf("request: %s", c.Request.URL.Path)
|
||||
|
||||
// LogReco 访问无法通过 inline 自动处理,需要手动迁移
|
||||
// 或者通过 getter 调用
|
||||
}
|
||||
```
|
||||
|
||||
### 对于字段访问的处理策略:
|
||||
|
||||
`engine.LogReco` 字段访问无法直接用 `go:fix inline` 处理,采用以下策略:
|
||||
|
||||
1. **保留字段但标记 deprecated**:继续导出 `LogReco` 但文档标记为 deprecated
|
||||
2. **提供 getter/setter**:通过 `go:fix inline` 提供 `GetLogReco/SetLogReco`
|
||||
3. **渐进迁移**:用户可以在方便时手动迁移到 `GetLogger()/SetLogger()`
|
||||
|
||||
---
|
||||
|
||||
## 七、迁移前后对比
|
||||
|
||||
### 场景 1:基本日志调用
|
||||
|
||||
**迁移前:**
|
||||
```go
|
||||
func myHandler(c *touka.Context) {
|
||||
c.Debugf("processing request %s", c.Request.URL.Path)
|
||||
c.Infof("user %s logged in", username)
|
||||
c.Warnf("slow query: %v", duration)
|
||||
c.Errorf("db error: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
**迁移后(自动替换):**
|
||||
```go
|
||||
func myHandler(c *touka.Context) {
|
||||
c.engine.logger.Debugf("processing request %s", c.Request.URL.Path)
|
||||
c.engine.logger.Infof("user %s logged in", username)
|
||||
c.engine.logger.Warnf("slow query: %v", duration)
|
||||
c.engine.logger.Errorf("db error: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
### 场景 2:Engine 配置日志
|
||||
|
||||
**迁移前:**
|
||||
```go
|
||||
engine := touka.New()
|
||||
engine.LogReco = myLogger // 直接赋值
|
||||
logger := engine.LogReco // 直接读取
|
||||
```
|
||||
|
||||
**迁移后(手动 + 自动混合):**
|
||||
```go
|
||||
engine := touka.New()
|
||||
|
||||
// 方式 1:使用新 API(推荐)
|
||||
engine.SetLogger(myLogger)
|
||||
logger := engine.GetLogger()
|
||||
|
||||
// 方式 2:通过 go:fix inline 自动替换为 getter
|
||||
// engine.SetLogReco(myLogger) ← go fix 替换
|
||||
// logger := engine.GetLogReco() ← go fix 替换
|
||||
```
|
||||
|
||||
### 场景 3:使用第三方日志库(新功能)
|
||||
|
||||
```go
|
||||
import "go.uber.org/zap"
|
||||
|
||||
func main() {
|
||||
zapLogger, _ := zap.NewProduction()
|
||||
defer zapLogger.Sync()
|
||||
|
||||
engine := touka.New()
|
||||
// 使用 zap 替代默认的 reco.Logger
|
||||
engine.SetLogger(&ZapAdapter{logger: zapLogger})
|
||||
|
||||
engine.GET("/api", func(c *touka.Context) {
|
||||
c.Infof("api called") // 自动使用 zap 输出
|
||||
})
|
||||
}
|
||||
|
||||
// ZapAdapter 适配 zap 到 touka.Logger 接口
|
||||
type ZapAdapter struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func (z *ZapAdapter) Debugf(format string, args ...any) {
|
||||
z.logger.Debug(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (z *ZapAdapter) Infof(format string, args ...any) {
|
||||
z.logger.Info(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (z *ZapAdapter) Warnf(format string, args ...any) {
|
||||
z.logger.Warn(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (z *ZapAdapter) Errorf(format string, args ...any) {
|
||||
z.logger.Error(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (z *ZapAdapter) Fatalf(format string, args ...any) {
|
||||
z.logger.Fatal(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (z *ZapAdapter) Panicf(format string, args ...any) {
|
||||
z.logger.Panic(fmt.Sprintf(format, args...))
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 八、内部使用迁移
|
||||
|
||||
框架内部代码也需要迁移,将直接调用 `engine.LogReco` 改为 `engine.logger`:
|
||||
|
||||
需要修改的文件:
|
||||
- `context.go`: writeResponseBody 中的 `c.engine.LogReco.Errorf`
|
||||
- `recovery.go`: 如有使用日志
|
||||
- `logreco.go`: CloseLogger 方法
|
||||
|
||||
```go
|
||||
// context.go 修改前
|
||||
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
|
||||
if _, err := c.Writer.Write(data); err != nil {
|
||||
if c.engine.LogReco != nil {
|
||||
c.engine.LogReco.Errorf("%s: %v", contextMsg, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// context.go 修改后
|
||||
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
|
||||
if _, err := c.Writer.Write(data); err != nil {
|
||||
if c.engine.logger != nil {
|
||||
c.engine.logger.Errorf("%s: %v", contextMsg, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 九、完整文件结构
|
||||
|
||||
```
|
||||
touka/
|
||||
├── logger.go # Logger 接口定义
|
||||
├── logreco.go # reco.Logger 相关工具函数
|
||||
├── compat.go # go:fix inline 兼容性函数 (Engine)
|
||||
├── context_compat.go # go:fix inline 兼容性函数 (Context)
|
||||
├── engine.go # Engine 结构变更
|
||||
├── context.go # Context 日志方法变更
|
||||
└── ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 十、版本策略
|
||||
|
||||
| 版本 | 变更内容 |
|
||||
|------|---------|
|
||||
| v1.x | 引入 Logger 接口,LogReco 标记 deprecated |
|
||||
| v2.x | 移除 LogReco 公开字段,仅通过 getter/setter 访问 |
|
||||
| v3.x | 移除 go:fix inline 兼容函数 |
|
||||
|
||||
---
|
||||
|
||||
## 十一、go:fix inline 限制说明
|
||||
|
||||
1. **字段访问无法自动迁移**:`engine.LogReco` 字段访问需要用户手动修改
|
||||
2. **返回类型变更需谨慎**:`GetLogger()` 返回类型变更会导致依赖具体类型的代码失败
|
||||
3. **inline 函数有大小限制**:函数体过大会影响内联效果
|
||||
4. **跨包迁移**:`go:fix inline` 支持跨包,但用户必须运行 `go fix`
|
||||
|
||||
---
|
||||
|
||||
## 十二、推荐迁移步骤
|
||||
|
||||
1. **框架侧**:添加 Logger 接口,添加 go:fix inline 函数
|
||||
2. **用户侧**:运行 `go fix ./...` 自动迁移可处理的部分
|
||||
3. **用户侧**:手动将 `engine.LogReco` 字段访问改为 `engine.SetLogger()/GetLogger()`
|
||||
4. **用户侧**:如需使用第三方日志,实现 Logger 接口并通过 SetLogger 设置
|
||||
|
|
@ -26,6 +26,41 @@ api.Use(AuthMiddleware())
|
|||
}
|
||||
```
|
||||
|
||||
也可以在创建组时直接传入中间件:
|
||||
|
||||
```go
|
||||
api := r.Group("/api", AuthMiddleware(), RateLimitMiddleware())
|
||||
{
|
||||
api.GET("/user", handleUser)
|
||||
api.POST("/data", handleData)
|
||||
}
|
||||
```
|
||||
|
||||
### 路由级中间件
|
||||
|
||||
为单个路由注册中间件,仅对该路由生效。
|
||||
|
||||
```go
|
||||
// 单个路由中间件
|
||||
r.GET("/protected", AuthMiddleware(), func(c *touka.Context) {
|
||||
c.String(http.StatusOK, "Protected content")
|
||||
})
|
||||
|
||||
// 多个路由中间件(按顺序执行)
|
||||
r.POST("/upload",
|
||||
RateLimitMiddleware(),
|
||||
AuthMiddleware(),
|
||||
PermissionCheckMiddleware(),
|
||||
func(c *touka.Context) {
|
||||
// 处理上传
|
||||
},
|
||||
)
|
||||
|
||||
// 路由组中的单个路由也可以使用路由级中间件
|
||||
api := r.Group("/api")
|
||||
api.GET("/admin", AdminAuthMiddleware(), adminHandler)
|
||||
```
|
||||
|
||||
## 编写自定义中间件
|
||||
|
||||
中间件的函数签名是 `touka.HandlerFunc`。
|
||||
|
|
@ -67,6 +102,36 @@ func APIKeyAuth() touka.HandlerFunc {
|
|||
}
|
||||
```
|
||||
|
||||
## 中间件执行顺序
|
||||
|
||||
理解中间件的执行顺序对于构建正确的处理流程至关重要。**注意:注册顺序决定了执行逻辑**,中间件必须在注册路由之前调用(全局中间件应在创建组或定义路由前注册)。中间件按照以下顺序执行:
|
||||
|
||||
```go
|
||||
// 全局中间件
|
||||
r.Use(GlobalMiddleware1())
|
||||
r.Use(GlobalMiddleware2())
|
||||
|
||||
// 组中间件
|
||||
api := r.Group("/api", GroupMiddleware1())
|
||||
api.Use(GroupMiddleware2())
|
||||
|
||||
// 路由级中间件
|
||||
api.GET("/users", RouteMiddleware1(), RouteMiddleware2(), userHandler)
|
||||
```
|
||||
|
||||
对于 `/api/users` 请求,执行顺序为:
|
||||
1. `GlobalMiddleware1()` - 全局中间件
|
||||
2. `GlobalMiddleware2()` - 全局中间件
|
||||
3. `GroupMiddleware1()` - 路由组中间件
|
||||
4. `GroupMiddleware2()` - 路由组中间件
|
||||
5. `RouteMiddleware1()` - 路由级中间件
|
||||
6. `RouteMiddleware2()` - 路由级中间件
|
||||
7. `userHandler` - 最终处理函数
|
||||
|
||||
```
|
||||
请求进入 → 全局中间件 → 路由组中间件 → 路由级中间件 → 最终处理函数 → 路由级中间件后置逻辑 → 路由组中间件后置逻辑 → 全局中间件后置逻辑 → 响应
|
||||
```
|
||||
|
||||
## 内置中间件
|
||||
|
||||
- **Recovery**: 捕获任何发生的 panic,恢复运行并返回 500 错误。它还负责调用全局错误处理器。
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ func main() {
|
|||
|
||||
// 4. 启动服务器并监听 8080 端口
|
||||
log.Println("Touka server is running on :8080")
|
||||
if err := r.Run(":8080"); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080")); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -66,11 +66,11 @@ go run main.go
|
|||
|
||||
## 优雅停机
|
||||
|
||||
在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成。
|
||||
在生产环境中,我们推荐为 `Run` 追加优雅关闭选项。启用后,Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`。
|
||||
|
||||
```go
|
||||
// 等待 10 秒以处理剩余请求
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatalf("Server forced to shutdown: %v", err)
|
||||
}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ func main() {
|
|||
Target: target,
|
||||
}))
|
||||
|
||||
_ = r.Run(":8080")
|
||||
_ = r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -60,10 +60,15 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
|||
```go
|
||||
type ReverseProxyConfig struct {
|
||||
Target *url.URL
|
||||
Targets []string
|
||||
|
||||
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||
|
||||
Transport http.RoundTripper
|
||||
FlushInterval time.Duration
|
||||
BufferPool BufferPool
|
||||
AllowH2CUpstream bool
|
||||
|
||||
ModifyRequest func(*http.Request)
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
|
@ -78,12 +83,133 @@ type ReverseProxyConfig struct {
|
|||
|
||||
### `Target`
|
||||
|
||||
必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。
|
||||
与 `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme` 和 `host`。
|
||||
|
||||
```go
|
||||
target, _ := url.Parse("http://backend:9000")
|
||||
```
|
||||
|
||||
### `Targets`
|
||||
|
||||
可选。用于配置多个后端目标地址。
|
||||
|
||||
- `Target` 与 `Targets` 互斥,只能使用其中一种
|
||||
- `Targets` 的每一项都必须是完整 URL
|
||||
- 每个 target 仍然可以自带 base path 和 query
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
"http://127.0.0.1:9001/base?from=a",
|
||||
"http://127.0.0.1:9002/base?from=b",
|
||||
},
|
||||
}))
|
||||
```
|
||||
|
||||
这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。
|
||||
|
||||
### `LoadBalancing`
|
||||
|
||||
用于配置 upstream 选择策略和重试行为。
|
||||
|
||||
```go
|
||||
type ReverseProxyLoadBalancingConfig struct {
|
||||
Policy ReverseProxyLBPolicy
|
||||
Retries int
|
||||
TryDuration time.Duration
|
||||
TryInterval time.Duration
|
||||
}
|
||||
```
|
||||
|
||||
当前内置策略:
|
||||
|
||||
- `touka.LBRandom()`
|
||||
- `touka.LBRoundRobin()`
|
||||
- `touka.LBFirst()`
|
||||
- `touka.LBLeastConn()`
|
||||
- `touka.LBIPHash()`
|
||||
- `touka.LBClientIPHash()`
|
||||
- `touka.LBURIHash()`
|
||||
- `touka.LBHeader("X-Upstream", fallback)`
|
||||
- `touka.LBQuery("tenant", fallback)`
|
||||
|
||||
其中:
|
||||
|
||||
- `LBFirst()` 适合主备/故障转移顺序
|
||||
- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback
|
||||
- 如果 `LBHeader` / `LBQuery` 没有显式 fallback,则默认回退到 `LBRandom()`
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
"http://127.0.0.1:9001",
|
||||
"http://127.0.0.1:9002",
|
||||
},
|
||||
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||
Policy: touka.LBHeader("X-Upstream", touka.LBFirst()),
|
||||
Retries: 1,
|
||||
},
|
||||
}))
|
||||
```
|
||||
|
||||
重试说明:
|
||||
|
||||
- 只对未开始收到上游响应的失败进行重试
|
||||
- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试
|
||||
- `Retries` 表示额外重试次数
|
||||
- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机
|
||||
- `TryInterval` 表示两次重试之间的等待间隔
|
||||
|
||||
### `PassiveHealth`
|
||||
|
||||
用于配置被动健康检查。它不会后台探测 upstream,而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。
|
||||
|
||||
```go
|
||||
type ReverseProxyPassiveHealthConfig struct {
|
||||
FailDuration time.Duration
|
||||
MaxFails int
|
||||
UnhealthyStatus []int
|
||||
}
|
||||
```
|
||||
|
||||
- `FailDuration > 0` 时启用被动健康跟踪
|
||||
- `MaxFails <= 0` 时默认按 `1` 处理
|
||||
- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream
|
||||
|
||||
```go
|
||||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Targets: []string{
|
||||
"http://127.0.0.1:9001",
|
||||
"http://127.0.0.1:9002",
|
||||
},
|
||||
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||
Policy: touka.LBFirst(),
|
||||
},
|
||||
PassiveHealth: touka.ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
UnhealthyStatus: []int{http.StatusServiceUnavailable},
|
||||
},
|
||||
}))
|
||||
```
|
||||
|
||||
### `AllowH2CUpstream`
|
||||
|
||||
允许代理使用未加密 HTTP/2(h2c)与 `http://` upstream 通信。
|
||||
|
||||
- 默认关闭
|
||||
- 这是一个显式配置项
|
||||
- 启用后,Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游
|
||||
- 这意味着上游本身也必须显式支持 h2c;它不是“先试 h2c,失败再自动回退到 h1”的协商模式
|
||||
|
||||
```go
|
||||
r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Target: target,
|
||||
AllowH2CUpstream: true,
|
||||
}))
|
||||
```
|
||||
|
||||
对于下游 HTTP/2 extended `CONNECT` websocket 场景,Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade,以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。
|
||||
|
||||
### `Transport`
|
||||
|
||||
可选。用于自定义底层转发所使用的 `http.RoundTripper`。
|
||||
|
|
@ -150,6 +276,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
|||
|
||||
在请求真正发往后端前,对出站请求做最后修改。
|
||||
|
||||
如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。
|
||||
|
||||
常见用途:
|
||||
|
||||
- 覆盖 `Host`
|
||||
|
|
@ -242,11 +370,20 @@ const (
|
|||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Target: target,
|
||||
ForwardedHeaders: touka.ForwardedBoth,
|
||||
ForwardedBy: "gateway-1",
|
||||
ForwardedBy: "_gateway-1",
|
||||
Via: "edge-1",
|
||||
}))
|
||||
```
|
||||
|
||||
如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。
|
||||
|
||||
- IPv4:`203.0.113.43`
|
||||
- IPv6 / 带端口:`[2001:db8::17]:443`
|
||||
- 匿名标识:`_gateway-1`
|
||||
- 未知:`unknown`
|
||||
|
||||
像 `gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。
|
||||
|
||||
`Via` 不是“留空即禁用”的开关。当前实现中:
|
||||
|
||||
- 如果 `Via` 非空,则使用该值追加 `Via`
|
||||
|
|
@ -282,11 +419,14 @@ Touka 会尽量遵循代理链语义:
|
|||
|
||||
Touka 的反向代理实现支持以下能力:
|
||||
|
||||
- `CONNECT` 隧道转发(HTTP/1.x)
|
||||
- HTTP/2 extended `CONNECT`
|
||||
- `Connection: Upgrade` / `Upgrade` 协议升级转发
|
||||
- WebSocket 等 101 Switching Protocols 场景
|
||||
- SSE(Server-Sent Events)立即刷新
|
||||
- Trailer 透传
|
||||
- 1xx 响应透传
|
||||
- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理
|
||||
|
||||
例如,代理 WebSocket 服务:
|
||||
|
||||
|
|
@ -341,7 +481,7 @@ func main() {
|
|||
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||
Target: target,
|
||||
ForwardedHeaders: touka.ForwardedBoth,
|
||||
ForwardedBy: "gateway-1",
|
||||
ForwardedBy: "_gateway-1",
|
||||
Via: "gateway-1",
|
||||
FlushInterval: 100 * time.Millisecond,
|
||||
ModifyRequest: func(req *http.Request) {
|
||||
|
|
@ -357,7 +497,7 @@ func main() {
|
|||
},
|
||||
}))
|
||||
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ r.ANY("/any", handle)
|
|||
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
|
||||
```
|
||||
|
||||
服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。
|
||||
|
||||
## 路径参数 (Named Parameters)
|
||||
|
||||
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。
|
||||
|
|
@ -140,7 +142,7 @@ func main() {
|
|||
r := touka.Default()
|
||||
fsroot, _ := fs.Sub(content, "dist")
|
||||
r.StaticFS("/", http.FS(fsroot))
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
|||
37
docs/sse.md
37
docs/sse.md
|
|
@ -40,43 +40,40 @@ r.GET("/events", func(c *touka.Context) {
|
|||
|
||||
## 模式二:通道模式 (EventStreamChan)
|
||||
|
||||
如果您需要更高级的并发控制(例如从多个异步源接收数据),可以使用通道模式。
|
||||
如果您需要更高级的并发控制(例如从多个异步源接收数据),可以使用通道模式。与回调模式类似,此方法是**阻塞的**:handler 会在此方法中停留,直到事件 channel 被关闭或客户端断开连接。
|
||||
|
||||
```go
|
||||
r.GET("/events-chan", func(c *touka.Context) {
|
||||
eventChan, errChan := c.EventStreamChan()
|
||||
eventChan := make(chan touka.Event)
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 监听错误/断开连接
|
||||
// 在独立的 goroutine 中发送事件.
|
||||
go func() {
|
||||
if err := <-errChan; err != nil {
|
||||
log.Printf("SSE 错误: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 发送数据
|
||||
go func() {
|
||||
defer close(eventChan) // 务必在结束时关闭
|
||||
defer close(eventChan) // 务必在结束时关闭以结束事件流.
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
default:
|
||||
eventChan <- touka.Event{
|
||||
case <-ctx.Done():
|
||||
return // 客户端已断开, 退出 goroutine.
|
||||
case eventChan <- touka.Event{
|
||||
Data: fmt.Sprintf("消息 #%d", i),
|
||||
}:
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// EventStreamChan 会阻塞直到流结束.
|
||||
c.EventStreamChan(eventChan)
|
||||
})
|
||||
```
|
||||
|
||||
## 最佳实践
|
||||
|
||||
1. **资源回收**: 确保在 `EventStreamChan` 模式下正确监听 `c.Request.Context().Done()` 以避免 Goroutine 泄漏。
|
||||
2. **数据格式**: SSE 协议要求数据为 UTF-8。Touka 的 `Render` 方法会自动处理多行数据并加上必要的 `data:` 前缀。
|
||||
3. **超时管理**: SSE 连接通常是长连接,请确保您的反向代理(如 Nginx)配置了足够大的写超时时间。
|
||||
1. **资源回收**: `EventStreamChan` 是阻塞的,handler 在事件流结束前不会返回。将 `c.Request.Context().Done()` 和 `eventChan <- ...` 作为同一个 `select` 的两个分支,确保发送操作本身能够响应客户端断开。
|
||||
2. **关闭 Channel**: 生产者完成发送后必须 `close(eventChan)`,否则 handler 会永远阻塞。
|
||||
3. **数据格式**: SSE 协议要求数据为 UTF-8。Touka 的 `Render` 方法会自动处理多行数据并加上必要的 `data:` 前缀。
|
||||
4. **超时管理**: SSE 连接通常是长连接,请确保您的反向代理(如 Nginx)配置了足够大的写超时时间。
|
||||
|
||||
## 优雅关闭与资源清理
|
||||
|
||||
|
|
@ -128,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) {
|
|||
2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。
|
||||
3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。
|
||||
|
||||
**注意:** 请务必使用 `RunShutdown`、`RunTLS` 或 `RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号。
|
||||
**注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)` 或 `touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文。
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ func main() {
|
|||
// 您也可以使用 StaticFS 服务根路径
|
||||
// r.StaticFS("/", http.FS(fsroot))
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
|||
2
ecw.go
2
ecw.go
|
|
@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool {
|
|||
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := ecw.w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface")
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
|
|
|||
59
ecw_benchmark_test.go
Normal file
59
ecw_benchmark_test.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrorCapturingResponseWriterResetClearsHeaderSnapshot(t *testing.T) {
|
||||
c, _ := CreateTestContext(nil)
|
||||
ecw := AcquireErrorCapturingResponseWriter(c)
|
||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
||||
|
||||
ecw.capturedErrorSignal = true
|
||||
ecw.Header().Set("Content-Type", "text/plain")
|
||||
ecw.Header().Add("X-Test", "one")
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
|
||||
ecw.reset(httptest.NewRecorder(), req, c, c.engine.errorHandle.handler)
|
||||
|
||||
if len(ecw.headerSnapshot) != 0 {
|
||||
t.Fatalf("expected header snapshot to be empty after reset, got %#v", ecw.headerSnapshot)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkErrorCapturingResponseWriterReset(b *testing.B) {
|
||||
c, _ := CreateTestContext(nil)
|
||||
ecw := AcquireErrorCapturingResponseWriter(c)
|
||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
||||
|
||||
rawWriter := httptest.NewRecorder()
|
||||
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
|
||||
keys := make([]string, 16)
|
||||
for i := range keys {
|
||||
keys[i] = http.CanonicalHeaderKey("X-Test-" + string(rune('A'+i)))
|
||||
}
|
||||
values := []string{"one", "two", "three"}
|
||||
for _, key := range keys {
|
||||
ecw.headerSnapshot[key] = values
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ecw.reset(rawWriter, req, c, c.engine.errorHandle.handler)
|
||||
for _, key := range keys {
|
||||
ecw.headerSnapshot[key] = values
|
||||
}
|
||||
}
|
||||
}
|
||||
384
engine.go
384
engine.go
|
|
@ -7,9 +7,11 @@ package touka
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"net/http"
|
||||
|
||||
|
|
@ -17,6 +19,7 @@ import (
|
|||
|
||||
"github.com/WJQSERVER-STUDIO/httpc"
|
||||
"github.com/fenthope/reco"
|
||||
"github.com/go-json-experiment/json"
|
||||
)
|
||||
|
||||
// Last 返回链中的最后一个处理函数
|
||||
|
|
@ -49,8 +52,14 @@ type Engine struct {
|
|||
|
||||
HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求
|
||||
|
||||
// LogReco 保留的 reco.Logger 字段
|
||||
// Deprecated: 使用 SetLogger/GetLogger 替代
|
||||
LogReco *reco.Logger
|
||||
|
||||
// logger 是新的日志接口,支持任意 Logger 实现
|
||||
// 优先级: logger > LogReco
|
||||
logger Logger
|
||||
|
||||
HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口
|
||||
|
||||
routesInfo []RouteInfo // 存储所有注册的路由信息
|
||||
|
|
@ -81,6 +90,11 @@ type Engine struct {
|
|||
|
||||
// GlobalMaxRequestBodySize 全局请求体Body大小限制
|
||||
GlobalMaxRequestBodySize int64
|
||||
|
||||
notFoundChain HandlersChain
|
||||
notFoundNoMethodChain HandlersChain
|
||||
unmatchedFSChain HandlersChain
|
||||
unmatchedFSNoMethodChain HandlersChain
|
||||
}
|
||||
|
||||
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
||||
|
|
@ -116,6 +130,90 @@ type ErrorHandle struct {
|
|||
|
||||
type ErrorHandler func(c *Context, code int, err error)
|
||||
|
||||
var errMethodNotAllowed = errors.New("method not allowed")
|
||||
var errNotFound = errors.New("not found")
|
||||
|
||||
type defaultErrorResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error())
|
||||
var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error())
|
||||
|
||||
func mustMarshalDefaultErrorBody(code int, errMsg string) []byte {
|
||||
body, err := json.Marshal(defaultErrorResponse{
|
||||
Code: code,
|
||||
Message: http.StatusText(code),
|
||||
Error: errMsg,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func writeDefaultErrorJSON(c *Context, code int, body []byte) {
|
||||
if c == nil || c.Writer == nil {
|
||||
return
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
c.Writer.WriteHeader(code)
|
||||
c.writeResponseBody(body, "failed to write default error response")
|
||||
c.Writer.Flush()
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
var methodNotAllowedHandler HandlerFunc = func(c *Context) {
|
||||
httpMethod := c.Request.Method
|
||||
requestPath := routeLookupPath(c.Request)
|
||||
engine := c.engine
|
||||
// 是否是OPTIONS方式
|
||||
if httpMethod == http.MethodOptions {
|
||||
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
||||
allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0])
|
||||
c.allowedMethodsBuf = allowedMethods[:0]
|
||||
if len(allowedMethods) > 0 {
|
||||
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
|
||||
allowHeader := c.allowHeaderBuf[:0]
|
||||
for i, method := range allowedMethods {
|
||||
if i > 0 {
|
||||
allowHeader = append(allowHeader, ',', ' ')
|
||||
}
|
||||
allowHeader = append(allowHeader, method...)
|
||||
}
|
||||
c.allowHeaderBuf = allowHeader[:0]
|
||||
c.Writer.Header().Set("Allow", string(allowHeader))
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
|
||||
tempSkippedNodes := GetTempSkippedNodes()
|
||||
for _, treeIter := range engine.methodTrees {
|
||||
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
|
||||
continue
|
||||
}
|
||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||
*tempSkippedNodes = (*tempSkippedNodes)[:0]
|
||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
|
||||
if value.handlers != nil {
|
||||
PutTempSkippedNodes(tempSkippedNodes)
|
||||
// 使用定义的ErrorHandle处理
|
||||
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
}
|
||||
PutTempSkippedNodes(tempSkippedNodes)
|
||||
}
|
||||
|
||||
var notFoundHandler HandlerFunc = func(c *Context) {
|
||||
engine := c.engine
|
||||
engine.errorHandle.handler(c, http.StatusNotFound, errNotFound)
|
||||
}
|
||||
|
||||
// defaultErrorHandle 默认错误处理
|
||||
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
|
||||
select {
|
||||
|
|
@ -126,16 +224,22 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是
|
|||
if c.Writer.Written() {
|
||||
return
|
||||
}
|
||||
if len(c.Errors) == 0 {
|
||||
switch {
|
||||
case code == http.StatusNotFound && errors.Is(err, errNotFound):
|
||||
writeDefaultErrorJSON(c, code, defaultNotFoundBody)
|
||||
return
|
||||
case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed):
|
||||
writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody)
|
||||
return
|
||||
}
|
||||
}
|
||||
// 输出json 状态码与状态码对应描述
|
||||
var errMsg string
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
}
|
||||
c.JSON(code, H{
|
||||
"code": code,
|
||||
"message": http.StatusText(code),
|
||||
"error": errMsg,
|
||||
})
|
||||
c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg})
|
||||
c.Writer.Flush()
|
||||
c.Abort()
|
||||
return
|
||||
|
|
@ -210,6 +314,7 @@ func New() *Engine {
|
|||
TLSServerConfigurator: nil,
|
||||
GlobalMaxRequestBodySize: -1,
|
||||
}
|
||||
engine.rebuildFallbackChains()
|
||||
engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background())
|
||||
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
||||
engine.SetDefaultProtocols()
|
||||
|
|
@ -265,16 +370,30 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) {
|
|||
// 是否开启MethodNotAllowed
|
||||
func (engine *Engine) SetHandleMethodNotAllowed(enable bool) {
|
||||
engine.HandleMethodNotAllowed = enable
|
||||
engine.rebuildFallbackChains()
|
||||
}
|
||||
|
||||
// SetLogger传入实例
|
||||
func (engine *Engine) SetLogger(logger *reco.Logger) {
|
||||
engine.LogReco = logger
|
||||
// SetLogger 传入 Logger 接口实例
|
||||
func (engine *Engine) SetLogger(logger Logger) {
|
||||
engine.logger = logger
|
||||
// 同步更新 LogReco 以保持向后兼容
|
||||
if rl, ok := logger.(*reco.Logger); ok {
|
||||
engine.LogReco = rl
|
||||
} else {
|
||||
engine.LogReco = nil
|
||||
}
|
||||
}
|
||||
|
||||
// 配置日志LoggerCfg
|
||||
// GetLogger 返回 Logger 接口实例
|
||||
func (engine *Engine) GetLogger() Logger {
|
||||
return engine.logger
|
||||
}
|
||||
|
||||
// SetLoggerCfg 使用 reco.Config 配置日志
|
||||
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
|
||||
engine.LogReco = NewLogger(logcfg)
|
||||
logger := NewLogger(logcfg)
|
||||
engine.logger = logger
|
||||
engine.LogReco = logger
|
||||
}
|
||||
|
||||
// 设置自定义错误处理
|
||||
|
|
@ -305,6 +424,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF
|
|||
engine.unMatchFS.ServeUnmatchedAsFS = false
|
||||
engine.UnMatchFSRoutes = nil
|
||||
}
|
||||
engine.rebuildFallbackChains()
|
||||
}
|
||||
|
||||
// 获取默认Protocol配置
|
||||
|
|
@ -340,11 +460,28 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
|
|||
}()
|
||||
}
|
||||
|
||||
func cloneServerProtocols(protocols *http.Protocols) *http.Protocols {
|
||||
if protocols == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *protocols
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func applyServerProtocols(srv *http.Server, protocols *http.Protocols) {
|
||||
if protocols != nil {
|
||||
srv.Protocols = cloneServerProtocols(protocols)
|
||||
if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() {
|
||||
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyDefaultServerConfig 应用框架的默认配置到 http.Server
|
||||
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
|
||||
if engine.serverProtocols != nil {
|
||||
srv.Protocols = engine.serverProtocols
|
||||
}
|
||||
applyServerProtocols(srv, engine.serverProtocols)
|
||||
}
|
||||
|
||||
// 配置全局Req Body大小限制
|
||||
|
|
@ -473,66 +610,64 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
|
|||
|
||||
// 405中间件
|
||||
func MethodNotAllowed() HandlerFunc {
|
||||
return func(c *Context) {
|
||||
httpMethod := c.Request.Method
|
||||
requestPath := c.Request.URL.Path
|
||||
engine := c.engine
|
||||
// 是否是OPTIONS方式
|
||||
if httpMethod == http.MethodOptions {
|
||||
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
||||
allowedMethods := []string{}
|
||||
for _, treeIter := range engine.methodTrees {
|
||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||
tempSkippedNodes := GetTempSkippedNodes()
|
||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
||||
PutTempSkippedNodes(tempSkippedNodes)
|
||||
if value.handlers != nil {
|
||||
allowedMethods = append(allowedMethods, treeIter.method)
|
||||
}
|
||||
}
|
||||
if len(allowedMethods) > 0 {
|
||||
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
|
||||
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
}
|
||||
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
|
||||
for _, treeIter := range engine.methodTrees {
|
||||
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
|
||||
continue
|
||||
}
|
||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||
tempSkippedNodes := GetTempSkippedNodes()
|
||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
|
||||
PutTempSkippedNodes(tempSkippedNodes)
|
||||
if value.handlers != nil {
|
||||
// 使用定义的ErrorHandle处理
|
||||
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return methodNotAllowedHandler
|
||||
}
|
||||
|
||||
// 404最后处理
|
||||
func NotFound() HandlerFunc {
|
||||
return func(c *Context) {
|
||||
engine := c.engine
|
||||
engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found"))
|
||||
}
|
||||
return notFoundHandler
|
||||
}
|
||||
|
||||
// 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理)
|
||||
func (Engine *Engine) NoRoute(handler HandlerFunc) {
|
||||
Engine.noRoute = handler
|
||||
Engine.noRoutes = nil
|
||||
Engine.rebuildFallbackChains()
|
||||
}
|
||||
|
||||
// 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理)
|
||||
func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) {
|
||||
Engine.noRoute = nil
|
||||
Engine.noRoutes = handlerFuncs
|
||||
Engine.rebuildFallbackChains()
|
||||
}
|
||||
|
||||
func (engine *Engine) rebuildFallbackChains() {
|
||||
buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain {
|
||||
finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound
|
||||
if includeMethodNotAllowed {
|
||||
finalSize++
|
||||
}
|
||||
if includeUnmatchedFS {
|
||||
finalSize += len(engine.UnMatchFSRoutes)
|
||||
}
|
||||
if engine.noRoute != nil {
|
||||
finalSize++
|
||||
} else {
|
||||
finalSize += len(engine.noRoutes)
|
||||
}
|
||||
|
||||
chain := make(HandlersChain, 0, finalSize)
|
||||
chain = append(chain, engine.globalHandlers...)
|
||||
if includeMethodNotAllowed {
|
||||
chain = append(chain, methodNotAllowedHandler)
|
||||
}
|
||||
if includeUnmatchedFS {
|
||||
chain = append(chain, engine.UnMatchFSRoutes...)
|
||||
}
|
||||
if engine.noRoute != nil {
|
||||
chain = append(chain, engine.noRoute)
|
||||
} else if len(engine.noRoutes) > 0 {
|
||||
chain = append(chain, engine.noRoutes...)
|
||||
}
|
||||
chain = append(chain, notFoundHandler)
|
||||
return chain
|
||||
}
|
||||
|
||||
engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false)
|
||||
engine.notFoundNoMethodChain = buildChain(false, false)
|
||||
engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS)
|
||||
engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS)
|
||||
}
|
||||
|
||||
// combineHandlers 组合多个处理函数链为一个
|
||||
|
|
@ -547,8 +682,9 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
|
|||
|
||||
// Use 将全局中间件添加到 Engine
|
||||
// 这些中间件将应用于所有注册的路由
|
||||
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter {
|
||||
func (engine *Engine) Use(middleware ...HandlerFunc) Router {
|
||||
engine.globalHandlers = append(engine.globalHandlers, middleware...)
|
||||
engine.rebuildFallbackChains()
|
||||
return engine
|
||||
}
|
||||
|
||||
|
|
@ -615,7 +751,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo {
|
|||
|
||||
// Group 创建一个新的路由组
|
||||
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
|
||||
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter {
|
||||
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router {
|
||||
return &RouterGroup{
|
||||
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
|
||||
basePath: resolveRoutePath("/", relativePath),
|
||||
|
|
@ -624,7 +760,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute
|
|||
}
|
||||
|
||||
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
|
||||
// 它也实现了 IRouter 接口,允许嵌套分组
|
||||
// 它也实现了 Router 接口,允许嵌套分组
|
||||
type RouterGroup struct {
|
||||
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
|
||||
basePath string // 组路径前缀
|
||||
|
|
@ -633,7 +769,7 @@ type RouterGroup struct {
|
|||
|
||||
// Use 将中间件应用于当前路由组
|
||||
// 这些中间件将应用于当前组及其子组的所有路由
|
||||
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter {
|
||||
func (group *RouterGroup) Use(middleware ...HandlerFunc) Router {
|
||||
group.Handlers = append(group.Handlers, middleware...)
|
||||
return group
|
||||
}
|
||||
|
|
@ -679,7 +815,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) {
|
|||
}
|
||||
|
||||
// Group 为当前组创建一个新的子组
|
||||
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter {
|
||||
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router {
|
||||
return &RouterGroup{
|
||||
Handlers: group.engine.combineHandlers(group.Handlers, handlers),
|
||||
basePath: resolveRoutePath(group.basePath, relativePath),
|
||||
|
|
@ -704,8 +840,13 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||
// handleRequest 负责根据请求查找路由并执行相应的处理函数链
|
||||
// 这是路由查找和执行的核心逻辑
|
||||
func (engine *Engine) handleRequest(c *Context) {
|
||||
if isGeneralOptionsRequest(c.Request) {
|
||||
engine.handleGeneralOptions(c)
|
||||
return
|
||||
}
|
||||
|
||||
httpMethod := c.Request.Method
|
||||
requestPath := c.Request.URL.Path
|
||||
requestPath := routeLookupPath(c.Request)
|
||||
|
||||
// 查找对应的路由树的根节点
|
||||
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
|
||||
|
|
@ -725,7 +866,7 @@ func (engine *Engine) handleRequest(c *Context) {
|
|||
}
|
||||
|
||||
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
|
||||
if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向
|
||||
if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向
|
||||
if value.tsr && engine.RedirectTrailingSlash {
|
||||
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
|
||||
redirectPath := requestPath
|
||||
|
|
@ -737,51 +878,98 @@ func (engine *Engine) handleRequest(c *Context) {
|
|||
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
|
||||
return
|
||||
}
|
||||
// 尝试不区分大小写的查找
|
||||
// 直接在 rootNode 上调用 findCaseInsensitivePath 方法
|
||||
ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash)
|
||||
if found && engine.RedirectFixedPath {
|
||||
c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径
|
||||
if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) {
|
||||
// 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历.
|
||||
ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash)
|
||||
if found {
|
||||
c.fixedPathBuf = ciPath[:0]
|
||||
c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径
|
||||
return
|
||||
}
|
||||
c.fixedPathBuf = c.fixedPathBuf[:0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建处理链
|
||||
// 组合全局中间件和路由处理函数
|
||||
handlers := engine.globalHandlers
|
||||
|
||||
// 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由
|
||||
// 则在全局中间件之后添加 MethodNotAllowed 处理器
|
||||
if engine.HandleMethodNotAllowed {
|
||||
handlers = append(handlers, MethodNotAllowed())
|
||||
}
|
||||
|
||||
// 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed
|
||||
// 则在处理链的最后添加 UnMatchFS 处理器
|
||||
if engine.unMatchFS.ServeUnmatchedAsFS {
|
||||
/*
|
||||
var unMatchFSHandle = c.engine.unMatchFileServer
|
||||
handlers = append(handlers, unMatchFSHandle)
|
||||
*/
|
||||
handlers = append(handlers, engine.UnMatchFSRoutes...)
|
||||
c.handlers = engine.unmatchedFSChain
|
||||
} else {
|
||||
c.handlers = engine.notFoundChain
|
||||
}
|
||||
|
||||
// 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS
|
||||
// 则在处理链的最后添加 NoRoute 处理器
|
||||
if engine.noRoute != nil {
|
||||
handlers = append(handlers, engine.noRoute)
|
||||
} else if len(engine.noRoutes) > 0 {
|
||||
handlers = append(handlers, engine.noRoutes...)
|
||||
}
|
||||
|
||||
handlers = append(handlers, NotFound())
|
||||
|
||||
c.handlers = handlers
|
||||
c.Next() // 执行处理函数链
|
||||
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
||||
}
|
||||
|
||||
func routeLookupPath(req *http.Request) string {
|
||||
if req == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") {
|
||||
return "/" + req.RequestURI
|
||||
}
|
||||
if isGeneralOptionsRequest(req) {
|
||||
return ""
|
||||
}
|
||||
if req.URL == nil {
|
||||
return ""
|
||||
}
|
||||
return req.URL.Path
|
||||
}
|
||||
|
||||
func isGeneralOptionsRequest(req *http.Request) bool {
|
||||
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
|
||||
}
|
||||
|
||||
func shouldTryFixedPathLookup(path string, root *node) bool {
|
||||
if root != nil && root.hasCaseInsensitivePath {
|
||||
return true
|
||||
}
|
||||
for i := 0; i < len(path); i++ {
|
||||
c := path[i]
|
||||
if c >= utf8.RuneSelf {
|
||||
return true
|
||||
}
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string {
|
||||
if cap(allowedMethods) < len(engine.methodTrees) {
|
||||
allowedMethods = make([]string, 0, len(engine.methodTrees))
|
||||
} else {
|
||||
allowedMethods = allowedMethods[:0]
|
||||
}
|
||||
tempSkippedNodes := GetTempSkippedNodes()
|
||||
for _, treeIter := range engine.methodTrees {
|
||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||
*tempSkippedNodes = (*tempSkippedNodes)[:0]
|
||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
||||
if value.handlers != nil {
|
||||
allowedMethods = append(allowedMethods, treeIter.method)
|
||||
}
|
||||
}
|
||||
PutTempSkippedNodes(tempSkippedNodes)
|
||||
return allowedMethods
|
||||
}
|
||||
|
||||
func (engine *Engine) handleGeneralOptions(c *Context) {
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Length", "0")
|
||||
if c.Request.ContentLength != 0 {
|
||||
mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10)
|
||||
_, _ = io.Copy(io.Discard, mb)
|
||||
}
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
|
||||
// 它可以用于在长连接 (如 SSE) 中监听关闭信号.
|
||||
func (engine *Engine) Context() context.Context {
|
||||
|
|
|
|||
71
engine_benchmark_test.go
Normal file
71
engine_benchmark_test.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var benchmarkStatusCode int
|
||||
|
||||
func buildServeHTTPBenchmarkEngine() *Engine {
|
||||
engine := New()
|
||||
engine.GET("/api/v1/users", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
engine.GET("/api/v1/users/:id", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
engine.POST("/api/v1/users", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
return engine
|
||||
}
|
||||
|
||||
func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) {
|
||||
b.Helper()
|
||||
|
||||
req, err := http.NewRequest(method, path, nil)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rr = httptest.NewRecorder()
|
||||
engine.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
benchmarkStatusCode = rr.Code
|
||||
}
|
||||
|
||||
func BenchmarkServeHTTP(b *testing.B) {
|
||||
engine := buildServeHTTPBenchmarkEngine()
|
||||
|
||||
b.Run("StaticHit", func(b *testing.B) {
|
||||
benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users")
|
||||
})
|
||||
|
||||
b.Run("NotFound", func(b *testing.B) {
|
||||
benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist")
|
||||
})
|
||||
|
||||
b.Run("MethodNotAllowed", func(b *testing.B) {
|
||||
benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users")
|
||||
})
|
||||
|
||||
b.Run("OptionsAllow", func(b *testing.B) {
|
||||
benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users")
|
||||
})
|
||||
|
||||
b.Run("FixedPathRedirect", func(b *testing.B) {
|
||||
benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS")
|
||||
})
|
||||
}
|
||||
306
engine_test.go
Normal file
306
engine_test.go
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"html/template"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type failingResponseWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
err error
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.status == 0 {
|
||||
w.status = statusCode
|
||||
}
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) Write(p []byte) (int, error) {
|
||||
if w.status == 0 {
|
||||
w.status = http.StatusOK
|
||||
}
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) Flush() {}
|
||||
|
||||
func (w *failingResponseWriter) Status() int {
|
||||
return w.status
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) Written() bool {
|
||||
return w.status != 0
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) IsHijacked() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
func TestHandleRequestRedirectFixedPath(t *testing.T) {
|
||||
engine := New()
|
||||
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil)
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" {
|
||||
t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) {
|
||||
engine := New()
|
||||
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil)
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) {
|
||||
engine := New()
|
||||
engine.GET("/Users/Profile", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil)
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "/Users/Profile" {
|
||||
t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) {
|
||||
engine := New()
|
||||
engine.GET("/Users/Profile", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("unexpected panic for fixed-path miss: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil)
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) {
|
||||
engine := New()
|
||||
engine.NoRoute(func(c *Context) {
|
||||
c.Writer.Header().Set("X-NoRoute", "hit")
|
||||
c.Next()
|
||||
})
|
||||
|
||||
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code)
|
||||
}
|
||||
if got := rr.Header().Get("X-NoRoute"); got != "hit" {
|
||||
t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) {
|
||||
engine := New()
|
||||
engine.GET("/users", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
engine.NoRoute(func(c *Context) {
|
||||
c.Writer.Header().Set("X-NoRoute", "hit")
|
||||
c.Next()
|
||||
})
|
||||
|
||||
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
|
||||
if rr.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
|
||||
}
|
||||
if got := rr.Header().Get("X-NoRoute"); got != "" {
|
||||
t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) {
|
||||
engine := New()
|
||||
engine.GET("/users", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
engine.POST("/users", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
allow := rr.Header().Get("Allow")
|
||||
if allow != "GET, POST" && allow != "POST, GET" {
|
||||
t.Fatalf("expected Allow header to list matching methods, got %q", allow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultErrorHandleJSONShape(t *testing.T) {
|
||||
engine := New()
|
||||
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
|
||||
if rr.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code)
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
|
||||
}
|
||||
if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" {
|
||||
t.Fatalf("unexpected error payload: %+v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultMethodNotAllowedJSONShape(t *testing.T) {
|
||||
engine := New()
|
||||
engine.GET("/users", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
|
||||
if rr.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
|
||||
}
|
||||
if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" {
|
||||
t.Fatalf("unexpected error payload: %+v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) {
|
||||
engine := New()
|
||||
engine.SetErrorHandler(func(c *Context, code int, err error) {
|
||||
c.Writer.Header().Set("X-Custom-Error", "1")
|
||||
c.String(code, "custom:%v", err)
|
||||
})
|
||||
engine.GET("/users", func(c *Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
|
||||
if rr.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
|
||||
}
|
||||
if got := rr.Header().Get("X-Custom-Error"); got != "1" {
|
||||
t.Fatalf("expected custom error header, got %q", got)
|
||||
}
|
||||
if rr.Body.String() != "custom:method not allowed" {
|
||||
t.Fatalf("expected custom error body, got %q", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseHelpersCaptureWriteErrors(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
run func(*Context)
|
||||
}{
|
||||
{name: "Raw", run: func(c *Context) { c.Raw(http.StatusOK, "application/octet-stream", []byte("payload")) }},
|
||||
{name: "String", run: func(c *Context) { c.String(http.StatusOK, "value=%d", 1) }},
|
||||
{name: "Text", run: func(c *Context) { c.Text(http.StatusOK, "payload") }},
|
||||
{name: "JSONBuf", run: func(c *Context) { c.JSONBuf(http.StatusOK, map[string]string{"a": "b"}) }},
|
||||
{name: "GOBBuf", run: func(c *Context) { c.GOBBuf(http.StatusOK, struct{ A string }{A: "b"}) }},
|
||||
{name: "WANFBuf", run: func(c *Context) { c.WANFBuf(http.StatusOK, map[string]string{"a": "b"}) }},
|
||||
{name: "HTMLFallback", run: func(c *Context) { c.HTML(http.StatusOK, "page", map[string]string{"a": "b"}) }},
|
||||
{name: "HTMLBuf", run: func(c *Context) {
|
||||
c.engine.HTMLRender = template.Must(template.New("page").Parse(`{{.a}}`))
|
||||
c.HTMLBuf(http.StatusOK, "page", map[string]string{"a": "b"})
|
||||
}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
writerErr := errors.New("write failed")
|
||||
w := &failingResponseWriter{err: writerErr}
|
||||
c, _ := CreateTestContext(w)
|
||||
|
||||
tc.run(c)
|
||||
|
||||
if got := len(c.Errors); got != 1 {
|
||||
t.Fatalf("expected exactly one captured error, got %d", got)
|
||||
}
|
||||
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
|
||||
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultErrorFastPathCapturesWriteErrors(t *testing.T) {
|
||||
writerErr := errors.New("write failed")
|
||||
w := &failingResponseWriter{err: writerErr}
|
||||
engine := New()
|
||||
c, _ := CreateTestContext(w)
|
||||
c.engine = engine
|
||||
req, err := http.NewRequest(http.MethodGet, "/missing", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
c.reset(w, req)
|
||||
|
||||
defaultErrorHandle(c, http.StatusNotFound, errNotFound)
|
||||
|
||||
if len(c.Errors) == 0 {
|
||||
t.Fatal("expected write error to be captured")
|
||||
}
|
||||
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
|
||||
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
|
||||
}
|
||||
if c.Writer.Status() != http.StatusNotFound {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusNotFound, c.Writer.Status())
|
||||
}
|
||||
if !c.IsAborted() {
|
||||
t.Fatal("expected fast path to abort context")
|
||||
}
|
||||
}
|
||||
103
examples/httpc/main.go
Normal file
103
examples/httpc/main.go
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/infinite-iroha/touka"
|
||||
)
|
||||
|
||||
func main() {
|
||||
r := touka.Default()
|
||||
|
||||
// 示例 1:简单 GET 请求(自动关联请求 Context)
|
||||
r.GET("/proxy", func(c *touka.Context) {
|
||||
// 使用 HTTPC() 方法,自动关联请求 Context
|
||||
// 当客户端断开连接时,出站请求也会自动取消
|
||||
body, err := c.HTTPC().
|
||||
GET("https://httpbin.org/get").
|
||||
Text()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.String(http.StatusOK, "%s", body)
|
||||
})
|
||||
|
||||
// 示例 2:带 Header 的 POST 请求
|
||||
r.POST("/users", func(c *touka.Context) {
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// 链式调用,保持 httpc 风格
|
||||
// 注意:SetJSONBody 返回 (*RequestBuilder, error)
|
||||
rb, err := c.HTTPC().
|
||||
POST("https://httpbin.org/post").
|
||||
SetHeader("X-API-Key", "secret").
|
||||
SetJSONBody(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := rb.DecodeJSON(&result); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, result)
|
||||
})
|
||||
|
||||
// 示例 3:带查询参数的请求
|
||||
r.GET("/search", func(c *touka.Context) {
|
||||
query := c.DefaultQuery("q", "")
|
||||
page := c.DefaultQuery("page", "1")
|
||||
|
||||
var result struct {
|
||||
Items []string `json:"items"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
err := c.HTTPC().
|
||||
GET("https://httpbin.org/get").
|
||||
SetQueryParam("q", query).
|
||||
SetQueryParam("page", page).
|
||||
DecodeJSON(&result)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, result)
|
||||
})
|
||||
|
||||
// 示例 4:使用底层 httpc.Client(旧方式,仍可用但不推荐)
|
||||
r.GET("/legacy", func(c *touka.Context) {
|
||||
// 旧方式:需要手动 WithContext
|
||||
body, err := c.Client().
|
||||
GET("https://httpbin.org/get").
|
||||
WithContext(c.Context()).
|
||||
Text()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.String(http.StatusOK, "%s", body)
|
||||
})
|
||||
|
||||
fmt.Println("Server running on :8080")
|
||||
fmt.Println("Try:")
|
||||
fmt.Println(" curl http://localhost:8080/proxy")
|
||||
fmt.Println(" curl -X POST -d '{\"name\":\"test\",\"email\":\"test@example.com\"}' http://localhost:8080/users")
|
||||
fmt.Println(" curl 'http://localhost:8080/search?q=golang&page=1'")
|
||||
|
||||
// r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
71
examples/logger_slog/main.go
Normal file
71
examples/logger_slog/main.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/infinite-iroha/touka"
|
||||
)
|
||||
|
||||
// SlogAdapter 将 slog.Logger 适配到 touka.Logger 接口
|
||||
type SlogAdapter struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewSlogAdapter(handler slog.Handler) *SlogAdapter {
|
||||
return &SlogAdapter{
|
||||
logger: slog.New(handler),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SlogAdapter) Debugf(format string, args ...any) {
|
||||
s.logger.Debug(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (s *SlogAdapter) Infof(format string, args ...any) {
|
||||
s.logger.Info(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (s *SlogAdapter) Warnf(format string, args ...any) {
|
||||
s.logger.Warn(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (s *SlogAdapter) Errorf(format string, args ...any) {
|
||||
s.logger.Error(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (s *SlogAdapter) Fatalf(format string, args ...any) {
|
||||
s.logger.Error(fmt.Sprintf(format, args...))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func (s *SlogAdapter) Panicf(format string, args ...any) {
|
||||
s.logger.Error(fmt.Sprintf(format, args...))
|
||||
panic(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func main() {
|
||||
engine := touka.New()
|
||||
|
||||
// 使用 slog 替换默认的 reco.Logger
|
||||
handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
})
|
||||
slogAdapter := NewSlogAdapter(handler)
|
||||
engine.SetLogger(slogAdapter)
|
||||
|
||||
engine.GET("/", func(c *touka.Context) {
|
||||
c.Infof("request received: %s", c.Request.URL.Path)
|
||||
c.JSON(http.StatusOK, map[string]string{"message": "hello"})
|
||||
})
|
||||
|
||||
// 也可以获取 Logger 接口
|
||||
logger := engine.GetLogger()
|
||||
logger.Debugf("engine started")
|
||||
|
||||
// 也可以直接使用 slog
|
||||
slog.Info("Server running", "addr", ":8080")
|
||||
// engine.Run(":8080")
|
||||
}
|
||||
5
go.mod
5
go.mod
|
|
@ -3,14 +3,15 @@ module github.com/infinite-iroha/touka
|
|||
go 1.26
|
||||
|
||||
require (
|
||||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2
|
||||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3
|
||||
github.com/WJQSERVER-STUDIO/httpc v0.9.0
|
||||
github.com/WJQSERVER/wanf v0.0.8
|
||||
github.com/fenthope/reco v0.0.5
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
|
||||
golang.org/x/net v0.52.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
)
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -1,5 +1,7 @@
|
|||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM=
|
||||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok=
|
||||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 h1:Hc1O6D50U3URkdSzfQ/SgeUU750wUBCYhefdvAbE2Ck=
|
||||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3/go.mod h1:nFQzepAwwdj5Hp5U+X19l4FVvsaOSBTW41BzfI/CkMA=
|
||||
github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4=
|
||||
github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg=
|
||||
github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8=
|
||||
|
|
@ -12,3 +14,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
|
|||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
|
|
|
|||
88
http2xconnect.go
Normal file
88
http2xconnect.go
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var enableHTTP2ExtendedConnectOnce sync.Once
|
||||
|
||||
//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol
|
||||
var xnetDisableHTTP2ExtendedConnectProtocol bool
|
||||
|
||||
func enableHTTP2ExtendedConnectProtocol() {
|
||||
enableHTTP2ExtendedConnectOnce.Do(func() {
|
||||
xnetDisableHTTP2ExtendedConnectProtocol = false
|
||||
})
|
||||
}
|
||||
|
||||
func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
|
||||
if srv == nil {
|
||||
return nil
|
||||
}
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
return http2.ConfigureServer(srv, nil)
|
||||
}
|
||||
|
||||
func newHTTP2ExtendedConnectTransport() http.RoundTripper {
|
||||
enableHTTP2ExtendedConnectProtocol()
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Protocols = new(http.Protocols)
|
||||
transport.Protocols.SetHTTP1(true)
|
||||
transport.Protocols.SetHTTP2(true)
|
||||
return transport
|
||||
}
|
||||
|
||||
func newHTTP1BridgeTransport() http.RoundTripper {
|
||||
return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}})
|
||||
}
|
||||
|
||||
func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper {
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Protocols = new(http.Protocols)
|
||||
transport.Protocols.SetHTTP1(true)
|
||||
if tlsConfig == nil {
|
||||
transport.TLSClientConfig = &tls.Config{}
|
||||
} else {
|
||||
transport.TLSClientConfig = tlsConfig.Clone()
|
||||
}
|
||||
if len(transport.TLSClientConfig.NextProtos) == 0 {
|
||||
transport.TLSClientConfig.NextProtos = []string{"http/1.1"}
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
func newH2CTransport() http.RoundTripper {
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Protocols = new(http.Protocols)
|
||||
transport.Protocols.SetUnencryptedHTTP2(true)
|
||||
return transport
|
||||
}
|
||||
|
||||
func cloneDefaultTransport() *http.Transport {
|
||||
if transport, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
return transport.Clone()
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
150
iox_benchmark_test.go
Normal file
150
iox_benchmark_test.go
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/WJQSERVER-STUDIO/go-utils/iox"
|
||||
)
|
||||
|
||||
type benchmarkResetReader struct {
|
||||
data []byte
|
||||
off int
|
||||
}
|
||||
|
||||
func (r *benchmarkResetReader) Read(p []byte) (int, error) {
|
||||
if r.off >= len(r.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, r.data[r.off:])
|
||||
r.off += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *benchmarkResetReader) Reset() {
|
||||
r.off = 0
|
||||
}
|
||||
|
||||
type benchmarkDiscardWriter struct{}
|
||||
|
||||
func (benchmarkDiscardWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
var benchmarkIOXResult int64
|
||||
var benchmarkIOXBytes []byte
|
||||
|
||||
func BenchmarkIOXCopyComparison(b *testing.B) {
|
||||
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
|
||||
|
||||
b.Run("io.Copy", func(b *testing.B) {
|
||||
r := &benchmarkResetReader{data: payload}
|
||||
w := benchmarkDiscardWriter{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.Reset()
|
||||
n, err := io.Copy(w, r)
|
||||
if err != nil {
|
||||
b.Fatalf("io.Copy failed: %v", err)
|
||||
}
|
||||
benchmarkIOXResult = n
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("iox.Copy", func(b *testing.B) {
|
||||
r := &benchmarkResetReader{data: payload}
|
||||
w := benchmarkDiscardWriter{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.Reset()
|
||||
n, err := iox.Copy(w, r)
|
||||
if err != nil {
|
||||
b.Fatalf("iox.Copy failed: %v", err)
|
||||
}
|
||||
benchmarkIOXResult = n
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkIOXCopyBufferComparison(b *testing.B) {
|
||||
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
|
||||
|
||||
b.Run("io.CopyBuffer", func(b *testing.B) {
|
||||
r := &benchmarkResetReader{data: payload}
|
||||
w := benchmarkDiscardWriter{}
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.Reset()
|
||||
n, err := io.CopyBuffer(w, r, buf)
|
||||
if err != nil {
|
||||
b.Fatalf("io.CopyBuffer failed: %v", err)
|
||||
}
|
||||
benchmarkIOXResult = n
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("iox.CopyBuffer", func(b *testing.B) {
|
||||
r := &benchmarkResetReader{data: payload}
|
||||
w := benchmarkDiscardWriter{}
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.Reset()
|
||||
n, err := iox.CopyBuffer(w, r, buf)
|
||||
if err != nil {
|
||||
b.Fatalf("iox.CopyBuffer failed: %v", err)
|
||||
}
|
||||
benchmarkIOXResult = n
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkIOXReadAllComparison(b *testing.B) {
|
||||
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
|
||||
|
||||
b.Run("io.ReadAll", func(b *testing.B) {
|
||||
r := &benchmarkResetReader{data: payload}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.Reset()
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
b.Fatalf("io.ReadAll failed: %v", err)
|
||||
}
|
||||
benchmarkIOXBytes = data
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("iox.ReadAll", func(b *testing.B) {
|
||||
r := &benchmarkResetReader{data: payload}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.Reset()
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
b.Fatalf("iox.ReadAll failed: %v", err)
|
||||
}
|
||||
benchmarkIOXBytes = data
|
||||
}
|
||||
})
|
||||
}
|
||||
23
logger.go
Normal file
23
logger.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2024 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
// Logger 是日志接口,支持多种日志库实现(reco、zap、logrus 等)
|
||||
// 用户可以通过实现此接口来替换默认的日志实现
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...any)
|
||||
Infof(format string, args ...any)
|
||||
Warnf(format string, args ...any)
|
||||
Errorf(format string, args ...any)
|
||||
Fatalf(format string, args ...any)
|
||||
Panicf(format string, args ...any)
|
||||
}
|
||||
|
||||
// CloserLogger 可选扩展接口,支持关闭操作
|
||||
// 如果 Logger 实现了此接口,Engine 在关闭时会调用 Close()
|
||||
type CloserLogger interface {
|
||||
Logger
|
||||
Close() error
|
||||
}
|
||||
|
|
@ -39,7 +39,16 @@ func CloseLogger(logger *reco.Logger) {
|
|||
}
|
||||
}
|
||||
|
||||
// CloseLogger 关闭 Engine 的日志实现
|
||||
// 如果 logger 实现了 CloserLogger 接口,会调用其 Close 方法
|
||||
func (engine *Engine) CloseLogger() {
|
||||
if cl, ok := engine.logger.(CloserLogger); ok {
|
||||
if err := cl.Close(); err != nil {
|
||||
log.Printf("Close Logger Error: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
// 兼容旧代码
|
||||
if engine.LogReco != nil {
|
||||
CloseLogger(engine.LogReco)
|
||||
}
|
||||
|
|
|
|||
67
maxreader.go
67
maxreader.go
|
|
@ -23,19 +23,21 @@ type maxBytesReader struct {
|
|||
n int64
|
||||
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
|
||||
read atomic.Int64
|
||||
// emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读.
|
||||
emptyAtLimit atomic.Bool
|
||||
}
|
||||
|
||||
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
|
||||
// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误.
|
||||
//
|
||||
// 如果 r 为 nil, 会 panic.
|
||||
// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r.
|
||||
// 如果 n 小于等于 0, 则读取不受限制, 直接返回原始的 r.
|
||||
func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
||||
if r == nil {
|
||||
panic("NewMaxBytesReader called with a nil reader")
|
||||
}
|
||||
// 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser.
|
||||
if n < 0 {
|
||||
// 如果限制为非正数, 意味着不限制, 直接返回原始的 ReadCloser.
|
||||
if n <= 0 {
|
||||
return r
|
||||
}
|
||||
return &maxBytesReader{
|
||||
|
|
@ -46,48 +48,53 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
|||
|
||||
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
|
||||
func (mbr *maxBytesReader) Read(p []byte) (int, error) {
|
||||
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
|
||||
readSoFar := mbr.read.Load()
|
||||
|
||||
// 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误.
|
||||
if readSoFar >= mbr.n {
|
||||
return 0, ErrBodyTooLarge
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// 计算当前还可以读取多少字节.
|
||||
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
|
||||
readSoFar := mbr.read.Load()
|
||||
remaining := mbr.n - readSoFar
|
||||
if remaining < 0 {
|
||||
return 0, ErrBodyTooLarge
|
||||
}
|
||||
if remaining == 0 {
|
||||
var probe [1]byte
|
||||
n, err := mbr.r.Read(probe[:])
|
||||
if n > 0 {
|
||||
mbr.read.Add(int64(n))
|
||||
return 0, ErrBodyTooLarge
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if mbr.emptyAtLimit.Swap(true) {
|
||||
return 0, ErrBodyTooLarge
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
mbr.emptyAtLimit.Store(false)
|
||||
|
||||
// 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度.
|
||||
// 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数.
|
||||
if int64(len(p)) > remaining {
|
||||
p = p[:remaining]
|
||||
// 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。
|
||||
if int64(len(p))-1 > remaining {
|
||||
p = p[:remaining+1]
|
||||
}
|
||||
|
||||
// 从底层 Reader 读取数据.
|
||||
n, err := mbr.r.Read(p)
|
||||
|
||||
// 如果实际读取到了数据, 更新原子计数器.
|
||||
if int64(n) <= remaining {
|
||||
if n > 0 {
|
||||
readSoFar = mbr.read.Add(int64(n))
|
||||
mbr.read.Add(int64(n))
|
||||
}
|
||||
|
||||
// 如果底层 Read 返回错误 (例如 io.EOF).
|
||||
if err != nil {
|
||||
// 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束.
|
||||
// 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了.
|
||||
return n, err
|
||||
}
|
||||
|
||||
// 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误.
|
||||
// 这是处理"跨越"限制情况的关键.
|
||||
if readSoFar > mbr.n {
|
||||
// 返回实际读取的字节数 n, 并附上超限错误.
|
||||
// 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭.
|
||||
return n, ErrBodyTooLarge
|
||||
// 读取结果跨过了限制,只向上层暴露允许的部分。
|
||||
if remaining > 0 {
|
||||
mbr.read.Add(remaining)
|
||||
}
|
||||
|
||||
// 一切正常, 返回读取的字节数和 nil 错误.
|
||||
return n, nil
|
||||
return int(remaining), ErrBodyTooLarge
|
||||
}
|
||||
|
||||
// Close 方法关闭底层的 ReadCloser, 保证资源释放.
|
||||
|
|
|
|||
113
mergectx.go
113
mergectx.go
|
|
@ -11,18 +11,16 @@ import (
|
|||
)
|
||||
|
||||
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
|
||||
// 嵌入 cancelCtx 作为基础 context, 支持 cause 传播.
|
||||
// deadlineCtx 作为 cancelCtx 的子 context, 确保 deadline 到期时 cancelCtx 也被取消.
|
||||
type mergedContext struct {
|
||||
// 嵌入一个基础 context, 它持有最早的 deadline 和取消信号.
|
||||
context.Context
|
||||
// 保存了所有的父 context, 用于 Value() 方法的查找.
|
||||
parents []context.Context
|
||||
// 用于手动取消此 mergedContext 的函数.
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// MergeCtx 创建并返回一个新的 context.Context.
|
||||
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
|
||||
// 自动被取消 (逻辑或关系).
|
||||
// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context.
|
||||
//
|
||||
// 新的 context 会继承:
|
||||
// - Deadline: 所有父 context 中最早的截止时间.
|
||||
|
|
@ -32,7 +30,8 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
|
|||
return context.WithCancel(context.Background())
|
||||
}
|
||||
if len(parents) == 1 {
|
||||
return context.WithCancel(parents[0])
|
||||
ctx, cancel := context.WithCancelCause(parents[0])
|
||||
return ctx, func() { cancel(nil) }
|
||||
}
|
||||
|
||||
var earliestDeadline time.Time
|
||||
|
|
@ -44,79 +43,93 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
|
|||
}
|
||||
}
|
||||
|
||||
var baseCtx context.Context
|
||||
var baseCancel context.CancelFunc
|
||||
// cancelCtx 作为基础 context, 提供 CancelCauseFunc 以支持 cause 传播.
|
||||
cancelCtx, cancelCause := context.WithCancelCause(context.Background())
|
||||
|
||||
// deadlineCtx 作为 cancelCtx 的子 context (如果有 deadline).
|
||||
// 当 cancelCtx 被取消时, deadlineCtx 也会被取消;
|
||||
// 当 deadline 到期时, deadlineCtx 自行取消, watcher 负责关闭 cancelCtx.
|
||||
var deadlineCtx context.Context
|
||||
var deadlineCancel context.CancelFunc
|
||||
if !earliestDeadline.IsZero() {
|
||||
baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline)
|
||||
} else {
|
||||
baseCtx, baseCancel = context.WithCancel(context.Background())
|
||||
deadlineCtx, deadlineCancel = context.WithDeadlineCause(cancelCtx, earliestDeadline, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// 嵌入的 context: 有 deadline 时用 deadlineCtx (以返回正确的 Deadline),
|
||||
// 否则用 cancelCtx.
|
||||
embedCtx := cancelCtx
|
||||
if deadlineCtx != nil {
|
||||
embedCtx = deadlineCtx
|
||||
}
|
||||
|
||||
mc := &mergedContext{
|
||||
Context: baseCtx,
|
||||
Context: embedCtx,
|
||||
parents: parents,
|
||||
cancel: baseCancel,
|
||||
}
|
||||
|
||||
// 启动一个监控 goroutine.
|
||||
// 启动监控 goroutine, 监听 parent 取消或 deadline 到期.
|
||||
go func() {
|
||||
defer mc.cancel()
|
||||
// 将 cancelCtx 加入 orDone, 确保手动 cancel() 时 orDone goroutine 能退出, 防止泄漏.
|
||||
parentDone := orDone(append(mc.parents, cancelCtx)...)
|
||||
|
||||
// orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭.
|
||||
// 同时监听 baseCtx.Done() 以便支持手动取消.
|
||||
if deadlineCtx != nil {
|
||||
defer deadlineCancel()
|
||||
select {
|
||||
case <-orDone(mc.parents...):
|
||||
case <-mc.Context.Done():
|
||||
case <-parentDone:
|
||||
// parent 取消或手动 cancel()
|
||||
for _, p := range mc.parents {
|
||||
if p.Err() != nil {
|
||||
cancelCause(context.Cause(p))
|
||||
return
|
||||
}
|
||||
}
|
||||
// 手动 cancel(), cause 已由 cancelCause() 设置
|
||||
case <-deadlineCtx.Done():
|
||||
// deadline 到期, 需要关闭 cancelCtx 并设置 cause
|
||||
cancelCause(context.DeadlineExceeded)
|
||||
}
|
||||
} else {
|
||||
<-parentDone
|
||||
for _, p := range mc.parents {
|
||||
if p.Err() != nil {
|
||||
cancelCause(context.Cause(p))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return mc, mc.cancel
|
||||
return mc, func() { cancelCause(nil) }
|
||||
}
|
||||
|
||||
// Value 返回当前Ctx Value
|
||||
// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause),
|
||||
// 再按传入顺序从 parents 中查找.
|
||||
func (mc *mergedContext) Value(key any) any {
|
||||
return mc.Context.Value(key)
|
||||
if v := mc.Context.Value(key); v != nil {
|
||||
return v
|
||||
}
|
||||
for _, p := range mc.parents {
|
||||
if val := p.Value(key); val != nil {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deadline 实现了 context.Context 的 Deadline 方法.
|
||||
func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return mc.Context.Deadline()
|
||||
}
|
||||
// Deadline, Done, Err 均由嵌入的 context.Context 提供.
|
||||
|
||||
// Done 实现了 context.Context 的 Done 方法.
|
||||
func (mc *mergedContext) Done() <-chan struct{} {
|
||||
return mc.Context.Done()
|
||||
}
|
||||
|
||||
// Err 实现了 context.Context 的 Err 方法.
|
||||
func (mc *mergedContext) Err() error {
|
||||
return mc.Context.Err()
|
||||
}
|
||||
|
||||
// orDone 是一个辅助函数, 返回一个 channel.
|
||||
// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭.
|
||||
// 这是一个非阻塞的、不会泄漏 goroutine 的实现.
|
||||
// orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭.
|
||||
func orDone(contexts ...context.Context) <-chan struct{} {
|
||||
done := make(chan struct{})
|
||||
|
||||
var once sync.Once
|
||||
closeDone := func() {
|
||||
once.Do(func() {
|
||||
close(done)
|
||||
})
|
||||
}
|
||||
|
||||
// 为每个父 context 启动一个 goroutine.
|
||||
for _, ctx := range contexts {
|
||||
go func(c context.Context) {
|
||||
select {
|
||||
case <-c.Done():
|
||||
closeDone()
|
||||
once.Do(func() { close(done) })
|
||||
case <-done:
|
||||
// orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出.
|
||||
}
|
||||
}(ctx)
|
||||
}
|
||||
|
||||
return done
|
||||
}
|
||||
|
|
|
|||
256
mergectx_test.go
Normal file
256
mergectx_test.go
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMergeCtx_NoParents(t *testing.T) {
|
||||
ctx, cancel := MergeCtx()
|
||||
defer cancel()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
t.Fatal("expected no error before cancel")
|
||||
}
|
||||
cancel()
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_SingleParent(t *testing.T) {
|
||||
parent, parentCancel := context.WithCancel(context.Background())
|
||||
|
||||
ctx, cancel := MergeCtx(parent)
|
||||
defer cancel()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
t.Fatal("expected no error before parent cancel")
|
||||
}
|
||||
|
||||
parentCancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after parent cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
cancel1()
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after p1 cancel")
|
||||
}
|
||||
// p2 should still be fine
|
||||
if p2.Err() != nil {
|
||||
t.Fatal("expected p2 to be unaffected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
cancel2()
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after p2 cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_ExternalCancel(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
defer cancel2()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
|
||||
cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after external cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_CausePropagation(t *testing.T) {
|
||||
testErr := errors.New("test cause")
|
||||
|
||||
p1, cancel1 := context.WithCancelCause(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
cancel1(testErr)
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after p1 cancel")
|
||||
}
|
||||
|
||||
cause := context.Cause(ctx)
|
||||
if cause != testErr {
|
||||
t.Fatalf("expected cause %v, got %v", testErr, cause)
|
||||
}
|
||||
cancel1(nil) // cleanup (already cancelled, no-op)
|
||||
}
|
||||
|
||||
func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) {
|
||||
testErr := errors.New("second parent cause")
|
||||
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancelCause(context.Background())
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
cancel2(testErr)
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after p2 cancel")
|
||||
}
|
||||
|
||||
cause := context.Cause(ctx)
|
||||
if cause != testErr {
|
||||
t.Fatalf("expected cause %v, got %v", testErr, cause)
|
||||
}
|
||||
|
||||
cancel1()
|
||||
}
|
||||
|
||||
func TestMergeCtx_Deadline_Earliest(t *testing.T) {
|
||||
now := time.Now()
|
||||
early := now.Add(100 * time.Millisecond)
|
||||
late := now.Add(1 * time.Hour)
|
||||
|
||||
p1, cancel1 := context.WithDeadline(context.Background(), late)
|
||||
p2, cancel2 := context.WithDeadline(context.Background(), early)
|
||||
defer cancel1()
|
||||
defer cancel2()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
dl, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Fatal("expected deadline to be set")
|
||||
}
|
||||
if !dl.Equal(early) {
|
||||
t.Fatalf("expected deadline %v, got %v", early, dl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_Deadline_Expires(t *testing.T) {
|
||||
p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancelP()
|
||||
|
||||
ctx, cancel := MergeCtx(p)
|
||||
defer cancel()
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after deadline expires")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_ValueLookup(t *testing.T) {
|
||||
type key struct{}
|
||||
p1 := context.WithValue(context.Background(), key{}, "from_p1")
|
||||
p2 := context.WithValue(context.Background(), key{}, "from_p2")
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
val := ctx.Value(key{})
|
||||
if val != "from_p1" {
|
||||
t.Fatalf("expected 'from_p1', got %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) {
|
||||
type key1 struct{}
|
||||
type key2 struct{}
|
||||
p1 := context.WithValue(context.Background(), key1{}, "val1")
|
||||
p2 := context.WithValue(context.Background(), key2{}, "val2")
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
if v := ctx.Value(key1{}); v != "val1" {
|
||||
t.Fatalf("expected 'val1', got %v", v)
|
||||
}
|
||||
if v := ctx.Value(key2{}); v != "val2" {
|
||||
t.Fatalf("expected 'val2', got %v", v)
|
||||
}
|
||||
if v := ctx.Value("missing"); v != nil {
|
||||
t.Fatalf("expected nil, got %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_ContextInterface(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
defer cancel2()
|
||||
|
||||
var ctx context.Context
|
||||
ctx, _ = MergeCtx(p1, p2)
|
||||
|
||||
// Verify all Context interface methods work
|
||||
_ = ctx.Done()
|
||||
_ = ctx.Err()
|
||||
_, _ = ctx.Deadline()
|
||||
_ = ctx.Value("any")
|
||||
}
|
||||
|
||||
func TestOrDone_SingleContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := orDone(ctx)
|
||||
|
||||
cancel()
|
||||
<-done // should not block
|
||||
}
|
||||
|
||||
func TestOrDone_MultipleContexts(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
|
||||
done := orDone(p1, p2)
|
||||
|
||||
cancel1()
|
||||
<-done // should not block
|
||||
}
|
||||
|
||||
func TestOrDone_SecondContextCancels(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
|
||||
done := orDone(p1, p2)
|
||||
|
||||
cancel2()
|
||||
<-done // should not block
|
||||
}
|
||||
|
|
@ -70,42 +70,25 @@ func TestApplyDefaultServerConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRunTLSProtocolInheritance(t *testing.T) {
|
||||
func TestTLSRunDefaultsProtocolInheritance(t *testing.T) {
|
||||
engine := New()
|
||||
|
||||
// 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2
|
||||
if engine.useDefaultProtocols {
|
||||
engine.setProtocols(&ProtocolsConfig{
|
||||
Http1: true,
|
||||
Http2: true,
|
||||
})
|
||||
}
|
||||
|
||||
srv := &http.Server{TLSConfig: &tls.Config{}}
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||
|
||||
if !srv.Protocols.HTTP2() {
|
||||
t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config")
|
||||
t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config")
|
||||
}
|
||||
|
||||
// 模拟用户设置了自定义协议后调用 RunTLS
|
||||
// 模拟用户设置了自定义协议后进入 TLS 运行模式
|
||||
engine = New()
|
||||
engine.SetProtocols(&ProtocolsConfig{
|
||||
Http1: true,
|
||||
Http2: false, // 用户明确不想要 HTTP/2
|
||||
})
|
||||
|
||||
if engine.useDefaultProtocols {
|
||||
engine.setProtocols(&ProtocolsConfig{
|
||||
Http1: true,
|
||||
Http2: true,
|
||||
})
|
||||
}
|
||||
|
||||
srv2 := &http.Server{TLSConfig: &tls.Config{}}
|
||||
engine.applyDefaultServerConfig(srv2)
|
||||
srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||
|
||||
if srv2.Protocols.HTTP2() {
|
||||
t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously")
|
||||
t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
2
respw.go
2
respw.go
|
|
@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|||
// 尝试从底层 ResponseWriter 获取 Hijacker 接口
|
||||
hj, ok := rw.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("http.Hijacker interface not supported")
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
// 调用底层的 Hijack 方法
|
||||
|
|
|
|||
1289
reverseproxy.go
1289
reverseproxy.go
File diff suppressed because it is too large
Load diff
355
reverseproxy_benchmark_test.go
Normal file
355
reverseproxy_benchmark_test.go
Normal file
|
|
@ -0,0 +1,355 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type benchmarkReadSeeker struct {
|
||||
data []byte
|
||||
off int
|
||||
}
|
||||
|
||||
func (r *benchmarkReadSeeker) Read(p []byte) (int, error) {
|
||||
if r.off >= len(r.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, r.data[r.off:])
|
||||
r.off += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *benchmarkReadSeeker) Reset() {
|
||||
r.off = 0
|
||||
}
|
||||
|
||||
type benchmarkResponseWriter struct {
|
||||
header http.Header
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
func newBenchmarkResponseWriter() *benchmarkResponseWriter {
|
||||
return &benchmarkResponseWriter{header: make(http.Header)}
|
||||
}
|
||||
|
||||
func (w *benchmarkResponseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *benchmarkResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.status == 0 {
|
||||
w.status = statusCode
|
||||
}
|
||||
}
|
||||
|
||||
func (w *benchmarkResponseWriter) Write(p []byte) (int, error) {
|
||||
if w.status == 0 {
|
||||
w.status = http.StatusOK
|
||||
}
|
||||
w.size += len(p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *benchmarkResponseWriter) Flush() {}
|
||||
|
||||
func (w *benchmarkResponseWriter) Status() int {
|
||||
return w.status
|
||||
}
|
||||
|
||||
func (w *benchmarkResponseWriter) Size() int {
|
||||
return w.size
|
||||
}
|
||||
|
||||
func (w *benchmarkResponseWriter) Written() bool {
|
||||
return w.status != 0
|
||||
}
|
||||
|
||||
func (w *benchmarkResponseWriter) IsHijacked() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *benchmarkResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
func (w *benchmarkResponseWriter) reset() {
|
||||
clear(w.header)
|
||||
w.status = 0
|
||||
w.size = 0
|
||||
}
|
||||
|
||||
var benchmarkReverseProxySink int
|
||||
|
||||
func BenchmarkReverseProxyCopyResponse(b *testing.B) {
|
||||
body := bytes.Repeat([]byte("0123456789abcdef"), 4096)
|
||||
proxy := newReverseProxyHandler(ReverseProxyConfig{})
|
||||
dst := newBenchmarkResponseWriter()
|
||||
src := &benchmarkReadSeeker{data: body}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
dst.reset()
|
||||
src.Reset()
|
||||
if err := proxy.copyResponse(dst, src, 0); err != nil {
|
||||
b.Fatalf("copyResponse failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
benchmarkReverseProxySink = dst.Size()
|
||||
}
|
||||
|
||||
func BenchmarkReverseProxyAvailableUpstreams(b *testing.B) {
|
||||
proxy := &reverseProxyHandler{
|
||||
upstreams: []*reverseProxyUpstream{
|
||||
{key: "a", index: 0},
|
||||
{key: "b", index: 1},
|
||||
{key: "c", index: 2},
|
||||
{key: "d", index: 3},
|
||||
},
|
||||
config: ReverseProxyConfig{
|
||||
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
MaxFails: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
proxy.upstreams[0].failures = []time.Time{now.Add(-30 * time.Second)}
|
||||
proxy.upstreams[1].failures = []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkReverseProxySink = len(proxy.availableUpstreams(now, nil))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReverseProxySelectUpstream(b *testing.B) {
|
||||
proxy := &reverseProxyHandler{
|
||||
upstreams: []*reverseProxyUpstream{
|
||||
{key: "a", index: 0},
|
||||
{key: "b", index: 1},
|
||||
{key: "c", index: 2},
|
||||
{key: "d", index: 3},
|
||||
},
|
||||
config: ReverseProxyConfig{
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
|
||||
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
MaxFails: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
proxy.upstreams[0].failures = []time.Time{time.Now().Add(-30 * time.Second)}
|
||||
|
||||
c, _ := CreateTestContext(nil)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
selected, err := proxy.selectUpstream(c, nil)
|
||||
if err != nil {
|
||||
b.Fatalf("selectUpstream failed: %v", err)
|
||||
}
|
||||
benchmarkReverseProxySink = selected.index
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReverseProxySelectUpstreamHeaderPolicy(b *testing.B) {
|
||||
proxy := &reverseProxyHandler{
|
||||
upstreams: []*reverseProxyUpstream{
|
||||
{key: "a", index: 0},
|
||||
{key: "b", index: 1},
|
||||
{key: "c", index: 2},
|
||||
{key: "d", index: 3},
|
||||
},
|
||||
config: ReverseProxyConfig{
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())},
|
||||
},
|
||||
}
|
||||
c, _ := CreateTestContext(nil)
|
||||
c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b", "tenant-c"}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
selected, err := proxy.selectUpstream(c, nil)
|
||||
if err != nil {
|
||||
b.Fatalf("selectUpstream failed: %v", err)
|
||||
}
|
||||
benchmarkReverseProxySink = selected.index
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) {
|
||||
proxy := newReverseProxyHandler(ReverseProxyConfig{})
|
||||
dst := newBenchmarkResponseWriter()
|
||||
src := bytes.NewBufferString("hello, reverse proxy")
|
||||
|
||||
if err := proxy.copyResponse(dst, src, 0); err != nil {
|
||||
t.Fatalf("copyResponse failed: %v", err)
|
||||
}
|
||||
|
||||
if got, want := dst.Size(), len("hello, reverse proxy"); got != want {
|
||||
t.Fatalf("expected %d bytes copied, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
type fixedLenBufferPool struct {
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func (p *fixedLenBufferPool) Get() []byte {
|
||||
return p.buf
|
||||
}
|
||||
|
||||
func (p *fixedLenBufferPool) Put(buf []byte) {
|
||||
p.buf = buf
|
||||
}
|
||||
|
||||
type recordingReader struct {
|
||||
chunk int
|
||||
reads []int
|
||||
left int
|
||||
}
|
||||
|
||||
func (r *recordingReader) Read(p []byte) (int, error) {
|
||||
if r.left == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := min(r.chunk, len(p), r.left)
|
||||
if n == 0 {
|
||||
return 0, errors.New("reader received zero-length buffer")
|
||||
}
|
||||
for i := range n {
|
||||
p[i] = 'x'
|
||||
}
|
||||
r.left -= n
|
||||
r.reads = append(r.reads, len(p))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func TestReverseProxyCopyResponseRespectsCustomBufferLength(t *testing.T) {
|
||||
pool := &fixedLenBufferPool{buf: make([]byte, 8, 32*1024)}
|
||||
proxy := newReverseProxyHandler(ReverseProxyConfig{BufferPool: pool})
|
||||
dst := newBenchmarkResponseWriter()
|
||||
src := &recordingReader{chunk: 8, left: 24}
|
||||
|
||||
if err := proxy.copyResponse(dst, src, 0); err != nil {
|
||||
t.Fatalf("copyResponse failed: %v", err)
|
||||
}
|
||||
|
||||
if len(src.reads) == 0 {
|
||||
t.Fatal("expected reader to be used")
|
||||
}
|
||||
for _, size := range src.reads {
|
||||
if size != 8 {
|
||||
t.Fatalf("expected custom buffer length 8 to be preserved, got read size %d", size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyAvailableUpstreamsFiltersExcludedAndUnhealthy(t *testing.T) {
|
||||
now := time.Now()
|
||||
proxy := &reverseProxyHandler{
|
||||
upstreams: []*reverseProxyUpstream{
|
||||
{key: "a"},
|
||||
{key: "b", failures: []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}},
|
||||
{key: "c"},
|
||||
},
|
||||
config: ReverseProxyConfig{
|
||||
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||
FailDuration: time.Minute,
|
||||
MaxFails: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
available := proxy.availableUpstreams(now, map[string]struct{}{"c": {}})
|
||||
if len(available) != 1 {
|
||||
t.Fatalf("expected only one available upstream, got %d", len(available))
|
||||
}
|
||||
if available[0].key != "a" {
|
||||
t.Fatalf("expected upstream 'a', got %q", available[0].key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderPolicyUsesAllHeaderValues(t *testing.T) {
|
||||
proxy := &reverseProxyHandler{
|
||||
upstreams: []*reverseProxyUpstream{
|
||||
{key: "a", index: 0},
|
||||
{key: "b", index: 1},
|
||||
{key: "c", index: 2},
|
||||
},
|
||||
config: ReverseProxyConfig{
|
||||
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())},
|
||||
},
|
||||
}
|
||||
|
||||
c, _ := CreateTestContext(nil)
|
||||
c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b"}
|
||||
|
||||
selectedA, err := proxy.selectUpstream(c, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("selectUpstream failed: %v", err)
|
||||
}
|
||||
selectedB, err := proxy.selectUpstream(c, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("selectUpstream failed: %v", err)
|
||||
}
|
||||
if selectedA.key != selectedB.key {
|
||||
t.Fatalf("expected stable selection for identical multi-value header, got %q and %q", selectedA.key, selectedB.key)
|
||||
}
|
||||
|
||||
c.Request.Header["X-Tenant"] = []string{"tenant-b", "tenant-a"}
|
||||
selectedC, err := proxy.selectUpstream(c, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("selectUpstream failed: %v", err)
|
||||
}
|
||||
if selectedC == nil {
|
||||
t.Fatal("expected upstream for reordered multi-value header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderPolicyMatchesJoinCompatibility(t *testing.T) {
|
||||
candidates := []*reverseProxyUpstream{
|
||||
{key: "a", index: 0},
|
||||
{key: "b", index: 1},
|
||||
{key: "c", index: 2},
|
||||
}
|
||||
|
||||
testCases := [][]string{
|
||||
{"tenant-a"},
|
||||
{"tenant-a", "tenant-b"},
|
||||
{"", "tenant-b"},
|
||||
{"tenant-a", ""},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, values := range testCases {
|
||||
got := reverseProxySelectHRWValues(candidates, values)
|
||||
want := reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||
if got == nil || want == nil {
|
||||
t.Fatalf("expected non-nil upstreams for values %v", values)
|
||||
}
|
||||
if got.key != want.key {
|
||||
t.Fatalf("expected joined compatibility for values %v, got %q want %q", values, got.key, want.key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ io.Writer = (*benchmarkResponseWriter)(nil)
|
||||
530
reverseproxy_headers_replace_test.go
Normal file
530
reverseproxy_headers_replace_test.go
Normal file
|
|
@ -0,0 +1,530 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReverseProxyHeaderOpsReplaceSubstring(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("X-Server"); got != "Caddy" {
|
||||
t.Errorf("expected X-Server=Caddy, got %q", got)
|
||||
}
|
||||
if got := r.Header.Get("X-Location"); got != "/api/v2/resource" {
|
||||
t.Errorf("expected X-Location=/api/v2/resource, got %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
RequestHeaders: &HeaderOps{
|
||||
Replace: map[string][]Replacement{
|
||||
"X-Server": {{Search: "NGINX", Replace: "Caddy"}},
|
||||
"X-Location": {{Search: "v1", Replace: "v2"}},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||
req.Header.Set("X-Server", "NGINX")
|
||||
req.Header.Set("X-Location", "/api/v1/resource")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderOpsReplaceRegexp(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("X-Route"); got != "/proxy-upstream" {
|
||||
t.Errorf("expected X-Route=/proxy-upstream, got %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
RequestHeaders: &HeaderOps{
|
||||
Replace: map[string][]Replacement{
|
||||
"X-Route": {{SearchRegexp: `^/([^/]+)/(.+)$`, Replace: "/proxy-$2"}},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||
req.Header.Set("X-Route", "/original/upstream")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderOpsReplaceWildcard(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("X-Host-A"); got != "new.example.com" {
|
||||
t.Errorf("expected X-Host-A=new.example.com, got %q", got)
|
||||
}
|
||||
if got := r.Header.Get("X-Host-B"); got != "new.example.com" {
|
||||
t.Errorf("expected X-Host-B=new.example.com, got %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
RequestHeaders: &HeaderOps{
|
||||
Replace: map[string][]Replacement{
|
||||
"*": {{Search: "old.example.com", Replace: "new.example.com"}},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||
req.Header.Set("X-Host-A", "old.example.com")
|
||||
req.Header.Set("X-Host-B", "old.example.com")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Backend", "backend-internal:8080")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
ResponseHeaders: &RespHeaderOps{
|
||||
HeaderOps: &HeaderOps{
|
||||
Replace: map[string][]Replacement{
|
||||
"X-Backend": {{Search: "backend-internal:8080", Replace: "public.example.com"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
resp, err := http.Get(proxy.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if got := resp.Header.Get("X-Backend"); got != "public.example.com" {
|
||||
t.Errorf("expected X-Backend=public.example.com, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
RequestHeaders: &HeaderOps{
|
||||
Replace: map[string][]Replacement{
|
||||
"X-Test": {{SearchRegexp: "[invalid"}},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Errorf("expected status 500, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplacementApply(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
r *Replacement
|
||||
s string
|
||||
want string
|
||||
}{
|
||||
{name: "nil replacement", r: nil, s: "hello", want: "hello"},
|
||||
{name: "empty string", r: &Replacement{Search: "x", Replace: "y"}, s: "", want: ""},
|
||||
{name: "substring match", r: &Replacement{Search: "world", Replace: "go"}, s: "hello world", want: "hello go"},
|
||||
{name: "substring no match", r: &Replacement{Search: "foo", Replace: "bar"}, s: "hello world", want: "hello world"},
|
||||
{name: "substring multiple", r: &Replacement{Search: "a", Replace: "b"}, s: "aaa", want: "bbb"},
|
||||
{name: "regexp match", r: &Replacement{SearchRegexp: `\d+`, Replace: "N", re: regexp.MustCompile(`\d+`)}, s: "abc123def", want: "abcNdef"},
|
||||
{name: "regexp no match", r: &Replacement{SearchRegexp: `z+`, Replace: "Z", re: regexp.MustCompile(`z+`)}, s: "abc", want: "abc"},
|
||||
{name: "empty search and regexp", r: &Replacement{}, s: "unchanged", want: "unchanged"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.r.apply(tt.s); got != tt.want {
|
||||
t.Errorf("Replacement.apply() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHeaderOpsAdd(b *testing.B) {
|
||||
ops := &HeaderOps{
|
||||
Add: map[string][]string{
|
||||
"X-Custom-1": {"value-1"},
|
||||
"X-Custom-2": {"value-2"},
|
||||
"X-Custom-3": {"value-3"},
|
||||
},
|
||||
}
|
||||
hdr := make(http.Header)
|
||||
repl := &reverseProxyReplacer{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hdr = make(http.Header)
|
||||
ops.applyTo(hdr, repl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHeaderOpsSet(b *testing.B) {
|
||||
ops := &HeaderOps{
|
||||
Set: map[string][]string{
|
||||
"X-Frame-Options": {"DENY"},
|
||||
"X-Content-Type-Options": {"nosniff"},
|
||||
"X-XSS-Protection": {"1; mode=block"},
|
||||
},
|
||||
}
|
||||
hdr := make(http.Header)
|
||||
repl := &reverseProxyReplacer{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hdr = make(http.Header)
|
||||
ops.applyTo(hdr, repl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHeaderOpsDeleteSingle(b *testing.B) {
|
||||
ops := &HeaderOps{
|
||||
Delete: []string{"X-Powered-By"},
|
||||
}
|
||||
repl := &reverseProxyReplacer{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hdr := make(http.Header)
|
||||
hdr.Set("X-Powered-By", "Express")
|
||||
hdr.Set("X-Keep", "value")
|
||||
ops.applyTo(hdr, repl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHeaderOpsDeleteWildcard(b *testing.B) {
|
||||
ops := &HeaderOps{
|
||||
Delete: []string{"X-Debug-*"},
|
||||
}
|
||||
repl := &reverseProxyReplacer{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hdr := make(http.Header)
|
||||
hdr.Set("X-Debug-1", "v1")
|
||||
hdr.Set("X-Debug-2", "v2")
|
||||
hdr.Set("X-Keep", "value")
|
||||
ops.applyTo(hdr, repl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHeaderOpsReplaceSubstring(b *testing.B) {
|
||||
ops := &HeaderOps{
|
||||
Replace: map[string][]Replacement{
|
||||
"Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}},
|
||||
},
|
||||
}
|
||||
repl := &reverseProxyReplacer{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hdr := make(http.Header)
|
||||
hdr.Set("Location", "http://internal:8080/api/v1/users")
|
||||
ops.applyTo(hdr, repl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHeaderOpsReplaceRegexp(b *testing.B) {
|
||||
re := regexp.MustCompile(`^http://([^/]+)(/.*)$`)
|
||||
ops := &HeaderOps{
|
||||
Replace: map[string][]Replacement{
|
||||
"Location": {{SearchRegexp: `^http://([^/]+)(/.*)$`, Replace: "https://public.example.com$2", re: re}},
|
||||
},
|
||||
}
|
||||
repl := &reverseProxyReplacer{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hdr := make(http.Header)
|
||||
hdr.Set("Location", "http://internal:8080/api/v1/users")
|
||||
ops.applyTo(hdr, repl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHeaderOpsReplaceWildcard(b *testing.B) {
|
||||
ops := &HeaderOps{
|
||||
Replace: map[string][]Replacement{
|
||||
"*": {{Search: "internal.example.com", Replace: "public.example.com"}},
|
||||
},
|
||||
}
|
||||
repl := &reverseProxyReplacer{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hdr := make(http.Header)
|
||||
hdr.Set("X-Host", "internal.example.com")
|
||||
hdr.Set("X-Origin", "internal.example.com")
|
||||
ops.applyTo(hdr, repl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHeaderOpsMixed(b *testing.B) {
|
||||
ops := &HeaderOps{
|
||||
Add: map[string][]string{
|
||||
"X-Request-ID": {"req-123"},
|
||||
},
|
||||
Set: map[string][]string{
|
||||
"X-Frame-Options": {"DENY"},
|
||||
},
|
||||
Delete: []string{"X-Powered-By"},
|
||||
Replace: map[string][]Replacement{
|
||||
"Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}},
|
||||
},
|
||||
}
|
||||
repl := &reverseProxyReplacer{}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hdr := make(http.Header)
|
||||
hdr.Set("X-Powered-By", "Express")
|
||||
hdr.Set("Location", "http://internal:8080/api")
|
||||
ops.applyTo(hdr, repl)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplacementApplySubstring(b *testing.B) {
|
||||
r := &Replacement{Search: "old.example.com", Replace: "new.example.com"}
|
||||
s := "https://old.example.com/api/v1/resource"
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = r.apply(s)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplacementApplyRegexp(b *testing.B) {
|
||||
r := &Replacement{SearchRegexp: `^https?://[^/]+`, Replace: "https://new.example.com", re: regexp.MustCompile(`^https?://[^/]+`)}
|
||||
s := "https://old.example.com/api/v1/resource"
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = r.apply(s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyReplacerDynamicVars(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com/api/v1/users?sort=name&limit=10", nil)
|
||||
req.Host = "example.com"
|
||||
repl := newReverseProxyReplacer(req)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"method", "{method}", "GET"},
|
||||
{"host", "{host}", "example.com"},
|
||||
{"path", "{path}", "/api/v1/users"},
|
||||
{"query", "{query}", "sort=name&limit=10"},
|
||||
{"scheme", "{scheme}", "http"},
|
||||
{"proto", "{proto}", "HTTP/1.1"},
|
||||
{"combined", "X-{method}-{path}", "X-GET-/api/v1/users"},
|
||||
{"no vars", "static-value", "static-value"},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := repl.Replace(tt.input); got != tt.want {
|
||||
t.Errorf("Replace(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyReplacerNilRequest(t *testing.T) {
|
||||
repl := newReverseProxyReplacer(nil)
|
||||
if got := repl.Replace("{method}"); got != "{method}" {
|
||||
t.Errorf("expected unchanged string with nil request, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyReplacerNilReplacer(t *testing.T) {
|
||||
var repl *reverseProxyReplacer
|
||||
if got := repl.Replace("{method}"); got != "{method}" {
|
||||
t.Errorf("expected unchanged string with nil replacer, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyReplacerFromHeader(t *testing.T) {
|
||||
hdr := make(http.Header)
|
||||
repl := newReverseProxyReplacerFromHeader(hdr)
|
||||
if got := repl.Replace("{method}"); got != "{method}" {
|
||||
t.Errorf("expected unchanged string from header replacer, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderOpsWithDynamicVars(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("X-Forwarded-Path"); got != "/dynamic/path" {
|
||||
t.Errorf("expected X-Forwarded-Path=/dynamic/path, got %q", got)
|
||||
}
|
||||
if got := r.Header.Get("X-Forwarded-Method"); got != "GET" {
|
||||
t.Errorf("expected X-Forwarded-Method=GET, got %q", got)
|
||||
}
|
||||
if got := r.Header.Get("X-Forwarded-Host"); got != "client.example" {
|
||||
t.Errorf("expected X-Forwarded-Host=client.example, got %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/dynamic/path", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
RequestHeaders: &HeaderOps{
|
||||
Add: map[string][]string{
|
||||
"X-Forwarded-Path": {"{path}"},
|
||||
"X-Forwarded-Method": {"{method}"},
|
||||
"X-Forwarded-Host": {"{host}"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/dynamic/path", nil)
|
||||
req.Host = "client.example"
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
220
reverseproxy_headers_test.go
Normal file
220
reverseproxy_headers_test.go
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReverseProxyHeaderOpsAdd(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("X-Custom-Header"); got != "test-value" {
|
||||
t.Errorf("expected X-Custom-Header=test-value, got %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
RequestHeaders: &HeaderOps{
|
||||
Add: map[string][]string{
|
||||
"X-Custom-Header": {"test-value"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
resp, err := http.Get(proxy.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderOpsDelete(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-Sensitive") != "" {
|
||||
t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive"))
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
RequestHeaders: &HeaderOps{
|
||||
Delete: []string{"X-Sensitive"},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||
req.Header.Set("X-Sensitive", "should-be-removed")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyHeaderOpsSet(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
got := r.Header.Get("X-Replace")
|
||||
if got != "new-value" {
|
||||
t.Errorf("expected X-Replace=new-value, got %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
RequestHeaders: &HeaderOps{
|
||||
Set: map[string][]string{
|
||||
"X-Replace": {"new-value"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||
req.Header.Set("X-Replace", "old-value")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyResponseHeaderOps(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Backend", "backend-server")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
ResponseHeaders: &RespHeaderOps{
|
||||
HeaderOps: &HeaderOps{
|
||||
Set: map[string][]string{
|
||||
"X-Custom": {"custom-value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
resp, err := http.Get(proxy.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if got := resp.Header.Get("X-Custom"); got != "custom-value" {
|
||||
t.Errorf("expected X-Custom=custom-value, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Powered-By", "Express")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
target, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse target: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||
Target: target,
|
||||
ResponseHeaders: &RespHeaderOps{
|
||||
HeaderOps: &HeaderOps{
|
||||
Delete: []string{"X-Powered-By"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
proxy := httptest.NewServer(engine)
|
||||
defer proxy.Close()
|
||||
|
||||
resp, err := http.Get(proxy.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
if got := resp.Header.Get("X-Powered-By"); got != "" {
|
||||
t.Errorf("expected X-Powered-By to be deleted, got %q", got)
|
||||
}
|
||||
}
|
||||
409
reverseproxy_lb.go
Normal file
409
reverseproxy_lb.go
Normal file
|
|
@ -0,0 +1,409 @@
|
|||
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||
package touka
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
|
||||
type ReverseProxyLoadBalancingConfig struct {
|
||||
Policy ReverseProxyLBPolicy
|
||||
Retries int
|
||||
TryDuration time.Duration
|
||||
TryInterval time.Duration
|
||||
}
|
||||
|
||||
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
|
||||
type ReverseProxyPassiveHealthConfig struct {
|
||||
FailDuration time.Duration
|
||||
MaxFails int
|
||||
UnhealthyStatus []int
|
||||
}
|
||||
|
||||
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
|
||||
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
|
||||
type ReverseProxyLBPolicy struct {
|
||||
kind reverseProxyLBPolicyKind
|
||||
key string
|
||||
fallback *ReverseProxyLBPolicy
|
||||
}
|
||||
|
||||
type reverseProxyLBPolicyKind uint8
|
||||
|
||||
const (
|
||||
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
|
||||
reverseProxyLBPolicyRoundRobin
|
||||
reverseProxyLBPolicyFirst
|
||||
reverseProxyLBPolicyLeastConn
|
||||
reverseProxyLBPolicyIPHash
|
||||
reverseProxyLBPolicyClientIPHash
|
||||
reverseProxyLBPolicyURIHash
|
||||
reverseProxyLBPolicyHeader
|
||||
reverseProxyLBPolicyQuery
|
||||
)
|
||||
|
||||
type reverseProxyUpstream struct {
|
||||
key string
|
||||
target *url.URL
|
||||
index int
|
||||
useH2C bool
|
||||
extendedConnectTransport http.RoundTripper
|
||||
bridgeTransport http.RoundTripper
|
||||
h2cTransport http.RoundTripper
|
||||
inFlight atomic.Int64
|
||||
|
||||
passiveMu sync.Mutex
|
||||
failures []time.Time
|
||||
}
|
||||
|
||||
func LBRandom() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
|
||||
}
|
||||
|
||||
func LBRoundRobin() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
|
||||
}
|
||||
|
||||
func LBFirst() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
|
||||
}
|
||||
|
||||
func LBLeastConn() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
|
||||
}
|
||||
|
||||
func LBIPHash() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
|
||||
}
|
||||
|
||||
func LBClientIPHash() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
|
||||
}
|
||||
|
||||
func LBURIHash() ReverseProxyLBPolicy {
|
||||
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
|
||||
}
|
||||
|
||||
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
|
||||
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||
policy.fallback = &fallback
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
|
||||
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||
policy.fallback = &fallback
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
|
||||
switch policy.kind {
|
||||
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
|
||||
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
|
||||
reverseProxyLBPolicyURIHash:
|
||||
return nil
|
||||
case reverseProxyLBPolicyHeader:
|
||||
if policy.key == "" {
|
||||
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
|
||||
}
|
||||
case reverseProxyLBPolicyQuery:
|
||||
if policy.key == "" {
|
||||
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
|
||||
}
|
||||
if policy.fallback != nil {
|
||||
return validateReverseProxyLBPolicy(*policy.fallback)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
|
||||
now := time.Now()
|
||||
policy := p.config.LoadBalancing.Policy
|
||||
candidateBuf := reverseProxyCandidatePool.Get().(*[]*reverseProxyUpstream)
|
||||
candidates := p.availableUpstreamsInto(now, excluded, *candidateBuf)
|
||||
if len(candidates) == 0 && len(excluded) > 0 {
|
||||
candidates = p.availableUpstreamsInto(now, nil, candidates[:0])
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
*candidateBuf = candidates[:0]
|
||||
reverseProxyCandidatePool.Put(candidateBuf)
|
||||
return nil, errReverseProxyNoAvailableUpstreams
|
||||
}
|
||||
selected := p.selectUpstreamWithPolicy(c, candidates, policy)
|
||||
*candidateBuf = candidates[:0]
|
||||
reverseProxyCandidatePool.Put(candidateBuf)
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
|
||||
return p.availableUpstreamsInto(now, excluded, nil)
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) availableUpstreamsInto(now time.Time, excluded map[string]struct{}, candidates []*reverseProxyUpstream) []*reverseProxyUpstream {
|
||||
if cap(candidates) < len(p.upstreams) {
|
||||
candidates = make([]*reverseProxyUpstream, 0, len(p.upstreams))
|
||||
} else {
|
||||
candidates = candidates[:0]
|
||||
}
|
||||
for _, upstream := range p.upstreams {
|
||||
if _, skip := excluded[upstream.key]; skip {
|
||||
continue
|
||||
}
|
||||
if !upstream.healthy(now, p.config.PassiveHealth) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, upstream)
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch policy.kind {
|
||||
case reverseProxyLBPolicyRoundRobin:
|
||||
return candidates[p.nextRoundRobinIndex(len(candidates))]
|
||||
case reverseProxyLBPolicyFirst:
|
||||
return candidates[0]
|
||||
case reverseProxyLBPolicyLeastConn:
|
||||
return p.selectLeastConnUpstream(candidates)
|
||||
case reverseProxyLBPolicyIPHash:
|
||||
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
|
||||
case reverseProxyLBPolicyClientIPHash:
|
||||
return reverseProxySelectHRW(candidates, c.RequestIP())
|
||||
case reverseProxyLBPolicyURIHash:
|
||||
if c.Request == nil || c.Request.URL == nil {
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
|
||||
case reverseProxyLBPolicyHeader:
|
||||
if c.Request != nil && c.Request.Header != nil {
|
||||
if values, ok := c.Request.Header[policy.key]; ok {
|
||||
return reverseProxySelectHRWValues(candidates, values)
|
||||
}
|
||||
}
|
||||
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||
case reverseProxyLBPolicyQuery:
|
||||
if c.Request != nil && c.Request.URL != nil {
|
||||
if values, ok := c.Request.URL.Query()[policy.key]; ok {
|
||||
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||
}
|
||||
}
|
||||
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||
case reverseProxyLBPolicyRandom:
|
||||
fallthrough
|
||||
default:
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
|
||||
if size <= 1 {
|
||||
return 0
|
||||
}
|
||||
return int((p.roundRobin.Add(1) - 1) % uint64(size))
|
||||
}
|
||||
|
||||
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
selected := candidates[0]
|
||||
lowest := selected.inFlight.Load()
|
||||
ties := []*reverseProxyUpstream{selected}
|
||||
for _, upstream := range candidates[1:] {
|
||||
count := upstream.inFlight.Load()
|
||||
switch {
|
||||
case count < lowest:
|
||||
selected = upstream
|
||||
lowest = count
|
||||
ties = []*reverseProxyUpstream{upstream}
|
||||
case count == lowest:
|
||||
ties = append(ties, upstream)
|
||||
}
|
||||
}
|
||||
if len(ties) == 1 {
|
||||
return selected
|
||||
}
|
||||
return ties[p.nextRoundRobinIndex(len(ties))]
|
||||
}
|
||||
|
||||
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(candidates) == 1 {
|
||||
return candidates[0]
|
||||
}
|
||||
return candidates[rand.IntN(len(candidates))]
|
||||
}
|
||||
|
||||
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if key == "" {
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
selected := candidates[0]
|
||||
bestScore := reverseProxyHRWScore(key, selected.key)
|
||||
for _, upstream := range candidates[1:] {
|
||||
score := reverseProxyHRWScore(key, upstream.key)
|
||||
if score > bestScore {
|
||||
selected = upstream
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
func reverseProxySelectHRWValues(candidates []*reverseProxyUpstream, values []string) *reverseProxyUpstream {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(values) == 0 {
|
||||
return reverseProxySelectRandom(candidates)
|
||||
}
|
||||
selected := candidates[0]
|
||||
bestScore := reverseProxyHRWValuesScore(values, selected.key)
|
||||
for _, upstream := range candidates[1:] {
|
||||
score := reverseProxyHRWValuesScore(values, upstream.key)
|
||||
if score > bestScore {
|
||||
selected = upstream
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
|
||||
const (
|
||||
offset64 = 14695981039346656037
|
||||
prime64 = 1099511628211
|
||||
)
|
||||
h := uint64(offset64)
|
||||
for i := 0; i < len(key); i++ {
|
||||
h ^= uint64(key[i])
|
||||
h *= prime64
|
||||
}
|
||||
h ^= 0xff
|
||||
h *= prime64
|
||||
for i := 0; i < len(upstreamKey); i++ {
|
||||
h ^= uint64(upstreamKey[i])
|
||||
h *= prime64
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func reverseProxyHRWValuesScore(values []string, upstreamKey string) uint64 {
|
||||
const (
|
||||
offset64 = 14695981039346656037
|
||||
prime64 = 1099511628211
|
||||
)
|
||||
h := uint64(offset64)
|
||||
for valueIndex, value := range values {
|
||||
for i := 0; i < len(value); i++ {
|
||||
h ^= uint64(value[i])
|
||||
h *= prime64
|
||||
}
|
||||
if valueIndex+1 < len(values) {
|
||||
h ^= ','
|
||||
h *= prime64
|
||||
}
|
||||
}
|
||||
h ^= 0xff
|
||||
h *= prime64
|
||||
for i := 0; i < len(upstreamKey); i++ {
|
||||
h ^= uint64(upstreamKey[i])
|
||||
h *= prime64
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||
if policy.fallback != nil {
|
||||
return *policy.fallback
|
||||
}
|
||||
return LBRandom()
|
||||
}
|
||||
|
||||
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
|
||||
maxFails := reverseProxyPassiveMaxFails(config)
|
||||
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
u.passiveMu.Lock()
|
||||
defer u.passiveMu.Unlock()
|
||||
u.pruneFailuresLocked(now, config.FailDuration)
|
||||
return len(u.failures) < maxFails
|
||||
}
|
||||
|
||||
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
|
||||
maxFails := reverseProxyPassiveMaxFails(config)
|
||||
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
u.passiveMu.Lock()
|
||||
defer u.passiveMu.Unlock()
|
||||
u.pruneFailuresLocked(now, config.FailDuration)
|
||||
u.failures = append(u.failures, now)
|
||||
}
|
||||
|
||||
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
|
||||
if len(u.failures) == 0 || window <= 0 {
|
||||
if window <= 0 {
|
||||
u.failures = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
cutoff := now.Add(-window)
|
||||
keep := 0
|
||||
for _, failureAt := range u.failures {
|
||||
if failureAt.Before(cutoff) {
|
||||
continue
|
||||
}
|
||||
u.failures[keep] = failureAt
|
||||
keep++
|
||||
}
|
||||
u.failures = u.failures[:keep]
|
||||
}
|
||||
|
||||
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
|
||||
if config.FailDuration <= 0 {
|
||||
return 0
|
||||
}
|
||||
if config.MaxFails <= 0 {
|
||||
return 1
|
||||
}
|
||||
return config.MaxFails
|
||||
}
|
||||
|
||||
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
|
||||
if status <= 0 {
|
||||
return false
|
||||
}
|
||||
return slices.Contains(config.UnhealthyStatus, status)
|
||||
}
|
||||
1955
reverseproxy_test.go
1955
reverseproxy_test.go
File diff suppressed because it is too large
Load diff
130
route_match_benchmark_test.go
Normal file
130
route_match_benchmark_test.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
package touka
|
||||
|
||||
import "testing"
|
||||
|
||||
var (
|
||||
benchmarkRouteHandlers HandlersChain
|
||||
benchmarkRouteFullPath string
|
||||
benchmarkRouteParamsLen int
|
||||
benchmarkRouteCIPath []byte
|
||||
benchmarkRouteCIFound bool
|
||||
)
|
||||
|
||||
func buildRouteMatchBenchmarkTree() *node {
|
||||
tree := &node{}
|
||||
routes := []string{
|
||||
"/",
|
||||
"/health",
|
||||
"/contact",
|
||||
"/api/v1/users",
|
||||
"/api/v1/users/:id",
|
||||
"/api/v1/users/:id/settings",
|
||||
"/assets/*filepath",
|
||||
"/abc/b",
|
||||
"/abc/:p1/cde",
|
||||
"/abc/:p1/:p2/def/*filepath",
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
tree.addRoute(route, fakeHandler(route))
|
||||
}
|
||||
|
||||
return tree
|
||||
}
|
||||
|
||||
func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) {
|
||||
b.Helper()
|
||||
|
||||
params := make(Params, 0, 4)
|
||||
skipped := make([]skippedNode, 0, 8)
|
||||
|
||||
value := tree.getValue(path, ¶ms, &skipped, true)
|
||||
if wantFullPath == "" {
|
||||
if value.handlers != nil {
|
||||
b.Fatalf("expected no match for %q, got %q", path, value.fullPath)
|
||||
}
|
||||
} else {
|
||||
if value.handlers == nil {
|
||||
b.Fatalf("expected match for %q, got nil handlers", path)
|
||||
}
|
||||
if value.fullPath != wantFullPath {
|
||||
b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath)
|
||||
}
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
params = params[:0]
|
||||
skipped = skipped[:0]
|
||||
value = tree.getValue(path, ¶ms, &skipped, true)
|
||||
}
|
||||
|
||||
benchmarkRouteHandlers = value.handlers
|
||||
benchmarkRouteFullPath = value.fullPath
|
||||
if value.params != nil {
|
||||
benchmarkRouteParamsLen = len(*value.params)
|
||||
} else {
|
||||
benchmarkRouteParamsLen = 0
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRouteMatch(b *testing.B) {
|
||||
tree := buildRouteMatchBenchmarkTree()
|
||||
|
||||
b.Run("StaticHit", func(b *testing.B) {
|
||||
benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users")
|
||||
})
|
||||
|
||||
b.Run("ParamHit", func(b *testing.B) {
|
||||
benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id")
|
||||
})
|
||||
|
||||
b.Run("BacktrackingHit", func(b *testing.B) {
|
||||
benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath")
|
||||
})
|
||||
|
||||
b.Run("Miss", func(b *testing.B) {
|
||||
benchmarkRouteLookup(b, tree, "/does/not/exist", "")
|
||||
})
|
||||
|
||||
b.Run("CaseInsensitiveHit", func(b *testing.B) {
|
||||
path := "/API/V1/USERS/123/SETTINGS"
|
||||
out, found := tree.findCaseInsensitivePath(path, true)
|
||||
if !found {
|
||||
b.Fatalf("expected fixed-path match for %q", path)
|
||||
}
|
||||
if got := string(out); got != "/api/v1/users/123/settings" {
|
||||
b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
out, found = tree.findCaseInsensitivePath(path, true)
|
||||
}
|
||||
|
||||
benchmarkRouteCIPath = out
|
||||
benchmarkRouteCIFound = found
|
||||
})
|
||||
|
||||
b.Run("CaseInsensitiveMiss", func(b *testing.B) {
|
||||
path := "/DOES/NOT/EXIST"
|
||||
out, found := tree.findCaseInsensitivePath(path, true)
|
||||
if found || out != nil {
|
||||
b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
out, found = tree.findCaseInsensitivePath(path, true)
|
||||
}
|
||||
|
||||
benchmarkRouteCIPath = out
|
||||
benchmarkRouteCIFound = found
|
||||
})
|
||||
}
|
||||
759
serve.go
759
serve.go
|
|
@ -14,6 +14,7 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
|
@ -21,329 +22,322 @@ import (
|
|||
"github.com/fenthope/reco"
|
||||
)
|
||||
|
||||
// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间
|
||||
const defaultShutdownTimeout = 5 * time.Second
|
||||
|
||||
// --- 内部辅助函数 ---
|
||||
type runMode uint8
|
||||
|
||||
// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080"
|
||||
func resolveAddress(addr []string) string {
|
||||
switch len(addr) {
|
||||
case 0:
|
||||
return ":8080"
|
||||
case 1:
|
||||
return addr[0]
|
||||
default:
|
||||
panic("too many parameters provided for server address")
|
||||
const (
|
||||
runModeHTTP runMode = iota
|
||||
runModeHTTPS
|
||||
runModeHTTPSRedirect
|
||||
)
|
||||
|
||||
type runConfig struct {
|
||||
addr string
|
||||
httpRedirectAddr string
|
||||
tlsConfig *tls.Config
|
||||
redirectHost string
|
||||
redirectHostHeaders []string
|
||||
useHeaderHost bool
|
||||
useHeaderHostSet bool
|
||||
graceful bool
|
||||
shutdownTimeout time.Duration
|
||||
gracefulCtx context.Context
|
||||
mode runMode
|
||||
shutdownDefaultSet bool
|
||||
shutdownTimeoutSet bool
|
||||
}
|
||||
|
||||
type RunOption interface {
|
||||
apply(*runConfig) error
|
||||
}
|
||||
|
||||
type runOptionFunc func(*runConfig) error
|
||||
|
||||
func (f runOptionFunc) apply(cfg *runConfig) error {
|
||||
return f(cfg)
|
||||
}
|
||||
|
||||
func defaultRunConfig() runConfig {
|
||||
return runConfig{
|
||||
addr: ":8080",
|
||||
shutdownTimeout: defaultShutdownTimeout,
|
||||
mode: runModeHTTP,
|
||||
useHeaderHost: true,
|
||||
}
|
||||
}
|
||||
|
||||
// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值
|
||||
func getShutdownTimeout(timeouts []time.Duration) time.Duration {
|
||||
if len(timeouts) > 0 && timeouts[0] > 0 {
|
||||
return timeouts[0]
|
||||
}
|
||||
return defaultShutdownTimeout
|
||||
type HTTPRedirectOption interface {
|
||||
applyRedirect(*runConfig) error
|
||||
}
|
||||
|
||||
// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server,
|
||||
// 并处理其启动失败的致命错误
|
||||
// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS")
|
||||
func runServer(serverType string, srv *http.Server) {
|
||||
type redirectOptionFunc func(*runConfig) error
|
||||
|
||||
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
|
||||
return f(cfg)
|
||||
}
|
||||
|
||||
func WithAddr(addr string) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if addr == "" {
|
||||
return errors.New("run address must not be empty")
|
||||
}
|
||||
cfg.addr = addr
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithTLS(tlsConfig *tls.Config) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if tlsConfig == nil {
|
||||
return errors.New("tls.Config must not be nil")
|
||||
}
|
||||
cfg.tlsConfig = tlsConfig
|
||||
if cfg.mode == runModeHTTP {
|
||||
cfg.mode = runModeHTTPS
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if addr == "" {
|
||||
return errors.New("http redirect address must not be empty")
|
||||
}
|
||||
cfg.httpRedirectAddr = addr
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
if err := opt.applyRedirect(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
|
||||
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.useHeaderHost = enabled
|
||||
cfg.useHeaderHostSet = true
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithRedirectHost(host string) HTTPRedirectOption {
|
||||
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||
if host == "" {
|
||||
return errors.New("redirect host must not be empty")
|
||||
}
|
||||
cfg.redirectHost = host
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
|
||||
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
|
||||
for _, header := range headers {
|
||||
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
|
||||
if trimmed != "" {
|
||||
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithGracefulShutdown(timeout time.Duration) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.graceful = true
|
||||
cfg.shutdownTimeoutSet = true
|
||||
if timeout > 0 {
|
||||
cfg.shutdownTimeout = timeout
|
||||
} else {
|
||||
cfg.shutdownTimeout = defaultShutdownTimeout
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithGracefulShutdownDefault() RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.graceful = true
|
||||
cfg.shutdownDefaultSet = true
|
||||
cfg.shutdownTimeout = defaultShutdownTimeout
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithShutdownContext(ctx context.Context) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown context must not be nil")
|
||||
}
|
||||
cfg.gracefulCtx = ctx
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func serveServer(srv *http.Server, serveTLS bool) error {
|
||||
if serveTLS {
|
||||
return srv.ListenAndServeTLS("", "")
|
||||
}
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
|
||||
func runServer(serverType string, srv *http.Server, serveTLS bool) {
|
||||
go func() {
|
||||
var err error
|
||||
protocol := "http"
|
||||
if srv.TLSConfig != nil {
|
||||
if serveTLS {
|
||||
protocol = "https"
|
||||
}
|
||||
|
||||
log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr)
|
||||
|
||||
if srv.TLSConfig != nil {
|
||||
// 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置,
|
||||
// ListenAndServeTLS 的前两个参数可以为空字符串
|
||||
err = srv.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
|
||||
// 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed),
|
||||
// 则认为是一个严重错误,并终止程序
|
||||
err := serveServer(srv, serveTLS)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Touka %s server failed: %v", serverType, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器
|
||||
// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿
|
||||
func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error {
|
||||
// 创建一个 channel 来接收操作系统信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
|
||||
<-quit // 阻塞,直到接收到上述信号之一
|
||||
log.Println("Shutting down Touka server(s)...")
|
||||
|
||||
// 关闭日志记录器
|
||||
if logger != nil {
|
||||
go func() {
|
||||
log.Println("Closing Touka logger...")
|
||||
CloseLogger(logger)
|
||||
}()
|
||||
}
|
||||
|
||||
// 创建一个带超时的上下文,用于 Shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
|
||||
|
||||
// 并发地关闭所有服务器
|
||||
for _, srv := range servers {
|
||||
wg.Add(1)
|
||||
go func(s *http.Server) {
|
||||
defer wg.Done()
|
||||
if err := s.Shutdown(ctx); err != nil {
|
||||
// 将错误发送到 channel
|
||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
||||
}
|
||||
}(srv)
|
||||
}
|
||||
|
||||
wg.Wait() // 等待所有服务器的关闭 goroutine 完成
|
||||
close(errChan) // 关闭 channel,以便可以安全地遍历它
|
||||
|
||||
// 收集所有关闭过程中发生的错误
|
||||
var shutdownErrors []error
|
||||
for err := range errChan {
|
||||
shutdownErrors = append(shutdownErrors, err)
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
|
||||
if len(shutdownErrors) > 0 {
|
||||
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
|
||||
}
|
||||
log.Println("Touka server(s) exited gracefully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error {
|
||||
// 创建一个 channel 来接收操作系统信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
|
||||
|
||||
// 启动服务器
|
||||
serverStopped := make(chan error, 1)
|
||||
for _, srv := range servers {
|
||||
go func(s *http.Server) {
|
||||
serverStopped <- s.ListenAndServe()
|
||||
}(srv)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context 被取消 (例如,通过外部取消函数)
|
||||
log.Println("Context cancelled, shutting down Touka server(s)...")
|
||||
case err := <-serverStopped:
|
||||
// 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("Touka HTTP server failed: %w", err)
|
||||
}
|
||||
log.Println("Touka HTTP server stopped gracefully.")
|
||||
return nil // 服务器已自行优雅关闭,无需进一步处理
|
||||
case <-quit:
|
||||
// 接收到操作系统信号
|
||||
log.Println("Shutting down Touka server(s) due to OS signal...")
|
||||
}
|
||||
|
||||
// 关闭日志记录器
|
||||
if logger != nil {
|
||||
go func() {
|
||||
log.Println("Closing Touka logger...")
|
||||
CloseLogger(logger)
|
||||
}()
|
||||
}
|
||||
|
||||
// 创建一个带超时的上下文,用于 Shutdown
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
|
||||
|
||||
// 并发地关闭所有服务器
|
||||
for _, srv := range servers {
|
||||
wg.Add(1)
|
||||
go func(s *http.Server) {
|
||||
defer wg.Done()
|
||||
if err := s.Shutdown(shutdownCtx); err != nil {
|
||||
// 将错误发送到 channel
|
||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
||||
}
|
||||
}(srv)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan) // 关闭 channel,以便可以安全地遍历它
|
||||
|
||||
// 收集所有关闭过程中发生的错误
|
||||
var shutdownErrors []error
|
||||
for err := range errChan {
|
||||
shutdownErrors = append(shutdownErrors, err)
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
|
||||
if len(shutdownErrors) > 0 {
|
||||
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
|
||||
}
|
||||
log.Println("Touka server(s) exited gracefully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- 公共 Run 方法 ---
|
||||
|
||||
// Run 启动一个不支持优雅关闭的 HTTP 服务器
|
||||
// 这是一个阻塞调用,主要用于简单的场景或快速测试
|
||||
// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法
|
||||
func (engine *Engine) Run(addr ...string) error {
|
||||
address := resolveAddress(addr)
|
||||
srv := &http.Server{Addr: address, Handler: engine}
|
||||
|
||||
// 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address)
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
|
||||
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
|
||||
func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error {
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: engine,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return engine.shutdownCtx
|
||||
},
|
||||
}
|
||||
srv.RegisterOnShutdown(engine.shutdownCancel)
|
||||
|
||||
// 应用框架的默认配置和用户提供的自定义配置
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
|
||||
runServer("HTTP", srv)
|
||||
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
|
||||
}
|
||||
|
||||
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
|
||||
func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error {
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: engine,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return engine.shutdownCtx
|
||||
},
|
||||
}
|
||||
srv.RegisterOnShutdown(engine.shutdownCancel)
|
||||
|
||||
// 应用框架的默认配置和用户提供的自定义配置
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
|
||||
return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco)
|
||||
}
|
||||
|
||||
// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器
|
||||
func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
||||
func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config {
|
||||
if tlsConfig == nil {
|
||||
return errors.New("tls.Config must not be nil for RunTLS")
|
||||
return nil
|
||||
}
|
||||
return tlsConfig.Clone()
|
||||
}
|
||||
|
||||
// 配置 HTTP/2 支持 (如果使用默认配置)
|
||||
if engine.useDefaultProtocols {
|
||||
engine.setProtocols(&ProtocolsConfig{
|
||||
Http1: true,
|
||||
Http2: true, // 默认在 TLS 上启用 HTTP/2
|
||||
})
|
||||
func parseHTTPSPort(addr string) (string, error) {
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("https address %q must include a port: %w", addr, err)
|
||||
}
|
||||
return port, nil
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: engine,
|
||||
TLSConfig: tlsConfig,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return engine.shutdownCtx
|
||||
},
|
||||
}
|
||||
srv.RegisterOnShutdown(engine.shutdownCancel)
|
||||
|
||||
// 应用框架的默认配置和用户提供的自定义配置
|
||||
// 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) {
|
||||
if serveTLS {
|
||||
if engine.TLSServerConfigurator != nil {
|
||||
engine.TLSServerConfigurator(srv)
|
||||
} else if engine.ServerConfigurator != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
|
||||
runServer("HTTPS", srv)
|
||||
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
|
||||
}
|
||||
|
||||
// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名
|
||||
func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
||||
return engine.RunTLS(addr, tlsConfig, timeouts...)
|
||||
func applyRedirectServerConfig(engine *Engine, srv *http.Server) {
|
||||
applyServerProtocols(srv, engine.serverProtocols)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
}
|
||||
|
||||
// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭
|
||||
func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
||||
if tlsConfig == nil {
|
||||
return errors.New("tls.Config must not be nil for RunTLSRedir")
|
||||
func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols {
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
if serveTLS && engine.useDefaultProtocols {
|
||||
protocols := &http.Protocols{}
|
||||
protocols.SetHTTP1(true)
|
||||
protocols.SetHTTP2(true)
|
||||
return protocols
|
||||
}
|
||||
return cloneServerProtocols(engine.serverProtocols)
|
||||
}
|
||||
|
||||
// --- HTTPS 服务器 ---
|
||||
if engine.useDefaultProtocols {
|
||||
engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true})
|
||||
}
|
||||
httpsSrv := &http.Server{
|
||||
Addr: httpsAddr,
|
||||
func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
|
||||
serveTLS := cfg.mode != runModeHTTP
|
||||
server := &http.Server{
|
||||
Addr: cfg.addr,
|
||||
Handler: engine,
|
||||
TLSConfig: tlsConfig,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
TLSConfig: cloneTLSConfig(cfg.tlsConfig),
|
||||
}
|
||||
if cfg.graceful {
|
||||
server.BaseContext = func(net.Listener) context.Context {
|
||||
return engine.shutdownCtx
|
||||
},
|
||||
}
|
||||
httpsSrv.RegisterOnShutdown(engine.shutdownCancel)
|
||||
engine.applyDefaultServerConfig(httpsSrv)
|
||||
if engine.TLSServerConfigurator != nil {
|
||||
engine.TLSServerConfigurator(httpsSrv)
|
||||
} else if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(httpsSrv)
|
||||
server.RegisterOnShutdown(engine.shutdownCancel)
|
||||
}
|
||||
applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS))
|
||||
applyMainServerConfig(engine, server, serveTLS)
|
||||
return server
|
||||
}
|
||||
|
||||
func firstRedirectHeaderHost(r *http.Request, headers []string) string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
for _, header := range headers {
|
||||
value := strings.TrimSpace(r.Header.Get(header))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if comma := strings.IndexByte(value, ','); comma >= 0 {
|
||||
value = strings.TrimSpace(value[:comma])
|
||||
}
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
|
||||
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||
if cfg.redirectHost == "" {
|
||||
return "", http.StatusInternalServerError, false
|
||||
}
|
||||
return cfg.redirectHost, 0, true
|
||||
}
|
||||
|
||||
if len(cfg.redirectHostHeaders) > 0 {
|
||||
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
|
||||
if host == "" {
|
||||
return "", http.StatusUpgradeRequired, false
|
||||
}
|
||||
return host, 0, true
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
return "", http.StatusUpgradeRequired, false
|
||||
}
|
||||
host := strings.TrimSpace(r.Host)
|
||||
if host == "" {
|
||||
return "", http.StatusUpgradeRequired, false
|
||||
}
|
||||
return host, 0, true
|
||||
}
|
||||
|
||||
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
|
||||
httpsAddr := cfg.addr
|
||||
httpAddr := cfg.httpRedirectAddr
|
||||
httpsPort, err := parseHTTPSPort(httpsAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// --- HTTP 重定向服务器 ---
|
||||
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
host, statusCode, ok := redirectTargetHost(r, cfg)
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
_, httpsPort, err := net.SplitHostPort(httpsAddr)
|
||||
if err != nil {
|
||||
// 如果 httpsAddr 没有端口,这是一个配置错误
|
||||
|
||||
log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr)
|
||||
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = parsedHost
|
||||
if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
|
||||
host = "[" + host + "]"
|
||||
}
|
||||
}
|
||||
|
||||
targetURL := "https://" + host
|
||||
// 只有在非标准 HTTPS 端口 (443) 时才附加端口号
|
||||
if httpsPort != "443" {
|
||||
targetURL = "https://" + net.JoinHostPort(host, httpsPort)
|
||||
}
|
||||
|
|
@ -351,22 +345,205 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con
|
|||
|
||||
http.Redirect(w, r, targetURL, http.StatusMovedPermanently)
|
||||
})
|
||||
httpSrv := &http.Server{
|
||||
Addr: httpAddr,
|
||||
Handler: redirectHandler,
|
||||
}
|
||||
engine.applyDefaultServerConfig(httpSrv)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(httpSrv)
|
||||
}
|
||||
|
||||
// --- 启动服务器和优雅关闭 ---
|
||||
runServer("HTTPS", httpsSrv)
|
||||
runServer("HTTP Redirect", httpSrv)
|
||||
return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco)
|
||||
server := &http.Server{Addr: httpAddr, Handler: redirectHandler}
|
||||
applyRedirectServerConfig(engine, server)
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性
|
||||
func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
||||
return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...)
|
||||
func validateRunConfig(cfg runConfig) error {
|
||||
if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil {
|
||||
return errors.New("WithHTTPRedirect requires WithTLS")
|
||||
}
|
||||
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
|
||||
return errors.New("https mode requires WithTLS")
|
||||
}
|
||||
if cfg.gracefulCtx != nil && !cfg.graceful {
|
||||
return errors.New("WithShutdownContext requires graceful shutdown")
|
||||
}
|
||||
if len(cfg.redirectHostHeaders) > 0 {
|
||||
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
|
||||
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
|
||||
}
|
||||
}
|
||||
if cfg.useHeaderHostSet && cfg.useHeaderHost {
|
||||
if cfg.redirectHost != "" {
|
||||
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
|
||||
}
|
||||
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||
if cfg.redirectHost == "" {
|
||||
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
|
||||
}
|
||||
if len(cfg.redirectHostHeaders) > 0 {
|
||||
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func effectiveShutdownTimeout(cfg runConfig) time.Duration {
|
||||
if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet {
|
||||
if cfg.shutdownTimeout > 0 {
|
||||
return cfg.shutdownTimeout
|
||||
}
|
||||
}
|
||||
return defaultShutdownTimeout
|
||||
}
|
||||
|
||||
func closeLoggerAsync(logger *reco.Logger) {
|
||||
if logger == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
log.Println("Closing Touka logger...")
|
||||
CloseLogger(logger)
|
||||
}()
|
||||
}
|
||||
|
||||
func shutdownServers(servers []*http.Server, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(servers))
|
||||
for _, srv := range servers {
|
||||
wg.Add(1)
|
||||
go func(s *http.Server) {
|
||||
defer wg.Done()
|
||||
if err := s.Shutdown(ctx); err != nil {
|
||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
||||
}
|
||||
}(srv)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
var shutdownErrors []error
|
||||
for err := range errChan {
|
||||
shutdownErrors = append(shutdownErrors, err)
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
if len(shutdownErrors) > 0 {
|
||||
return errors.Join(shutdownErrors...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error {
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(quit)
|
||||
|
||||
serverStopped := make(chan error, len(servers))
|
||||
for i, srv := range servers {
|
||||
serveTLSFlag := serveTLS[i]
|
||||
go func(server *http.Server, useTLS bool) {
|
||||
serverStopped <- serveServer(server, useTLS)
|
||||
}(srv, serveTLSFlag)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-serverStopped:
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil {
|
||||
return errors.Join(err, shutdownErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Println("Touka server stopped gracefully.")
|
||||
return nil
|
||||
case <-quit:
|
||||
log.Println("Shutting down Touka server(s) due to OS signal...")
|
||||
case <-shutdownCtx.Done():
|
||||
log.Println("Context cancelled, shutting down Touka server(s)...")
|
||||
}
|
||||
|
||||
closeLoggerAsync(logger)
|
||||
if err := shutdownServers(servers, timeout); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Println("Touka server(s) exited gracefully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run starts the engine with the provided startup options.
|
||||
//
|
||||
// Default behavior with no options:
|
||||
// - HTTP only
|
||||
// - listens on :8080
|
||||
// - no graceful shutdown orchestration
|
||||
//
|
||||
// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable
|
||||
// signal-aware graceful shutdown and request-context cancellation semantics.
|
||||
// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown.
|
||||
func (engine *Engine) Run(opts ...RunOption) error {
|
||||
cfg := defaultRunConfig()
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
if err := opt.apply(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if cfg.httpRedirectAddr != "" {
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
} else if cfg.tlsConfig != nil {
|
||||
cfg.mode = runModeHTTPS
|
||||
}
|
||||
if err := validateRunConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serveTLS := cfg.mode != runModeHTTP
|
||||
|
||||
mainServer := buildMainServer(engine, cfg)
|
||||
servers := []*http.Server{mainServer}
|
||||
serveTLSFlags := []bool{serveTLS}
|
||||
if cfg.mode == runModeHTTPSRedirect {
|
||||
redirectServer, err := buildRedirectServer(engine, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
servers = append(servers, redirectServer)
|
||||
serveTLSFlags = append(serveTLSFlags, false)
|
||||
}
|
||||
|
||||
if !cfg.graceful {
|
||||
if len(servers) > 1 {
|
||||
serverStopped := make(chan error, len(servers))
|
||||
for i, srv := range servers {
|
||||
serveTLSFlag := serveTLSFlags[i]
|
||||
go func(server *http.Server, useTLS bool) {
|
||||
serverStopped <- serveServer(server, useTLS)
|
||||
}(srv, serveTLSFlag)
|
||||
}
|
||||
|
||||
err := <-serverStopped
|
||||
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return errors.Join(err, shutdownErr)
|
||||
}
|
||||
return shutdownErr
|
||||
}
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
protocolLabel := "HTTP"
|
||||
if serveTLS {
|
||||
protocolLabel = "HTTPS"
|
||||
}
|
||||
log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr)
|
||||
return serveServer(mainServer, serveTLS)
|
||||
}
|
||||
|
||||
shutdownCtx := context.Background()
|
||||
if cfg.gracefulCtx != nil {
|
||||
shutdownCtx = cfg.gracefulCtx
|
||||
}
|
||||
return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx)
|
||||
}
|
||||
|
|
|
|||
492
serve_test.go
Normal file
492
serve_test.go
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func generateSelfSignedCert(t *testing.T) tls.Certificate {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate private key: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "127.0.0.1"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("create self-signed cert: %v", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||
|
||||
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("parse self-signed cert: %v", err)
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen on ephemeral port: %v", err)
|
||||
}
|
||||
addr := listener.Addr().String()
|
||||
if err := listener.Close(); err != nil {
|
||||
t.Fatalf("close temporary listener: %v", err)
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}),
|
||||
// RunShutdown uses the HTTP startup path and must not let a shared
|
||||
// ServerConfigurator accidentally turn it into HTTPS.
|
||||
TLSConfig: &tls.Config{},
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- serveServer(srv, false)
|
||||
}()
|
||||
|
||||
client := &http.Client{Timeout: 200 * time.Millisecond}
|
||||
var resp *http.Response
|
||||
requestURL := "http://" + addr
|
||||
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
resp, err = client.Get(requestURL)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
select {
|
||||
case serveErr := <-errCh:
|
||||
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: request error=%v, serve error=%v", err, serveErr)
|
||||
default:
|
||||
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err)
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read response body: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status code: got %d want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
if string(body) != "ok" {
|
||||
t.Fatalf("unexpected body: got %q want %q", string(body), "ok")
|
||||
}
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
t.Fatalf("shutdown server: %v", err)
|
||||
}
|
||||
|
||||
if err := <-errCh; !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
|
||||
engine := New()
|
||||
err := engine.Run(WithHTTPRedirect(":80"))
|
||||
if err == nil {
|
||||
t.Fatal("expected redirect mode without TLS to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
|
||||
engine := New()
|
||||
err := engine.Run(
|
||||
WithAddr(":443"),
|
||||
WithTLS(&tls.Config{}),
|
||||
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
|
||||
t.Fatalf("apply graceful default option: %v", err)
|
||||
}
|
||||
if !cfg.graceful {
|
||||
t.Fatal("expected graceful shutdown to be enabled")
|
||||
}
|
||||
if cfg.shutdownTimeout != defaultShutdownTimeout {
|
||||
t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
if err := WithTLS(tlsConfig).apply(&cfg); err != nil {
|
||||
t.Fatalf("apply TLS option: %v", err)
|
||||
}
|
||||
if cfg.mode != runModeHTTPS {
|
||||
t.Fatalf("expected HTTPS mode, got %v", cfg.mode)
|
||||
}
|
||||
if cfg.graceful {
|
||||
t.Fatal("expected TLS option to remain independent from graceful shutdown")
|
||||
}
|
||||
if cfg.tlsConfig != tlsConfig {
|
||||
t.Fatal("expected TLS config to be preserved in run config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
|
||||
engine := New()
|
||||
if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
|
||||
t.Fatal("expected redirect server builder to reject https address without port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
ctx := t.Context()
|
||||
if err := WithShutdownContext(ctx).apply(&cfg); err != nil {
|
||||
t.Fatalf("apply shutdown context option: %v", err)
|
||||
}
|
||||
if err := validateRunConfig(cfg); err == nil {
|
||||
t.Fatal("expected shutdown context without graceful shutdown to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
cfg.httpRedirectAddr = ":80"
|
||||
if err := validateRunConfig(cfg); err != nil {
|
||||
t.Fatalf("validate run config: %v", err)
|
||||
}
|
||||
if cfg.mode != runModeHTTP {
|
||||
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
cfg.tlsConfig = &tls.Config{}
|
||||
cfg.useHeaderHost = false
|
||||
cfg.useHeaderHostSet = true
|
||||
if err := validateRunConfig(cfg); err == nil {
|
||||
t.Fatal("expected configured host mode without redirect host to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
cfg.tlsConfig = &tls.Config{}
|
||||
cfg.useHeaderHost = true
|
||||
cfg.useHeaderHostSet = true
|
||||
cfg.redirectHost = "configured.example"
|
||||
if err := validateRunConfig(cfg); err == nil {
|
||||
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
|
||||
engine := New()
|
||||
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
|
||||
if server.BaseContext == nil {
|
||||
t.Fatal("expected graceful main server to set BaseContext")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen for base context check: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
if got := server.BaseContext(listener); got != engine.shutdownCtx {
|
||||
t.Fatal("expected graceful main server to use engine shutdown context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) {
|
||||
engine := New()
|
||||
serverConfigured := false
|
||||
tlsConfigured := false
|
||||
engine.SetServerConfigurator(func(s *http.Server) {
|
||||
serverConfigured = true
|
||||
s.ReadTimeout = time.Second
|
||||
})
|
||||
engine.SetTLSServerConfigurator(func(s *http.Server) {
|
||||
tlsConfigured = true
|
||||
s.IdleTimeout = time.Second
|
||||
})
|
||||
|
||||
server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||
if !tlsConfigured {
|
||||
t.Fatal("expected TLS configurator to run for HTTPS main server")
|
||||
}
|
||||
if serverConfigured {
|
||||
t.Fatal("expected generic server configurator to be skipped when TLS configurator is set")
|
||||
}
|
||||
if server.IdleTimeout != time.Second {
|
||||
t.Fatal("expected TLS configurator changes to be applied to HTTPS main server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
|
||||
engine := New()
|
||||
configured := false
|
||||
engine.SetServerConfigurator(func(s *http.Server) {
|
||||
configured = true
|
||||
s.ReadTimeout = time.Second
|
||||
})
|
||||
|
||||
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
if !configured {
|
||||
t.Fatal("expected redirect server to use generic server configurator")
|
||||
}
|
||||
if server.ReadTimeout != time.Second {
|
||||
t.Fatal("expected redirect server configurator changes to be applied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
|
||||
engine := New()
|
||||
httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||
if !httpsServer.Protocols.HTTP2() {
|
||||
t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings")
|
||||
}
|
||||
|
||||
httpServer := buildMainServer(engine, defaultRunConfig())
|
||||
if httpServer.Protocols.HTTP2() {
|
||||
t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||
req.Host = "example.com:80"
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{
|
||||
addr: ":443",
|
||||
httpRedirectAddr: ":80",
|
||||
useHeaderHost: true,
|
||||
useHeaderHostSet: true,
|
||||
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||
req.Host = "example.com:80"
|
||||
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
||||
req.Header.Set("X-First-Host", "first.example")
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{
|
||||
addr: ":443",
|
||||
httpRedirectAddr: ":80",
|
||||
useHeaderHost: true,
|
||||
useHeaderHostSet: true,
|
||||
redirectHostHeaders: []string{"X-Forwarded-Host"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||
req.Host = "example.com:80"
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUpgradeRequired {
|
||||
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{
|
||||
addr: ":443",
|
||||
httpRedirectAddr: ":80",
|
||||
useHeaderHost: false,
|
||||
useHeaderHostSet: true,
|
||||
redirectHost: "configured.example",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||
req.Host = "example.com:80"
|
||||
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil)
|
||||
req.Host = "[::1]:80"
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" {
|
||||
t.Fatalf("unexpected IPv6 redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
|
||||
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen on occupied addr: %v", err)
|
||||
}
|
||||
occupiedAddr := occupied.Addr().String()
|
||||
defer occupied.Close()
|
||||
|
||||
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen for redirect addr: %v", err)
|
||||
}
|
||||
redirectAddr := redirectListener.Addr().String()
|
||||
if err := redirectListener.Close(); err != nil {
|
||||
t.Fatalf("close redirect addr probe: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
mainServer := &http.Server{Addr: occupiedAddr, Handler: engine}
|
||||
|
||||
err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected gracefulServe to fail when one server cannot bind")
|
||||
}
|
||||
if !strings.Contains(err.Error(), occupiedAddr) {
|
||||
t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err)
|
||||
}
|
||||
|
||||
conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond)
|
||||
if dialErr == nil {
|
||||
conn.Close()
|
||||
t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr)
|
||||
}
|
||||
if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") {
|
||||
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
|
||||
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen on occupied addr: %v", err)
|
||||
}
|
||||
occupiedAddr := occupied.Addr().String()
|
||||
defer occupied.Close()
|
||||
|
||||
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen for redirect addr: %v", err)
|
||||
}
|
||||
redirectAddr := redirectListener.Addr().String()
|
||||
if err := redirectListener.Close(); err != nil {
|
||||
t.Fatalf("close redirect addr probe: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
err = engine.Run(
|
||||
WithAddr(occupiedAddr),
|
||||
WithTLS(&tls.Config{}),
|
||||
WithHTTPRedirect(redirectAddr),
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), occupiedAddr) {
|
||||
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
|
||||
}
|
||||
}
|
||||
66
sse.go
66
sse.go
|
|
@ -111,46 +111,40 @@ func (c *Context) EventStream(streamer func(w io.Writer) bool) {
|
|||
// EventStreamChan 返回用于 SSE 事件流的 channel.
|
||||
// 这是为高级并发场景设计的、更灵活的API.
|
||||
//
|
||||
// 重要:
|
||||
// - 调用者必须 close(eventChan) 来结束事件流.
|
||||
// - 调用者必须在独立的 goroutine 中消费 errChan 来处理错误和连接断开.
|
||||
// - 为防止 goroutine 泄漏, 建议发送方在 select 中同时监听 c.Request.Context().Done().
|
||||
// 与 EventStream 回调模式类似, 此方法是阻塞的: handler 会在此方法中停留,
|
||||
// 直到事件 channel 被关闭 (close eventChan) 或客户端断开连接.
|
||||
// 这保证了 Context 不会在 SSE 流期间被 pool 回收.
|
||||
//
|
||||
// eventChan 必须在调用此方法之前创建, 以便调用者可以在独立的 goroutine 中发送事件.
|
||||
// 调用者必须在完成后 close(eventChan) 来结束流.
|
||||
// 生产者 goroutine 必须在 select 中监听 c.Request.Context().Done(), 否则在客户端断开时会产生 goroutine 泄漏.
|
||||
//
|
||||
// 详细用法:
|
||||
//
|
||||
// r.GET("/sse/channel", func(c *touka.Context) {
|
||||
// eventChan, errChan := c.EventStreamChan()
|
||||
// eventChan := make(chan touka.Event)
|
||||
//
|
||||
// // 必须在独立的goroutine中处理错误和连接断开.
|
||||
// // 在独立的 goroutine 中异步发送事件.
|
||||
// go func() {
|
||||
// if err := <-errChan; err != nil {
|
||||
// c.Errorf("SSE channel error: %v", err)
|
||||
// }
|
||||
// }()
|
||||
//
|
||||
// // 在另一个goroutine中异步发送事件.
|
||||
// go func() {
|
||||
// // 重要: 必须在逻辑结束时关闭channel, 以通知框架.
|
||||
// defer close(eventChan)
|
||||
// defer close(eventChan) // 完成后关闭 channel 以结束事件流.
|
||||
//
|
||||
// for i := 1; i <= 5; i++ {
|
||||
// select {
|
||||
// case <-c.Request.Context().Done():
|
||||
// return // 客户端已断开, 退出 goroutine.
|
||||
// default:
|
||||
// eventChan <- touka.Event{
|
||||
// case eventChan <- touka.Event{
|
||||
// Id: fmt.Sprintf("%d", i),
|
||||
// Data: "hello from channel",
|
||||
// }:
|
||||
// }
|
||||
// time.Sleep(2 * time.Second)
|
||||
// }
|
||||
// }
|
||||
// }()
|
||||
//
|
||||
// // 阻塞直到事件流结束.
|
||||
// c.EventStreamChan(eventChan)
|
||||
// })
|
||||
func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
|
||||
eventChan := make(chan Event)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
func (c *Context) EventStreamChan(eventChan <-chan Event) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
c.Writer.Header().Del("Connection")
|
||||
|
|
@ -159,8 +153,16 @@ func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
|
|||
c.Writer.WriteHeader(http.StatusOK)
|
||||
c.Writer.Flush()
|
||||
|
||||
// 捕获稳定的引用, 不持有 *Context 指针, 以免 Context 被 pool 回收后出现竞态.
|
||||
w := c.Writer
|
||||
fl, _ := w.(http.Flusher)
|
||||
reqCtx := c.Request.Context()
|
||||
|
||||
goroutineExited := make(chan struct{})
|
||||
|
||||
// 写入 goroutine: 从 eventChan 消费事件并写入响应.
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
defer close(goroutineExited)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
|
@ -168,17 +170,23 @@ func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
|
|||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := event.Render(c.Writer); err != nil {
|
||||
errChan <- err
|
||||
if err := event.Render(w); err != nil {
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
case <-c.Request.Context().Done():
|
||||
errChan <- c.Request.Context().Err()
|
||||
if fl != nil {
|
||||
fl.Flush()
|
||||
}
|
||||
case <-reqCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return eventChan, errChan
|
||||
// 阻塞直到:
|
||||
// 1. 写入 goroutine 退出 (eventChan 关闭或写入失败)
|
||||
// 2. 客户端断开连接 (reqCtx 取消)
|
||||
select {
|
||||
case <-goroutineExited:
|
||||
case <-reqCtx.Done():
|
||||
}
|
||||
}
|
||||
|
|
|
|||
142
sse_test.go
Normal file
142
sse_test.go
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestEventStreamChanBlocksHandler verifies that EventStreamChan blocks until
|
||||
// the event channel is closed.
|
||||
func TestEventStreamChanBlocksHandler(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/sse", nil)
|
||||
c, _ := CreateTestContextWithRequest(rr, req)
|
||||
|
||||
handlerReturned := make(chan struct{})
|
||||
eventChan := make(chan Event)
|
||||
|
||||
// Start producer goroutine before EventStreamChan blocks
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
eventChan <- Event{Data: "hello"}
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
c.EventStreamChan(eventChan)
|
||||
close(handlerReturned)
|
||||
}()
|
||||
|
||||
// Wait for goroutine to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Handler should NOT have returned (eventChan not closed)
|
||||
select {
|
||||
case <-handlerReturned:
|
||||
t.Fatal("Handler returned before eventChan was closed - EventStreamChan is not blocking")
|
||||
case <-time.After(40 * time.Millisecond):
|
||||
// good, still blocking
|
||||
}
|
||||
|
||||
// Wait for producer to finish (30+30ms + margin)
|
||||
select {
|
||||
case <-handlerReturned:
|
||||
// good, handler returned
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("Handler did not return after eventChan was closed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEventStreamChanUnblocksOnClientDisconnect verifies the handler returns
|
||||
// when the request context is cancelled, even if eventChan is never closed.
|
||||
func TestEventStreamChanUnblocksOnClientDisconnect(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/sse", nil).WithContext(ctx)
|
||||
c, _ := CreateTestContextWithRequest(rr, req)
|
||||
|
||||
eventChan := make(chan Event)
|
||||
handlerReturned := make(chan struct{})
|
||||
|
||||
// Producer never closes eventChan
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case eventChan <- Event{Data: "tick"}:
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
c.EventStreamChan(eventChan)
|
||||
close(handlerReturned)
|
||||
}()
|
||||
|
||||
// Handler should NOT have returned
|
||||
select {
|
||||
case <-handlerReturned:
|
||||
t.Fatal("Handler returned before stream ended")
|
||||
case <-time.After(60 * time.Millisecond):
|
||||
// good, still blocked
|
||||
}
|
||||
|
||||
// Cancel context to simulate client disconnect
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-handlerReturned:
|
||||
// good
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("Handler did not return after client disconnect")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEventStreamChanWritesEvents verifies the SSE event format is correct.
|
||||
func TestEventStreamChanWritesEvents(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/sse", nil)
|
||||
c, _ := CreateTestContextWithRequest(rr, req)
|
||||
|
||||
eventChan := make(chan Event)
|
||||
|
||||
go func() {
|
||||
defer close(eventChan)
|
||||
eventChan <- Event{Id: "1", Event: "tick", Data: "hello\nworld"}
|
||||
eventChan <- Event{Id: "2", Data: "second"}
|
||||
}()
|
||||
|
||||
c.EventStreamChan(eventChan)
|
||||
|
||||
body := rr.Body.String()
|
||||
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
if !strings.Contains(ct, "text/event-stream") {
|
||||
t.Fatalf("expected text/event-stream content type, got %q", ct)
|
||||
}
|
||||
|
||||
if !strings.Contains(body, "id: 1") {
|
||||
t.Fatal("missing id field in first event")
|
||||
}
|
||||
if !strings.Contains(body, "event: tick") {
|
||||
t.Fatal("missing event field in first event")
|
||||
}
|
||||
if !strings.Contains(body, "data: hello") {
|
||||
t.Fatal("missing data line 1 in first event")
|
||||
}
|
||||
if !strings.Contains(body, "data: world") {
|
||||
t.Fatal("missing data line 2 in first event")
|
||||
}
|
||||
if !strings.Contains(body, "id: 2") {
|
||||
t.Fatal("missing id field in second event")
|
||||
}
|
||||
if !strings.Contains(body, "data: second") {
|
||||
t.Fatal("missing data in second event")
|
||||
}
|
||||
}
|
||||
8
touka.go
8
touka.go
|
|
@ -22,10 +22,10 @@ type HandlerFunc func(*Context)
|
|||
// HandlersChain 定义处理函数链(中间件栈)的类型。
|
||||
type HandlersChain []HandlerFunc
|
||||
|
||||
// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。
|
||||
type IRouter interface {
|
||||
Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组
|
||||
Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组
|
||||
// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。
|
||||
type Router interface {
|
||||
Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组
|
||||
Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组
|
||||
|
||||
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
|
||||
GET(relativePath string, handlers ...HandlerFunc)
|
||||
|
|
|
|||
82
tree.go
82
tree.go
|
|
@ -124,6 +124,7 @@ type node struct {
|
|||
path string // 当前节点的路径段
|
||||
indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点
|
||||
wildChild bool // 是否包含通配符子节点(:param 或 *catchAll)
|
||||
hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由
|
||||
nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有)
|
||||
priority uint32 // 节点的优先级, 用于查找时优先匹配
|
||||
children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾
|
||||
|
|
@ -131,6 +132,19 @@ type node struct {
|
|||
fullPath string // 完整路径, 用于调试和错误信息
|
||||
}
|
||||
|
||||
func routeNeedsCaseInsensitiveLookup(path string) bool {
|
||||
for i := 0; i < len(path); i++ {
|
||||
c := path[i]
|
||||
if c >= utf8.RuneSelf {
|
||||
return true
|
||||
}
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序.
|
||||
func (n *node) incrementChildPrio(pos int) int {
|
||||
cs := n.children // 获取子节点切片
|
||||
|
|
@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int {
|
|||
func (n *node) addRoute(path string, handlers HandlersChain) {
|
||||
fullPath := path // 记录完整的路径
|
||||
n.priority++ // 增加当前节点的优先级
|
||||
if routeNeedsCaseInsensitiveLookup(path) {
|
||||
n.hasCaseInsensitivePath = true
|
||||
}
|
||||
|
||||
// 如果是空树(根节点)
|
||||
if len(n.path) == 0 && len(n.children) == 0 {
|
||||
|
|
@ -452,12 +469,14 @@ type skippedNode struct {
|
|||
// 建议进行 TSR(尾部斜杠重定向).
|
||||
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
|
||||
var globalParamsCount int16 // 全局参数计数
|
||||
var backtrackToWildChild bool
|
||||
|
||||
walk: // 外部循环用于遍历路由树
|
||||
for {
|
||||
prefix := n.path // 当前节点的路径前缀
|
||||
if len(path) > len(prefix) {
|
||||
if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头
|
||||
pathAtNode := path
|
||||
path = path[len(prefix):] // 移除已匹配的前缀
|
||||
|
||||
// 在访问 path[0] 之前进行安全检查
|
||||
|
|
@ -467,23 +486,16 @@ walk: // 外部循环用于遍历路由树
|
|||
|
||||
// 优先尝试所有非通配符子节点, 通过匹配索引字符
|
||||
idxc := path[0] // 剩余路径的第一个字符
|
||||
for i, c := range []byte(n.indices) {
|
||||
if c == idxc { // 如果找到匹配的索引字符
|
||||
if !backtrackToWildChild {
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if n.indices[i] == idxc { // 如果找到匹配的索引字符
|
||||
// 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯
|
||||
if n.wildChild {
|
||||
index := len(*skippedNodes)
|
||||
*skippedNodes = (*skippedNodes)[:index+1]
|
||||
(*skippedNodes)[index] = skippedNode{
|
||||
path: prefix + path, // 记录跳过的路径
|
||||
node: &node{ // 复制当前节点的状态
|
||||
path: n.path,
|
||||
wildChild: n.wildChild,
|
||||
nType: n.nType,
|
||||
priority: n.priority,
|
||||
children: n.children,
|
||||
handlers: n.handlers,
|
||||
fullPath: n.fullPath,
|
||||
},
|
||||
path: pathAtNode, // 记录进入当前节点时的剩余路径
|
||||
node: n,
|
||||
paramsCount: globalParamsCount, // 记录当前参数计数
|
||||
}
|
||||
}
|
||||
|
|
@ -492,6 +504,9 @@ walk: // 外部循环用于遍历路由树
|
|||
continue walk // 继续外部循环
|
||||
}
|
||||
}
|
||||
} else {
|
||||
backtrackToWildChild = false
|
||||
}
|
||||
|
||||
if !n.wildChild {
|
||||
// 如果路径在循环结束时不等于 '/' 且当前节点没有子节点
|
||||
|
|
@ -507,6 +522,7 @@ walk: // 外部循环用于遍历路由树
|
|||
*value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片
|
||||
}
|
||||
globalParamsCount = skippedNode.paramsCount // 恢复参数计数
|
||||
backtrackToWildChild = true
|
||||
continue walk // 继续外部循环
|
||||
}
|
||||
}
|
||||
|
|
@ -547,7 +563,7 @@ walk: // 外部循环用于遍历路由树
|
|||
i := len(*value.params)
|
||||
*value.params = (*value.params)[:i+1] // 扩展切片
|
||||
val := path[:end] // 提取参数值
|
||||
if unescape { // 如果需要进行 URL 解码
|
||||
if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
|
||||
if v, err := url.QueryUnescape(val); err == nil {
|
||||
val = v // 解码成功则更新值
|
||||
}
|
||||
|
|
@ -599,7 +615,7 @@ walk: // 外部循环用于遍历路由树
|
|||
i := len(*value.params)
|
||||
*value.params = (*value.params)[:i+1] // 扩展切片
|
||||
val := path // 参数值是剩余的整个路径
|
||||
if unescape { // 如果需要进行 URL 解码
|
||||
if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
|
||||
if v, err := url.QueryUnescape(path); err == nil {
|
||||
val = v // 解码成功则更新值
|
||||
}
|
||||
|
|
@ -634,6 +650,7 @@ walk: // 外部循环用于遍历路由树
|
|||
*value.params = (*value.params)[:skippedNode.paramsCount]
|
||||
}
|
||||
globalParamsCount = skippedNode.paramsCount
|
||||
backtrackToWildChild = true
|
||||
continue walk
|
||||
}
|
||||
}
|
||||
|
|
@ -658,8 +675,8 @@ walk: // 外部循环用于遍历路由树
|
|||
}
|
||||
|
||||
// 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议
|
||||
for i, c := range []byte(n.indices) {
|
||||
if c == '/' { // 如果索引中包含 '/'
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if n.indices[i] == '/' { // 如果索引中包含 '/'
|
||||
n = n.children[i] // 移动到对应的子节点
|
||||
value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
||||
(n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数
|
||||
|
|
@ -688,6 +705,7 @@ walk: // 外部循环用于遍历路由树
|
|||
*value.params = (*value.params)[:skippedNode.paramsCount]
|
||||
}
|
||||
globalParamsCount = skippedNode.paramsCount
|
||||
backtrackToWildChild = true
|
||||
continue walk
|
||||
}
|
||||
}
|
||||
|
|
@ -701,13 +719,15 @@ walk: // 外部循环用于遍历路由树
|
|||
// 它还可以选择修复尾部斜杠.
|
||||
// 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功.
|
||||
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
|
||||
const stackBufSize = 128 // 栈上缓冲区的默认大小
|
||||
return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash)
|
||||
}
|
||||
|
||||
// 在常见情况下使用栈上静态大小的缓冲区.
|
||||
// 如果路径太长, 则在堆上分配缓冲区.
|
||||
buf := make([]byte, 0, stackBufSize)
|
||||
if length := len(path) + 1; length > stackBufSize {
|
||||
buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区
|
||||
func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) {
|
||||
if buf != nil {
|
||||
buf = buf[:0]
|
||||
}
|
||||
if cap(buf) < len(path)+1 {
|
||||
buf = make([]byte, 0, len(path)+1)
|
||||
}
|
||||
|
||||
ciPath := n.findCaseInsensitivePathRec(
|
||||
|
|
@ -758,8 +778,8 @@ walk: // 外部循环用于遍历路由树
|
|||
// 未找到处理函数.
|
||||
// 尝试通过添加尾部斜杠来修复路径
|
||||
if fixTrailingSlash {
|
||||
for i, c := range []byte(n.indices) {
|
||||
if c == '/' { // 如果索引中包含 '/'
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if n.indices[i] == '/' { // 如果索引中包含 '/'
|
||||
n = n.children[i] // 移动到对应的子节点
|
||||
if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
||||
(n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数
|
||||
|
|
@ -781,8 +801,8 @@ walk: // 外部循环用于遍历路由树
|
|||
if rb[0] != 0 {
|
||||
// 旧 rune 未处理完
|
||||
idxc := rb[0]
|
||||
for i, c := range []byte(n.indices) {
|
||||
if c == idxc {
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
if n.indices[i] == idxc {
|
||||
// 继续处理子节点
|
||||
n = n.children[i]
|
||||
npLen = len(n.path)
|
||||
|
|
@ -813,9 +833,9 @@ walk: // 外部循环用于遍历路由树
|
|||
rb = shiftNRuneBytes(rb, off)
|
||||
|
||||
idxc := rb[0]
|
||||
for i, c := range []byte(n.indices) {
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
// 小写匹配
|
||||
if c == idxc {
|
||||
if n.indices[i] == idxc {
|
||||
// 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在
|
||||
if out := n.children[i].findCaseInsensitivePathRec(
|
||||
path, ciPath, rb, fixTrailingSlash,
|
||||
|
|
@ -832,9 +852,9 @@ walk: // 外部循环用于遍历路由树
|
|||
rb = shiftNRuneBytes(rb, off)
|
||||
|
||||
idxc := rb[0]
|
||||
for i, c := range []byte(n.indices) {
|
||||
for i := 0; i < len(n.indices); i++ {
|
||||
// 大写匹配
|
||||
if c == idxc {
|
||||
if n.indices[i] == idxc {
|
||||
// 继续处理子节点
|
||||
n = n.children[i]
|
||||
npLen = len(n.path)
|
||||
|
|
@ -852,7 +872,7 @@ walk: // 外部循环用于遍历路由树
|
|||
return nil // 未找到, 返回 nil
|
||||
}
|
||||
|
||||
n = n.children[0] // 移动到通配符子节点(通常是唯一一个)
|
||||
n = n.children[len(n.children)-1] // 通配符子节点约定始终位于末尾
|
||||
switch n.nType {
|
||||
case param: // 参数节点
|
||||
// 查找参数结束位置('/' 或路径末尾)
|
||||
|
|
|
|||
94
tree_test.go
94
tree_test.go
|
|
@ -11,6 +11,7 @@ import (
|
|||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Used as a workaround since we can't compare functions or their addresses
|
||||
|
|
@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode {
|
|||
return &ps
|
||||
}
|
||||
|
||||
func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue {
|
||||
t.Helper()
|
||||
|
||||
resultCh := make(chan nodeValue, 1)
|
||||
go func() {
|
||||
resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape)
|
||||
}()
|
||||
|
||||
select {
|
||||
case value := <-resultCh:
|
||||
return value
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path)
|
||||
return nodeValue{}
|
||||
}
|
||||
}
|
||||
|
||||
func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) {
|
||||
unescape := false
|
||||
if len(unescapes) >= 1 {
|
||||
|
|
@ -901,6 +919,34 @@ func TestTreeInvalidNodeType(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestFindCaseInsensitivePathWithStaticAndParamRoutesDoesNotPanicOnMiss(t *testing.T) {
|
||||
tree := &node{}
|
||||
routes := [...]string{
|
||||
"/:user/:repo/info/refs",
|
||||
"/healthz",
|
||||
"/api/db/data",
|
||||
"/api/db/sum",
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
tree.addRoute(route, fakeHandler(route))
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatalf("unexpected panic while looking up missing path: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if out, found := tree.findCaseInsensitivePath("/does-not-exist", true); found || out != nil {
|
||||
t.Fatalf("expected missing path lookup to return no match, got %q, %t", string(out), found)
|
||||
}
|
||||
|
||||
if out, found := tree.findCaseInsensitivePath("/does-not-exist", false); found || out != nil {
|
||||
t.Fatalf("expected missing path lookup without trailing slash fix to return no match, got %q, %t", string(out), found)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTreeInvalidParamsType(t *testing.T) {
|
||||
tree := &node{}
|
||||
// add a child with wildcard
|
||||
|
|
@ -1076,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) {
|
|||
t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
routes []string
|
||||
requestPath string
|
||||
wantFullPath string
|
||||
wantParams Params
|
||||
}{
|
||||
{
|
||||
name: "param route after static dead end",
|
||||
routes: []string{"/foo/bar", "/foo/:id/details"},
|
||||
requestPath: "/foo/bar/details",
|
||||
wantFullPath: "/foo/:id/details",
|
||||
wantParams: Params{{Key: "id", Value: "bar"}},
|
||||
},
|
||||
{
|
||||
name: "catch-all route after static dead end",
|
||||
routes: []string{"/foo/bar", "/foo/:id/*rest"},
|
||||
requestPath: "/foo/bar/baz.txt",
|
||||
wantFullPath: "/foo/:id/*rest",
|
||||
wantParams: Params{
|
||||
{Key: "id", Value: "bar"},
|
||||
{Key: "rest", Value: "/baz.txt"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tree := &node{}
|
||||
for _, route := range tt.routes {
|
||||
tree.addRoute(route, fakeHandler(route))
|
||||
}
|
||||
|
||||
value := getValueWithTimeout(t, tree, tt.requestPath, false)
|
||||
if value.handlers == nil {
|
||||
t.Fatalf("expected handlers for %q", tt.requestPath)
|
||||
}
|
||||
if value.fullPath != tt.wantFullPath {
|
||||
t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath)
|
||||
}
|
||||
if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) {
|
||||
t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue