mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Merge pull request #82 from infinite-iroha/break/v1-redesign-run-api
feat: redesign server startup around Run options
This commit is contained in:
commit
efa1e3fb3f
14 changed files with 1037 additions and 341 deletions
|
|
@ -59,9 +59,9 @@ func main() {
|
|||
c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query)
|
||||
})
|
||||
|
||||
// 启动服务器 (支持优雅关闭)
|
||||
// 启动服务器(通过 WithGracefulShutdown 启用优雅关闭)
|
||||
log.Println("Touka Server starting on :8080...")
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatalf("Touka server failed to start: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,13 +70,13 @@ func main() {
|
|||
r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB
|
||||
|
||||
// ... 其他配置
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
#### 1.3. 服务器生命周期管理
|
||||
|
||||
Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。
|
||||
Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。
|
||||
|
||||
```go
|
||||
func main() {
|
||||
|
|
@ -90,11 +90,11 @@ func main() {
|
|||
fmt.Println("自定义的 HTTP 服务器配置已应用")
|
||||
})
|
||||
|
||||
// 启动服务器,并支持优雅关闭
|
||||
// RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号
|
||||
// 第二个参数是优雅关闭的超时时间
|
||||
// 启动服务器,并通过 Run 选项启用优雅关闭
|
||||
// Run(...) 会阻塞当前 goroutine
|
||||
// WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒
|
||||
fmt.Println("服务器启动于 :8080")
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatalf("服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -187,7 +187,7 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
|
||||
func AuthMiddleware() touka.HandlerFunc {
|
||||
|
|
@ -313,7 +313,7 @@ func main() {
|
|||
})
|
||||
})
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
|
||||
// templates/index.html
|
||||
|
|
@ -400,7 +400,7 @@ func main() {
|
|||
c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID})
|
||||
})
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -483,7 +483,7 @@ func main() {
|
|||
// 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获
|
||||
r.StaticDir("/files", "./non-existent-dir")
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -546,7 +546,7 @@ func main() {
|
|||
// 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录
|
||||
r.StaticFS("/", http.FS(subFS))
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
|||
134
docs/advanced.md
134
docs/advanced.md
|
|
@ -44,7 +44,9 @@ r.SetTLSServerConfigurator(func(server *http.Server) {
|
|||
Touka 支持配置 HTTP/1.1、HTTP/2 和 H2C(HTTP/2 Cleartext):
|
||||
|
||||
```go
|
||||
// 使用默认协议配置(仅 HTTP/1.1)
|
||||
// 使用默认协议配置
|
||||
// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集,
|
||||
// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。
|
||||
r.SetDefaultProtocols()
|
||||
|
||||
// 自定义协议配置
|
||||
|
|
@ -57,33 +59,147 @@ r.SetProtocols(&touka.ProtocolsConfig{
|
|||
|
||||
### 启动方式
|
||||
|
||||
Touka 提供了多种服务器启动方式:
|
||||
Touka 统一通过 `Run(opts...)` 启动服务器:
|
||||
|
||||
```go
|
||||
// 1. 简单启动(无优雅停机)
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
|
||||
// 2. 带优雅停机的启动
|
||||
r.RunShutdown(":8080", 10*time.Second)
|
||||
r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second))
|
||||
|
||||
// 3. 带上下文的优雅停机
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
r.RunShutdownWithContext(":8080", ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
r.Run(
|
||||
touka.WithAddr(":8080"),
|
||||
touka.WithGracefulShutdown(10*time.Second),
|
||||
touka.WithShutdownContext(ctx),
|
||||
)
|
||||
|
||||
// 4. HTTPS 启动
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
// 其他 TLS 配置...
|
||||
}
|
||||
r.RunTLS(":443", tlsConfig, 10*time.Second)
|
||||
// WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithGracefulShutdownDefault(),
|
||||
)
|
||||
|
||||
// 5. HTTPS + HTTP 重定向
|
||||
r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second)
|
||||
// WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(":80"),
|
||||
touka.WithGracefulShutdown(10*time.Second),
|
||||
)
|
||||
|
||||
// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host)
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(
|
||||
":80",
|
||||
touka.WithUseHeaderHost(true),
|
||||
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
|
||||
),
|
||||
)
|
||||
|
||||
// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host)
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(
|
||||
":80",
|
||||
touka.WithUseHeaderHost(false),
|
||||
touka.WithRedirectHost("example.com"),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
### HTTPS Redirect Host 策略
|
||||
|
||||
`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。
|
||||
|
||||
可用的 redirect 子选项:
|
||||
|
||||
- `touka.WithUseHeaderHost(true|false)`
|
||||
- `touka.WithRedirectHostHeaders([]string{...})`
|
||||
- `touka.WithRedirectHost("example.com")`
|
||||
|
||||
#### 模式一:使用请求输入侧的 host
|
||||
|
||||
当 `WithUseHeaderHost(true)` 时:
|
||||
|
||||
- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host`
|
||||
- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值
|
||||
- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required`
|
||||
|
||||
示例:
|
||||
|
||||
```go
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(
|
||||
":80",
|
||||
touka.WithUseHeaderHost(true),
|
||||
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
#### 模式二:使用配置的固定 host
|
||||
|
||||
当 `WithUseHeaderHost(false)` 时:
|
||||
|
||||
- 不读取 `Request.Host`
|
||||
- 不读取 `WithRedirectHostHeaders(...)`
|
||||
- 必须配置 `WithRedirectHost("example.com")`
|
||||
|
||||
示例:
|
||||
|
||||
```go
|
||||
r.Run(
|
||||
touka.WithAddr(":443"),
|
||||
touka.WithTLS(tlsConfig),
|
||||
touka.WithHTTPRedirect(
|
||||
":80",
|
||||
touka.WithUseHeaderHost(false),
|
||||
touka.WithRedirectHost("example.com"),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
#### 严格校验规则
|
||||
|
||||
以下组合会直接返回配置错误:
|
||||
|
||||
- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)`
|
||||
- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)`
|
||||
- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)`
|
||||
- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)`
|
||||
- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)`
|
||||
|
||||
#### 优先级关系
|
||||
|
||||
1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式
|
||||
2. `WithUseHeaderHost(...)` 决定 host 来源模式
|
||||
3. 当 `WithUseHeaderHost(true)` 时:
|
||||
- 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询
|
||||
- 未配置时使用 `Request.Host`
|
||||
4. 当 `WithUseHeaderHost(false)` 时:
|
||||
- 只使用 `WithRedirectHost(...)`
|
||||
|
||||
**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。
|
||||
|
||||
## 优雅停机 (Graceful Shutdown)
|
||||
|
||||
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。
|
||||
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。
|
||||
|
||||
```go
|
||||
r := touka.Default()
|
||||
|
|
@ -91,7 +207,7 @@ r := touka.Default()
|
|||
|
||||
// 监听 SIGINT 和 SIGTERM 信号
|
||||
// 如果在 10 秒内未处理完,则强制关闭
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatal("服务器退出异常:", err)
|
||||
}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其
|
|||
|
||||
1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。
|
||||
2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。
|
||||
3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理。
|
||||
3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求。
|
||||
|
||||
Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ func main() {
|
|||
|
||||
// 4. 启动服务器并监听 8080 端口
|
||||
log.Println("Touka server is running on :8080")
|
||||
if err := r.Run(":8080"); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080")); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -66,11 +66,11 @@ go run main.go
|
|||
|
||||
## 优雅停机
|
||||
|
||||
在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成。
|
||||
在生产环境中,我们推荐为 `Run` 追加优雅关闭选项。启用后,Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`。
|
||||
|
||||
```go
|
||||
// 等待 10 秒以处理剩余请求
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatalf("Server forced to shutdown: %v", err)
|
||||
}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ func main() {
|
|||
Target: target,
|
||||
}))
|
||||
|
||||
_ = r.Run(":8080")
|
||||
_ = r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -497,7 +497,7 @@ func main() {
|
|||
},
|
||||
}))
|
||||
|
||||
if err := r.RunShutdown(":8080", 10*time.Second); err != nil {
|
||||
if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ func main() {
|
|||
r := touka.Default()
|
||||
fsroot, _ := fs.Sub(content, "dist")
|
||||
r.StaticFS("/", http.FS(fsroot))
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -125,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) {
|
|||
2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。
|
||||
3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。
|
||||
|
||||
**注意:** 请务必使用 `RunShutdown`、`RunTLS` 或 `RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号。
|
||||
**注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)` 或 `touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文。
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ func main() {
|
|||
// 您也可以使用 StaticFS 服务根路径
|
||||
// r.StaticFS("/", http.FS(fsroot))
|
||||
|
||||
r.Run(":8080")
|
||||
r.Run(touka.WithAddr(":8080"))
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
|||
32
engine.go
32
engine.go
|
|
@ -404,11 +404,18 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
|
|||
}()
|
||||
}
|
||||
|
||||
// applyDefaultServerConfig 应用框架的默认配置到 http.Server
|
||||
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
|
||||
if engine.serverProtocols != nil {
|
||||
srv.Protocols = engine.serverProtocols
|
||||
if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() {
|
||||
func cloneServerProtocols(protocols *http.Protocols) *http.Protocols {
|
||||
if protocols == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *protocols
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func applyServerProtocols(srv *http.Server, protocols *http.Protocols) {
|
||||
if protocols != nil {
|
||||
srv.Protocols = cloneServerProtocols(protocols)
|
||||
if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() {
|
||||
if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
@ -416,6 +423,11 @@ func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
|
|||
}
|
||||
}
|
||||
|
||||
// applyDefaultServerConfig 应用框架的默认配置到 http.Server
|
||||
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
|
||||
applyServerProtocols(srv, engine.serverProtocols)
|
||||
}
|
||||
|
||||
// 配置全局Req Body大小限制
|
||||
func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) {
|
||||
engine.GlobalMaxRequestBodySize = size
|
||||
|
|
@ -614,7 +626,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
|
|||
|
||||
// Use 将全局中间件添加到 Engine
|
||||
// 这些中间件将应用于所有注册的路由
|
||||
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter {
|
||||
func (engine *Engine) Use(middleware ...HandlerFunc) Router {
|
||||
engine.globalHandlers = append(engine.globalHandlers, middleware...)
|
||||
engine.rebuildFallbackChains()
|
||||
return engine
|
||||
|
|
@ -683,7 +695,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo {
|
|||
|
||||
// Group 创建一个新的路由组
|
||||
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
|
||||
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter {
|
||||
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router {
|
||||
return &RouterGroup{
|
||||
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
|
||||
basePath: resolveRoutePath("/", relativePath),
|
||||
|
|
@ -692,7 +704,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute
|
|||
}
|
||||
|
||||
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
|
||||
// 它也实现了 IRouter 接口,允许嵌套分组
|
||||
// 它也实现了 Router 接口,允许嵌套分组
|
||||
type RouterGroup struct {
|
||||
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
|
||||
basePath string // 组路径前缀
|
||||
|
|
@ -701,7 +713,7 @@ type RouterGroup struct {
|
|||
|
||||
// Use 将中间件应用于当前路由组
|
||||
// 这些中间件将应用于当前组及其子组的所有路由
|
||||
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter {
|
||||
func (group *RouterGroup) Use(middleware ...HandlerFunc) Router {
|
||||
group.Handlers = append(group.Handlers, middleware...)
|
||||
return group
|
||||
}
|
||||
|
|
@ -747,7 +759,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) {
|
|||
}
|
||||
|
||||
// Group 为当前组创建一个新的子组
|
||||
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter {
|
||||
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router {
|
||||
return &RouterGroup{
|
||||
Handlers: group.engine.combineHandlers(group.Handlers, handlers),
|
||||
basePath: resolveRoutePath(group.basePath, relativePath),
|
||||
|
|
|
|||
|
|
@ -70,42 +70,25 @@ func TestApplyDefaultServerConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRunTLSProtocolInheritance(t *testing.T) {
|
||||
func TestTLSRunDefaultsProtocolInheritance(t *testing.T) {
|
||||
engine := New()
|
||||
|
||||
// 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2
|
||||
if engine.useDefaultProtocols {
|
||||
engine.setProtocols(&ProtocolsConfig{
|
||||
Http1: true,
|
||||
Http2: true,
|
||||
})
|
||||
}
|
||||
|
||||
srv := &http.Server{TLSConfig: &tls.Config{}}
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||
|
||||
if !srv.Protocols.HTTP2() {
|
||||
t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config")
|
||||
t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config")
|
||||
}
|
||||
|
||||
// 模拟用户设置了自定义协议后调用 RunTLS
|
||||
// 模拟用户设置了自定义协议后进入 TLS 运行模式
|
||||
engine = New()
|
||||
engine.SetProtocols(&ProtocolsConfig{
|
||||
Http1: true,
|
||||
Http2: false, // 用户明确不想要 HTTP/2
|
||||
})
|
||||
|
||||
if engine.useDefaultProtocols {
|
||||
engine.setProtocols(&ProtocolsConfig{
|
||||
Http1: true,
|
||||
Http2: true,
|
||||
})
|
||||
}
|
||||
|
||||
srv2 := &http.Server{TLSConfig: &tls.Config{}}
|
||||
engine.applyDefaultServerConfig(srv2)
|
||||
srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||
|
||||
if srv2.Protocols.HTTP2() {
|
||||
t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously")
|
||||
t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
741
serve.go
741
serve.go
|
|
@ -14,6 +14,7 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
|
@ -21,45 +22,173 @@ import (
|
|||
"github.com/fenthope/reco"
|
||||
)
|
||||
|
||||
// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间
|
||||
const defaultShutdownTimeout = 5 * time.Second
|
||||
|
||||
// --- 内部辅助函数 ---
|
||||
type runMode uint8
|
||||
|
||||
// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080"
|
||||
func resolveAddress(addr []string) string {
|
||||
switch len(addr) {
|
||||
case 0:
|
||||
return ":8080"
|
||||
case 1:
|
||||
return addr[0]
|
||||
default:
|
||||
panic("too many parameters provided for server address")
|
||||
const (
|
||||
runModeHTTP runMode = iota
|
||||
runModeHTTPS
|
||||
runModeHTTPSRedirect
|
||||
)
|
||||
|
||||
type runConfig struct {
|
||||
addr string
|
||||
httpRedirectAddr string
|
||||
tlsConfig *tls.Config
|
||||
redirectHost string
|
||||
redirectHostHeaders []string
|
||||
useHeaderHost bool
|
||||
useHeaderHostSet bool
|
||||
graceful bool
|
||||
shutdownTimeout time.Duration
|
||||
gracefulCtx context.Context
|
||||
mode runMode
|
||||
shutdownDefaultSet bool
|
||||
shutdownTimeoutSet bool
|
||||
}
|
||||
|
||||
type RunOption interface {
|
||||
apply(*runConfig) error
|
||||
}
|
||||
|
||||
type runOptionFunc func(*runConfig) error
|
||||
|
||||
func (f runOptionFunc) apply(cfg *runConfig) error {
|
||||
return f(cfg)
|
||||
}
|
||||
|
||||
func defaultRunConfig() runConfig {
|
||||
return runConfig{
|
||||
addr: ":8080",
|
||||
shutdownTimeout: defaultShutdownTimeout,
|
||||
mode: runModeHTTP,
|
||||
useHeaderHost: true,
|
||||
}
|
||||
}
|
||||
|
||||
// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值
|
||||
func getShutdownTimeout(timeouts []time.Duration) time.Duration {
|
||||
if len(timeouts) > 0 && timeouts[0] > 0 {
|
||||
return timeouts[0]
|
||||
}
|
||||
return defaultShutdownTimeout
|
||||
type HTTPRedirectOption interface {
|
||||
applyRedirect(*runConfig) error
|
||||
}
|
||||
|
||||
type redirectOptionFunc func(*runConfig) error
|
||||
|
||||
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
|
||||
return f(cfg)
|
||||
}
|
||||
|
||||
func WithAddr(addr string) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if addr == "" {
|
||||
return errors.New("run address must not be empty")
|
||||
}
|
||||
cfg.addr = addr
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithTLS(tlsConfig *tls.Config) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if tlsConfig == nil {
|
||||
return errors.New("tls.Config must not be nil")
|
||||
}
|
||||
cfg.tlsConfig = tlsConfig
|
||||
if cfg.mode == runModeHTTP {
|
||||
cfg.mode = runModeHTTPS
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if addr == "" {
|
||||
return errors.New("http redirect address must not be empty")
|
||||
}
|
||||
cfg.httpRedirectAddr = addr
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
if err := opt.applyRedirect(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
|
||||
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.useHeaderHost = enabled
|
||||
cfg.useHeaderHostSet = true
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithRedirectHost(host string) HTTPRedirectOption {
|
||||
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||
if host == "" {
|
||||
return errors.New("redirect host must not be empty")
|
||||
}
|
||||
cfg.redirectHost = host
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
|
||||
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
|
||||
for _, header := range headers {
|
||||
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
|
||||
if trimmed != "" {
|
||||
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithGracefulShutdown(timeout time.Duration) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.graceful = true
|
||||
cfg.shutdownTimeoutSet = true
|
||||
if timeout > 0 {
|
||||
cfg.shutdownTimeout = timeout
|
||||
} else {
|
||||
cfg.shutdownTimeout = defaultShutdownTimeout
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithGracefulShutdownDefault() RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.graceful = true
|
||||
cfg.shutdownDefaultSet = true
|
||||
cfg.shutdownTimeout = defaultShutdownTimeout
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithShutdownContext(ctx context.Context) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown context must not be nil")
|
||||
}
|
||||
cfg.gracefulCtx = ctx
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// serveServer 根据显式指定的启动模式运行 HTTP 或 HTTPS 服务器.
|
||||
func serveServer(srv *http.Server, serveTLS bool) error {
|
||||
if serveTLS {
|
||||
// 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置,
|
||||
// ListenAndServeTLS 的前两个参数可以为空字符串
|
||||
return srv.ListenAndServeTLS("", "")
|
||||
}
|
||||
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
|
||||
// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server,
|
||||
// 并处理其启动失败的致命错误
|
||||
// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS")
|
||||
func runServer(serverType string, srv *http.Server, serveTLS bool) {
|
||||
go func() {
|
||||
protocol := "http"
|
||||
|
|
@ -70,284 +199,145 @@ func runServer(serverType string, srv *http.Server, serveTLS bool) {
|
|||
log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr)
|
||||
|
||||
err := serveServer(srv, serveTLS)
|
||||
|
||||
// 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed),
|
||||
// 则认为是一个严重错误,并终止程序
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Touka %s server failed: %v", serverType, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器
|
||||
// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿
|
||||
func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error {
|
||||
// 创建一个 channel 来接收操作系统信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
|
||||
<-quit // 阻塞,直到接收到上述信号之一
|
||||
log.Println("Shutting down Touka server(s)...")
|
||||
|
||||
// 关闭日志记录器
|
||||
if logger != nil {
|
||||
go func() {
|
||||
log.Println("Closing Touka logger...")
|
||||
CloseLogger(logger)
|
||||
}()
|
||||
}
|
||||
|
||||
// 创建一个带超时的上下文,用于 Shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
|
||||
|
||||
// 并发地关闭所有服务器
|
||||
for _, srv := range servers {
|
||||
wg.Add(1)
|
||||
go func(s *http.Server) {
|
||||
defer wg.Done()
|
||||
if err := s.Shutdown(ctx); err != nil {
|
||||
// 将错误发送到 channel
|
||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
||||
}
|
||||
}(srv)
|
||||
}
|
||||
|
||||
wg.Wait() // 等待所有服务器的关闭 goroutine 完成
|
||||
close(errChan) // 关闭 channel,以便可以安全地遍历它
|
||||
|
||||
// 收集所有关闭过程中发生的错误
|
||||
var shutdownErrors []error
|
||||
for err := range errChan {
|
||||
shutdownErrors = append(shutdownErrors, err)
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
|
||||
if len(shutdownErrors) > 0 {
|
||||
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
|
||||
}
|
||||
log.Println("Touka server(s) exited gracefully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error {
|
||||
// 创建一个 channel 来接收操作系统信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
|
||||
|
||||
// 启动服务器
|
||||
serverStopped := make(chan error, 1)
|
||||
for _, srv := range servers {
|
||||
go func(s *http.Server) {
|
||||
serverStopped <- s.ListenAndServe()
|
||||
}(srv)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context 被取消 (例如,通过外部取消函数)
|
||||
log.Println("Context cancelled, shutting down Touka server(s)...")
|
||||
case err := <-serverStopped:
|
||||
// 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("Touka HTTP server failed: %w", err)
|
||||
}
|
||||
log.Println("Touka HTTP server stopped gracefully.")
|
||||
return nil // 服务器已自行优雅关闭,无需进一步处理
|
||||
case <-quit:
|
||||
// 接收到操作系统信号
|
||||
log.Println("Shutting down Touka server(s) due to OS signal...")
|
||||
}
|
||||
|
||||
// 关闭日志记录器
|
||||
if logger != nil {
|
||||
go func() {
|
||||
log.Println("Closing Touka logger...")
|
||||
CloseLogger(logger)
|
||||
}()
|
||||
}
|
||||
|
||||
// 创建一个带超时的上下文,用于 Shutdown
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
|
||||
|
||||
// 并发地关闭所有服务器
|
||||
for _, srv := range servers {
|
||||
wg.Add(1)
|
||||
go func(s *http.Server) {
|
||||
defer wg.Done()
|
||||
if err := s.Shutdown(shutdownCtx); err != nil {
|
||||
// 将错误发送到 channel
|
||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
||||
}
|
||||
}(srv)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan) // 关闭 channel,以便可以安全地遍历它
|
||||
|
||||
// 收集所有关闭过程中发生的错误
|
||||
var shutdownErrors []error
|
||||
for err := range errChan {
|
||||
shutdownErrors = append(shutdownErrors, err)
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
|
||||
if len(shutdownErrors) > 0 {
|
||||
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
|
||||
}
|
||||
log.Println("Touka server(s) exited gracefully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- 公共 Run 方法 ---
|
||||
|
||||
// Run 启动一个不支持优雅关闭的 HTTP 服务器
|
||||
// 这是一个阻塞调用,主要用于简单的场景或快速测试
|
||||
// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法
|
||||
func (engine *Engine) Run(addr ...string) error {
|
||||
address := resolveAddress(addr)
|
||||
srv := &http.Server{Addr: address, Handler: engine}
|
||||
|
||||
// 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address)
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
|
||||
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
|
||||
func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error {
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: engine,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return engine.shutdownCtx
|
||||
},
|
||||
}
|
||||
srv.RegisterOnShutdown(engine.shutdownCancel)
|
||||
|
||||
// 应用框架的默认配置和用户提供的自定义配置
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
|
||||
runServer("HTTP", srv, false)
|
||||
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
|
||||
}
|
||||
|
||||
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器
|
||||
func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error {
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: engine,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return engine.shutdownCtx
|
||||
},
|
||||
}
|
||||
srv.RegisterOnShutdown(engine.shutdownCancel)
|
||||
|
||||
// 应用框架的默认配置和用户提供的自定义配置
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
|
||||
return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco)
|
||||
}
|
||||
|
||||
// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器
|
||||
func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
||||
func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config {
|
||||
if tlsConfig == nil {
|
||||
return errors.New("tls.Config must not be nil for RunTLS")
|
||||
return nil
|
||||
}
|
||||
return tlsConfig.Clone()
|
||||
}
|
||||
|
||||
// 配置 HTTP/2 支持 (如果使用默认配置)
|
||||
if engine.useDefaultProtocols {
|
||||
engine.setProtocols(&ProtocolsConfig{
|
||||
Http1: true,
|
||||
Http2: true, // 默认在 TLS 上启用 HTTP/2
|
||||
})
|
||||
func parseHTTPSPort(addr string) (string, error) {
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("https address %q must include a port: %w", addr, err)
|
||||
}
|
||||
return port, nil
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: engine,
|
||||
TLSConfig: tlsConfig,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return engine.shutdownCtx
|
||||
},
|
||||
}
|
||||
srv.RegisterOnShutdown(engine.shutdownCancel)
|
||||
|
||||
// 应用框架的默认配置和用户提供的自定义配置
|
||||
// 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator
|
||||
engine.applyDefaultServerConfig(srv)
|
||||
func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) {
|
||||
if serveTLS {
|
||||
if engine.TLSServerConfigurator != nil {
|
||||
engine.TLSServerConfigurator(srv)
|
||||
} else if engine.ServerConfigurator != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
|
||||
runServer("HTTPS", srv, true)
|
||||
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
|
||||
}
|
||||
|
||||
// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名
|
||||
func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
||||
return engine.RunTLS(addr, tlsConfig, timeouts...)
|
||||
func applyRedirectServerConfig(engine *Engine, srv *http.Server) {
|
||||
applyServerProtocols(srv, engine.serverProtocols)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(srv)
|
||||
}
|
||||
}
|
||||
|
||||
// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭
|
||||
func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
||||
if tlsConfig == nil {
|
||||
return errors.New("tls.Config must not be nil for RunTLSRedir")
|
||||
func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols {
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
if serveTLS && engine.useDefaultProtocols {
|
||||
protocols := &http.Protocols{}
|
||||
protocols.SetHTTP1(true)
|
||||
protocols.SetHTTP2(true)
|
||||
return protocols
|
||||
}
|
||||
return cloneServerProtocols(engine.serverProtocols)
|
||||
}
|
||||
|
||||
// --- HTTPS 服务器 ---
|
||||
if engine.useDefaultProtocols {
|
||||
engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true})
|
||||
}
|
||||
httpsSrv := &http.Server{
|
||||
Addr: httpsAddr,
|
||||
func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
|
||||
serveTLS := cfg.mode != runModeHTTP
|
||||
server := &http.Server{
|
||||
Addr: cfg.addr,
|
||||
Handler: engine,
|
||||
TLSConfig: tlsConfig,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
TLSConfig: cloneTLSConfig(cfg.tlsConfig),
|
||||
}
|
||||
if cfg.graceful {
|
||||
server.BaseContext = func(net.Listener) context.Context {
|
||||
return engine.shutdownCtx
|
||||
},
|
||||
}
|
||||
httpsSrv.RegisterOnShutdown(engine.shutdownCancel)
|
||||
engine.applyDefaultServerConfig(httpsSrv)
|
||||
if engine.TLSServerConfigurator != nil {
|
||||
engine.TLSServerConfigurator(httpsSrv)
|
||||
} else if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(httpsSrv)
|
||||
server.RegisterOnShutdown(engine.shutdownCancel)
|
||||
}
|
||||
applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS))
|
||||
applyMainServerConfig(engine, server, serveTLS)
|
||||
return server
|
||||
}
|
||||
|
||||
func firstRedirectHeaderHost(r *http.Request, headers []string) string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
for _, header := range headers {
|
||||
value := strings.TrimSpace(r.Header.Get(header))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if comma := strings.IndexByte(value, ','); comma >= 0 {
|
||||
value = strings.TrimSpace(value[:comma])
|
||||
}
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
|
||||
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||
if cfg.redirectHost == "" {
|
||||
return "", http.StatusInternalServerError, false
|
||||
}
|
||||
return cfg.redirectHost, 0, true
|
||||
}
|
||||
|
||||
if len(cfg.redirectHostHeaders) > 0 {
|
||||
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
|
||||
if host == "" {
|
||||
return "", http.StatusUpgradeRequired, false
|
||||
}
|
||||
return host, 0, true
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
return "", http.StatusUpgradeRequired, false
|
||||
}
|
||||
host := strings.TrimSpace(r.Host)
|
||||
if host == "" {
|
||||
return "", http.StatusUpgradeRequired, false
|
||||
}
|
||||
return host, 0, true
|
||||
}
|
||||
|
||||
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
|
||||
httpsAddr := cfg.addr
|
||||
httpAddr := cfg.httpRedirectAddr
|
||||
httpsPort, err := parseHTTPSPort(httpsAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// --- HTTP 重定向服务器 ---
|
||||
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
host, statusCode, ok := redirectTargetHost(r, cfg)
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
_, httpsPort, err := net.SplitHostPort(httpsAddr)
|
||||
if err != nil {
|
||||
// 如果 httpsAddr 没有端口,这是一个配置错误
|
||||
|
||||
log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr)
|
||||
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = parsedHost
|
||||
if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
|
||||
host = "[" + host + "]"
|
||||
}
|
||||
}
|
||||
|
||||
targetURL := "https://" + host
|
||||
// 只有在非标准 HTTPS 端口 (443) 时才附加端口号
|
||||
if httpsPort != "443" {
|
||||
targetURL = "https://" + net.JoinHostPort(host, httpsPort)
|
||||
}
|
||||
|
|
@ -355,22 +345,205 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con
|
|||
|
||||
http.Redirect(w, r, targetURL, http.StatusMovedPermanently)
|
||||
})
|
||||
httpSrv := &http.Server{
|
||||
Addr: httpAddr,
|
||||
Handler: redirectHandler,
|
||||
}
|
||||
engine.applyDefaultServerConfig(httpSrv)
|
||||
if engine.ServerConfigurator != nil {
|
||||
engine.ServerConfigurator(httpSrv)
|
||||
|
||||
server := &http.Server{Addr: httpAddr, Handler: redirectHandler}
|
||||
applyRedirectServerConfig(engine, server)
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// --- 启动服务器和优雅关闭 ---
|
||||
runServer("HTTPS", httpsSrv, true)
|
||||
runServer("HTTP Redirect", httpSrv, false)
|
||||
return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco)
|
||||
func validateRunConfig(cfg runConfig) error {
|
||||
if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil {
|
||||
return errors.New("WithHTTPRedirect requires WithTLS")
|
||||
}
|
||||
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
|
||||
return errors.New("https mode requires WithTLS")
|
||||
}
|
||||
if cfg.gracefulCtx != nil && !cfg.graceful {
|
||||
return errors.New("WithShutdownContext requires graceful shutdown")
|
||||
}
|
||||
if len(cfg.redirectHostHeaders) > 0 {
|
||||
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
|
||||
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
|
||||
}
|
||||
}
|
||||
if cfg.useHeaderHostSet && cfg.useHeaderHost {
|
||||
if cfg.redirectHost != "" {
|
||||
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
|
||||
}
|
||||
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||
if cfg.redirectHost == "" {
|
||||
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
|
||||
}
|
||||
if len(cfg.redirectHostHeaders) > 0 {
|
||||
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性
|
||||
func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error {
|
||||
return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...)
|
||||
func effectiveShutdownTimeout(cfg runConfig) time.Duration {
|
||||
if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet {
|
||||
if cfg.shutdownTimeout > 0 {
|
||||
return cfg.shutdownTimeout
|
||||
}
|
||||
}
|
||||
return defaultShutdownTimeout
|
||||
}
|
||||
|
||||
func closeLoggerAsync(logger *reco.Logger) {
|
||||
if logger == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
log.Println("Closing Touka logger...")
|
||||
CloseLogger(logger)
|
||||
}()
|
||||
}
|
||||
|
||||
func shutdownServers(servers []*http.Server, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(servers))
|
||||
for _, srv := range servers {
|
||||
wg.Add(1)
|
||||
go func(s *http.Server) {
|
||||
defer wg.Done()
|
||||
if err := s.Shutdown(ctx); err != nil {
|
||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
||||
}
|
||||
}(srv)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
var shutdownErrors []error
|
||||
for err := range errChan {
|
||||
shutdownErrors = append(shutdownErrors, err)
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
if len(shutdownErrors) > 0 {
|
||||
return errors.Join(shutdownErrors...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error {
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(quit)
|
||||
|
||||
serverStopped := make(chan error, len(servers))
|
||||
for i, srv := range servers {
|
||||
serveTLSFlag := serveTLS[i]
|
||||
go func(server *http.Server, useTLS bool) {
|
||||
serverStopped <- serveServer(server, useTLS)
|
||||
}(srv, serveTLSFlag)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-serverStopped:
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil {
|
||||
return errors.Join(err, shutdownErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Println("Touka server stopped gracefully.")
|
||||
return nil
|
||||
case <-quit:
|
||||
log.Println("Shutting down Touka server(s) due to OS signal...")
|
||||
case <-shutdownCtx.Done():
|
||||
log.Println("Context cancelled, shutting down Touka server(s)...")
|
||||
}
|
||||
|
||||
closeLoggerAsync(logger)
|
||||
if err := shutdownServers(servers, timeout); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Println("Touka server(s) exited gracefully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run starts the engine with the provided startup options.
|
||||
//
|
||||
// Default behavior with no options:
|
||||
// - HTTP only
|
||||
// - listens on :8080
|
||||
// - no graceful shutdown orchestration
|
||||
//
|
||||
// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable
|
||||
// signal-aware graceful shutdown and request-context cancellation semantics.
|
||||
// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown.
|
||||
func (engine *Engine) Run(opts ...RunOption) error {
|
||||
cfg := defaultRunConfig()
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
if err := opt.apply(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if cfg.httpRedirectAddr != "" {
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
} else if cfg.tlsConfig != nil {
|
||||
cfg.mode = runModeHTTPS
|
||||
}
|
||||
if err := validateRunConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serveTLS := cfg.mode != runModeHTTP
|
||||
|
||||
mainServer := buildMainServer(engine, cfg)
|
||||
servers := []*http.Server{mainServer}
|
||||
serveTLSFlags := []bool{serveTLS}
|
||||
if cfg.mode == runModeHTTPSRedirect {
|
||||
redirectServer, err := buildRedirectServer(engine, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
servers = append(servers, redirectServer)
|
||||
serveTLSFlags = append(serveTLSFlags, false)
|
||||
}
|
||||
|
||||
if !cfg.graceful {
|
||||
if len(servers) > 1 {
|
||||
serverStopped := make(chan error, len(servers))
|
||||
for i, srv := range servers {
|
||||
serveTLSFlag := serveTLSFlags[i]
|
||||
go func(server *http.Server, useTLS bool) {
|
||||
serverStopped <- serveServer(server, useTLS)
|
||||
}(srv, serveTLSFlag)
|
||||
}
|
||||
|
||||
err := <-serverStopped
|
||||
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return errors.Join(err, shutdownErr)
|
||||
}
|
||||
return shutdownErr
|
||||
}
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
protocolLabel := "HTTP"
|
||||
if serveTLS {
|
||||
protocolLabel = "HTTPS"
|
||||
}
|
||||
log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr)
|
||||
return serveServer(mainServer, serveTLS)
|
||||
}
|
||||
|
||||
shutdownCtx := context.Background()
|
||||
if cfg.gracefulCtx != nil {
|
||||
shutdownCtx = cfg.gracefulCtx
|
||||
}
|
||||
return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx)
|
||||
}
|
||||
|
|
|
|||
412
serve_test.go
412
serve_test.go
|
|
@ -2,15 +2,58 @@ package touka
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func generateSelfSignedCert(t *testing.T) tls.Certificate {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate private key: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "127.0.0.1"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("create self-signed cert: %v", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||
|
||||
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("parse self-signed cert: %v", err)
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
|
|
@ -79,3 +122,372 @@ func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
|
|||
t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
|
||||
engine := New()
|
||||
err := engine.Run(WithHTTPRedirect(":80"))
|
||||
if err == nil {
|
||||
t.Fatal("expected redirect mode without TLS to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
|
||||
engine := New()
|
||||
err := engine.Run(
|
||||
WithAddr(":443"),
|
||||
WithTLS(&tls.Config{}),
|
||||
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
|
||||
t.Fatalf("apply graceful default option: %v", err)
|
||||
}
|
||||
if !cfg.graceful {
|
||||
t.Fatal("expected graceful shutdown to be enabled")
|
||||
}
|
||||
if cfg.shutdownTimeout != defaultShutdownTimeout {
|
||||
t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
if err := WithTLS(tlsConfig).apply(&cfg); err != nil {
|
||||
t.Fatalf("apply TLS option: %v", err)
|
||||
}
|
||||
if cfg.mode != runModeHTTPS {
|
||||
t.Fatalf("expected HTTPS mode, got %v", cfg.mode)
|
||||
}
|
||||
if cfg.graceful {
|
||||
t.Fatal("expected TLS option to remain independent from graceful shutdown")
|
||||
}
|
||||
if cfg.tlsConfig != tlsConfig {
|
||||
t.Fatal("expected TLS config to be preserved in run config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
|
||||
engine := New()
|
||||
if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
|
||||
t.Fatal("expected redirect server builder to reject https address without port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
if err := WithShutdownContext(ctx).apply(&cfg); err != nil {
|
||||
t.Fatalf("apply shutdown context option: %v", err)
|
||||
}
|
||||
if err := validateRunConfig(cfg); err == nil {
|
||||
t.Fatal("expected shutdown context without graceful shutdown to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
cfg.httpRedirectAddr = ":80"
|
||||
if err := validateRunConfig(cfg); err != nil {
|
||||
t.Fatalf("validate run config: %v", err)
|
||||
}
|
||||
if cfg.mode != runModeHTTP {
|
||||
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
cfg.tlsConfig = &tls.Config{}
|
||||
cfg.useHeaderHost = false
|
||||
cfg.useHeaderHostSet = true
|
||||
if err := validateRunConfig(cfg); err == nil {
|
||||
t.Fatal("expected configured host mode without redirect host to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
|
||||
cfg := defaultRunConfig()
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
cfg.tlsConfig = &tls.Config{}
|
||||
cfg.useHeaderHost = true
|
||||
cfg.useHeaderHostSet = true
|
||||
cfg.redirectHost = "configured.example"
|
||||
if err := validateRunConfig(cfg); err == nil {
|
||||
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
|
||||
engine := New()
|
||||
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
|
||||
if server.BaseContext == nil {
|
||||
t.Fatal("expected graceful main server to set BaseContext")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen for base context check: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
if got := server.BaseContext(listener); got != engine.shutdownCtx {
|
||||
t.Fatal("expected graceful main server to use engine shutdown context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) {
|
||||
engine := New()
|
||||
serverConfigured := false
|
||||
tlsConfigured := false
|
||||
engine.SetServerConfigurator(func(s *http.Server) {
|
||||
serverConfigured = true
|
||||
s.ReadTimeout = time.Second
|
||||
})
|
||||
engine.SetTLSServerConfigurator(func(s *http.Server) {
|
||||
tlsConfigured = true
|
||||
s.IdleTimeout = time.Second
|
||||
})
|
||||
|
||||
server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||
if !tlsConfigured {
|
||||
t.Fatal("expected TLS configurator to run for HTTPS main server")
|
||||
}
|
||||
if serverConfigured {
|
||||
t.Fatal("expected generic server configurator to be skipped when TLS configurator is set")
|
||||
}
|
||||
if server.IdleTimeout != time.Second {
|
||||
t.Fatal("expected TLS configurator changes to be applied to HTTPS main server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
|
||||
engine := New()
|
||||
configured := false
|
||||
engine.SetServerConfigurator(func(s *http.Server) {
|
||||
configured = true
|
||||
s.ReadTimeout = time.Second
|
||||
})
|
||||
|
||||
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
if !configured {
|
||||
t.Fatal("expected redirect server to use generic server configurator")
|
||||
}
|
||||
if server.ReadTimeout != time.Second {
|
||||
t.Fatal("expected redirect server configurator changes to be applied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
|
||||
engine := New()
|
||||
httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
||||
if !httpsServer.Protocols.HTTP2() {
|
||||
t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings")
|
||||
}
|
||||
|
||||
httpServer := buildMainServer(engine, defaultRunConfig())
|
||||
if httpServer.Protocols.HTTP2() {
|
||||
t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||
req.Host = "example.com:80"
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{
|
||||
addr: ":443",
|
||||
httpRedirectAddr: ":80",
|
||||
useHeaderHost: true,
|
||||
useHeaderHostSet: true,
|
||||
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||
req.Host = "example.com:80"
|
||||
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
||||
req.Header.Set("X-First-Host", "first.example")
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{
|
||||
addr: ":443",
|
||||
httpRedirectAddr: ":80",
|
||||
useHeaderHost: true,
|
||||
useHeaderHostSet: true,
|
||||
redirectHostHeaders: []string{"X-Forwarded-Host"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||
req.Host = "example.com:80"
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUpgradeRequired {
|
||||
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{
|
||||
addr: ":443",
|
||||
httpRedirectAddr: ":80",
|
||||
useHeaderHost: false,
|
||||
useHeaderHostSet: true,
|
||||
redirectHost: "configured.example",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||
req.Host = "example.com:80"
|
||||
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
|
||||
t.Fatalf("unexpected redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) {
|
||||
engine := New()
|
||||
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil)
|
||||
req.Host = "[::1]:80"
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||
}
|
||||
if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" {
|
||||
t.Fatalf("unexpected IPv6 redirect location: %q", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
|
||||
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen on occupied addr: %v", err)
|
||||
}
|
||||
occupiedAddr := occupied.Addr().String()
|
||||
defer occupied.Close()
|
||||
|
||||
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen for redirect addr: %v", err)
|
||||
}
|
||||
redirectAddr := redirectListener.Addr().String()
|
||||
if err := redirectListener.Close(); err != nil {
|
||||
t.Fatalf("close redirect addr probe: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
|
||||
if err != nil {
|
||||
t.Fatalf("build redirect server: %v", err)
|
||||
}
|
||||
mainServer := &http.Server{Addr: occupiedAddr, Handler: engine}
|
||||
|
||||
err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected gracefulServe to fail when one server cannot bind")
|
||||
}
|
||||
if !strings.Contains(err.Error(), occupiedAddr) {
|
||||
t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err)
|
||||
}
|
||||
|
||||
conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond)
|
||||
if dialErr == nil {
|
||||
conn.Close()
|
||||
t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr)
|
||||
}
|
||||
if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") {
|
||||
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
|
||||
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen on occupied addr: %v", err)
|
||||
}
|
||||
occupiedAddr := occupied.Addr().String()
|
||||
defer occupied.Close()
|
||||
|
||||
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen for redirect addr: %v", err)
|
||||
}
|
||||
redirectAddr := redirectListener.Addr().String()
|
||||
if err := redirectListener.Close(); err != nil {
|
||||
t.Fatalf("close redirect addr probe: %v", err)
|
||||
}
|
||||
|
||||
engine := New()
|
||||
err = engine.Run(
|
||||
WithAddr(occupiedAddr),
|
||||
WithTLS(&tls.Config{}),
|
||||
WithHTTPRedirect(redirectAddr),
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), occupiedAddr) {
|
||||
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
8
touka.go
8
touka.go
|
|
@ -22,10 +22,10 @@ type HandlerFunc func(*Context)
|
|||
// HandlersChain 定义处理函数链(中间件栈)的类型。
|
||||
type HandlersChain []HandlerFunc
|
||||
|
||||
// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。
|
||||
type IRouter interface {
|
||||
Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组
|
||||
Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组
|
||||
// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。
|
||||
type Router interface {
|
||||
Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组
|
||||
Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组
|
||||
|
||||
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
|
||||
GET(relativePath string, handlers ...HandlerFunc)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue