mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Compare commits
106 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d439662adf | ||
|
|
810ba788ae | ||
|
|
de0e16852f | ||
|
|
8ec77ecc9f | ||
|
|
b3b82b3c61 | ||
|
|
52db699db9 | ||
|
|
43fede96d5 | ||
|
|
01395dc942 | ||
|
|
3c40a3d6b5 | ||
|
|
9dcab4b1ae | ||
|
|
2d693e3b13 | ||
|
|
d8a5f200c1 | ||
|
|
6006267d25 | ||
|
|
390190695f | ||
|
|
7487369125 | ||
|
|
e7c7d5e41f | ||
|
|
4f262b2497 | ||
|
|
f2295c3084 | ||
|
|
b83e536def | ||
|
|
10033f4a17 | ||
|
|
c8b14ef43a | ||
|
|
2581697771 | ||
|
|
58fd877ae2 | ||
|
|
fce12ee7e7 | ||
|
|
d9328c3176 | ||
|
|
8fdb16ae1e | ||
|
|
1243d2d37a | ||
|
|
fa925582d7 | ||
|
|
5d9bb3187d | ||
|
|
c0e31c449e | ||
|
|
93f5edc6eb | ||
|
|
06a6d42de1 | ||
|
|
3b5f2c81af | ||
|
|
b008fc8e61 | ||
|
|
0f7cf23abb | ||
|
|
54f7de0c60 | ||
|
|
02861b5537 | ||
|
|
7c37d4c38c | ||
|
|
271e54eb4d | ||
|
|
017bb13295 | ||
|
|
71a344a3de | ||
|
|
efa1e3fb3f | ||
|
|
7cb777225f | ||
|
|
121679b44e | ||
|
|
9e57f5a5f5 | ||
|
|
e2cf08d5dd | ||
|
|
e4d3eed379 | ||
|
|
fca9bbd3ef | ||
|
|
987ea81329 | ||
|
|
fa027347d3 | ||
|
|
57847fa446 | ||
|
|
2d4aefc86e | ||
|
|
5d979e5670 | ||
|
|
6acac9edce | ||
|
|
b1ce4d584e | ||
|
|
7db3d32d7b | ||
|
|
d12e887858 | ||
|
|
7f69d5668e | ||
|
|
70f8cc6159 | ||
|
|
863f984990 | ||
|
|
1a6325d461 | ||
|
|
d53693952a | ||
|
|
dcdb1504a3 | ||
|
|
20dc6e4047 | ||
|
|
7abedc1ace | ||
|
|
50c6a23614 | ||
|
|
a9c1662333 | ||
|
|
0d7721a24c | ||
|
|
919236665b | ||
|
|
59f190ce3a | ||
|
|
2165cc4114 | ||
|
|
ed44c592d3 | ||
|
|
c019f24e99 | ||
|
|
e6ff0fa6b9 | ||
|
|
91c50536c4 | ||
|
|
85cc9b5cf6 | ||
|
|
64e2ad9e7b | ||
|
|
ef965f4a6a | ||
|
|
d90d043811 | ||
|
|
8dc7d8c136 | ||
|
|
9f210deadf | ||
|
|
7be49b96c8 | ||
|
|
3aa84f5dcf | ||
|
|
fba6fedfc5 | ||
|
|
d0fa14c3c5 | ||
|
|
45c6d36748 | ||
|
|
b4e45610b2 | ||
|
|
b09595e745 | ||
|
|
6e33bc48aa | ||
|
|
7e15181c0b | ||
|
|
559aefeb85 | ||
|
|
2f94763c65 | ||
|
|
c7a9a889e4 | ||
|
|
8031e799d9 | ||
|
|
6d89b8674f | ||
|
|
1946216c0e | ||
|
|
e4ca20e848 | ||
|
|
764a764720 | ||
|
|
e5400c2da7 | ||
|
|
67a7e21d81 | ||
|
|
91a330f51b | ||
|
|
a98fb27058 | ||
|
|
3be2c05f0c | ||
|
|
fcc23745b6 | ||
|
|
7b8c0d7dcb | ||
|
|
8af515059a |
52 changed files with 11820 additions and 698 deletions
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
|
|
@ -2,8 +2,6 @@ name: Go Test
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
|
||||||
- '*'
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|
@ -13,9 +11,9 @@ jobs:
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.24'
|
go-version-file: 'go.mod'
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: go test -v ./...
|
run: go test -v ./...
|
||||||
|
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -1 +1,2 @@
|
||||||
test
|
test
|
||||||
|
/bench_route_match_baseline.txt
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ Touka(灯花) 是一个基于 Go 语言构建的多层次、高性能 Web 框架
|
||||||
- **[中间件 (middleware.md)](docs/middleware.md)**
|
- **[中间件 (middleware.md)](docs/middleware.md)**
|
||||||
- **[统一错误处理 (error-handling.md)](docs/error-handling.md)**
|
- **[统一错误处理 (error-handling.md)](docs/error-handling.md)**
|
||||||
- **[静态文件与资源 (static-files.md)](docs/static-files.md)**
|
- **[静态文件与资源 (static-files.md)](docs/static-files.md)**
|
||||||
|
- **[反向代理 (reverse-proxy.md)](docs/reverse-proxy.md)**
|
||||||
- **[Server-Sent Events (sse.md)](docs/sse.md)**
|
- **[Server-Sent Events (sse.md)](docs/sse.md)**
|
||||||
- **[高级特性与优化 (advanced.md)](docs/advanced.md)**
|
- **[高级特性与优化 (advanced.md)](docs/advanced.md)**
|
||||||
|
|
||||||
|
|
@ -58,9 +59,9 @@ func main() {
|
||||||
c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query)
|
c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 启动服务器 (支持优雅关闭)
|
// 启动服务器(通过 WithGracefulShutdown 启用优雅关闭)
|
||||||
log.Println("Touka Server starting on :8080...")
|
log.Println("Touka Server starting on :8080...")
|
||||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||||
log.Fatalf("Touka server failed to start: %v", err)
|
log.Fatalf("Touka server failed to start: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -70,13 +70,13 @@ func main() {
|
||||||
r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB
|
r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB
|
||||||
|
|
||||||
// ... 其他配置
|
// ... 其他配置
|
||||||
r.Run(":8080")
|
r.Run(touka.WithAddr(":8080"))
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 1.3. 服务器生命周期管理
|
#### 1.3. 服务器生命周期管理
|
||||||
|
|
||||||
Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。
|
Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。
|
||||||
|
|
||||||
```go
|
```go
|
||||||
func main() {
|
func main() {
|
||||||
|
|
@ -90,11 +90,11 @@ func main() {
|
||||||
fmt.Println("自定义的 HTTP 服务器配置已应用")
|
fmt.Println("自定义的 HTTP 服务器配置已应用")
|
||||||
})
|
})
|
||||||
|
|
||||||
// 启动服务器,并支持优雅关闭
|
// 启动服务器,并通过 Run 选项启用优雅关闭
|
||||||
// RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号
|
// Run(...) 会阻塞当前 goroutine
|
||||||
// 第二个参数是优雅关闭的超时时间
|
// WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒
|
||||||
fmt.Println("服务器启动于 :8080")
|
fmt.Println("服务器启动于 :8080")
|
||||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||||
log.Fatalf("服务器启动失败: %v", err)
|
log.Fatalf("服务器启动失败: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -187,7 +187,7 @@ func main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Run(":8080")
|
r.Run(touka.WithAddr(":8080"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthMiddleware() touka.HandlerFunc {
|
func AuthMiddleware() touka.HandlerFunc {
|
||||||
|
|
@ -313,7 +313,7 @@ func main() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Run(":8080")
|
r.Run(touka.WithAddr(":8080"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// templates/index.html
|
// templates/index.html
|
||||||
|
|
@ -400,7 +400,7 @@ func main() {
|
||||||
c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID})
|
c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID})
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Run(":8080")
|
r.Run(touka.WithAddr(":8080"))
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -483,7 +483,7 @@ func main() {
|
||||||
// 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获
|
// 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获
|
||||||
r.StaticDir("/files", "./non-existent-dir")
|
r.StaticDir("/files", "./non-existent-dir")
|
||||||
|
|
||||||
r.Run(":8080")
|
r.Run(touka.WithAddr(":8080"))
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -546,7 +546,7 @@ func main() {
|
||||||
// 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录
|
// 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录
|
||||||
r.StaticFS("/", http.FS(subFS))
|
r.StaticFS("/", http.FS(subFS))
|
||||||
|
|
||||||
r.Run(":8080")
|
r.Run(touka.WithAddr(":8080"))
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
52
compat.go
Normal file
52
compat.go
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
// Copyright 2024 WJQSERVER. All rights reserved.
|
||||||
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/WJQSERVER-STUDIO/httpc"
|
||||||
|
"github.com/fenthope/reco"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- reco 兼容函数 ---
|
||||||
|
|
||||||
|
// GetLogReco 返回底层的 reco.Logger 实例
|
||||||
|
// 用于需要访问 reco 特定功能的场景
|
||||||
|
// 如果当前 logger 不是 *reco.Logger 类型,返回 nil
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (engine *Engine) GetLogReco() *reco.Logger {
|
||||||
|
return engine.LogReco
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogReco 设置 reco.Logger 实例
|
||||||
|
// 用于向后兼容,等价于 SetLogger(l)
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (engine *Engine) SetLogReco(l *reco.Logger) {
|
||||||
|
engine.LogReco = l
|
||||||
|
engine.logger = l
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLoggerReco 返回底层的 reco.Logger 实例
|
||||||
|
// 用于需要访问 reco 特定功能的场景
|
||||||
|
// 如果当前 logger 不是 *reco.Logger 类型,返回 nil
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) GetLoggerReco() *reco.Logger {
|
||||||
|
if rl, ok := c.engine.logger.(*reco.Logger); ok {
|
||||||
|
return rl
|
||||||
|
}
|
||||||
|
return c.engine.LogReco
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- httpc 兼容函数 ---
|
||||||
|
|
||||||
|
// GetHTTPC 返回底层的 httpc.Client 实例
|
||||||
|
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) GetHTTPC() *httpc.Client {
|
||||||
|
return c.Client()
|
||||||
|
}
|
||||||
314
context.go
314
context.go
|
|
@ -26,7 +26,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/WJQSERVER/wanf"
|
"github.com/WJQSERVER/wanf"
|
||||||
"github.com/fenthope/reco"
|
|
||||||
"github.com/go-json-experiment/json"
|
"github.com/go-json-experiment/json"
|
||||||
|
|
||||||
"github.com/WJQSERVER-STUDIO/go-utils/iox"
|
"github.com/WJQSERVER-STUDIO/go-utils/iox"
|
||||||
|
|
@ -44,6 +43,8 @@ type Context struct {
|
||||||
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
|
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
|
||||||
index int8 // 当前执行到处理链的哪个位置
|
index int8 // 当前执行到处理链的哪个位置
|
||||||
|
|
||||||
|
requestBodyPrepared bool
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
Keys map[string]any // 用于在中间件之间传递数据
|
Keys map[string]any // 用于在中间件之间传递数据
|
||||||
|
|
||||||
|
|
@ -71,6 +72,12 @@ type Context struct {
|
||||||
// skippedNodes 用于记录跳过的节点信息,以便回溯
|
// skippedNodes 用于记录跳过的节点信息,以便回溯
|
||||||
// 通常在处理嵌套路由时使用
|
// 通常在处理嵌套路由时使用
|
||||||
SkippedNodes []skippedNode
|
SkippedNodes []skippedNode
|
||||||
|
|
||||||
|
// fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲.
|
||||||
|
fixedPathBuf []byte
|
||||||
|
|
||||||
|
allowedMethodsBuf []string
|
||||||
|
allowHeaderBuf []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Context 相关方法实现 ---
|
// --- Context 相关方法实现 ---
|
||||||
|
|
@ -95,19 +102,42 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
c.handlers = nil
|
c.handlers = nil
|
||||||
c.index = -1 // 初始为 -1,`Next()` 将其设置为 0
|
c.index = -1 // 初始为 -1,`Next()` 将其设置为 0
|
||||||
c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染
|
c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map
|
||||||
c.Errors = c.Errors[:0] // 清空 Errors 切片
|
c.Errors = c.Errors[:0] // 清空 Errors 切片
|
||||||
c.queryCache = nil // 清空查询参数缓存
|
c.queryCache = nil // 清空查询参数缓存
|
||||||
c.formCache = nil // 清空表单数据缓存
|
c.formCache = nil // 清空表单数据缓存
|
||||||
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
||||||
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
||||||
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
||||||
|
c.requestBodyPrepared = false
|
||||||
|
|
||||||
if cap(c.SkippedNodes) > 0 {
|
if cap(c.SkippedNodes) > 0 {
|
||||||
c.SkippedNodes = c.SkippedNodes[:0]
|
c.SkippedNodes = c.SkippedNodes[:0]
|
||||||
} else {
|
} else {
|
||||||
c.SkippedNodes = make([]skippedNode, 0, 256)
|
c.SkippedNodes = make([]skippedNode, 0, 256)
|
||||||
}
|
}
|
||||||
|
if cap(c.fixedPathBuf) > 0 {
|
||||||
|
c.fixedPathBuf = c.fixedPathBuf[:0]
|
||||||
|
}
|
||||||
|
if cap(c.allowedMethodsBuf) > 0 {
|
||||||
|
c.allowedMethodsBuf = c.allowedMethodsBuf[:0]
|
||||||
|
}
|
||||||
|
if cap(c.allowHeaderBuf) > 0 {
|
||||||
|
c.allowHeaderBuf = c.allowHeaderBuf[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := c.Writer.Write(data); err != nil {
|
||||||
|
wrapped := fmt.Errorf("%s: %w", contextMsg, err)
|
||||||
|
c.AddError(wrapped)
|
||||||
|
if c.engine != nil && c.engine.logger != nil {
|
||||||
|
c.engine.logger.Errorf("%s: %v", contextMsg, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next 在处理链中执行下一个处理函数
|
// Next 在处理链中执行下一个处理函数
|
||||||
|
|
@ -237,6 +267,18 @@ func (c *Context) SetMaxRequestBodySize(size int64) {
|
||||||
c.MaxRequestBodySize = size
|
c.MaxRequestBodySize = size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Context) prepareRequestBody() io.ReadCloser {
|
||||||
|
if c.Request == nil || c.Request.Body == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 {
|
||||||
|
return c.Request.Body
|
||||||
|
}
|
||||||
|
c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||||
|
c.requestBodyPrepared = true
|
||||||
|
return c.Request.Body
|
||||||
|
}
|
||||||
|
|
||||||
// Query 从 URL 查询参数中获取值
|
// Query 从 URL 查询参数中获取值
|
||||||
// 懒加载解析查询参数,并进行缓存
|
// 懒加载解析查询参数,并进行缓存
|
||||||
func (c *Context) Query(key string) string {
|
func (c *Context) Query(key string) string {
|
||||||
|
|
@ -258,7 +300,39 @@ func (c *Context) DefaultQuery(key, defaultValue string) string {
|
||||||
// 懒加载解析表单数据,并进行缓存
|
// 懒加载解析表单数据,并进行缓存
|
||||||
func (c *Context) PostForm(key string) string {
|
func (c *Context) PostForm(key string) string {
|
||||||
if c.formCache == nil {
|
if c.formCache == nil {
|
||||||
c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
c.prepareRequestBody()
|
||||||
|
}
|
||||||
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
|
if err != nil {
|
||||||
|
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||||
|
c.formCache = make(url.Values)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch mediaType {
|
||||||
|
case "multipart/form-data":
|
||||||
|
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||||
|
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||||
|
c.formCache = make(url.Values)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case "application/x-www-form-urlencoded":
|
||||||
|
if err := c.Request.ParseForm(); err != nil {
|
||||||
|
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||||
|
c.formCache = make(url.Values)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||||
|
if !errors.Is(err, http.ErrNotMultipart) {
|
||||||
|
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||||
|
c.formCache = make(url.Values)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
c.formCache = c.Request.PostForm
|
c.formCache = c.Request.PostForm
|
||||||
}
|
}
|
||||||
return c.formCache.Get(key)
|
return c.formCache.Get(key)
|
||||||
|
|
@ -282,20 +356,20 @@ func (c *Context) Param(key string) string {
|
||||||
func (c *Context) Raw(code int, contentType string, data []byte) {
|
func (c *Context) Raw(code int, contentType string, data []byte) {
|
||||||
c.Writer.Header().Set("Content-Type", contentType)
|
c.Writer.Header().Set("Content-Type", contentType)
|
||||||
c.Writer.WriteHeader(code)
|
c.Writer.WriteHeader(code)
|
||||||
c.Writer.Write(data)
|
c.writeResponseBody(data, "failed to write raw response")
|
||||||
}
|
}
|
||||||
|
|
||||||
// String 向响应写入格式化的字符串
|
// String 向响应写入格式化的字符串
|
||||||
func (c *Context) String(code int, format string, values ...any) {
|
func (c *Context) String(code int, format string, values ...any) {
|
||||||
c.Writer.WriteHeader(code)
|
c.Writer.WriteHeader(code)
|
||||||
c.Writer.Write(fmt.Appendf(nil, format, values...))
|
c.writeResponseBody(fmt.Appendf(nil, format, values...), "failed to write string response")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Text 向响应写入无需格式化的string
|
// Text 向响应写入无需格式化的string
|
||||||
func (c *Context) Text(code int, text string) {
|
func (c *Context) Text(code int, text string) {
|
||||||
c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
c.Writer.WriteHeader(code)
|
c.Writer.WriteHeader(code)
|
||||||
c.Writer.Write([]byte(text))
|
c.writeResponseBody([]byte(text), "failed to write text response")
|
||||||
}
|
}
|
||||||
|
|
||||||
// FileText
|
// FileText
|
||||||
|
|
@ -338,8 +412,11 @@ func (c *Context) FileText(code int, filePath string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
|
||||||
c.SetBodyStream(file, int(fileInfo.Size()))
|
c.Writer.WriteHeader(code)
|
||||||
|
if _, err := iox.Copy(c.Writer, file); err != nil {
|
||||||
|
c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
@ -417,6 +494,22 @@ func (c *Context) JSON(code int, obj any) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JSONBuf 先将 JSON 编码到 buffer, 成功后再写入状态码和响应体.
|
||||||
|
// 与 JSON 相比,编码失败时可以正确返回 500 状态码,代价是多一次内存分配.
|
||||||
|
func (c *Context) JSONBuf(code int, obj any) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := json.MarshalWrite(&buf, obj); err != nil {
|
||||||
|
errMsg := fmt.Errorf("failed to marshal JSON: %w", err)
|
||||||
|
c.AddError(errMsg)
|
||||||
|
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
c.Writer.WriteHeader(code)
|
||||||
|
c.writeResponseBody(buf.Bytes(), "failed to write buffered JSON response")
|
||||||
|
}
|
||||||
|
|
||||||
// GOB 向响应写入GOB数据
|
// GOB 向响应写入GOB数据
|
||||||
// 设置 Content-Type 为 application/octet-stream
|
// 设置 Content-Type 为 application/octet-stream
|
||||||
func (c *Context) GOB(code int, obj any) {
|
func (c *Context) GOB(code int, obj any) {
|
||||||
|
|
@ -431,6 +524,21 @@ func (c *Context) GOB(code int, obj any) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GOBBuf 先将 GOB 编码到 buffer, 成功后再写入状态码和响应体.
|
||||||
|
func (c *Context) GOBBuf(code int, obj any) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
encoder := gob.NewEncoder(&buf)
|
||||||
|
if err := encoder.Encode(obj); err != nil {
|
||||||
|
errMsg := fmt.Errorf("failed to encode GOB: %w", err)
|
||||||
|
c.AddError(errMsg)
|
||||||
|
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/octet-stream")
|
||||||
|
c.Writer.WriteHeader(code)
|
||||||
|
c.writeResponseBody(buf.Bytes(), "failed to write buffered GOB response")
|
||||||
|
}
|
||||||
|
|
||||||
// WANF向响应写入WANF数据
|
// WANF向响应写入WANF数据
|
||||||
// 设置 application/vnd.wjqserver.wanf; charset=utf-8
|
// 设置 application/vnd.wjqserver.wanf; charset=utf-8
|
||||||
func (c *Context) WANF(code int, obj any) {
|
func (c *Context) WANF(code int, obj any) {
|
||||||
|
|
@ -445,6 +553,21 @@ func (c *Context) WANF(code int, obj any) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WANFBuf 先将 WANF 编码到 buffer, 成功后再写入状态码和响应体.
|
||||||
|
func (c *Context) WANFBuf(code int, obj any) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
encoder := wanf.NewStreamEncoder(&buf)
|
||||||
|
if err := encoder.Encode(obj); err != nil {
|
||||||
|
errMsg := fmt.Errorf("failed to encode WANF: %w", err)
|
||||||
|
c.AddError(errMsg)
|
||||||
|
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8")
|
||||||
|
c.Writer.WriteHeader(code)
|
||||||
|
c.writeResponseBody(buf.Bytes(), "failed to write buffered WANF response")
|
||||||
|
}
|
||||||
|
|
||||||
// HTML 渲染 HTML 模板
|
// HTML 渲染 HTML 模板
|
||||||
// 如果 Engine 配置了 HTMLRender,则使用它进行渲染
|
// 如果 Engine 配置了 HTMLRender,则使用它进行渲染
|
||||||
// 否则,会进行简单的字符串输出
|
// 否则,会进行简单的字符串输出
|
||||||
|
|
@ -466,7 +589,37 @@ func (c *Context) HTML(code int, name string, obj any) {
|
||||||
// 可以扩展支持其他渲染器接口
|
// 可以扩展支持其他渲染器接口
|
||||||
}
|
}
|
||||||
// 默认简单输出,用于未配置 HTMLRender 的情况
|
// 默认简单输出,用于未配置 HTMLRender 的情况
|
||||||
c.Writer.Write(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj))
|
c.writeResponseBody(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj), "failed to write HTML response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体.
|
||||||
|
// 如果模板渲染失败,则返回 500 错误且不写入任何内容.
|
||||||
|
func (c *Context) HTMLBuf(code int, name string, obj any) {
|
||||||
|
if c.engine == nil || c.engine.HTMLRender == nil {
|
||||||
|
// 没有渲染器,回退到简单输出
|
||||||
|
c.HTML(code, name, obj)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tpl, ok := c.engine.HTMLRender.(*template.Template); ok {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := tpl.ExecuteTemplate(&buf, name, obj)
|
||||||
|
if err != nil {
|
||||||
|
// 渲染失败,记录错误并返回 500,不写入任何内容
|
||||||
|
errMsg := fmt.Errorf("failed to render HTML template '%s': %w", name, err)
|
||||||
|
c.AddError(errMsg)
|
||||||
|
c.ErrorUseHandle(http.StatusInternalServerError, errMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 渲染成功,写入响应
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.Writer.WriteHeader(code)
|
||||||
|
c.writeResponseBody(buf.Bytes(), "failed to write buffered HTML response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不支持的渲染器类型,回退到简单输出
|
||||||
|
c.HTML(code, name, obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect 执行 HTTP 重定向
|
// Redirect 执行 HTTP 重定向
|
||||||
|
|
@ -481,10 +634,16 @@ func (c *Context) Redirect(code int, location string) {
|
||||||
|
|
||||||
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象
|
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象
|
||||||
func (c *Context) ShouldBindJSON(obj any) error {
|
func (c *Context) ShouldBindJSON(obj any) error {
|
||||||
if c.Request.Body == nil {
|
var body io.ReadCloser
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
body = c.prepareRequestBody()
|
||||||
|
} else {
|
||||||
|
body = c.Request.Body
|
||||||
|
}
|
||||||
|
if body == nil {
|
||||||
return errors.New("request body is empty")
|
return errors.New("request body is empty")
|
||||||
}
|
}
|
||||||
err := json.UnmarshalRead(c.Request.Body, obj)
|
err := json.UnmarshalRead(body, obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("json binding error: %w", err)
|
return fmt.Errorf("json binding error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -493,10 +652,16 @@ func (c *Context) ShouldBindJSON(obj any) error {
|
||||||
|
|
||||||
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
|
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
|
||||||
func (c *Context) ShouldBindWANF(obj any) error {
|
func (c *Context) ShouldBindWANF(obj any) error {
|
||||||
if c.Request.Body == nil {
|
var body io.ReadCloser
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
body = c.prepareRequestBody()
|
||||||
|
} else {
|
||||||
|
body = c.Request.Body
|
||||||
|
}
|
||||||
|
if body == nil {
|
||||||
return errors.New("request body is empty")
|
return errors.New("request body is empty")
|
||||||
}
|
}
|
||||||
decoder, err := wanf.NewStreamDecoder(c.Request.Body)
|
decoder, err := wanf.NewStreamDecoder(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create WANF decoder: %w", err)
|
return fmt.Errorf("failed to create WANF decoder: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -509,10 +674,16 @@ func (c *Context) ShouldBindWANF(obj any) error {
|
||||||
|
|
||||||
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
|
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
|
||||||
func (c *Context) ShouldBindGOB(obj any) error {
|
func (c *Context) ShouldBindGOB(obj any) error {
|
||||||
if c.Request.Body == nil {
|
var body io.ReadCloser
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
body = c.prepareRequestBody()
|
||||||
|
} else {
|
||||||
|
body = c.Request.Body
|
||||||
|
}
|
||||||
|
if body == nil {
|
||||||
return errors.New("request body is empty")
|
return errors.New("request body is empty")
|
||||||
}
|
}
|
||||||
decoder := gob.NewDecoder(c.Request.Body)
|
decoder := gob.NewDecoder(body)
|
||||||
if err := decoder.Decode(obj); err != nil {
|
if err := decoder.Decode(obj); err != nil {
|
||||||
return fmt.Errorf("GOB binding error: %w", err)
|
return fmt.Errorf("GOB binding error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -629,6 +800,10 @@ func setFieldValue(field reflect.Value, values []string) error {
|
||||||
// ShouldBindForm 尝试将表单数据绑定到结构体
|
// ShouldBindForm 尝试将表单数据绑定到结构体
|
||||||
// 支持 application/x-www-form-urlencoded 和 multipart/form-data
|
// 支持 application/x-www-form-urlencoded 和 multipart/form-data
|
||||||
func (c *Context) ShouldBindForm(obj any) error {
|
func (c *Context) ShouldBindForm(obj any) error {
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
c.prepareRequestBody()
|
||||||
|
}
|
||||||
|
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -637,7 +812,7 @@ func (c *Context) ShouldBindForm(obj any) error {
|
||||||
|
|
||||||
switch mediaType {
|
switch mediaType {
|
||||||
case "multipart/form-data":
|
case "multipart/form-data":
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||||
return fmt.Errorf("parse multipart form error: %w", err)
|
return fmt.Errorf("parse multipart form error: %w", err)
|
||||||
}
|
}
|
||||||
case "application/x-www-form-urlencoded":
|
case "application/x-www-form-urlencoded":
|
||||||
|
|
@ -651,6 +826,7 @@ func (c *Context) ShouldBindForm(obj any) error {
|
||||||
if err := bindForm(c.Request.Form, obj); err != nil {
|
if err := bindForm(c.Request.Form, obj); err != nil {
|
||||||
return fmt.Errorf("form binding error: %w", err)
|
return fmt.Errorf("form binding error: %w", err)
|
||||||
}
|
}
|
||||||
|
c.formCache = c.Request.PostForm
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -688,10 +864,29 @@ func (c *Context) GetErrors() []error {
|
||||||
return c.Errors
|
return c.Errors
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client 返回 Engine 提供的 HTTPClient
|
// Client 返回当前请求的 HTTPClient
|
||||||
// 方便在请求处理函数中进行出站 HTTP 请求
|
// 如果请求处理函数或中间件设置了自定义 HTTPClient,返回该实例;
|
||||||
|
// 否则返回 Engine 提供的默认实例
|
||||||
|
//
|
||||||
|
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
|
||||||
func (c *Context) Client() *httpc.Client {
|
func (c *Context) Client() *httpc.Client {
|
||||||
|
if c.HTTPClient != nil {
|
||||||
return c.HTTPClient
|
return c.HTTPClient
|
||||||
|
}
|
||||||
|
return c.engine.HTTPClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPC 返回自动关联请求 Context 的 HTTP 客户端
|
||||||
|
// 当请求被取消时,通过此客户端发起的出站请求也会自动取消
|
||||||
|
func (c *Context) HTTPC() *contextHTTPClient {
|
||||||
|
client := c.HTTPClient
|
||||||
|
if client == nil {
|
||||||
|
client = c.engine.HTTPClient
|
||||||
|
}
|
||||||
|
return &contextHTTPClient{
|
||||||
|
client: client,
|
||||||
|
ctx: c.ctx,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Context() 返回请求的上下文,用于取消操作
|
// Context() 返回请求的上下文,用于取消操作
|
||||||
|
|
@ -751,37 +946,30 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) {
|
||||||
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
|
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
|
||||||
// 注意:请求体只能读取一次
|
// 注意:请求体只能读取一次
|
||||||
func (c *Context) GetReqBody() io.ReadCloser {
|
func (c *Context) GetReqBody() io.ReadCloser {
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
return c.prepareRequestBody()
|
||||||
|
}
|
||||||
|
if c.Request == nil || c.Request.Body == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return c.Request.Body
|
return c.Request.Body
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetReqBodyFull 读取并返回请求体的所有内容
|
// GetReqBodyFull 读取并返回请求体的所有内容
|
||||||
// 注意:请求体只能读取一次
|
// 注意:请求体只能读取一次
|
||||||
func (c *Context) GetReqBodyFull() ([]byte, error) {
|
func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||||||
if c.Request.Body == nil {
|
body := c.GetReqBody()
|
||||||
|
if body == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var limitBytesReader io.ReadCloser
|
|
||||||
|
|
||||||
if c.MaxRequestBodySize > 0 {
|
|
||||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := limitBytesReader.Close()
|
err := body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
} else {
|
|
||||||
limitBytesReader = c.Request.Body
|
|
||||||
defer func() {
|
|
||||||
err := limitBytesReader.Close()
|
|
||||||
if err != nil {
|
|
||||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := iox.ReadAll(limitBytesReader)
|
data, err := io.ReadAll(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||||
|
|
@ -791,31 +979,18 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||||||
|
|
||||||
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
|
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
|
||||||
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
||||||
if c.Request.Body == nil {
|
body := c.GetReqBody()
|
||||||
|
if body == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var limitBytesReader io.ReadCloser
|
|
||||||
|
|
||||||
if c.MaxRequestBodySize > 0 {
|
|
||||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := limitBytesReader.Close()
|
err := body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
} else {
|
|
||||||
limitBytesReader = c.Request.Body
|
|
||||||
defer func() {
|
|
||||||
err := limitBytesReader.Close()
|
|
||||||
if err != nil {
|
|
||||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := iox.ReadAll(limitBytesReader)
|
data, err := io.ReadAll(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||||
|
|
@ -974,14 +1149,9 @@ func (c *Context) GetProtocol() string {
|
||||||
return c.Request.Proto
|
return c.Request.Proto
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHTTPC 获取框架自带传递的httpc
|
// GetLogger 获取engine的Logger接口
|
||||||
func (c *Context) GetHTTPC() *httpc.Client {
|
func (c *Context) GetLogger() Logger {
|
||||||
return c.HTTPClient
|
return c.engine.logger
|
||||||
}
|
|
||||||
|
|
||||||
// GetLogger 获取engine的Logger
|
|
||||||
func (c *Context) GetLogger() *reco.Logger {
|
|
||||||
return c.engine.LogReco
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetReqQueryString
|
// GetReqQueryString
|
||||||
|
|
@ -1084,17 +1254,25 @@ func (c *Context) SetSameSite(samesite http.SameSite) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCookie 设置一个 HTTP cookie
|
// SetCookie 设置一个 HTTP cookie
|
||||||
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool) {
|
// sameSite 参数是可选的,如果不提供则使用通过 SetSameSite 设置的值
|
||||||
|
func (c *Context) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool, sameSite ...http.SameSite) {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
path = "/"
|
path = "/"
|
||||||
}
|
}
|
||||||
|
site := c.sameSite
|
||||||
|
if len(sameSite) > 0 {
|
||||||
|
if len(sameSite) > 1 {
|
||||||
|
c.Warnf("SetCookie: only the first SameSite value will be used, got %d values", len(sameSite))
|
||||||
|
}
|
||||||
|
site = sameSite[0]
|
||||||
|
}
|
||||||
http.SetCookie(c.Writer, &http.Cookie{
|
http.SetCookie(c.Writer, &http.Cookie{
|
||||||
Name: name,
|
Name: name,
|
||||||
Value: url.QueryEscape(value),
|
Value: url.QueryEscape(value),
|
||||||
MaxAge: maxAge,
|
MaxAge: maxAge,
|
||||||
Path: path,
|
Path: path,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
SameSite: c.sameSite,
|
SameSite: site,
|
||||||
Secure: secure,
|
Secure: secure,
|
||||||
HttpOnly: httpOnly,
|
HttpOnly: httpOnly,
|
||||||
})
|
})
|
||||||
|
|
@ -1132,25 +1310,25 @@ func (c *Context) DeleteCookie(name string) {
|
||||||
|
|
||||||
// === 日志记录 ===
|
// === 日志记录 ===
|
||||||
func (c *Context) Debugf(format string, args ...any) {
|
func (c *Context) Debugf(format string, args ...any) {
|
||||||
c.engine.LogReco.Debugf(format, args...)
|
c.engine.logger.Debugf(format, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Infof(format string, args ...any) {
|
func (c *Context) Infof(format string, args ...any) {
|
||||||
c.engine.LogReco.Infof(format, args...)
|
c.engine.logger.Infof(format, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Warnf(format string, args ...any) {
|
func (c *Context) Warnf(format string, args ...any) {
|
||||||
c.engine.LogReco.Warnf(format, args...)
|
c.engine.logger.Warnf(format, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Errorf(format string, args ...any) {
|
func (c *Context) Errorf(format string, args ...any) {
|
||||||
c.engine.LogReco.Errorf(format, args...)
|
c.engine.logger.Errorf(format, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Fatalf(format string, args ...any) {
|
func (c *Context) Fatalf(format string, args ...any) {
|
||||||
c.engine.LogReco.Fatalf(format, args...)
|
c.engine.logger.Fatalf(format, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Panicf(format string, args ...any) {
|
func (c *Context) Panicf(format string, args ...any) {
|
||||||
c.engine.LogReco.Panicf(format, args...)
|
c.engine.logger.Panicf(format, args...)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
81
context_benchmark_test.go
Normal file
81
context_benchmark_test.go
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextResetKeepsKeysNilUntilSet(t *testing.T) {
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
if c.Keys != nil {
|
||||||
|
t.Fatalf("expected fresh test context Keys to be nil before first Set")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("answer", 42)
|
||||||
|
if c.Keys == nil {
|
||||||
|
t.Fatalf("expected Set to allocate Keys map")
|
||||||
|
}
|
||||||
|
if value, exists := c.Get("answer"); !exists || value != 42 {
|
||||||
|
t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
c.reset(UnwrapResponseWriter(c.Writer), req)
|
||||||
|
|
||||||
|
if c.Keys != nil {
|
||||||
|
t.Fatalf("expected reset to clear Keys without allocating a new map")
|
||||||
|
}
|
||||||
|
if value, exists := c.Get("answer"); exists || value != nil {
|
||||||
|
t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxValue := c.Value("missing")
|
||||||
|
if ctxValue != nil {
|
||||||
|
t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatalf("expected MustGet to panic for missing key after reset")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
_ = c.MustGet("answer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContextReset(b *testing.B) {
|
||||||
|
b.Run("NoKeysUse", func(b *testing.B) {
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
rawWriter := UnwrapResponseWriter(c.Writer)
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
c.reset(rawWriter, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WithKeysUse", func(b *testing.B) {
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
rawWriter := UnwrapResponseWriter(c.Writer)
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
c.reset(rawWriter, req)
|
||||||
|
c.Set("request-id", i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
174
context_bodylimit_test.go
Normal file
174
context_bodylimit_test.go
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type zeroNilThenEOFReader struct {
|
||||||
|
readCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) {
|
||||||
|
r.readCalls++
|
||||||
|
if r.readCalls == 1 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *zeroNilThenEOFReader) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTextUsesProvidedStatusCode(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
filePath := filepath.Join(dir, "hello.txt")
|
||||||
|
if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
c, _ := CreateTestContext(rr)
|
||||||
|
|
||||||
|
c.FileText(http.StatusCreated, filePath)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code)
|
||||||
|
}
|
||||||
|
if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" {
|
||||||
|
t.Fatalf("unexpected content type: %q", got)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); body != "hello touka" {
|
||||||
|
t.Fatalf("unexpected body: %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxBytesReaderAllowsExactLimit(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected exact limit read to succeed, got %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "abcd" {
|
||||||
|
t.Fatalf("unexpected data: %q", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxBytesReaderRejectsOverLimit(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
_, err := io.ReadAll(reader)
|
||||||
|
if !errors.Is(err, ErrBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
n, err := reader.Read(buf)
|
||||||
|
if n != 0 || err != nil {
|
||||||
|
t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = reader.Read(buf)
|
||||||
|
if n != 0 || !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxBytesReaderTreatsZeroLimitAsUnlimited(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abc")), 0)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected zero limit to leave body unlimited, got %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "abc" {
|
||||||
|
t.Fatalf("unexpected data: %q", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body := strings.NewReader(`{"name":"abcdef"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/json", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||||
|
c.SetMaxRequestBodySize(8)
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.ShouldBindJSON(&payload)
|
||||||
|
if !errors.Is(err, ErrBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body := strings.NewReader("name=abcdef")
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/form", body)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||||
|
c.SetMaxRequestBodySize(4)
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Name string `form:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.ShouldBindForm(&payload)
|
||||||
|
if !errors.Is(err, ErrBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostFormHonorsMaxRequestBodySize(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body := strings.NewReader("name=abcdef")
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/form", body)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||||
|
c.SetMaxRequestBodySize(4)
|
||||||
|
|
||||||
|
if got := c.PostForm("name"); got != "" {
|
||||||
|
t.Fatalf("expected empty value on over-limit form body, got %q", got)
|
||||||
|
}
|
||||||
|
if len(c.Errors) == 0 {
|
||||||
|
t.Fatal("expected parse error to be recorded")
|
||||||
|
}
|
||||||
|
if !errors.Is(c.Errors[0], ErrBodyTooLarge) {
|
||||||
|
t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
58
context_httpc.go
Normal file
58
context_httpc.go
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
// Copyright 2024 WJQSERVER. All rights reserved.
|
||||||
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/WJQSERVER-STUDIO/httpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// contextHTTPClient 包装 httpc.Client,自动关联请求的 Context
|
||||||
|
// 当请求被取消时,出站 HTTP 请求也会自动取消
|
||||||
|
type contextHTTPClient struct {
|
||||||
|
client *httpc.Client
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequestBuilder 创建请求构建器,自动关联请求 Context
|
||||||
|
func (c *contextHTTPClient) NewRequestBuilder(method, urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.NewRequestBuilder(method, urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GET 创建 GET 请求构建器
|
||||||
|
func (c *contextHTTPClient) GET(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.GET(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// POST 创建 POST 请求构建器
|
||||||
|
func (c *contextHTTPClient) POST(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.POST(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PUT 创建 PUT 请求构建器
|
||||||
|
func (c *contextHTTPClient) PUT(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.PUT(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DELETE 创建 DELETE 请求构建器
|
||||||
|
func (c *contextHTTPClient) DELETE(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.DELETE(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PATCH 创建 PATCH 请求构建器
|
||||||
|
func (c *contextHTTPClient) PATCH(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.PATCH(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HEAD 创建 HEAD 请求构建器
|
||||||
|
func (c *contextHTTPClient) HEAD(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.HEAD(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OPTIONS 创建 OPTIONS 请求构建器
|
||||||
|
func (c *contextHTTPClient) OPTIONS(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.OPTIONS(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
368
docs/advanced.md
368
docs/advanced.md
|
|
@ -14,9 +14,192 @@ Touka 使用 `sync.Pool` 来重用 `touka.Context` 对象。这极大减少了
|
||||||
|
|
||||||
在路由匹配过程中,Touka 会预分配路径参数切片,并根据路由深度进行缓存,从而在路由查找时实现几乎零分配。
|
在路由匹配过程中,Touka 会预分配路径参数切片,并根据路由深度进行缓存,从而在路由查找时实现几乎零分配。
|
||||||
|
|
||||||
|
## 服务器配置
|
||||||
|
|
||||||
|
### 服务器配置器 (ServerConfigurator)
|
||||||
|
|
||||||
|
Touka 允许您在服务器启动前对底层 `*http.Server` 进行自定义配置:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := touka.New()
|
||||||
|
|
||||||
|
// 配置 HTTP 服务器
|
||||||
|
r.SetServerConfigurator(func(server *http.Server) {
|
||||||
|
server.ReadTimeout = 30 * time.Second
|
||||||
|
server.WriteTimeout = 30 * time.Second
|
||||||
|
server.IdleTimeout = 120 * time.Second
|
||||||
|
server.MaxHeaderBytes = 1 << 20 // 1MB
|
||||||
|
})
|
||||||
|
|
||||||
|
// 专门配置 HTTPS 服务器(优先级高于 ServerConfigurator)
|
||||||
|
r.SetTLSServerConfigurator(func(server *http.Server) {
|
||||||
|
server.ReadTimeout = 30 * time.Second
|
||||||
|
server.WriteTimeout = 30 * time.Second
|
||||||
|
// HTTPS 特定配置...
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 协议配置
|
||||||
|
|
||||||
|
Touka 支持配置 HTTP/1.1、HTTP/2 和 H2C(HTTP/2 Cleartext):
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 使用默认协议配置
|
||||||
|
// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集,
|
||||||
|
// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。
|
||||||
|
r.SetDefaultProtocols()
|
||||||
|
|
||||||
|
// 自定义协议配置
|
||||||
|
r.SetProtocols(&touka.ProtocolsConfig{
|
||||||
|
Http1: true, // 启用 HTTP/1.1
|
||||||
|
Http2: true, // 启用 HTTP/2(需要 TLS)
|
||||||
|
Http2_Cleartext: true, // 启用 H2C(无需 TLS 的 HTTP/2)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 启动方式
|
||||||
|
|
||||||
|
Touka 统一通过 `Run(opts...)` 启动服务器:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 1. 简单启动(无优雅停机)
|
||||||
|
r.Run(touka.WithAddr(":8080"))
|
||||||
|
|
||||||
|
// 2. 带优雅停机的启动
|
||||||
|
r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second))
|
||||||
|
|
||||||
|
// 3. 带上下文的优雅停机
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":8080"),
|
||||||
|
touka.WithGracefulShutdown(10*time.Second),
|
||||||
|
touka.WithShutdownContext(ctx),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 4. HTTPS 启动
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
// 其他 TLS 配置...
|
||||||
|
}
|
||||||
|
// WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithGracefulShutdownDefault(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 5. HTTPS + HTTP 重定向
|
||||||
|
// WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithHTTPRedirect(":80"),
|
||||||
|
touka.WithGracefulShutdown(10*time.Second),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host)
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithHTTPRedirect(
|
||||||
|
":80",
|
||||||
|
touka.WithUseHeaderHost(true),
|
||||||
|
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host)
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithHTTPRedirect(
|
||||||
|
":80",
|
||||||
|
touka.WithUseHeaderHost(false),
|
||||||
|
touka.WithRedirectHost("example.com"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### HTTPS Redirect Host 策略
|
||||||
|
|
||||||
|
`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。
|
||||||
|
|
||||||
|
可用的 redirect 子选项:
|
||||||
|
|
||||||
|
- `touka.WithUseHeaderHost(true|false)`
|
||||||
|
- `touka.WithRedirectHostHeaders([]string{...})`
|
||||||
|
- `touka.WithRedirectHost("example.com")`
|
||||||
|
|
||||||
|
#### 模式一:使用请求输入侧的 host
|
||||||
|
|
||||||
|
当 `WithUseHeaderHost(true)` 时:
|
||||||
|
|
||||||
|
- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host`
|
||||||
|
- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值
|
||||||
|
- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithHTTPRedirect(
|
||||||
|
":80",
|
||||||
|
touka.WithUseHeaderHost(true),
|
||||||
|
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 模式二:使用配置的固定 host
|
||||||
|
|
||||||
|
当 `WithUseHeaderHost(false)` 时:
|
||||||
|
|
||||||
|
- 不读取 `Request.Host`
|
||||||
|
- 不读取 `WithRedirectHostHeaders(...)`
|
||||||
|
- 必须配置 `WithRedirectHost("example.com")`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithHTTPRedirect(
|
||||||
|
":80",
|
||||||
|
touka.WithUseHeaderHost(false),
|
||||||
|
touka.WithRedirectHost("example.com"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 严格校验规则
|
||||||
|
|
||||||
|
以下组合会直接返回配置错误:
|
||||||
|
|
||||||
|
- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)`
|
||||||
|
- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)`
|
||||||
|
- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)`
|
||||||
|
- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)`
|
||||||
|
- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)`
|
||||||
|
|
||||||
|
#### 优先级关系
|
||||||
|
|
||||||
|
1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式
|
||||||
|
2. `WithUseHeaderHost(...)` 决定 host 来源模式
|
||||||
|
3. 当 `WithUseHeaderHost(true)` 时:
|
||||||
|
- 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询
|
||||||
|
- 未配置时使用 `Request.Host`
|
||||||
|
4. 当 `WithUseHeaderHost(false)` 时:
|
||||||
|
- 只使用 `WithRedirectHost(...)`
|
||||||
|
|
||||||
|
**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。
|
||||||
|
|
||||||
## 优雅停机 (Graceful Shutdown)
|
## 优雅停机 (Graceful Shutdown)
|
||||||
|
|
||||||
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。
|
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。
|
||||||
|
|
||||||
```go
|
```go
|
||||||
r := touka.Default()
|
r := touka.Default()
|
||||||
|
|
@ -24,11 +207,133 @@ r := touka.Default()
|
||||||
|
|
||||||
// 监听 SIGINT 和 SIGTERM 信号
|
// 监听 SIGINT 和 SIGTERM 信号
|
||||||
// 如果在 10 秒内未处理完,则强制关闭
|
// 如果在 10 秒内未处理完,则强制关闭
|
||||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||||
log.Fatal("服务器退出异常:", err)
|
log.Fatal("服务器退出异常:", err)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### SSE 长连接的优雅关闭
|
||||||
|
|
||||||
|
对于 SSE 等长连接场景,Touka 会自动将引擎的关闭信号注入到请求的 Context 中:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/events", func(c *touka.Context) {
|
||||||
|
c.EventStream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
// 收到关闭信号,优雅退出
|
||||||
|
return false
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
// 发送数据
|
||||||
|
event := touka.Event{Data: "tick"}
|
||||||
|
event.Render(w)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## 路由行为配置
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := touka.New()
|
||||||
|
|
||||||
|
// 是否自动重定向尾部斜杠(默认 true)
|
||||||
|
// /foo/ -> /foo 或 /foo -> /foo/
|
||||||
|
r.SetRedirectTrailingSlash(true)
|
||||||
|
|
||||||
|
// 是否自动修复路径大小写(默认 true)
|
||||||
|
// /FOO -> /foo
|
||||||
|
r.SetRedirectFixedPath(true)
|
||||||
|
|
||||||
|
// 是否处理 405 Method Not Allowed(默认 true)
|
||||||
|
// 当路径匹配但方法不匹配时返回 405 而非 404
|
||||||
|
r.SetHandleMethodNotAllowed(true)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 自定义 404 处理
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 单个处理器
|
||||||
|
r.NoRoute(func(c *touka.Context) {
|
||||||
|
c.JSON(http.StatusNotFound, touka.H{
|
||||||
|
"error": "Page not found",
|
||||||
|
"path": c.Request.URL.Path,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// 处理器链(可以在 404 前执行额外中间件)
|
||||||
|
r.NoRoutes(
|
||||||
|
LogNotFoundMiddleware(),
|
||||||
|
func(c *touka.Context) {
|
||||||
|
c.JSON(http.StatusNotFound, touka.H{"error": "Not found"})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 未匹配路径作为静态文件服务
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 当没有路由匹配时,尝试从文件系统中查找文件
|
||||||
|
// 非常适合单页应用(SPA)部署
|
||||||
|
r.SetUnMatchFS(http.Dir("./frontend/dist"))
|
||||||
|
|
||||||
|
// 也可以添加额外的中间件
|
||||||
|
r.SetUnMatchFS(http.Dir("./frontend/dist"), AuthMiddleware())
|
||||||
|
```
|
||||||
|
|
||||||
|
## IP 地址解析配置
|
||||||
|
|
||||||
|
在反向代理环境中,正确配置 IP 解析非常重要:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := touka.New()
|
||||||
|
|
||||||
|
// 是否信任代理头部获取客户端 IP(默认 true)
|
||||||
|
r.SetForwardByClientIP(true)
|
||||||
|
|
||||||
|
// 设置用于获取客户端 IP 的头部列表(按优先级排序)
|
||||||
|
r.SetRemoteIPHeaders([]string{
|
||||||
|
"X-Forwarded-For",
|
||||||
|
"X-Real-IP",
|
||||||
|
"CF-Connecting-IP", // Cloudflare
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
如果您同时使用 Touka 的 `ReverseProxy` 把请求继续转发给其他后端,请再参考 `docs/reverse-proxy.md` 中关于 `Forwarded`、`X-Forwarded-*` 与 `Via` 的说明。前者解决“当前请求的客户端 IP 如何被 Touka 正确解析”,后者解决“代理后的请求如何把链路信息继续传给下一跳”。
|
||||||
|
|
||||||
|
## 请求体大小限制
|
||||||
|
|
||||||
|
为了防止恶意的大数据包攻击(如慢速 HTTP 攻击或内存溢出),Touka 内置了请求体大小限制机制。
|
||||||
|
|
||||||
|
### 全局限制
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 设置全局最大请求体大小(例如 10MB)
|
||||||
|
r.SetGlobalMaxRequestBodySize(10 << 20)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 单个请求限制
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.POST("/upload", func(c *touka.Context) {
|
||||||
|
// 为特定请求设置限制(覆盖全局设置)
|
||||||
|
c.SetMaxRequestBodySize(100 << 20) // 100MB
|
||||||
|
|
||||||
|
body, err := c.GetReqBodyFull()
|
||||||
|
if err != nil {
|
||||||
|
// 如果超过限制,会返回 ErrBodyTooLarge
|
||||||
|
if errors.Is(err, touka.ErrBodyTooLarge) {
|
||||||
|
c.ErrorUseHandle(http.StatusRequestEntityTooLarge, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.ErrorUseHandle(http.StatusBadRequest, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 处理 body...
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
## 与标准库集成
|
## 与标准库集成
|
||||||
|
|
||||||
Touka 遵循 `net/http` 哲学。您可以方便地使用现有的标准库组件。
|
Touka 遵循 `net/http` 哲学。您可以方便地使用现有的标准库组件。
|
||||||
|
|
@ -39,6 +344,14 @@ Touka 遵循 `net/http` 哲学。您可以方便地使用现有的标准库组
|
||||||
r.GET("/pprof/*any", touka.AdapterStdFunc(pprof.Index))
|
r.GET("/pprof/*any", touka.AdapterStdFunc(pprof.Index))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 适配 `http.Handler`
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 适配 http.FileServer
|
||||||
|
fileServer := http.FileServer(http.Dir("./static"))
|
||||||
|
r.GET("/static/*filepath", touka.AdapterStdHandle(http.StripPrefix("/static", fileServer)))
|
||||||
|
```
|
||||||
|
|
||||||
### 手动注入
|
### 手动注入
|
||||||
|
|
||||||
由于 `Engine` 实现了 `http.Handler` 接口,您可以将其挂载到任何地方。
|
由于 `Engine` 实现了 `http.Handler` 接口,您可以将其挂载到任何地方。
|
||||||
|
|
@ -61,17 +374,60 @@ Touka 默认集成了 `reco` 日志库。您可以自定义其输出行为。
|
||||||
```go
|
```go
|
||||||
logConfig := reco.Config{
|
logConfig := reco.Config{
|
||||||
Level: reco.LevelInfo,
|
Level: reco.LevelInfo,
|
||||||
|
Mode: reco.ModeText, // 或 reco.ModeJSON
|
||||||
Output: os.Stdout,
|
Output: os.Stdout,
|
||||||
Async: true, // 异步写入提高性能
|
Async: true, // 异步写入提高性能
|
||||||
|
TimeFormat: time.RFC3339,
|
||||||
}
|
}
|
||||||
r.SetLoggerCfg(logConfig)
|
r.SetLoggerCfg(logConfig)
|
||||||
|
|
||||||
|
// 或直接传入日志实例
|
||||||
|
logger, _ := reco.New(logConfig)
|
||||||
|
r.SetLogger(logger)
|
||||||
|
|
||||||
|
// 关闭日志(在服务器关闭时)
|
||||||
|
defer r.CloseLogger()
|
||||||
```
|
```
|
||||||
|
|
||||||
## 内存读取限制 (MaxReader)
|
## HTTP 客户端配置
|
||||||
|
|
||||||
为了防止恶意的大数据包攻击(如慢速 HTTP 攻击或内存溢出),Touka 内置了 `MaxReader` 机制。
|
Touka 内置了 `httpc` HTTP 客户端,可以在请求处理中方便地发起出站请求:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// 设置全局最大读取限制(例如 2MB)
|
// 创建自定义 HTTP 客户端
|
||||||
r.SetMaxReader(2 << 20)
|
customClient := httpc.New()
|
||||||
|
r.SetHTTPClient(customClient)
|
||||||
|
|
||||||
|
// 在处理器中使用
|
||||||
|
r.GET("/proxy", func(c *touka.Context) {
|
||||||
|
resp, err := c.GetHTTPC().Get("https://api.example.com/data")
|
||||||
|
// ...
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## 条件中间件
|
||||||
|
|
||||||
|
Touka 支持根据条件动态启用或禁用中间件:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 单个条件中间件
|
||||||
|
r.Use(r.UseIf(config.EnableLogging, AccessLoggerMiddleware()))
|
||||||
|
|
||||||
|
// 条件中间件链
|
||||||
|
r.Use(r.UseChainIf(config.EnableMetrics,
|
||||||
|
MetricsMiddleware,
|
||||||
|
PrometheusMiddleware,
|
||||||
|
MonitoringMiddleware,
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
## 获取路由信息
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 获取所有已注册的路由信息
|
||||||
|
routes := r.GetRouterInfo()
|
||||||
|
for _, route := range routes {
|
||||||
|
fmt.Printf("Method: %s, Path: %s, Handler: %s, Group: %s\n",
|
||||||
|
route.Method, route.Path, route.Handler, route.Group)
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
357
docs/context.md
357
docs/context.md
|
|
@ -4,6 +4,16 @@
|
||||||
|
|
||||||
## 请求数据解析
|
## 请求数据解析
|
||||||
|
|
||||||
|
### 路径参数 (Path Parameters)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 路由: /users/:id
|
||||||
|
r.GET("/users/:id", func(c *touka.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
c.String(http.StatusOK, "User ID: %s", id)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
### 查询参数 (Query Parameters)
|
### 查询参数 (Query Parameters)
|
||||||
|
|
||||||
```go
|
```go
|
||||||
|
|
@ -31,14 +41,80 @@ r.POST("/form_post", func(c *touka.Context) {
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 请求体读取
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 读取完整请求体
|
||||||
|
r.POST("/raw", func(c *touka.Context) {
|
||||||
|
body, err := c.GetReqBodyFull()
|
||||||
|
if err != nil {
|
||||||
|
c.ErrorUseHandle(http.StatusBadRequest, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Raw(http.StatusOK, "application/octet-stream", body)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 获取 io.ReadCloser(只能读取一次)
|
||||||
|
r.POST("/stream", func(c *touka.Context) {
|
||||||
|
reader := c.GetReqBody()
|
||||||
|
defer reader.Close()
|
||||||
|
// 处理 reader...
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 客户端信息
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/client-info", func(c *touka.Context) {
|
||||||
|
// 获取客户端 IP(支持代理转发)
|
||||||
|
ip := c.RequestIP()
|
||||||
|
// 或使用别名
|
||||||
|
ip = c.ClientIP()
|
||||||
|
|
||||||
|
// 获取 User-Agent
|
||||||
|
ua := c.UserAgent()
|
||||||
|
|
||||||
|
// 获取 Content-Type
|
||||||
|
ct := c.ContentType()
|
||||||
|
|
||||||
|
// 获取请求协议
|
||||||
|
proto := c.GetProtocol()
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, touka.H{
|
||||||
|
"ip": ip,
|
||||||
|
"userAgent": ua,
|
||||||
|
"protocol": proto,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 请求头
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/headers", func(c *touka.Context) {
|
||||||
|
// 获取单个请求头
|
||||||
|
auth := c.GetReqHeader("Authorization")
|
||||||
|
|
||||||
|
// 获取所有请求头
|
||||||
|
allHeaders := c.GetAllReqHeader()
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, touka.H{
|
||||||
|
"authorization": auth,
|
||||||
|
"allHeaders": allHeaders,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## 数据绑定
|
||||||
|
|
||||||
### JSON 绑定
|
### JSON 绑定
|
||||||
|
|
||||||
Touka 提供了非常便捷的 JSON 绑定功能,它会自动解析请求体并填充到结构体中,同时进行基本的验证。
|
Touka 提供了非常便捷的 JSON 绑定功能,它会自动解析请求体并填充到结构体中。
|
||||||
|
|
||||||
```go
|
```go
|
||||||
type LoginRequest struct {
|
type LoginRequest struct {
|
||||||
User string `json:"user" binding:"required"`
|
User string `json:"user"`
|
||||||
Password string `json:"password" binding:"required"`
|
Password string `json:"password"`
|
||||||
}
|
}
|
||||||
|
|
||||||
r.POST("/login", func(c *touka.Context) {
|
r.POST("/login", func(c *touka.Context) {
|
||||||
|
|
@ -57,6 +133,67 @@ r.POST("/login", func(c *touka.Context) {
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 表单绑定
|
||||||
|
|
||||||
|
```go
|
||||||
|
type UserForm struct {
|
||||||
|
Name string `form:"name"`
|
||||||
|
Email string `form:"email"`
|
||||||
|
Age int `form:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
r.POST("/user", func(c *touka.Context) {
|
||||||
|
var form UserForm
|
||||||
|
if err := c.ShouldBindForm(&form); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, form)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 通用绑定
|
||||||
|
|
||||||
|
`ShouldBind` 方法会根据请求的 `Content-Type` 自动选择绑定方式:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.POST("/data", func(c *touka.Context) {
|
||||||
|
var data MyData
|
||||||
|
// 自动根据 Content-Type 绑定(支持 JSON、Form、WANF、GOB)
|
||||||
|
if err := c.ShouldBind(&data); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, data)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### WANF 绑定
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.POST("/wanf", func(c *touka.Context) {
|
||||||
|
var data MyData
|
||||||
|
if err := c.ShouldBindWANF(&data); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, data)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### GOB 绑定
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.POST("/gob", func(c *touka.Context) {
|
||||||
|
var data MyData
|
||||||
|
if err := c.ShouldBindGOB(&data); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, data)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
## 响应构建
|
## 响应构建
|
||||||
|
|
||||||
### 基础格式
|
### 基础格式
|
||||||
|
|
@ -73,21 +210,73 @@ c.String(http.StatusOK, "welcome %s", name)
|
||||||
// 纯文本
|
// 纯文本
|
||||||
c.Text(http.StatusOK, "just text")
|
c.Text(http.StatusOK, "just text")
|
||||||
|
|
||||||
|
// 原始数据
|
||||||
|
c.Raw(http.StatusOK, "application/octet-stream", []byte("raw bytes"))
|
||||||
|
|
||||||
// HTML 模板
|
// HTML 模板
|
||||||
c.HTML(http.StatusOK, "index.tmpl", touka.H{"title": "Main website"})
|
c.HTML(http.StatusOK, "index.tmpl", touka.H{"title": "Main website"})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### WANF 响应
|
||||||
|
|
||||||
|
```go
|
||||||
|
// WANF 格式响应
|
||||||
|
c.WANF(http.StatusOK, touka.H{"message": "wanf format"})
|
||||||
|
```
|
||||||
|
|
||||||
|
### GOB 响应
|
||||||
|
|
||||||
|
```go
|
||||||
|
// GOB 格式响应
|
||||||
|
c.GOB(http.StatusOK, myData)
|
||||||
|
```
|
||||||
|
|
||||||
### 文件与流
|
### 文件与流
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// 服务本地文件
|
// 服务本地文件(触发浏览器下载)
|
||||||
c.File("/local/file.go")
|
c.File("/local/file.go")
|
||||||
|
|
||||||
// 将文件内容作为响应体(不触发下载)
|
// 将文件内容作为响应体(不触发下载)
|
||||||
c.SetRespBodyFile(http.StatusOK, "config.json")
|
c.SetRespBodyFile(http.StatusOK, "config.json")
|
||||||
|
|
||||||
|
// 以文本形式发送文件
|
||||||
|
c.FileText(http.StatusOK, "/path/to/file.txt")
|
||||||
|
|
||||||
// 写入数据流
|
// 写入数据流
|
||||||
c.WriteStream(reader)
|
c.WriteStream(reader)
|
||||||
|
|
||||||
|
// 设置响应体为流
|
||||||
|
c.SetBodyStream(reader, contentSize) // contentSize 为 -1 表示未知大小
|
||||||
|
```
|
||||||
|
|
||||||
|
### 响应头操作
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 设置响应头
|
||||||
|
c.SetHeader("X-Custom-Header", "value")
|
||||||
|
|
||||||
|
// 添加响应头(不覆盖已有值)
|
||||||
|
c.AddHeader("X-Custom-Header", "another-value")
|
||||||
|
|
||||||
|
// 删除响应头
|
||||||
|
c.DelHeader("X-Custom-Header")
|
||||||
|
|
||||||
|
// 批量设置响应头
|
||||||
|
c.SetHeaders(map[string][]string{
|
||||||
|
"X-Header-1": {"value1"},
|
||||||
|
"X-Header-2": {"value2a", "value2b"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 获取所有响应头
|
||||||
|
headers := c.GetAllRespHeader()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 状态码
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 设置状态码(不写入响应体)
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 重定向
|
### 重定向
|
||||||
|
|
@ -96,6 +285,34 @@ c.WriteStream(reader)
|
||||||
c.Redirect(http.StatusMovedPermanently, "http://google.com/")
|
c.Redirect(http.StatusMovedPermanently, "http://google.com/")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Cookie 操作
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 设置 Cookie
|
||||||
|
c.SetCookie("session_id", "abc123", 3600, "/", "example.com", true, true)
|
||||||
|
|
||||||
|
// 设置 SameSite 属性
|
||||||
|
c.SetSameSite(http.SameSiteStrictMode)
|
||||||
|
|
||||||
|
// 使用完整 Cookie 对象
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: "token",
|
||||||
|
Value: "xyz",
|
||||||
|
Path: "/",
|
||||||
|
}
|
||||||
|
c.SetCookieData(cookie)
|
||||||
|
|
||||||
|
// 获取 Cookie
|
||||||
|
value, err := c.GetCookie("session_id")
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusUnauthorized, "Cookie not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除 Cookie
|
||||||
|
c.DeleteCookie("session_id")
|
||||||
|
```
|
||||||
|
|
||||||
## 数据传递 (Keys/Values)
|
## 数据传递 (Keys/Values)
|
||||||
|
|
||||||
您可以在中间件和处理器之间共享数据。
|
您可以在中间件和处理器之间共享数据。
|
||||||
|
|
@ -107,14 +324,146 @@ c.Set("user_id", 12345)
|
||||||
// 在处理器中获取
|
// 在处理器中获取
|
||||||
id, exists := c.Get("user_id")
|
id, exists := c.Get("user_id")
|
||||||
val := c.MustGet("user_id").(int)
|
val := c.MustGet("user_id").(int)
|
||||||
|
|
||||||
|
// 类型安全的获取方法
|
||||||
|
str, exists := c.GetString("key")
|
||||||
|
i, exists := c.GetInt("key")
|
||||||
|
b, exists := c.GetBool("key")
|
||||||
|
f, exists := c.GetFloat64("key")
|
||||||
|
t, exists := c.GetTime("key")
|
||||||
|
d, exists := c.GetDuration("key")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 错误处理
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/error", func(c *touka.Context) {
|
||||||
|
// 添加错误到上下文(可以添加多个)
|
||||||
|
c.AddError(errors.New("error 1"))
|
||||||
|
c.AddError(errors.New("error 2"))
|
||||||
|
|
||||||
|
// 获取所有错误
|
||||||
|
errs := c.GetErrors()
|
||||||
|
|
||||||
|
// 使用全局错误处理器
|
||||||
|
c.ErrorUseHandle(http.StatusInternalServerError, errors.New("something went wrong"))
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## 日志记录
|
||||||
|
|
||||||
|
Touka 集成了 `reco` 日志库,可以直接在 Context 中使用:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/log", func(c *touka.Context) {
|
||||||
|
c.Debugf("Debug message: %s", "details")
|
||||||
|
c.Infof("User accessed /log")
|
||||||
|
c.Warnf("Warning: %v", someWarning)
|
||||||
|
c.Errorf("Error occurred: %v", someError)
|
||||||
|
|
||||||
|
// 获取底层日志器
|
||||||
|
logger := c.GetLogger()
|
||||||
|
logger.CustomLog("level", "message")
|
||||||
|
|
||||||
|
c.String(http.StatusOK, "Logged")
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## HTTP 客户端
|
||||||
|
|
||||||
|
Touka 集成了 `httpc` HTTP 客户端,方便发起出站请求:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/proxy", func(c *touka.Context) {
|
||||||
|
// 获取 HTTP 客户端
|
||||||
|
client := c.GetHTTPC()
|
||||||
|
// 或
|
||||||
|
client = c.Client()
|
||||||
|
|
||||||
|
// 发起请求
|
||||||
|
resp, err := client.Get("https://api.example.com/data")
|
||||||
|
if err != nil {
|
||||||
|
c.ErrorUseHandle(http.StatusBadGateway, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 将响应流式传输给客户端
|
||||||
|
c.SetHeader("Content-Type", resp.Header.Get("Content-Type"))
|
||||||
|
c.WriteStream(resp.Body)
|
||||||
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
## 状态管理
|
## 状态管理
|
||||||
|
|
||||||
- `c.Abort()`: 停止执行后续的处理器/中间件。
|
- `c.Abort()`: 停止执行后续的处理器/中间件。
|
||||||
|
- `c.AbortWithStatus(code)`: 中止并设置状态码。
|
||||||
|
- `c.IsAborted()`: 检查是否已中止。
|
||||||
- `c.Next()`: 执行后续的处理链。这常用于中间件中,在执行完某些前置逻辑后,显式调用 `Next`,并在其返回后执行后置逻辑。
|
- `c.Next()`: 执行后续的处理链。这常用于中间件中,在执行完某些前置逻辑后,显式调用 `Next`,并在其返回后执行后置逻辑。
|
||||||
|
|
||||||
|
## 请求上下文 (Go Context)
|
||||||
|
|
||||||
|
Touka Context 实现了 Go 标准库的 `context.Context` 接口:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/long-task", func(c *touka.Context) {
|
||||||
|
// 获取 Go context
|
||||||
|
ctx := c.Context()
|
||||||
|
|
||||||
|
// 监听取消信号
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// 客户端断开连接或超时
|
||||||
|
return
|
||||||
|
case result := <-doLongTask(ctx):
|
||||||
|
c.JSON(http.StatusOK, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 其他 context 方法
|
||||||
|
done := c.Done() // 获取 Done channel
|
||||||
|
err := c.Err() // 获取错误
|
||||||
|
val := c.Value("key") // 获取值(同时查找 Keys 和 Go context)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 其他方法
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 获取原始请求 URI
|
||||||
|
uri := c.GetRequestURI()
|
||||||
|
|
||||||
|
// 获取请求路径
|
||||||
|
path := c.GetRequestURIPath()
|
||||||
|
|
||||||
|
// 获取查询字符串
|
||||||
|
query := c.GetReqQueryString()
|
||||||
|
|
||||||
|
// 获取请求协议版本
|
||||||
|
proto := c.GetProtocol() // 例如 "HTTP/1.1"
|
||||||
|
```
|
||||||
|
|
||||||
## 对象池化
|
## 对象池化
|
||||||
|
|
||||||
为了提高性能,Touka 的 Context 对象是复用的。
|
为了提高性能,Touka 的 Context 对象是复用的。
|
||||||
|
|
||||||
**重要提示:不要在 Goroutine 中持久化持有 `touka.Context` 指针。如果您需要在 Goroutine 中使用请求数据,请务必在派生 Goroutine 前提取所需的值。**
|
**重要提示:不要在 Goroutine 中持久化持有 `touka.Context` 指针。如果您需要在 Goroutine 中使用请求数据,请务必在派生 Goroutine 前提取所需的值。**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 错误示例 ❌
|
||||||
|
r.GET("/bad", func(c *touka.Context) {
|
||||||
|
go func() {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
// 此时 c 可能已被复用,数据不安全
|
||||||
|
log.Println(c.Query("name"))
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
|
||||||
|
// 正确示例 ✓
|
||||||
|
r.GET("/good", func(c *touka.Context) {
|
||||||
|
name := c.Query("name") // 提前提取值
|
||||||
|
go func() {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
log.Println(name) // 使用提取的值,安全
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
|
||||||
188
docs/httpc.md
Normal file
188
docs/httpc.md
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
# HTTP Client (httpc)
|
||||||
|
|
||||||
|
Touka 内置了 [httpc](https://github.com/WJQSERVER-STUDIO/httpc) HTTP 客户端,方便在请求处理函数中发起出站 HTTP 请求。
|
||||||
|
|
||||||
|
## 核心特性
|
||||||
|
|
||||||
|
- **自动 Context 关联**:使用 `HTTPC()` 方法时,出站请求会自动关联当前请求的 Context
|
||||||
|
- **请求取消传播**:当客户端断开连接时,出站请求会自动取消,避免资源泄漏
|
||||||
|
- **链式调用**:保持 httpc 原有的组合式构建器风格
|
||||||
|
|
||||||
|
## 基本用法
|
||||||
|
|
||||||
|
### 简单 GET 请求
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/proxy", func(c *touka.Context) {
|
||||||
|
body, err := c.HTTPC().
|
||||||
|
GET("https://api.example.com/data").
|
||||||
|
Text()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.String(200, body)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### POST JSON 请求
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.POST("/users", func(c *touka.Context) {
|
||||||
|
var req struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
c.ShouldBindJSON(&req)
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.HTTPC().
|
||||||
|
POST("https://api.example.com/users").
|
||||||
|
SetHeader("Authorization", "Bearer "+token).
|
||||||
|
SetJSONBody(req).
|
||||||
|
DecodeJSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, result)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 带查询参数
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/search", func(c *touka.Context) {
|
||||||
|
query := c.Query("q")
|
||||||
|
|
||||||
|
var result SearchResult
|
||||||
|
err := c.HTTPC().
|
||||||
|
GET("https://api.example.com/search").
|
||||||
|
SetQueryParam("q", query).
|
||||||
|
SetQueryParam("limit", "10").
|
||||||
|
DecodeJSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, result)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## API 对比
|
||||||
|
|
||||||
|
### 旧方式(Deprecated)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 需要手动 WithContext,容易忘记
|
||||||
|
resp, err := c.Client().
|
||||||
|
WithContext(c.Context()).
|
||||||
|
GET(url).
|
||||||
|
Execute()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 新方式(推荐)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 自动关联请求 Context
|
||||||
|
resp, err := c.HTTPC().
|
||||||
|
GET(url).
|
||||||
|
Execute()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Context 取消机制
|
||||||
|
|
||||||
|
使用 `HTTPC()` 时,当客户端断开连接(如关闭浏览器),出站请求会自动取消:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/long-task", func(c *touka.Context) {
|
||||||
|
// 这个请求会在客户端断开时自动取消
|
||||||
|
resp, err := c.HTTPC().
|
||||||
|
GET("https://slow-api.example.com/data").
|
||||||
|
Execute()
|
||||||
|
|
||||||
|
// 如果客户端已断开,err 会包含 context.Canceled
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return // 客户端已断开,无需处理
|
||||||
|
}
|
||||||
|
// ...
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## 完整 API
|
||||||
|
|
||||||
|
### contextHTTPClient 方法
|
||||||
|
|
||||||
|
| 方法 | 返回类型 | 说明 |
|
||||||
|
|------|----------|------|
|
||||||
|
| `NewRequestBuilder(method, url)` | `*httpc.RequestBuilder` | 创建通用请求构建器 |
|
||||||
|
| `GET(url)` | `*httpc.RequestBuilder` | 创建 GET 请求 |
|
||||||
|
| `POST(url)` | `*httpc.RequestBuilder` | 创建 POST 请求 |
|
||||||
|
| `PUT(url)` | `*httpc.RequestBuilder` | 创建 PUT 请求 |
|
||||||
|
| `DELETE(url)` | `*httpc.RequestBuilder` | 创建 DELETE 请求 |
|
||||||
|
| `PATCH(url)` | `*httpc.RequestBuilder` | 创建 PATCH 请求 |
|
||||||
|
| `HEAD(url)` | `*httpc.RequestBuilder` | 创建 HEAD 请求 |
|
||||||
|
| `OPTIONS(url)` | `*httpc.RequestBuilder` | 创建 OPTIONS 请求 |
|
||||||
|
|
||||||
|
### httpc.RequestBuilder 链式方法
|
||||||
|
|
||||||
|
返回 `*httpc.RequestBuilder`(用于链式调用):
|
||||||
|
|
||||||
|
| 方法 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `WithContext(ctx)` | 设置 Context(通常不需要,已自动关联) |
|
||||||
|
| `NoDefaultHeaders()` | 不添加默认 Header |
|
||||||
|
| `SetHeader(key, value)` | 设置 Header |
|
||||||
|
| `AddHeader(key, value)` | 添加 Header(可重复) |
|
||||||
|
| `SetHeaders(map)` | 批量设置 Headers |
|
||||||
|
| `SetQueryParam(key, value)` | 设置查询参数 |
|
||||||
|
| `AddQueryParam(key, value)` | 添加查询参数(可重复) |
|
||||||
|
| `SetQueryParams(map)` | 批量设置查询参数 |
|
||||||
|
| `SetBody(io.Reader)` | 设置请求 Body |
|
||||||
|
| `SetRawBody([]byte)` | 设置字节 Body |
|
||||||
|
|
||||||
|
返回 `(*httpc.RequestBuilder, error)`(可能失败):
|
||||||
|
|
||||||
|
| 方法 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `SetJSONBody(any)` | 设置 JSON Body |
|
||||||
|
| `SetXMLBody(any)` | 设置 XML Body |
|
||||||
|
| `SetGOBBody(any)` | 设置 GOB Body |
|
||||||
|
|
||||||
|
### 终结方法
|
||||||
|
|
||||||
|
| 方法 | 返回类型 | 说明 |
|
||||||
|
|------|----------|------|
|
||||||
|
| `Build()` | `(*http.Request, error)` | 构建请求但不执行 |
|
||||||
|
| `Execute()` | `(*http.Response, error)` | 执行并返回原始响应 |
|
||||||
|
| `DecodeJSON(v)` | `error` | 执行并解码 JSON |
|
||||||
|
| `DecodeXML(v)` | `error` | 执行并解码 XML |
|
||||||
|
| `DecodeGOB(v)` | `error` | 执行并解码 GOB |
|
||||||
|
| `Text()` | `(string, error)` | 执行并返回文本 |
|
||||||
|
| `Bytes()` | `([]byte, error)` | 执行并返回字节 |
|
||||||
|
| `SSE()` | `(*SSEStream, error)` | 建立 SSE 流连接 |
|
||||||
|
|
||||||
|
## 迁移指南
|
||||||
|
|
||||||
|
### go:fix inline 兼容
|
||||||
|
|
||||||
|
旧代码 `c.GetHTTPC()` 可通过 `go fix` 自动迁移到 `c.Client()`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go fix ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 手动迁移
|
||||||
|
|
||||||
|
| 旧代码 | 新代码 |
|
||||||
|
|--------|--------|
|
||||||
|
| `c.GetHTTPC()` | `c.Client()` 或 `c.HTTPC()` |
|
||||||
|
| `c.Client().WithContext(ctx).GET(url)` | `c.HTTPC().GET(url)` |
|
||||||
|
|
||||||
|
## 示例
|
||||||
|
|
||||||
|
完整示例请参考 [examples/httpc](../examples/httpc)。
|
||||||
|
|
@ -14,6 +14,7 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其
|
||||||
- **最小化内存分配**: 在热点路径上尽可能减少临时对象的产生。
|
- **最小化内存分配**: 在热点路径上尽可能减少临时对象的产生。
|
||||||
- **统一错误处理**: 独创的 `errorCapturingResponseWriter` 机制,能够捕获包括标准库 `http.FileServer` 在内的所有组件产生的错误状态码,并交由全局处理器统一处理。
|
- **统一错误处理**: 独创的 `errorCapturingResponseWriter` 机制,能够捕获包括标准库 `http.FileServer` 在内的所有组件产生的错误状态码,并交由全局处理器统一处理。
|
||||||
- **无缝集成 SSE**: 内置对 Server-Sent Events 的支持,提供简单易用的回调式 API 和高度灵活的通道式 API。
|
- **无缝集成 SSE**: 内置对 Server-Sent Events 的支持,提供简单易用的回调式 API 和高度灵活的通道式 API。
|
||||||
|
- **内置反向代理**: 支持请求转发、协议升级、转发头维护、Trailer 与流式响应透传。
|
||||||
- **静态资源增强**: 针对本地文件、目录以及 Go 嵌入式文件系统(embed.FS)提供了开箱即用的支持。
|
- **静态资源增强**: 针对本地文件、目录以及 Go 嵌入式文件系统(embed.FS)提供了开箱即用的支持。
|
||||||
- **标准库兼容**: 提供了适配器,可以轻松将现有的 `http.Handler` 或 `http.HandlerFunc` 集成到 Touka 中。
|
- **标准库兼容**: 提供了适配器,可以轻松将现有的 `http.Handler` 或 `http.HandlerFunc` 集成到 Touka 中。
|
||||||
|
|
||||||
|
|
@ -21,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其
|
||||||
|
|
||||||
1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。
|
1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。
|
||||||
2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。
|
2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。
|
||||||
3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理。
|
3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求。
|
||||||
|
|
||||||
Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。
|
Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。
|
||||||
|
|
|
||||||
400
docs/logger-migration-design.md
Normal file
400
docs/logger-migration-design.md
Normal file
|
|
@ -0,0 +1,400 @@
|
||||||
|
# Touka Logger 接口迁移方案
|
||||||
|
|
||||||
|
## 基于 Go 1.26 `go:fix inline` 的自动化迁移设计
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 一、问题分析
|
||||||
|
|
||||||
|
当前架构问题:
|
||||||
|
```
|
||||||
|
Engine.LogReco → *reco.Logger (公开字段, 直接访问)
|
||||||
|
Context.GetLogger() → 返回 *reco.Logger (具体类型)
|
||||||
|
Context.Debugf/Infof... → 硬编码 c.engine.LogReco.Debugf(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
这导致用户无法替换日志实现(如 zap/logrus)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、目标架构
|
||||||
|
|
||||||
|
```
|
||||||
|
Engine.logger → Logger 接口 (私有)
|
||||||
|
Engine.LogReco → *reco.Logger (公开, Deprecated - 保持向后兼容)
|
||||||
|
Engine.GetLogger() → 返回 Logger 接口
|
||||||
|
Engine.SetLogger(Logger)→ 设置日志实现
|
||||||
|
Context.GetLogger() → 返回 Logger 接口
|
||||||
|
Context.Debugf/Infof... → 调用 c.engine.logger.Debugf(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、Logger 接口定义
|
||||||
|
|
||||||
|
```go
|
||||||
|
// logger.go
|
||||||
|
package touka
|
||||||
|
|
||||||
|
// Logger 是日志接口,支持任意日志库实现
|
||||||
|
type Logger interface {
|
||||||
|
Debugf(format string, args ...any)
|
||||||
|
Infof(format string, args ...any)
|
||||||
|
Warnf(format string, args ...any)
|
||||||
|
Errorf(format string, args ...any)
|
||||||
|
Fatalf(format string, args ...any)
|
||||||
|
Panicf(format string, args ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloserLogger 可选扩展,支持关闭操作
|
||||||
|
type CloserLogger interface {
|
||||||
|
Logger
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、Engine 结构变更
|
||||||
|
|
||||||
|
```go
|
||||||
|
// engine.go 变更
|
||||||
|
type Engine struct {
|
||||||
|
// ... 其他字段保持不变
|
||||||
|
|
||||||
|
// logger 是新的日志接口 (私有)
|
||||||
|
logger Logger
|
||||||
|
|
||||||
|
// logReco 是保留的 reco.Logger 引用 (私有)
|
||||||
|
// 用于向后兼容,当通过 SetLoggerReco 设置时同步到 logger
|
||||||
|
logReco *reco.Logger
|
||||||
|
|
||||||
|
// 其他字段...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
新增/修改方法:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// GetLogger 返回日志接口
|
||||||
|
func (engine *Engine) GetLogger() Logger {
|
||||||
|
return engine.logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogger 设置任意 Logger 实现
|
||||||
|
func (engine *Engine) SetLogger(l Logger) {
|
||||||
|
engine.logger = l
|
||||||
|
// 如果是 *reco.Logger 类型,同步更新 logReco
|
||||||
|
if rl, ok := l.(*reco.Logger); ok {
|
||||||
|
engine.logReco = rl
|
||||||
|
} else {
|
||||||
|
engine.logReco = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLoggerCfg 使用 reco.Config 配置日志
|
||||||
|
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
|
||||||
|
logger := NewLogger(logcfg)
|
||||||
|
engine.logger = logger
|
||||||
|
engine.logReco = logger
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、`go:fix inline` 兼容性函数
|
||||||
|
|
||||||
|
### 5.1 旧 API 包装函数
|
||||||
|
|
||||||
|
在 `compat.go` 中定义:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// compat.go
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import "github.com/fenthope/reco"
|
||||||
|
|
||||||
|
// GetLogReco 返回 reco.Logger,用于向后兼容
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (engine *Engine) GetLogReco() *reco.Logger {
|
||||||
|
return engine.logReco
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogReco 设置 reco.Logger,用于向后兼容
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (engine *Engine) SetLogReco(l *reco.Logger) {
|
||||||
|
engine.logReco = l
|
||||||
|
engine.logger = l
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 Context 日志方法的 inline 包装
|
||||||
|
|
||||||
|
```go
|
||||||
|
// context_compat.go
|
||||||
|
package touka
|
||||||
|
|
||||||
|
// Debugf 记录 Debug 级别日志
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) Debugf(format string, args ...any) {
|
||||||
|
c.engine.logger.Debugf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infof 记录 Info 级别日志
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) Infof(format string, args ...any) {
|
||||||
|
c.engine.logger.Infof(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warnf 记录 Warn 级别日志
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) Warnf(format string, args ...any) {
|
||||||
|
c.engine.logger.Warnf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errorf 记录 Error 级别日志
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) Errorf(format string, args ...any) {
|
||||||
|
c.engine.logger.Errorf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fatalf 记录 Fatal 级别日志
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) Fatalf(format string, args ...any) {
|
||||||
|
c.engine.logger.Fatalf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Panicf 记录 Panic 级别日志
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) Panicf(format string, args ...any) {
|
||||||
|
c.engine.logger.Panicf(format, args...)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 GetLogger 返回类型的兼容处理
|
||||||
|
|
||||||
|
由于 `GetLogger()` 返回类型从 `*reco.Logger` 变为 `Logger`,需要提供兼容函数:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// context_compat.go (续)
|
||||||
|
|
||||||
|
// GetLoggerReco 返回 *reco.Logger 类型,用于需要具体类型的场景
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) GetLoggerReco() *reco.Logger {
|
||||||
|
if rl, ok := c.engine.logger.(*reco.Logger); ok {
|
||||||
|
return rl
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、go:fix inline 工作原理
|
||||||
|
|
||||||
|
### 迁移前用户代码:
|
||||||
|
```go
|
||||||
|
func handler(c *touka.Context) {
|
||||||
|
// 旧 API 调用
|
||||||
|
c.Debugf("request: %s", c.Request.URL.Path)
|
||||||
|
c.engine.LogReco.Infof("server started")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### go fix 执行后(自动替换):
|
||||||
|
```go
|
||||||
|
func handler(c *touka.Context) {
|
||||||
|
// Debugf 被替换为函数体
|
||||||
|
c.engine.logger.Debugf("request: %s", c.Request.URL.Path)
|
||||||
|
|
||||||
|
// LogReco 访问无法通过 inline 自动处理,需要手动迁移
|
||||||
|
// 或者通过 getter 调用
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 对于字段访问的处理策略:
|
||||||
|
|
||||||
|
`engine.LogReco` 字段访问无法直接用 `go:fix inline` 处理,采用以下策略:
|
||||||
|
|
||||||
|
1. **保留字段但标记 deprecated**:继续导出 `LogReco` 但文档标记为 deprecated
|
||||||
|
2. **提供 getter/setter**:通过 `go:fix inline` 提供 `GetLogReco/SetLogReco`
|
||||||
|
3. **渐进迁移**:用户可以在方便时手动迁移到 `GetLogger()/SetLogger()`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 七、迁移前后对比
|
||||||
|
|
||||||
|
### 场景 1:基本日志调用
|
||||||
|
|
||||||
|
**迁移前:**
|
||||||
|
```go
|
||||||
|
func myHandler(c *touka.Context) {
|
||||||
|
c.Debugf("processing request %s", c.Request.URL.Path)
|
||||||
|
c.Infof("user %s logged in", username)
|
||||||
|
c.Warnf("slow query: %v", duration)
|
||||||
|
c.Errorf("db error: %v", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**迁移后(自动替换):**
|
||||||
|
```go
|
||||||
|
func myHandler(c *touka.Context) {
|
||||||
|
c.engine.logger.Debugf("processing request %s", c.Request.URL.Path)
|
||||||
|
c.engine.logger.Infof("user %s logged in", username)
|
||||||
|
c.engine.logger.Warnf("slow query: %v", duration)
|
||||||
|
c.engine.logger.Errorf("db error: %v", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 场景 2:Engine 配置日志
|
||||||
|
|
||||||
|
**迁移前:**
|
||||||
|
```go
|
||||||
|
engine := touka.New()
|
||||||
|
engine.LogReco = myLogger // 直接赋值
|
||||||
|
logger := engine.LogReco // 直接读取
|
||||||
|
```
|
||||||
|
|
||||||
|
**迁移后(手动 + 自动混合):**
|
||||||
|
```go
|
||||||
|
engine := touka.New()
|
||||||
|
|
||||||
|
// 方式 1:使用新 API(推荐)
|
||||||
|
engine.SetLogger(myLogger)
|
||||||
|
logger := engine.GetLogger()
|
||||||
|
|
||||||
|
// 方式 2:通过 go:fix inline 自动替换为 getter
|
||||||
|
// engine.SetLogReco(myLogger) ← go fix 替换
|
||||||
|
// logger := engine.GetLogReco() ← go fix 替换
|
||||||
|
```
|
||||||
|
|
||||||
|
### 场景 3:使用第三方日志库(新功能)
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "go.uber.org/zap"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
zapLogger, _ := zap.NewProduction()
|
||||||
|
defer zapLogger.Sync()
|
||||||
|
|
||||||
|
engine := touka.New()
|
||||||
|
// 使用 zap 替代默认的 reco.Logger
|
||||||
|
engine.SetLogger(&ZapAdapter{logger: zapLogger})
|
||||||
|
|
||||||
|
engine.GET("/api", func(c *touka.Context) {
|
||||||
|
c.Infof("api called") // 自动使用 zap 输出
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZapAdapter 适配 zap 到 touka.Logger 接口
|
||||||
|
type ZapAdapter struct {
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *ZapAdapter) Debugf(format string, args ...any) {
|
||||||
|
z.logger.Debug(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *ZapAdapter) Infof(format string, args ...any) {
|
||||||
|
z.logger.Info(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *ZapAdapter) Warnf(format string, args ...any) {
|
||||||
|
z.logger.Warn(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *ZapAdapter) Errorf(format string, args ...any) {
|
||||||
|
z.logger.Error(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *ZapAdapter) Fatalf(format string, args ...any) {
|
||||||
|
z.logger.Fatal(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *ZapAdapter) Panicf(format string, args ...any) {
|
||||||
|
z.logger.Panic(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 八、内部使用迁移
|
||||||
|
|
||||||
|
框架内部代码也需要迁移,将直接调用 `engine.LogReco` 改为 `engine.logger`:
|
||||||
|
|
||||||
|
需要修改的文件:
|
||||||
|
- `context.go`: writeResponseBody 中的 `c.engine.LogReco.Errorf`
|
||||||
|
- `recovery.go`: 如有使用日志
|
||||||
|
- `logreco.go`: CloseLogger 方法
|
||||||
|
|
||||||
|
```go
|
||||||
|
// context.go 修改前
|
||||||
|
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
|
||||||
|
if _, err := c.Writer.Write(data); err != nil {
|
||||||
|
if c.engine.LogReco != nil {
|
||||||
|
c.engine.LogReco.Errorf("%s: %v", contextMsg, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// context.go 修改后
|
||||||
|
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
|
||||||
|
if _, err := c.Writer.Write(data); err != nil {
|
||||||
|
if c.engine.logger != nil {
|
||||||
|
c.engine.logger.Errorf("%s: %v", contextMsg, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 九、完整文件结构
|
||||||
|
|
||||||
|
```
|
||||||
|
touka/
|
||||||
|
├── logger.go # Logger 接口定义
|
||||||
|
├── logreco.go # reco.Logger 相关工具函数
|
||||||
|
├── compat.go # go:fix inline 兼容性函数 (Engine)
|
||||||
|
├── context_compat.go # go:fix inline 兼容性函数 (Context)
|
||||||
|
├── engine.go # Engine 结构变更
|
||||||
|
├── context.go # Context 日志方法变更
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 十、版本策略
|
||||||
|
|
||||||
|
| 版本 | 变更内容 |
|
||||||
|
|------|---------|
|
||||||
|
| v1.x | 引入 Logger 接口,LogReco 标记 deprecated |
|
||||||
|
| v2.x | 移除 LogReco 公开字段,仅通过 getter/setter 访问 |
|
||||||
|
| v3.x | 移除 go:fix inline 兼容函数 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 十一、go:fix inline 限制说明
|
||||||
|
|
||||||
|
1. **字段访问无法自动迁移**:`engine.LogReco` 字段访问需要用户手动修改
|
||||||
|
2. **返回类型变更需谨慎**:`GetLogger()` 返回类型变更会导致依赖具体类型的代码失败
|
||||||
|
3. **inline 函数有大小限制**:函数体过大会影响内联效果
|
||||||
|
4. **跨包迁移**:`go:fix inline` 支持跨包,但用户必须运行 `go fix`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 十二、推荐迁移步骤
|
||||||
|
|
||||||
|
1. **框架侧**:添加 Logger 接口,添加 go:fix inline 函数
|
||||||
|
2. **用户侧**:运行 `go fix ./...` 自动迁移可处理的部分
|
||||||
|
3. **用户侧**:手动将 `engine.LogReco` 字段访问改为 `engine.SetLogger()/GetLogger()`
|
||||||
|
4. **用户侧**:如需使用第三方日志,实现 Logger 接口并通过 SetLogger 设置
|
||||||
|
|
@ -26,6 +26,41 @@ api.Use(AuthMiddleware())
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
也可以在创建组时直接传入中间件:
|
||||||
|
|
||||||
|
```go
|
||||||
|
api := r.Group("/api", AuthMiddleware(), RateLimitMiddleware())
|
||||||
|
{
|
||||||
|
api.GET("/user", handleUser)
|
||||||
|
api.POST("/data", handleData)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 路由级中间件
|
||||||
|
|
||||||
|
为单个路由注册中间件,仅对该路由生效。
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 单个路由中间件
|
||||||
|
r.GET("/protected", AuthMiddleware(), func(c *touka.Context) {
|
||||||
|
c.String(http.StatusOK, "Protected content")
|
||||||
|
})
|
||||||
|
|
||||||
|
// 多个路由中间件(按顺序执行)
|
||||||
|
r.POST("/upload",
|
||||||
|
RateLimitMiddleware(),
|
||||||
|
AuthMiddleware(),
|
||||||
|
PermissionCheckMiddleware(),
|
||||||
|
func(c *touka.Context) {
|
||||||
|
// 处理上传
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
// 路由组中的单个路由也可以使用路由级中间件
|
||||||
|
api := r.Group("/api")
|
||||||
|
api.GET("/admin", AdminAuthMiddleware(), adminHandler)
|
||||||
|
```
|
||||||
|
|
||||||
## 编写自定义中间件
|
## 编写自定义中间件
|
||||||
|
|
||||||
中间件的函数签名是 `touka.HandlerFunc`。
|
中间件的函数签名是 `touka.HandlerFunc`。
|
||||||
|
|
@ -53,7 +88,7 @@ func TimerMiddleware() touka.HandlerFunc {
|
||||||
```go
|
```go
|
||||||
func APIKeyAuth() touka.HandlerFunc {
|
func APIKeyAuth() touka.HandlerFunc {
|
||||||
return func(c *touka.Context) {
|
return func(c *touka.Context) {
|
||||||
apiKey := c.GetHeader("X-API-KEY")
|
apiKey := c.GetReqHeader("X-API-KEY")
|
||||||
if apiKey != "secret-token" {
|
if apiKey != "secret-token" {
|
||||||
// 验证失败,返回错误并中止后续逻辑
|
// 验证失败,返回错误并中止后续逻辑
|
||||||
c.JSON(http.StatusUnauthorized, touka.H{"error": "Invalid API Key"})
|
c.JSON(http.StatusUnauthorized, touka.H{"error": "Invalid API Key"})
|
||||||
|
|
@ -67,6 +102,36 @@ func APIKeyAuth() touka.HandlerFunc {
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 中间件执行顺序
|
||||||
|
|
||||||
|
理解中间件的执行顺序对于构建正确的处理流程至关重要。**注意:注册顺序决定了执行逻辑**,中间件必须在注册路由之前调用(全局中间件应在创建组或定义路由前注册)。中间件按照以下顺序执行:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 全局中间件
|
||||||
|
r.Use(GlobalMiddleware1())
|
||||||
|
r.Use(GlobalMiddleware2())
|
||||||
|
|
||||||
|
// 组中间件
|
||||||
|
api := r.Group("/api", GroupMiddleware1())
|
||||||
|
api.Use(GroupMiddleware2())
|
||||||
|
|
||||||
|
// 路由级中间件
|
||||||
|
api.GET("/users", RouteMiddleware1(), RouteMiddleware2(), userHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
对于 `/api/users` 请求,执行顺序为:
|
||||||
|
1. `GlobalMiddleware1()` - 全局中间件
|
||||||
|
2. `GlobalMiddleware2()` - 全局中间件
|
||||||
|
3. `GroupMiddleware1()` - 路由组中间件
|
||||||
|
4. `GroupMiddleware2()` - 路由组中间件
|
||||||
|
5. `RouteMiddleware1()` - 路由级中间件
|
||||||
|
6. `RouteMiddleware2()` - 路由级中间件
|
||||||
|
7. `userHandler` - 最终处理函数
|
||||||
|
|
||||||
|
```
|
||||||
|
请求进入 → 全局中间件 → 路由组中间件 → 路由级中间件 → 最终处理函数 → 路由级中间件后置逻辑 → 路由组中间件后置逻辑 → 全局中间件后置逻辑 → 响应
|
||||||
|
```
|
||||||
|
|
||||||
## 内置中间件
|
## 内置中间件
|
||||||
|
|
||||||
- **Recovery**: 捕获任何发生的 panic,恢复运行并返回 500 错误。它还负责调用全局错误处理器。
|
- **Recovery**: 捕获任何发生的 panic,恢复运行并返回 500 错误。它还负责调用全局错误处理器。
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
## 安装
|
## 安装
|
||||||
|
|
||||||
确保您的环境中已经安装了 Go 1.25 或更高版本。
|
确保您的环境中已经安装了 Go 1.26 或更高版本。
|
||||||
|
|
||||||
在您的项目目录中运行:
|
在您的项目目录中运行:
|
||||||
|
|
||||||
|
|
@ -46,7 +46,7 @@ func main() {
|
||||||
|
|
||||||
// 4. 启动服务器并监听 8080 端口
|
// 4. 启动服务器并监听 8080 端口
|
||||||
log.Println("Touka server is running on :8080")
|
log.Println("Touka server is running on :8080")
|
||||||
if err := r.Run(":8080"); err != nil {
|
if err := r.Run(touka.WithAddr(":8080")); err != nil {
|
||||||
log.Fatalf("Server failed: %v", err)
|
log.Fatalf("Server failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -66,11 +66,11 @@ go run main.go
|
||||||
|
|
||||||
## 优雅停机
|
## 优雅停机
|
||||||
|
|
||||||
在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成。
|
在生产环境中,我们推荐为 `Run` 追加优雅关闭选项。启用后,Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`。
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// 等待 10 秒以处理剩余请求
|
// 等待 10 秒以处理剩余请求
|
||||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||||
log.Fatalf("Server forced to shutdown: %v", err)
|
log.Fatalf("Server forced to shutdown: %v", err)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
517
docs/reverse-proxy.md
Normal file
517
docs/reverse-proxy.md
Normal file
|
|
@ -0,0 +1,517 @@
|
||||||
|
# 反向代理
|
||||||
|
|
||||||
|
Touka 内置了反向代理能力,可以直接把某一组请求转发到后端服务,同时保留 Touka 的路由、中间件与统一错误处理风格。
|
||||||
|
|
||||||
|
`touka.ReverseProxy` 返回一个 `HandlerFunc`,因此它可以像普通路由处理器一样直接挂到 `GET`、`ANY`、路由组等位置。
|
||||||
|
|
||||||
|
## 最简单的用法
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/infinite-iroha/touka"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
r := touka.Default()
|
||||||
|
|
||||||
|
target, err := url.Parse("http://127.0.0.1:9000")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
}))
|
||||||
|
|
||||||
|
_ = r.Run(touka.WithAddr(":8080"))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
当客户端访问 `http://127.0.0.1:8080/api/users` 时,请求会被转发到 `http://127.0.0.1:9000/api/users`。
|
||||||
|
|
||||||
|
## 带基础路径的代理
|
||||||
|
|
||||||
|
如果目标服务部署在一个子路径下,可以直接把目标地址写成带路径的 URL:
|
||||||
|
|
||||||
|
```go
|
||||||
|
target, _ := url.Parse("http://127.0.0.1:9000/backend")
|
||||||
|
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
此时:
|
||||||
|
|
||||||
|
- `/api/users` 会转发到 `/backend/api/users`
|
||||||
|
- `/api/orders?id=10` 会转发到 `/backend/api/orders?id=10`
|
||||||
|
|
||||||
|
目标 URL 自身携带的查询参数也会被保留并与原请求查询参数合并。
|
||||||
|
合并后的出站查询串会再经过一次规范化处理,因此某些非标准分隔符(例如 `;`)或非法参数片段可能被重编码、折叠或直接丢弃。
|
||||||
|
这是为了尽量让代理链各跳对查询参数的解析结果保持一致,并减少参数走私这类解析歧义风险。
|
||||||
|
|
||||||
|
## 配置项说明
|
||||||
|
|
||||||
|
```go
|
||||||
|
type ReverseProxyConfig struct {
|
||||||
|
Target *url.URL
|
||||||
|
Targets []string
|
||||||
|
|
||||||
|
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||||
|
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||||
|
|
||||||
|
Transport http.RoundTripper
|
||||||
|
FlushInterval time.Duration
|
||||||
|
BufferPool BufferPool
|
||||||
|
AllowH2CUpstream bool
|
||||||
|
|
||||||
|
ModifyRequest func(*http.Request)
|
||||||
|
ModifyResponse func(*http.Response) error
|
||||||
|
ErrorHandler func(http.ResponseWriter, *http.Request, error)
|
||||||
|
|
||||||
|
ForwardedHeaders ForwardedHeadersPolicy
|
||||||
|
ForwardedBy string
|
||||||
|
Via string
|
||||||
|
PreserveHost bool
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### `Target`
|
||||||
|
|
||||||
|
与 `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme` 和 `host`。
|
||||||
|
|
||||||
|
```go
|
||||||
|
target, _ := url.Parse("http://backend:9000")
|
||||||
|
```
|
||||||
|
|
||||||
|
### `Targets`
|
||||||
|
|
||||||
|
可选。用于配置多个后端目标地址。
|
||||||
|
|
||||||
|
- `Target` 与 `Targets` 互斥,只能使用其中一种
|
||||||
|
- `Targets` 的每一项都必须是完整 URL
|
||||||
|
- 每个 target 仍然可以自带 base path 和 query
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Targets: []string{
|
||||||
|
"http://127.0.0.1:9001/base?from=a",
|
||||||
|
"http://127.0.0.1:9002/base?from=b",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。
|
||||||
|
|
||||||
|
### `LoadBalancing`
|
||||||
|
|
||||||
|
用于配置 upstream 选择策略和重试行为。
|
||||||
|
|
||||||
|
```go
|
||||||
|
type ReverseProxyLoadBalancingConfig struct {
|
||||||
|
Policy ReverseProxyLBPolicy
|
||||||
|
Retries int
|
||||||
|
TryDuration time.Duration
|
||||||
|
TryInterval time.Duration
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
当前内置策略:
|
||||||
|
|
||||||
|
- `touka.LBRandom()`
|
||||||
|
- `touka.LBRoundRobin()`
|
||||||
|
- `touka.LBFirst()`
|
||||||
|
- `touka.LBLeastConn()`
|
||||||
|
- `touka.LBIPHash()`
|
||||||
|
- `touka.LBClientIPHash()`
|
||||||
|
- `touka.LBURIHash()`
|
||||||
|
- `touka.LBHeader("X-Upstream", fallback)`
|
||||||
|
- `touka.LBQuery("tenant", fallback)`
|
||||||
|
|
||||||
|
其中:
|
||||||
|
|
||||||
|
- `LBFirst()` 适合主备/故障转移顺序
|
||||||
|
- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback
|
||||||
|
- 如果 `LBHeader` / `LBQuery` 没有显式 fallback,则默认回退到 `LBRandom()`
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Targets: []string{
|
||||||
|
"http://127.0.0.1:9001",
|
||||||
|
"http://127.0.0.1:9002",
|
||||||
|
},
|
||||||
|
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: touka.LBHeader("X-Upstream", touka.LBFirst()),
|
||||||
|
Retries: 1,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
重试说明:
|
||||||
|
|
||||||
|
- 只对未开始收到上游响应的失败进行重试
|
||||||
|
- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试
|
||||||
|
- `Retries` 表示额外重试次数
|
||||||
|
- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机
|
||||||
|
- `TryInterval` 表示两次重试之间的等待间隔
|
||||||
|
|
||||||
|
### `PassiveHealth`
|
||||||
|
|
||||||
|
用于配置被动健康检查。它不会后台探测 upstream,而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。
|
||||||
|
|
||||||
|
```go
|
||||||
|
type ReverseProxyPassiveHealthConfig struct {
|
||||||
|
FailDuration time.Duration
|
||||||
|
MaxFails int
|
||||||
|
UnhealthyStatus []int
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `FailDuration > 0` 时启用被动健康跟踪
|
||||||
|
- `MaxFails <= 0` 时默认按 `1` 处理
|
||||||
|
- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Targets: []string{
|
||||||
|
"http://127.0.0.1:9001",
|
||||||
|
"http://127.0.0.1:9002",
|
||||||
|
},
|
||||||
|
LoadBalancing: touka.ReverseProxyLoadBalancingConfig{
|
||||||
|
Policy: touka.LBFirst(),
|
||||||
|
},
|
||||||
|
PassiveHealth: touka.ReverseProxyPassiveHealthConfig{
|
||||||
|
FailDuration: time.Minute,
|
||||||
|
UnhealthyStatus: []int{http.StatusServiceUnavailable},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
### `AllowH2CUpstream`
|
||||||
|
|
||||||
|
允许代理使用未加密 HTTP/2(h2c)与 `http://` upstream 通信。
|
||||||
|
|
||||||
|
- 默认关闭
|
||||||
|
- 这是一个显式配置项
|
||||||
|
- 启用后,Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游
|
||||||
|
- 这意味着上游本身也必须显式支持 h2c;它不是“先试 h2c,失败再自动回退到 h1”的协商模式
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
AllowH2CUpstream: true,
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
对于下游 HTTP/2 extended `CONNECT` websocket 场景,Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade,以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。
|
||||||
|
|
||||||
|
### `Transport`
|
||||||
|
|
||||||
|
可选。用于自定义底层转发所使用的 `http.RoundTripper`。
|
||||||
|
|
||||||
|
如果留空,则默认使用 `http.DefaultTransport`。
|
||||||
|
|
||||||
|
```go
|
||||||
|
proxyTransport := &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
Transport: proxyTransport,
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
### `FlushInterval`
|
||||||
|
|
||||||
|
控制代理在复制响应体时的主动刷新间隔:
|
||||||
|
|
||||||
|
- `0`:不额外定时刷新
|
||||||
|
- `> 0`:按指定间隔刷新
|
||||||
|
- `< 0`:每次写入后立即刷新
|
||||||
|
|
||||||
|
对于 SSE 和无 `Content-Length` 的流式响应,Touka 会自动立即刷新,不依赖该配置。
|
||||||
|
|
||||||
|
### `BufferPool`
|
||||||
|
|
||||||
|
可选。用于为响应体复制过程提供可复用的字节缓冲区,以减少大响应或高并发代理场景下的临时内存分配。
|
||||||
|
|
||||||
|
如果留空,Touka 会在复制响应体时按需分配默认缓冲区。
|
||||||
|
|
||||||
|
```go
|
||||||
|
type bytePool struct {
|
||||||
|
pool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bytePool) Get() []byte {
|
||||||
|
if buf, ok := p.pool.Get().([]byte); ok {
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
return make([]byte, 32*1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bytePool) Put(buf []byte) {
|
||||||
|
if cap(buf) >= 32*1024 {
|
||||||
|
p.pool.Put(buf[:32*1024])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyPool := &bytePool{}
|
||||||
|
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
BufferPool: proxyPool,
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
通常只有在您已经观察到明显的分配压力,或代理的响应体较大、吞吐较高时,才需要专门配置它。
|
||||||
|
|
||||||
|
### `ModifyRequest`
|
||||||
|
|
||||||
|
在请求真正发往后端前,对出站请求做最后修改。
|
||||||
|
|
||||||
|
如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。
|
||||||
|
|
||||||
|
常见用途:
|
||||||
|
|
||||||
|
- 覆盖 `Host`
|
||||||
|
- 增加鉴权头
|
||||||
|
- 重写路径
|
||||||
|
- 注入内部追踪头
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ModifyRequest: func(req *http.Request) {
|
||||||
|
req.Header.Set("X-Internal-Token", "gateway-token")
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
### `ModifyResponse`
|
||||||
|
|
||||||
|
在后端返回响应后、写回客户端前,对响应做额外处理。
|
||||||
|
|
||||||
|
注意:`ModifyResponse` 也会作用于 `101 Switching Protocols` 响应。
|
||||||
|
如果该代理路由需要转发 WebSocket 或其他 Upgrade 流量,请不要在这里消费、完全缓冲,或替换 `resp.Body` 为只读对象;后续升级流程仍然要求它保留 `io.ReadWriteCloser` 能力。
|
||||||
|
更稳妥的做法是对 `101` 响应直接跳过这类处理。
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ModifyResponse: func(resp *http.Response) error {
|
||||||
|
if resp.StatusCode == http.StatusSwitchingProtocols {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resp.Header.Set("X-Proxy", "touka")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
如果该函数返回错误,会转入 `ErrorHandler` 或默认的 `502 Bad Gateway` 处理流程。
|
||||||
|
|
||||||
|
### `ErrorHandler`
|
||||||
|
|
||||||
|
用于处理连接后端失败、协议升级失败、`ModifyResponse` 返回错误等情况。
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
|
_, _ = w.Write([]byte("upstream unavailable"))
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
### `PreserveHost`
|
||||||
|
|
||||||
|
默认情况下,代理请求的 `Host` 会跟随后端目标地址。
|
||||||
|
|
||||||
|
如果设置为 `true`,则会保留客户端原始 `Host`。
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
PreserveHost: true,
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
这在某些依赖原始域名进行路由或租户识别的后端服务中会比较有用。
|
||||||
|
|
||||||
|
## 转发头策略
|
||||||
|
|
||||||
|
Touka 支持两类常见的代理转发头:
|
||||||
|
|
||||||
|
- 兼容性更好的 `X-Forwarded-*`
|
||||||
|
- 标准化的 `Forwarded`(RFC 7239)
|
||||||
|
|
||||||
|
可选值:
|
||||||
|
|
||||||
|
```go
|
||||||
|
const (
|
||||||
|
ForwardedBoth ForwardedHeadersPolicy = iota
|
||||||
|
ForwardedNone
|
||||||
|
ForwardedXForwardedOnly
|
||||||
|
ForwardedRFC7239Only
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
推荐默认使用 `ForwardedBoth`。
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ForwardedHeaders: touka.ForwardedBoth,
|
||||||
|
ForwardedBy: "_gateway-1",
|
||||||
|
Via: "edge-1",
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。
|
||||||
|
|
||||||
|
- IPv4:`203.0.113.43`
|
||||||
|
- IPv6 / 带端口:`[2001:db8::17]:443`
|
||||||
|
- 匿名标识:`_gateway-1`
|
||||||
|
- 未知:`unknown`
|
||||||
|
|
||||||
|
像 `gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。
|
||||||
|
|
||||||
|
`Via` 不是“留空即禁用”的开关。当前实现中:
|
||||||
|
|
||||||
|
- 如果 `Via` 非空,则使用该值追加 `Via`
|
||||||
|
- 如果 `Via` 为空,则会回退到固定值 `touka-engine`
|
||||||
|
|
||||||
|
因此,把 `Via` 留空时,发送出去的请求仍会包含 `Via` 头,只是使用默认标识 `touka-engine`。
|
||||||
|
|
||||||
|
如果您希望上游清楚区分不同入口、环境或网关实例,仍然建议显式设置一个稳定且可公开暴露的代理标识,例如:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
Via: "edge-gateway",
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
当前版本没有提供“完全禁用追加 Via”的单独配置项,因此不要把空字符串当作关闭手段。
|
||||||
|
|
||||||
|
### Touka 会如何处理这些头?
|
||||||
|
|
||||||
|
Touka 会尽量遵循代理链语义:
|
||||||
|
|
||||||
|
- 已有的 `X-Forwarded-For` 会保留,并在末尾追加当前 hop 的客户端 IP
|
||||||
|
- 已有的 `Forwarded` 会保留,并在末尾追加当前 hop 的条目
|
||||||
|
- 已有的 `X-Forwarded-Host` 与 `X-Forwarded-Proto` 会优先保留;如果缺失,则由当前请求补齐
|
||||||
|
- `Via` 会追加当前代理标识
|
||||||
|
|
||||||
|
这意味着在 Touka 前面还有一层可信代理(如 Nginx、Traefik、Cloudflare、网关)时,上游服务仍然可以看到完整的代理链。
|
||||||
|
|
||||||
|
如果您**不信任**客户端传入的这些头,请在进入 `ReverseProxy` 之前自行清理,或在 `ModifyRequest` 中显式重写。
|
||||||
|
|
||||||
|
## 协议升级与流式响应
|
||||||
|
|
||||||
|
Touka 的反向代理实现支持以下能力:
|
||||||
|
|
||||||
|
- `CONNECT` 隧道转发(HTTP/1.x)
|
||||||
|
- HTTP/2 extended `CONNECT`
|
||||||
|
- `Connection: Upgrade` / `Upgrade` 协议升级转发
|
||||||
|
- WebSocket 等 101 Switching Protocols 场景
|
||||||
|
- SSE(Server-Sent Events)立即刷新
|
||||||
|
- Trailer 透传
|
||||||
|
- 1xx 响应透传
|
||||||
|
- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理
|
||||||
|
|
||||||
|
例如,代理 WebSocket 服务:
|
||||||
|
|
||||||
|
```go
|
||||||
|
target, _ := url.Parse("http://127.0.0.1:9001")
|
||||||
|
|
||||||
|
r.ANY("/ws/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
}))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Hop-by-hop 头处理
|
||||||
|
|
||||||
|
根据 HTTP 代理语义,Touka 在转发时会移除连接级别的 hop-by-hop 头,避免把只应作用于单跳连接的头继续传给下游。
|
||||||
|
|
||||||
|
典型包括:
|
||||||
|
|
||||||
|
- `Connection`
|
||||||
|
- `Proxy-Connection`
|
||||||
|
- `Keep-Alive`
|
||||||
|
- `Proxy-Authenticate`
|
||||||
|
- `Proxy-Authorization`
|
||||||
|
- `TE`
|
||||||
|
- `Trailer`
|
||||||
|
- `Transfer-Encoding`
|
||||||
|
- `Upgrade`
|
||||||
|
|
||||||
|
同时,若请求本身是合法的协议升级请求,Touka 会在剥离后重新补回必要的 `Connection: Upgrade` 与 `Upgrade` 头。
|
||||||
|
|
||||||
|
## 一个更完整的例子
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/infinite-iroha/touka"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
r := touka.Default()
|
||||||
|
|
||||||
|
target, err := url.Parse("http://127.0.0.1:9000")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ForwardedHeaders: touka.ForwardedBoth,
|
||||||
|
ForwardedBy: "_gateway-1",
|
||||||
|
Via: "gateway-1",
|
||||||
|
FlushInterval: 100 * time.Millisecond,
|
||||||
|
ModifyRequest: func(req *http.Request) {
|
||||||
|
req.Header.Set("X-Gateway", "touka")
|
||||||
|
},
|
||||||
|
ModifyResponse: func(resp *http.Response) error {
|
||||||
|
resp.Header.Set("X-Proxy", "touka")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
|
_, _ = w.Write([]byte("bad gateway"))
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 与 `SetForwardByClientIP` 的关系
|
||||||
|
|
||||||
|
`ReverseProxy` 负责把请求转发给后端,并维护代理链头。
|
||||||
|
|
||||||
|
而 `SetForwardByClientIP` / `SetRemoteIPHeaders` 是 Touka 在**接收请求**时,用于解析当前请求客户端 IP 的逻辑。
|
||||||
|
|
||||||
|
两者通常会一起出现,但解决的是两个不同方向的问题:
|
||||||
|
|
||||||
|
- `ReverseProxy`:出站转发
|
||||||
|
- `SetForwardByClientIP`:入站解析
|
||||||
|
|
||||||
|
如果您的 Touka 本身就部署在其他代理之后,建议同时正确配置这两部分。
|
||||||
|
|
@ -17,8 +17,13 @@ r.OPTIONS("/someOptions", handle)
|
||||||
|
|
||||||
// 注册所有上述方法的路由
|
// 注册所有上述方法的路由
|
||||||
r.ANY("/any", handle)
|
r.ANY("/any", handle)
|
||||||
|
|
||||||
|
// 同时注册多个方法
|
||||||
|
r.HandleFunc([]string{"GET", "POST"}, "/multi", handle)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。
|
||||||
|
|
||||||
## 路径参数 (Named Parameters)
|
## 路径参数 (Named Parameters)
|
||||||
|
|
||||||
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。
|
使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。
|
||||||
|
|
@ -78,8 +83,8 @@ Touka 允许您自定义路由匹配的行为:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
r := touka.New()
|
r := touka.New()
|
||||||
r.RedirectTrailingSlash = true
|
r.SetRedirectTrailingSlash(true)
|
||||||
r.HandleMethodNotAllowed = true
|
r.SetHandleMethodNotAllowed(true)
|
||||||
```
|
```
|
||||||
|
|
||||||
## 获取已注册路由信息
|
## 获取已注册路由信息
|
||||||
|
|
@ -92,3 +97,59 @@ for _, route := range routes {
|
||||||
fmt.Printf("Method: %s, Path: %s\n", route.Method, route.Path)
|
fmt.Printf("Method: %s, Path: %s\n", route.Method, route.Path)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 自定义 404 处理
|
||||||
|
|
||||||
|
当请求没有匹配到任何路由时,Touka 会返回 404。您可以自定义 404 的处理逻辑:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 使用单个处理器
|
||||||
|
r.NoRoute(func(c *touka.Context) {
|
||||||
|
c.JSON(http.StatusNotFound, touka.H{
|
||||||
|
"error": "资源未找到",
|
||||||
|
"path": c.Request.URL.Path,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// 使用处理器链
|
||||||
|
r.NoRoutes(
|
||||||
|
LogNotFoundMiddleware(),
|
||||||
|
func(c *touka.Context) {
|
||||||
|
c.JSON(http.StatusNotFound, touka.H{"error": "Not found"})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**注意**:`NoRoute` 和 `NoRoutes` 不是处理链的终点,您仍然可以在其中调用 `c.Next()` 来继续执行默认的 404 处理。
|
||||||
|
|
||||||
|
## 静态文件路由
|
||||||
|
|
||||||
|
Touka 提供了便捷的方法来注册静态文件路由:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 服务整个目录
|
||||||
|
r.StaticDir("/assets", "./static")
|
||||||
|
// 访问 /assets/js/main.js 将返回 ./static/js/main.js
|
||||||
|
|
||||||
|
// 服务单个文件
|
||||||
|
r.StaticFile("/favicon.ico", "./resources/favicon.ico")
|
||||||
|
|
||||||
|
// 服务嵌入式文件系统
|
||||||
|
//go:embed dist/*
|
||||||
|
var content embed.FS
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
r := touka.Default()
|
||||||
|
fsroot, _ := fs.Sub(content, "dist")
|
||||||
|
r.StaticFS("/", http.FS(fsroot))
|
||||||
|
r.Run(touka.WithAddr(":8080"))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
这些方法同样可以在路由组中使用:
|
||||||
|
|
||||||
|
```go
|
||||||
|
api := r.Group("/api")
|
||||||
|
api.StaticDir("/files", "./uploads")
|
||||||
|
api.StaticFile("/logo", "./assets/logo.png")
|
||||||
|
```
|
||||||
|
|
|
||||||
37
docs/sse.md
37
docs/sse.md
|
|
@ -40,43 +40,40 @@ r.GET("/events", func(c *touka.Context) {
|
||||||
|
|
||||||
## 模式二:通道模式 (EventStreamChan)
|
## 模式二:通道模式 (EventStreamChan)
|
||||||
|
|
||||||
如果您需要更高级的并发控制(例如从多个异步源接收数据),可以使用通道模式。
|
如果您需要更高级的并发控制(例如从多个异步源接收数据),可以使用通道模式。与回调模式类似,此方法是**阻塞的**:handler 会在此方法中停留,直到事件 channel 被关闭或客户端断开连接。
|
||||||
|
|
||||||
```go
|
```go
|
||||||
r.GET("/events-chan", func(c *touka.Context) {
|
r.GET("/events-chan", func(c *touka.Context) {
|
||||||
eventChan, errChan := c.EventStreamChan()
|
eventChan := make(chan touka.Event)
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// 监听错误/断开连接
|
// 在独立的 goroutine 中发送事件.
|
||||||
go func() {
|
go func() {
|
||||||
if err := <-errChan; err != nil {
|
defer close(eventChan) // 务必在结束时关闭以结束事件流.
|
||||||
log.Printf("SSE 错误: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 发送数据
|
|
||||||
go func() {
|
|
||||||
defer close(eventChan) // 务必在结束时关闭
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
select {
|
select {
|
||||||
case <-c.Request.Context().Done():
|
case <-ctx.Done():
|
||||||
return
|
return // 客户端已断开, 退出 goroutine.
|
||||||
default:
|
case eventChan <- touka.Event{
|
||||||
eventChan <- touka.Event{
|
|
||||||
Data: fmt.Sprintf("消息 #%d", i),
|
Data: fmt.Sprintf("消息 #%d", i),
|
||||||
|
}:
|
||||||
}
|
}
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// EventStreamChan 会阻塞直到流结束.
|
||||||
|
c.EventStreamChan(eventChan)
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
## 最佳实践
|
## 最佳实践
|
||||||
|
|
||||||
1. **资源回收**: 确保在 `EventStreamChan` 模式下正确监听 `c.Request.Context().Done()` 以避免 Goroutine 泄漏。
|
1. **资源回收**: `EventStreamChan` 是阻塞的,handler 在事件流结束前不会返回。将 `c.Request.Context().Done()` 和 `eventChan <- ...` 作为同一个 `select` 的两个分支,确保发送操作本身能够响应客户端断开。
|
||||||
2. **数据格式**: SSE 协议要求数据为 UTF-8。Touka 的 `Render` 方法会自动处理多行数据并加上必要的 `data:` 前缀。
|
2. **关闭 Channel**: 生产者完成发送后必须 `close(eventChan)`,否则 handler 会永远阻塞。
|
||||||
3. **超时管理**: SSE 连接通常是长连接,请确保您的反向代理(如 Nginx)配置了足够大的写超时时间。
|
3. **数据格式**: SSE 协议要求数据为 UTF-8。Touka 的 `Render` 方法会自动处理多行数据并加上必要的 `data:` 前缀。
|
||||||
|
4. **超时管理**: SSE 连接通常是长连接,请确保您的反向代理(如 Nginx)配置了足够大的写超时时间。
|
||||||
|
|
||||||
## 优雅关闭与资源清理
|
## 优雅关闭与资源清理
|
||||||
|
|
||||||
|
|
@ -128,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) {
|
||||||
2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。
|
2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。
|
||||||
3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。
|
3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。
|
||||||
|
|
||||||
**注意:** 请务必使用 `RunShutdown`、`RunTLS` 或 `RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号。
|
**注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)` 或 `touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文。
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ func main() {
|
||||||
// 您也可以使用 StaticFS 服务根路径
|
// 您也可以使用 StaticFS 服务根路径
|
||||||
// r.StaticFS("/", http.FS(fsroot))
|
// r.StaticFS("/", http.FS(fsroot))
|
||||||
|
|
||||||
r.Run(":8080")
|
r.Run(touka.WithAddr(":8080"))
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -49,7 +49,7 @@ func main() {
|
||||||
|
|
||||||
```go
|
```go
|
||||||
r := touka.New()
|
r := touka.New()
|
||||||
r.SetUnMatchFS(http.Dir("./frontend/dist"), true)
|
r.SetUnMatchFS(http.Dir("./frontend/dist"))
|
||||||
|
|
||||||
// API 路由
|
// API 路由
|
||||||
r.GET("/api/status", handleStatus)
|
r.GET("/api/status", handleStatus)
|
||||||
|
|
|
||||||
2
ecw.go
2
ecw.go
|
|
@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool {
|
||||||
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
hijacker, ok := ecw.w.(http.Hijacker)
|
hijacker, ok := ecw.w.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface")
|
return nil, nil, http.ErrNotSupported
|
||||||
}
|
}
|
||||||
return hijacker.Hijack()
|
return hijacker.Hijack()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
59
ecw_benchmark_test.go
Normal file
59
ecw_benchmark_test.go
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrorCapturingResponseWriterResetClearsHeaderSnapshot(t *testing.T) {
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
ecw := AcquireErrorCapturingResponseWriter(c)
|
||||||
|
defer ReleaseErrorCapturingResponseWriter(ecw)
|
||||||
|
|
||||||
|
ecw.capturedErrorSignal = true
|
||||||
|
ecw.Header().Set("Content-Type", "text/plain")
|
||||||
|
ecw.Header().Add("X-Test", "one")
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ecw.reset(httptest.NewRecorder(), req, c, c.engine.errorHandle.handler)
|
||||||
|
|
||||||
|
if len(ecw.headerSnapshot) != 0 {
|
||||||
|
t.Fatalf("expected header snapshot to be empty after reset, got %#v", ecw.headerSnapshot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkErrorCapturingResponseWriterReset(b *testing.B) {
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
ecw := AcquireErrorCapturingResponseWriter(c)
|
||||||
|
defer ReleaseErrorCapturingResponseWriter(ecw)
|
||||||
|
|
||||||
|
rawWriter := httptest.NewRecorder()
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 16)
|
||||||
|
for i := range keys {
|
||||||
|
keys[i] = http.CanonicalHeaderKey("X-Test-" + string(rune('A'+i)))
|
||||||
|
}
|
||||||
|
values := []string{"one", "two", "three"}
|
||||||
|
for _, key := range keys {
|
||||||
|
ecw.headerSnapshot[key] = values
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ecw.reset(rawWriter, req, c, c.engine.errorHandle.handler)
|
||||||
|
for _, key := range keys {
|
||||||
|
ecw.headerSnapshot[key] = values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
393
engine.go
393
engine.go
|
|
@ -7,9 +7,11 @@ package touka
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
|
@ -17,6 +19,7 @@ import (
|
||||||
|
|
||||||
"github.com/WJQSERVER-STUDIO/httpc"
|
"github.com/WJQSERVER-STUDIO/httpc"
|
||||||
"github.com/fenthope/reco"
|
"github.com/fenthope/reco"
|
||||||
|
"github.com/go-json-experiment/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Last 返回链中的最后一个处理函数
|
// Last 返回链中的最后一个处理函数
|
||||||
|
|
@ -49,8 +52,14 @@ type Engine struct {
|
||||||
|
|
||||||
HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求
|
HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求
|
||||||
|
|
||||||
|
// LogReco 保留的 reco.Logger 字段
|
||||||
|
// Deprecated: 使用 SetLogger/GetLogger 替代
|
||||||
LogReco *reco.Logger
|
LogReco *reco.Logger
|
||||||
|
|
||||||
|
// logger 是新的日志接口,支持任意 Logger 实现
|
||||||
|
// 优先级: logger > LogReco
|
||||||
|
logger Logger
|
||||||
|
|
||||||
HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口
|
HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口
|
||||||
|
|
||||||
routesInfo []RouteInfo // 存储所有注册的路由信息
|
routesInfo []RouteInfo // 存储所有注册的路由信息
|
||||||
|
|
@ -81,6 +90,11 @@ type Engine struct {
|
||||||
|
|
||||||
// GlobalMaxRequestBodySize 全局请求体Body大小限制
|
// GlobalMaxRequestBodySize 全局请求体Body大小限制
|
||||||
GlobalMaxRequestBodySize int64
|
GlobalMaxRequestBodySize int64
|
||||||
|
|
||||||
|
notFoundChain HandlersChain
|
||||||
|
notFoundNoMethodChain HandlersChain
|
||||||
|
unmatchedFSChain HandlersChain
|
||||||
|
unmatchedFSNoMethodChain HandlersChain
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
||||||
|
|
@ -116,6 +130,90 @@ type ErrorHandle struct {
|
||||||
|
|
||||||
type ErrorHandler func(c *Context, code int, err error)
|
type ErrorHandler func(c *Context, code int, err error)
|
||||||
|
|
||||||
|
var errMethodNotAllowed = errors.New("method not allowed")
|
||||||
|
var errNotFound = errors.New("not found")
|
||||||
|
|
||||||
|
type defaultErrorResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error())
|
||||||
|
var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error())
|
||||||
|
|
||||||
|
func mustMarshalDefaultErrorBody(code int, errMsg string) []byte {
|
||||||
|
body, err := json.Marshal(defaultErrorResponse{
|
||||||
|
Code: code,
|
||||||
|
Message: http.StatusText(code),
|
||||||
|
Error: errMsg,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeDefaultErrorJSON(c *Context, code int, body []byte) {
|
||||||
|
if c == nil || c.Writer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
c.Writer.WriteHeader(code)
|
||||||
|
c.writeResponseBody(body, "failed to write default error response")
|
||||||
|
c.Writer.Flush()
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
|
var methodNotAllowedHandler HandlerFunc = func(c *Context) {
|
||||||
|
httpMethod := c.Request.Method
|
||||||
|
requestPath := routeLookupPath(c.Request)
|
||||||
|
engine := c.engine
|
||||||
|
// 是否是OPTIONS方式
|
||||||
|
if httpMethod == http.MethodOptions {
|
||||||
|
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
||||||
|
allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0])
|
||||||
|
c.allowedMethodsBuf = allowedMethods[:0]
|
||||||
|
if len(allowedMethods) > 0 {
|
||||||
|
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
|
||||||
|
allowHeader := c.allowHeaderBuf[:0]
|
||||||
|
for i, method := range allowedMethods {
|
||||||
|
if i > 0 {
|
||||||
|
allowHeader = append(allowHeader, ',', ' ')
|
||||||
|
}
|
||||||
|
allowHeader = append(allowHeader, method...)
|
||||||
|
}
|
||||||
|
c.allowHeaderBuf = allowHeader[:0]
|
||||||
|
c.Writer.Header().Set("Allow", string(allowHeader))
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
|
||||||
|
tempSkippedNodes := GetTempSkippedNodes()
|
||||||
|
for _, treeIter := range engine.methodTrees {
|
||||||
|
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||||
|
*tempSkippedNodes = (*tempSkippedNodes)[:0]
|
||||||
|
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
|
||||||
|
if value.handlers != nil {
|
||||||
|
PutTempSkippedNodes(tempSkippedNodes)
|
||||||
|
// 使用定义的ErrorHandle处理
|
||||||
|
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PutTempSkippedNodes(tempSkippedNodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
var notFoundHandler HandlerFunc = func(c *Context) {
|
||||||
|
engine := c.engine
|
||||||
|
engine.errorHandle.handler(c, http.StatusNotFound, errNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
// defaultErrorHandle 默认错误处理
|
// defaultErrorHandle 默认错误处理
|
||||||
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
|
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
|
||||||
select {
|
select {
|
||||||
|
|
@ -126,16 +224,22 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是
|
||||||
if c.Writer.Written() {
|
if c.Writer.Written() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if len(c.Errors) == 0 {
|
||||||
|
switch {
|
||||||
|
case code == http.StatusNotFound && errors.Is(err, errNotFound):
|
||||||
|
writeDefaultErrorJSON(c, code, defaultNotFoundBody)
|
||||||
|
return
|
||||||
|
case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed):
|
||||||
|
writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
// 输出json 状态码与状态码对应描述
|
// 输出json 状态码与状态码对应描述
|
||||||
var errMsg string
|
var errMsg string
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg = err.Error()
|
errMsg = err.Error()
|
||||||
}
|
}
|
||||||
c.JSON(code, H{
|
c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg})
|
||||||
"code": code,
|
|
||||||
"message": http.StatusText(code),
|
|
||||||
"error": errMsg,
|
|
||||||
})
|
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
|
|
@ -210,6 +314,7 @@ func New() *Engine {
|
||||||
TLSServerConfigurator: nil,
|
TLSServerConfigurator: nil,
|
||||||
GlobalMaxRequestBodySize: -1,
|
GlobalMaxRequestBodySize: -1,
|
||||||
}
|
}
|
||||||
|
engine.rebuildFallbackChains()
|
||||||
engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background())
|
engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background())
|
||||||
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
||||||
engine.SetDefaultProtocols()
|
engine.SetDefaultProtocols()
|
||||||
|
|
@ -265,16 +370,30 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) {
|
||||||
// 是否开启MethodNotAllowed
|
// 是否开启MethodNotAllowed
|
||||||
func (engine *Engine) SetHandleMethodNotAllowed(enable bool) {
|
func (engine *Engine) SetHandleMethodNotAllowed(enable bool) {
|
||||||
engine.HandleMethodNotAllowed = enable
|
engine.HandleMethodNotAllowed = enable
|
||||||
|
engine.rebuildFallbackChains()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLogger传入实例
|
// SetLogger 传入 Logger 接口实例
|
||||||
func (engine *Engine) SetLogger(logger *reco.Logger) {
|
func (engine *Engine) SetLogger(logger Logger) {
|
||||||
engine.LogReco = logger
|
engine.logger = logger
|
||||||
|
// 同步更新 LogReco 以保持向后兼容
|
||||||
|
if rl, ok := logger.(*reco.Logger); ok {
|
||||||
|
engine.LogReco = rl
|
||||||
|
} else {
|
||||||
|
engine.LogReco = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 配置日志LoggerCfg
|
// GetLogger 返回 Logger 接口实例
|
||||||
|
func (engine *Engine) GetLogger() Logger {
|
||||||
|
return engine.logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLoggerCfg 使用 reco.Config 配置日志
|
||||||
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
|
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
|
||||||
engine.LogReco = NewLogger(logcfg)
|
logger := NewLogger(logcfg)
|
||||||
|
engine.logger = logger
|
||||||
|
engine.LogReco = logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置自定义错误处理
|
// 设置自定义错误处理
|
||||||
|
|
@ -305,6 +424,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF
|
||||||
engine.unMatchFS.ServeUnmatchedAsFS = false
|
engine.unMatchFS.ServeUnmatchedAsFS = false
|
||||||
engine.UnMatchFSRoutes = nil
|
engine.UnMatchFSRoutes = nil
|
||||||
}
|
}
|
||||||
|
engine.rebuildFallbackChains()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取默认Protocol配置
|
// 获取默认Protocol配置
|
||||||
|
|
@ -319,11 +439,16 @@ func GetDefaultProtocolsConfig() *ProtocolsConfig {
|
||||||
// 设置默认Protocols
|
// 设置默认Protocols
|
||||||
func (engine *Engine) SetDefaultProtocols() {
|
func (engine *Engine) SetDefaultProtocols() {
|
||||||
engine.useDefaultProtocols = true
|
engine.useDefaultProtocols = true
|
||||||
engine.SetProtocols(GetDefaultProtocolsConfig())
|
engine.setProtocols(GetDefaultProtocolsConfig())
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置Protocol
|
// 设置Protocol
|
||||||
func (engine *Engine) SetProtocols(config *ProtocolsConfig) {
|
func (engine *Engine) SetProtocols(config *ProtocolsConfig) {
|
||||||
|
engine.setProtocols(config)
|
||||||
|
engine.useDefaultProtocols = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) setProtocols(config *ProtocolsConfig) {
|
||||||
engine.Protocols = *config
|
engine.Protocols = *config
|
||||||
engine.serverProtocols = &http.Protocols{} // 初始化指针
|
engine.serverProtocols = &http.Protocols{} // 初始化指针
|
||||||
func() {
|
func() {
|
||||||
|
|
@ -333,7 +458,30 @@ func (engine *Engine) SetProtocols(config *ProtocolsConfig) {
|
||||||
p.SetUnencryptedHTTP2(config.Http2_Cleartext)
|
p.SetUnencryptedHTTP2(config.Http2_Cleartext)
|
||||||
*engine.serverProtocols = p // 将值赋给指针指向的结构体
|
*engine.serverProtocols = p // 将值赋给指针指向的结构体
|
||||||
}()
|
}()
|
||||||
engine.useDefaultProtocols = false
|
}
|
||||||
|
|
||||||
|
func cloneServerProtocols(protocols *http.Protocols) *http.Protocols {
|
||||||
|
if protocols == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := *protocols
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyServerProtocols(srv *http.Server, protocols *http.Protocols) {
|
||||||
|
if protocols != nil {
|
||||||
|
srv.Protocols = cloneServerProtocols(protocols)
|
||||||
|
if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() {
|
||||||
|
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyDefaultServerConfig 应用框架的默认配置到 http.Server
|
||||||
|
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
|
||||||
|
applyServerProtocols(srv, engine.serverProtocols)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 配置全局Req Body大小限制
|
// 配置全局Req Body大小限制
|
||||||
|
|
@ -462,66 +610,64 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) {
|
||||||
|
|
||||||
// 405中间件
|
// 405中间件
|
||||||
func MethodNotAllowed() HandlerFunc {
|
func MethodNotAllowed() HandlerFunc {
|
||||||
return func(c *Context) {
|
return methodNotAllowedHandler
|
||||||
httpMethod := c.Request.Method
|
|
||||||
requestPath := c.Request.URL.Path
|
|
||||||
engine := c.engine
|
|
||||||
// 是否是OPTIONS方式
|
|
||||||
if httpMethod == http.MethodOptions {
|
|
||||||
// 如果是 OPTIONS 请求,尝试查找所有允许的方法
|
|
||||||
allowedMethods := []string{}
|
|
||||||
for _, treeIter := range engine.methodTrees {
|
|
||||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
|
||||||
tempSkippedNodes := GetTempSkippedNodes()
|
|
||||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
|
||||||
PutTempSkippedNodes(tempSkippedNodes)
|
|
||||||
if value.handlers != nil {
|
|
||||||
allowedMethods = append(allowedMethods, treeIter.method)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(allowedMethods) > 0 {
|
|
||||||
// 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部
|
|
||||||
c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", "))
|
|
||||||
c.Status(http.StatusOK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径
|
|
||||||
for _, treeIter := range engine.methodTrees {
|
|
||||||
if treeIter.method == httpMethod { // 已经处理过当前方法,跳过
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
|
||||||
tempSkippedNodes := GetTempSkippedNodes()
|
|
||||||
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数
|
|
||||||
PutTempSkippedNodes(tempSkippedNodes)
|
|
||||||
if value.handlers != nil {
|
|
||||||
// 使用定义的ErrorHandle处理
|
|
||||||
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 404最后处理
|
// 404最后处理
|
||||||
func NotFound() HandlerFunc {
|
func NotFound() HandlerFunc {
|
||||||
return func(c *Context) {
|
return notFoundHandler
|
||||||
engine := c.engine
|
|
||||||
engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理)
|
// 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理)
|
||||||
func (Engine *Engine) NoRoute(handler HandlerFunc) {
|
func (Engine *Engine) NoRoute(handler HandlerFunc) {
|
||||||
Engine.noRoute = handler
|
Engine.noRoute = handler
|
||||||
Engine.noRoutes = nil
|
Engine.noRoutes = nil
|
||||||
|
Engine.rebuildFallbackChains()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理)
|
// 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理)
|
||||||
func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) {
|
func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) {
|
||||||
Engine.noRoute = nil
|
Engine.noRoute = nil
|
||||||
Engine.noRoutes = handlerFuncs
|
Engine.noRoutes = handlerFuncs
|
||||||
|
Engine.rebuildFallbackChains()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) rebuildFallbackChains() {
|
||||||
|
buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain {
|
||||||
|
finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound
|
||||||
|
if includeMethodNotAllowed {
|
||||||
|
finalSize++
|
||||||
|
}
|
||||||
|
if includeUnmatchedFS {
|
||||||
|
finalSize += len(engine.UnMatchFSRoutes)
|
||||||
|
}
|
||||||
|
if engine.noRoute != nil {
|
||||||
|
finalSize++
|
||||||
|
} else {
|
||||||
|
finalSize += len(engine.noRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
chain := make(HandlersChain, 0, finalSize)
|
||||||
|
chain = append(chain, engine.globalHandlers...)
|
||||||
|
if includeMethodNotAllowed {
|
||||||
|
chain = append(chain, methodNotAllowedHandler)
|
||||||
|
}
|
||||||
|
if includeUnmatchedFS {
|
||||||
|
chain = append(chain, engine.UnMatchFSRoutes...)
|
||||||
|
}
|
||||||
|
if engine.noRoute != nil {
|
||||||
|
chain = append(chain, engine.noRoute)
|
||||||
|
} else if len(engine.noRoutes) > 0 {
|
||||||
|
chain = append(chain, engine.noRoutes...)
|
||||||
|
}
|
||||||
|
chain = append(chain, notFoundHandler)
|
||||||
|
return chain
|
||||||
|
}
|
||||||
|
|
||||||
|
engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false)
|
||||||
|
engine.notFoundNoMethodChain = buildChain(false, false)
|
||||||
|
engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS)
|
||||||
|
engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// combineHandlers 组合多个处理函数链为一个
|
// combineHandlers 组合多个处理函数链为一个
|
||||||
|
|
@ -536,8 +682,9 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
|
||||||
|
|
||||||
// Use 将全局中间件添加到 Engine
|
// Use 将全局中间件添加到 Engine
|
||||||
// 这些中间件将应用于所有注册的路由
|
// 这些中间件将应用于所有注册的路由
|
||||||
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter {
|
func (engine *Engine) Use(middleware ...HandlerFunc) Router {
|
||||||
engine.globalHandlers = append(engine.globalHandlers, middleware...)
|
engine.globalHandlers = append(engine.globalHandlers, middleware...)
|
||||||
|
engine.rebuildFallbackChains()
|
||||||
return engine
|
return engine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -604,7 +751,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo {
|
||||||
|
|
||||||
// Group 创建一个新的路由组
|
// Group 创建一个新的路由组
|
||||||
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
|
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
|
||||||
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter {
|
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router {
|
||||||
return &RouterGroup{
|
return &RouterGroup{
|
||||||
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
|
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
|
||||||
basePath: resolveRoutePath("/", relativePath),
|
basePath: resolveRoutePath("/", relativePath),
|
||||||
|
|
@ -613,7 +760,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
|
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
|
||||||
// 它也实现了 IRouter 接口,允许嵌套分组
|
// 它也实现了 Router 接口,允许嵌套分组
|
||||||
type RouterGroup struct {
|
type RouterGroup struct {
|
||||||
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
|
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
|
||||||
basePath string // 组路径前缀
|
basePath string // 组路径前缀
|
||||||
|
|
@ -622,7 +769,7 @@ type RouterGroup struct {
|
||||||
|
|
||||||
// Use 将中间件应用于当前路由组
|
// Use 将中间件应用于当前路由组
|
||||||
// 这些中间件将应用于当前组及其子组的所有路由
|
// 这些中间件将应用于当前组及其子组的所有路由
|
||||||
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter {
|
func (group *RouterGroup) Use(middleware ...HandlerFunc) Router {
|
||||||
group.Handlers = append(group.Handlers, middleware...)
|
group.Handlers = append(group.Handlers, middleware...)
|
||||||
return group
|
return group
|
||||||
}
|
}
|
||||||
|
|
@ -668,7 +815,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group 为当前组创建一个新的子组
|
// Group 为当前组创建一个新的子组
|
||||||
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter {
|
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router {
|
||||||
return &RouterGroup{
|
return &RouterGroup{
|
||||||
Handlers: group.engine.combineHandlers(group.Handlers, handlers),
|
Handlers: group.engine.combineHandlers(group.Handlers, handlers),
|
||||||
basePath: resolveRoutePath(group.basePath, relativePath),
|
basePath: resolveRoutePath(group.basePath, relativePath),
|
||||||
|
|
@ -693,8 +840,13 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
// handleRequest 负责根据请求查找路由并执行相应的处理函数链
|
// handleRequest 负责根据请求查找路由并执行相应的处理函数链
|
||||||
// 这是路由查找和执行的核心逻辑
|
// 这是路由查找和执行的核心逻辑
|
||||||
func (engine *Engine) handleRequest(c *Context) {
|
func (engine *Engine) handleRequest(c *Context) {
|
||||||
|
if isGeneralOptionsRequest(c.Request) {
|
||||||
|
engine.handleGeneralOptions(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
httpMethod := c.Request.Method
|
httpMethod := c.Request.Method
|
||||||
requestPath := c.Request.URL.Path
|
requestPath := routeLookupPath(c.Request)
|
||||||
|
|
||||||
// 查找对应的路由树的根节点
|
// 查找对应的路由树的根节点
|
||||||
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
|
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
|
||||||
|
|
@ -714,7 +866,7 @@ func (engine *Engine) handleRequest(c *Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
|
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
|
||||||
if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向
|
if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向
|
||||||
if value.tsr && engine.RedirectTrailingSlash {
|
if value.tsr && engine.RedirectTrailingSlash {
|
||||||
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
|
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
|
||||||
redirectPath := requestPath
|
redirectPath := requestPath
|
||||||
|
|
@ -726,51 +878,98 @@ func (engine *Engine) handleRequest(c *Context) {
|
||||||
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
|
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 尝试不区分大小写的查找
|
if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) {
|
||||||
// 直接在 rootNode 上调用 findCaseInsensitivePath 方法
|
// 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历.
|
||||||
ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash)
|
ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash)
|
||||||
if found && engine.RedirectFixedPath {
|
if found {
|
||||||
c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径
|
c.fixedPathBuf = ciPath[:0]
|
||||||
|
c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.fixedPathBuf = c.fixedPathBuf[:0]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建处理链
|
|
||||||
// 组合全局中间件和路由处理函数
|
|
||||||
handlers := engine.globalHandlers
|
|
||||||
|
|
||||||
// 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由
|
|
||||||
// 则在全局中间件之后添加 MethodNotAllowed 处理器
|
|
||||||
if engine.HandleMethodNotAllowed {
|
|
||||||
handlers = append(handlers, MethodNotAllowed())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed
|
|
||||||
// 则在处理链的最后添加 UnMatchFS 处理器
|
|
||||||
if engine.unMatchFS.ServeUnmatchedAsFS {
|
if engine.unMatchFS.ServeUnmatchedAsFS {
|
||||||
/*
|
c.handlers = engine.unmatchedFSChain
|
||||||
var unMatchFSHandle = c.engine.unMatchFileServer
|
} else {
|
||||||
handlers = append(handlers, unMatchFSHandle)
|
c.handlers = engine.notFoundChain
|
||||||
*/
|
|
||||||
handlers = append(handlers, engine.UnMatchFSRoutes...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS
|
|
||||||
// 则在处理链的最后添加 NoRoute 处理器
|
|
||||||
if engine.noRoute != nil {
|
|
||||||
handlers = append(handlers, engine.noRoute)
|
|
||||||
} else if len(engine.noRoutes) > 0 {
|
|
||||||
handlers = append(handlers, engine.noRoutes...)
|
|
||||||
}
|
|
||||||
|
|
||||||
handlers = append(handlers, NotFound())
|
|
||||||
|
|
||||||
c.handlers = handlers
|
|
||||||
c.Next() // 执行处理函数链
|
c.Next() // 执行处理函数链
|
||||||
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func routeLookupPath(req *http.Request) string {
|
||||||
|
if req == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") {
|
||||||
|
return "/" + req.RequestURI
|
||||||
|
}
|
||||||
|
if isGeneralOptionsRequest(req) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if req.URL == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return req.URL.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
func isGeneralOptionsRequest(req *http.Request) bool {
|
||||||
|
return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*"
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldTryFixedPathLookup(path string, root *node) bool {
|
||||||
|
if root != nil && root.hasCaseInsensitivePath {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for i := 0; i < len(path); i++ {
|
||||||
|
c := path[i]
|
||||||
|
if c >= utf8.RuneSelf {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if c >= 'A' && c <= 'Z' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string {
|
||||||
|
if cap(allowedMethods) < len(engine.methodTrees) {
|
||||||
|
allowedMethods = make([]string, 0, len(engine.methodTrees))
|
||||||
|
} else {
|
||||||
|
allowedMethods = allowedMethods[:0]
|
||||||
|
}
|
||||||
|
tempSkippedNodes := GetTempSkippedNodes()
|
||||||
|
for _, treeIter := range engine.methodTrees {
|
||||||
|
// 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型
|
||||||
|
*tempSkippedNodes = (*tempSkippedNodes)[:0]
|
||||||
|
value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false)
|
||||||
|
if value.handlers != nil {
|
||||||
|
allowedMethods = append(allowedMethods, treeIter.method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PutTempSkippedNodes(tempSkippedNodes)
|
||||||
|
return allowedMethods
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) handleGeneralOptions(c *Context) {
|
||||||
|
if c == nil || c.Request == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Length", "0")
|
||||||
|
if c.Request.ContentLength != 0 {
|
||||||
|
mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10)
|
||||||
|
_, _ = io.Copy(io.Discard, mb)
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
|
// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消.
|
||||||
// 它可以用于在长连接 (如 SSE) 中监听关闭信号.
|
// 它可以用于在长连接 (如 SSE) 中监听关闭信号.
|
||||||
func (engine *Engine) Context() context.Context {
|
func (engine *Engine) Context() context.Context {
|
||||||
|
|
|
||||||
71
engine_benchmark_test.go
Normal file
71
engine_benchmark_test.go
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var benchmarkStatusCode int
|
||||||
|
|
||||||
|
func buildServeHTTPBenchmarkEngine() *Engine {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/api/v1/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.GET("/api/v1/users/:id", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.POST("/api/v1/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
return engine
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
req, err := http.NewRequest(method, path, nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
engine.ServeHTTP(rr, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkStatusCode = rr.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkServeHTTP(b *testing.B) {
|
||||||
|
engine := buildServeHTTPBenchmarkEngine()
|
||||||
|
|
||||||
|
b.Run("StaticHit", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("NotFound", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("MethodNotAllowed", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("OptionsAllow", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("FixedPathRedirect", func(b *testing.B) {
|
||||||
|
benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS")
|
||||||
|
})
|
||||||
|
}
|
||||||
306
engine_test.go
Normal file
306
engine_test.go
Normal file
|
|
@ -0,0 +1,306 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"html/template"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type failingResponseWriter struct {
|
||||||
|
header http.Header
|
||||||
|
status int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingResponseWriter) Header() http.Header {
|
||||||
|
if w.header == nil {
|
||||||
|
w.header = make(http.Header)
|
||||||
|
}
|
||||||
|
return w.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
if w.status == 0 {
|
||||||
|
w.status = statusCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingResponseWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.status == 0 {
|
||||||
|
w.status = http.StatusOK
|
||||||
|
}
|
||||||
|
if w.err != nil {
|
||||||
|
return 0, w.err
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingResponseWriter) Flush() {}
|
||||||
|
|
||||||
|
func (w *failingResponseWriter) Status() int {
|
||||||
|
return w.status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingResponseWriter) Size() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingResponseWriter) Written() bool {
|
||||||
|
return w.status != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingResponseWriter) IsHijacked() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return nil, nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleRequestRedirectFixedPath(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil)
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" {
|
||||||
|
t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/api/v1/users/:id/settings", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil)
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/Users/Profile", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil)
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "/Users/Profile" {
|
||||||
|
t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/Users/Profile", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Fatalf("unexpected panic for fixed-path miss: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil)
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.NoRoute(func(c *Context) {
|
||||||
|
c.Writer.Header().Set("X-NoRoute", "hit")
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code)
|
||||||
|
}
|
||||||
|
if got := rr.Header().Get("X-NoRoute"); got != "hit" {
|
||||||
|
t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.NoRoute(func(c *Context) {
|
||||||
|
c.Writer.Header().Set("X-NoRoute", "hit")
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
|
||||||
|
if rr.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
|
||||||
|
}
|
||||||
|
if got := rr.Header().Get("X-NoRoute"); got != "" {
|
||||||
|
t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
engine.POST("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code)
|
||||||
|
}
|
||||||
|
allow := rr.Header().Get("Allow")
|
||||||
|
if allow != "GET, POST" && allow != "POST, GET" {
|
||||||
|
t.Fatalf("expected Allow header to list matching methods, got %q", allow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultErrorHandleJSONShape(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil)
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
|
||||||
|
}
|
||||||
|
if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" {
|
||||||
|
t.Fatalf("unexpected error payload: %+v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultMethodNotAllowedJSONShape(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
|
||||||
|
if rr.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
|
||||||
|
}
|
||||||
|
if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" {
|
||||||
|
t.Fatalf("unexpected error payload: %+v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.SetErrorHandler(func(c *Context, code int, err error) {
|
||||||
|
c.Writer.Header().Set("X-Custom-Error", "1")
|
||||||
|
c.String(code, "custom:%v", err)
|
||||||
|
})
|
||||||
|
engine.GET("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
|
||||||
|
if rr.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
|
||||||
|
}
|
||||||
|
if got := rr.Header().Get("X-Custom-Error"); got != "1" {
|
||||||
|
t.Fatalf("expected custom error header, got %q", got)
|
||||||
|
}
|
||||||
|
if rr.Body.String() != "custom:method not allowed" {
|
||||||
|
t.Fatalf("expected custom error body, got %q", rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseHelpersCaptureWriteErrors(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
run func(*Context)
|
||||||
|
}{
|
||||||
|
{name: "Raw", run: func(c *Context) { c.Raw(http.StatusOK, "application/octet-stream", []byte("payload")) }},
|
||||||
|
{name: "String", run: func(c *Context) { c.String(http.StatusOK, "value=%d", 1) }},
|
||||||
|
{name: "Text", run: func(c *Context) { c.Text(http.StatusOK, "payload") }},
|
||||||
|
{name: "JSONBuf", run: func(c *Context) { c.JSONBuf(http.StatusOK, map[string]string{"a": "b"}) }},
|
||||||
|
{name: "GOBBuf", run: func(c *Context) { c.GOBBuf(http.StatusOK, struct{ A string }{A: "b"}) }},
|
||||||
|
{name: "WANFBuf", run: func(c *Context) { c.WANFBuf(http.StatusOK, map[string]string{"a": "b"}) }},
|
||||||
|
{name: "HTMLFallback", run: func(c *Context) { c.HTML(http.StatusOK, "page", map[string]string{"a": "b"}) }},
|
||||||
|
{name: "HTMLBuf", run: func(c *Context) {
|
||||||
|
c.engine.HTMLRender = template.Must(template.New("page").Parse(`{{.a}}`))
|
||||||
|
c.HTMLBuf(http.StatusOK, "page", map[string]string{"a": "b"})
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
writerErr := errors.New("write failed")
|
||||||
|
w := &failingResponseWriter{err: writerErr}
|
||||||
|
c, _ := CreateTestContext(w)
|
||||||
|
|
||||||
|
tc.run(c)
|
||||||
|
|
||||||
|
if got := len(c.Errors); got != 1 {
|
||||||
|
t.Fatalf("expected exactly one captured error, got %d", got)
|
||||||
|
}
|
||||||
|
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
|
||||||
|
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultErrorFastPathCapturesWriteErrors(t *testing.T) {
|
||||||
|
writerErr := errors.New("write failed")
|
||||||
|
w := &failingResponseWriter{err: writerErr}
|
||||||
|
engine := New()
|
||||||
|
c, _ := CreateTestContext(w)
|
||||||
|
c.engine = engine
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/missing", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to build request: %v", err)
|
||||||
|
}
|
||||||
|
c.reset(w, req)
|
||||||
|
|
||||||
|
defaultErrorHandle(c, http.StatusNotFound, errNotFound)
|
||||||
|
|
||||||
|
if len(c.Errors) == 0 {
|
||||||
|
t.Fatal("expected write error to be captured")
|
||||||
|
}
|
||||||
|
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
|
||||||
|
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
|
||||||
|
}
|
||||||
|
if c.Writer.Status() != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusNotFound, c.Writer.Status())
|
||||||
|
}
|
||||||
|
if !c.IsAborted() {
|
||||||
|
t.Fatal("expected fast path to abort context")
|
||||||
|
}
|
||||||
|
}
|
||||||
103
examples/httpc/main.go
Normal file
103
examples/httpc/main.go
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/infinite-iroha/touka"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
r := touka.Default()
|
||||||
|
|
||||||
|
// 示例 1:简单 GET 请求(自动关联请求 Context)
|
||||||
|
r.GET("/proxy", func(c *touka.Context) {
|
||||||
|
// 使用 HTTPC() 方法,自动关联请求 Context
|
||||||
|
// 当客户端断开连接时,出站请求也会自动取消
|
||||||
|
body, err := c.HTTPC().
|
||||||
|
GET("https://httpbin.org/get").
|
||||||
|
Text()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.String(http.StatusOK, "%s", body)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 示例 2:带 Header 的 POST 请求
|
||||||
|
r.POST("/users", func(c *touka.Context) {
|
||||||
|
var req struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 链式调用,保持 httpc 风格
|
||||||
|
// 注意:SetJSONBody 返回 (*RequestBuilder, error)
|
||||||
|
rb, err := c.HTTPC().
|
||||||
|
POST("https://httpbin.org/post").
|
||||||
|
SetHeader("X-API-Key", "secret").
|
||||||
|
SetJSONBody(req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := rb.DecodeJSON(&result); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 示例 3:带查询参数的请求
|
||||||
|
r.GET("/search", func(c *touka.Context) {
|
||||||
|
query := c.DefaultQuery("q", "")
|
||||||
|
page := c.DefaultQuery("page", "1")
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Items []string `json:"items"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.HTTPC().
|
||||||
|
GET("https://httpbin.org/get").
|
||||||
|
SetQueryParam("q", query).
|
||||||
|
SetQueryParam("page", page).
|
||||||
|
DecodeJSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 示例 4:使用底层 httpc.Client(旧方式,仍可用但不推荐)
|
||||||
|
r.GET("/legacy", func(c *touka.Context) {
|
||||||
|
// 旧方式:需要手动 WithContext
|
||||||
|
body, err := c.Client().
|
||||||
|
GET("https://httpbin.org/get").
|
||||||
|
WithContext(c.Context()).
|
||||||
|
Text()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.String(http.StatusOK, "%s", body)
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Println("Server running on :8080")
|
||||||
|
fmt.Println("Try:")
|
||||||
|
fmt.Println(" curl http://localhost:8080/proxy")
|
||||||
|
fmt.Println(" curl -X POST -d '{\"name\":\"test\",\"email\":\"test@example.com\"}' http://localhost:8080/users")
|
||||||
|
fmt.Println(" curl 'http://localhost:8080/search?q=golang&page=1'")
|
||||||
|
|
||||||
|
// r.Run(touka.WithAddr(":8080"))
|
||||||
|
}
|
||||||
71
examples/logger_slog/main.go
Normal file
71
examples/logger_slog/main.go
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/infinite-iroha/touka"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SlogAdapter 将 slog.Logger 适配到 touka.Logger 接口
|
||||||
|
type SlogAdapter struct {
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSlogAdapter(handler slog.Handler) *SlogAdapter {
|
||||||
|
return &SlogAdapter{
|
||||||
|
logger: slog.New(handler),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SlogAdapter) Debugf(format string, args ...any) {
|
||||||
|
s.logger.Debug(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SlogAdapter) Infof(format string, args ...any) {
|
||||||
|
s.logger.Info(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SlogAdapter) Warnf(format string, args ...any) {
|
||||||
|
s.logger.Warn(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SlogAdapter) Errorf(format string, args ...any) {
|
||||||
|
s.logger.Error(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SlogAdapter) Fatalf(format string, args ...any) {
|
||||||
|
s.logger.Error(fmt.Sprintf(format, args...))
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SlogAdapter) Panicf(format string, args ...any) {
|
||||||
|
s.logger.Error(fmt.Sprintf(format, args...))
|
||||||
|
panic(fmt.Sprintf(format, args...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
engine := touka.New()
|
||||||
|
|
||||||
|
// 使用 slog 替换默认的 reco.Logger
|
||||||
|
handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
||||||
|
Level: slog.LevelDebug,
|
||||||
|
})
|
||||||
|
slogAdapter := NewSlogAdapter(handler)
|
||||||
|
engine.SetLogger(slogAdapter)
|
||||||
|
|
||||||
|
engine.GET("/", func(c *touka.Context) {
|
||||||
|
c.Infof("request received: %s", c.Request.URL.Path)
|
||||||
|
c.JSON(http.StatusOK, map[string]string{"message": "hello"})
|
||||||
|
})
|
||||||
|
|
||||||
|
// 也可以获取 Logger 接口
|
||||||
|
logger := engine.GetLogger()
|
||||||
|
logger.Debugf("engine started")
|
||||||
|
|
||||||
|
// 也可以直接使用 slog
|
||||||
|
slog.Info("Server running", "addr", ":8080")
|
||||||
|
// engine.Run(":8080")
|
||||||
|
}
|
||||||
7
go.mod
7
go.mod
|
|
@ -3,14 +3,15 @@ module github.com/infinite-iroha/touka
|
||||||
go 1.26
|
go 1.26
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2
|
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3
|
||||||
github.com/WJQSERVER-STUDIO/httpc v0.8.3
|
github.com/WJQSERVER-STUDIO/httpc v0.9.3
|
||||||
github.com/WJQSERVER/wanf v0.0.8
|
github.com/WJQSERVER/wanf v0.0.8
|
||||||
github.com/fenthope/reco v0.0.5
|
github.com/fenthope/reco v0.0.5
|
||||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
|
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433
|
||||||
|
golang.org/x/net v0.53.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
golang.org/x/net v0.50.0 // indirect
|
golang.org/x/text v0.36.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
|
||||||
14
go.sum
14
go.sum
|
|
@ -1,7 +1,7 @@
|
||||||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM=
|
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 h1:Hc1O6D50U3URkdSzfQ/SgeUU750wUBCYhefdvAbE2Ck=
|
||||||
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok=
|
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3/go.mod h1:nFQzepAwwdj5Hp5U+X19l4FVvsaOSBTW41BzfI/CkMA=
|
||||||
github.com/WJQSERVER-STUDIO/httpc v0.8.3 h1:g3CvOimwPonQuKDfbH8Ex35J/VSz+W1k5Q1FiHg2xn8=
|
github.com/WJQSERVER-STUDIO/httpc v0.9.3 h1:wYZkz9f/+2WuDuzPlExebvnn0q6QeArM15Y51HJ5UUI=
|
||||||
github.com/WJQSERVER-STUDIO/httpc v0.8.3/go.mod h1:/+NKun9LIUW5YFdvpOf7JbChSVsvdySOGn04FB3rTPg=
|
github.com/WJQSERVER-STUDIO/httpc v0.9.3/go.mod h1:vtaDmN/8gN8Es1DJsGvvrFr8kErysJndu87i+KOWUHY=
|
||||||
github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8=
|
github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8=
|
||||||
github.com/WJQSERVER/wanf v0.0.8/go.mod h1:R0Zw/1skEMVlQ9m5atbkmanlW+9h2bkdq7+wbPY+F/8=
|
github.com/WJQSERVER/wanf v0.0.8/go.mod h1:R0Zw/1skEMVlQ9m5atbkmanlW+9h2bkdq7+wbPY+F/8=
|
||||||
github.com/fenthope/reco v0.0.5 h1:Z/bOunFf4LSgYP/IxG9fe2pTrIq7bPsDflflbNR5Agw=
|
github.com/fenthope/reco v0.0.5 h1:Z/bOunFf4LSgYP/IxG9fe2pTrIq7bPsDflflbNR5Agw=
|
||||||
|
|
@ -10,5 +10,7 @@ github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVw
|
||||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
|
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||||
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
|
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||||
|
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||||
|
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||||
|
|
|
||||||
88
http2xconnect.go
Normal file
88
http2xconnect.go
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||||
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
_ "unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var enableHTTP2ExtendedConnectOnce sync.Once
|
||||||
|
|
||||||
|
//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol
|
||||||
|
var xnetDisableHTTP2ExtendedConnectProtocol bool
|
||||||
|
|
||||||
|
func enableHTTP2ExtendedConnectProtocol() {
|
||||||
|
enableHTTP2ExtendedConnectOnce.Do(func() {
|
||||||
|
xnetDisableHTTP2ExtendedConnectProtocol = false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureHTTP2ExtendedConnectServer(srv *http.Server) error {
|
||||||
|
if srv == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
return http2.ConfigureServer(srv, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP2ExtendedConnectTransport() http.RoundTripper {
|
||||||
|
enableHTTP2ExtendedConnectProtocol()
|
||||||
|
transport := cloneDefaultTransport()
|
||||||
|
transport.Protocols = new(http.Protocols)
|
||||||
|
transport.Protocols.SetHTTP1(true)
|
||||||
|
transport.Protocols.SetHTTP2(true)
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP1BridgeTransport() http.RoundTripper {
|
||||||
|
return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper {
|
||||||
|
transport := cloneDefaultTransport()
|
||||||
|
transport.Protocols = new(http.Protocols)
|
||||||
|
transport.Protocols.SetHTTP1(true)
|
||||||
|
if tlsConfig == nil {
|
||||||
|
transport.TLSClientConfig = &tls.Config{}
|
||||||
|
} else {
|
||||||
|
transport.TLSClientConfig = tlsConfig.Clone()
|
||||||
|
}
|
||||||
|
if len(transport.TLSClientConfig.NextProtos) == 0 {
|
||||||
|
transport.TLSClientConfig.NextProtos = []string{"http/1.1"}
|
||||||
|
}
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func newH2CTransport() http.RoundTripper {
|
||||||
|
transport := cloneDefaultTransport()
|
||||||
|
transport.Protocols = new(http.Protocols)
|
||||||
|
transport.Protocols.SetUnencryptedHTTP2(true)
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneDefaultTransport() *http.Transport {
|
||||||
|
if transport, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||||
|
return transport.Clone()
|
||||||
|
}
|
||||||
|
return &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
150
iox_benchmark_test.go
Normal file
150
iox_benchmark_test.go
Normal file
|
|
@ -0,0 +1,150 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/WJQSERVER-STUDIO/go-utils/iox"
|
||||||
|
)
|
||||||
|
|
||||||
|
type benchmarkResetReader struct {
|
||||||
|
data []byte
|
||||||
|
off int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *benchmarkResetReader) Read(p []byte) (int, error) {
|
||||||
|
if r.off >= len(r.data) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(p, r.data[r.off:])
|
||||||
|
r.off += n
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *benchmarkResetReader) Reset() {
|
||||||
|
r.off = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type benchmarkDiscardWriter struct{}
|
||||||
|
|
||||||
|
func (benchmarkDiscardWriter) Write(p []byte) (int, error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var benchmarkIOXResult int64
|
||||||
|
var benchmarkIOXBytes []byte
|
||||||
|
|
||||||
|
func BenchmarkIOXCopyComparison(b *testing.B) {
|
||||||
|
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
|
||||||
|
|
||||||
|
b.Run("io.Copy", func(b *testing.B) {
|
||||||
|
r := &benchmarkResetReader{data: payload}
|
||||||
|
w := benchmarkDiscardWriter{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.Reset()
|
||||||
|
n, err := io.Copy(w, r)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("io.Copy failed: %v", err)
|
||||||
|
}
|
||||||
|
benchmarkIOXResult = n
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("iox.Copy", func(b *testing.B) {
|
||||||
|
r := &benchmarkResetReader{data: payload}
|
||||||
|
w := benchmarkDiscardWriter{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.Reset()
|
||||||
|
n, err := iox.Copy(w, r)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("iox.Copy failed: %v", err)
|
||||||
|
}
|
||||||
|
benchmarkIOXResult = n
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIOXCopyBufferComparison(b *testing.B) {
|
||||||
|
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
|
||||||
|
|
||||||
|
b.Run("io.CopyBuffer", func(b *testing.B) {
|
||||||
|
r := &benchmarkResetReader{data: payload}
|
||||||
|
w := benchmarkDiscardWriter{}
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.Reset()
|
||||||
|
n, err := io.CopyBuffer(w, r, buf)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("io.CopyBuffer failed: %v", err)
|
||||||
|
}
|
||||||
|
benchmarkIOXResult = n
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("iox.CopyBuffer", func(b *testing.B) {
|
||||||
|
r := &benchmarkResetReader{data: payload}
|
||||||
|
w := benchmarkDiscardWriter{}
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.Reset()
|
||||||
|
n, err := iox.CopyBuffer(w, r, buf)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("iox.CopyBuffer failed: %v", err)
|
||||||
|
}
|
||||||
|
benchmarkIOXResult = n
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIOXReadAllComparison(b *testing.B) {
|
||||||
|
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
|
||||||
|
|
||||||
|
b.Run("io.ReadAll", func(b *testing.B) {
|
||||||
|
r := &benchmarkResetReader{data: payload}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.Reset()
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("io.ReadAll failed: %v", err)
|
||||||
|
}
|
||||||
|
benchmarkIOXBytes = data
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("iox.ReadAll", func(b *testing.B) {
|
||||||
|
r := &benchmarkResetReader{data: payload}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
r.Reset()
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("iox.ReadAll failed: %v", err)
|
||||||
|
}
|
||||||
|
benchmarkIOXBytes = data
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
23
logger.go
Normal file
23
logger.go
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
// Copyright 2024 WJQSERVER. All rights reserved.
|
||||||
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||||
|
package touka
|
||||||
|
|
||||||
|
// Logger 是日志接口,支持多种日志库实现(reco、zap、logrus 等)
|
||||||
|
// 用户可以通过实现此接口来替换默认的日志实现
|
||||||
|
type Logger interface {
|
||||||
|
Debugf(format string, args ...any)
|
||||||
|
Infof(format string, args ...any)
|
||||||
|
Warnf(format string, args ...any)
|
||||||
|
Errorf(format string, args ...any)
|
||||||
|
Fatalf(format string, args ...any)
|
||||||
|
Panicf(format string, args ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloserLogger 可选扩展接口,支持关闭操作
|
||||||
|
// 如果 Logger 实现了此接口,Engine 在关闭时会调用 Close()
|
||||||
|
type CloserLogger interface {
|
||||||
|
Logger
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
@ -39,7 +39,16 @@ func CloseLogger(logger *reco.Logger) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloseLogger 关闭 Engine 的日志实现
|
||||||
|
// 如果 logger 实现了 CloserLogger 接口,会调用其 Close 方法
|
||||||
func (engine *Engine) CloseLogger() {
|
func (engine *Engine) CloseLogger() {
|
||||||
|
if cl, ok := engine.logger.(CloserLogger); ok {
|
||||||
|
if err := cl.Close(); err != nil {
|
||||||
|
log.Printf("Close Logger Error: %s", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 兼容旧代码
|
||||||
if engine.LogReco != nil {
|
if engine.LogReco != nil {
|
||||||
CloseLogger(engine.LogReco)
|
CloseLogger(engine.LogReco)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
67
maxreader.go
67
maxreader.go
|
|
@ -23,19 +23,21 @@ type maxBytesReader struct {
|
||||||
n int64
|
n int64
|
||||||
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
|
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
|
||||||
read atomic.Int64
|
read atomic.Int64
|
||||||
|
// emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读.
|
||||||
|
emptyAtLimit atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
|
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
|
||||||
// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误.
|
// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误.
|
||||||
//
|
//
|
||||||
// 如果 r 为 nil, 会 panic.
|
// 如果 r 为 nil, 会 panic.
|
||||||
// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r.
|
// 如果 n 小于等于 0, 则读取不受限制, 直接返回原始的 r.
|
||||||
func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
panic("NewMaxBytesReader called with a nil reader")
|
panic("NewMaxBytesReader called with a nil reader")
|
||||||
}
|
}
|
||||||
// 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser.
|
// 如果限制为非正数, 意味着不限制, 直接返回原始的 ReadCloser.
|
||||||
if n < 0 {
|
if n <= 0 {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
return &maxBytesReader{
|
return &maxBytesReader{
|
||||||
|
|
@ -46,48 +48,53 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
||||||
|
|
||||||
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
|
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
|
||||||
func (mbr *maxBytesReader) Read(p []byte) (int, error) {
|
func (mbr *maxBytesReader) Read(p []byte) (int, error) {
|
||||||
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
|
if len(p) == 0 {
|
||||||
readSoFar := mbr.read.Load()
|
return 0, nil
|
||||||
|
|
||||||
// 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误.
|
|
||||||
if readSoFar >= mbr.n {
|
|
||||||
return 0, ErrBodyTooLarge
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算当前还可以读取多少字节.
|
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
|
||||||
|
readSoFar := mbr.read.Load()
|
||||||
remaining := mbr.n - readSoFar
|
remaining := mbr.n - readSoFar
|
||||||
|
if remaining < 0 {
|
||||||
|
return 0, ErrBodyTooLarge
|
||||||
|
}
|
||||||
|
if remaining == 0 {
|
||||||
|
var probe [1]byte
|
||||||
|
n, err := mbr.r.Read(probe[:])
|
||||||
|
if n > 0 {
|
||||||
|
mbr.read.Add(int64(n))
|
||||||
|
return 0, ErrBodyTooLarge
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if mbr.emptyAtLimit.Swap(true) {
|
||||||
|
return 0, ErrBodyTooLarge
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
mbr.emptyAtLimit.Store(false)
|
||||||
|
|
||||||
// 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度.
|
// 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。
|
||||||
// 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数.
|
if int64(len(p))-1 > remaining {
|
||||||
if int64(len(p)) > remaining {
|
p = p[:remaining+1]
|
||||||
p = p[:remaining]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从底层 Reader 读取数据.
|
// 从底层 Reader 读取数据.
|
||||||
n, err := mbr.r.Read(p)
|
n, err := mbr.r.Read(p)
|
||||||
|
|
||||||
// 如果实际读取到了数据, 更新原子计数器.
|
if int64(n) <= remaining {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
readSoFar = mbr.read.Add(int64(n))
|
mbr.read.Add(int64(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果底层 Read 返回错误 (例如 io.EOF).
|
|
||||||
if err != nil {
|
|
||||||
// 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束.
|
|
||||||
// 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了.
|
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误.
|
// 读取结果跨过了限制,只向上层暴露允许的部分。
|
||||||
// 这是处理"跨越"限制情况的关键.
|
if remaining > 0 {
|
||||||
if readSoFar > mbr.n {
|
mbr.read.Add(remaining)
|
||||||
// 返回实际读取的字节数 n, 并附上超限错误.
|
|
||||||
// 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭.
|
|
||||||
return n, ErrBodyTooLarge
|
|
||||||
}
|
}
|
||||||
|
return int(remaining), ErrBodyTooLarge
|
||||||
// 一切正常, 返回读取的字节数和 nil 错误.
|
|
||||||
return n, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close 方法关闭底层的 ReadCloser, 保证资源释放.
|
// Close 方法关闭底层的 ReadCloser, 保证资源释放.
|
||||||
|
|
|
||||||
113
mergectx.go
113
mergectx.go
|
|
@ -11,18 +11,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
|
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
|
||||||
|
// 嵌入 cancelCtx 作为基础 context, 支持 cause 传播.
|
||||||
|
// deadlineCtx 作为 cancelCtx 的子 context, 确保 deadline 到期时 cancelCtx 也被取消.
|
||||||
type mergedContext struct {
|
type mergedContext struct {
|
||||||
// 嵌入一个基础 context, 它持有最早的 deadline 和取消信号.
|
|
||||||
context.Context
|
context.Context
|
||||||
// 保存了所有的父 context, 用于 Value() 方法的查找.
|
|
||||||
parents []context.Context
|
parents []context.Context
|
||||||
// 用于手动取消此 mergedContext 的函数.
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MergeCtx 创建并返回一个新的 context.Context.
|
// MergeCtx 创建并返回一个新的 context.Context.
|
||||||
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
|
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
|
||||||
// 自动被取消 (逻辑或关系).
|
// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context.
|
||||||
//
|
//
|
||||||
// 新的 context 会继承:
|
// 新的 context 会继承:
|
||||||
// - Deadline: 所有父 context 中最早的截止时间.
|
// - Deadline: 所有父 context 中最早的截止时间.
|
||||||
|
|
@ -32,7 +30,8 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
|
||||||
return context.WithCancel(context.Background())
|
return context.WithCancel(context.Background())
|
||||||
}
|
}
|
||||||
if len(parents) == 1 {
|
if len(parents) == 1 {
|
||||||
return context.WithCancel(parents[0])
|
ctx, cancel := context.WithCancelCause(parents[0])
|
||||||
|
return ctx, func() { cancel(nil) }
|
||||||
}
|
}
|
||||||
|
|
||||||
var earliestDeadline time.Time
|
var earliestDeadline time.Time
|
||||||
|
|
@ -44,79 +43,93 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var baseCtx context.Context
|
// cancelCtx 作为基础 context, 提供 CancelCauseFunc 以支持 cause 传播.
|
||||||
var baseCancel context.CancelFunc
|
cancelCtx, cancelCause := context.WithCancelCause(context.Background())
|
||||||
|
|
||||||
|
// deadlineCtx 作为 cancelCtx 的子 context (如果有 deadline).
|
||||||
|
// 当 cancelCtx 被取消时, deadlineCtx 也会被取消;
|
||||||
|
// 当 deadline 到期时, deadlineCtx 自行取消, watcher 负责关闭 cancelCtx.
|
||||||
|
var deadlineCtx context.Context
|
||||||
|
var deadlineCancel context.CancelFunc
|
||||||
if !earliestDeadline.IsZero() {
|
if !earliestDeadline.IsZero() {
|
||||||
baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline)
|
deadlineCtx, deadlineCancel = context.WithDeadlineCause(cancelCtx, earliestDeadline, context.DeadlineExceeded)
|
||||||
} else {
|
}
|
||||||
baseCtx, baseCancel = context.WithCancel(context.Background())
|
|
||||||
|
// 嵌入的 context: 有 deadline 时用 deadlineCtx (以返回正确的 Deadline),
|
||||||
|
// 否则用 cancelCtx.
|
||||||
|
embedCtx := cancelCtx
|
||||||
|
if deadlineCtx != nil {
|
||||||
|
embedCtx = deadlineCtx
|
||||||
}
|
}
|
||||||
|
|
||||||
mc := &mergedContext{
|
mc := &mergedContext{
|
||||||
Context: baseCtx,
|
Context: embedCtx,
|
||||||
parents: parents,
|
parents: parents,
|
||||||
cancel: baseCancel,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动一个监控 goroutine.
|
// 启动监控 goroutine, 监听 parent 取消或 deadline 到期.
|
||||||
go func() {
|
go func() {
|
||||||
defer mc.cancel()
|
// 将 cancelCtx 加入 orDone, 确保手动 cancel() 时 orDone goroutine 能退出, 防止泄漏.
|
||||||
|
parentDone := orDone(append(mc.parents, cancelCtx)...)
|
||||||
|
|
||||||
// orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭.
|
if deadlineCtx != nil {
|
||||||
// 同时监听 baseCtx.Done() 以便支持手动取消.
|
defer deadlineCancel()
|
||||||
select {
|
select {
|
||||||
case <-orDone(mc.parents...):
|
case <-parentDone:
|
||||||
case <-mc.Context.Done():
|
// parent 取消或手动 cancel()
|
||||||
|
for _, p := range mc.parents {
|
||||||
|
if p.Err() != nil {
|
||||||
|
cancelCause(context.Cause(p))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 手动 cancel(), cause 已由 cancelCause() 设置
|
||||||
|
case <-deadlineCtx.Done():
|
||||||
|
// deadline 到期, 需要关闭 cancelCtx 并设置 cause
|
||||||
|
cancelCause(context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
<-parentDone
|
||||||
|
for _, p := range mc.parents {
|
||||||
|
if p.Err() != nil {
|
||||||
|
cancelCause(context.Cause(p))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return mc, mc.cancel
|
return mc, func() { cancelCause(nil) }
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value 返回当前Ctx Value
|
// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause),
|
||||||
|
// 再按传入顺序从 parents 中查找.
|
||||||
func (mc *mergedContext) Value(key any) any {
|
func (mc *mergedContext) Value(key any) any {
|
||||||
return mc.Context.Value(key)
|
if v := mc.Context.Value(key); v != nil {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
for _, p := range mc.parents {
|
||||||
|
if val := p.Value(key); val != nil {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deadline 实现了 context.Context 的 Deadline 方法.
|
// Deadline, Done, Err 均由嵌入的 context.Context 提供.
|
||||||
func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) {
|
|
||||||
return mc.Context.Deadline()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Done 实现了 context.Context 的 Done 方法.
|
// orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭.
|
||||||
func (mc *mergedContext) Done() <-chan struct{} {
|
|
||||||
return mc.Context.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Err 实现了 context.Context 的 Err 方法.
|
|
||||||
func (mc *mergedContext) Err() error {
|
|
||||||
return mc.Context.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// orDone 是一个辅助函数, 返回一个 channel.
|
|
||||||
// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭.
|
|
||||||
// 这是一个非阻塞的、不会泄漏 goroutine 的实现.
|
|
||||||
func orDone(contexts ...context.Context) <-chan struct{} {
|
func orDone(contexts ...context.Context) <-chan struct{} {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
closeDone := func() {
|
|
||||||
once.Do(func() {
|
|
||||||
close(done)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// 为每个父 context 启动一个 goroutine.
|
|
||||||
for _, ctx := range contexts {
|
for _, ctx := range contexts {
|
||||||
go func(c context.Context) {
|
go func(c context.Context) {
|
||||||
select {
|
select {
|
||||||
case <-c.Done():
|
case <-c.Done():
|
||||||
closeDone()
|
once.Do(func() { close(done) })
|
||||||
case <-done:
|
case <-done:
|
||||||
// orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出.
|
|
||||||
}
|
}
|
||||||
}(ctx)
|
}(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
return done
|
return done
|
||||||
}
|
}
|
||||||
|
|
|
||||||
256
mergectx_test.go
Normal file
256
mergectx_test.go
Normal file
|
|
@ -0,0 +1,256 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeCtx_NoParents(t *testing.T) {
|
||||||
|
ctx, cancel := MergeCtx()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
t.Fatal("expected no error before cancel")
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_SingleParent(t *testing.T) {
|
||||||
|
parent, parentCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(parent)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
t.Fatal("expected no error before parent cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
parentCancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after parent cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cancel1()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after p1 cancel")
|
||||||
|
}
|
||||||
|
// p2 should still be fine
|
||||||
|
if p2.Err() != nil {
|
||||||
|
t.Fatal("expected p2 to be unaffected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel1()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cancel2()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after p2 cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_ExternalCancel(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel1()
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after external cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_CausePropagation(t *testing.T) {
|
||||||
|
testErr := errors.New("test cause")
|
||||||
|
|
||||||
|
p1, cancel1 := context.WithCancelCause(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cancel1(testErr)
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after p1 cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
cause := context.Cause(ctx)
|
||||||
|
if cause != testErr {
|
||||||
|
t.Fatalf("expected cause %v, got %v", testErr, cause)
|
||||||
|
}
|
||||||
|
cancel1(nil) // cleanup (already cancelled, no-op)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) {
|
||||||
|
testErr := errors.New("second parent cause")
|
||||||
|
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancelCause(context.Background())
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cancel2(testErr)
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after p2 cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
cause := context.Cause(ctx)
|
||||||
|
if cause != testErr {
|
||||||
|
t.Fatalf("expected cause %v, got %v", testErr, cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel1()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_Deadline_Earliest(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
early := now.Add(100 * time.Millisecond)
|
||||||
|
late := now.Add(1 * time.Hour)
|
||||||
|
|
||||||
|
p1, cancel1 := context.WithDeadline(context.Background(), late)
|
||||||
|
p2, cancel2 := context.WithDeadline(context.Background(), early)
|
||||||
|
defer cancel1()
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
dl, ok := ctx.Deadline()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected deadline to be set")
|
||||||
|
}
|
||||||
|
if !dl.Equal(early) {
|
||||||
|
t.Fatalf("expected deadline %v, got %v", early, dl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_Deadline_Expires(t *testing.T) {
|
||||||
|
p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancelP()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after deadline expires")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_ValueLookup(t *testing.T) {
|
||||||
|
type key struct{}
|
||||||
|
p1 := context.WithValue(context.Background(), key{}, "from_p1")
|
||||||
|
p2 := context.WithValue(context.Background(), key{}, "from_p2")
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
val := ctx.Value(key{})
|
||||||
|
if val != "from_p1" {
|
||||||
|
t.Fatalf("expected 'from_p1', got %v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) {
|
||||||
|
type key1 struct{}
|
||||||
|
type key2 struct{}
|
||||||
|
p1 := context.WithValue(context.Background(), key1{}, "val1")
|
||||||
|
p2 := context.WithValue(context.Background(), key2{}, "val2")
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if v := ctx.Value(key1{}); v != "val1" {
|
||||||
|
t.Fatalf("expected 'val1', got %v", v)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(key2{}); v != "val2" {
|
||||||
|
t.Fatalf("expected 'val2', got %v", v)
|
||||||
|
}
|
||||||
|
if v := ctx.Value("missing"); v != nil {
|
||||||
|
t.Fatalf("expected nil, got %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_ContextInterface(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel1()
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, _ = MergeCtx(p1, p2)
|
||||||
|
|
||||||
|
// Verify all Context interface methods work
|
||||||
|
_ = ctx.Done()
|
||||||
|
_ = ctx.Err()
|
||||||
|
_, _ = ctx.Deadline()
|
||||||
|
_ = ctx.Value("any")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrDone_SingleContext(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := orDone(ctx)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
<-done // should not block
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrDone_MultipleContexts(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
done := orDone(p1, p2)
|
||||||
|
|
||||||
|
cancel1()
|
||||||
|
<-done // should not block
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrDone_SecondContextCancels(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel1()
|
||||||
|
|
||||||
|
done := orDone(p1, p2)
|
||||||
|
|
||||||
|
cancel2()
|
||||||
|
<-done // should not block
|
||||||
|
}
|
||||||
94
protocols_test.go
Normal file
94
protocols_test.go
Normal file
|
|
@ -0,0 +1,94 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyDefaultServerConfig(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
|
||||||
|
// 1. 测试默认协议
|
||||||
|
srv1 := &http.Server{}
|
||||||
|
engine.applyDefaultServerConfig(srv1)
|
||||||
|
|
||||||
|
if srv1.Protocols == nil {
|
||||||
|
t.Fatal("srv1.Protocols should not be nil after applyDefaultServerConfig")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 默认配置是 Http1: true, Http2: false, Http2_Cleartext: false
|
||||||
|
if !srv1.Protocols.HTTP1() {
|
||||||
|
t.Error("Expected HTTP/1 to be enabled by default")
|
||||||
|
}
|
||||||
|
if srv1.Protocols.HTTP2() {
|
||||||
|
t.Error("Expected HTTP/2 to be disabled by default")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 测试自定义协议
|
||||||
|
engine.SetProtocols(&ProtocolsConfig{
|
||||||
|
Http1: true,
|
||||||
|
Http2: true,
|
||||||
|
Http2_Cleartext: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
srv2 := &http.Server{}
|
||||||
|
engine.applyDefaultServerConfig(srv2)
|
||||||
|
|
||||||
|
if srv2.Protocols == nil {
|
||||||
|
t.Fatal("srv2.Protocols should not be nil after applyDefaultServerConfig")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !srv2.Protocols.HTTP1() {
|
||||||
|
t.Error("Expected HTTP/1 to be enabled after SetProtocols")
|
||||||
|
}
|
||||||
|
if !srv2.Protocols.HTTP2() {
|
||||||
|
t.Error("Expected HTTP/2 to be enabled after SetProtocols")
|
||||||
|
}
|
||||||
|
if !srv2.Protocols.UnencryptedHTTP2() {
|
||||||
|
t.Error("Expected Unencrypted HTTP/2 to be enabled after SetProtocols")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 再次更改协议并验证
|
||||||
|
engine.SetProtocols(&ProtocolsConfig{
|
||||||
|
Http1: false,
|
||||||
|
Http2: true,
|
||||||
|
Http2_Cleartext: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
srv3 := &http.Server{}
|
||||||
|
engine.applyDefaultServerConfig(srv3)
|
||||||
|
|
||||||
|
if srv3.Protocols == nil {
|
||||||
|
t.Fatal("srv3.Protocols should not be nil")
|
||||||
|
}
|
||||||
|
if srv3.Protocols.HTTP1() {
|
||||||
|
t.Error("Expected HTTP/1 to be disabled")
|
||||||
|
}
|
||||||
|
if !srv3.Protocols.HTTP2() {
|
||||||
|
t.Error("Expected HTTP/2 to be enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSRunDefaultsProtocolInheritance(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
|
||||||
|
srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||||
|
|
||||||
|
if !srv.Protocols.HTTP2() {
|
||||||
|
t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 模拟用户设置了自定义协议后进入 TLS 运行模式
|
||||||
|
engine = New()
|
||||||
|
engine.SetProtocols(&ProtocolsConfig{
|
||||||
|
Http1: true,
|
||||||
|
Http2: false, // 用户明确不想要 HTTP/2
|
||||||
|
})
|
||||||
|
|
||||||
|
srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||||
|
|
||||||
|
if srv2.Protocols.HTTP2() {
|
||||||
|
t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols")
|
||||||
|
}
|
||||||
|
}
|
||||||
15
respw.go
15
respw.go
|
|
@ -45,6 +45,15 @@ func newResponseWriter(w http.ResponseWriter) ResponseWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnwrapResponseWriter returns the underlying stdlib response writer when the
|
||||||
|
// provided writer is Touka's internal wrapper.
|
||||||
|
func UnwrapResponseWriter(w ResponseWriter) http.ResponseWriter {
|
||||||
|
if wrapped, ok := w.(*responseWriterImpl); ok && wrapped.ResponseWriter != nil {
|
||||||
|
return wrapped.ResponseWriter
|
||||||
|
}
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
func (rw *responseWriterImpl) reset(w http.ResponseWriter) {
|
func (rw *responseWriterImpl) reset(w http.ResponseWriter) {
|
||||||
rw.ResponseWriter = w
|
rw.ResponseWriter = w
|
||||||
rw.status = 0
|
rw.status = 0
|
||||||
|
|
@ -56,6 +65,10 @@ func (rw *responseWriterImpl) WriteHeader(statusCode int) {
|
||||||
if rw.hijacked {
|
if rw.hijacked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if statusCode >= 100 && statusCode < 200 && statusCode != http.StatusSwitchingProtocols {
|
||||||
|
rw.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
return
|
||||||
|
}
|
||||||
if rw.status == 0 { // 确保只设置一次
|
if rw.status == 0 { // 确保只设置一次
|
||||||
rw.status = statusCode
|
rw.status = statusCode
|
||||||
rw.ResponseWriter.WriteHeader(statusCode)
|
rw.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
|
@ -100,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
// 尝试从底层 ResponseWriter 获取 Hijacker 接口
|
// 尝试从底层 ResponseWriter 获取 Hijacker 接口
|
||||||
hj, ok := rw.ResponseWriter.(http.Hijacker)
|
hj, ok := rw.ResponseWriter.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, errors.New("http.Hijacker interface not supported")
|
return nil, nil, http.ErrNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调用底层的 Hijack 方法
|
// 调用底层的 Hijack 方法
|
||||||
|
|
|
||||||
2084
reverseproxy.go
Normal file
2084
reverseproxy.go
Normal file
File diff suppressed because it is too large
Load diff
355
reverseproxy_benchmark_test.go
Normal file
355
reverseproxy_benchmark_test.go
Normal file
|
|
@ -0,0 +1,355 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type benchmarkReadSeeker struct {
|
||||||
|
data []byte
|
||||||
|
off int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *benchmarkReadSeeker) Read(p []byte) (int, error) {
|
||||||
|
if r.off >= len(r.data) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(p, r.data[r.off:])
|
||||||
|
r.off += n
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *benchmarkReadSeeker) Reset() {
|
||||||
|
r.off = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type benchmarkResponseWriter struct {
|
||||||
|
header http.Header
|
||||||
|
status int
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBenchmarkResponseWriter() *benchmarkResponseWriter {
|
||||||
|
return &benchmarkResponseWriter{header: make(http.Header)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) Header() http.Header {
|
||||||
|
return w.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
if w.status == 0 {
|
||||||
|
w.status = statusCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.status == 0 {
|
||||||
|
w.status = http.StatusOK
|
||||||
|
}
|
||||||
|
w.size += len(p)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) Flush() {}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) Status() int {
|
||||||
|
return w.status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) Size() int {
|
||||||
|
return w.size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) Written() bool {
|
||||||
|
return w.status != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) IsHijacked() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return nil, nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *benchmarkResponseWriter) reset() {
|
||||||
|
clear(w.header)
|
||||||
|
w.status = 0
|
||||||
|
w.size = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var benchmarkReverseProxySink int
|
||||||
|
|
||||||
|
func BenchmarkReverseProxyCopyResponse(b *testing.B) {
|
||||||
|
body := bytes.Repeat([]byte("0123456789abcdef"), 4096)
|
||||||
|
proxy := newReverseProxyHandler(ReverseProxyConfig{})
|
||||||
|
dst := newBenchmarkResponseWriter()
|
||||||
|
src := &benchmarkReadSeeker{data: body}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
dst.reset()
|
||||||
|
src.Reset()
|
||||||
|
if err := proxy.copyResponse(dst, src, 0); err != nil {
|
||||||
|
b.Fatalf("copyResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkReverseProxySink = dst.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReverseProxyAvailableUpstreams(b *testing.B) {
|
||||||
|
proxy := &reverseProxyHandler{
|
||||||
|
upstreams: []*reverseProxyUpstream{
|
||||||
|
{key: "a", index: 0},
|
||||||
|
{key: "b", index: 1},
|
||||||
|
{key: "c", index: 2},
|
||||||
|
{key: "d", index: 3},
|
||||||
|
},
|
||||||
|
config: ReverseProxyConfig{
|
||||||
|
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||||
|
FailDuration: time.Minute,
|
||||||
|
MaxFails: 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
proxy.upstreams[0].failures = []time.Time{now.Add(-30 * time.Second)}
|
||||||
|
proxy.upstreams[1].failures = []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
benchmarkReverseProxySink = len(proxy.availableUpstreams(now, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReverseProxySelectUpstream(b *testing.B) {
|
||||||
|
proxy := &reverseProxyHandler{
|
||||||
|
upstreams: []*reverseProxyUpstream{
|
||||||
|
{key: "a", index: 0},
|
||||||
|
{key: "b", index: 1},
|
||||||
|
{key: "c", index: 2},
|
||||||
|
{key: "d", index: 3},
|
||||||
|
},
|
||||||
|
config: ReverseProxyConfig{
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()},
|
||||||
|
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||||
|
FailDuration: time.Minute,
|
||||||
|
MaxFails: 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
proxy.upstreams[0].failures = []time.Time{time.Now().Add(-30 * time.Second)}
|
||||||
|
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
selected, err := proxy.selectUpstream(c, nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("selectUpstream failed: %v", err)
|
||||||
|
}
|
||||||
|
benchmarkReverseProxySink = selected.index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReverseProxySelectUpstreamHeaderPolicy(b *testing.B) {
|
||||||
|
proxy := &reverseProxyHandler{
|
||||||
|
upstreams: []*reverseProxyUpstream{
|
||||||
|
{key: "a", index: 0},
|
||||||
|
{key: "b", index: 1},
|
||||||
|
{key: "c", index: 2},
|
||||||
|
{key: "d", index: 3},
|
||||||
|
},
|
||||||
|
config: ReverseProxyConfig{
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b", "tenant-c"}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
selected, err := proxy.selectUpstream(c, nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("selectUpstream failed: %v", err)
|
||||||
|
}
|
||||||
|
benchmarkReverseProxySink = selected.index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) {
|
||||||
|
proxy := newReverseProxyHandler(ReverseProxyConfig{})
|
||||||
|
dst := newBenchmarkResponseWriter()
|
||||||
|
src := bytes.NewBufferString("hello, reverse proxy")
|
||||||
|
|
||||||
|
if err := proxy.copyResponse(dst, src, 0); err != nil {
|
||||||
|
t.Fatalf("copyResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := dst.Size(), len("hello, reverse proxy"); got != want {
|
||||||
|
t.Fatalf("expected %d bytes copied, got %d", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fixedLenBufferPool struct {
|
||||||
|
buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fixedLenBufferPool) Get() []byte {
|
||||||
|
return p.buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fixedLenBufferPool) Put(buf []byte) {
|
||||||
|
p.buf = buf
|
||||||
|
}
|
||||||
|
|
||||||
|
type recordingReader struct {
|
||||||
|
chunk int
|
||||||
|
reads []int
|
||||||
|
left int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *recordingReader) Read(p []byte) (int, error) {
|
||||||
|
if r.left == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := min(r.chunk, len(p), r.left)
|
||||||
|
if n == 0 {
|
||||||
|
return 0, errors.New("reader received zero-length buffer")
|
||||||
|
}
|
||||||
|
for i := range n {
|
||||||
|
p[i] = 'x'
|
||||||
|
}
|
||||||
|
r.left -= n
|
||||||
|
r.reads = append(r.reads, len(p))
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyCopyResponseRespectsCustomBufferLength(t *testing.T) {
|
||||||
|
pool := &fixedLenBufferPool{buf: make([]byte, 8, 32*1024)}
|
||||||
|
proxy := newReverseProxyHandler(ReverseProxyConfig{BufferPool: pool})
|
||||||
|
dst := newBenchmarkResponseWriter()
|
||||||
|
src := &recordingReader{chunk: 8, left: 24}
|
||||||
|
|
||||||
|
if err := proxy.copyResponse(dst, src, 0); err != nil {
|
||||||
|
t.Fatalf("copyResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src.reads) == 0 {
|
||||||
|
t.Fatal("expected reader to be used")
|
||||||
|
}
|
||||||
|
for _, size := range src.reads {
|
||||||
|
if size != 8 {
|
||||||
|
t.Fatalf("expected custom buffer length 8 to be preserved, got read size %d", size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyAvailableUpstreamsFiltersExcludedAndUnhealthy(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
proxy := &reverseProxyHandler{
|
||||||
|
upstreams: []*reverseProxyUpstream{
|
||||||
|
{key: "a"},
|
||||||
|
{key: "b", failures: []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}},
|
||||||
|
{key: "c"},
|
||||||
|
},
|
||||||
|
config: ReverseProxyConfig{
|
||||||
|
PassiveHealth: ReverseProxyPassiveHealthConfig{
|
||||||
|
FailDuration: time.Minute,
|
||||||
|
MaxFails: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
available := proxy.availableUpstreams(now, map[string]struct{}{"c": {}})
|
||||||
|
if len(available) != 1 {
|
||||||
|
t.Fatalf("expected only one available upstream, got %d", len(available))
|
||||||
|
}
|
||||||
|
if available[0].key != "a" {
|
||||||
|
t.Fatalf("expected upstream 'a', got %q", available[0].key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderPolicyUsesAllHeaderValues(t *testing.T) {
|
||||||
|
proxy := &reverseProxyHandler{
|
||||||
|
upstreams: []*reverseProxyUpstream{
|
||||||
|
{key: "a", index: 0},
|
||||||
|
{key: "b", index: 1},
|
||||||
|
{key: "c", index: 2},
|
||||||
|
},
|
||||||
|
config: ReverseProxyConfig{
|
||||||
|
LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c, _ := CreateTestContext(nil)
|
||||||
|
c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b"}
|
||||||
|
|
||||||
|
selectedA, err := proxy.selectUpstream(c, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("selectUpstream failed: %v", err)
|
||||||
|
}
|
||||||
|
selectedB, err := proxy.selectUpstream(c, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("selectUpstream failed: %v", err)
|
||||||
|
}
|
||||||
|
if selectedA.key != selectedB.key {
|
||||||
|
t.Fatalf("expected stable selection for identical multi-value header, got %q and %q", selectedA.key, selectedB.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Header["X-Tenant"] = []string{"tenant-b", "tenant-a"}
|
||||||
|
selectedC, err := proxy.selectUpstream(c, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("selectUpstream failed: %v", err)
|
||||||
|
}
|
||||||
|
if selectedC == nil {
|
||||||
|
t.Fatal("expected upstream for reordered multi-value header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderPolicyMatchesJoinCompatibility(t *testing.T) {
|
||||||
|
candidates := []*reverseProxyUpstream{
|
||||||
|
{key: "a", index: 0},
|
||||||
|
{key: "b", index: 1},
|
||||||
|
{key: "c", index: 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := [][]string{
|
||||||
|
{"tenant-a"},
|
||||||
|
{"tenant-a", "tenant-b"},
|
||||||
|
{"", "tenant-b"},
|
||||||
|
{"tenant-a", ""},
|
||||||
|
{"", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, values := range testCases {
|
||||||
|
got := reverseProxySelectHRWValues(candidates, values)
|
||||||
|
want := reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||||
|
if got == nil || want == nil {
|
||||||
|
t.Fatalf("expected non-nil upstreams for values %v", values)
|
||||||
|
}
|
||||||
|
if got.key != want.key {
|
||||||
|
t.Fatalf("expected joined compatibility for values %v, got %q want %q", values, got.key, want.key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ io.Writer = (*benchmarkResponseWriter)(nil)
|
||||||
530
reverseproxy_headers_replace_test.go
Normal file
530
reverseproxy_headers_replace_test.go
Normal file
|
|
@ -0,0 +1,530 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsReplaceSubstring(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("X-Server"); got != "Caddy" {
|
||||||
|
t.Errorf("expected X-Server=Caddy, got %q", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Location"); got != "/api/v2/resource" {
|
||||||
|
t.Errorf("expected X-Location=/api/v2/resource, got %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Replace: map[string][]Replacement{
|
||||||
|
"X-Server": {{Search: "NGINX", Replace: "Caddy"}},
|
||||||
|
"X-Location": {{Search: "v1", Replace: "v2"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||||
|
req.Header.Set("X-Server", "NGINX")
|
||||||
|
req.Header.Set("X-Location", "/api/v1/resource")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsReplaceRegexp(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("X-Route"); got != "/proxy-upstream" {
|
||||||
|
t.Errorf("expected X-Route=/proxy-upstream, got %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Replace: map[string][]Replacement{
|
||||||
|
"X-Route": {{SearchRegexp: `^/([^/]+)/(.+)$`, Replace: "/proxy-$2"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||||
|
req.Header.Set("X-Route", "/original/upstream")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsReplaceWildcard(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("X-Host-A"); got != "new.example.com" {
|
||||||
|
t.Errorf("expected X-Host-A=new.example.com, got %q", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Host-B"); got != "new.example.com" {
|
||||||
|
t.Errorf("expected X-Host-B=new.example.com, got %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Replace: map[string][]Replacement{
|
||||||
|
"*": {{Search: "old.example.com", Replace: "new.example.com"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||||
|
req.Header.Set("X-Host-A", "old.example.com")
|
||||||
|
req.Header.Set("X-Host-B", "old.example.com")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Backend", "backend-internal:8080")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ResponseHeaders: &RespHeaderOps{
|
||||||
|
HeaderOps: &HeaderOps{
|
||||||
|
Replace: map[string][]Replacement{
|
||||||
|
"X-Backend": {{Search: "backend-internal:8080", Replace: "public.example.com"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(proxy.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if got := resp.Header.Get("X-Backend"); got != "public.example.com" {
|
||||||
|
t.Errorf("expected X-Backend=public.example.com, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) {
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Replace: map[string][]Replacement{
|
||||||
|
"X-Test": {{SearchRegexp: "[invalid"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusInternalServerError {
|
||||||
|
t.Errorf("expected status 500, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplacementApply(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
r *Replacement
|
||||||
|
s string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "nil replacement", r: nil, s: "hello", want: "hello"},
|
||||||
|
{name: "empty string", r: &Replacement{Search: "x", Replace: "y"}, s: "", want: ""},
|
||||||
|
{name: "substring match", r: &Replacement{Search: "world", Replace: "go"}, s: "hello world", want: "hello go"},
|
||||||
|
{name: "substring no match", r: &Replacement{Search: "foo", Replace: "bar"}, s: "hello world", want: "hello world"},
|
||||||
|
{name: "substring multiple", r: &Replacement{Search: "a", Replace: "b"}, s: "aaa", want: "bbb"},
|
||||||
|
{name: "regexp match", r: &Replacement{SearchRegexp: `\d+`, Replace: "N", re: regexp.MustCompile(`\d+`)}, s: "abc123def", want: "abcNdef"},
|
||||||
|
{name: "regexp no match", r: &Replacement{SearchRegexp: `z+`, Replace: "Z", re: regexp.MustCompile(`z+`)}, s: "abc", want: "abc"},
|
||||||
|
{name: "empty search and regexp", r: &Replacement{}, s: "unchanged", want: "unchanged"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.r.apply(tt.s); got != tt.want {
|
||||||
|
t.Errorf("Replacement.apply() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHeaderOpsAdd(b *testing.B) {
|
||||||
|
ops := &HeaderOps{
|
||||||
|
Add: map[string][]string{
|
||||||
|
"X-Custom-1": {"value-1"},
|
||||||
|
"X-Custom-2": {"value-2"},
|
||||||
|
"X-Custom-3": {"value-3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
hdr := make(http.Header)
|
||||||
|
repl := &reverseProxyReplacer{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
hdr = make(http.Header)
|
||||||
|
ops.applyTo(hdr, repl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHeaderOpsSet(b *testing.B) {
|
||||||
|
ops := &HeaderOps{
|
||||||
|
Set: map[string][]string{
|
||||||
|
"X-Frame-Options": {"DENY"},
|
||||||
|
"X-Content-Type-Options": {"nosniff"},
|
||||||
|
"X-XSS-Protection": {"1; mode=block"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
hdr := make(http.Header)
|
||||||
|
repl := &reverseProxyReplacer{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
hdr = make(http.Header)
|
||||||
|
ops.applyTo(hdr, repl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHeaderOpsDeleteSingle(b *testing.B) {
|
||||||
|
ops := &HeaderOps{
|
||||||
|
Delete: []string{"X-Powered-By"},
|
||||||
|
}
|
||||||
|
repl := &reverseProxyReplacer{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
hdr := make(http.Header)
|
||||||
|
hdr.Set("X-Powered-By", "Express")
|
||||||
|
hdr.Set("X-Keep", "value")
|
||||||
|
ops.applyTo(hdr, repl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHeaderOpsDeleteWildcard(b *testing.B) {
|
||||||
|
ops := &HeaderOps{
|
||||||
|
Delete: []string{"X-Debug-*"},
|
||||||
|
}
|
||||||
|
repl := &reverseProxyReplacer{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
hdr := make(http.Header)
|
||||||
|
hdr.Set("X-Debug-1", "v1")
|
||||||
|
hdr.Set("X-Debug-2", "v2")
|
||||||
|
hdr.Set("X-Keep", "value")
|
||||||
|
ops.applyTo(hdr, repl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHeaderOpsReplaceSubstring(b *testing.B) {
|
||||||
|
ops := &HeaderOps{
|
||||||
|
Replace: map[string][]Replacement{
|
||||||
|
"Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repl := &reverseProxyReplacer{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
hdr := make(http.Header)
|
||||||
|
hdr.Set("Location", "http://internal:8080/api/v1/users")
|
||||||
|
ops.applyTo(hdr, repl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHeaderOpsReplaceRegexp(b *testing.B) {
|
||||||
|
re := regexp.MustCompile(`^http://([^/]+)(/.*)$`)
|
||||||
|
ops := &HeaderOps{
|
||||||
|
Replace: map[string][]Replacement{
|
||||||
|
"Location": {{SearchRegexp: `^http://([^/]+)(/.*)$`, Replace: "https://public.example.com$2", re: re}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repl := &reverseProxyReplacer{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
hdr := make(http.Header)
|
||||||
|
hdr.Set("Location", "http://internal:8080/api/v1/users")
|
||||||
|
ops.applyTo(hdr, repl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHeaderOpsReplaceWildcard(b *testing.B) {
|
||||||
|
ops := &HeaderOps{
|
||||||
|
Replace: map[string][]Replacement{
|
||||||
|
"*": {{Search: "internal.example.com", Replace: "public.example.com"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repl := &reverseProxyReplacer{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
hdr := make(http.Header)
|
||||||
|
hdr.Set("X-Host", "internal.example.com")
|
||||||
|
hdr.Set("X-Origin", "internal.example.com")
|
||||||
|
ops.applyTo(hdr, repl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHeaderOpsMixed(b *testing.B) {
|
||||||
|
ops := &HeaderOps{
|
||||||
|
Add: map[string][]string{
|
||||||
|
"X-Request-ID": {"req-123"},
|
||||||
|
},
|
||||||
|
Set: map[string][]string{
|
||||||
|
"X-Frame-Options": {"DENY"},
|
||||||
|
},
|
||||||
|
Delete: []string{"X-Powered-By"},
|
||||||
|
Replace: map[string][]Replacement{
|
||||||
|
"Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repl := &reverseProxyReplacer{}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
hdr := make(http.Header)
|
||||||
|
hdr.Set("X-Powered-By", "Express")
|
||||||
|
hdr.Set("Location", "http://internal:8080/api")
|
||||||
|
ops.applyTo(hdr, repl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReplacementApplySubstring(b *testing.B) {
|
||||||
|
r := &Replacement{Search: "old.example.com", Replace: "new.example.com"}
|
||||||
|
s := "https://old.example.com/api/v1/resource"
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = r.apply(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReplacementApplyRegexp(b *testing.B) {
|
||||||
|
r := &Replacement{SearchRegexp: `^https?://[^/]+`, Replace: "https://new.example.com", re: regexp.MustCompile(`^https?://[^/]+`)}
|
||||||
|
s := "https://old.example.com/api/v1/resource"
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = r.apply(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyReplacerDynamicVars(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com/api/v1/users?sort=name&limit=10", nil)
|
||||||
|
req.Host = "example.com"
|
||||||
|
repl := newReverseProxyReplacer(req)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"method", "{method}", "GET"},
|
||||||
|
{"host", "{host}", "example.com"},
|
||||||
|
{"path", "{path}", "/api/v1/users"},
|
||||||
|
{"query", "{query}", "sort=name&limit=10"},
|
||||||
|
{"scheme", "{scheme}", "http"},
|
||||||
|
{"proto", "{proto}", "HTTP/1.1"},
|
||||||
|
{"combined", "X-{method}-{path}", "X-GET-/api/v1/users"},
|
||||||
|
{"no vars", "static-value", "static-value"},
|
||||||
|
{"empty", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := repl.Replace(tt.input); got != tt.want {
|
||||||
|
t.Errorf("Replace(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyReplacerNilRequest(t *testing.T) {
|
||||||
|
repl := newReverseProxyReplacer(nil)
|
||||||
|
if got := repl.Replace("{method}"); got != "{method}" {
|
||||||
|
t.Errorf("expected unchanged string with nil request, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyReplacerNilReplacer(t *testing.T) {
|
||||||
|
var repl *reverseProxyReplacer
|
||||||
|
if got := repl.Replace("{method}"); got != "{method}" {
|
||||||
|
t.Errorf("expected unchanged string with nil replacer, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyReplacerFromHeader(t *testing.T) {
|
||||||
|
hdr := make(http.Header)
|
||||||
|
repl := newReverseProxyReplacerFromHeader(hdr)
|
||||||
|
if got := repl.Replace("{method}"); got != "{method}" {
|
||||||
|
t.Errorf("expected unchanged string from header replacer, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsWithDynamicVars(t *testing.T) {
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("X-Forwarded-Path"); got != "/dynamic/path" {
|
||||||
|
t.Errorf("expected X-Forwarded-Path=/dynamic/path, got %q", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Forwarded-Method"); got != "GET" {
|
||||||
|
t.Errorf("expected X-Forwarded-Method=GET, got %q", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("X-Forwarded-Host"); got != "client.example" {
|
||||||
|
t.Errorf("expected X-Forwarded-Host=client.example, got %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/dynamic/path", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Add: map[string][]string{
|
||||||
|
"X-Forwarded-Path": {"{path}"},
|
||||||
|
"X-Forwarded-Method": {"{method}"},
|
||||||
|
"X-Forwarded-Host": {"{host}"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/dynamic/path", nil)
|
||||||
|
req.Host = "client.example"
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
220
reverseproxy_headers_test.go
Normal file
220
reverseproxy_headers_test.go
Normal file
|
|
@ -0,0 +1,220 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsAdd(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("X-Custom-Header"); got != "test-value" {
|
||||||
|
t.Errorf("expected X-Custom-Header=test-value, got %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Add: map[string][]string{
|
||||||
|
"X-Custom-Header": {"test-value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(proxy.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsDelete(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("X-Sensitive") != "" {
|
||||||
|
t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive"))
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Delete: []string{"X-Sensitive"},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||||
|
req.Header.Set("X-Sensitive", "should-be-removed")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsSet(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
got := r.Header.Get("X-Replace")
|
||||||
|
if got != "new-value" {
|
||||||
|
t.Errorf("expected X-Replace=new-value, got %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Set: map[string][]string{
|
||||||
|
"X-Replace": {"new-value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||||
|
req.Header.Set("X-Replace", "old-value")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyResponseHeaderOps(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Backend", "backend-server")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ResponseHeaders: &RespHeaderOps{
|
||||||
|
HeaderOps: &HeaderOps{
|
||||||
|
Set: map[string][]string{
|
||||||
|
"X-Custom": {"custom-value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(proxy.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if got := resp.Header.Get("X-Custom"); got != "custom-value" {
|
||||||
|
t.Errorf("expected X-Custom=custom-value, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Powered-By", "Express")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ResponseHeaders: &RespHeaderOps{
|
||||||
|
HeaderOps: &HeaderOps{
|
||||||
|
Delete: []string{"X-Powered-By"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(proxy.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if got := resp.Header.Get("X-Powered-By"); got != "" {
|
||||||
|
t.Errorf("expected X-Powered-By to be deleted, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
409
reverseproxy_lb.go
Normal file
409
reverseproxy_lb.go
Normal file
|
|
@ -0,0 +1,409 @@
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
// Copyright 2026 WJQSERVER. All rights reserved.
|
||||||
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand/v2"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
|
"net/url"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReverseProxyLoadBalancingConfig configures upstream selection and retries.
|
||||||
|
type ReverseProxyLoadBalancingConfig struct {
|
||||||
|
Policy ReverseProxyLBPolicy
|
||||||
|
Retries int
|
||||||
|
TryDuration time.Duration
|
||||||
|
TryInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReverseProxyPassiveHealthConfig configures inline passive health tracking.
|
||||||
|
type ReverseProxyPassiveHealthConfig struct {
|
||||||
|
FailDuration time.Duration
|
||||||
|
MaxFails int
|
||||||
|
UnhealthyStatus []int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReverseProxyLBPolicy selects an upstream from the configured target pool.
|
||||||
|
// Use the helper constructors such as LBRandom or LBHeader to build a policy.
|
||||||
|
type ReverseProxyLBPolicy struct {
|
||||||
|
kind reverseProxyLBPolicyKind
|
||||||
|
key string
|
||||||
|
fallback *ReverseProxyLBPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
type reverseProxyLBPolicyKind uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota
|
||||||
|
reverseProxyLBPolicyRoundRobin
|
||||||
|
reverseProxyLBPolicyFirst
|
||||||
|
reverseProxyLBPolicyLeastConn
|
||||||
|
reverseProxyLBPolicyIPHash
|
||||||
|
reverseProxyLBPolicyClientIPHash
|
||||||
|
reverseProxyLBPolicyURIHash
|
||||||
|
reverseProxyLBPolicyHeader
|
||||||
|
reverseProxyLBPolicyQuery
|
||||||
|
)
|
||||||
|
|
||||||
|
type reverseProxyUpstream struct {
|
||||||
|
key string
|
||||||
|
target *url.URL
|
||||||
|
index int
|
||||||
|
useH2C bool
|
||||||
|
extendedConnectTransport http.RoundTripper
|
||||||
|
bridgeTransport http.RoundTripper
|
||||||
|
h2cTransport http.RoundTripper
|
||||||
|
inFlight atomic.Int64
|
||||||
|
|
||||||
|
passiveMu sync.Mutex
|
||||||
|
failures []time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBRandom() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBRoundRobin() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBFirst() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBLeastConn() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBIPHash() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBClientIPHash() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBURIHash() ReverseProxyLBPolicy {
|
||||||
|
return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||||
|
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))}
|
||||||
|
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||||
|
policy.fallback = &fallback
|
||||||
|
}
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||||
|
policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)}
|
||||||
|
if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil {
|
||||||
|
policy.fallback = &fallback
|
||||||
|
}
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error {
|
||||||
|
switch policy.kind {
|
||||||
|
case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst,
|
||||||
|
reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash,
|
||||||
|
reverseProxyLBPolicyURIHash:
|
||||||
|
return nil
|
||||||
|
case reverseProxyLBPolicyHeader:
|
||||||
|
if policy.key == "" {
|
||||||
|
return fmt.Errorf("reverse proxy header load-balancing policy requires a header field")
|
||||||
|
}
|
||||||
|
case reverseProxyLBPolicyQuery:
|
||||||
|
if policy.key == "" {
|
||||||
|
return fmt.Errorf("reverse proxy query load-balancing policy requires a query key")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("reverse proxy load-balancing policy is invalid")
|
||||||
|
}
|
||||||
|
if policy.fallback != nil {
|
||||||
|
return validateReverseProxyLBPolicy(*policy.fallback)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) {
|
||||||
|
now := time.Now()
|
||||||
|
policy := p.config.LoadBalancing.Policy
|
||||||
|
candidateBuf := reverseProxyCandidatePool.Get().(*[]*reverseProxyUpstream)
|
||||||
|
candidates := p.availableUpstreamsInto(now, excluded, *candidateBuf)
|
||||||
|
if len(candidates) == 0 && len(excluded) > 0 {
|
||||||
|
candidates = p.availableUpstreamsInto(now, nil, candidates[:0])
|
||||||
|
}
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
*candidateBuf = candidates[:0]
|
||||||
|
reverseProxyCandidatePool.Put(candidateBuf)
|
||||||
|
return nil, errReverseProxyNoAvailableUpstreams
|
||||||
|
}
|
||||||
|
selected := p.selectUpstreamWithPolicy(c, candidates, policy)
|
||||||
|
*candidateBuf = candidates[:0]
|
||||||
|
reverseProxyCandidatePool.Put(candidateBuf)
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream {
|
||||||
|
return p.availableUpstreamsInto(now, excluded, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) availableUpstreamsInto(now time.Time, excluded map[string]struct{}, candidates []*reverseProxyUpstream) []*reverseProxyUpstream {
|
||||||
|
if cap(candidates) < len(p.upstreams) {
|
||||||
|
candidates = make([]*reverseProxyUpstream, 0, len(p.upstreams))
|
||||||
|
} else {
|
||||||
|
candidates = candidates[:0]
|
||||||
|
}
|
||||||
|
for _, upstream := range p.upstreams {
|
||||||
|
if _, skip := excluded[upstream.key]; skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !upstream.healthy(now, p.config.PassiveHealth) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidates = append(candidates, upstream)
|
||||||
|
}
|
||||||
|
return candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch policy.kind {
|
||||||
|
case reverseProxyLBPolicyRoundRobin:
|
||||||
|
return candidates[p.nextRoundRobinIndex(len(candidates))]
|
||||||
|
case reverseProxyLBPolicyFirst:
|
||||||
|
return candidates[0]
|
||||||
|
case reverseProxyLBPolicyLeastConn:
|
||||||
|
return p.selectLeastConnUpstream(candidates)
|
||||||
|
case reverseProxyLBPolicyIPHash:
|
||||||
|
return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr))
|
||||||
|
case reverseProxyLBPolicyClientIPHash:
|
||||||
|
return reverseProxySelectHRW(candidates, c.RequestIP())
|
||||||
|
case reverseProxyLBPolicyURIHash:
|
||||||
|
if c.Request == nil || c.Request.URL == nil {
|
||||||
|
return reverseProxySelectRandom(candidates)
|
||||||
|
}
|
||||||
|
return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI())
|
||||||
|
case reverseProxyLBPolicyHeader:
|
||||||
|
if c.Request != nil && c.Request.Header != nil {
|
||||||
|
if values, ok := c.Request.Header[policy.key]; ok {
|
||||||
|
return reverseProxySelectHRWValues(candidates, values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||||
|
case reverseProxyLBPolicyQuery:
|
||||||
|
if c.Request != nil && c.Request.URL != nil {
|
||||||
|
if values, ok := c.Request.URL.Query()[policy.key]; ok {
|
||||||
|
return reverseProxySelectHRW(candidates, strings.Join(values, ","))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy))
|
||||||
|
case reverseProxyLBPolicyRandom:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
return reverseProxySelectRandom(candidates)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int {
|
||||||
|
if size <= 1 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int((p.roundRobin.Add(1) - 1) % uint64(size))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
selected := candidates[0]
|
||||||
|
lowest := selected.inFlight.Load()
|
||||||
|
ties := []*reverseProxyUpstream{selected}
|
||||||
|
for _, upstream := range candidates[1:] {
|
||||||
|
count := upstream.inFlight.Load()
|
||||||
|
switch {
|
||||||
|
case count < lowest:
|
||||||
|
selected = upstream
|
||||||
|
lowest = count
|
||||||
|
ties = []*reverseProxyUpstream{upstream}
|
||||||
|
case count == lowest:
|
||||||
|
ties = append(ties, upstream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(ties) == 1 {
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
return ties[p.nextRoundRobinIndex(len(ties))]
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(candidates) == 1 {
|
||||||
|
return candidates[0]
|
||||||
|
}
|
||||||
|
return candidates[rand.IntN(len(candidates))]
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
return reverseProxySelectRandom(candidates)
|
||||||
|
}
|
||||||
|
selected := candidates[0]
|
||||||
|
bestScore := reverseProxyHRWScore(key, selected.key)
|
||||||
|
for _, upstream := range candidates[1:] {
|
||||||
|
score := reverseProxyHRWScore(key, upstream.key)
|
||||||
|
if score > bestScore {
|
||||||
|
selected = upstream
|
||||||
|
bestScore = score
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxySelectHRWValues(candidates []*reverseProxyUpstream, values []string) *reverseProxyUpstream {
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(values) == 0 {
|
||||||
|
return reverseProxySelectRandom(candidates)
|
||||||
|
}
|
||||||
|
selected := candidates[0]
|
||||||
|
bestScore := reverseProxyHRWValuesScore(values, selected.key)
|
||||||
|
for _, upstream := range candidates[1:] {
|
||||||
|
score := reverseProxyHRWValuesScore(values, upstream.key)
|
||||||
|
if score > bestScore {
|
||||||
|
selected = upstream
|
||||||
|
bestScore = score
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyHRWScore(key, upstreamKey string) uint64 {
|
||||||
|
const (
|
||||||
|
offset64 = 14695981039346656037
|
||||||
|
prime64 = 1099511628211
|
||||||
|
)
|
||||||
|
h := uint64(offset64)
|
||||||
|
for i := 0; i < len(key); i++ {
|
||||||
|
h ^= uint64(key[i])
|
||||||
|
h *= prime64
|
||||||
|
}
|
||||||
|
h ^= 0xff
|
||||||
|
h *= prime64
|
||||||
|
for i := 0; i < len(upstreamKey); i++ {
|
||||||
|
h ^= uint64(upstreamKey[i])
|
||||||
|
h *= prime64
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyHRWValuesScore(values []string, upstreamKey string) uint64 {
|
||||||
|
const (
|
||||||
|
offset64 = 14695981039346656037
|
||||||
|
prime64 = 1099511628211
|
||||||
|
)
|
||||||
|
h := uint64(offset64)
|
||||||
|
for valueIndex, value := range values {
|
||||||
|
for i := 0; i < len(value); i++ {
|
||||||
|
h ^= uint64(value[i])
|
||||||
|
h *= prime64
|
||||||
|
}
|
||||||
|
if valueIndex+1 < len(values) {
|
||||||
|
h ^= ','
|
||||||
|
h *= prime64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h ^= 0xff
|
||||||
|
h *= prime64
|
||||||
|
for i := 0; i < len(upstreamKey); i++ {
|
||||||
|
h ^= uint64(upstreamKey[i])
|
||||||
|
h *= prime64
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy {
|
||||||
|
if policy.fallback != nil {
|
||||||
|
return *policy.fallback
|
||||||
|
}
|
||||||
|
return LBRandom()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool {
|
||||||
|
maxFails := reverseProxyPassiveMaxFails(config)
|
||||||
|
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
u.passiveMu.Lock()
|
||||||
|
defer u.passiveMu.Unlock()
|
||||||
|
u.pruneFailuresLocked(now, config.FailDuration)
|
||||||
|
return len(u.failures) < maxFails
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) {
|
||||||
|
maxFails := reverseProxyPassiveMaxFails(config)
|
||||||
|
if config.FailDuration <= 0 || maxFails <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u.passiveMu.Lock()
|
||||||
|
defer u.passiveMu.Unlock()
|
||||||
|
u.pruneFailuresLocked(now, config.FailDuration)
|
||||||
|
u.failures = append(u.failures, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) {
|
||||||
|
if len(u.failures) == 0 || window <= 0 {
|
||||||
|
if window <= 0 {
|
||||||
|
u.failures = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cutoff := now.Add(-window)
|
||||||
|
keep := 0
|
||||||
|
for _, failureAt := range u.failures {
|
||||||
|
if failureAt.Before(cutoff) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
u.failures[keep] = failureAt
|
||||||
|
keep++
|
||||||
|
}
|
||||||
|
u.failures = u.failures[:keep]
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int {
|
||||||
|
if config.FailDuration <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if config.MaxFails <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return config.MaxFails
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool {
|
||||||
|
if status <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return slices.Contains(config.UnhealthyStatus, status)
|
||||||
|
}
|
||||||
2519
reverseproxy_test.go
Normal file
2519
reverseproxy_test.go
Normal file
File diff suppressed because it is too large
Load diff
130
route_match_benchmark_test.go
Normal file
130
route_match_benchmark_test.go
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
var (
|
||||||
|
benchmarkRouteHandlers HandlersChain
|
||||||
|
benchmarkRouteFullPath string
|
||||||
|
benchmarkRouteParamsLen int
|
||||||
|
benchmarkRouteCIPath []byte
|
||||||
|
benchmarkRouteCIFound bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildRouteMatchBenchmarkTree() *node {
|
||||||
|
tree := &node{}
|
||||||
|
routes := []string{
|
||||||
|
"/",
|
||||||
|
"/health",
|
||||||
|
"/contact",
|
||||||
|
"/api/v1/users",
|
||||||
|
"/api/v1/users/:id",
|
||||||
|
"/api/v1/users/:id/settings",
|
||||||
|
"/assets/*filepath",
|
||||||
|
"/abc/b",
|
||||||
|
"/abc/:p1/cde",
|
||||||
|
"/abc/:p1/:p2/def/*filepath",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
tree.addRoute(route, fakeHandler(route))
|
||||||
|
}
|
||||||
|
|
||||||
|
return tree
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
params := make(Params, 0, 4)
|
||||||
|
skipped := make([]skippedNode, 0, 8)
|
||||||
|
|
||||||
|
value := tree.getValue(path, ¶ms, &skipped, true)
|
||||||
|
if wantFullPath == "" {
|
||||||
|
if value.handlers != nil {
|
||||||
|
b.Fatalf("expected no match for %q, got %q", path, value.fullPath)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if value.handlers == nil {
|
||||||
|
b.Fatalf("expected match for %q, got nil handlers", path)
|
||||||
|
}
|
||||||
|
if value.fullPath != wantFullPath {
|
||||||
|
b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
params = params[:0]
|
||||||
|
skipped = skipped[:0]
|
||||||
|
value = tree.getValue(path, ¶ms, &skipped, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkRouteHandlers = value.handlers
|
||||||
|
benchmarkRouteFullPath = value.fullPath
|
||||||
|
if value.params != nil {
|
||||||
|
benchmarkRouteParamsLen = len(*value.params)
|
||||||
|
} else {
|
||||||
|
benchmarkRouteParamsLen = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRouteMatch(b *testing.B) {
|
||||||
|
tree := buildRouteMatchBenchmarkTree()
|
||||||
|
|
||||||
|
b.Run("StaticHit", func(b *testing.B) {
|
||||||
|
benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ParamHit", func(b *testing.B) {
|
||||||
|
benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("BacktrackingHit", func(b *testing.B) {
|
||||||
|
benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Miss", func(b *testing.B) {
|
||||||
|
benchmarkRouteLookup(b, tree, "/does/not/exist", "")
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("CaseInsensitiveHit", func(b *testing.B) {
|
||||||
|
path := "/API/V1/USERS/123/SETTINGS"
|
||||||
|
out, found := tree.findCaseInsensitivePath(path, true)
|
||||||
|
if !found {
|
||||||
|
b.Fatalf("expected fixed-path match for %q", path)
|
||||||
|
}
|
||||||
|
if got := string(out); got != "/api/v1/users/123/settings" {
|
||||||
|
b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
out, found = tree.findCaseInsensitivePath(path, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkRouteCIPath = out
|
||||||
|
benchmarkRouteCIFound = found
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("CaseInsensitiveMiss", func(b *testing.B) {
|
||||||
|
path := "/DOES/NOT/EXIST"
|
||||||
|
out, found := tree.findCaseInsensitivePath(path, true)
|
||||||
|
if found || out != nil {
|
||||||
|
b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
out, found = tree.findCaseInsensitivePath(path, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkRouteCIPath = out
|
||||||
|
benchmarkRouteCIFound = found
|
||||||
|
})
|
||||||
|
}
|
||||||
759
serve.go
759
serve.go
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -21,329 +22,322 @@ import (
|
||||||
"github.com/fenthope/reco"
|
"github.com/fenthope/reco"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间
|
|
||||||
const defaultShutdownTimeout = 5 * time.Second
|
const defaultShutdownTimeout = 5 * time.Second
|
||||||
|
|
||||||
// --- 内部辅助函数 ---
|
type runMode uint8
|
||||||
|
|
||||||
// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080"
|
const (
|
||||||
func resolveAddress(addr []string) string {
|
runModeHTTP runMode = iota
|
||||||
switch len(addr) {
|
runModeHTTPS
|
||||||
case 0:
|
runModeHTTPSRedirect
|
||||||
return ":8080"
|
)
|
||||||
case 1:
|
|
||||||
return addr[0]
|
type runConfig struct {
|
||||||
default:
|
addr string
|
||||||
panic("too many parameters provided for server address")
|
httpRedirectAddr string
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
redirectHost string
|
||||||
|
redirectHostHeaders []string
|
||||||
|
useHeaderHost bool
|
||||||
|
useHeaderHostSet bool
|
||||||
|
graceful bool
|
||||||
|
shutdownTimeout time.Duration
|
||||||
|
gracefulCtx context.Context
|
||||||
|
mode runMode
|
||||||
|
shutdownDefaultSet bool
|
||||||
|
shutdownTimeoutSet bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type RunOption interface {
|
||||||
|
apply(*runConfig) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type runOptionFunc func(*runConfig) error
|
||||||
|
|
||||||
|
func (f runOptionFunc) apply(cfg *runConfig) error {
|
||||||
|
return f(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultRunConfig() runConfig {
|
||||||
|
return runConfig{
|
||||||
|
addr: ":8080",
|
||||||
|
shutdownTimeout: defaultShutdownTimeout,
|
||||||
|
mode: runModeHTTP,
|
||||||
|
useHeaderHost: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值
|
type HTTPRedirectOption interface {
|
||||||
func getShutdownTimeout(timeouts []time.Duration) time.Duration {
|
applyRedirect(*runConfig) error
|
||||||
if len(timeouts) > 0 && timeouts[0] > 0 {
|
|
||||||
return timeouts[0]
|
|
||||||
}
|
|
||||||
return defaultShutdownTimeout
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server,
|
type redirectOptionFunc func(*runConfig) error
|
||||||
// 并处理其启动失败的致命错误
|
|
||||||
// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS")
|
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
|
||||||
func runServer(serverType string, srv *http.Server) {
|
return f(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAddr(addr string) RunOption {
|
||||||
|
return runOptionFunc(func(cfg *runConfig) error {
|
||||||
|
if addr == "" {
|
||||||
|
return errors.New("run address must not be empty")
|
||||||
|
}
|
||||||
|
cfg.addr = addr
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTLS(tlsConfig *tls.Config) RunOption {
|
||||||
|
return runOptionFunc(func(cfg *runConfig) error {
|
||||||
|
if tlsConfig == nil {
|
||||||
|
return errors.New("tls.Config must not be nil")
|
||||||
|
}
|
||||||
|
cfg.tlsConfig = tlsConfig
|
||||||
|
if cfg.mode == runModeHTTP {
|
||||||
|
cfg.mode = runModeHTTPS
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
|
||||||
|
return runOptionFunc(func(cfg *runConfig) error {
|
||||||
|
if addr == "" {
|
||||||
|
return errors.New("http redirect address must not be empty")
|
||||||
|
}
|
||||||
|
cfg.httpRedirectAddr = addr
|
||||||
|
cfg.mode = runModeHTTPSRedirect
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := opt.applyRedirect(cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
|
||||||
|
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||||
|
cfg.useHeaderHost = enabled
|
||||||
|
cfg.useHeaderHostSet = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRedirectHost(host string) HTTPRedirectOption {
|
||||||
|
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||||
|
if host == "" {
|
||||||
|
return errors.New("redirect host must not be empty")
|
||||||
|
}
|
||||||
|
cfg.redirectHost = host
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
|
||||||
|
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||||
|
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
|
||||||
|
for _, header := range headers {
|
||||||
|
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
|
||||||
|
if trimmed != "" {
|
||||||
|
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithGracefulShutdown(timeout time.Duration) RunOption {
|
||||||
|
return runOptionFunc(func(cfg *runConfig) error {
|
||||||
|
cfg.graceful = true
|
||||||
|
cfg.shutdownTimeoutSet = true
|
||||||
|
if timeout > 0 {
|
||||||
|
cfg.shutdownTimeout = timeout
|
||||||
|
} else {
|
||||||
|
cfg.shutdownTimeout = defaultShutdownTimeout
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithGracefulShutdownDefault() RunOption {
|
||||||
|
return runOptionFunc(func(cfg *runConfig) error {
|
||||||
|
cfg.graceful = true
|
||||||
|
cfg.shutdownDefaultSet = true
|
||||||
|
cfg.shutdownTimeout = defaultShutdownTimeout
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithShutdownContext(ctx context.Context) RunOption {
|
||||||
|
return runOptionFunc(func(cfg *runConfig) error {
|
||||||
|
if ctx == nil {
|
||||||
|
return errors.New("shutdown context must not be nil")
|
||||||
|
}
|
||||||
|
cfg.gracefulCtx = ctx
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveServer(srv *http.Server, serveTLS bool) error {
|
||||||
|
if serveTLS {
|
||||||
|
return srv.ListenAndServeTLS("", "")
|
||||||
|
}
|
||||||
|
return srv.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
func runServer(serverType string, srv *http.Server, serveTLS bool) {
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
|
||||||
protocol := "http"
|
protocol := "http"
|
||||||
if srv.TLSConfig != nil {
|
if serveTLS {
|
||||||
protocol = "https"
|
protocol = "https"
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr)
|
log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr)
|
||||||
|
|
||||||
if srv.TLSConfig != nil {
|
err := serveServer(srv, serveTLS)
|
||||||
// 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置,
|
|
||||||
// ListenAndServeTLS 的前两个参数可以为空字符串
|
|
||||||
err = srv.ListenAndServeTLS("", "")
|
|
||||||
} else {
|
|
||||||
err = srv.ListenAndServe()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed),
|
|
||||||
// 则认为是一个严重错误,并终止程序
|
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatalf("Touka %s server failed: %v", serverType, err)
|
log.Fatalf("Touka %s server failed: %v", serverType, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器
|
func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config {
|
||||||
// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿
|
|
||||||
func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error {
|
|
||||||
// 创建一个 channel 来接收操作系统信号
|
|
||||||
quit := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
|
|
||||||
<-quit // 阻塞,直到接收到上述信号之一
|
|
||||||
log.Println("Shutting down Touka server(s)...")
|
|
||||||
|
|
||||||
// 关闭日志记录器
|
|
||||||
if logger != nil {
|
|
||||||
go func() {
|
|
||||||
log.Println("Closing Touka logger...")
|
|
||||||
CloseLogger(logger)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建一个带超时的上下文,用于 Shutdown
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
|
|
||||||
|
|
||||||
// 并发地关闭所有服务器
|
|
||||||
for _, srv := range servers {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(s *http.Server) {
|
|
||||||
defer wg.Done()
|
|
||||||
if err := s.Shutdown(ctx); err != nil {
|
|
||||||
// 将错误发送到 channel
|
|
||||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
|
||||||
}
|
|
||||||
}(srv)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait() // 等待所有服务器的关闭 goroutine 完成
|
|
||||||
close(errChan) // 关闭 channel,以便可以安全地遍历它
|
|
||||||
|
|
||||||
// 收集所有关闭过程中发生的错误
|
|
||||||
var shutdownErrors []error
|
|
||||||
for err := range errChan {
|
|
||||||
shutdownErrors = append(shutdownErrors, err)
|
|
||||||
log.Printf("Shutdown error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(shutdownErrors) > 0 {
|
|
||||||
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
|
|
||||||
}
|
|
||||||
log.Println("Touka server(s) exited gracefully.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error {
|
|
||||||
// 创建一个 channel 来接收操作系统信号
|
|
||||||
quit := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
|
|
||||||
|
|
||||||
// 启动服务器
|
|
||||||
serverStopped := make(chan error, 1)
|
|
||||||
for _, srv := range servers {
|
|
||||||
go func(s *http.Server) {
|
|
||||||
serverStopped <- s.ListenAndServe()
|
|
||||||
}(srv)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
// Context 被取消 (例如,通过外部取消函数)
|
|
||||||
log.Println("Context cancelled, shutting down Touka server(s)...")
|
|
||||||
case err := <-serverStopped:
|
|
||||||
// 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误)
|
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
return fmt.Errorf("Touka HTTP server failed: %w", err)
|
|
||||||
}
|
|
||||||
log.Println("Touka HTTP server stopped gracefully.")
|
|
||||||
return nil // 服务器已自行优雅关闭,无需进一步处理
|
|
||||||
case <-quit:
|
|
||||||
// 接收到操作系统信号
|
|
||||||
log.Println("Shutting down Touka server(s) due to OS signal...")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 关闭日志记录器
|
|
||||||
if logger != nil {
|
|
||||||
go func() {
|
|
||||||
log.Println("Closing Touka logger...")
|
|
||||||
CloseLogger(logger)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建一个带超时的上下文,用于 Shutdown
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
|
|
||||||
|
|
||||||
// 并发地关闭所有服务器
|
|
||||||
for _, srv := range servers {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(s *http.Server) {
|
|
||||||
defer wg.Done()
|
|
||||||
if err := s.Shutdown(shutdownCtx); err != nil {
|
|
||||||
// 将错误发送到 channel
|
|
||||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
|
||||||
}
|
|
||||||
}(srv)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
close(errChan) // 关闭 channel,以便可以安全地遍历它
|
|
||||||
|
|
||||||
// 收集所有关闭过程中发生的错误
|
|
||||||
var shutdownErrors []error
|
|
||||||
for err := range errChan {
|
|
||||||
shutdownErrors = append(shutdownErrors, err)
|
|
||||||
log.Printf("Shutdown error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(shutdownErrors) > 0 {
|
|
||||||
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
|
|
||||||
}
|
|
||||||
log.Println("Touka server(s) exited gracefully.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- 公共 Run 方法 ---
|
|
||||||
|
|
||||||
// Run 启动一个不支持优雅关闭的 HTTP 服务器
|
|
||||||
// 这是一个阻塞调用,主要用于简单的场景或快速测试
|
|
||||||
// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法
|
|
||||||
func (engine *Engine) Run(addr ...string) error {
|
|
||||||
address := resolveAddress(addr)
|
|
||||||
srv := &http.Server{Addr: address, Handler: engine}
|
|
||||||
|
|
||||||
// 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性
|
|
||||||
//engine.applyDefaultServerConfig(srv)
|
|
||||||
if engine.ServerConfigurator != nil {
|
|
||||||
engine.ServerConfigurator(srv)
|
|
||||||
}
|
|
||||||
log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address)
|
|
||||||
return srv.ListenAndServe()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
|
|
||||||
func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error {
|
|
||||||
srv := &http.Server{
|
|
||||||
Addr: addr,
|
|
||||||
Handler: engine,
|
|
||||||
BaseContext: func(l net.Listener) context.Context {
|
|
||||||
return engine.shutdownCtx
|
|
||||||
},
|
|
||||||
}
|
|
||||||
srv.RegisterOnShutdown(engine.shutdownCancel)
|
|
||||||
|
|
||||||
// 应用框架的默认配置和用户提供的自定义配置
|
|
||||||
//engine.applyDefaultServerConfig(srv)
|
|
||||||
if engine.ServerConfigurator != nil {
|
|
||||||
engine.ServerConfigurator(srv)
|
|
||||||
}
|
|
||||||
|
|
||||||
runServer("HTTP", srv)
|
|
||||||
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
|
|
||||||
func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error {
|
|
||||||
srv := &http.Server{
|
|
||||||
Addr: addr,
|
|
||||||
Handler: engine,
|
|
||||||
BaseContext: func(l net.Listener) context.Context {
|
|
||||||
return engine.shutdownCtx
|
|
||||||
},
|
|
||||||
}
|
|
||||||
srv.RegisterOnShutdown(engine.shutdownCancel)
|
|
||||||
|
|
||||||
// 应用框架的默认配置和用户提供的自定义配置
|
|
||||||
//engine.applyDefaultServerConfig(srv)
|
|
||||||
if engine.ServerConfigurator != nil {
|
|
||||||
engine.ServerConfigurator(srv)
|
|
||||||
}
|
|
||||||
|
|
||||||
return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器
|
|
||||||
func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
|
||||||
if tlsConfig == nil {
|
if tlsConfig == nil {
|
||||||
return errors.New("tls.Config must not be nil for RunTLS")
|
return nil
|
||||||
}
|
}
|
||||||
|
return tlsConfig.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
// 配置 HTTP/2 支持 (如果使用默认配置)
|
func parseHTTPSPort(addr string) (string, error) {
|
||||||
if engine.useDefaultProtocols {
|
_, port, err := net.SplitHostPort(addr)
|
||||||
engine.SetProtocols(&ProtocolsConfig{
|
if err != nil {
|
||||||
Http1: true,
|
return "", fmt.Errorf("https address %q must include a port: %w", addr, err)
|
||||||
Http2: true, // 默认在 TLS 上启用 HTTP/2
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
srv := &http.Server{
|
func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) {
|
||||||
Addr: addr,
|
if serveTLS {
|
||||||
Handler: engine,
|
|
||||||
TLSConfig: tlsConfig,
|
|
||||||
BaseContext: func(l net.Listener) context.Context {
|
|
||||||
return engine.shutdownCtx
|
|
||||||
},
|
|
||||||
}
|
|
||||||
srv.RegisterOnShutdown(engine.shutdownCancel)
|
|
||||||
|
|
||||||
// 应用框架的默认配置和用户提供的自定义配置
|
|
||||||
// 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator
|
|
||||||
//engine.applyDefaultServerConfig(srv)
|
|
||||||
if engine.TLSServerConfigurator != nil {
|
if engine.TLSServerConfigurator != nil {
|
||||||
engine.TLSServerConfigurator(srv)
|
engine.TLSServerConfigurator(srv)
|
||||||
} else if engine.ServerConfigurator != nil {
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if engine.ServerConfigurator != nil {
|
||||||
engine.ServerConfigurator(srv)
|
engine.ServerConfigurator(srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
runServer("HTTPS", srv)
|
|
||||||
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名
|
func applyRedirectServerConfig(engine *Engine, srv *http.Server) {
|
||||||
func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
applyServerProtocols(srv, engine.serverProtocols)
|
||||||
return engine.RunTLS(addr, tlsConfig, timeouts...)
|
if engine.ServerConfigurator != nil {
|
||||||
|
engine.ServerConfigurator(srv)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭
|
func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols {
|
||||||
func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
if engine == nil {
|
||||||
if tlsConfig == nil {
|
return nil
|
||||||
return errors.New("tls.Config must not be nil for RunTLSRedir")
|
|
||||||
}
|
}
|
||||||
|
if serveTLS && engine.useDefaultProtocols {
|
||||||
|
protocols := &http.Protocols{}
|
||||||
|
protocols.SetHTTP1(true)
|
||||||
|
protocols.SetHTTP2(true)
|
||||||
|
return protocols
|
||||||
|
}
|
||||||
|
return cloneServerProtocols(engine.serverProtocols)
|
||||||
|
}
|
||||||
|
|
||||||
// --- HTTPS 服务器 ---
|
func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
|
||||||
if engine.useDefaultProtocols {
|
serveTLS := cfg.mode != runModeHTTP
|
||||||
engine.SetProtocols(&ProtocolsConfig{Http1: true, Http2: true})
|
server := &http.Server{
|
||||||
}
|
Addr: cfg.addr,
|
||||||
httpsSrv := &http.Server{
|
|
||||||
Addr: httpsAddr,
|
|
||||||
Handler: engine,
|
Handler: engine,
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: cloneTLSConfig(cfg.tlsConfig),
|
||||||
BaseContext: func(l net.Listener) context.Context {
|
}
|
||||||
|
if cfg.graceful {
|
||||||
|
server.BaseContext = func(net.Listener) context.Context {
|
||||||
return engine.shutdownCtx
|
return engine.shutdownCtx
|
||||||
},
|
|
||||||
}
|
}
|
||||||
httpsSrv.RegisterOnShutdown(engine.shutdownCancel)
|
server.RegisterOnShutdown(engine.shutdownCancel)
|
||||||
//engine.applyDefaultServerConfig(httpsSrv)
|
}
|
||||||
if engine.TLSServerConfigurator != nil {
|
applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS))
|
||||||
engine.TLSServerConfigurator(httpsSrv)
|
applyMainServerConfig(engine, server, serveTLS)
|
||||||
} else if engine.ServerConfigurator != nil {
|
return server
|
||||||
engine.ServerConfigurator(httpsSrv)
|
}
|
||||||
|
|
||||||
|
func firstRedirectHeaderHost(r *http.Request, headers []string) string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, header := range headers {
|
||||||
|
value := strings.TrimSpace(r.Header.Get(header))
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if comma := strings.IndexByte(value, ','); comma >= 0 {
|
||||||
|
value = strings.TrimSpace(value[:comma])
|
||||||
|
}
|
||||||
|
if value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
|
||||||
|
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||||
|
if cfg.redirectHost == "" {
|
||||||
|
return "", http.StatusInternalServerError, false
|
||||||
|
}
|
||||||
|
return cfg.redirectHost, 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.redirectHostHeaders) > 0 {
|
||||||
|
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
|
||||||
|
if host == "" {
|
||||||
|
return "", http.StatusUpgradeRequired, false
|
||||||
|
}
|
||||||
|
return host, 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if r == nil {
|
||||||
|
return "", http.StatusUpgradeRequired, false
|
||||||
|
}
|
||||||
|
host := strings.TrimSpace(r.Host)
|
||||||
|
if host == "" {
|
||||||
|
return "", http.StatusUpgradeRequired, false
|
||||||
|
}
|
||||||
|
return host, 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
|
||||||
|
httpsAddr := cfg.addr
|
||||||
|
httpAddr := cfg.httpRedirectAddr
|
||||||
|
httpsPort, err := parseHTTPSPort(httpsAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- HTTP 重定向服务器 ---
|
|
||||||
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
host, _, err := net.SplitHostPort(r.Host)
|
host, statusCode, ok := redirectTargetHost(r, cfg)
|
||||||
if err != nil {
|
if !ok {
|
||||||
host = r.Host
|
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, httpsPort, err := net.SplitHostPort(httpsAddr)
|
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
||||||
if err != nil {
|
host = parsedHost
|
||||||
// 如果 httpsAddr 没有端口,这是一个配置错误
|
if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
|
||||||
|
host = "[" + host + "]"
|
||||||
log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
targetURL := "https://" + host
|
targetURL := "https://" + host
|
||||||
// 只有在非标准 HTTPS 端口 (443) 时才附加端口号
|
|
||||||
if httpsPort != "443" {
|
if httpsPort != "443" {
|
||||||
targetURL = "https://" + net.JoinHostPort(host, httpsPort)
|
targetURL = "https://" + net.JoinHostPort(host, httpsPort)
|
||||||
}
|
}
|
||||||
|
|
@ -351,22 +345,205 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con
|
||||||
|
|
||||||
http.Redirect(w, r, targetURL, http.StatusMovedPermanently)
|
http.Redirect(w, r, targetURL, http.StatusMovedPermanently)
|
||||||
})
|
})
|
||||||
httpSrv := &http.Server{
|
|
||||||
Addr: httpAddr,
|
|
||||||
Handler: redirectHandler,
|
|
||||||
}
|
|
||||||
//engine.applyDefaultServerConfig(httpSrv)
|
|
||||||
if engine.ServerConfigurator != nil {
|
|
||||||
engine.ServerConfigurator(httpSrv)
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- 启动服务器和优雅关闭 ---
|
server := &http.Server{Addr: httpAddr, Handler: redirectHandler}
|
||||||
runServer("HTTPS", httpsSrv)
|
applyRedirectServerConfig(engine, server)
|
||||||
runServer("HTTP Redirect", httpSrv)
|
return server, nil
|
||||||
return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性
|
func validateRunConfig(cfg runConfig) error {
|
||||||
func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil {
|
||||||
return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...)
|
return errors.New("WithHTTPRedirect requires WithTLS")
|
||||||
|
}
|
||||||
|
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
|
||||||
|
return errors.New("https mode requires WithTLS")
|
||||||
|
}
|
||||||
|
if cfg.gracefulCtx != nil && !cfg.graceful {
|
||||||
|
return errors.New("WithShutdownContext requires graceful shutdown")
|
||||||
|
}
|
||||||
|
if len(cfg.redirectHostHeaders) > 0 {
|
||||||
|
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
|
||||||
|
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.useHeaderHostSet && cfg.useHeaderHost {
|
||||||
|
if cfg.redirectHost != "" {
|
||||||
|
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
|
||||||
|
}
|
||||||
|
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||||
|
if cfg.redirectHost == "" {
|
||||||
|
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
|
||||||
|
}
|
||||||
|
if len(cfg.redirectHostHeaders) > 0 {
|
||||||
|
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveShutdownTimeout(cfg runConfig) time.Duration {
|
||||||
|
if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet {
|
||||||
|
if cfg.shutdownTimeout > 0 {
|
||||||
|
return cfg.shutdownTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultShutdownTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeLoggerAsync(logger *reco.Logger) {
|
||||||
|
if logger == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
log.Println("Closing Touka logger...")
|
||||||
|
CloseLogger(logger)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func shutdownServers(servers []*http.Server, timeout time.Duration) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errChan := make(chan error, len(servers))
|
||||||
|
for _, srv := range servers {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(s *http.Server) {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := s.Shutdown(ctx); err != nil {
|
||||||
|
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
||||||
|
}
|
||||||
|
}(srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errChan)
|
||||||
|
|
||||||
|
var shutdownErrors []error
|
||||||
|
for err := range errChan {
|
||||||
|
shutdownErrors = append(shutdownErrors, err)
|
||||||
|
log.Printf("Shutdown error: %v", err)
|
||||||
|
}
|
||||||
|
if len(shutdownErrors) > 0 {
|
||||||
|
return errors.Join(shutdownErrors...)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error {
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
defer signal.Stop(quit)
|
||||||
|
|
||||||
|
serverStopped := make(chan error, len(servers))
|
||||||
|
for i, srv := range servers {
|
||||||
|
serveTLSFlag := serveTLS[i]
|
||||||
|
go func(server *http.Server, useTLS bool) {
|
||||||
|
serverStopped <- serveServer(server, useTLS)
|
||||||
|
}(srv, serveTLSFlag)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-serverStopped:
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil {
|
||||||
|
return errors.Join(err, shutdownErr)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Println("Touka server stopped gracefully.")
|
||||||
|
return nil
|
||||||
|
case <-quit:
|
||||||
|
log.Println("Shutting down Touka server(s) due to OS signal...")
|
||||||
|
case <-shutdownCtx.Done():
|
||||||
|
log.Println("Context cancelled, shutting down Touka server(s)...")
|
||||||
|
}
|
||||||
|
|
||||||
|
closeLoggerAsync(logger)
|
||||||
|
if err := shutdownServers(servers, timeout); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Println("Touka server(s) exited gracefully.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the engine with the provided startup options.
|
||||||
|
//
|
||||||
|
// Default behavior with no options:
|
||||||
|
// - HTTP only
|
||||||
|
// - listens on :8080
|
||||||
|
// - no graceful shutdown orchestration
|
||||||
|
//
|
||||||
|
// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable
|
||||||
|
// signal-aware graceful shutdown and request-context cancellation semantics.
|
||||||
|
// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown.
|
||||||
|
func (engine *Engine) Run(opts ...RunOption) error {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := opt.apply(&cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.httpRedirectAddr != "" {
|
||||||
|
cfg.mode = runModeHTTPSRedirect
|
||||||
|
} else if cfg.tlsConfig != nil {
|
||||||
|
cfg.mode = runModeHTTPS
|
||||||
|
}
|
||||||
|
if err := validateRunConfig(cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
serveTLS := cfg.mode != runModeHTTP
|
||||||
|
|
||||||
|
mainServer := buildMainServer(engine, cfg)
|
||||||
|
servers := []*http.Server{mainServer}
|
||||||
|
serveTLSFlags := []bool{serveTLS}
|
||||||
|
if cfg.mode == runModeHTTPSRedirect {
|
||||||
|
redirectServer, err := buildRedirectServer(engine, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
servers = append(servers, redirectServer)
|
||||||
|
serveTLSFlags = append(serveTLSFlags, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.graceful {
|
||||||
|
if len(servers) > 1 {
|
||||||
|
serverStopped := make(chan error, len(servers))
|
||||||
|
for i, srv := range servers {
|
||||||
|
serveTLSFlag := serveTLSFlags[i]
|
||||||
|
go func(server *http.Server, useTLS bool) {
|
||||||
|
serverStopped <- serveServer(server, useTLS)
|
||||||
|
}(srv, serveTLSFlag)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := <-serverStopped
|
||||||
|
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return errors.Join(err, shutdownErr)
|
||||||
|
}
|
||||||
|
return shutdownErr
|
||||||
|
}
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
protocolLabel := "HTTP"
|
||||||
|
if serveTLS {
|
||||||
|
protocolLabel = "HTTPS"
|
||||||
|
}
|
||||||
|
log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr)
|
||||||
|
return serveServer(mainServer, serveTLS)
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownCtx := context.Background()
|
||||||
|
if cfg.gracefulCtx != nil {
|
||||||
|
shutdownCtx = cfg.gracefulCtx
|
||||||
|
}
|
||||||
|
return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
492
serve_test.go
Normal file
492
serve_test.go
Normal file
|
|
@ -0,0 +1,492 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateSelfSignedCert(t *testing.T) tls.Certificate {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate private key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{CommonName: "127.0.0.1"},
|
||||||
|
NotBefore: time.Now().Add(-time.Hour),
|
||||||
|
NotAfter: time.Now().Add(time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageServerAuth,
|
||||||
|
},
|
||||||
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||||
|
}
|
||||||
|
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create self-signed cert: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||||
|
|
||||||
|
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse self-signed cert: %v", err)
|
||||||
|
}
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen on ephemeral port: %v", err)
|
||||||
|
}
|
||||||
|
addr := listener.Addr().String()
|
||||||
|
if err := listener.Close(); err != nil {
|
||||||
|
t.Fatalf("close temporary listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}),
|
||||||
|
// RunShutdown uses the HTTP startup path and must not let a shared
|
||||||
|
// ServerConfigurator accidentally turn it into HTTPS.
|
||||||
|
TLSConfig: &tls.Config{},
|
||||||
|
}
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errCh <- serveServer(srv, false)
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 200 * time.Millisecond}
|
||||||
|
var resp *http.Response
|
||||||
|
requestURL := "http://" + addr
|
||||||
|
|
||||||
|
deadline := time.Now().Add(3 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
resp, err = client.Get(requestURL)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case serveErr := <-errCh:
|
||||||
|
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: request error=%v, serve error=%v", err, serveErr)
|
||||||
|
default:
|
||||||
|
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read response body: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status code: got %d want %d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
if string(body) != "ok" {
|
||||||
|
t.Fatalf("unexpected body: got %q want %q", string(body), "ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||||
|
t.Fatalf("shutdown server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := <-errCh; !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
err := engine.Run(WithHTTPRedirect(":80"))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected redirect mode without TLS to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
err := engine.Run(
|
||||||
|
WithAddr(":443"),
|
||||||
|
WithTLS(&tls.Config{}),
|
||||||
|
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
|
||||||
|
t.Fatalf("apply graceful default option: %v", err)
|
||||||
|
}
|
||||||
|
if !cfg.graceful {
|
||||||
|
t.Fatal("expected graceful shutdown to be enabled")
|
||||||
|
}
|
||||||
|
if cfg.shutdownTimeout != defaultShutdownTimeout {
|
||||||
|
t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
|
||||||
|
if err := WithTLS(tlsConfig).apply(&cfg); err != nil {
|
||||||
|
t.Fatalf("apply TLS option: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.mode != runModeHTTPS {
|
||||||
|
t.Fatalf("expected HTTPS mode, got %v", cfg.mode)
|
||||||
|
}
|
||||||
|
if cfg.graceful {
|
||||||
|
t.Fatal("expected TLS option to remain independent from graceful shutdown")
|
||||||
|
}
|
||||||
|
if cfg.tlsConfig != tlsConfig {
|
||||||
|
t.Fatal("expected TLS config to be preserved in run config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
|
||||||
|
t.Fatal("expected redirect server builder to reject https address without port")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
ctx := t.Context()
|
||||||
|
if err := WithShutdownContext(ctx).apply(&cfg); err != nil {
|
||||||
|
t.Fatalf("apply shutdown context option: %v", err)
|
||||||
|
}
|
||||||
|
if err := validateRunConfig(cfg); err == nil {
|
||||||
|
t.Fatal("expected shutdown context without graceful shutdown to fail validation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
cfg.httpRedirectAddr = ":80"
|
||||||
|
if err := validateRunConfig(cfg); err != nil {
|
||||||
|
t.Fatalf("validate run config: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.mode != runModeHTTP {
|
||||||
|
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
cfg.mode = runModeHTTPSRedirect
|
||||||
|
cfg.tlsConfig = &tls.Config{}
|
||||||
|
cfg.useHeaderHost = false
|
||||||
|
cfg.useHeaderHostSet = true
|
||||||
|
if err := validateRunConfig(cfg); err == nil {
|
||||||
|
t.Fatal("expected configured host mode without redirect host to fail validation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
cfg.mode = runModeHTTPSRedirect
|
||||||
|
cfg.tlsConfig = &tls.Config{}
|
||||||
|
cfg.useHeaderHost = true
|
||||||
|
cfg.useHeaderHostSet = true
|
||||||
|
cfg.redirectHost = "configured.example"
|
||||||
|
if err := validateRunConfig(cfg); err == nil {
|
||||||
|
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
|
||||||
|
if server.BaseContext == nil {
|
||||||
|
t.Fatal("expected graceful main server to set BaseContext")
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen for base context check: %v", err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
if got := server.BaseContext(listener); got != engine.shutdownCtx {
|
||||||
|
t.Fatal("expected graceful main server to use engine shutdown context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
serverConfigured := false
|
||||||
|
tlsConfigured := false
|
||||||
|
engine.SetServerConfigurator(func(s *http.Server) {
|
||||||
|
serverConfigured = true
|
||||||
|
s.ReadTimeout = time.Second
|
||||||
|
})
|
||||||
|
engine.SetTLSServerConfigurator(func(s *http.Server) {
|
||||||
|
tlsConfigured = true
|
||||||
|
s.IdleTimeout = time.Second
|
||||||
|
})
|
||||||
|
|
||||||
|
server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||||
|
if !tlsConfigured {
|
||||||
|
t.Fatal("expected TLS configurator to run for HTTPS main server")
|
||||||
|
}
|
||||||
|
if serverConfigured {
|
||||||
|
t.Fatal("expected generic server configurator to be skipped when TLS configurator is set")
|
||||||
|
}
|
||||||
|
if server.IdleTimeout != time.Second {
|
||||||
|
t.Fatal("expected TLS configurator changes to be applied to HTTPS main server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
configured := false
|
||||||
|
engine.SetServerConfigurator(func(s *http.Server) {
|
||||||
|
configured = true
|
||||||
|
s.ReadTimeout = time.Second
|
||||||
|
})
|
||||||
|
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
if !configured {
|
||||||
|
t.Fatal("expected redirect server to use generic server configurator")
|
||||||
|
}
|
||||||
|
if server.ReadTimeout != time.Second {
|
||||||
|
t.Fatal("expected redirect server configurator changes to be applied")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||||
|
if !httpsServer.Protocols.HTTP2() {
|
||||||
|
t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings")
|
||||||
|
}
|
||||||
|
|
||||||
|
httpServer := buildMainServer(engine, defaultRunConfig())
|
||||||
|
if httpServer.Protocols.HTTP2() {
|
||||||
|
t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||||
|
req.Host = "example.com:80"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.Handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" {
|
||||||
|
t.Fatalf("unexpected redirect location: %q", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{
|
||||||
|
addr: ":443",
|
||||||
|
httpRedirectAddr: ":80",
|
||||||
|
useHeaderHost: true,
|
||||||
|
useHeaderHostSet: true,
|
||||||
|
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||||
|
req.Host = "example.com:80"
|
||||||
|
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
||||||
|
req.Header.Set("X-First-Host", "first.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.Handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
|
||||||
|
t.Fatalf("unexpected redirect location: %q", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{
|
||||||
|
addr: ":443",
|
||||||
|
httpRedirectAddr: ":80",
|
||||||
|
useHeaderHost: true,
|
||||||
|
useHeaderHostSet: true,
|
||||||
|
redirectHostHeaders: []string{"X-Forwarded-Host"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||||
|
req.Host = "example.com:80"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.Handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUpgradeRequired {
|
||||||
|
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{
|
||||||
|
addr: ":443",
|
||||||
|
httpRedirectAddr: ":80",
|
||||||
|
useHeaderHost: false,
|
||||||
|
useHeaderHostSet: true,
|
||||||
|
redirectHost: "configured.example",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||||
|
req.Host = "example.com:80"
|
||||||
|
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.Handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
|
||||||
|
t.Fatalf("unexpected redirect location: %q", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil)
|
||||||
|
req.Host = "[::1]:80"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.Handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" {
|
||||||
|
t.Fatalf("unexpected IPv6 redirect location: %q", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
|
||||||
|
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen on occupied addr: %v", err)
|
||||||
|
}
|
||||||
|
occupiedAddr := occupied.Addr().String()
|
||||||
|
defer occupied.Close()
|
||||||
|
|
||||||
|
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen for redirect addr: %v", err)
|
||||||
|
}
|
||||||
|
redirectAddr := redirectListener.Addr().String()
|
||||||
|
if err := redirectListener.Close(); err != nil {
|
||||||
|
t.Fatalf("close redirect addr probe: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
mainServer := &http.Server{Addr: occupiedAddr, Handler: engine}
|
||||||
|
|
||||||
|
err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected gracefulServe to fail when one server cannot bind")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), occupiedAddr) {
|
||||||
|
t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond)
|
||||||
|
if dialErr == nil {
|
||||||
|
conn.Close()
|
||||||
|
t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr)
|
||||||
|
}
|
||||||
|
if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") {
|
||||||
|
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
|
||||||
|
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen on occupied addr: %v", err)
|
||||||
|
}
|
||||||
|
occupiedAddr := occupied.Addr().String()
|
||||||
|
defer occupied.Close()
|
||||||
|
|
||||||
|
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen for redirect addr: %v", err)
|
||||||
|
}
|
||||||
|
redirectAddr := redirectListener.Addr().String()
|
||||||
|
if err := redirectListener.Close(); err != nil {
|
||||||
|
t.Fatalf("close redirect addr probe: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
err = engine.Run(
|
||||||
|
WithAddr(occupiedAddr),
|
||||||
|
WithTLS(&tls.Config{}),
|
||||||
|
WithHTTPRedirect(redirectAddr),
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), occupiedAddr) {
|
||||||
|
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
66
sse.go
66
sse.go
|
|
@ -111,46 +111,40 @@ func (c *Context) EventStream(streamer func(w io.Writer) bool) {
|
||||||
// EventStreamChan 返回用于 SSE 事件流的 channel.
|
// EventStreamChan 返回用于 SSE 事件流的 channel.
|
||||||
// 这是为高级并发场景设计的、更灵活的API.
|
// 这是为高级并发场景设计的、更灵活的API.
|
||||||
//
|
//
|
||||||
// 重要:
|
// 与 EventStream 回调模式类似, 此方法是阻塞的: handler 会在此方法中停留,
|
||||||
// - 调用者必须 close(eventChan) 来结束事件流.
|
// 直到事件 channel 被关闭 (close eventChan) 或客户端断开连接.
|
||||||
// - 调用者必须在独立的 goroutine 中消费 errChan 来处理错误和连接断开.
|
// 这保证了 Context 不会在 SSE 流期间被 pool 回收.
|
||||||
// - 为防止 goroutine 泄漏, 建议发送方在 select 中同时监听 c.Request.Context().Done().
|
//
|
||||||
|
// eventChan 必须在调用此方法之前创建, 以便调用者可以在独立的 goroutine 中发送事件.
|
||||||
|
// 调用者必须在完成后 close(eventChan) 来结束流.
|
||||||
|
// 生产者 goroutine 必须在 select 中监听 c.Request.Context().Done(), 否则在客户端断开时会产生 goroutine 泄漏.
|
||||||
//
|
//
|
||||||
// 详细用法:
|
// 详细用法:
|
||||||
//
|
//
|
||||||
// r.GET("/sse/channel", func(c *touka.Context) {
|
// r.GET("/sse/channel", func(c *touka.Context) {
|
||||||
// eventChan, errChan := c.EventStreamChan()
|
// eventChan := make(chan touka.Event)
|
||||||
//
|
//
|
||||||
// // 必须在独立的goroutine中处理错误和连接断开.
|
// // 在独立的 goroutine 中异步发送事件.
|
||||||
// go func() {
|
// go func() {
|
||||||
// if err := <-errChan; err != nil {
|
// defer close(eventChan) // 完成后关闭 channel 以结束事件流.
|
||||||
// c.Errorf("SSE channel error: %v", err)
|
|
||||||
// }
|
|
||||||
// }()
|
|
||||||
//
|
|
||||||
// // 在另一个goroutine中异步发送事件.
|
|
||||||
// go func() {
|
|
||||||
// // 重要: 必须在逻辑结束时关闭channel, 以通知框架.
|
|
||||||
// defer close(eventChan)
|
|
||||||
//
|
//
|
||||||
// for i := 1; i <= 5; i++ {
|
// for i := 1; i <= 5; i++ {
|
||||||
// select {
|
// select {
|
||||||
// case <-c.Request.Context().Done():
|
// case <-c.Request.Context().Done():
|
||||||
// return // 客户端已断开, 退出 goroutine.
|
// return // 客户端已断开, 退出 goroutine.
|
||||||
// default:
|
// case eventChan <- touka.Event{
|
||||||
// eventChan <- touka.Event{
|
|
||||||
// Id: fmt.Sprintf("%d", i),
|
// Id: fmt.Sprintf("%d", i),
|
||||||
// Data: "hello from channel",
|
// Data: "hello from channel",
|
||||||
|
// }:
|
||||||
// }
|
// }
|
||||||
// time.Sleep(2 * time.Second)
|
// time.Sleep(2 * time.Second)
|
||||||
// }
|
// }
|
||||||
// }
|
|
||||||
// }()
|
// }()
|
||||||
|
//
|
||||||
|
// // 阻塞直到事件流结束.
|
||||||
|
// c.EventStreamChan(eventChan)
|
||||||
// })
|
// })
|
||||||
func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
|
func (c *Context) EventStreamChan(eventChan <-chan Event) {
|
||||||
eventChan := make(chan Event)
|
|
||||||
errChan := make(chan error, 1)
|
|
||||||
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache, no-transform")
|
c.Writer.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||||
c.Writer.Header().Del("Connection")
|
c.Writer.Header().Del("Connection")
|
||||||
|
|
@ -159,8 +153,16 @@ func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
|
||||||
c.Writer.WriteHeader(http.StatusOK)
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
// 捕获稳定的引用, 不持有 *Context 指针, 以免 Context 被 pool 回收后出现竞态.
|
||||||
|
w := c.Writer
|
||||||
|
fl, _ := w.(http.Flusher)
|
||||||
|
reqCtx := c.Request.Context()
|
||||||
|
|
||||||
|
goroutineExited := make(chan struct{})
|
||||||
|
|
||||||
|
// 写入 goroutine: 从 eventChan 消费事件并写入响应.
|
||||||
go func() {
|
go func() {
|
||||||
defer close(errChan)
|
defer close(goroutineExited)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
|
@ -168,17 +170,23 @@ func (c *Context) EventStreamChan() (chan<- Event, <-chan error) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := event.Render(c.Writer); err != nil {
|
if err := event.Render(w); err != nil {
|
||||||
errChan <- err
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Writer.Flush()
|
if fl != nil {
|
||||||
case <-c.Request.Context().Done():
|
fl.Flush()
|
||||||
errChan <- c.Request.Context().Err()
|
}
|
||||||
|
case <-reqCtx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return eventChan, errChan
|
// 阻塞直到:
|
||||||
|
// 1. 写入 goroutine 退出 (eventChan 关闭或写入失败)
|
||||||
|
// 2. 客户端断开连接 (reqCtx 取消)
|
||||||
|
select {
|
||||||
|
case <-goroutineExited:
|
||||||
|
case <-reqCtx.Done():
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
142
sse_test.go
Normal file
142
sse_test.go
Normal file
|
|
@ -0,0 +1,142 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestEventStreamChanBlocksHandler verifies that EventStreamChan blocks until
|
||||||
|
// the event channel is closed.
|
||||||
|
func TestEventStreamChanBlocksHandler(t *testing.T) {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/sse", nil)
|
||||||
|
c, _ := CreateTestContextWithRequest(rr, req)
|
||||||
|
|
||||||
|
handlerReturned := make(chan struct{})
|
||||||
|
eventChan := make(chan Event)
|
||||||
|
|
||||||
|
// Start producer goroutine before EventStreamChan blocks
|
||||||
|
go func() {
|
||||||
|
defer close(eventChan)
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
eventChan <- Event{Data: "hello"}
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
c.EventStreamChan(eventChan)
|
||||||
|
close(handlerReturned)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for goroutine to start
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Handler should NOT have returned (eventChan not closed)
|
||||||
|
select {
|
||||||
|
case <-handlerReturned:
|
||||||
|
t.Fatal("Handler returned before eventChan was closed - EventStreamChan is not blocking")
|
||||||
|
case <-time.After(40 * time.Millisecond):
|
||||||
|
// good, still blocking
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for producer to finish (30+30ms + margin)
|
||||||
|
select {
|
||||||
|
case <-handlerReturned:
|
||||||
|
// good, handler returned
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("Handler did not return after eventChan was closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEventStreamChanUnblocksOnClientDisconnect verifies the handler returns
|
||||||
|
// when the request context is cancelled, even if eventChan is never closed.
|
||||||
|
func TestEventStreamChanUnblocksOnClientDisconnect(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/sse", nil).WithContext(ctx)
|
||||||
|
c, _ := CreateTestContextWithRequest(rr, req)
|
||||||
|
|
||||||
|
eventChan := make(chan Event)
|
||||||
|
handlerReturned := make(chan struct{})
|
||||||
|
|
||||||
|
// Producer never closes eventChan
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case eventChan <- Event{Data: "tick"}:
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
c.EventStreamChan(eventChan)
|
||||||
|
close(handlerReturned)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Handler should NOT have returned
|
||||||
|
select {
|
||||||
|
case <-handlerReturned:
|
||||||
|
t.Fatal("Handler returned before stream ended")
|
||||||
|
case <-time.After(60 * time.Millisecond):
|
||||||
|
// good, still blocked
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel context to simulate client disconnect
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-handlerReturned:
|
||||||
|
// good
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Fatal("Handler did not return after client disconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEventStreamChanWritesEvents verifies the SSE event format is correct.
|
||||||
|
func TestEventStreamChanWritesEvents(t *testing.T) {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/sse", nil)
|
||||||
|
c, _ := CreateTestContextWithRequest(rr, req)
|
||||||
|
|
||||||
|
eventChan := make(chan Event)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(eventChan)
|
||||||
|
eventChan <- Event{Id: "1", Event: "tick", Data: "hello\nworld"}
|
||||||
|
eventChan <- Event{Id: "2", Data: "second"}
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.EventStreamChan(eventChan)
|
||||||
|
|
||||||
|
body := rr.Body.String()
|
||||||
|
|
||||||
|
ct := rr.Header().Get("Content-Type")
|
||||||
|
if !strings.Contains(ct, "text/event-stream") {
|
||||||
|
t.Fatalf("expected text/event-stream content type, got %q", ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(body, "id: 1") {
|
||||||
|
t.Fatal("missing id field in first event")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "event: tick") {
|
||||||
|
t.Fatal("missing event field in first event")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "data: hello") {
|
||||||
|
t.Fatal("missing data line 1 in first event")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "data: world") {
|
||||||
|
t.Fatal("missing data line 2 in first event")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "id: 2") {
|
||||||
|
t.Fatal("missing id field in second event")
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "data: second") {
|
||||||
|
t.Fatal("missing data in second event")
|
||||||
|
}
|
||||||
|
}
|
||||||
8
touka.go
8
touka.go
|
|
@ -22,10 +22,10 @@ type HandlerFunc func(*Context)
|
||||||
// HandlersChain 定义处理函数链(中间件栈)的类型。
|
// HandlersChain 定义处理函数链(中间件栈)的类型。
|
||||||
type HandlersChain []HandlerFunc
|
type HandlersChain []HandlerFunc
|
||||||
|
|
||||||
// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。
|
// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。
|
||||||
type IRouter interface {
|
type Router interface {
|
||||||
Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组
|
Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组
|
||||||
Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组
|
Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组
|
||||||
|
|
||||||
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
|
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
|
||||||
GET(relativePath string, handlers ...HandlerFunc)
|
GET(relativePath string, handlers ...HandlerFunc)
|
||||||
|
|
|
||||||
82
tree.go
82
tree.go
|
|
@ -124,6 +124,7 @@ type node struct {
|
||||||
path string // 当前节点的路径段
|
path string // 当前节点的路径段
|
||||||
indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点
|
indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点
|
||||||
wildChild bool // 是否包含通配符子节点(:param 或 *catchAll)
|
wildChild bool // 是否包含通配符子节点(:param 或 *catchAll)
|
||||||
|
hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由
|
||||||
nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有)
|
nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有)
|
||||||
priority uint32 // 节点的优先级, 用于查找时优先匹配
|
priority uint32 // 节点的优先级, 用于查找时优先匹配
|
||||||
children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾
|
children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾
|
||||||
|
|
@ -131,6 +132,19 @@ type node struct {
|
||||||
fullPath string // 完整路径, 用于调试和错误信息
|
fullPath string // 完整路径, 用于调试和错误信息
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func routeNeedsCaseInsensitiveLookup(path string) bool {
|
||||||
|
for i := 0; i < len(path); i++ {
|
||||||
|
c := path[i]
|
||||||
|
if c >= utf8.RuneSelf {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if c >= 'A' && c <= 'Z' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序.
|
// incrementChildPrio 增加给定子节点的优先级并在必要时重新排序.
|
||||||
func (n *node) incrementChildPrio(pos int) int {
|
func (n *node) incrementChildPrio(pos int) int {
|
||||||
cs := n.children // 获取子节点切片
|
cs := n.children // 获取子节点切片
|
||||||
|
|
@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int {
|
||||||
func (n *node) addRoute(path string, handlers HandlersChain) {
|
func (n *node) addRoute(path string, handlers HandlersChain) {
|
||||||
fullPath := path // 记录完整的路径
|
fullPath := path // 记录完整的路径
|
||||||
n.priority++ // 增加当前节点的优先级
|
n.priority++ // 增加当前节点的优先级
|
||||||
|
if routeNeedsCaseInsensitiveLookup(path) {
|
||||||
|
n.hasCaseInsensitivePath = true
|
||||||
|
}
|
||||||
|
|
||||||
// 如果是空树(根节点)
|
// 如果是空树(根节点)
|
||||||
if len(n.path) == 0 && len(n.children) == 0 {
|
if len(n.path) == 0 && len(n.children) == 0 {
|
||||||
|
|
@ -452,12 +469,14 @@ type skippedNode struct {
|
||||||
// 建议进行 TSR(尾部斜杠重定向).
|
// 建议进行 TSR(尾部斜杠重定向).
|
||||||
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
|
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
|
||||||
var globalParamsCount int16 // 全局参数计数
|
var globalParamsCount int16 // 全局参数计数
|
||||||
|
var backtrackToWildChild bool
|
||||||
|
|
||||||
walk: // 外部循环用于遍历路由树
|
walk: // 外部循环用于遍历路由树
|
||||||
for {
|
for {
|
||||||
prefix := n.path // 当前节点的路径前缀
|
prefix := n.path // 当前节点的路径前缀
|
||||||
if len(path) > len(prefix) {
|
if len(path) > len(prefix) {
|
||||||
if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头
|
if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头
|
||||||
|
pathAtNode := path
|
||||||
path = path[len(prefix):] // 移除已匹配的前缀
|
path = path[len(prefix):] // 移除已匹配的前缀
|
||||||
|
|
||||||
// 在访问 path[0] 之前进行安全检查
|
// 在访问 path[0] 之前进行安全检查
|
||||||
|
|
@ -467,23 +486,16 @@ walk: // 外部循环用于遍历路由树
|
||||||
|
|
||||||
// 优先尝试所有非通配符子节点, 通过匹配索引字符
|
// 优先尝试所有非通配符子节点, 通过匹配索引字符
|
||||||
idxc := path[0] // 剩余路径的第一个字符
|
idxc := path[0] // 剩余路径的第一个字符
|
||||||
for i, c := range []byte(n.indices) {
|
if !backtrackToWildChild {
|
||||||
if c == idxc { // 如果找到匹配的索引字符
|
for i := 0; i < len(n.indices); i++ {
|
||||||
|
if n.indices[i] == idxc { // 如果找到匹配的索引字符
|
||||||
// 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯
|
// 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯
|
||||||
if n.wildChild {
|
if n.wildChild {
|
||||||
index := len(*skippedNodes)
|
index := len(*skippedNodes)
|
||||||
*skippedNodes = (*skippedNodes)[:index+1]
|
*skippedNodes = (*skippedNodes)[:index+1]
|
||||||
(*skippedNodes)[index] = skippedNode{
|
(*skippedNodes)[index] = skippedNode{
|
||||||
path: prefix + path, // 记录跳过的路径
|
path: pathAtNode, // 记录进入当前节点时的剩余路径
|
||||||
node: &node{ // 复制当前节点的状态
|
node: n,
|
||||||
path: n.path,
|
|
||||||
wildChild: n.wildChild,
|
|
||||||
nType: n.nType,
|
|
||||||
priority: n.priority,
|
|
||||||
children: n.children,
|
|
||||||
handlers: n.handlers,
|
|
||||||
fullPath: n.fullPath,
|
|
||||||
},
|
|
||||||
paramsCount: globalParamsCount, // 记录当前参数计数
|
paramsCount: globalParamsCount, // 记录当前参数计数
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -492,6 +504,9 @@ walk: // 外部循环用于遍历路由树
|
||||||
continue walk // 继续外部循环
|
continue walk // 继续外部循环
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
backtrackToWildChild = false
|
||||||
|
}
|
||||||
|
|
||||||
if !n.wildChild {
|
if !n.wildChild {
|
||||||
// 如果路径在循环结束时不等于 '/' 且当前节点没有子节点
|
// 如果路径在循环结束时不等于 '/' 且当前节点没有子节点
|
||||||
|
|
@ -507,6 +522,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
*value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片
|
*value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片
|
||||||
}
|
}
|
||||||
globalParamsCount = skippedNode.paramsCount // 恢复参数计数
|
globalParamsCount = skippedNode.paramsCount // 恢复参数计数
|
||||||
|
backtrackToWildChild = true
|
||||||
continue walk // 继续外部循环
|
continue walk // 继续外部循环
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -547,7 +563,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
i := len(*value.params)
|
i := len(*value.params)
|
||||||
*value.params = (*value.params)[:i+1] // 扩展切片
|
*value.params = (*value.params)[:i+1] // 扩展切片
|
||||||
val := path[:end] // 提取参数值
|
val := path[:end] // 提取参数值
|
||||||
if unescape { // 如果需要进行 URL 解码
|
if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
|
||||||
if v, err := url.QueryUnescape(val); err == nil {
|
if v, err := url.QueryUnescape(val); err == nil {
|
||||||
val = v // 解码成功则更新值
|
val = v // 解码成功则更新值
|
||||||
}
|
}
|
||||||
|
|
@ -599,7 +615,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
i := len(*value.params)
|
i := len(*value.params)
|
||||||
*value.params = (*value.params)[:i+1] // 扩展切片
|
*value.params = (*value.params)[:i+1] // 扩展切片
|
||||||
val := path // 参数值是剩余的整个路径
|
val := path // 参数值是剩余的整个路径
|
||||||
if unescape { // 如果需要进行 URL 解码
|
if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) {
|
||||||
if v, err := url.QueryUnescape(path); err == nil {
|
if v, err := url.QueryUnescape(path); err == nil {
|
||||||
val = v // 解码成功则更新值
|
val = v // 解码成功则更新值
|
||||||
}
|
}
|
||||||
|
|
@ -634,6 +650,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
*value.params = (*value.params)[:skippedNode.paramsCount]
|
*value.params = (*value.params)[:skippedNode.paramsCount]
|
||||||
}
|
}
|
||||||
globalParamsCount = skippedNode.paramsCount
|
globalParamsCount = skippedNode.paramsCount
|
||||||
|
backtrackToWildChild = true
|
||||||
continue walk
|
continue walk
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -658,8 +675,8 @@ walk: // 外部循环用于遍历路由树
|
||||||
}
|
}
|
||||||
|
|
||||||
// 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议
|
// 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
if c == '/' { // 如果索引中包含 '/'
|
if n.indices[i] == '/' { // 如果索引中包含 '/'
|
||||||
n = n.children[i] // 移动到对应的子节点
|
n = n.children[i] // 移动到对应的子节点
|
||||||
value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
||||||
(n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数
|
(n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数
|
||||||
|
|
@ -688,6 +705,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
*value.params = (*value.params)[:skippedNode.paramsCount]
|
*value.params = (*value.params)[:skippedNode.paramsCount]
|
||||||
}
|
}
|
||||||
globalParamsCount = skippedNode.paramsCount
|
globalParamsCount = skippedNode.paramsCount
|
||||||
|
backtrackToWildChild = true
|
||||||
continue walk
|
continue walk
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -701,13 +719,15 @@ walk: // 外部循环用于遍历路由树
|
||||||
// 它还可以选择修复尾部斜杠.
|
// 它还可以选择修复尾部斜杠.
|
||||||
// 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功.
|
// 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功.
|
||||||
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
|
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
|
||||||
const stackBufSize = 128 // 栈上缓冲区的默认大小
|
return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash)
|
||||||
|
}
|
||||||
|
|
||||||
// 在常见情况下使用栈上静态大小的缓冲区.
|
func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) {
|
||||||
// 如果路径太长, 则在堆上分配缓冲区.
|
if buf != nil {
|
||||||
buf := make([]byte, 0, stackBufSize)
|
buf = buf[:0]
|
||||||
if length := len(path) + 1; length > stackBufSize {
|
}
|
||||||
buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区
|
if cap(buf) < len(path)+1 {
|
||||||
|
buf = make([]byte, 0, len(path)+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
ciPath := n.findCaseInsensitivePathRec(
|
ciPath := n.findCaseInsensitivePathRec(
|
||||||
|
|
@ -758,8 +778,8 @@ walk: // 外部循环用于遍历路由树
|
||||||
// 未找到处理函数.
|
// 未找到处理函数.
|
||||||
// 尝试通过添加尾部斜杠来修复路径
|
// 尝试通过添加尾部斜杠来修复路径
|
||||||
if fixTrailingSlash {
|
if fixTrailingSlash {
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
if c == '/' { // 如果索引中包含 '/'
|
if n.indices[i] == '/' { // 如果索引中包含 '/'
|
||||||
n = n.children[i] // 移动到对应的子节点
|
n = n.children[i] // 移动到对应的子节点
|
||||||
if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数
|
||||||
(n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数
|
(n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数
|
||||||
|
|
@ -781,8 +801,8 @@ walk: // 外部循环用于遍历路由树
|
||||||
if rb[0] != 0 {
|
if rb[0] != 0 {
|
||||||
// 旧 rune 未处理完
|
// 旧 rune 未处理完
|
||||||
idxc := rb[0]
|
idxc := rb[0]
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
if c == idxc {
|
if n.indices[i] == idxc {
|
||||||
// 继续处理子节点
|
// 继续处理子节点
|
||||||
n = n.children[i]
|
n = n.children[i]
|
||||||
npLen = len(n.path)
|
npLen = len(n.path)
|
||||||
|
|
@ -813,9 +833,9 @@ walk: // 外部循环用于遍历路由树
|
||||||
rb = shiftNRuneBytes(rb, off)
|
rb = shiftNRuneBytes(rb, off)
|
||||||
|
|
||||||
idxc := rb[0]
|
idxc := rb[0]
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
// 小写匹配
|
// 小写匹配
|
||||||
if c == idxc {
|
if n.indices[i] == idxc {
|
||||||
// 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在
|
// 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在
|
||||||
if out := n.children[i].findCaseInsensitivePathRec(
|
if out := n.children[i].findCaseInsensitivePathRec(
|
||||||
path, ciPath, rb, fixTrailingSlash,
|
path, ciPath, rb, fixTrailingSlash,
|
||||||
|
|
@ -832,9 +852,9 @@ walk: // 外部循环用于遍历路由树
|
||||||
rb = shiftNRuneBytes(rb, off)
|
rb = shiftNRuneBytes(rb, off)
|
||||||
|
|
||||||
idxc := rb[0]
|
idxc := rb[0]
|
||||||
for i, c := range []byte(n.indices) {
|
for i := 0; i < len(n.indices); i++ {
|
||||||
// 大写匹配
|
// 大写匹配
|
||||||
if c == idxc {
|
if n.indices[i] == idxc {
|
||||||
// 继续处理子节点
|
// 继续处理子节点
|
||||||
n = n.children[i]
|
n = n.children[i]
|
||||||
npLen = len(n.path)
|
npLen = len(n.path)
|
||||||
|
|
@ -852,7 +872,7 @@ walk: // 外部循环用于遍历路由树
|
||||||
return nil // 未找到, 返回 nil
|
return nil // 未找到, 返回 nil
|
||||||
}
|
}
|
||||||
|
|
||||||
n = n.children[0] // 移动到通配符子节点(通常是唯一一个)
|
n = n.children[len(n.children)-1] // 通配符子节点约定始终位于末尾
|
||||||
switch n.nType {
|
switch n.nType {
|
||||||
case param: // 参数节点
|
case param: // 参数节点
|
||||||
// 查找参数结束位置('/' 或路径末尾)
|
// 查找参数结束位置('/' 或路径末尾)
|
||||||
|
|
|
||||||
94
tree_test.go
94
tree_test.go
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Used as a workaround since we can't compare functions or their addresses
|
// Used as a workaround since we can't compare functions or their addresses
|
||||||
|
|
@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode {
|
||||||
return &ps
|
return &ps
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
resultCh := make(chan nodeValue, 1)
|
||||||
|
go func() {
|
||||||
|
resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case value := <-resultCh:
|
||||||
|
return value
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path)
|
||||||
|
return nodeValue{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) {
|
func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) {
|
||||||
unescape := false
|
unescape := false
|
||||||
if len(unescapes) >= 1 {
|
if len(unescapes) >= 1 {
|
||||||
|
|
@ -901,6 +919,34 @@ func TestTreeInvalidNodeType(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindCaseInsensitivePathWithStaticAndParamRoutesDoesNotPanicOnMiss(t *testing.T) {
|
||||||
|
tree := &node{}
|
||||||
|
routes := [...]string{
|
||||||
|
"/:user/:repo/info/refs",
|
||||||
|
"/healthz",
|
||||||
|
"/api/db/data",
|
||||||
|
"/api/db/sum",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
tree.addRoute(route, fakeHandler(route))
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Fatalf("unexpected panic while looking up missing path: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if out, found := tree.findCaseInsensitivePath("/does-not-exist", true); found || out != nil {
|
||||||
|
t.Fatalf("expected missing path lookup to return no match, got %q, %t", string(out), found)
|
||||||
|
}
|
||||||
|
|
||||||
|
if out, found := tree.findCaseInsensitivePath("/does-not-exist", false); found || out != nil {
|
||||||
|
t.Fatalf("expected missing path lookup without trailing slash fix to return no match, got %q, %t", string(out), found)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTreeInvalidParamsType(t *testing.T) {
|
func TestTreeInvalidParamsType(t *testing.T) {
|
||||||
tree := &node{}
|
tree := &node{}
|
||||||
// add a child with wildcard
|
// add a child with wildcard
|
||||||
|
|
@ -1076,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) {
|
||||||
t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams)
|
t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
routes []string
|
||||||
|
requestPath string
|
||||||
|
wantFullPath string
|
||||||
|
wantParams Params
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "param route after static dead end",
|
||||||
|
routes: []string{"/foo/bar", "/foo/:id/details"},
|
||||||
|
requestPath: "/foo/bar/details",
|
||||||
|
wantFullPath: "/foo/:id/details",
|
||||||
|
wantParams: Params{{Key: "id", Value: "bar"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "catch-all route after static dead end",
|
||||||
|
routes: []string{"/foo/bar", "/foo/:id/*rest"},
|
||||||
|
requestPath: "/foo/bar/baz.txt",
|
||||||
|
wantFullPath: "/foo/:id/*rest",
|
||||||
|
wantParams: Params{
|
||||||
|
{Key: "id", Value: "bar"},
|
||||||
|
{Key: "rest", Value: "/baz.txt"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tree := &node{}
|
||||||
|
for _, route := range tt.routes {
|
||||||
|
tree.addRoute(route, fakeHandler(route))
|
||||||
|
}
|
||||||
|
|
||||||
|
value := getValueWithTimeout(t, tree, tt.requestPath, false)
|
||||||
|
if value.handlers == nil {
|
||||||
|
t.Fatalf("expected handlers for %q", tt.requestPath)
|
||||||
|
}
|
||||||
|
if value.fullPath != tt.wantFullPath {
|
||||||
|
t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath)
|
||||||
|
}
|
||||||
|
if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) {
|
||||||
|
t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue