feat: add redirect host selection options

Support explicit redirect host source selection for HTTP-to-HTTPS redirects with ordered header lookup, fixed host mode, and strict validation. Document the new redirect option relationships and add focused tests for 426 fallback, conflict checks, and non-graceful startup errors.
This commit is contained in:
wjqserver 2026-04-07 19:49:13 +08:00
parent e4d3eed379
commit e2cf08d5dd
5 changed files with 422 additions and 35 deletions

178
serve.go
View file

@ -14,6 +14,7 @@ import (
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
@ -32,15 +33,19 @@ const (
)
type runConfig struct {
addr string
httpRedirectAddr string
tlsConfig *tls.Config
graceful bool
shutdownTimeout time.Duration
gracefulCtx context.Context
mode runMode
shutdownDefaultSet bool
shutdownTimeoutSet bool
addr string
httpRedirectAddr string
tlsConfig *tls.Config
redirectHost string
redirectHostHeaders []string
useHeaderHost bool
useHeaderHostSet bool
graceful bool
shutdownTimeout time.Duration
gracefulCtx context.Context
mode runMode
shutdownDefaultSet bool
shutdownTimeoutSet bool
}
type RunOption interface {
@ -58,9 +63,20 @@ func defaultRunConfig() runConfig {
addr: ":8080",
shutdownTimeout: defaultShutdownTimeout,
mode: runModeHTTP,
useHeaderHost: true,
}
}
type HTTPRedirectOption interface {
applyRedirect(*runConfig) error
}
type redirectOptionFunc func(*runConfig) error
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
return f(cfg)
}
func WithAddr(addr string) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if addr == "" {
@ -84,13 +100,52 @@ func WithTLS(tlsConfig *tls.Config) RunOption {
})
}
func WithHTTPRedirect(addr string) RunOption {
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
return runOptionFunc(func(cfg *runConfig) error {
if addr == "" {
return errors.New("http redirect address must not be empty")
}
cfg.httpRedirectAddr = addr
cfg.mode = runModeHTTPSRedirect
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt.applyRedirect(cfg); err != nil {
return err
}
}
return nil
})
}
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.useHeaderHost = enabled
cfg.useHeaderHostSet = true
return nil
})
}
func WithRedirectHost(host string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
if host == "" {
return errors.New("redirect host must not be empty")
}
cfg.redirectHost = host
return nil
})
}
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
return redirectOptionFunc(func(cfg *runConfig) error {
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
for _, header := range headers {
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
if trimmed != "" {
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
}
}
return nil
})
}
@ -215,16 +270,68 @@ func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
return server
}
func buildRedirectServer(engine *Engine, httpsAddr, httpAddr string) (*http.Server, error) {
func firstRedirectHeaderHost(r *http.Request, headers []string) string {
if r == nil {
return ""
}
for _, header := range headers {
value := strings.TrimSpace(r.Header.Get(header))
if value == "" {
continue
}
if comma := strings.IndexByte(value, ','); comma >= 0 {
value = strings.TrimSpace(value[:comma])
}
if value != "" {
return value
}
}
return ""
}
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if cfg.redirectHost == "" {
return "", http.StatusInternalServerError, false
}
return cfg.redirectHost, 0, true
}
if len(cfg.redirectHostHeaders) > 0 {
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
if host == "" {
return "", http.StatusUpgradeRequired, false
}
return host, 0, true
}
if r == nil {
return "", http.StatusUpgradeRequired, false
}
host := strings.TrimSpace(r.Host)
if host == "" {
return "", http.StatusUpgradeRequired, false
}
return host, 0, true
}
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
httpsAddr := cfg.addr
httpAddr := cfg.httpRedirectAddr
httpsPort, err := parseHTTPSPort(httpsAddr)
if err != nil {
return nil, err
}
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
host, statusCode, ok := redirectTargetHost(r, cfg)
if !ok {
http.Error(w, http.StatusText(statusCode), statusCode)
return
}
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
host = parsedHost
}
targetURL := "https://" + host
@ -248,12 +355,26 @@ func validateRunConfig(cfg runConfig) error {
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
return errors.New("https mode requires WithTLS")
}
if cfg.httpRedirectAddr != "" && cfg.mode != runModeHTTPSRedirect {
cfg.mode = runModeHTTPSRedirect
}
if cfg.gracefulCtx != nil && !cfg.graceful {
return errors.New("WithShutdownContext requires graceful shutdown")
}
if len(cfg.redirectHostHeaders) > 0 {
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
}
}
if cfg.useHeaderHostSet && cfg.useHeaderHost {
if cfg.redirectHost != "" {
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
}
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
if cfg.redirectHost == "" {
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
}
if len(cfg.redirectHostHeaders) > 0 {
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
}
}
return nil
}
@ -286,7 +407,7 @@ func shutdownServers(servers []*http.Server, timeout time.Duration) error {
wg.Add(1)
go func(s *http.Server) {
defer wg.Done()
if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
if err := s.Shutdown(ctx); err != nil {
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
}
}(srv)
@ -378,7 +499,7 @@ func (engine *Engine) Run(opts ...RunOption) error {
servers := []*http.Server{mainServer}
serveTLSFlags := []bool{serveTLS}
if cfg.mode == runModeHTTPSRedirect {
redirectServer, err := buildRedirectServer(engine, cfg.addr, cfg.httpRedirectAddr)
redirectServer, err := buildRedirectServer(engine, cfg)
if err != nil {
return err
}
@ -388,9 +509,22 @@ func (engine *Engine) Run(opts ...RunOption) error {
if !cfg.graceful {
if len(servers) > 1 {
runServer("HTTPS", servers[0], true)
log.Printf("Starting Touka HTTP Redirect server on %s", servers[1].Addr)
return serveServer(servers[1], false)
serverStopped := make(chan error, len(servers))
for i, srv := range servers {
serveTLSFlag := serveTLSFlags[i]
go func(server *http.Server, useTLS bool) {
serverStopped <- serveServer(server, useTLS)
}(srv, serveTLSFlag)
}
err := <-serverStopped
if err != nil && !errors.Is(err, http.ErrServerClosed) {
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
return errors.Join(err, shutdownErr)
}
return err
}
return err
}
protocolLabel := "HTTP"