Merge pull request #82 from infinite-iroha/break/v1-redesign-run-api

feat: redesign server startup around Run options
This commit is contained in:
里見 灯花 2026-04-07 20:54:47 +08:00 committed by GitHub
commit efa1e3fb3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1037 additions and 341 deletions

View file

@ -59,9 +59,9 @@ func main() {
c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query) c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query)
}) })
// 启动服务器 (支持优雅关闭) // 启动服务器(通过 WithGracefulShutdown 启用优雅关闭)
log.Println("Touka Server starting on :8080...") log.Println("Touka Server starting on :8080...")
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatalf("Touka server failed to start: %v", err) log.Fatalf("Touka server failed to start: %v", err)
} }
} }

View file

@ -70,13 +70,13 @@ func main() {
r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB
// ... 其他配置 // ... 其他配置
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```
#### 1.3. 服务器生命周期管理 #### 1.3. 服务器生命周期管理
Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。 Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。
```go ```go
func main() { func main() {
@ -90,11 +90,11 @@ func main() {
fmt.Println("自定义的 HTTP 服务器配置已应用") fmt.Println("自定义的 HTTP 服务器配置已应用")
}) })
// 启动服务器,并支持优雅关闭 // 启动服务器,并通过 Run 选项启用优雅关闭
// RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号 // Run(...) 会阻塞当前 goroutine
// 第二个参数是优雅关闭的超时时间 // WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒
fmt.Println("服务器启动于 :8080") fmt.Println("服务器启动于 :8080")
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatalf("服务器启动失败: %v", err) log.Fatalf("服务器启动失败: %v", err)
} }
} }
@ -187,7 +187,7 @@ func main() {
} }
} }
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
func AuthMiddleware() touka.HandlerFunc { func AuthMiddleware() touka.HandlerFunc {
@ -313,7 +313,7 @@ func main() {
}) })
}) })
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
// templates/index.html // templates/index.html
@ -400,7 +400,7 @@ func main() {
c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID}) c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID})
}) })
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```
@ -483,7 +483,7 @@ func main() {
// 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获 // 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获
r.StaticDir("/files", "./non-existent-dir") r.StaticDir("/files", "./non-existent-dir")
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```
@ -546,7 +546,7 @@ func main() {
// 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录 // 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录
r.StaticFS("/", http.FS(subFS)) r.StaticFS("/", http.FS(subFS))
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```

View file

@ -44,7 +44,9 @@ r.SetTLSServerConfigurator(func(server *http.Server) {
Touka 支持配置 HTTP/1.1、HTTP/2 和 H2CHTTP/2 Cleartext Touka 支持配置 HTTP/1.1、HTTP/2 和 H2CHTTP/2 Cleartext
```go ```go
// 使用默认协议配置(仅 HTTP/1.1 // 使用默认协议配置
// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集,
// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。
r.SetDefaultProtocols() r.SetDefaultProtocols()
// 自定义协议配置 // 自定义协议配置
@ -57,33 +59,147 @@ r.SetProtocols(&touka.ProtocolsConfig{
### 启动方式 ### 启动方式
Touka 提供了多种服务器启动方式 Touka 统一通过 `Run(opts...)` 启动服务器
```go ```go
// 1. 简单启动(无优雅停机) // 1. 简单启动(无优雅停机)
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
// 2. 带优雅停机的启动 // 2. 带优雅停机的启动
r.RunShutdown(":8080", 10*time.Second) r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second))
// 3. 带上下文的优雅停机 // 3. 带上下文的优雅停机
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
r.RunShutdownWithContext(":8080", ctx, 10*time.Second) defer cancel()
r.Run(
touka.WithAddr(":8080"),
touka.WithGracefulShutdown(10*time.Second),
touka.WithShutdownContext(ctx),
)
// 4. HTTPS 启动 // 4. HTTPS 启动
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
// 其他 TLS 配置... // 其他 TLS 配置...
} }
r.RunTLS(":443", tlsConfig, 10*time.Second) // WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithGracefulShutdownDefault(),
)
// 5. HTTPS + HTTP 重定向 // 5. HTTPS + HTTP 重定向
r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second) // WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(":80"),
touka.WithGracefulShutdown(10*time.Second),
)
// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(true),
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
),
)
// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(false),
touka.WithRedirectHost("example.com"),
),
)
``` ```
### HTTPS Redirect Host 策略
`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。
可用的 redirect 子选项:
- `touka.WithUseHeaderHost(true|false)`
- `touka.WithRedirectHostHeaders([]string{...})`
- `touka.WithRedirectHost("example.com")`
#### 模式一:使用请求输入侧的 host
`WithUseHeaderHost(true)` 时:
- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host`
- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header并使用第一个非空值
- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required`
示例:
```go
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(true),
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
),
)
```
#### 模式二:使用配置的固定 host
`WithUseHeaderHost(false)` 时:
- 不读取 `Request.Host`
- 不读取 `WithRedirectHostHeaders(...)`
- 必须配置 `WithRedirectHost("example.com")`
示例:
```go
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(false),
touka.WithRedirectHost("example.com"),
),
)
```
#### 严格校验规则
以下组合会直接返回配置错误:
- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)`
- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)`
- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)`
- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)`
- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)`
#### 优先级关系
1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式
2. `WithUseHeaderHost(...)` 决定 host 来源模式
3. 当 `WithUseHeaderHost(true)` 时:
- 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询
- 未配置时使用 `Request.Host`
4. 当 `WithUseHeaderHost(false)` 时:
- 只使用 `WithRedirectHost(...)`
**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。
## 优雅停机 (Graceful Shutdown) ## 优雅停机 (Graceful Shutdown)
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。 在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。
```go ```go
r := touka.Default() r := touka.Default()
@ -91,7 +207,7 @@ r := touka.Default()
// 监听 SIGINT 和 SIGTERM 信号 // 监听 SIGINT 和 SIGTERM 信号
// 如果在 10 秒内未处理完,则强制关闭 // 如果在 10 秒内未处理完,则强制关闭
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatal("服务器退出异常:", err) log.Fatal("服务器退出异常:", err)
} }
``` ```

View file

@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其
1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。 1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。
2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。 2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。
3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理 3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求
Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。 Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。

View file

@ -46,7 +46,7 @@ func main() {
// 4. 启动服务器并监听 8080 端口 // 4. 启动服务器并监听 8080 端口
log.Println("Touka server is running on :8080") log.Println("Touka server is running on :8080")
if err := r.Run(":8080"); err != nil { if err := r.Run(touka.WithAddr(":8080")); err != nil {
log.Fatalf("Server failed: %v", err) log.Fatalf("Server failed: %v", err)
} }
} }
@ -66,11 +66,11 @@ go run main.go
## 优雅停机 ## 优雅停机
在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成 在生产环境中,我们推荐`Run` 追加优雅关闭选项。启用后Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`
```go ```go
// 等待 10 秒以处理剩余请求 // 等待 10 秒以处理剩余请求
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatalf("Server forced to shutdown: %v", err) log.Fatalf("Server forced to shutdown: %v", err)
} }
``` ```

View file

@ -28,7 +28,7 @@ func main() {
Target: target, Target: target,
})) }))
_ = r.Run(":8080") _ = r.Run(touka.WithAddr(":8080"))
} }
``` ```
@ -497,7 +497,7 @@ func main() {
}, },
})) }))
if err := r.RunShutdown(":8080", 10*time.Second); err != nil { if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }

View file

@ -142,7 +142,7 @@ func main() {
r := touka.Default() r := touka.Default()
fsroot, _ := fs.Sub(content, "dist") fsroot, _ := fs.Sub(content, "dist")
r.StaticFS("/", http.FS(fsroot)) r.StaticFS("/", http.FS(fsroot))
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```

View file

@ -125,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) {
2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。 2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。
3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。 3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。
**注意:** 请务必使用 `RunShutdown``RunTLS``RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号 **注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)``touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文

View file

@ -39,7 +39,7 @@ func main() {
// 您也可以使用 StaticFS 服务根路径 // 您也可以使用 StaticFS 服务根路径
// r.StaticFS("/", http.FS(fsroot)) // r.StaticFS("/", http.FS(fsroot))
r.Run(":8080") r.Run(touka.WithAddr(":8080"))
} }
``` ```

View file

@ -404,11 +404,18 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) {
}() }()
} }
// applyDefaultServerConfig 应用框架的默认配置到 http.Server func cloneServerProtocols(protocols *http.Protocols) *http.Protocols {
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { if protocols == nil {
if engine.serverProtocols != nil { return nil
srv.Protocols = engine.serverProtocols }
if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() { cloned := *protocols
return &cloned
}
func applyServerProtocols(srv *http.Server, protocols *http.Protocols) {
if protocols != nil {
srv.Protocols = cloneServerProtocols(protocols)
if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() {
if err := configureHTTP2ExtendedConnectServer(srv); err != nil { if err := configureHTTP2ExtendedConnectServer(srv); err != nil {
panic(err) panic(err)
} }
@ -416,6 +423,11 @@ func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
} }
} }
// applyDefaultServerConfig 应用框架的默认配置到 http.Server
func (engine *Engine) applyDefaultServerConfig(srv *http.Server) {
applyServerProtocols(srv, engine.serverProtocols)
}
// 配置全局Req Body大小限制 // 配置全局Req Body大小限制
func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) { func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) {
engine.GlobalMaxRequestBodySize = size engine.GlobalMaxRequestBodySize = size
@ -614,7 +626,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
// Use 将全局中间件添加到 Engine // Use 将全局中间件添加到 Engine
// 这些中间件将应用于所有注册的路由 // 这些中间件将应用于所有注册的路由
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { func (engine *Engine) Use(middleware ...HandlerFunc) Router {
engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.globalHandlers = append(engine.globalHandlers, middleware...)
engine.rebuildFallbackChains() engine.rebuildFallbackChains()
return engine return engine
@ -683,7 +695,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo {
// Group 创建一个新的路由组 // Group 创建一个新的路由组
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router {
return &RouterGroup{ return &RouterGroup{
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
basePath: resolveRoutePath("/", relativePath), basePath: resolveRoutePath("/", relativePath),
@ -692,7 +704,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute
} }
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
// 它也实现了 IRouter 接口,允许嵌套分组 // 它也实现了 Router 接口,允许嵌套分组
type RouterGroup struct { type RouterGroup struct {
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
basePath string // 组路径前缀 basePath string // 组路径前缀
@ -701,7 +713,7 @@ type RouterGroup struct {
// Use 将中间件应用于当前路由组 // Use 将中间件应用于当前路由组
// 这些中间件将应用于当前组及其子组的所有路由 // 这些中间件将应用于当前组及其子组的所有路由
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { func (group *RouterGroup) Use(middleware ...HandlerFunc) Router {
group.Handlers = append(group.Handlers, middleware...) group.Handlers = append(group.Handlers, middleware...)
return group return group
} }
@ -747,7 +759,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) {
} }
// Group 为当前组创建一个新的子组 // Group 为当前组创建一个新的子组
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router {
return &RouterGroup{ return &RouterGroup{
Handlers: group.engine.combineHandlers(group.Handlers, handlers), Handlers: group.engine.combineHandlers(group.Handlers, handlers),
basePath: resolveRoutePath(group.basePath, relativePath), basePath: resolveRoutePath(group.basePath, relativePath),

View file

@ -70,42 +70,25 @@ func TestApplyDefaultServerConfig(t *testing.T) {
} }
} }
func TestRunTLSProtocolInheritance(t *testing.T) { func TestTLSRunDefaultsProtocolInheritance(t *testing.T) {
engine := New() engine := New()
// 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2 srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if engine.useDefaultProtocols {
engine.setProtocols(&ProtocolsConfig{
Http1: true,
Http2: true,
})
}
srv := &http.Server{TLSConfig: &tls.Config{}}
engine.applyDefaultServerConfig(srv)
if !srv.Protocols.HTTP2() { if !srv.Protocols.HTTP2() {
t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config") t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config")
} }
// 模拟用户设置了自定义协议后调用 RunTLS // 模拟用户设置了自定义协议后进入 TLS 运行模式
engine = New() engine = New()
engine.SetProtocols(&ProtocolsConfig{ engine.SetProtocols(&ProtocolsConfig{
Http1: true, Http1: true,
Http2: false, // 用户明确不想要 HTTP/2 Http2: false, // 用户明确不想要 HTTP/2
}) })
if engine.useDefaultProtocols { srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
engine.setProtocols(&ProtocolsConfig{
Http1: true,
Http2: true,
})
}
srv2 := &http.Server{TLSConfig: &tls.Config{}}
engine.applyDefaultServerConfig(srv2)
if srv2.Protocols.HTTP2() { if srv2.Protocols.HTTP2() {
t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously") t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols")
} }
} }

719
serve.go
View file

@ -14,6 +14,7 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -21,45 +22,173 @@ import (
"github.com/fenthope/reco" "github.com/fenthope/reco"
) )
// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间
const defaultShutdownTimeout = 5 * time.Second const defaultShutdownTimeout = 5 * time.Second
// --- 内部辅助函数 --- type runMode uint8
// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080" const (
func resolveAddress(addr []string) string { runModeHTTP runMode = iota
switch len(addr) { runModeHTTPS
case 0: runModeHTTPSRedirect
return ":8080" )
case 1:
return addr[0] type runConfig struct {
default: addr string
panic("too many parameters provided for server address") httpRedirectAddr string
tlsConfig *tls.Config
redirectHost string
redirectHostHeaders []string
useHeaderHost bool
useHeaderHostSet bool
graceful bool
shutdownTimeout time.Duration
gracefulCtx context.Context
mode runMode
shutdownDefaultSet bool
shutdownTimeoutSet bool
}
type RunOption interface {
apply(*runConfig) error
}
type runOptionFunc func(*runConfig) error
func (f runOptionFunc) apply(cfg *runConfig) error {
return f(cfg)
}
func defaultRunConfig() runConfig {
return runConfig{
addr: ":8080",
shutdownTimeout: defaultShutdownTimeout,
mode: runModeHTTP,
useHeaderHost: true,
} }
} }
// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值 type HTTPRedirectOption interface {
func getShutdownTimeout(timeouts []time.Duration) time.Duration { applyRedirect(*runConfig) error
if len(timeouts) > 0 && timeouts[0] > 0 { }
return timeouts[0]
} type redirectOptionFunc func(*runConfig) error
return defaultShutdownTimeout
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
return f(cfg)
}
func WithAddr(addr string) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if addr == "" {
return errors.New("run address must not be empty")
}
cfg.addr = addr
return nil
})
}
func WithTLS(tlsConfig *tls.Config) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if tlsConfig == nil {
return errors.New("tls.Config must not be nil")
}
cfg.tlsConfig = tlsConfig
if cfg.mode == runModeHTTP {
cfg.mode = runModeHTTPS
}
return nil
})
}
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if addr == "" {
return errors.New("http redirect address must not be empty")
}
cfg.httpRedirectAddr = addr
cfg.mode = runModeHTTPSRedirect
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt.applyRedirect(cfg); err != nil {
return err
}
}
return nil
})
}
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.useHeaderHost = enabled
cfg.useHeaderHostSet = true
return nil
})
}
func WithRedirectHost(host string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
if host == "" {
return errors.New("redirect host must not be empty")
}
cfg.redirectHost = host
return nil
})
}
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
for _, header := range headers {
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
if trimmed != "" {
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
}
}
return nil
})
}
func WithGracefulShutdown(timeout time.Duration) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
cfg.graceful = true
cfg.shutdownTimeoutSet = true
if timeout > 0 {
cfg.shutdownTimeout = timeout
} else {
cfg.shutdownTimeout = defaultShutdownTimeout
}
return nil
})
}
func WithGracefulShutdownDefault() RunOption {
return runOptionFunc(func(cfg *runConfig) error {
cfg.graceful = true
cfg.shutdownDefaultSet = true
cfg.shutdownTimeout = defaultShutdownTimeout
return nil
})
}
func WithShutdownContext(ctx context.Context) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if ctx == nil {
return errors.New("shutdown context must not be nil")
}
cfg.gracefulCtx = ctx
return nil
})
} }
// serveServer 根据显式指定的启动模式运行 HTTP 或 HTTPS 服务器.
func serveServer(srv *http.Server, serveTLS bool) error { func serveServer(srv *http.Server, serveTLS bool) error {
if serveTLS { if serveTLS {
// 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置,
// ListenAndServeTLS 的前两个参数可以为空字符串
return srv.ListenAndServeTLS("", "") return srv.ListenAndServeTLS("", "")
} }
return srv.ListenAndServe() return srv.ListenAndServe()
} }
// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server,
// 并处理其启动失败的致命错误
// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS")
func runServer(serverType string, srv *http.Server, serveTLS bool) { func runServer(serverType string, srv *http.Server, serveTLS bool) {
go func() { go func() {
protocol := "http" protocol := "http"
@ -70,284 +199,145 @@ func runServer(serverType string, srv *http.Server, serveTLS bool) {
log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr)
err := serveServer(srv, serveTLS) err := serveServer(srv, serveTLS)
// 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed),
// 则认为是一个严重错误,并终止程序
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Touka %s server failed: %v", serverType, err) log.Fatalf("Touka %s server failed: %v", serverType, err)
} }
}() }()
} }
// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器 func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config {
// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿 if tlsConfig == nil {
func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error { return nil
// 创建一个 channel 来接收操作系统信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号
<-quit // 阻塞,直到接收到上述信号之一
log.Println("Shutting down Touka server(s)...")
// 关闭日志记录器
if logger != nil {
go func() {
log.Println("Closing Touka logger...")
CloseLogger(logger)
}()
} }
return tlsConfig.Clone()
// 创建一个带超时的上下文,用于 Shutdown
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
// 并发地关闭所有服务器
for _, srv := range servers {
wg.Add(1)
go func(s *http.Server) {
defer wg.Done()
if err := s.Shutdown(ctx); err != nil {
// 将错误发送到 channel
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
}
}(srv)
}
wg.Wait() // 等待所有服务器的关闭 goroutine 完成
close(errChan) // 关闭 channel,以便可以安全地遍历它
// 收集所有关闭过程中发生的错误
var shutdownErrors []error
for err := range errChan {
shutdownErrors = append(shutdownErrors, err)
log.Printf("Shutdown error: %v", err)
}
if len(shutdownErrors) > 0 {
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
}
log.Println("Touka server(s) exited gracefully.")
return nil
} }
func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error { func parseHTTPSPort(addr string) (string, error) {
// 创建一个 channel 来接收操作系统信号 _, port, err := net.SplitHostPort(addr)
quit := make(chan os.Signal, 1) if err != nil {
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 return "", fmt.Errorf("https address %q must include a port: %w", addr, err)
// 启动服务器
serverStopped := make(chan error, 1)
for _, srv := range servers {
go func(s *http.Server) {
serverStopped <- s.ListenAndServe()
}(srv)
} }
return port, nil
}
select { func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) {
case <-ctx.Done(): if serveTLS {
// Context 被取消 (例如,通过外部取消函数) if engine.TLSServerConfigurator != nil {
log.Println("Context cancelled, shutting down Touka server(s)...") engine.TLSServerConfigurator(srv)
case err := <-serverStopped: return
// 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("Touka HTTP server failed: %w", err)
} }
log.Println("Touka HTTP server stopped gracefully.")
return nil // 服务器已自行优雅关闭,无需进一步处理
case <-quit:
// 接收到操作系统信号
log.Println("Shutting down Touka server(s) due to OS signal...")
} }
// 关闭日志记录器
if logger != nil {
go func() {
log.Println("Closing Touka logger...")
CloseLogger(logger)
}()
}
// 创建一个带超时的上下文,用于 Shutdown
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel
// 并发地关闭所有服务器
for _, srv := range servers {
wg.Add(1)
go func(s *http.Server) {
defer wg.Done()
if err := s.Shutdown(shutdownCtx); err != nil {
// 将错误发送到 channel
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
}
}(srv)
}
wg.Wait()
close(errChan) // 关闭 channel,以便可以安全地遍历它
// 收集所有关闭过程中发生的错误
var shutdownErrors []error
for err := range errChan {
shutdownErrors = append(shutdownErrors, err)
log.Printf("Shutdown error: %v", err)
}
if len(shutdownErrors) > 0 {
return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误
}
log.Println("Touka server(s) exited gracefully.")
return nil
}
// --- 公共 Run 方法 ---
// Run 启动一个不支持优雅关闭的 HTTP 服务器
// 这是一个阻塞调用,主要用于简单的场景或快速测试
// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法
func (engine *Engine) Run(addr ...string) error {
address := resolveAddress(addr)
srv := &http.Server{Addr: address, Handler: engine}
// 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性
engine.applyDefaultServerConfig(srv)
if engine.ServerConfigurator != nil { if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv) engine.ServerConfigurator(srv)
} }
log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address)
return srv.ListenAndServe()
} }
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 func applyRedirectServerConfig(engine *Engine, srv *http.Server) {
func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error { applyServerProtocols(srv, engine.serverProtocols)
srv := &http.Server{
Addr: addr,
Handler: engine,
BaseContext: func(l net.Listener) context.Context {
return engine.shutdownCtx
},
}
srv.RegisterOnShutdown(engine.shutdownCancel)
// 应用框架的默认配置和用户提供的自定义配置
engine.applyDefaultServerConfig(srv)
if engine.ServerConfigurator != nil { if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv) engine.ServerConfigurator(srv)
} }
runServer("HTTP", srv, false)
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
} }
// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols {
func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error { if engine == nil {
srv := &http.Server{ return nil
Addr: addr,
Handler: engine,
BaseContext: func(l net.Listener) context.Context {
return engine.shutdownCtx
},
} }
srv.RegisterOnShutdown(engine.shutdownCancel) if serveTLS && engine.useDefaultProtocols {
protocols := &http.Protocols{}
// 应用框架的默认配置和用户提供的自定义配置 protocols.SetHTTP1(true)
engine.applyDefaultServerConfig(srv) protocols.SetHTTP2(true)
if engine.ServerConfigurator != nil { return protocols
engine.ServerConfigurator(srv)
} }
return cloneServerProtocols(engine.serverProtocols)
return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco)
} }
// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器 func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { serveTLS := cfg.mode != runModeHTTP
if tlsConfig == nil { server := &http.Server{
return errors.New("tls.Config must not be nil for RunTLS") Addr: cfg.addr,
}
// 配置 HTTP/2 支持 (如果使用默认配置)
if engine.useDefaultProtocols {
engine.setProtocols(&ProtocolsConfig{
Http1: true,
Http2: true, // 默认在 TLS 上启用 HTTP/2
})
}
srv := &http.Server{
Addr: addr,
Handler: engine, Handler: engine,
TLSConfig: tlsConfig, TLSConfig: cloneTLSConfig(cfg.tlsConfig),
BaseContext: func(l net.Listener) context.Context { }
if cfg.graceful {
server.BaseContext = func(net.Listener) context.Context {
return engine.shutdownCtx return engine.shutdownCtx
}, }
server.RegisterOnShutdown(engine.shutdownCancel)
} }
srv.RegisterOnShutdown(engine.shutdownCancel) applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS))
applyMainServerConfig(engine, server, serveTLS)
// 应用框架的默认配置和用户提供的自定义配置 return server
// 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator
engine.applyDefaultServerConfig(srv)
if engine.TLSServerConfigurator != nil {
engine.TLSServerConfigurator(srv)
} else if engine.ServerConfigurator != nil {
engine.ServerConfigurator(srv)
}
runServer("HTTPS", srv, true)
return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco)
} }
// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名 func firstRedirectHeaderHost(r *http.Request, headers []string) string {
func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { if r == nil {
return engine.RunTLS(addr, tlsConfig, timeouts...) return ""
}
for _, header := range headers {
value := strings.TrimSpace(r.Header.Get(header))
if value == "" {
continue
}
if comma := strings.IndexByte(value, ','); comma >= 0 {
value = strings.TrimSpace(value[:comma])
}
if value != "" {
return value
}
}
return ""
} }
// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭 func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if tlsConfig == nil { if cfg.redirectHost == "" {
return errors.New("tls.Config must not be nil for RunTLSRedir") return "", http.StatusInternalServerError, false
}
return cfg.redirectHost, 0, true
} }
// --- HTTPS 服务器 --- if len(cfg.redirectHostHeaders) > 0 {
if engine.useDefaultProtocols { host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true}) if host == "" {
} return "", http.StatusUpgradeRequired, false
httpsSrv := &http.Server{ }
Addr: httpsAddr, return host, 0, true
Handler: engine, }
TLSConfig: tlsConfig,
BaseContext: func(l net.Listener) context.Context { if r == nil {
return engine.shutdownCtx return "", http.StatusUpgradeRequired, false
}, }
} host := strings.TrimSpace(r.Host)
httpsSrv.RegisterOnShutdown(engine.shutdownCancel) if host == "" {
engine.applyDefaultServerConfig(httpsSrv) return "", http.StatusUpgradeRequired, false
if engine.TLSServerConfigurator != nil { }
engine.TLSServerConfigurator(httpsSrv) return host, 0, true
} else if engine.ServerConfigurator != nil { }
engine.ServerConfigurator(httpsSrv)
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
httpsAddr := cfg.addr
httpAddr := cfg.httpRedirectAddr
httpsPort, err := parseHTTPSPort(httpsAddr)
if err != nil {
return nil, err
} }
// --- HTTP 重定向服务器 ---
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host) host, statusCode, ok := redirectTargetHost(r, cfg)
if err != nil { if !ok {
host = r.Host http.Error(w, http.StatusText(statusCode), statusCode)
return
} }
_, httpsPort, err := net.SplitHostPort(httpsAddr) if parsedHost, _, err := net.SplitHostPort(host); err == nil {
if err != nil { host = parsedHost
// 如果 httpsAddr 没有端口,这是一个配置错误 if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
host = "[" + host + "]"
log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr) }
} }
targetURL := "https://" + host targetURL := "https://" + host
// 只有在非标准 HTTPS 端口 (443) 时才附加端口号
if httpsPort != "443" { if httpsPort != "443" {
targetURL = "https://" + net.JoinHostPort(host, httpsPort) targetURL = "https://" + net.JoinHostPort(host, httpsPort)
} }
@ -355,22 +345,205 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con
http.Redirect(w, r, targetURL, http.StatusMovedPermanently) http.Redirect(w, r, targetURL, http.StatusMovedPermanently)
}) })
httpSrv := &http.Server{
Addr: httpAddr,
Handler: redirectHandler,
}
engine.applyDefaultServerConfig(httpSrv)
if engine.ServerConfigurator != nil {
engine.ServerConfigurator(httpSrv)
}
// --- 启动服务器和优雅关闭 --- server := &http.Server{Addr: httpAddr, Handler: redirectHandler}
runServer("HTTPS", httpsSrv, true) applyRedirectServerConfig(engine, server)
runServer("HTTP Redirect", httpSrv, false) return server, nil
return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco)
} }
// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性 func validateRunConfig(cfg runConfig) error {
func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil {
return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...) return errors.New("WithHTTPRedirect requires WithTLS")
}
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
return errors.New("https mode requires WithTLS")
}
if cfg.gracefulCtx != nil && !cfg.graceful {
return errors.New("WithShutdownContext requires graceful shutdown")
}
if len(cfg.redirectHostHeaders) > 0 {
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
}
}
if cfg.useHeaderHostSet && cfg.useHeaderHost {
if cfg.redirectHost != "" {
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
}
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if cfg.redirectHost == "" {
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
}
if len(cfg.redirectHostHeaders) > 0 {
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
}
}
return nil
}
func effectiveShutdownTimeout(cfg runConfig) time.Duration {
if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet {
if cfg.shutdownTimeout > 0 {
return cfg.shutdownTimeout
}
}
return defaultShutdownTimeout
}
func closeLoggerAsync(logger *reco.Logger) {
if logger == nil {
return
}
go func() {
log.Println("Closing Touka logger...")
CloseLogger(logger)
}()
}
func shutdownServers(servers []*http.Server, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, len(servers))
for _, srv := range servers {
wg.Add(1)
go func(s *http.Server) {
defer wg.Done()
if err := s.Shutdown(ctx); err != nil {
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
}
}(srv)
}
wg.Wait()
close(errChan)
var shutdownErrors []error
for err := range errChan {
shutdownErrors = append(shutdownErrors, err)
log.Printf("Shutdown error: %v", err)
}
if len(shutdownErrors) > 0 {
return errors.Join(shutdownErrors...)
}
return nil
}
func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(quit)
serverStopped := make(chan error, len(servers))
for i, srv := range servers {
serveTLSFlag := serveTLS[i]
go func(server *http.Server, useTLS bool) {
serverStopped <- serveServer(server, useTLS)
}(srv, serveTLSFlag)
}
select {
case err := <-serverStopped:
if err != nil && !errors.Is(err, http.ErrServerClosed) {
if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil {
return errors.Join(err, shutdownErr)
}
return err
}
log.Println("Touka server stopped gracefully.")
return nil
case <-quit:
log.Println("Shutting down Touka server(s) due to OS signal...")
case <-shutdownCtx.Done():
log.Println("Context cancelled, shutting down Touka server(s)...")
}
closeLoggerAsync(logger)
if err := shutdownServers(servers, timeout); err != nil {
return err
}
log.Println("Touka server(s) exited gracefully.")
return nil
}
// Run starts the engine with the provided startup options.
//
// Default behavior with no options:
// - HTTP only
// - listens on :8080
// - no graceful shutdown orchestration
//
// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable
// signal-aware graceful shutdown and request-context cancellation semantics.
// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown.
func (engine *Engine) Run(opts ...RunOption) error {
cfg := defaultRunConfig()
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt.apply(&cfg); err != nil {
return err
}
}
if cfg.httpRedirectAddr != "" {
cfg.mode = runModeHTTPSRedirect
} else if cfg.tlsConfig != nil {
cfg.mode = runModeHTTPS
}
if err := validateRunConfig(cfg); err != nil {
return err
}
serveTLS := cfg.mode != runModeHTTP
mainServer := buildMainServer(engine, cfg)
servers := []*http.Server{mainServer}
serveTLSFlags := []bool{serveTLS}
if cfg.mode == runModeHTTPSRedirect {
redirectServer, err := buildRedirectServer(engine, cfg)
if err != nil {
return err
}
servers = append(servers, redirectServer)
serveTLSFlags = append(serveTLSFlags, false)
}
if !cfg.graceful {
if len(servers) > 1 {
serverStopped := make(chan error, len(servers))
for i, srv := range servers {
serveTLSFlag := serveTLSFlags[i]
go func(server *http.Server, useTLS bool) {
serverStopped <- serveServer(server, useTLS)
}(srv, serveTLSFlag)
}
err := <-serverStopped
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return errors.Join(err, shutdownErr)
}
return shutdownErr
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
protocolLabel := "HTTP"
if serveTLS {
protocolLabel = "HTTPS"
}
log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr)
return serveServer(mainServer, serveTLS)
}
shutdownCtx := context.Background()
if cfg.gracefulCtx != nil {
shutdownCtx = cfg.gracefulCtx
}
return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx)
} }

View file

@ -2,15 +2,58 @@ package touka
import ( import (
"context" "context"
"crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors" "errors"
"io" "io"
"math/big"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
) )
func generateSelfSignedCert(t *testing.T) tls.Certificate {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate private key: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "127.0.0.1"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("create self-signed cert: %v", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("parse self-signed cert: %v", err)
}
return cert
}
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
@ -79,3 +122,372 @@ func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err) t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err)
} }
} }
func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
engine := New()
err := engine.Run(WithHTTPRedirect(":80"))
if err == nil {
t.Fatal("expected redirect mode without TLS to fail")
}
}
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
engine := New()
err := engine.Run(
WithAddr(":443"),
WithTLS(&tls.Config{}),
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
)
if err == nil {
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
}
}
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
cfg := defaultRunConfig()
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
t.Fatalf("apply graceful default option: %v", err)
}
if !cfg.graceful {
t.Fatal("expected graceful shutdown to be enabled")
}
if cfg.shutdownTimeout != defaultShutdownTimeout {
t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout)
}
}
func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
cfg := defaultRunConfig()
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
if err := WithTLS(tlsConfig).apply(&cfg); err != nil {
t.Fatalf("apply TLS option: %v", err)
}
if cfg.mode != runModeHTTPS {
t.Fatalf("expected HTTPS mode, got %v", cfg.mode)
}
if cfg.graceful {
t.Fatal("expected TLS option to remain independent from graceful shutdown")
}
if cfg.tlsConfig != tlsConfig {
t.Fatal("expected TLS config to be preserved in run config")
}
}
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
engine := New()
if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
t.Fatal("expected redirect server builder to reject https address without port")
}
}
func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
cfg := defaultRunConfig()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := WithShutdownContext(ctx).apply(&cfg); err != nil {
t.Fatalf("apply shutdown context option: %v", err)
}
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected shutdown context without graceful shutdown to fail validation")
}
}
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
cfg := defaultRunConfig()
cfg.httpRedirectAddr = ":80"
if err := validateRunConfig(cfg); err != nil {
t.Fatalf("validate run config: %v", err)
}
if cfg.mode != runModeHTTP {
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
}
}
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = false
cfg.useHeaderHostSet = true
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected configured host mode without redirect host to fail validation")
}
}
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = true
cfg.useHeaderHostSet = true
cfg.redirectHost = "configured.example"
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
}
}
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
engine := New()
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
if server.BaseContext == nil {
t.Fatal("expected graceful main server to set BaseContext")
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for base context check: %v", err)
}
defer listener.Close()
if got := server.BaseContext(listener); got != engine.shutdownCtx {
t.Fatal("expected graceful main server to use engine shutdown context")
}
}
func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) {
engine := New()
serverConfigured := false
tlsConfigured := false
engine.SetServerConfigurator(func(s *http.Server) {
serverConfigured = true
s.ReadTimeout = time.Second
})
engine.SetTLSServerConfigurator(func(s *http.Server) {
tlsConfigured = true
s.IdleTimeout = time.Second
})
server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if !tlsConfigured {
t.Fatal("expected TLS configurator to run for HTTPS main server")
}
if serverConfigured {
t.Fatal("expected generic server configurator to be skipped when TLS configurator is set")
}
if server.IdleTimeout != time.Second {
t.Fatal("expected TLS configurator changes to be applied to HTTPS main server")
}
}
func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
engine := New()
configured := false
engine.SetServerConfigurator(func(s *http.Server) {
configured = true
s.ReadTimeout = time.Second
})
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
if !configured {
t.Fatal("expected redirect server to use generic server configurator")
}
if server.ReadTimeout != time.Second {
t.Fatal("expected redirect server configurator changes to be applied")
}
}
func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
engine := New()
httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if !httpsServer.Protocols.HTTP2() {
t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings")
}
httpServer := buildMainServer(engine, defaultRunConfig())
if httpServer.Protocols.HTTP2() {
t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled")
}
}
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
req.Header.Set("X-First-Host", "first.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUpgradeRequired {
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
}
}
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: false,
useHeaderHostSet: true,
redirectHost: "configured.example",
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil)
req.Host = "[::1]:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" {
t.Fatalf("unexpected IPv6 redirect location: %q", location)
}
}
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on occupied addr: %v", err)
}
occupiedAddr := occupied.Addr().String()
defer occupied.Close()
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for redirect addr: %v", err)
}
redirectAddr := redirectListener.Addr().String()
if err := redirectListener.Close(); err != nil {
t.Fatalf("close redirect addr probe: %v", err)
}
engine := New()
redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
mainServer := &http.Server{Addr: occupiedAddr, Handler: engine}
err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background())
if err == nil {
t.Fatal("expected gracefulServe to fail when one server cannot bind")
}
if !strings.Contains(err.Error(), occupiedAddr) {
t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err)
}
conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond)
if dialErr == nil {
conn.Close()
t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr)
}
if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") {
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
}
}
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on occupied addr: %v", err)
}
occupiedAddr := occupied.Addr().String()
defer occupied.Close()
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for redirect addr: %v", err)
}
redirectAddr := redirectListener.Addr().String()
if err := redirectListener.Close(); err != nil {
t.Fatalf("close redirect addr probe: %v", err)
}
engine := New()
err = engine.Run(
WithAddr(occupiedAddr),
WithTLS(&tls.Config{}),
WithHTTPRedirect(redirectAddr),
)
if err == nil {
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
}
if !strings.Contains(err.Error(), occupiedAddr) {
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
}
}

View file

@ -22,10 +22,10 @@ type HandlerFunc func(*Context)
// HandlersChain 定义处理函数链(中间件栈)的类型。 // HandlersChain 定义处理函数链(中间件栈)的类型。
type HandlersChain []HandlerFunc type HandlersChain []HandlerFunc
// IRouter 定义了路由注册的接口提供路由分组和HTTP方法注册的能力。 // Router 定义了路由注册的接口提供路由分组和HTTP方法注册的能力。
type IRouter interface { type Router interface {
Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组
Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
GET(relativePath string, handlers ...HandlerFunc) GET(relativePath string, handlers ...HandlerFunc)