mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat: add redirect host selection options
Support explicit redirect host source selection for HTTP-to-HTTPS redirects with ordered header lookup, fixed host mode, and strict validation. Document the new redirect option relationships and add focused tests for 426 fallback, conflict checks, and non-graceful startup errors.
This commit is contained in:
parent
e4d3eed379
commit
e2cf08d5dd
5 changed files with 422 additions and 35 deletions
178
serve.go
178
serve.go
|
|
@ -14,6 +14,7 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
|
@ -32,15 +33,19 @@ const (
|
|||
)
|
||||
|
||||
type runConfig struct {
|
||||
addr string
|
||||
httpRedirectAddr string
|
||||
tlsConfig *tls.Config
|
||||
graceful bool
|
||||
shutdownTimeout time.Duration
|
||||
gracefulCtx context.Context
|
||||
mode runMode
|
||||
shutdownDefaultSet bool
|
||||
shutdownTimeoutSet bool
|
||||
addr string
|
||||
httpRedirectAddr string
|
||||
tlsConfig *tls.Config
|
||||
redirectHost string
|
||||
redirectHostHeaders []string
|
||||
useHeaderHost bool
|
||||
useHeaderHostSet bool
|
||||
graceful bool
|
||||
shutdownTimeout time.Duration
|
||||
gracefulCtx context.Context
|
||||
mode runMode
|
||||
shutdownDefaultSet bool
|
||||
shutdownTimeoutSet bool
|
||||
}
|
||||
|
||||
type RunOption interface {
|
||||
|
|
@ -58,9 +63,20 @@ func defaultRunConfig() runConfig {
|
|||
addr: ":8080",
|
||||
shutdownTimeout: defaultShutdownTimeout,
|
||||
mode: runModeHTTP,
|
||||
useHeaderHost: true,
|
||||
}
|
||||
}
|
||||
|
||||
type HTTPRedirectOption interface {
|
||||
applyRedirect(*runConfig) error
|
||||
}
|
||||
|
||||
type redirectOptionFunc func(*runConfig) error
|
||||
|
||||
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
|
||||
return f(cfg)
|
||||
}
|
||||
|
||||
func WithAddr(addr string) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if addr == "" {
|
||||
|
|
@ -84,13 +100,52 @@ func WithTLS(tlsConfig *tls.Config) RunOption {
|
|||
})
|
||||
}
|
||||
|
||||
func WithHTTPRedirect(addr string) RunOption {
|
||||
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
|
||||
return runOptionFunc(func(cfg *runConfig) error {
|
||||
if addr == "" {
|
||||
return errors.New("http redirect address must not be empty")
|
||||
}
|
||||
cfg.httpRedirectAddr = addr
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
if err := opt.applyRedirect(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
|
||||
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.useHeaderHost = enabled
|
||||
cfg.useHeaderHostSet = true
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithRedirectHost(host string) HTTPRedirectOption {
|
||||
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||
if host == "" {
|
||||
return errors.New("redirect host must not be empty")
|
||||
}
|
||||
cfg.redirectHost = host
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
|
||||
return redirectOptionFunc(func(cfg *runConfig) error {
|
||||
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
|
||||
for _, header := range headers {
|
||||
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
|
||||
if trimmed != "" {
|
||||
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
@ -215,16 +270,68 @@ func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
|
|||
return server
|
||||
}
|
||||
|
||||
func buildRedirectServer(engine *Engine, httpsAddr, httpAddr string) (*http.Server, error) {
|
||||
func firstRedirectHeaderHost(r *http.Request, headers []string) string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
for _, header := range headers {
|
||||
value := strings.TrimSpace(r.Header.Get(header))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if comma := strings.IndexByte(value, ','); comma >= 0 {
|
||||
value = strings.TrimSpace(value[:comma])
|
||||
}
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
|
||||
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||
if cfg.redirectHost == "" {
|
||||
return "", http.StatusInternalServerError, false
|
||||
}
|
||||
return cfg.redirectHost, 0, true
|
||||
}
|
||||
|
||||
if len(cfg.redirectHostHeaders) > 0 {
|
||||
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
|
||||
if host == "" {
|
||||
return "", http.StatusUpgradeRequired, false
|
||||
}
|
||||
return host, 0, true
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
return "", http.StatusUpgradeRequired, false
|
||||
}
|
||||
host := strings.TrimSpace(r.Host)
|
||||
if host == "" {
|
||||
return "", http.StatusUpgradeRequired, false
|
||||
}
|
||||
return host, 0, true
|
||||
}
|
||||
|
||||
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
|
||||
httpsAddr := cfg.addr
|
||||
httpAddr := cfg.httpRedirectAddr
|
||||
httpsPort, err := parseHTTPSPort(httpsAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
host, statusCode, ok := redirectTargetHost(r, cfg)
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(statusCode), statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = parsedHost
|
||||
}
|
||||
|
||||
targetURL := "https://" + host
|
||||
|
|
@ -248,12 +355,26 @@ func validateRunConfig(cfg runConfig) error {
|
|||
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
|
||||
return errors.New("https mode requires WithTLS")
|
||||
}
|
||||
if cfg.httpRedirectAddr != "" && cfg.mode != runModeHTTPSRedirect {
|
||||
cfg.mode = runModeHTTPSRedirect
|
||||
}
|
||||
if cfg.gracefulCtx != nil && !cfg.graceful {
|
||||
return errors.New("WithShutdownContext requires graceful shutdown")
|
||||
}
|
||||
if len(cfg.redirectHostHeaders) > 0 {
|
||||
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
|
||||
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
|
||||
}
|
||||
}
|
||||
if cfg.useHeaderHostSet && cfg.useHeaderHost {
|
||||
if cfg.redirectHost != "" {
|
||||
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
|
||||
}
|
||||
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
||||
if cfg.redirectHost == "" {
|
||||
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
|
||||
}
|
||||
if len(cfg.redirectHostHeaders) > 0 {
|
||||
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -286,7 +407,7 @@ func shutdownServers(servers []*http.Server, timeout time.Duration) error {
|
|||
wg.Add(1)
|
||||
go func(s *http.Server) {
|
||||
defer wg.Done()
|
||||
if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
if err := s.Shutdown(ctx); err != nil {
|
||||
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
||||
}
|
||||
}(srv)
|
||||
|
|
@ -378,7 +499,7 @@ func (engine *Engine) Run(opts ...RunOption) error {
|
|||
servers := []*http.Server{mainServer}
|
||||
serveTLSFlags := []bool{serveTLS}
|
||||
if cfg.mode == runModeHTTPSRedirect {
|
||||
redirectServer, err := buildRedirectServer(engine, cfg.addr, cfg.httpRedirectAddr)
|
||||
redirectServer, err := buildRedirectServer(engine, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -388,9 +509,22 @@ func (engine *Engine) Run(opts ...RunOption) error {
|
|||
|
||||
if !cfg.graceful {
|
||||
if len(servers) > 1 {
|
||||
runServer("HTTPS", servers[0], true)
|
||||
log.Printf("Starting Touka HTTP Redirect server on %s", servers[1].Addr)
|
||||
return serveServer(servers[1], false)
|
||||
serverStopped := make(chan error, len(servers))
|
||||
for i, srv := range servers {
|
||||
serveTLSFlag := serveTLSFlags[i]
|
||||
go func(server *http.Server, useTLS bool) {
|
||||
serverStopped <- serveServer(server, useTLS)
|
||||
}(srv, serveTLSFlag)
|
||||
}
|
||||
|
||||
err := <-serverStopped
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
|
||||
return errors.Join(err, shutdownErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
protocolLabel := "HTTP"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue