From e4d3eed379cb58c1bbcc915d776a3c9ccda6e796 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:44:55 +0800 Subject: [PATCH 1/4] feat: redesign server startup around Run options Replace the old RunShutdown and RunTLS style entry points with a single Run(opts...) API for v1. Add focused startup semantics tests, keep TLS and graceful shutdown independent, ensure sibling servers are cleaned up on startup failure, and update docs to match the new option-based startup model. --- README.md | 4 +- about-touka.md | 22 +- docs/advanced.md | 36 ++- docs/introduction.md | 2 +- docs/quickstart.md | 6 +- docs/reverse-proxy.md | 4 +- docs/routing.md | 2 +- docs/sse.md | 2 +- docs/static-files.md | 2 +- engine.go | 22 +- protocols_test.go | 29 +-- serve.go | 585 ++++++++++++++++++++++-------------------- serve_test.go | 196 ++++++++++++++ 13 files changed, 577 insertions(+), 335 deletions(-) diff --git a/README.md b/README.md index a7b99fd..e2eaec8 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,9 @@ func main() { c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query) }) - // 启动服务器 (支持优雅关闭) + // 启动服务器(通过 WithGracefulShutdown 启用优雅关闭) log.Println("Touka Server starting on :8080...") - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("Touka server failed to start: %v", err) } } diff --git a/about-touka.md b/about-touka.md index 86a056f..b3a16b4 100644 --- a/about-touka.md +++ b/about-touka.md @@ -70,13 +70,13 @@ func main() { r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB // ... 其他配置 - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` #### 1.3. 服务器生命周期管理 -Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。 +Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。 ```go func main() { @@ -90,11 +90,11 @@ func main() { fmt.Println("自定义的 HTTP 服务器配置已应用") }) - // 启动服务器,并支持优雅关闭 - // RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号 - // 第二个参数是优雅关闭的超时时间 + // 启动服务器,并通过 Run 选项启用优雅关闭 + // Run(...) 会阻塞当前 goroutine + // WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒 fmt.Println("服务器启动于 :8080") - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("服务器启动失败: %v", err) } } @@ -187,7 +187,7 @@ func main() { } } - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } func AuthMiddleware() touka.HandlerFunc { @@ -313,7 +313,7 @@ func main() { }) }) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } // templates/index.html @@ -400,7 +400,7 @@ func main() { c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID}) }) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` @@ -483,7 +483,7 @@ func main() { // 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获 r.StaticDir("/files", "./non-existent-dir") - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` @@ -546,7 +546,7 @@ func main() { // 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录 r.StaticFS("/", http.FS(subFS)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/docs/advanced.md b/docs/advanced.md index a7cb9a2..7e6a417 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -44,7 +44,9 @@ r.SetTLSServerConfigurator(func(server *http.Server) { Touka 支持配置 HTTP/1.1、HTTP/2 和 H2C(HTTP/2 Cleartext): ```go -// 使用默认协议配置(仅 HTTP/1.1) +// 使用默认协议配置 +// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集, +// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。 r.SetDefaultProtocols() // 自定义协议配置 @@ -57,33 +59,49 @@ r.SetProtocols(&touka.ProtocolsConfig{ ### 启动方式 -Touka 提供了多种服务器启动方式: +Touka 统一通过 `Run(opts...)` 启动服务器: ```go // 1. 简单启动(无优雅停机) -r.Run(":8080") +r.Run(touka.WithAddr(":8080")) // 2. 带优雅停机的启动 -r.RunShutdown(":8080", 10*time.Second) +r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)) // 3. 带上下文的优雅停机 ctx, cancel := context.WithCancel(context.Background()) -r.RunShutdownWithContext(":8080", ctx, 10*time.Second) +defer cancel() +r.Run( + touka.WithAddr(":8080"), + touka.WithGracefulShutdown(10*time.Second), + touka.WithShutdownContext(ctx), +) // 4. HTTPS 启动 tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, // 其他 TLS 配置... } -r.RunTLS(":443", tlsConfig, 10*time.Second) +// WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。 +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithGracefulShutdownDefault(), +) // 5. HTTPS + HTTP 重定向 -r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second) +// WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。 +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect(":80"), + touka.WithGracefulShutdown(10*time.Second), +) ``` ## 优雅停机 (Graceful Shutdown) -在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。 +在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 ```go r := touka.Default() @@ -91,7 +109,7 @@ r := touka.Default() // 监听 SIGINT 和 SIGTERM 信号 // 如果在 10 秒内未处理完,则强制关闭 -if err := r.RunShutdown(":8080", 10*time.Second); err != nil { +if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatal("服务器退出异常:", err) } ``` diff --git a/docs/introduction.md b/docs/introduction.md index 94a7310..87c3e40 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其 1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。 2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。 -3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理。 +3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求。 Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。 diff --git a/docs/quickstart.md b/docs/quickstart.md index 94f7433..2911732 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -46,7 +46,7 @@ func main() { // 4. 启动服务器并监听 8080 端口 log.Println("Touka server is running on :8080") - if err := r.Run(":8080"); err != nil { + if err := r.Run(touka.WithAddr(":8080")); err != nil { log.Fatalf("Server failed: %v", err) } } @@ -66,11 +66,11 @@ go run main.go ## 优雅停机 -在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成。 +在生产环境中,我们推荐为 `Run` 追加优雅关闭选项。启用后,Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`。 ```go // 等待 10 秒以处理剩余请求 -if err := r.RunShutdown(":8080", 10*time.Second); err != nil { +if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("Server forced to shutdown: %v", err) } ``` diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 7d05290..cb4b2a3 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -28,7 +28,7 @@ func main() { Target: target, })) - _ = r.Run(":8080") + _ = r.Run(touka.WithAddr(":8080")) } ``` @@ -497,7 +497,7 @@ func main() { }, })) - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatal(err) } } diff --git a/docs/routing.md b/docs/routing.md index 223081a..70a24dc 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -142,7 +142,7 @@ func main() { r := touka.Default() fsroot, _ := fs.Sub(content, "dist") r.StaticFS("/", http.FS(fsroot)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/docs/sse.md b/docs/sse.md index 1b44521..a003be9 100644 --- a/docs/sse.md +++ b/docs/sse.md @@ -125,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) { 2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。 3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。 -**注意:** 请务必使用 `RunShutdown`、`RunTLS` 或 `RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号。 +**注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)` 或 `touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文。 diff --git a/docs/static-files.md b/docs/static-files.md index a2138cd..b1f06a8 100644 --- a/docs/static-files.md +++ b/docs/static-files.md @@ -39,7 +39,7 @@ func main() { // 您也可以使用 StaticFS 服务根路径 // r.StaticFS("/", http.FS(fsroot)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/engine.go b/engine.go index f9d233a..2849ffa 100644 --- a/engine.go +++ b/engine.go @@ -404,11 +404,18 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) { }() } -// applyDefaultServerConfig 应用框架的默认配置到 http.Server -func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { - if engine.serverProtocols != nil { - srv.Protocols = engine.serverProtocols - if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() { +func cloneServerProtocols(protocols *http.Protocols) *http.Protocols { + if protocols == nil { + return nil + } + cloned := *protocols + return &cloned +} + +func applyServerProtocols(srv *http.Server, protocols *http.Protocols) { + if protocols != nil { + srv.Protocols = cloneServerProtocols(protocols) + if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() { if err := configureHTTP2ExtendedConnectServer(srv); err != nil { panic(err) } @@ -416,6 +423,11 @@ func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { } } +// applyDefaultServerConfig 应用框架的默认配置到 http.Server +func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { + applyServerProtocols(srv, engine.serverProtocols) +} + // 配置全局Req Body大小限制 func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) { engine.GlobalMaxRequestBodySize = size diff --git a/protocols_test.go b/protocols_test.go index 73f16e9..0e2bf1f 100644 --- a/protocols_test.go +++ b/protocols_test.go @@ -70,42 +70,25 @@ func TestApplyDefaultServerConfig(t *testing.T) { } } -func TestRunTLSProtocolInheritance(t *testing.T) { +func TestTLSRunDefaultsProtocolInheritance(t *testing.T) { engine := New() - // 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2 - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, - }) - } - - srv := &http.Server{TLSConfig: &tls.Config{}} - engine.applyDefaultServerConfig(srv) + srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) if !srv.Protocols.HTTP2() { - t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config") + t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config") } - // 模拟用户设置了自定义协议后调用 RunTLS + // 模拟用户设置了自定义协议后进入 TLS 运行模式 engine = New() engine.SetProtocols(&ProtocolsConfig{ Http1: true, Http2: false, // 用户明确不想要 HTTP/2 }) - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, - }) - } - - srv2 := &http.Server{TLSConfig: &tls.Config{}} - engine.applyDefaultServerConfig(srv2) + srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) if srv2.Protocols.HTTP2() { - t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously") + t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols") } } diff --git a/serve.go b/serve.go index 1825b32..2c8c73b 100644 --- a/serve.go +++ b/serve.go @@ -21,45 +21,119 @@ import ( "github.com/fenthope/reco" ) -// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间 const defaultShutdownTimeout = 5 * time.Second -// --- 内部辅助函数 --- +type runMode uint8 -// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080" -func resolveAddress(addr []string) string { - switch len(addr) { - case 0: - return ":8080" - case 1: - return addr[0] - default: - panic("too many parameters provided for server address") +const ( + runModeHTTP runMode = iota + runModeHTTPS + runModeHTTPSRedirect +) + +type runConfig struct { + addr string + httpRedirectAddr string + tlsConfig *tls.Config + graceful bool + shutdownTimeout time.Duration + gracefulCtx context.Context + mode runMode + shutdownDefaultSet bool + shutdownTimeoutSet bool +} + +type RunOption interface { + apply(*runConfig) error +} + +type runOptionFunc func(*runConfig) error + +func (f runOptionFunc) apply(cfg *runConfig) error { + return f(cfg) +} + +func defaultRunConfig() runConfig { + return runConfig{ + addr: ":8080", + shutdownTimeout: defaultShutdownTimeout, + mode: runModeHTTP, } } -// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值 -func getShutdownTimeout(timeouts []time.Duration) time.Duration { - if len(timeouts) > 0 && timeouts[0] > 0 { - return timeouts[0] - } - return defaultShutdownTimeout +func WithAddr(addr string) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if addr == "" { + return errors.New("run address must not be empty") + } + cfg.addr = addr + return nil + }) +} + +func WithTLS(tlsConfig *tls.Config) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil") + } + cfg.tlsConfig = tlsConfig + if cfg.mode == runModeHTTP { + cfg.mode = runModeHTTPS + } + return nil + }) +} + +func WithHTTPRedirect(addr string) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if addr == "" { + return errors.New("http redirect address must not be empty") + } + cfg.httpRedirectAddr = addr + cfg.mode = runModeHTTPSRedirect + return nil + }) +} + +func WithGracefulShutdown(timeout time.Duration) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + cfg.graceful = true + cfg.shutdownTimeoutSet = true + if timeout > 0 { + cfg.shutdownTimeout = timeout + } else { + cfg.shutdownTimeout = defaultShutdownTimeout + } + return nil + }) +} + +func WithGracefulShutdownDefault() RunOption { + return runOptionFunc(func(cfg *runConfig) error { + cfg.graceful = true + cfg.shutdownDefaultSet = true + cfg.shutdownTimeout = defaultShutdownTimeout + return nil + }) +} + +func WithShutdownContext(ctx context.Context) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if ctx == nil { + return errors.New("shutdown context must not be nil") + } + cfg.gracefulCtx = ctx + return nil + }) } -// serveServer 根据显式指定的启动模式运行 HTTP 或 HTTPS 服务器. func serveServer(srv *http.Server, serveTLS bool) error { if serveTLS { - // 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置, - // ListenAndServeTLS 的前两个参数可以为空字符串 return srv.ListenAndServeTLS("", "") } - return srv.ListenAndServe() } -// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server, -// 并处理其启动失败的致命错误 -// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS") func runServer(serverType string, srv *http.Server, serveTLS bool) { go func() { protocol := "http" @@ -70,284 +144,90 @@ func runServer(serverType string, srv *http.Server, serveTLS bool) { log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) err := serveServer(srv, serveTLS) - - // 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed), - // 则认为是一个严重错误,并终止程序 if err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Touka %s server failed: %v", serverType, err) } }() } -// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器 -// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿 -func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error { - // 创建一个 channel 来接收操作系统信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 - <-quit // 阻塞,直到接收到上述信号之一 - log.Println("Shutting down Touka server(s)...") - - // 关闭日志记录器 - if logger != nil { - go func() { - log.Println("Closing Touka logger...") - CloseLogger(logger) - }() +func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config { + if tlsConfig == nil { + return nil } - - // 创建一个带超时的上下文,用于 Shutdown - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - var wg sync.WaitGroup - errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel - - // 并发地关闭所有服务器 - for _, srv := range servers { - wg.Add(1) - go func(s *http.Server) { - defer wg.Done() - if err := s.Shutdown(ctx); err != nil { - // 将错误发送到 channel - errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) - } - }(srv) - } - - wg.Wait() // 等待所有服务器的关闭 goroutine 完成 - close(errChan) // 关闭 channel,以便可以安全地遍历它 - - // 收集所有关闭过程中发生的错误 - var shutdownErrors []error - for err := range errChan { - shutdownErrors = append(shutdownErrors, err) - log.Printf("Shutdown error: %v", err) - } - - if len(shutdownErrors) > 0 { - return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 - } - log.Println("Touka server(s) exited gracefully.") - return nil + return tlsConfig.Clone() } -func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error { - // 创建一个 channel 来接收操作系统信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 - - // 启动服务器 - serverStopped := make(chan error, 1) - for _, srv := range servers { - go func(s *http.Server) { - serverStopped <- s.ListenAndServe() - }(srv) +func parseHTTPSPort(addr string) (string, error) { + _, port, err := net.SplitHostPort(addr) + if err != nil { + return "", fmt.Errorf("https address %q must include a port: %w", addr, err) } + return port, nil +} - select { - case <-ctx.Done(): - // Context 被取消 (例如,通过外部取消函数) - log.Println("Context cancelled, shutting down Touka server(s)...") - case err := <-serverStopped: - // 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("Touka HTTP server failed: %w", err) +func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) { + if serveTLS { + if engine.TLSServerConfigurator != nil { + engine.TLSServerConfigurator(srv) + return } - log.Println("Touka HTTP server stopped gracefully.") - return nil // 服务器已自行优雅关闭,无需进一步处理 - case <-quit: - // 接收到操作系统信号 - log.Println("Shutting down Touka server(s) due to OS signal...") } - - // 关闭日志记录器 - if logger != nil { - go func() { - log.Println("Closing Touka logger...") - CloseLogger(logger) - }() - } - - // 创建一个带超时的上下文,用于 Shutdown - shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - var wg sync.WaitGroup - errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel - - // 并发地关闭所有服务器 - for _, srv := range servers { - wg.Add(1) - go func(s *http.Server) { - defer wg.Done() - if err := s.Shutdown(shutdownCtx); err != nil { - // 将错误发送到 channel - errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) - } - }(srv) - } - - wg.Wait() - close(errChan) // 关闭 channel,以便可以安全地遍历它 - - // 收集所有关闭过程中发生的错误 - var shutdownErrors []error - for err := range errChan { - shutdownErrors = append(shutdownErrors, err) - log.Printf("Shutdown error: %v", err) - } - - if len(shutdownErrors) > 0 { - return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 - } - log.Println("Touka server(s) exited gracefully.") - return nil -} - -// --- 公共 Run 方法 --- - -// Run 启动一个不支持优雅关闭的 HTTP 服务器 -// 这是一个阻塞调用,主要用于简单的场景或快速测试 -// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法 -func (engine *Engine) Run(addr ...string) error { - address := resolveAddress(addr) - srv := &http.Server{Addr: address, Handler: engine} - - // 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性 - engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } - log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address) - return srv.ListenAndServe() } -// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 -func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error { - srv := &http.Server{ - Addr: addr, - Handler: engine, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, - } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - engine.applyDefaultServerConfig(srv) +func applyRedirectServerConfig(engine *Engine, srv *http.Server) { + applyServerProtocols(srv, engine.serverProtocols) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } - - runServer("HTTP", srv, false) - return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) } -// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 -func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error { - srv := &http.Server{ - Addr: addr, - Handler: engine, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, +func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols { + if engine == nil { + return nil } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - engine.applyDefaultServerConfig(srv) - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) + if serveTLS && engine.useDefaultProtocols { + protocols := &http.Protocols{} + protocols.SetHTTP1(true) + protocols.SetHTTP2(true) + return protocols } - - return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco) + return cloneServerProtocols(engine.serverProtocols) } -// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器 -func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunTLS") - } - - // 配置 HTTP/2 支持 (如果使用默认配置) - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, // 默认在 TLS 上启用 HTTP/2 - }) - } - - srv := &http.Server{ - Addr: addr, +func buildMainServer(engine *Engine, cfg runConfig) *http.Server { + serveTLS := cfg.mode != runModeHTTP + server := &http.Server{ + Addr: cfg.addr, Handler: engine, - TLSConfig: tlsConfig, - BaseContext: func(l net.Listener) context.Context { + TLSConfig: cloneTLSConfig(cfg.tlsConfig), + } + if cfg.graceful { + server.BaseContext = func(net.Listener) context.Context { return engine.shutdownCtx - }, + } + server.RegisterOnShutdown(engine.shutdownCancel) } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - // 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator - engine.applyDefaultServerConfig(srv) - if engine.TLSServerConfigurator != nil { - engine.TLSServerConfigurator(srv) - } else if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) - } - - runServer("HTTPS", srv, true) - return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) + applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS)) + applyMainServerConfig(engine, server, serveTLS) + return server } -// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名 -func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - return engine.RunTLS(addr, tlsConfig, timeouts...) -} - -// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭 -func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunTLSRedir") +func buildRedirectServer(engine *Engine, httpsAddr, httpAddr string) (*http.Server, error) { + httpsPort, err := parseHTTPSPort(httpsAddr) + if err != nil { + return nil, err } - // --- HTTPS 服务器 --- - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true}) - } - httpsSrv := &http.Server{ - Addr: httpsAddr, - Handler: engine, - TLSConfig: tlsConfig, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, - } - httpsSrv.RegisterOnShutdown(engine.shutdownCancel) - engine.applyDefaultServerConfig(httpsSrv) - if engine.TLSServerConfigurator != nil { - engine.TLSServerConfigurator(httpsSrv) - } else if engine.ServerConfigurator != nil { - engine.ServerConfigurator(httpsSrv) - } - - // --- HTTP 重定向服务器 --- redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { host, _, err := net.SplitHostPort(r.Host) if err != nil { host = r.Host } - _, httpsPort, err := net.SplitHostPort(httpsAddr) - if err != nil { - // 如果 httpsAddr 没有端口,这是一个配置错误 - - log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr) - } - targetURL := "https://" + host - // 只有在非标准 HTTPS 端口 (443) 时才附加端口号 if httpsPort != "443" { targetURL = "https://" + net.JoinHostPort(host, httpsPort) } @@ -355,22 +235,175 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con http.Redirect(w, r, targetURL, http.StatusMovedPermanently) }) - httpSrv := &http.Server{ - Addr: httpAddr, - Handler: redirectHandler, - } - engine.applyDefaultServerConfig(httpSrv) - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(httpSrv) - } - // --- 启动服务器和优雅关闭 --- - runServer("HTTPS", httpsSrv, true) - runServer("HTTP Redirect", httpSrv, false) - return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco) + server := &http.Server{Addr: httpAddr, Handler: redirectHandler} + applyRedirectServerConfig(engine, server) + return server, nil } -// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性 -func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...) +func validateRunConfig(cfg runConfig) error { + if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil { + return errors.New("WithHTTPRedirect requires WithTLS") + } + if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil { + return errors.New("https mode requires WithTLS") + } + if cfg.httpRedirectAddr != "" && cfg.mode != runModeHTTPSRedirect { + cfg.mode = runModeHTTPSRedirect + } + if cfg.gracefulCtx != nil && !cfg.graceful { + return errors.New("WithShutdownContext requires graceful shutdown") + } + return nil +} + +func effectiveShutdownTimeout(cfg runConfig) time.Duration { + if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet { + if cfg.shutdownTimeout > 0 { + return cfg.shutdownTimeout + } + } + return defaultShutdownTimeout +} + +func closeLoggerAsync(logger *reco.Logger) { + if logger == nil { + return + } + go func() { + log.Println("Closing Touka logger...") + CloseLogger(logger) + }() +} + +func shutdownServers(servers []*http.Server, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + var wg sync.WaitGroup + errChan := make(chan error, len(servers)) + for _, srv := range servers { + wg.Add(1) + go func(s *http.Server) { + defer wg.Done() + if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) + } + }(srv) + } + + wg.Wait() + close(errChan) + + var shutdownErrors []error + for err := range errChan { + shutdownErrors = append(shutdownErrors, err) + log.Printf("Shutdown error: %v", err) + } + if len(shutdownErrors) > 0 { + return errors.Join(shutdownErrors...) + } + return nil +} + +func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(quit) + + serverStopped := make(chan error, len(servers)) + for i, srv := range servers { + serveTLSFlag := serveTLS[i] + go func(server *http.Server, useTLS bool) { + serverStopped <- serveServer(server, useTLS) + }(srv, serveTLSFlag) + } + + select { + case err := <-serverStopped: + if err != nil && !errors.Is(err, http.ErrServerClosed) { + if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil { + return errors.Join(err, shutdownErr) + } + return err + } + log.Println("Touka server stopped gracefully.") + return nil + case <-quit: + log.Println("Shutting down Touka server(s) due to OS signal...") + case <-shutdownCtx.Done(): + log.Println("Context cancelled, shutting down Touka server(s)...") + } + + closeLoggerAsync(logger) + if err := shutdownServers(servers, timeout); err != nil { + return err + } + log.Println("Touka server(s) exited gracefully.") + return nil +} + +// Run starts the engine with the provided startup options. +// +// Default behavior with no options: +// - HTTP only +// - listens on :8080 +// - no graceful shutdown orchestration +// +// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable +// signal-aware graceful shutdown and request-context cancellation semantics. +// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown. +func (engine *Engine) Run(opts ...RunOption) error { + cfg := defaultRunConfig() + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.apply(&cfg); err != nil { + return err + } + } + if cfg.httpRedirectAddr != "" { + cfg.mode = runModeHTTPSRedirect + } else if cfg.tlsConfig != nil { + cfg.mode = runModeHTTPS + } + if err := validateRunConfig(cfg); err != nil { + return err + } + + serveTLS := cfg.mode != runModeHTTP + + mainServer := buildMainServer(engine, cfg) + servers := []*http.Server{mainServer} + serveTLSFlags := []bool{serveTLS} + if cfg.mode == runModeHTTPSRedirect { + redirectServer, err := buildRedirectServer(engine, cfg.addr, cfg.httpRedirectAddr) + if err != nil { + return err + } + servers = append(servers, redirectServer) + serveTLSFlags = append(serveTLSFlags, false) + } + + if !cfg.graceful { + if len(servers) > 1 { + runServer("HTTPS", servers[0], true) + log.Printf("Starting Touka HTTP Redirect server on %s", servers[1].Addr) + return serveServer(servers[1], false) + } + + protocolLabel := "HTTP" + if serveTLS { + protocolLabel = "HTTPS" + } + log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr) + return serveServer(mainServer, serveTLS) + } + + shutdownCtx := context.Background() + if cfg.gracefulCtx != nil { + shutdownCtx = cfg.gracefulCtx + } + return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx) } diff --git a/serve_test.go b/serve_test.go index 6092f7b..6ecbeba 100644 --- a/serve_test.go +++ b/serve_test.go @@ -7,6 +7,8 @@ import ( "io" "net" "net/http" + "net/http/httptest" + "strings" "testing" "time" ) @@ -79,3 +81,197 @@ func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err) } } + +func TestRunRejectsRedirectWithoutTLS(t *testing.T) { + engine := New() + err := engine.Run(WithHTTPRedirect(":80")) + if err == nil { + t.Fatal("expected redirect mode without TLS to fail") + } +} + +func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) { + cfg := defaultRunConfig() + if err := WithGracefulShutdownDefault().apply(&cfg); err != nil { + t.Fatalf("apply graceful default option: %v", err) + } + if !cfg.graceful { + t.Fatal("expected graceful shutdown to be enabled") + } + if cfg.shutdownTimeout != defaultShutdownTimeout { + t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout) + } +} + +func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) { + cfg := defaultRunConfig() + tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12} + if err := WithTLS(tlsConfig).apply(&cfg); err != nil { + t.Fatalf("apply TLS option: %v", err) + } + if cfg.mode != runModeHTTPS { + t.Fatalf("expected HTTPS mode, got %v", cfg.mode) + } + if cfg.graceful { + t.Fatal("expected TLS option to remain independent from graceful shutdown") + } + if cfg.tlsConfig != tlsConfig { + t.Fatal("expected TLS config to be preserved in run config") + } +} + +func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { + engine := New() + if _, err := buildRedirectServer(engine, "example.com", ":80"); err == nil { + t.Fatal("expected redirect server builder to reject https address without port") + } +} + +func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { + cfg := defaultRunConfig() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := WithShutdownContext(ctx).apply(&cfg); err != nil { + t.Fatalf("apply shutdown context option: %v", err) + } + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected shutdown context without graceful shutdown to fail validation") + } +} + +func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) { + engine := New() + server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP}) + if server.BaseContext == nil { + t.Fatal("expected graceful main server to set BaseContext") + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for base context check: %v", err) + } + defer listener.Close() + if got := server.BaseContext(listener); got != engine.shutdownCtx { + t.Fatal("expected graceful main server to use engine shutdown context") + } +} + +func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) { + engine := New() + serverConfigured := false + tlsConfigured := false + engine.SetServerConfigurator(func(s *http.Server) { + serverConfigured = true + s.ReadTimeout = time.Second + }) + engine.SetTLSServerConfigurator(func(s *http.Server) { + tlsConfigured = true + s.IdleTimeout = time.Second + }) + + server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) + if !tlsConfigured { + t.Fatal("expected TLS configurator to run for HTTPS main server") + } + if serverConfigured { + t.Fatal("expected generic server configurator to be skipped when TLS configurator is set") + } + if server.IdleTimeout != time.Second { + t.Fatal("expected TLS configurator changes to be applied to HTTPS main server") + } +} + +func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) { + engine := New() + configured := false + engine.SetServerConfigurator(func(s *http.Server) { + configured = true + s.ReadTimeout = time.Second + }) + + server, err := buildRedirectServer(engine, ":443", ":80") + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + if !configured { + t.Fatal("expected redirect server to use generic server configurator") + } + if server.ReadTimeout != time.Second { + t.Fatal("expected redirect server configurator changes to be applied") + } +} + +func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) { + engine := New() + httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) + if !httpsServer.Protocols.HTTP2() { + t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings") + } + + httpServer := buildMainServer(engine, defaultRunConfig()) + if httpServer.Protocols.HTTP2() { + t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled") + } +} + +func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, ":443", ":80") + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on occupied addr: %v", err) + } + occupiedAddr := occupied.Addr().String() + defer occupied.Close() + + redirectListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for redirect addr: %v", err) + } + redirectAddr := redirectListener.Addr().String() + if err := redirectListener.Close(); err != nil { + t.Fatalf("close redirect addr probe: %v", err) + } + + engine := New() + redirectServer, err := buildRedirectServer(engine, ":443", redirectAddr) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + mainServer := &http.Server{Addr: occupiedAddr, Handler: engine} + + err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background()) + if err == nil { + t.Fatal("expected gracefulServe to fail when one server cannot bind") + } + if !strings.Contains(err.Error(), occupiedAddr) { + t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err) + } + + conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond) + if dialErr == nil { + conn.Close() + t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr) + } + if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") { + t.Fatalf("unexpected dial result after shutdown, got %v", dialErr) + } +} From e2cf08d5ddc659e67077ad52953ce50bc8d8694d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 19:49:13 +0800 Subject: [PATCH 2/4] feat: add redirect host selection options Support explicit redirect host source selection for HTTP-to-HTTPS redirects with ordered header lookup, fixed host mode, and strict validation. Document the new redirect option relationships and add focused tests for 426 fallback, conflict checks, and non-graceful startup errors. --- docs/advanced.md | 98 ++++++++++++++++++++++++++ engine.go | 10 +-- serve.go | 178 +++++++++++++++++++++++++++++++++++++++++------ serve_test.go | 163 +++++++++++++++++++++++++++++++++++++++++-- touka.go | 8 +-- 5 files changed, 422 insertions(+), 35 deletions(-) diff --git a/docs/advanced.md b/docs/advanced.md index 7e6a417..eb44c2d 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -97,8 +97,106 @@ r.Run( touka.WithHTTPRedirect(":80"), touka.WithGracefulShutdown(10*time.Second), ) + +// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) + +// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) ``` +### HTTPS Redirect Host 策略 + +`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。 + +可用的 redirect 子选项: + +- `touka.WithUseHeaderHost(true|false)` +- `touka.WithRedirectHostHeaders([]string{...})` +- `touka.WithRedirectHost("example.com")` + +#### 模式一:使用请求输入侧的 host + +当 `WithUseHeaderHost(true)` 时: + +- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host` +- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值 +- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) +``` + +#### 模式二:使用配置的固定 host + +当 `WithUseHeaderHost(false)` 时: + +- 不读取 `Request.Host` +- 不读取 `WithRedirectHostHeaders(...)` +- 必须配置 `WithRedirectHost("example.com")` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) +``` + +#### 严格校验规则 + +以下组合会直接返回配置错误: + +- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)` +- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)` +- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)` +- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)` +- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)` + +#### 优先级关系 + +1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式 +2. `WithUseHeaderHost(...)` 决定 host 来源模式 +3. 当 `WithUseHeaderHost(true)` 时: + - 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询 + - 未配置时使用 `Request.Host` +4. 当 `WithUseHeaderHost(false)` 时: + - 只使用 `WithRedirectHost(...)` + +**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。 + ## 优雅停机 (Graceful Shutdown) 在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 diff --git a/engine.go b/engine.go index 2849ffa..b0723e7 100644 --- a/engine.go +++ b/engine.go @@ -626,7 +626,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // Use 将全局中间件添加到 Engine // 这些中间件将应用于所有注册的路由 -func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { +func (engine *Engine) Use(middleware ...HandlerFunc) Router { engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.rebuildFallbackChains() return engine @@ -695,7 +695,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo { // Group 创建一个新的路由组 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 -func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 basePath: resolveRoutePath("/", relativePath), @@ -704,7 +704,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute } // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 -// 它也实现了 IRouter 接口,允许嵌套分组 +// 它也实现了 Router 接口,允许嵌套分组 type RouterGroup struct { Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 basePath string // 组路径前缀 @@ -713,7 +713,7 @@ type RouterGroup struct { // Use 将中间件应用于当前路由组 // 这些中间件将应用于当前组及其子组的所有路由 -func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { +func (group *RouterGroup) Use(middleware ...HandlerFunc) Router { group.Handlers = append(group.Handlers, middleware...) return group } @@ -759,7 +759,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) { } // Group 为当前组创建一个新的子组 -func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: group.engine.combineHandlers(group.Handlers, handlers), basePath: resolveRoutePath(group.basePath, relativePath), diff --git a/serve.go b/serve.go index 2c8c73b..b2ba358 100644 --- a/serve.go +++ b/serve.go @@ -14,6 +14,7 @@ import ( "net/http" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -32,15 +33,19 @@ const ( ) type runConfig struct { - addr string - httpRedirectAddr string - tlsConfig *tls.Config - graceful bool - shutdownTimeout time.Duration - gracefulCtx context.Context - mode runMode - shutdownDefaultSet bool - shutdownTimeoutSet bool + addr string + httpRedirectAddr string + tlsConfig *tls.Config + redirectHost string + redirectHostHeaders []string + useHeaderHost bool + useHeaderHostSet bool + graceful bool + shutdownTimeout time.Duration + gracefulCtx context.Context + mode runMode + shutdownDefaultSet bool + shutdownTimeoutSet bool } type RunOption interface { @@ -58,9 +63,20 @@ func defaultRunConfig() runConfig { addr: ":8080", shutdownTimeout: defaultShutdownTimeout, mode: runModeHTTP, + useHeaderHost: true, } } +type HTTPRedirectOption interface { + applyRedirect(*runConfig) error +} + +type redirectOptionFunc func(*runConfig) error + +func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error { + return f(cfg) +} + func WithAddr(addr string) RunOption { return runOptionFunc(func(cfg *runConfig) error { if addr == "" { @@ -84,13 +100,52 @@ func WithTLS(tlsConfig *tls.Config) RunOption { }) } -func WithHTTPRedirect(addr string) RunOption { +func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption { return runOptionFunc(func(cfg *runConfig) error { if addr == "" { return errors.New("http redirect address must not be empty") } cfg.httpRedirectAddr = addr cfg.mode = runModeHTTPSRedirect + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.applyRedirect(cfg); err != nil { + return err + } + } + return nil + }) +} + +func WithUseHeaderHost(enabled bool) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.useHeaderHost = enabled + cfg.useHeaderHostSet = true + return nil + }) +} + +func WithRedirectHost(host string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + if host == "" { + return errors.New("redirect host must not be empty") + } + cfg.redirectHost = host + return nil + }) +} + +func WithRedirectHostHeaders(headers []string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0] + for _, header := range headers { + trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header)) + if trimmed != "" { + cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed) + } + } return nil }) } @@ -215,16 +270,68 @@ func buildMainServer(engine *Engine, cfg runConfig) *http.Server { return server } -func buildRedirectServer(engine *Engine, httpsAddr, httpAddr string) (*http.Server, error) { +func firstRedirectHeaderHost(r *http.Request, headers []string) string { + if r == nil { + return "" + } + for _, header := range headers { + value := strings.TrimSpace(r.Header.Get(header)) + if value == "" { + continue + } + if comma := strings.IndexByte(value, ','); comma >= 0 { + value = strings.TrimSpace(value[:comma]) + } + if value != "" { + return value + } + } + return "" +} + +func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) { + if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return "", http.StatusInternalServerError, false + } + return cfg.redirectHost, 0, true + } + + if len(cfg.redirectHostHeaders) > 0 { + host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true + } + + if r == nil { + return "", http.StatusUpgradeRequired, false + } + host := strings.TrimSpace(r.Host) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true +} + +func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { + httpsAddr := cfg.addr + httpAddr := cfg.httpRedirectAddr httpsPort, err := parseHTTPSPort(httpsAddr) if err != nil { return nil, err } redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - host = r.Host + host, statusCode, ok := redirectTargetHost(r, cfg) + if !ok { + http.Error(w, http.StatusText(statusCode), statusCode) + return + } + + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + host = parsedHost } targetURL := "https://" + host @@ -248,12 +355,26 @@ func validateRunConfig(cfg runConfig) error { if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil { return errors.New("https mode requires WithTLS") } - if cfg.httpRedirectAddr != "" && cfg.mode != runModeHTTPSRedirect { - cfg.mode = runModeHTTPSRedirect - } if cfg.gracefulCtx != nil && !cfg.graceful { return errors.New("WithShutdownContext requires graceful shutdown") } + if len(cfg.redirectHostHeaders) > 0 { + if !cfg.useHeaderHostSet || !cfg.useHeaderHost { + return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)") + } + } + if cfg.useHeaderHostSet && cfg.useHeaderHost { + if cfg.redirectHost != "" { + return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)") + } + } else if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return errors.New("WithUseHeaderHost(false) requires WithRedirectHost") + } + if len(cfg.redirectHostHeaders) > 0 { + return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)") + } + } return nil } @@ -286,7 +407,7 @@ func shutdownServers(servers []*http.Server, timeout time.Duration) error { wg.Add(1) go func(s *http.Server) { defer wg.Done() - if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := s.Shutdown(ctx); err != nil { errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) } }(srv) @@ -378,7 +499,7 @@ func (engine *Engine) Run(opts ...RunOption) error { servers := []*http.Server{mainServer} serveTLSFlags := []bool{serveTLS} if cfg.mode == runModeHTTPSRedirect { - redirectServer, err := buildRedirectServer(engine, cfg.addr, cfg.httpRedirectAddr) + redirectServer, err := buildRedirectServer(engine, cfg) if err != nil { return err } @@ -388,9 +509,22 @@ func (engine *Engine) Run(opts ...RunOption) error { if !cfg.graceful { if len(servers) > 1 { - runServer("HTTPS", servers[0], true) - log.Printf("Starting Touka HTTP Redirect server on %s", servers[1].Addr) - return serveServer(servers[1], false) + serverStopped := make(chan error, len(servers)) + for i, srv := range servers { + serveTLSFlag := serveTLSFlags[i] + go func(server *http.Server, useTLS bool) { + serverStopped <- serveServer(server, useTLS) + }(srv, serveTLSFlag) + } + + err := <-serverStopped + if err != nil && !errors.Is(err, http.ErrServerClosed) { + if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { + return errors.Join(err, shutdownErr) + } + return err + } + return err } protocolLabel := "HTTP" diff --git a/serve_test.go b/serve_test.go index 6ecbeba..2bdddc5 100644 --- a/serve_test.go +++ b/serve_test.go @@ -90,6 +90,18 @@ func TestRunRejectsRedirectWithoutTLS(t *testing.T) { } } +func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) { + engine := New() + err := engine.Run( + WithAddr(":443"), + WithTLS(&tls.Config{}), + WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})), + ) + if err == nil { + t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail") + } +} + func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) { cfg := defaultRunConfig() if err := WithGracefulShutdownDefault().apply(&cfg); err != nil { @@ -122,7 +134,7 @@ func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) { func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { engine := New() - if _, err := buildRedirectServer(engine, "example.com", ":80"); err == nil { + if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil { t.Fatal("expected redirect server builder to reject https address without port") } } @@ -139,6 +151,40 @@ func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { } } +func TestValidateRunConfigDoesNotMutateMode(t *testing.T) { + cfg := defaultRunConfig() + cfg.httpRedirectAddr = ":80" + if err := validateRunConfig(cfg); err != nil { + t.Fatalf("validate run config: %v", err) + } + if cfg.mode != runModeHTTP { + t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode) + } +} + +func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = false + cfg.useHeaderHostSet = true + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected configured host mode without redirect host to fail validation") + } +} + +func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = true + cfg.useHeaderHostSet = true + cfg.redirectHost = "configured.example" + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected redirect host to be rejected when header host mode is enabled") + } +} + func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) { engine := New() server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP}) @@ -189,7 +235,7 @@ func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) { s.ReadTimeout = time.Second }) - server, err := buildRedirectServer(engine, ":443", ":80") + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) if err != nil { t.Fatalf("build redirect server: %v", err) } @@ -216,7 +262,7 @@ func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) { func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { engine := New() - server, err := buildRedirectServer(engine, ":443", ":80") + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) if err != nil { t.Fatalf("build redirect server: %v", err) } @@ -234,6 +280,84 @@ func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { } } +func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + req.Header.Set("X-First-Host", "first.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUpgradeRequired { + t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code) + } +} + +func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: false, + useHeaderHostSet: true, + redirectHost: "configured.example", + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { occupied, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -252,7 +376,7 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { } engine := New() - redirectServer, err := buildRedirectServer(engine, ":443", redirectAddr) + redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr}) if err != nil { t.Fatalf("build redirect server: %v", err) } @@ -275,3 +399,34 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { t.Fatalf("unexpected dial result after shutdown, got %v", dialErr) } } + +func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) { + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on occupied addr: %v", err) + } + occupiedAddr := occupied.Addr().String() + defer occupied.Close() + + redirectListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for redirect addr: %v", err) + } + redirectAddr := redirectListener.Addr().String() + if err := redirectListener.Close(); err != nil { + t.Fatalf("close redirect addr probe: %v", err) + } + + engine := New() + err = engine.Run( + WithAddr(occupiedAddr), + WithTLS(&tls.Config{}), + WithHTTPRedirect(redirectAddr), + ) + if err == nil { + t.Fatal("expected non-graceful TLS redirect startup to return bind error") + } + if !strings.Contains(err.Error(), occupiedAddr) { + t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err) + } +} diff --git a/touka.go b/touka.go index dd529cb..4ad81da 100644 --- a/touka.go +++ b/touka.go @@ -22,10 +22,10 @@ type HandlerFunc func(*Context) // HandlersChain 定义处理函数链(中间件栈)的类型。 type HandlersChain []HandlerFunc -// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 -type IRouter interface { - Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 - Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 +// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 +type Router interface { + Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组 + Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 GET(relativePath string, handlers ...HandlerFunc) From 9e57f5a5f56d5ab1b3bc6c981c948f710c67e2cf Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 20:00:58 +0800 Subject: [PATCH 3/4] fix: stop redirect siblings on shutdown Make the non-graceful HTTPS redirect path shut down all sibling servers after any server returns, so cleanup stays consistent with the graceful path and partial shutdowns do not leave the redirect listener running. --- serve.go | 9 ++++++--- serve_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/serve.go b/serve.go index b2ba358..386eaf5 100644 --- a/serve.go +++ b/serve.go @@ -518,13 +518,16 @@ func (engine *Engine) Run(opts ...RunOption) error { } err := <-serverStopped - if err != nil && !errors.Is(err, http.ErrServerClosed) { - if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { + if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) { return errors.Join(err, shutdownErr) } + return shutdownErr + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { return err } - return err + return nil } protocolLabel := "HTTP" diff --git a/serve_test.go b/serve_test.go index 2bdddc5..8de14c3 100644 --- a/serve_test.go +++ b/serve_test.go @@ -2,9 +2,15 @@ package touka import ( "context" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "io" + "math/big" "net" "net/http" "net/http/httptest" @@ -13,6 +19,41 @@ import ( "time" ) +func generateSelfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate private key: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "127.0.0.1"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("create self-signed cert: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("parse self-signed cert: %v", err) + } + return cert +} + func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { From 121679b44e160aab44e9fa8a98d99002a26c8010 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 20:31:10 +0800 Subject: [PATCH 4/4] fix: preserve IPv6 brackets in redirects Re-wrap bare IPv6 hosts after stripping ports so HTTPS redirect URLs stay valid. Add a regression test covering bracketed IPv6 hosts in redirect responses. --- serve.go | 3 +++ serve_test.go | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/serve.go b/serve.go index 386eaf5..0fc83f9 100644 --- a/serve.go +++ b/serve.go @@ -332,6 +332,9 @@ func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { if parsedHost, _, err := net.SplitHostPort(host); err == nil { host = parsedHost + if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") { + host = "[" + host + "]" + } } targetURL := "https://" + host diff --git a/serve_test.go b/serve_test.go index 8de14c3..c717653 100644 --- a/serve_test.go +++ b/serve_test.go @@ -399,6 +399,26 @@ func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t * } } +func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil) + req.Host = "[::1]:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" { + t.Fatalf("unexpected IPv6 redirect location: %q", location) + } +} + func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { occupied, err := net.Listen("tcp", "127.0.0.1:0") if err != nil {