diff --git a/README.md b/README.md index a7b99fd..e2eaec8 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,9 @@ func main() { c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query) }) - // 启动服务器 (支持优雅关闭) + // 启动服务器(通过 WithGracefulShutdown 启用优雅关闭) log.Println("Touka Server starting on :8080...") - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("Touka server failed to start: %v", err) } } diff --git a/about-touka.md b/about-touka.md index 86a056f..b3a16b4 100644 --- a/about-touka.md +++ b/about-touka.md @@ -70,13 +70,13 @@ func main() { r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB // ... 其他配置 - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` #### 1.3. 服务器生命周期管理 -Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。 +Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。 ```go func main() { @@ -90,11 +90,11 @@ func main() { fmt.Println("自定义的 HTTP 服务器配置已应用") }) - // 启动服务器,并支持优雅关闭 - // RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号 - // 第二个参数是优雅关闭的超时时间 + // 启动服务器,并通过 Run 选项启用优雅关闭 + // Run(...) 会阻塞当前 goroutine + // WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒 fmt.Println("服务器启动于 :8080") - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("服务器启动失败: %v", err) } } @@ -187,7 +187,7 @@ func main() { } } - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } func AuthMiddleware() touka.HandlerFunc { @@ -313,7 +313,7 @@ func main() { }) }) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } // templates/index.html @@ -400,7 +400,7 @@ func main() { c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID}) }) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` @@ -483,7 +483,7 @@ func main() { // 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获 r.StaticDir("/files", "./non-existent-dir") - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` @@ -546,7 +546,7 @@ func main() { // 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录 r.StaticFS("/", http.FS(subFS)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/docs/advanced.md b/docs/advanced.md index a7cb9a2..eb44c2d 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -44,7 +44,9 @@ r.SetTLSServerConfigurator(func(server *http.Server) { Touka 支持配置 HTTP/1.1、HTTP/2 和 H2C(HTTP/2 Cleartext): ```go -// 使用默认协议配置(仅 HTTP/1.1) +// 使用默认协议配置 +// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集, +// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。 r.SetDefaultProtocols() // 自定义协议配置 @@ -57,33 +59,147 @@ r.SetProtocols(&touka.ProtocolsConfig{ ### 启动方式 -Touka 提供了多种服务器启动方式: +Touka 统一通过 `Run(opts...)` 启动服务器: ```go // 1. 简单启动(无优雅停机) -r.Run(":8080") +r.Run(touka.WithAddr(":8080")) // 2. 带优雅停机的启动 -r.RunShutdown(":8080", 10*time.Second) +r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)) // 3. 带上下文的优雅停机 ctx, cancel := context.WithCancel(context.Background()) -r.RunShutdownWithContext(":8080", ctx, 10*time.Second) +defer cancel() +r.Run( + touka.WithAddr(":8080"), + touka.WithGracefulShutdown(10*time.Second), + touka.WithShutdownContext(ctx), +) // 4. HTTPS 启动 tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, // 其他 TLS 配置... } -r.RunTLS(":443", tlsConfig, 10*time.Second) +// WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。 +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithGracefulShutdownDefault(), +) // 5. HTTPS + HTTP 重定向 -r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second) +// WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。 +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect(":80"), + touka.WithGracefulShutdown(10*time.Second), +) + +// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) + +// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) ``` +### HTTPS Redirect Host 策略 + +`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。 + +可用的 redirect 子选项: + +- `touka.WithUseHeaderHost(true|false)` +- `touka.WithRedirectHostHeaders([]string{...})` +- `touka.WithRedirectHost("example.com")` + +#### 模式一:使用请求输入侧的 host + +当 `WithUseHeaderHost(true)` 时: + +- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host` +- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值 +- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) +``` + +#### 模式二:使用配置的固定 host + +当 `WithUseHeaderHost(false)` 时: + +- 不读取 `Request.Host` +- 不读取 `WithRedirectHostHeaders(...)` +- 必须配置 `WithRedirectHost("example.com")` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) +``` + +#### 严格校验规则 + +以下组合会直接返回配置错误: + +- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)` +- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)` +- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)` +- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)` +- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)` + +#### 优先级关系 + +1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式 +2. `WithUseHeaderHost(...)` 决定 host 来源模式 +3. 当 `WithUseHeaderHost(true)` 时: + - 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询 + - 未配置时使用 `Request.Host` +4. 当 `WithUseHeaderHost(false)` 时: + - 只使用 `WithRedirectHost(...)` + +**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。 + ## 优雅停机 (Graceful Shutdown) -在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。 +在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 ```go r := touka.Default() @@ -91,7 +207,7 @@ r := touka.Default() // 监听 SIGINT 和 SIGTERM 信号 // 如果在 10 秒内未处理完,则强制关闭 -if err := r.RunShutdown(":8080", 10*time.Second); err != nil { +if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatal("服务器退出异常:", err) } ``` diff --git a/docs/introduction.md b/docs/introduction.md index 94a7310..87c3e40 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其 1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。 2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。 -3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理。 +3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求。 Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。 diff --git a/docs/quickstart.md b/docs/quickstart.md index 94f7433..2911732 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -46,7 +46,7 @@ func main() { // 4. 启动服务器并监听 8080 端口 log.Println("Touka server is running on :8080") - if err := r.Run(":8080"); err != nil { + if err := r.Run(touka.WithAddr(":8080")); err != nil { log.Fatalf("Server failed: %v", err) } } @@ -66,11 +66,11 @@ go run main.go ## 优雅停机 -在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成。 +在生产环境中,我们推荐为 `Run` 追加优雅关闭选项。启用后,Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`。 ```go // 等待 10 秒以处理剩余请求 -if err := r.RunShutdown(":8080", 10*time.Second); err != nil { +if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("Server forced to shutdown: %v", err) } ``` diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 7d05290..cb4b2a3 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -28,7 +28,7 @@ func main() { Target: target, })) - _ = r.Run(":8080") + _ = r.Run(touka.WithAddr(":8080")) } ``` @@ -497,7 +497,7 @@ func main() { }, })) - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatal(err) } } diff --git a/docs/routing.md b/docs/routing.md index 223081a..70a24dc 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -142,7 +142,7 @@ func main() { r := touka.Default() fsroot, _ := fs.Sub(content, "dist") r.StaticFS("/", http.FS(fsroot)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/docs/sse.md b/docs/sse.md index 1b44521..a003be9 100644 --- a/docs/sse.md +++ b/docs/sse.md @@ -125,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) { 2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。 3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。 -**注意:** 请务必使用 `RunShutdown`、`RunTLS` 或 `RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号。 +**注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)` 或 `touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文。 diff --git a/docs/static-files.md b/docs/static-files.md index a2138cd..b1f06a8 100644 --- a/docs/static-files.md +++ b/docs/static-files.md @@ -39,7 +39,7 @@ func main() { // 您也可以使用 StaticFS 服务根路径 // r.StaticFS("/", http.FS(fsroot)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/engine.go b/engine.go index f9d233a..b0723e7 100644 --- a/engine.go +++ b/engine.go @@ -404,11 +404,18 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) { }() } -// applyDefaultServerConfig 应用框架的默认配置到 http.Server -func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { - if engine.serverProtocols != nil { - srv.Protocols = engine.serverProtocols - if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() { +func cloneServerProtocols(protocols *http.Protocols) *http.Protocols { + if protocols == nil { + return nil + } + cloned := *protocols + return &cloned +} + +func applyServerProtocols(srv *http.Server, protocols *http.Protocols) { + if protocols != nil { + srv.Protocols = cloneServerProtocols(protocols) + if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() { if err := configureHTTP2ExtendedConnectServer(srv); err != nil { panic(err) } @@ -416,6 +423,11 @@ func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { } } +// applyDefaultServerConfig 应用框架的默认配置到 http.Server +func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { + applyServerProtocols(srv, engine.serverProtocols) +} + // 配置全局Req Body大小限制 func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) { engine.GlobalMaxRequestBodySize = size @@ -614,7 +626,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // Use 将全局中间件添加到 Engine // 这些中间件将应用于所有注册的路由 -func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { +func (engine *Engine) Use(middleware ...HandlerFunc) Router { engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.rebuildFallbackChains() return engine @@ -683,7 +695,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo { // Group 创建一个新的路由组 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 -func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 basePath: resolveRoutePath("/", relativePath), @@ -692,7 +704,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute } // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 -// 它也实现了 IRouter 接口,允许嵌套分组 +// 它也实现了 Router 接口,允许嵌套分组 type RouterGroup struct { Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 basePath string // 组路径前缀 @@ -701,7 +713,7 @@ type RouterGroup struct { // Use 将中间件应用于当前路由组 // 这些中间件将应用于当前组及其子组的所有路由 -func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { +func (group *RouterGroup) Use(middleware ...HandlerFunc) Router { group.Handlers = append(group.Handlers, middleware...) return group } @@ -747,7 +759,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) { } // Group 为当前组创建一个新的子组 -func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: group.engine.combineHandlers(group.Handlers, handlers), basePath: resolveRoutePath(group.basePath, relativePath), diff --git a/protocols_test.go b/protocols_test.go index 73f16e9..0e2bf1f 100644 --- a/protocols_test.go +++ b/protocols_test.go @@ -70,42 +70,25 @@ func TestApplyDefaultServerConfig(t *testing.T) { } } -func TestRunTLSProtocolInheritance(t *testing.T) { +func TestTLSRunDefaultsProtocolInheritance(t *testing.T) { engine := New() - // 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2 - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, - }) - } - - srv := &http.Server{TLSConfig: &tls.Config{}} - engine.applyDefaultServerConfig(srv) + srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) if !srv.Protocols.HTTP2() { - t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config") + t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config") } - // 模拟用户设置了自定义协议后调用 RunTLS + // 模拟用户设置了自定义协议后进入 TLS 运行模式 engine = New() engine.SetProtocols(&ProtocolsConfig{ Http1: true, Http2: false, // 用户明确不想要 HTTP/2 }) - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, - }) - } - - srv2 := &http.Server{TLSConfig: &tls.Config{}} - engine.applyDefaultServerConfig(srv2) + srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) if srv2.Protocols.HTTP2() { - t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously") + t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols") } } diff --git a/serve.go b/serve.go index 1825b32..0fc83f9 100644 --- a/serve.go +++ b/serve.go @@ -14,6 +14,7 @@ import ( "net/http" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -21,45 +22,173 @@ import ( "github.com/fenthope/reco" ) -// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间 const defaultShutdownTimeout = 5 * time.Second -// --- 内部辅助函数 --- +type runMode uint8 -// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080" -func resolveAddress(addr []string) string { - switch len(addr) { - case 0: - return ":8080" - case 1: - return addr[0] - default: - panic("too many parameters provided for server address") +const ( + runModeHTTP runMode = iota + runModeHTTPS + runModeHTTPSRedirect +) + +type runConfig struct { + addr string + httpRedirectAddr string + tlsConfig *tls.Config + redirectHost string + redirectHostHeaders []string + useHeaderHost bool + useHeaderHostSet bool + graceful bool + shutdownTimeout time.Duration + gracefulCtx context.Context + mode runMode + shutdownDefaultSet bool + shutdownTimeoutSet bool +} + +type RunOption interface { + apply(*runConfig) error +} + +type runOptionFunc func(*runConfig) error + +func (f runOptionFunc) apply(cfg *runConfig) error { + return f(cfg) +} + +func defaultRunConfig() runConfig { + return runConfig{ + addr: ":8080", + shutdownTimeout: defaultShutdownTimeout, + mode: runModeHTTP, + useHeaderHost: true, } } -// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值 -func getShutdownTimeout(timeouts []time.Duration) time.Duration { - if len(timeouts) > 0 && timeouts[0] > 0 { - return timeouts[0] - } - return defaultShutdownTimeout +type HTTPRedirectOption interface { + applyRedirect(*runConfig) error +} + +type redirectOptionFunc func(*runConfig) error + +func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error { + return f(cfg) +} + +func WithAddr(addr string) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if addr == "" { + return errors.New("run address must not be empty") + } + cfg.addr = addr + return nil + }) +} + +func WithTLS(tlsConfig *tls.Config) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil") + } + cfg.tlsConfig = tlsConfig + if cfg.mode == runModeHTTP { + cfg.mode = runModeHTTPS + } + return nil + }) +} + +func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if addr == "" { + return errors.New("http redirect address must not be empty") + } + cfg.httpRedirectAddr = addr + cfg.mode = runModeHTTPSRedirect + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.applyRedirect(cfg); err != nil { + return err + } + } + return nil + }) +} + +func WithUseHeaderHost(enabled bool) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.useHeaderHost = enabled + cfg.useHeaderHostSet = true + return nil + }) +} + +func WithRedirectHost(host string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + if host == "" { + return errors.New("redirect host must not be empty") + } + cfg.redirectHost = host + return nil + }) +} + +func WithRedirectHostHeaders(headers []string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0] + for _, header := range headers { + trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header)) + if trimmed != "" { + cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed) + } + } + return nil + }) +} + +func WithGracefulShutdown(timeout time.Duration) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + cfg.graceful = true + cfg.shutdownTimeoutSet = true + if timeout > 0 { + cfg.shutdownTimeout = timeout + } else { + cfg.shutdownTimeout = defaultShutdownTimeout + } + return nil + }) +} + +func WithGracefulShutdownDefault() RunOption { + return runOptionFunc(func(cfg *runConfig) error { + cfg.graceful = true + cfg.shutdownDefaultSet = true + cfg.shutdownTimeout = defaultShutdownTimeout + return nil + }) +} + +func WithShutdownContext(ctx context.Context) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if ctx == nil { + return errors.New("shutdown context must not be nil") + } + cfg.gracefulCtx = ctx + return nil + }) } -// serveServer 根据显式指定的启动模式运行 HTTP 或 HTTPS 服务器. func serveServer(srv *http.Server, serveTLS bool) error { if serveTLS { - // 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置, - // ListenAndServeTLS 的前两个参数可以为空字符串 return srv.ListenAndServeTLS("", "") } - return srv.ListenAndServe() } -// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server, -// 并处理其启动失败的致命错误 -// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS") func runServer(serverType string, srv *http.Server, serveTLS bool) { go func() { protocol := "http" @@ -70,284 +199,145 @@ func runServer(serverType string, srv *http.Server, serveTLS bool) { log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) err := serveServer(srv, serveTLS) - - // 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed), - // 则认为是一个严重错误,并终止程序 if err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Touka %s server failed: %v", serverType, err) } }() } -// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器 -// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿 -func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error { - // 创建一个 channel 来接收操作系统信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 - <-quit // 阻塞,直到接收到上述信号之一 - log.Println("Shutting down Touka server(s)...") - - // 关闭日志记录器 - if logger != nil { - go func() { - log.Println("Closing Touka logger...") - CloseLogger(logger) - }() +func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config { + if tlsConfig == nil { + return nil } - - // 创建一个带超时的上下文,用于 Shutdown - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - var wg sync.WaitGroup - errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel - - // 并发地关闭所有服务器 - for _, srv := range servers { - wg.Add(1) - go func(s *http.Server) { - defer wg.Done() - if err := s.Shutdown(ctx); err != nil { - // 将错误发送到 channel - errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) - } - }(srv) - } - - wg.Wait() // 等待所有服务器的关闭 goroutine 完成 - close(errChan) // 关闭 channel,以便可以安全地遍历它 - - // 收集所有关闭过程中发生的错误 - var shutdownErrors []error - for err := range errChan { - shutdownErrors = append(shutdownErrors, err) - log.Printf("Shutdown error: %v", err) - } - - if len(shutdownErrors) > 0 { - return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 - } - log.Println("Touka server(s) exited gracefully.") - return nil + return tlsConfig.Clone() } -func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error { - // 创建一个 channel 来接收操作系统信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 - - // 启动服务器 - serverStopped := make(chan error, 1) - for _, srv := range servers { - go func(s *http.Server) { - serverStopped <- s.ListenAndServe() - }(srv) +func parseHTTPSPort(addr string) (string, error) { + _, port, err := net.SplitHostPort(addr) + if err != nil { + return "", fmt.Errorf("https address %q must include a port: %w", addr, err) } + return port, nil +} - select { - case <-ctx.Done(): - // Context 被取消 (例如,通过外部取消函数) - log.Println("Context cancelled, shutting down Touka server(s)...") - case err := <-serverStopped: - // 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("Touka HTTP server failed: %w", err) +func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) { + if serveTLS { + if engine.TLSServerConfigurator != nil { + engine.TLSServerConfigurator(srv) + return } - log.Println("Touka HTTP server stopped gracefully.") - return nil // 服务器已自行优雅关闭,无需进一步处理 - case <-quit: - // 接收到操作系统信号 - log.Println("Shutting down Touka server(s) due to OS signal...") } - - // 关闭日志记录器 - if logger != nil { - go func() { - log.Println("Closing Touka logger...") - CloseLogger(logger) - }() - } - - // 创建一个带超时的上下文,用于 Shutdown - shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - var wg sync.WaitGroup - errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel - - // 并发地关闭所有服务器 - for _, srv := range servers { - wg.Add(1) - go func(s *http.Server) { - defer wg.Done() - if err := s.Shutdown(shutdownCtx); err != nil { - // 将错误发送到 channel - errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) - } - }(srv) - } - - wg.Wait() - close(errChan) // 关闭 channel,以便可以安全地遍历它 - - // 收集所有关闭过程中发生的错误 - var shutdownErrors []error - for err := range errChan { - shutdownErrors = append(shutdownErrors, err) - log.Printf("Shutdown error: %v", err) - } - - if len(shutdownErrors) > 0 { - return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 - } - log.Println("Touka server(s) exited gracefully.") - return nil -} - -// --- 公共 Run 方法 --- - -// Run 启动一个不支持优雅关闭的 HTTP 服务器 -// 这是一个阻塞调用,主要用于简单的场景或快速测试 -// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法 -func (engine *Engine) Run(addr ...string) error { - address := resolveAddress(addr) - srv := &http.Server{Addr: address, Handler: engine} - - // 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性 - engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } - log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address) - return srv.ListenAndServe() } -// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 -func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error { - srv := &http.Server{ - Addr: addr, - Handler: engine, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, - } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - engine.applyDefaultServerConfig(srv) +func applyRedirectServerConfig(engine *Engine, srv *http.Server) { + applyServerProtocols(srv, engine.serverProtocols) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } - - runServer("HTTP", srv, false) - return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) } -// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 -func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error { - srv := &http.Server{ - Addr: addr, - Handler: engine, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, +func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols { + if engine == nil { + return nil } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - engine.applyDefaultServerConfig(srv) - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) + if serveTLS && engine.useDefaultProtocols { + protocols := &http.Protocols{} + protocols.SetHTTP1(true) + protocols.SetHTTP2(true) + return protocols } - - return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco) + return cloneServerProtocols(engine.serverProtocols) } -// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器 -func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunTLS") - } - - // 配置 HTTP/2 支持 (如果使用默认配置) - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, // 默认在 TLS 上启用 HTTP/2 - }) - } - - srv := &http.Server{ - Addr: addr, +func buildMainServer(engine *Engine, cfg runConfig) *http.Server { + serveTLS := cfg.mode != runModeHTTP + server := &http.Server{ + Addr: cfg.addr, Handler: engine, - TLSConfig: tlsConfig, - BaseContext: func(l net.Listener) context.Context { + TLSConfig: cloneTLSConfig(cfg.tlsConfig), + } + if cfg.graceful { + server.BaseContext = func(net.Listener) context.Context { return engine.shutdownCtx - }, + } + server.RegisterOnShutdown(engine.shutdownCancel) } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - // 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator - engine.applyDefaultServerConfig(srv) - if engine.TLSServerConfigurator != nil { - engine.TLSServerConfigurator(srv) - } else if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) - } - - runServer("HTTPS", srv, true) - return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) + applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS)) + applyMainServerConfig(engine, server, serveTLS) + return server } -// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名 -func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - return engine.RunTLS(addr, tlsConfig, timeouts...) +func firstRedirectHeaderHost(r *http.Request, headers []string) string { + if r == nil { + return "" + } + for _, header := range headers { + value := strings.TrimSpace(r.Header.Get(header)) + if value == "" { + continue + } + if comma := strings.IndexByte(value, ','); comma >= 0 { + value = strings.TrimSpace(value[:comma]) + } + if value != "" { + return value + } + } + return "" } -// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭 -func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunTLSRedir") +func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) { + if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return "", http.StatusInternalServerError, false + } + return cfg.redirectHost, 0, true } - // --- HTTPS 服务器 --- - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true}) - } - httpsSrv := &http.Server{ - Addr: httpsAddr, - Handler: engine, - TLSConfig: tlsConfig, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, - } - httpsSrv.RegisterOnShutdown(engine.shutdownCancel) - engine.applyDefaultServerConfig(httpsSrv) - if engine.TLSServerConfigurator != nil { - engine.TLSServerConfigurator(httpsSrv) - } else if engine.ServerConfigurator != nil { - engine.ServerConfigurator(httpsSrv) + if len(cfg.redirectHostHeaders) > 0 { + host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true + } + + if r == nil { + return "", http.StatusUpgradeRequired, false + } + host := strings.TrimSpace(r.Host) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true +} + +func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { + httpsAddr := cfg.addr + httpAddr := cfg.httpRedirectAddr + httpsPort, err := parseHTTPSPort(httpsAddr) + if err != nil { + return nil, err } - // --- HTTP 重定向服务器 --- redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - host = r.Host + host, statusCode, ok := redirectTargetHost(r, cfg) + if !ok { + http.Error(w, http.StatusText(statusCode), statusCode) + return } - _, httpsPort, err := net.SplitHostPort(httpsAddr) - if err != nil { - // 如果 httpsAddr 没有端口,这是一个配置错误 - - log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr) + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + host = parsedHost + if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") { + host = "[" + host + "]" + } } targetURL := "https://" + host - // 只有在非标准 HTTPS 端口 (443) 时才附加端口号 if httpsPort != "443" { targetURL = "https://" + net.JoinHostPort(host, httpsPort) } @@ -355,22 +345,205 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con http.Redirect(w, r, targetURL, http.StatusMovedPermanently) }) - httpSrv := &http.Server{ - Addr: httpAddr, - Handler: redirectHandler, - } - engine.applyDefaultServerConfig(httpSrv) - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(httpSrv) - } - // --- 启动服务器和优雅关闭 --- - runServer("HTTPS", httpsSrv, true) - runServer("HTTP Redirect", httpSrv, false) - return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco) + server := &http.Server{Addr: httpAddr, Handler: redirectHandler} + applyRedirectServerConfig(engine, server) + return server, nil } -// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性 -func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...) +func validateRunConfig(cfg runConfig) error { + if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil { + return errors.New("WithHTTPRedirect requires WithTLS") + } + if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil { + return errors.New("https mode requires WithTLS") + } + if cfg.gracefulCtx != nil && !cfg.graceful { + return errors.New("WithShutdownContext requires graceful shutdown") + } + if len(cfg.redirectHostHeaders) > 0 { + if !cfg.useHeaderHostSet || !cfg.useHeaderHost { + return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)") + } + } + if cfg.useHeaderHostSet && cfg.useHeaderHost { + if cfg.redirectHost != "" { + return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)") + } + } else if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return errors.New("WithUseHeaderHost(false) requires WithRedirectHost") + } + if len(cfg.redirectHostHeaders) > 0 { + return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)") + } + } + return nil +} + +func effectiveShutdownTimeout(cfg runConfig) time.Duration { + if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet { + if cfg.shutdownTimeout > 0 { + return cfg.shutdownTimeout + } + } + return defaultShutdownTimeout +} + +func closeLoggerAsync(logger *reco.Logger) { + if logger == nil { + return + } + go func() { + log.Println("Closing Touka logger...") + CloseLogger(logger) + }() +} + +func shutdownServers(servers []*http.Server, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + var wg sync.WaitGroup + errChan := make(chan error, len(servers)) + for _, srv := range servers { + wg.Add(1) + go func(s *http.Server) { + defer wg.Done() + if err := s.Shutdown(ctx); err != nil { + errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) + } + }(srv) + } + + wg.Wait() + close(errChan) + + var shutdownErrors []error + for err := range errChan { + shutdownErrors = append(shutdownErrors, err) + log.Printf("Shutdown error: %v", err) + } + if len(shutdownErrors) > 0 { + return errors.Join(shutdownErrors...) + } + return nil +} + +func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(quit) + + serverStopped := make(chan error, len(servers)) + for i, srv := range servers { + serveTLSFlag := serveTLS[i] + go func(server *http.Server, useTLS bool) { + serverStopped <- serveServer(server, useTLS) + }(srv, serveTLSFlag) + } + + select { + case err := <-serverStopped: + if err != nil && !errors.Is(err, http.ErrServerClosed) { + if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil { + return errors.Join(err, shutdownErr) + } + return err + } + log.Println("Touka server stopped gracefully.") + return nil + case <-quit: + log.Println("Shutting down Touka server(s) due to OS signal...") + case <-shutdownCtx.Done(): + log.Println("Context cancelled, shutting down Touka server(s)...") + } + + closeLoggerAsync(logger) + if err := shutdownServers(servers, timeout); err != nil { + return err + } + log.Println("Touka server(s) exited gracefully.") + return nil +} + +// Run starts the engine with the provided startup options. +// +// Default behavior with no options: +// - HTTP only +// - listens on :8080 +// - no graceful shutdown orchestration +// +// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable +// signal-aware graceful shutdown and request-context cancellation semantics. +// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown. +func (engine *Engine) Run(opts ...RunOption) error { + cfg := defaultRunConfig() + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.apply(&cfg); err != nil { + return err + } + } + if cfg.httpRedirectAddr != "" { + cfg.mode = runModeHTTPSRedirect + } else if cfg.tlsConfig != nil { + cfg.mode = runModeHTTPS + } + if err := validateRunConfig(cfg); err != nil { + return err + } + + serveTLS := cfg.mode != runModeHTTP + + mainServer := buildMainServer(engine, cfg) + servers := []*http.Server{mainServer} + serveTLSFlags := []bool{serveTLS} + if cfg.mode == runModeHTTPSRedirect { + redirectServer, err := buildRedirectServer(engine, cfg) + if err != nil { + return err + } + servers = append(servers, redirectServer) + serveTLSFlags = append(serveTLSFlags, false) + } + + if !cfg.graceful { + if len(servers) > 1 { + serverStopped := make(chan error, len(servers)) + for i, srv := range servers { + serveTLSFlag := serveTLSFlags[i] + go func(server *http.Server, useTLS bool) { + serverStopped <- serveServer(server, useTLS) + }(srv, serveTLSFlag) + } + + err := <-serverStopped + if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return errors.Join(err, shutdownErr) + } + return shutdownErr + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil + } + + protocolLabel := "HTTP" + if serveTLS { + protocolLabel = "HTTPS" + } + log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr) + return serveServer(mainServer, serveTLS) + } + + shutdownCtx := context.Background() + if cfg.gracefulCtx != nil { + shutdownCtx = cfg.gracefulCtx + } + return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx) } diff --git a/serve_test.go b/serve_test.go index 6092f7b..c717653 100644 --- a/serve_test.go +++ b/serve_test.go @@ -2,15 +2,58 @@ package touka import ( "context" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "io" + "math/big" "net" "net/http" + "net/http/httptest" + "strings" "testing" "time" ) +func generateSelfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate private key: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "127.0.0.1"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("create self-signed cert: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("parse self-signed cert: %v", err) + } + return cert +} + func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -79,3 +122,372 @@ func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err) } } + +func TestRunRejectsRedirectWithoutTLS(t *testing.T) { + engine := New() + err := engine.Run(WithHTTPRedirect(":80")) + if err == nil { + t.Fatal("expected redirect mode without TLS to fail") + } +} + +func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) { + engine := New() + err := engine.Run( + WithAddr(":443"), + WithTLS(&tls.Config{}), + WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})), + ) + if err == nil { + t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail") + } +} + +func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) { + cfg := defaultRunConfig() + if err := WithGracefulShutdownDefault().apply(&cfg); err != nil { + t.Fatalf("apply graceful default option: %v", err) + } + if !cfg.graceful { + t.Fatal("expected graceful shutdown to be enabled") + } + if cfg.shutdownTimeout != defaultShutdownTimeout { + t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout) + } +} + +func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) { + cfg := defaultRunConfig() + tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12} + if err := WithTLS(tlsConfig).apply(&cfg); err != nil { + t.Fatalf("apply TLS option: %v", err) + } + if cfg.mode != runModeHTTPS { + t.Fatalf("expected HTTPS mode, got %v", cfg.mode) + } + if cfg.graceful { + t.Fatal("expected TLS option to remain independent from graceful shutdown") + } + if cfg.tlsConfig != tlsConfig { + t.Fatal("expected TLS config to be preserved in run config") + } +} + +func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { + engine := New() + if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil { + t.Fatal("expected redirect server builder to reject https address without port") + } +} + +func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { + cfg := defaultRunConfig() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := WithShutdownContext(ctx).apply(&cfg); err != nil { + t.Fatalf("apply shutdown context option: %v", err) + } + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected shutdown context without graceful shutdown to fail validation") + } +} + +func TestValidateRunConfigDoesNotMutateMode(t *testing.T) { + cfg := defaultRunConfig() + cfg.httpRedirectAddr = ":80" + if err := validateRunConfig(cfg); err != nil { + t.Fatalf("validate run config: %v", err) + } + if cfg.mode != runModeHTTP { + t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode) + } +} + +func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = false + cfg.useHeaderHostSet = true + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected configured host mode without redirect host to fail validation") + } +} + +func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = true + cfg.useHeaderHostSet = true + cfg.redirectHost = "configured.example" + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected redirect host to be rejected when header host mode is enabled") + } +} + +func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) { + engine := New() + server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP}) + if server.BaseContext == nil { + t.Fatal("expected graceful main server to set BaseContext") + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for base context check: %v", err) + } + defer listener.Close() + if got := server.BaseContext(listener); got != engine.shutdownCtx { + t.Fatal("expected graceful main server to use engine shutdown context") + } +} + +func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) { + engine := New() + serverConfigured := false + tlsConfigured := false + engine.SetServerConfigurator(func(s *http.Server) { + serverConfigured = true + s.ReadTimeout = time.Second + }) + engine.SetTLSServerConfigurator(func(s *http.Server) { + tlsConfigured = true + s.IdleTimeout = time.Second + }) + + server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) + if !tlsConfigured { + t.Fatal("expected TLS configurator to run for HTTPS main server") + } + if serverConfigured { + t.Fatal("expected generic server configurator to be skipped when TLS configurator is set") + } + if server.IdleTimeout != time.Second { + t.Fatal("expected TLS configurator changes to be applied to HTTPS main server") + } +} + +func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) { + engine := New() + configured := false + engine.SetServerConfigurator(func(s *http.Server) { + configured = true + s.ReadTimeout = time.Second + }) + + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + if !configured { + t.Fatal("expected redirect server to use generic server configurator") + } + if server.ReadTimeout != time.Second { + t.Fatal("expected redirect server configurator changes to be applied") + } +} + +func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) { + engine := New() + httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) + if !httpsServer.Protocols.HTTP2() { + t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings") + } + + httpServer := buildMainServer(engine, defaultRunConfig()) + if httpServer.Protocols.HTTP2() { + t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled") + } +} + +func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + req.Header.Set("X-First-Host", "first.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUpgradeRequired { + t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code) + } +} + +func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: false, + useHeaderHostSet: true, + redirectHost: "configured.example", + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil) + req.Host = "[::1]:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" { + t.Fatalf("unexpected IPv6 redirect location: %q", location) + } +} + +func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on occupied addr: %v", err) + } + occupiedAddr := occupied.Addr().String() + defer occupied.Close() + + redirectListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for redirect addr: %v", err) + } + redirectAddr := redirectListener.Addr().String() + if err := redirectListener.Close(); err != nil { + t.Fatalf("close redirect addr probe: %v", err) + } + + engine := New() + redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + mainServer := &http.Server{Addr: occupiedAddr, Handler: engine} + + err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background()) + if err == nil { + t.Fatal("expected gracefulServe to fail when one server cannot bind") + } + if !strings.Contains(err.Error(), occupiedAddr) { + t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err) + } + + conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond) + if dialErr == nil { + conn.Close() + t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr) + } + if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") { + t.Fatalf("unexpected dial result after shutdown, got %v", dialErr) + } +} + +func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) { + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on occupied addr: %v", err) + } + occupiedAddr := occupied.Addr().String() + defer occupied.Close() + + redirectListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for redirect addr: %v", err) + } + redirectAddr := redirectListener.Addr().String() + if err := redirectListener.Close(); err != nil { + t.Fatalf("close redirect addr probe: %v", err) + } + + engine := New() + err = engine.Run( + WithAddr(occupiedAddr), + WithTLS(&tls.Config{}), + WithHTTPRedirect(redirectAddr), + ) + if err == nil { + t.Fatal("expected non-graceful TLS redirect startup to return bind error") + } + if !strings.Contains(err.Error(), occupiedAddr) { + t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err) + } +} diff --git a/touka.go b/touka.go index dd529cb..4ad81da 100644 --- a/touka.go +++ b/touka.go @@ -22,10 +22,10 @@ type HandlerFunc func(*Context) // HandlersChain 定义处理函数链(中间件栈)的类型。 type HandlersChain []HandlerFunc -// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 -type IRouter interface { - Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 - Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 +// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 +type Router interface { + Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组 + Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 GET(relativePath string, handlers ...HandlerFunc)