touka/serve_test.go

492 lines
15 KiB
Go

package touka
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func generateSelfSignedCert(t *testing.T) tls.Certificate {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate private key: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "127.0.0.1"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("create self-signed cert: %v", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("parse self-signed cert: %v", err)
}
return cert
}
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on ephemeral port: %v", err)
}
addr := listener.Addr().String()
if err := listener.Close(); err != nil {
t.Fatalf("close temporary listener: %v", err)
}
srv := &http.Server{
Addr: addr,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("ok"))
}),
// RunShutdown uses the HTTP startup path and must not let a shared
// ServerConfigurator accidentally turn it into HTTPS.
TLSConfig: &tls.Config{},
}
errCh := make(chan error, 1)
go func() {
errCh <- serveServer(srv, false)
}()
client := &http.Client{Timeout: 200 * time.Millisecond}
var resp *http.Response
requestURL := "http://" + addr
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
resp, err = client.Get(requestURL)
if err == nil {
break
}
time.Sleep(20 * time.Millisecond)
}
if err != nil {
select {
case serveErr := <-errCh:
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: request error=%v, serve error=%v", err, serveErr)
default:
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err)
}
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: got %d want %d", resp.StatusCode, http.StatusOK)
}
if string(body) != "ok" {
t.Fatalf("unexpected body: got %q want %q", string(body), "ok")
}
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
t.Fatalf("shutdown server: %v", err)
}
if err := <-errCh; !errors.Is(err, http.ErrServerClosed) {
t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err)
}
}
func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
engine := New()
err := engine.Run(WithHTTPRedirect(":80"))
if err == nil {
t.Fatal("expected redirect mode without TLS to fail")
}
}
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
engine := New()
err := engine.Run(
WithAddr(":443"),
WithTLS(&tls.Config{}),
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
)
if err == nil {
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
}
}
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
cfg := defaultRunConfig()
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
t.Fatalf("apply graceful default option: %v", err)
}
if !cfg.graceful {
t.Fatal("expected graceful shutdown to be enabled")
}
if cfg.shutdownTimeout != defaultShutdownTimeout {
t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout)
}
}
func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
cfg := defaultRunConfig()
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
if err := WithTLS(tlsConfig).apply(&cfg); err != nil {
t.Fatalf("apply TLS option: %v", err)
}
if cfg.mode != runModeHTTPS {
t.Fatalf("expected HTTPS mode, got %v", cfg.mode)
}
if cfg.graceful {
t.Fatal("expected TLS option to remain independent from graceful shutdown")
}
if cfg.tlsConfig != tlsConfig {
t.Fatal("expected TLS config to be preserved in run config")
}
}
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
engine := New()
if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
t.Fatal("expected redirect server builder to reject https address without port")
}
}
func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
cfg := defaultRunConfig()
ctx := t.Context()
if err := WithShutdownContext(ctx).apply(&cfg); err != nil {
t.Fatalf("apply shutdown context option: %v", err)
}
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected shutdown context without graceful shutdown to fail validation")
}
}
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
cfg := defaultRunConfig()
cfg.httpRedirectAddr = ":80"
if err := validateRunConfig(cfg); err != nil {
t.Fatalf("validate run config: %v", err)
}
if cfg.mode != runModeHTTP {
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
}
}
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = false
cfg.useHeaderHostSet = true
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected configured host mode without redirect host to fail validation")
}
}
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
cfg := defaultRunConfig()
cfg.mode = runModeHTTPSRedirect
cfg.tlsConfig = &tls.Config{}
cfg.useHeaderHost = true
cfg.useHeaderHostSet = true
cfg.redirectHost = "configured.example"
if err := validateRunConfig(cfg); err == nil {
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
}
}
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
engine := New()
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
if server.BaseContext == nil {
t.Fatal("expected graceful main server to set BaseContext")
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for base context check: %v", err)
}
defer listener.Close()
if got := server.BaseContext(listener); got != engine.shutdownCtx {
t.Fatal("expected graceful main server to use engine shutdown context")
}
}
func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) {
engine := New()
serverConfigured := false
tlsConfigured := false
engine.SetServerConfigurator(func(s *http.Server) {
serverConfigured = true
s.ReadTimeout = time.Second
})
engine.SetTLSServerConfigurator(func(s *http.Server) {
tlsConfigured = true
s.IdleTimeout = time.Second
})
server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if !tlsConfigured {
t.Fatal("expected TLS configurator to run for HTTPS main server")
}
if serverConfigured {
t.Fatal("expected generic server configurator to be skipped when TLS configurator is set")
}
if server.IdleTimeout != time.Second {
t.Fatal("expected TLS configurator changes to be applied to HTTPS main server")
}
}
func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
engine := New()
configured := false
engine.SetServerConfigurator(func(s *http.Server) {
configured = true
s.ReadTimeout = time.Second
})
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
if !configured {
t.Fatal("expected redirect server to use generic server configurator")
}
if server.ReadTimeout != time.Second {
t.Fatal("expected redirect server configurator changes to be applied")
}
}
func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
engine := New()
httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
if !httpsServer.Protocols.HTTP2() {
t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings")
}
httpServer := buildMainServer(engine, defaultRunConfig())
if httpServer.Protocols.HTTP2() {
t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled")
}
}
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
req.Header.Set("X-First-Host", "first.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: true,
useHeaderHostSet: true,
redirectHostHeaders: []string{"X-Forwarded-Host"},
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUpgradeRequired {
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
}
}
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{
addr: ":443",
httpRedirectAddr: ":80",
useHeaderHost: false,
useHeaderHostSet: true,
redirectHost: "configured.example",
})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
req.Host = "example.com:80"
req.Header.Set("X-Forwarded-Host", "forwarded.example")
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
t.Fatalf("unexpected redirect location: %q", location)
}
}
func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) {
engine := New()
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil)
req.Host = "[::1]:80"
rr := httptest.NewRecorder()
server.Handler.ServeHTTP(rr, req)
if rr.Code != http.StatusMovedPermanently {
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
}
if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" {
t.Fatalf("unexpected IPv6 redirect location: %q", location)
}
}
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on occupied addr: %v", err)
}
occupiedAddr := occupied.Addr().String()
defer occupied.Close()
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for redirect addr: %v", err)
}
redirectAddr := redirectListener.Addr().String()
if err := redirectListener.Close(); err != nil {
t.Fatalf("close redirect addr probe: %v", err)
}
engine := New()
redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
if err != nil {
t.Fatalf("build redirect server: %v", err)
}
mainServer := &http.Server{Addr: occupiedAddr, Handler: engine}
err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background())
if err == nil {
t.Fatal("expected gracefulServe to fail when one server cannot bind")
}
if !strings.Contains(err.Error(), occupiedAddr) {
t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err)
}
conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond)
if dialErr == nil {
conn.Close()
t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr)
}
if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") {
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
}
}
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on occupied addr: %v", err)
}
occupiedAddr := occupied.Addr().String()
defer occupied.Close()
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen for redirect addr: %v", err)
}
redirectAddr := redirectListener.Addr().String()
if err := redirectListener.Close(); err != nil {
t.Fatalf("close redirect addr probe: %v", err)
}
engine := New()
err = engine.Run(
WithAddr(occupiedAddr),
WithTLS(&tls.Config{}),
WithHTTPRedirect(redirectAddr),
)
if err == nil {
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
}
if !strings.Contains(err.Error(), occupiedAddr) {
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
}
}