mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Support explicit redirect host source selection for HTTP-to-HTTPS redirects with ordered header lookup, fixed host mode, and strict validation. Document the new redirect option relationships and add focused tests for 426 fallback, conflict checks, and non-graceful startup errors.
543 lines
13 KiB
Go
543 lines
13 KiB
Go
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
// Copyright 2024 WJQSERVER. All rights reserved.
|
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
|
package touka
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/fenthope/reco"
|
|
)
|
|
|
|
const defaultShutdownTimeout = 5 * time.Second
|
|
|
|
type runMode uint8
|
|
|
|
const (
|
|
runModeHTTP runMode = iota
|
|
runModeHTTPS
|
|
runModeHTTPSRedirect
|
|
)
|
|
|
|
type runConfig struct {
|
|
addr string
|
|
httpRedirectAddr string
|
|
tlsConfig *tls.Config
|
|
redirectHost string
|
|
redirectHostHeaders []string
|
|
useHeaderHost bool
|
|
useHeaderHostSet bool
|
|
graceful bool
|
|
shutdownTimeout time.Duration
|
|
gracefulCtx context.Context
|
|
mode runMode
|
|
shutdownDefaultSet bool
|
|
shutdownTimeoutSet bool
|
|
}
|
|
|
|
type RunOption interface {
|
|
apply(*runConfig) error
|
|
}
|
|
|
|
type runOptionFunc func(*runConfig) error
|
|
|
|
func (f runOptionFunc) apply(cfg *runConfig) error {
|
|
return f(cfg)
|
|
}
|
|
|
|
func defaultRunConfig() runConfig {
|
|
return runConfig{
|
|
addr: ":8080",
|
|
shutdownTimeout: defaultShutdownTimeout,
|
|
mode: runModeHTTP,
|
|
useHeaderHost: true,
|
|
}
|
|
}
|
|
|
|
type HTTPRedirectOption interface {
|
|
applyRedirect(*runConfig) error
|
|
}
|
|
|
|
type redirectOptionFunc func(*runConfig) error
|
|
|
|
func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error {
|
|
return f(cfg)
|
|
}
|
|
|
|
func WithAddr(addr string) RunOption {
|
|
return runOptionFunc(func(cfg *runConfig) error {
|
|
if addr == "" {
|
|
return errors.New("run address must not be empty")
|
|
}
|
|
cfg.addr = addr
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func WithTLS(tlsConfig *tls.Config) RunOption {
|
|
return runOptionFunc(func(cfg *runConfig) error {
|
|
if tlsConfig == nil {
|
|
return errors.New("tls.Config must not be nil")
|
|
}
|
|
cfg.tlsConfig = tlsConfig
|
|
if cfg.mode == runModeHTTP {
|
|
cfg.mode = runModeHTTPS
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption {
|
|
return runOptionFunc(func(cfg *runConfig) error {
|
|
if addr == "" {
|
|
return errors.New("http redirect address must not be empty")
|
|
}
|
|
cfg.httpRedirectAddr = addr
|
|
cfg.mode = runModeHTTPSRedirect
|
|
for _, opt := range opts {
|
|
if opt == nil {
|
|
continue
|
|
}
|
|
if err := opt.applyRedirect(cfg); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func WithUseHeaderHost(enabled bool) HTTPRedirectOption {
|
|
return redirectOptionFunc(func(cfg *runConfig) error {
|
|
cfg.useHeaderHost = enabled
|
|
cfg.useHeaderHostSet = true
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func WithRedirectHost(host string) HTTPRedirectOption {
|
|
return redirectOptionFunc(func(cfg *runConfig) error {
|
|
if host == "" {
|
|
return errors.New("redirect host must not be empty")
|
|
}
|
|
cfg.redirectHost = host
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func WithRedirectHostHeaders(headers []string) HTTPRedirectOption {
|
|
return redirectOptionFunc(func(cfg *runConfig) error {
|
|
cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0]
|
|
for _, header := range headers {
|
|
trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header))
|
|
if trimmed != "" {
|
|
cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func WithGracefulShutdown(timeout time.Duration) RunOption {
|
|
return runOptionFunc(func(cfg *runConfig) error {
|
|
cfg.graceful = true
|
|
cfg.shutdownTimeoutSet = true
|
|
if timeout > 0 {
|
|
cfg.shutdownTimeout = timeout
|
|
} else {
|
|
cfg.shutdownTimeout = defaultShutdownTimeout
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func WithGracefulShutdownDefault() RunOption {
|
|
return runOptionFunc(func(cfg *runConfig) error {
|
|
cfg.graceful = true
|
|
cfg.shutdownDefaultSet = true
|
|
cfg.shutdownTimeout = defaultShutdownTimeout
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func WithShutdownContext(ctx context.Context) RunOption {
|
|
return runOptionFunc(func(cfg *runConfig) error {
|
|
if ctx == nil {
|
|
return errors.New("shutdown context must not be nil")
|
|
}
|
|
cfg.gracefulCtx = ctx
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func serveServer(srv *http.Server, serveTLS bool) error {
|
|
if serveTLS {
|
|
return srv.ListenAndServeTLS("", "")
|
|
}
|
|
return srv.ListenAndServe()
|
|
}
|
|
|
|
func runServer(serverType string, srv *http.Server, serveTLS bool) {
|
|
go func() {
|
|
protocol := "http"
|
|
if serveTLS {
|
|
protocol = "https"
|
|
}
|
|
|
|
log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr)
|
|
|
|
err := serveServer(srv, serveTLS)
|
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
log.Fatalf("Touka %s server failed: %v", serverType, err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config {
|
|
if tlsConfig == nil {
|
|
return nil
|
|
}
|
|
return tlsConfig.Clone()
|
|
}
|
|
|
|
func parseHTTPSPort(addr string) (string, error) {
|
|
_, port, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
return "", fmt.Errorf("https address %q must include a port: %w", addr, err)
|
|
}
|
|
return port, nil
|
|
}
|
|
|
|
func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) {
|
|
if serveTLS {
|
|
if engine.TLSServerConfigurator != nil {
|
|
engine.TLSServerConfigurator(srv)
|
|
return
|
|
}
|
|
}
|
|
if engine.ServerConfigurator != nil {
|
|
engine.ServerConfigurator(srv)
|
|
}
|
|
}
|
|
|
|
func applyRedirectServerConfig(engine *Engine, srv *http.Server) {
|
|
applyServerProtocols(srv, engine.serverProtocols)
|
|
if engine.ServerConfigurator != nil {
|
|
engine.ServerConfigurator(srv)
|
|
}
|
|
}
|
|
|
|
func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols {
|
|
if engine == nil {
|
|
return nil
|
|
}
|
|
if serveTLS && engine.useDefaultProtocols {
|
|
protocols := &http.Protocols{}
|
|
protocols.SetHTTP1(true)
|
|
protocols.SetHTTP2(true)
|
|
return protocols
|
|
}
|
|
return cloneServerProtocols(engine.serverProtocols)
|
|
}
|
|
|
|
func buildMainServer(engine *Engine, cfg runConfig) *http.Server {
|
|
serveTLS := cfg.mode != runModeHTTP
|
|
server := &http.Server{
|
|
Addr: cfg.addr,
|
|
Handler: engine,
|
|
TLSConfig: cloneTLSConfig(cfg.tlsConfig),
|
|
}
|
|
if cfg.graceful {
|
|
server.BaseContext = func(net.Listener) context.Context {
|
|
return engine.shutdownCtx
|
|
}
|
|
server.RegisterOnShutdown(engine.shutdownCancel)
|
|
}
|
|
applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS))
|
|
applyMainServerConfig(engine, server, serveTLS)
|
|
return server
|
|
}
|
|
|
|
func firstRedirectHeaderHost(r *http.Request, headers []string) string {
|
|
if r == nil {
|
|
return ""
|
|
}
|
|
for _, header := range headers {
|
|
value := strings.TrimSpace(r.Header.Get(header))
|
|
if value == "" {
|
|
continue
|
|
}
|
|
if comma := strings.IndexByte(value, ','); comma >= 0 {
|
|
value = strings.TrimSpace(value[:comma])
|
|
}
|
|
if value != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) {
|
|
if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
|
if cfg.redirectHost == "" {
|
|
return "", http.StatusInternalServerError, false
|
|
}
|
|
return cfg.redirectHost, 0, true
|
|
}
|
|
|
|
if len(cfg.redirectHostHeaders) > 0 {
|
|
host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders)
|
|
if host == "" {
|
|
return "", http.StatusUpgradeRequired, false
|
|
}
|
|
return host, 0, true
|
|
}
|
|
|
|
if r == nil {
|
|
return "", http.StatusUpgradeRequired, false
|
|
}
|
|
host := strings.TrimSpace(r.Host)
|
|
if host == "" {
|
|
return "", http.StatusUpgradeRequired, false
|
|
}
|
|
return host, 0, true
|
|
}
|
|
|
|
func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) {
|
|
httpsAddr := cfg.addr
|
|
httpAddr := cfg.httpRedirectAddr
|
|
httpsPort, err := parseHTTPSPort(httpsAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
host, statusCode, ok := redirectTargetHost(r, cfg)
|
|
if !ok {
|
|
http.Error(w, http.StatusText(statusCode), statusCode)
|
|
return
|
|
}
|
|
|
|
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
|
host = parsedHost
|
|
}
|
|
|
|
targetURL := "https://" + host
|
|
if httpsPort != "443" {
|
|
targetURL = "https://" + net.JoinHostPort(host, httpsPort)
|
|
}
|
|
targetURL += r.URL.RequestURI()
|
|
|
|
http.Redirect(w, r, targetURL, http.StatusMovedPermanently)
|
|
})
|
|
|
|
server := &http.Server{Addr: httpAddr, Handler: redirectHandler}
|
|
applyRedirectServerConfig(engine, server)
|
|
return server, nil
|
|
}
|
|
|
|
func validateRunConfig(cfg runConfig) error {
|
|
if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil {
|
|
return errors.New("WithHTTPRedirect requires WithTLS")
|
|
}
|
|
if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil {
|
|
return errors.New("https mode requires WithTLS")
|
|
}
|
|
if cfg.gracefulCtx != nil && !cfg.graceful {
|
|
return errors.New("WithShutdownContext requires graceful shutdown")
|
|
}
|
|
if len(cfg.redirectHostHeaders) > 0 {
|
|
if !cfg.useHeaderHostSet || !cfg.useHeaderHost {
|
|
return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)")
|
|
}
|
|
}
|
|
if cfg.useHeaderHostSet && cfg.useHeaderHost {
|
|
if cfg.redirectHost != "" {
|
|
return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)")
|
|
}
|
|
} else if cfg.useHeaderHostSet && !cfg.useHeaderHost {
|
|
if cfg.redirectHost == "" {
|
|
return errors.New("WithUseHeaderHost(false) requires WithRedirectHost")
|
|
}
|
|
if len(cfg.redirectHostHeaders) > 0 {
|
|
return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func effectiveShutdownTimeout(cfg runConfig) time.Duration {
|
|
if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet {
|
|
if cfg.shutdownTimeout > 0 {
|
|
return cfg.shutdownTimeout
|
|
}
|
|
}
|
|
return defaultShutdownTimeout
|
|
}
|
|
|
|
func closeLoggerAsync(logger *reco.Logger) {
|
|
if logger == nil {
|
|
return
|
|
}
|
|
go func() {
|
|
log.Println("Closing Touka logger...")
|
|
CloseLogger(logger)
|
|
}()
|
|
}
|
|
|
|
func shutdownServers(servers []*http.Server, timeout time.Duration) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
var wg sync.WaitGroup
|
|
errChan := make(chan error, len(servers))
|
|
for _, srv := range servers {
|
|
wg.Add(1)
|
|
go func(s *http.Server) {
|
|
defer wg.Done()
|
|
if err := s.Shutdown(ctx); err != nil {
|
|
errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err)
|
|
}
|
|
}(srv)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errChan)
|
|
|
|
var shutdownErrors []error
|
|
for err := range errChan {
|
|
shutdownErrors = append(shutdownErrors, err)
|
|
log.Printf("Shutdown error: %v", err)
|
|
}
|
|
if len(shutdownErrors) > 0 {
|
|
return errors.Join(shutdownErrors...)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error {
|
|
quit := make(chan os.Signal, 1)
|
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
defer signal.Stop(quit)
|
|
|
|
serverStopped := make(chan error, len(servers))
|
|
for i, srv := range servers {
|
|
serveTLSFlag := serveTLS[i]
|
|
go func(server *http.Server, useTLS bool) {
|
|
serverStopped <- serveServer(server, useTLS)
|
|
}(srv, serveTLSFlag)
|
|
}
|
|
|
|
select {
|
|
case err := <-serverStopped:
|
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil {
|
|
return errors.Join(err, shutdownErr)
|
|
}
|
|
return err
|
|
}
|
|
log.Println("Touka server stopped gracefully.")
|
|
return nil
|
|
case <-quit:
|
|
log.Println("Shutting down Touka server(s) due to OS signal...")
|
|
case <-shutdownCtx.Done():
|
|
log.Println("Context cancelled, shutting down Touka server(s)...")
|
|
}
|
|
|
|
closeLoggerAsync(logger)
|
|
if err := shutdownServers(servers, timeout); err != nil {
|
|
return err
|
|
}
|
|
log.Println("Touka server(s) exited gracefully.")
|
|
return nil
|
|
}
|
|
|
|
// Run starts the engine with the provided startup options.
|
|
//
|
|
// Default behavior with no options:
|
|
// - HTTP only
|
|
// - listens on :8080
|
|
// - no graceful shutdown orchestration
|
|
//
|
|
// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable
|
|
// signal-aware graceful shutdown and request-context cancellation semantics.
|
|
// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown.
|
|
func (engine *Engine) Run(opts ...RunOption) error {
|
|
cfg := defaultRunConfig()
|
|
for _, opt := range opts {
|
|
if opt == nil {
|
|
continue
|
|
}
|
|
if err := opt.apply(&cfg); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if cfg.httpRedirectAddr != "" {
|
|
cfg.mode = runModeHTTPSRedirect
|
|
} else if cfg.tlsConfig != nil {
|
|
cfg.mode = runModeHTTPS
|
|
}
|
|
if err := validateRunConfig(cfg); err != nil {
|
|
return err
|
|
}
|
|
|
|
serveTLS := cfg.mode != runModeHTTP
|
|
|
|
mainServer := buildMainServer(engine, cfg)
|
|
servers := []*http.Server{mainServer}
|
|
serveTLSFlags := []bool{serveTLS}
|
|
if cfg.mode == runModeHTTPSRedirect {
|
|
redirectServer, err := buildRedirectServer(engine, cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
servers = append(servers, redirectServer)
|
|
serveTLSFlags = append(serveTLSFlags, false)
|
|
}
|
|
|
|
if !cfg.graceful {
|
|
if len(servers) > 1 {
|
|
serverStopped := make(chan error, len(servers))
|
|
for i, srv := range servers {
|
|
serveTLSFlag := serveTLSFlags[i]
|
|
go func(server *http.Server, useTLS bool) {
|
|
serverStopped <- serveServer(server, useTLS)
|
|
}(srv, serveTLSFlag)
|
|
}
|
|
|
|
err := <-serverStopped
|
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil {
|
|
return errors.Join(err, shutdownErr)
|
|
}
|
|
return err
|
|
}
|
|
return err
|
|
}
|
|
|
|
protocolLabel := "HTTP"
|
|
if serveTLS {
|
|
protocolLabel = "HTTPS"
|
|
}
|
|
log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr)
|
|
return serveServer(mainServer, serveTLS)
|
|
}
|
|
|
|
shutdownCtx := context.Background()
|
|
if cfg.gracefulCtx != nil {
|
|
shutdownCtx = cfg.gracefulCtx
|
|
}
|
|
return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx)
|
|
}
|