mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Merge pull request #83 from infinite-iroha/break/v1-redirect-host-strategy
feat: add redirect host selection options
This commit is contained in:
commit
7cb777225f
5 changed files with 489 additions and 35 deletions
|
|
@ -97,8 +97,106 @@ r.Run(
|
||||||
touka.WithHTTPRedirect(":80"),
|
touka.WithHTTPRedirect(":80"),
|
||||||
touka.WithGracefulShutdown(10*time.Second),
|
touka.WithGracefulShutdown(10*time.Second),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host)
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithHTTPRedirect(
|
||||||
|
":80",
|
||||||
|
touka.WithUseHeaderHost(true),
|
||||||
|
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host)
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithHTTPRedirect(
|
||||||
|
":80",
|
||||||
|
touka.WithUseHeaderHost(false),
|
||||||
|
touka.WithRedirectHost("example.com"),
|
||||||
|
),
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### HTTPS Redirect Host 策略
|
||||||
|
|
||||||
|
`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。
|
||||||
|
|
||||||
|
可用的 redirect 子选项:
|
||||||
|
|
||||||
|
- `touka.WithUseHeaderHost(true|false)`
|
||||||
|
- `touka.WithRedirectHostHeaders([]string{...})`
|
||||||
|
- `touka.WithRedirectHost("example.com")`
|
||||||
|
|
||||||
|
#### 模式一:使用请求输入侧的 host
|
||||||
|
|
||||||
|
当 `WithUseHeaderHost(true)` 时:
|
||||||
|
|
||||||
|
- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host`
|
||||||
|
- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值
|
||||||
|
- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithHTTPRedirect(
|
||||||
|
":80",
|
||||||
|
touka.WithUseHeaderHost(true),
|
||||||
|
touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 模式二:使用配置的固定 host
|
||||||
|
|
||||||
|
当 `WithUseHeaderHost(false)` 时:
|
||||||
|
|
||||||
|
- 不读取 `Request.Host`
|
||||||
|
- 不读取 `WithRedirectHostHeaders(...)`
|
||||||
|
- 必须配置 `WithRedirectHost("example.com")`
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.Run(
|
||||||
|
touka.WithAddr(":443"),
|
||||||
|
touka.WithTLS(tlsConfig),
|
||||||
|
touka.WithHTTPRedirect(
|
||||||
|
":80",
|
||||||
|
touka.WithUseHeaderHost(false),
|
||||||
|
touka.WithRedirectHost("example.com"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 严格校验规则
|
||||||
|
|
||||||
|
以下组合会直接返回配置错误:
|
||||||
|
|
||||||
|
- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)`
|
||||||
|
- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)`
|
||||||
|
- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)`
|
||||||
|
- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)`
|
||||||
|
- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)`
|
||||||
|
|
||||||
|
#### 优先级关系
|
||||||
|
|
||||||
|
1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式
|
||||||
|
2. `WithUseHeaderHost(...)` 决定 host 来源模式
|
||||||
|
3. 当 `WithUseHeaderHost(true)` 时:
|
||||||
|
- 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询
|
||||||
|
- 未配置时使用 `Request.Host`
|
||||||
|
4. 当 `WithUseHeaderHost(false)` 时:
|
||||||
|
- 只使用 `WithRedirectHost(...)`
|
||||||
|
|
||||||
|
**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。
|
||||||
|
|
||||||
## 优雅停机 (Graceful Shutdown)
|
## 优雅停机 (Graceful Shutdown)
|
||||||
|
|
||||||
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。
|
在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。
|
||||||
|
|
|
||||||
10
engine.go
10
engine.go
|
|
@ -626,7 +626,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle
|
||||||
|
|
||||||
// Use 将全局中间件添加到 Engine
|
// Use 将全局中间件添加到 Engine
|
||||||
// 这些中间件将应用于所有注册的路由
|
// 这些中间件将应用于所有注册的路由
|
||||||
func (engine *Engine) Use(middleware ...HandlerFunc) IRouter {
|
func (engine *Engine) Use(middleware ...HandlerFunc) Router {
|
||||||
engine.globalHandlers = append(engine.globalHandlers, middleware...)
|
engine.globalHandlers = append(engine.globalHandlers, middleware...)
|
||||||
engine.rebuildFallbackChains()
|
engine.rebuildFallbackChains()
|
||||||
return engine
|
return engine
|
||||||
|
|
@ -695,7 +695,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo {
|
||||||
|
|
||||||
// Group 创建一个新的路由组
|
// Group 创建一个新的路由组
|
||||||
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
|
// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起
|
||||||
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter {
|
func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router {
|
||||||
return &RouterGroup{
|
return &RouterGroup{
|
||||||
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
|
Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件
|
||||||
basePath: resolveRoutePath("/", relativePath),
|
basePath: resolveRoutePath("/", relativePath),
|
||||||
|
|
@ -704,7 +704,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
|
// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由
|
||||||
// 它也实现了 IRouter 接口,允许嵌套分组
|
// 它也实现了 Router 接口,允许嵌套分组
|
||||||
type RouterGroup struct {
|
type RouterGroup struct {
|
||||||
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
|
Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由
|
||||||
basePath string // 组路径前缀
|
basePath string // 组路径前缀
|
||||||
|
|
@ -713,7 +713,7 @@ type RouterGroup struct {
|
||||||
|
|
||||||
// Use 将中间件应用于当前路由组
|
// Use 将中间件应用于当前路由组
|
||||||
// 这些中间件将应用于当前组及其子组的所有路由
|
// 这些中间件将应用于当前组及其子组的所有路由
|
||||||
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter {
|
func (group *RouterGroup) Use(middleware ...HandlerFunc) Router {
|
||||||
group.Handlers = append(group.Handlers, middleware...)
|
group.Handlers = append(group.Handlers, middleware...)
|
||||||
return group
|
return group
|
||||||
}
|
}
|
||||||
|
|
@ -759,7 +759,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group 为当前组创建一个新的子组
|
// Group 为当前组创建一个新的子组
|
||||||
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter {
|
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router {
|
||||||
return &RouterGroup{
|
return &RouterGroup{
|
||||||
Handlers: group.engine.combineHandlers(group.Handlers, handlers),
|
Handlers: group.engine.combineHandlers(group.Handlers, handlers),
|
||||||
basePath: resolveRoutePath(group.basePath, relativePath),
|
basePath: resolveRoutePath(group.basePath, relativePath),
|
||||||
|
|
|
||||||
166
serve.go
166
serve.go
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -35,6 +36,10 @@ type runConfig struct {
|
||||||
addr string
|
addr string
|
||||||
httpRedirectAddr string
|
httpRedirectAddr string
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
|
redirectHost string
|
||||||
|
redirectHostHeaders []string
|
||||||
|
useHeaderHost bool
|
||||||
|
useHeaderHostSet bool
|
||||||
graceful bool
|
graceful bool
|
||||||
shutdownTimeout time.Duration
|
shutdownTimeout time.Duration
|
||||||
gracefulCtx context.Context
|
gracefulCtx context.Context
|
||||||
|
|
@ -58,9 +63,20 @@ func defaultRunConfig() runConfig {
|
||||||
addr: ":8080",
|
addr: ":8080",
|
||||||
shutdownTimeout: defaultShutdownTimeout,
|
shutdownTimeout: defaultShutdownTimeout,
|
||||||
mode: runModeHTTP,
|
mode: runModeHTTP,
|
||||||
|
useHeaderHost: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HTTPRedirectOption interface {
|
||||||
|
applyRedirect(*runConfig) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type redirectOptionFunc func(*runConfig) error
|
||||||
|
|
||||||
|
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
|
||||||
|
return f(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
func WithAddr(addr string) RunOption {
|
func WithAddr(addr string) RunOption {
|
||||||
return runOptionFunc(func(cfg *runConfig) error {
|
return runOptionFunc(func(cfg *runConfig) error {
|
||||||
if addr == "" {
|
if addr == "" {
|
||||||
|
|
@ -84,13 +100,52 @@ func WithTLS(tlsConfig *tls.Config) RunOption {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithHTTPRedirect(addr string) RunOption {
|
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
|
||||||
return runOptionFunc(func(cfg *runConfig) error {
|
return runOptionFunc(func(cfg *runConfig) error {
|
||||||
if addr == "" {
|
if addr == "" {
|
||||||
return errors.New("http redirect address must not be empty")
|
return errors.New("http redirect address must not be empty")
|
||||||
}
|
}
|
||||||
cfg.httpRedirectAddr = addr
|
cfg.httpRedirectAddr = addr
|
||||||
cfg.mode = runModeHTTPSRedirect
|
cfg.mode = runModeHTTPSRedirect
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := opt.applyRedirect(cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
|
||||||
|
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||||
|
cfg.useHeaderHost = enabled
|
||||||
|
cfg.useHeaderHostSet = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRedirectHost(host string) HTTPRedirectOption {
|
||||||
|
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||||
|
if host == "" {
|
||||||
|
return errors.New("redirect host must not be empty")
|
||||||
|
}
|
||||||
|
cfg.redirectHost = host
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
|
||||||
|
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||||
|
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
|
||||||
|
for _, header := range headers {
|
||||||
|
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
|
||||||
|
if trimmed != "" {
|
||||||
|
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -215,16 +270,71 @@ func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildRedirectServer(engine *Engine, httpsAddr, httpAddr string) (*http.Server, error) {
|
func firstRedirectHeaderHost(r *http.Request, headers []string) string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, header := range headers {
|
||||||
|
value := strings.TrimSpace(r.Header.Get(header))
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if comma := strings.IndexByte(value, ','); comma >= 0 {
|
||||||
|
value = strings.TrimSpace(value[:comma])
|
||||||
|
}
|
||||||
|
if value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
|
||||||
|
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||||
|
if cfg.redirectHost == "" {
|
||||||
|
return "", http.StatusInternalServerError, false
|
||||||
|
}
|
||||||
|
return cfg.redirectHost, 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.redirectHostHeaders) > 0 {
|
||||||
|
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
|
||||||
|
if host == "" {
|
||||||
|
return "", http.StatusUpgradeRequired, false
|
||||||
|
}
|
||||||
|
return host, 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if r == nil {
|
||||||
|
return "", http.StatusUpgradeRequired, false
|
||||||
|
}
|
||||||
|
host := strings.TrimSpace(r.Host)
|
||||||
|
if host == "" {
|
||||||
|
return "", http.StatusUpgradeRequired, false
|
||||||
|
}
|
||||||
|
return host, 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
|
||||||
|
httpsAddr := cfg.addr
|
||||||
|
httpAddr := cfg.httpRedirectAddr
|
||||||
httpsPort, err := parseHTTPSPort(httpsAddr)
|
httpsPort, err := parseHTTPSPort(httpsAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
host, _, err := net.SplitHostPort(r.Host)
|
host, statusCode, ok := redirectTargetHost(r, cfg)
|
||||||
if err != nil {
|
if !ok {
|
||||||
host = r.Host
|
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
host = parsedHost
|
||||||
|
if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
|
||||||
|
host = "[" + host + "]"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
targetURL := "https://" + host
|
targetURL := "https://" + host
|
||||||
|
|
@ -248,12 +358,26 @@ func validateRunConfig(cfg runConfig) error {
|
||||||
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
|
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
|
||||||
return errors.New("https mode requires WithTLS")
|
return errors.New("https mode requires WithTLS")
|
||||||
}
|
}
|
||||||
if cfg.httpRedirectAddr != "" && cfg.mode != runModeHTTPSRedirect {
|
|
||||||
cfg.mode = runModeHTTPSRedirect
|
|
||||||
}
|
|
||||||
if cfg.gracefulCtx != nil && !cfg.graceful {
|
if cfg.gracefulCtx != nil && !cfg.graceful {
|
||||||
return errors.New("WithShutdownContext requires graceful shutdown")
|
return errors.New("WithShutdownContext requires graceful shutdown")
|
||||||
}
|
}
|
||||||
|
if len(cfg.redirectHostHeaders) > 0 {
|
||||||
|
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
|
||||||
|
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.useHeaderHostSet && cfg.useHeaderHost {
|
||||||
|
if cfg.redirectHost != "" {
|
||||||
|
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
|
||||||
|
}
|
||||||
|
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||||
|
if cfg.redirectHost == "" {
|
||||||
|
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
|
||||||
|
}
|
||||||
|
if len(cfg.redirectHostHeaders) > 0 {
|
||||||
|
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -286,7 +410,7 @@ func shutdownServers(servers []*http.Server, timeout time.Duration) error {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(s *http.Server) {
|
go func(s *http.Server) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
if err := s.Shutdown(ctx); err != nil {
|
||||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
||||||
}
|
}
|
||||||
}(srv)
|
}(srv)
|
||||||
|
|
@ -378,7 +502,7 @@ func (engine *Engine) Run(opts ...RunOption) error {
|
||||||
servers := []*http.Server{mainServer}
|
servers := []*http.Server{mainServer}
|
||||||
serveTLSFlags := []bool{serveTLS}
|
serveTLSFlags := []bool{serveTLS}
|
||||||
if cfg.mode == runModeHTTPSRedirect {
|
if cfg.mode == runModeHTTPSRedirect {
|
||||||
redirectServer, err := buildRedirectServer(engine, cfg.addr, cfg.httpRedirectAddr)
|
redirectServer, err := buildRedirectServer(engine, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -388,9 +512,25 @@ func (engine *Engine) Run(opts ...RunOption) error {
|
||||||
|
|
||||||
if !cfg.graceful {
|
if !cfg.graceful {
|
||||||
if len(servers) > 1 {
|
if len(servers) > 1 {
|
||||||
runServer("HTTPS", servers[0], true)
|
serverStopped := make(chan error, len(servers))
|
||||||
log.Printf("Starting Touka HTTP Redirect server on %s", servers[1].Addr)
|
for i, srv := range servers {
|
||||||
return serveServer(servers[1], false)
|
serveTLSFlag := serveTLSFlags[i]
|
||||||
|
go func(server *http.Server, useTLS bool) {
|
||||||
|
serverStopped <- serveServer(server, useTLS)
|
||||||
|
}(srv, serveTLSFlag)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := <-serverStopped
|
||||||
|
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return errors.Join(err, shutdownErr)
|
||||||
|
}
|
||||||
|
return shutdownErr
|
||||||
|
}
|
||||||
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
protocolLabel := "HTTP"
|
protocolLabel := "HTTP"
|
||||||
|
|
|
||||||
224
serve_test.go
224
serve_test.go
|
|
@ -2,9 +2,15 @@ package touka
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
@ -13,6 +19,41 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func generateSelfSignedCert(t *testing.T) tls.Certificate {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate private key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{CommonName: "127.0.0.1"},
|
||||||
|
NotBefore: time.Now().Add(-time.Hour),
|
||||||
|
NotAfter: time.Now().Add(time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageServerAuth,
|
||||||
|
},
|
||||||
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||||
|
}
|
||||||
|
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create self-signed cert: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
||||||
|
|
||||||
|
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse self-signed cert: %v", err)
|
||||||
|
}
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
|
||||||
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
|
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -90,6 +131,18 @@ func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
err := engine.Run(
|
||||||
|
WithAddr(":443"),
|
||||||
|
WithTLS(&tls.Config{}),
|
||||||
|
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
|
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
|
||||||
cfg := defaultRunConfig()
|
cfg := defaultRunConfig()
|
||||||
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
|
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
|
||||||
|
|
@ -122,7 +175,7 @@ func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
|
||||||
|
|
||||||
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
|
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
|
||||||
engine := New()
|
engine := New()
|
||||||
if _, err := buildRedirectServer(engine, "example.com", ":80"); err == nil {
|
if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
|
||||||
t.Fatal("expected redirect server builder to reject https address without port")
|
t.Fatal("expected redirect server builder to reject https address without port")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -139,6 +192,40 @@ func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
cfg.httpRedirectAddr = ":80"
|
||||||
|
if err := validateRunConfig(cfg); err != nil {
|
||||||
|
t.Fatalf("validate run config: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.mode != runModeHTTP {
|
||||||
|
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
cfg.mode = runModeHTTPSRedirect
|
||||||
|
cfg.tlsConfig = &tls.Config{}
|
||||||
|
cfg.useHeaderHost = false
|
||||||
|
cfg.useHeaderHostSet = true
|
||||||
|
if err := validateRunConfig(cfg); err == nil {
|
||||||
|
t.Fatal("expected configured host mode without redirect host to fail validation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
|
||||||
|
cfg := defaultRunConfig()
|
||||||
|
cfg.mode = runModeHTTPSRedirect
|
||||||
|
cfg.tlsConfig = &tls.Config{}
|
||||||
|
cfg.useHeaderHost = true
|
||||||
|
cfg.useHeaderHostSet = true
|
||||||
|
cfg.redirectHost = "configured.example"
|
||||||
|
if err := validateRunConfig(cfg); err == nil {
|
||||||
|
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
|
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
|
||||||
engine := New()
|
engine := New()
|
||||||
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
|
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
|
||||||
|
|
@ -189,7 +276,7 @@ func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
|
||||||
s.ReadTimeout = time.Second
|
s.ReadTimeout = time.Second
|
||||||
})
|
})
|
||||||
|
|
||||||
server, err := buildRedirectServer(engine, ":443", ":80")
|
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("build redirect server: %v", err)
|
t.Fatalf("build redirect server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -216,7 +303,7 @@ func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
|
||||||
|
|
||||||
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
|
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
|
||||||
engine := New()
|
engine := New()
|
||||||
server, err := buildRedirectServer(engine, ":443", ":80")
|
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("build redirect server: %v", err)
|
t.Fatalf("build redirect server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -234,6 +321,104 @@ func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{
|
||||||
|
addr: ":443",
|
||||||
|
httpRedirectAddr: ":80",
|
||||||
|
useHeaderHost: true,
|
||||||
|
useHeaderHostSet: true,
|
||||||
|
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||||
|
req.Host = "example.com:80"
|
||||||
|
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
||||||
|
req.Header.Set("X-First-Host", "first.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.Handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
|
||||||
|
t.Fatalf("unexpected redirect location: %q", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{
|
||||||
|
addr: ":443",
|
||||||
|
httpRedirectAddr: ":80",
|
||||||
|
useHeaderHost: true,
|
||||||
|
useHeaderHostSet: true,
|
||||||
|
redirectHostHeaders: []string{"X-Forwarded-Host"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||||
|
req.Host = "example.com:80"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.Handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUpgradeRequired {
|
||||||
|
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{
|
||||||
|
addr: ":443",
|
||||||
|
httpRedirectAddr: ":80",
|
||||||
|
useHeaderHost: false,
|
||||||
|
useHeaderHostSet: true,
|
||||||
|
redirectHost: "configured.example",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
||||||
|
req.Host = "example.com:80"
|
||||||
|
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.Handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
|
||||||
|
t.Fatalf("unexpected redirect location: %q", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build redirect server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil)
|
||||||
|
req.Host = "[::1]:80"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.Handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusMovedPermanently {
|
||||||
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
||||||
|
}
|
||||||
|
if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" {
|
||||||
|
t.Fatalf("unexpected IPv6 redirect location: %q", location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
|
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
|
||||||
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -252,7 +437,7 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := New()
|
engine := New()
|
||||||
redirectServer, err := buildRedirectServer(engine, ":443", redirectAddr)
|
redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("build redirect server: %v", err)
|
t.Fatalf("build redirect server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -275,3 +460,34 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
|
||||||
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
|
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
|
||||||
|
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen on occupied addr: %v", err)
|
||||||
|
}
|
||||||
|
occupiedAddr := occupied.Addr().String()
|
||||||
|
defer occupied.Close()
|
||||||
|
|
||||||
|
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen for redirect addr: %v", err)
|
||||||
|
}
|
||||||
|
redirectAddr := redirectListener.Addr().String()
|
||||||
|
if err := redirectListener.Close(); err != nil {
|
||||||
|
t.Fatalf("close redirect addr probe: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
err = engine.Run(
|
||||||
|
WithAddr(occupiedAddr),
|
||||||
|
WithTLS(&tls.Config{}),
|
||||||
|
WithHTTPRedirect(redirectAddr),
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), occupiedAddr) {
|
||||||
|
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
8
touka.go
8
touka.go
|
|
@ -22,10 +22,10 @@ type HandlerFunc func(*Context)
|
||||||
// HandlersChain 定义处理函数链(中间件栈)的类型。
|
// HandlersChain 定义处理函数链(中间件栈)的类型。
|
||||||
type HandlersChain []HandlerFunc
|
type HandlersChain []HandlerFunc
|
||||||
|
|
||||||
// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。
|
// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。
|
||||||
type IRouter interface {
|
type Router interface {
|
||||||
Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组
|
Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组
|
||||||
Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组
|
Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组
|
||||||
|
|
||||||
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
|
Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法
|
||||||
GET(relativePath string, handlers ...HandlerFunc)
|
GET(relativePath string, handlers ...HandlerFunc)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue