diff --git a/docs/advanced.md b/docs/advanced.md index 7e6a417..eb44c2d 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -97,8 +97,106 @@ r.Run( touka.WithHTTPRedirect(":80"), touka.WithGracefulShutdown(10*time.Second), ) + +// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) + +// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) ``` +### HTTPS Redirect Host 策略 + +`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。 + +可用的 redirect 子选项: + +- `touka.WithUseHeaderHost(true|false)` +- `touka.WithRedirectHostHeaders([]string{...})` +- `touka.WithRedirectHost("example.com")` + +#### 模式一:使用请求输入侧的 host + +当 `WithUseHeaderHost(true)` 时: + +- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host` +- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值 +- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) +``` + +#### 模式二:使用配置的固定 host + +当 `WithUseHeaderHost(false)` 时: + +- 不读取 `Request.Host` +- 不读取 `WithRedirectHostHeaders(...)` +- 必须配置 `WithRedirectHost("example.com")` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) +``` + +#### 严格校验规则 + +以下组合会直接返回配置错误: + +- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)` +- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)` +- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)` +- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)` +- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)` + +#### 优先级关系 + +1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式 +2. `WithUseHeaderHost(...)` 决定 host 来源模式 +3. 当 `WithUseHeaderHost(true)` 时: + - 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询 + - 未配置时使用 `Request.Host` +4. 当 `WithUseHeaderHost(false)` 时: + - 只使用 `WithRedirectHost(...)` + +**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。 + ## 优雅停机 (Graceful Shutdown) 在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 diff --git a/engine.go b/engine.go index 2849ffa..b0723e7 100644 --- a/engine.go +++ b/engine.go @@ -626,7 +626,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // Use 将全局中间件添加到 Engine // 这些中间件将应用于所有注册的路由 -func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { +func (engine *Engine) Use(middleware ...HandlerFunc) Router { engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.rebuildFallbackChains() return engine @@ -695,7 +695,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo { // Group 创建一个新的路由组 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 -func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 basePath: resolveRoutePath("/", relativePath), @@ -704,7 +704,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute } // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 -// 它也实现了 IRouter 接口,允许嵌套分组 +// 它也实现了 Router 接口,允许嵌套分组 type RouterGroup struct { Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 basePath string // 组路径前缀 @@ -713,7 +713,7 @@ type RouterGroup struct { // Use 将中间件应用于当前路由组 // 这些中间件将应用于当前组及其子组的所有路由 -func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { +func (group *RouterGroup) Use(middleware ...HandlerFunc) Router { group.Handlers = append(group.Handlers, middleware...) return group } @@ -759,7 +759,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) { } // Group 为当前组创建一个新的子组 -func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: group.engine.combineHandlers(group.Handlers, handlers), basePath: resolveRoutePath(group.basePath, relativePath), diff --git a/serve.go b/serve.go index 2c8c73b..b2ba358 100644 --- a/serve.go +++ b/serve.go @@ -14,6 +14,7 @@ import ( "net/http" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -32,15 +33,19 @@ const ( ) type runConfig struct { - addr string - httpRedirectAddr string - tlsConfig *tls.Config - graceful bool - shutdownTimeout time.Duration - gracefulCtx context.Context - mode runMode - shutdownDefaultSet bool - shutdownTimeoutSet bool + addr string + httpRedirectAddr string + tlsConfig *tls.Config + redirectHost string + redirectHostHeaders []string + useHeaderHost bool + useHeaderHostSet bool + graceful bool + shutdownTimeout time.Duration + gracefulCtx context.Context + mode runMode + shutdownDefaultSet bool + shutdownTimeoutSet bool } type RunOption interface { @@ -58,9 +63,20 @@ func defaultRunConfig() runConfig { addr: ":8080", shutdownTimeout: defaultShutdownTimeout, mode: runModeHTTP, + useHeaderHost: true, } } +type HTTPRedirectOption interface { + applyRedirect(*runConfig) error +} + +type redirectOptionFunc func(*runConfig) error + +func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error { + return f(cfg) +} + func WithAddr(addr string) RunOption { return runOptionFunc(func(cfg *runConfig) error { if addr == "" { @@ -84,13 +100,52 @@ func WithTLS(tlsConfig *tls.Config) RunOption { }) } -func WithHTTPRedirect(addr string) RunOption { +func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption { return runOptionFunc(func(cfg *runConfig) error { if addr == "" { return errors.New("http redirect address must not be empty") } cfg.httpRedirectAddr = addr cfg.mode = runModeHTTPSRedirect + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.applyRedirect(cfg); err != nil { + return err + } + } + return nil + }) +} + +func WithUseHeaderHost(enabled bool) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.useHeaderHost = enabled + cfg.useHeaderHostSet = true + return nil + }) +} + +func WithRedirectHost(host string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + if host == "" { + return errors.New("redirect host must not be empty") + } + cfg.redirectHost = host + return nil + }) +} + +func WithRedirectHostHeaders(headers []string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0] + for _, header := range headers { + trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header)) + if trimmed != "" { + cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed) + } + } return nil }) } @@ -215,16 +270,68 @@ func buildMainServer(engine *Engine, cfg runConfig) *http.Server { return server } -func buildRedirectServer(engine *Engine, httpsAddr, httpAddr string) (*http.Server, error) { +func firstRedirectHeaderHost(r *http.Request, headers []string) string { + if r == nil { + return "" + } + for _, header := range headers { + value := strings.TrimSpace(r.Header.Get(header)) + if value == "" { + continue + } + if comma := strings.IndexByte(value, ','); comma >= 0 { + value = strings.TrimSpace(value[:comma]) + } + if value != "" { + return value + } + } + return "" +} + +func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) { + if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return "", http.StatusInternalServerError, false + } + return cfg.redirectHost, 0, true + } + + if len(cfg.redirectHostHeaders) > 0 { + host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true + } + + if r == nil { + return "", http.StatusUpgradeRequired, false + } + host := strings.TrimSpace(r.Host) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true +} + +func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { + httpsAddr := cfg.addr + httpAddr := cfg.httpRedirectAddr httpsPort, err := parseHTTPSPort(httpsAddr) if err != nil { return nil, err } redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - host = r.Host + host, statusCode, ok := redirectTargetHost(r, cfg) + if !ok { + http.Error(w, http.StatusText(statusCode), statusCode) + return + } + + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + host = parsedHost } targetURL := "https://" + host @@ -248,12 +355,26 @@ func validateRunConfig(cfg runConfig) error { if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil { return errors.New("https mode requires WithTLS") } - if cfg.httpRedirectAddr != "" && cfg.mode != runModeHTTPSRedirect { - cfg.mode = runModeHTTPSRedirect - } if cfg.gracefulCtx != nil && !cfg.graceful { return errors.New("WithShutdownContext requires graceful shutdown") } + if len(cfg.redirectHostHeaders) > 0 { + if !cfg.useHeaderHostSet || !cfg.useHeaderHost { + return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)") + } + } + if cfg.useHeaderHostSet && cfg.useHeaderHost { + if cfg.redirectHost != "" { + return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)") + } + } else if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return errors.New("WithUseHeaderHost(false) requires WithRedirectHost") + } + if len(cfg.redirectHostHeaders) > 0 { + return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)") + } + } return nil } @@ -286,7 +407,7 @@ func shutdownServers(servers []*http.Server, timeout time.Duration) error { wg.Add(1) go func(s *http.Server) { defer wg.Done() - if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := s.Shutdown(ctx); err != nil { errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) } }(srv) @@ -378,7 +499,7 @@ func (engine *Engine) Run(opts ...RunOption) error { servers := []*http.Server{mainServer} serveTLSFlags := []bool{serveTLS} if cfg.mode == runModeHTTPSRedirect { - redirectServer, err := buildRedirectServer(engine, cfg.addr, cfg.httpRedirectAddr) + redirectServer, err := buildRedirectServer(engine, cfg) if err != nil { return err } @@ -388,9 +509,22 @@ func (engine *Engine) Run(opts ...RunOption) error { if !cfg.graceful { if len(servers) > 1 { - runServer("HTTPS", servers[0], true) - log.Printf("Starting Touka HTTP Redirect server on %s", servers[1].Addr) - return serveServer(servers[1], false) + serverStopped := make(chan error, len(servers)) + for i, srv := range servers { + serveTLSFlag := serveTLSFlags[i] + go func(server *http.Server, useTLS bool) { + serverStopped <- serveServer(server, useTLS) + }(srv, serveTLSFlag) + } + + err := <-serverStopped + if err != nil && !errors.Is(err, http.ErrServerClosed) { + if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { + return errors.Join(err, shutdownErr) + } + return err + } + return err } protocolLabel := "HTTP" diff --git a/serve_test.go b/serve_test.go index 6ecbeba..2bdddc5 100644 --- a/serve_test.go +++ b/serve_test.go @@ -90,6 +90,18 @@ func TestRunRejectsRedirectWithoutTLS(t *testing.T) { } } +func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) { + engine := New() + err := engine.Run( + WithAddr(":443"), + WithTLS(&tls.Config{}), + WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})), + ) + if err == nil { + t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail") + } +} + func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) { cfg := defaultRunConfig() if err := WithGracefulShutdownDefault().apply(&cfg); err != nil { @@ -122,7 +134,7 @@ func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) { func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { engine := New() - if _, err := buildRedirectServer(engine, "example.com", ":80"); err == nil { + if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil { t.Fatal("expected redirect server builder to reject https address without port") } } @@ -139,6 +151,40 @@ func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { } } +func TestValidateRunConfigDoesNotMutateMode(t *testing.T) { + cfg := defaultRunConfig() + cfg.httpRedirectAddr = ":80" + if err := validateRunConfig(cfg); err != nil { + t.Fatalf("validate run config: %v", err) + } + if cfg.mode != runModeHTTP { + t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode) + } +} + +func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = false + cfg.useHeaderHostSet = true + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected configured host mode without redirect host to fail validation") + } +} + +func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = true + cfg.useHeaderHostSet = true + cfg.redirectHost = "configured.example" + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected redirect host to be rejected when header host mode is enabled") + } +} + func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) { engine := New() server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP}) @@ -189,7 +235,7 @@ func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) { s.ReadTimeout = time.Second }) - server, err := buildRedirectServer(engine, ":443", ":80") + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) if err != nil { t.Fatalf("build redirect server: %v", err) } @@ -216,7 +262,7 @@ func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) { func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { engine := New() - server, err := buildRedirectServer(engine, ":443", ":80") + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) if err != nil { t.Fatalf("build redirect server: %v", err) } @@ -234,6 +280,84 @@ func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { } } +func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + req.Header.Set("X-First-Host", "first.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUpgradeRequired { + t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code) + } +} + +func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: false, + useHeaderHostSet: true, + redirectHost: "configured.example", + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { occupied, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -252,7 +376,7 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { } engine := New() - redirectServer, err := buildRedirectServer(engine, ":443", redirectAddr) + redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr}) if err != nil { t.Fatalf("build redirect server: %v", err) } @@ -275,3 +399,34 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { t.Fatalf("unexpected dial result after shutdown, got %v", dialErr) } } + +func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) { + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on occupied addr: %v", err) + } + occupiedAddr := occupied.Addr().String() + defer occupied.Close() + + redirectListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for redirect addr: %v", err) + } + redirectAddr := redirectListener.Addr().String() + if err := redirectListener.Close(); err != nil { + t.Fatalf("close redirect addr probe: %v", err) + } + + engine := New() + err = engine.Run( + WithAddr(occupiedAddr), + WithTLS(&tls.Config{}), + WithHTTPRedirect(redirectAddr), + ) + if err == nil { + t.Fatal("expected non-graceful TLS redirect startup to return bind error") + } + if !strings.Contains(err.Error(), occupiedAddr) { + t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err) + } +} diff --git a/touka.go b/touka.go index dd529cb..4ad81da 100644 --- a/touka.go +++ b/touka.go @@ -22,10 +22,10 @@ type HandlerFunc func(*Context) // HandlersChain 定义处理函数链(中间件栈)的类型。 type HandlersChain []HandlerFunc -// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 -type IRouter interface { - Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 - Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 +// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 +type Router interface { + Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组 + Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 GET(relativePath string, handlers ...HandlerFunc)