feat: add redirect host selection options

Support explicit redirect host source selection for HTTP-to-HTTPS redirects with ordered header lookup, fixed host mode, and strict validation. Document the new redirect option relationships and add focused tests for 426 fallback, conflict checks, and non-graceful startup errors.
This commit is contained in:
wjqserver 2026-04-07 19:49:13 +08:00
parent e4d3eed379
commit e2cf08d5dd
5 changed files with 422 additions and 35 deletions

View file

@ -97,8 +97,106 @@ r.Run(
touka.WithHTTPRedirect(":80"), touka.WithHTTPRedirect(":80"),
touka.WithGracefulShutdown(10*time.Second), touka.WithGracefulShutdown(10*time.Second),
) )
// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(true),
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
),
)
// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(false),
touka.WithRedirectHost("example.com"),
),
)
``` ```
### HTTPS Redirect Host 策略
`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。
可用的 redirect 子选项:
- `touka.WithUseHeaderHost(true|false)`
- `touka.WithRedirectHostHeaders([]string{...})`
- `touka.WithRedirectHost("example.com")`
#### 模式一:使用请求输入侧的 host
`WithUseHeaderHost(true)` 时:
- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host`
- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header并使用第一个非空值
- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required`
示例:
```go
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(true),
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
),
)
```
#### 模式二:使用配置的固定 host
`WithUseHeaderHost(false)` 时:
- 不读取 `Request.Host`
- 不读取 `WithRedirectHostHeaders(...)`
- 必须配置 `WithRedirectHost("example.com")`
示例:
```go
r.Run(
touka.WithAddr(":443"),
touka.WithTLS(tlsConfig),
touka.WithHTTPRedirect(
":80",
touka.WithUseHeaderHost(false),
touka.WithRedirectHost("example.com"),
),
)
```
#### 严格校验规则
以下组合会直接返回配置错误:
- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)`
- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)`
- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)`
- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)`
- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)`
#### 优先级关系
1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式
2. `WithUseHeaderHost(...)` 决定 host 来源模式
3. 当 `WithUseHeaderHost(true)` 时:
- 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询
- 未配置时使用 `Request.Host`
4. 当 `WithUseHeaderHost(false)` 时:
- 只使用 `WithRedirectHost(...)`
**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。
## 优雅停机 (Graceful Shutdown) ## 优雅停机 (Graceful Shutdown)
在部署新版本时我们希望服务器停止接收新请求但能处理完当前正在进行的请求。启用优雅关闭后Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 在部署新版本时我们希望服务器停止接收新请求但能处理完当前正在进行的请求。启用优雅关闭后Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。

View file

@ -626,7 +626,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
// Use 将全局中间件添加到 Engine // Use 将全局中间件添加到 Engine
// 这些中间件将应用于所有注册的路由 // 这些中间件将应用于所有注册的路由
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { func (engine *Engine) Use(middleware ...HandlerFunc) Router {
engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.globalHandlers = append(engine.globalHandlers, middleware...)
engine.rebuildFallbackChains() engine.rebuildFallbackChains()
return engine return engine
@ -695,7 +695,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo {
// Group 创建一个新的路由组 // Group 创建一个新的路由组
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router {
return &RouterGroup{ return &RouterGroup{
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
basePath: resolveRoutePath("/", relativePath), basePath: resolveRoutePath("/", relativePath),
@ -704,7 +704,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute
} }
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
// 它也实现了 IRouter 接口,允许嵌套分组 // 它也实现了 Router 接口,允许嵌套分组
type RouterGroup struct { type RouterGroup struct {
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
basePath string // 组路径前缀 basePath string // 组路径前缀
@ -713,7 +713,7 @@ type RouterGroup struct {
// Use 将中间件应用于当前路由组 // Use 将中间件应用于当前路由组
// 这些中间件将应用于当前组及其子组的所有路由 // 这些中间件将应用于当前组及其子组的所有路由
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { func (group *RouterGroup) Use(middleware ...HandlerFunc) Router {
group.Handlers = append(group.Handlers, middleware...) group.Handlers = append(group.Handlers, middleware...)
return group return group
} }
@ -759,7 +759,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) {
} }
// Group 为当前组创建一个新的子组 // Group 为当前组创建一个新的子组
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router {
return &RouterGroup{ return &RouterGroup{
Handlers: group.engine.combineHandlers(group.Handlers, handlers), Handlers: group.engine.combineHandlers(group.Handlers, handlers),
basePath: resolveRoutePath(group.basePath, relativePath), basePath: resolveRoutePath(group.basePath, relativePath),

178
serve.go
View file

@ -14,6 +14,7 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -32,15 +33,19 @@ const (
) )
type runConfig struct { type runConfig struct {
addr string addr string
httpRedirectAddr string httpRedirectAddr string
tlsConfig *tls.Config tlsConfig *tls.Config
graceful bool redirectHost string
shutdownTimeout time.Duration redirectHostHeaders []string
gracefulCtx context.Context useHeaderHost bool
mode runMode useHeaderHostSet bool
shutdownDefaultSet bool graceful bool
shutdownTimeoutSet bool shutdownTimeout time.Duration
gracefulCtx context.Context
mode runMode
shutdownDefaultSet bool
shutdownTimeoutSet bool
} }
type RunOption interface { type RunOption interface {
@ -58,9 +63,20 @@ func defaultRunConfig() runConfig {
addr: ":8080", addr: ":8080",
shutdownTimeout: defaultShutdownTimeout, shutdownTimeout: defaultShutdownTimeout,
mode: runModeHTTP, mode: runModeHTTP,
useHeaderHost: true,
} }
} }
type HTTPRedirectOption interface {
applyRedirect(*runConfig) error
}
type redirectOptionFunc func(*runConfig) error
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
return f(cfg)
}
func WithAddr(addr string) RunOption { func WithAddr(addr string) RunOption {
return runOptionFunc(func(cfg *runConfig) error { return runOptionFunc(func(cfg *runConfig) error {
if addr == "" { if addr == "" {
@ -84,13 +100,52 @@ func WithTLS(tlsConfig *tls.Config) RunOption {
}) })
} }
func WithHTTPRedirect(addr string) RunOption { func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
return runOptionFunc(func(cfg *runConfig) error { return runOptionFunc(func(cfg *runConfig) error {
if addr == "" { if addr == "" {
return errors.New("http redirect address must not be empty") return errors.New("http redirect address must not be empty")
} }
cfg.httpRedirectAddr = addr cfg.httpRedirectAddr = addr
cfg.mode = runModeHTTPSRedirect cfg.mode = runModeHTTPSRedirect
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt.applyRedirect(cfg); err != nil {
return err
}
}
return nil
})
}
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.useHeaderHost = enabled
cfg.useHeaderHostSet = true
return nil
})
}
func WithRedirectHost(host string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
if host == "" {
return errors.New("redirect host must not be empty")
}
cfg.redirectHost = host
return nil
})
}
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
for _, header := range headers {
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
if trimmed != "" {
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
}
}
return nil return nil
}) })
} }
@ -215,16 +270,68 @@ func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
return server return server
} }
func buildRedirectServer(engine *Engine, httpsAddr, httpAddr string) (*http.Server, error) { func firstRedirectHeaderHost(r *http.Request, headers []string) string {
if r == nil {
return ""
}
for _, header := range headers {
value := strings.TrimSpace(r.Header.Get(header))
if value == "" {
continue
}
if comma := strings.IndexByte(value, ','); comma >= 0 {
value = strings.TrimSpace(value[:comma])
}
if value != "" {
return value
}
}
return ""
}
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if cfg.redirectHost == "" {
return "", http.StatusInternalServerError, false
}
return cfg.redirectHost, 0, true
}
if len(cfg.redirectHostHeaders) > 0 {
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
if host == "" {
return "", http.StatusUpgradeRequired, false
}
return host, 0, true
}
if r == nil {
return "", http.StatusUpgradeRequired, false
}
host := strings.TrimSpace(r.Host)
if host == "" {
return "", http.StatusUpgradeRequired, false
}
return host, 0, true
}
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
httpsAddr := cfg.addr
httpAddr := cfg.httpRedirectAddr
httpsPort, err := parseHTTPSPort(httpsAddr) httpsPort, err := parseHTTPSPort(httpsAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host) host, statusCode, ok := redirectTargetHost(r, cfg)
if err != nil { if !ok {
host = r.Host http.Error(w, http.StatusText(statusCode), statusCode)
return
}
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
host = parsedHost
} }
targetURL := "https://" + host targetURL := "https://" + host
@ -248,12 +355,26 @@ func validateRunConfig(cfg runConfig) error {
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil { if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
return errors.New("https mode requires WithTLS") return errors.New("https mode requires WithTLS")
} }
if cfg.httpRedirectAddr != "" && cfg.mode != runModeHTTPSRedirect {
cfg.mode = runModeHTTPSRedirect
}
if cfg.gracefulCtx != nil && !cfg.graceful { if cfg.gracefulCtx != nil && !cfg.graceful {
return errors.New("WithShutdownContext requires graceful shutdown") return errors.New("WithShutdownContext requires graceful shutdown")
} }
if len(cfg.redirectHostHeaders) > 0 {
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
}
}
if cfg.useHeaderHostSet && cfg.useHeaderHost {
if cfg.redirectHost != "" {
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
}
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if cfg.redirectHost == "" {
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
}
if len(cfg.redirectHostHeaders) > 0 {
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
}
}
return nil return nil
} }
@ -286,7 +407,7 @@ func shutdownServers(servers []*http.Server, timeout time.Duration) error {
wg.Add(1) wg.Add(1)
go func(s *http.Server) { go func(s *http.Server) {
defer wg.Done() defer wg.Done()
if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { if err := s.Shutdown(ctx); err != nil {
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
} }
}(srv) }(srv)
@ -378,7 +499,7 @@ func (engine *Engine) Run(opts ...RunOption) error {
servers := []*http.Server{mainServer} servers := []*http.Server{mainServer}
serveTLSFlags := []bool{serveTLS} serveTLSFlags := []bool{serveTLS}
if cfg.mode == runModeHTTPSRedirect { if cfg.mode == runModeHTTPSRedirect {
redirectServer, err := buildRedirectServer(engine, cfg.addr, cfg.httpRedirectAddr) redirectServer, err := buildRedirectServer(engine, cfg)
if err != nil { if err != nil {
return err return err
} }
@ -388,9 +509,22 @@ func (engine *Engine) Run(opts ...RunOption) error {
if !cfg.graceful { if !cfg.graceful {
if len(servers) > 1 { if len(servers) > 1 {
runServer("HTTPS", servers[0], true) serverStopped := make(chan error, len(servers))
log.Printf("Starting Touka HTTP Redirect server on %s", servers[1].Addr) for i, srv := range servers {
return serveServer(servers[1], false) serveTLSFlag := serveTLSFlags[i]
go func(server *http.Server, useTLS bool) {
serverStopped <- serveServer(server, useTLS)
}(srv, serveTLSFlag)
}
err := <-serverStopped
if err != nil && !errors.Is(err, http.ErrServerClosed) {
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
return errors.Join(err, shutdownErr)
}
return err
}
return err
} }
protocolLabel := "HTTP" protocolLabel := "HTTP"

View file

@ -90,6 +90,18 @@ func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
} }
} }
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
engine := New()
err := engine.Run(
WithAddr(":443"),
WithTLS(&tls.Config{}),
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
)
if err == nil {
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
}
}
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) { func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
cfg := defaultRunConfig() cfg := defaultRunConfig()
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil { if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
@ -122,7 +134,7 @@ func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
engine := New() engine := New()
if _, err := buildRedirectServer(engine, "example.com", ":80"); err == nil { if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
t.Fatal("expected redirect server builder to reject https address without port") t.Fatal("expected redirect server builder to reject https address without port")
} }
} }
@ -139,6 +151,40 @@ func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
} }
} }
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
cfg := defaultRunConfig()
cfg.httpRedirectAddr = ":80"
if err := validateRunConfig(cfg); err != nil {
t.Fatalf("validate run config: %v", err)
}
if cfg.mode != runModeHTTP {
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
}
}
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = false
cfg.useHeaderHostSet = true
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected configured host mode without redirect host to fail validation")
}
}
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = true
cfg.useHeaderHostSet = true
cfg.redirectHost = "configured.example"
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
}
}
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) { func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
engine := New() engine := New()
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP}) server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
@ -189,7 +235,7 @@ func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
s.ReadTimeout = time.Second s.ReadTimeout = time.Second
}) })
server, err := buildRedirectServer(engine, ":443", ":80") server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil { if err != nil {
t.Fatalf("build redirect server: %v", err) t.Fatalf("build redirect server: %v", err)
} }
@ -216,7 +262,7 @@ func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
engine := New() engine := New()
server, err := buildRedirectServer(engine, ":443", ":80") server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil { if err != nil {
t.Fatalf("build redirect server: %v", err) t.Fatalf("build redirect server: %v", err)
} }
@ -234,6 +280,84 @@ func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
} }
} }
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
req.Header.Set("X-First-Host", "first.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUpgradeRequired {
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
}
}
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: false,
useHeaderHostSet: true,
redirectHost: "configured.example",
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0") occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
@ -252,7 +376,7 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
} }
engine := New() engine := New()
redirectServer, err := buildRedirectServer(engine, ":443", redirectAddr) redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
if err != nil { if err != nil {
t.Fatalf("build redirect server: %v", err) t.Fatalf("build redirect server: %v", err)
} }
@ -275,3 +399,34 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr) t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
} }
} }
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on occupied addr: %v", err)
}
occupiedAddr := occupied.Addr().String()
defer occupied.Close()
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for redirect addr: %v", err)
}
redirectAddr := redirectListener.Addr().String()
if err := redirectListener.Close(); err != nil {
t.Fatalf("close redirect addr probe: %v", err)
}
engine := New()
err = engine.Run(
WithAddr(occupiedAddr),
WithTLS(&tls.Config{}),
WithHTTPRedirect(redirectAddr),
)
if err == nil {
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
}
if !strings.Contains(err.Error(), occupiedAddr) {
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
}
}

View file

@ -22,10 +22,10 @@ type HandlerFunc func(*Context)
// HandlersChain 定义处理函数链(中间件栈)的类型。 // HandlersChain 定义处理函数链(中间件栈)的类型。
type HandlersChain []HandlerFunc type HandlersChain []HandlerFunc
// IRouter 定义了路由注册的接口提供路由分组和HTTP方法注册的能力。 // Router 定义了路由注册的接口提供路由分组和HTTP方法注册的能力。
type IRouter interface { type Router interface {
Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组
Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
GET(relativePath string, handlers ...HandlerFunc) GET(relativePath string, handlers ...HandlerFunc)