mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
492 lines
15 KiB
Go
492 lines
15 KiB
Go
package touka
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"errors"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func generateSelfSignedCert(t *testing.T) tls.Certificate {
|
|
t.Helper()
|
|
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("generate private key: %v", err)
|
|
}
|
|
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{CommonName: "127.0.0.1"},
|
|
NotBefore: time.Now().Add(-time.Hour),
|
|
NotAfter: time.Now().Add(time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageServerAuth,
|
|
},
|
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
}
|
|
|
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
|
|
if err != nil {
|
|
t.Fatalf("create self-signed cert: %v", err)
|
|
}
|
|
|
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
|
|
|
|
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
|
if err != nil {
|
|
t.Fatalf("parse self-signed cert: %v", err)
|
|
}
|
|
return cert
|
|
}
|
|
|
|
func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen on ephemeral port: %v", err)
|
|
}
|
|
addr := listener.Addr().String()
|
|
if err := listener.Close(); err != nil {
|
|
t.Fatalf("close temporary listener: %v", err)
|
|
}
|
|
|
|
srv := &http.Server{
|
|
Addr: addr,
|
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = w.Write([]byte("ok"))
|
|
}),
|
|
// RunShutdown uses the HTTP startup path and must not let a shared
|
|
// ServerConfigurator accidentally turn it into HTTPS.
|
|
TLSConfig: &tls.Config{},
|
|
}
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- serveServer(srv, false)
|
|
}()
|
|
|
|
client := &http.Client{Timeout: 200 * time.Millisecond}
|
|
var resp *http.Response
|
|
requestURL := "http://" + addr
|
|
|
|
deadline := time.Now().Add(3 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
resp, err = client.Get(requestURL)
|
|
if err == nil {
|
|
break
|
|
}
|
|
time.Sleep(20 * time.Millisecond)
|
|
}
|
|
if err != nil {
|
|
select {
|
|
case serveErr := <-errCh:
|
|
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: request error=%v, serve error=%v", err, serveErr)
|
|
default:
|
|
t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err)
|
|
}
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("read response body: %v", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("unexpected status code: got %d want %d", resp.StatusCode, http.StatusOK)
|
|
}
|
|
if string(body) != "ok" {
|
|
t.Fatalf("unexpected body: got %q want %q", string(body), "ok")
|
|
}
|
|
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
if err := srv.Shutdown(shutdownCtx); err != nil {
|
|
t.Fatalf("shutdown server: %v", err)
|
|
}
|
|
|
|
if err := <-errCh; !errors.Is(err, http.ErrServerClosed) {
|
|
t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRunRejectsRedirectWithoutTLS(t *testing.T) {
|
|
engine := New()
|
|
err := engine.Run(WithHTTPRedirect(":80"))
|
|
if err == nil {
|
|
t.Fatal("expected redirect mode without TLS to fail")
|
|
}
|
|
}
|
|
|
|
func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) {
|
|
engine := New()
|
|
err := engine.Run(
|
|
WithAddr(":443"),
|
|
WithTLS(&tls.Config{}),
|
|
WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})),
|
|
)
|
|
if err == nil {
|
|
t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail")
|
|
}
|
|
}
|
|
|
|
func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) {
|
|
cfg := defaultRunConfig()
|
|
if err := WithGracefulShutdownDefault().apply(&cfg); err != nil {
|
|
t.Fatalf("apply graceful default option: %v", err)
|
|
}
|
|
if !cfg.graceful {
|
|
t.Fatal("expected graceful shutdown to be enabled")
|
|
}
|
|
if cfg.shutdownTimeout != defaultShutdownTimeout {
|
|
t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout)
|
|
}
|
|
}
|
|
|
|
func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) {
|
|
cfg := defaultRunConfig()
|
|
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
|
|
if err := WithTLS(tlsConfig).apply(&cfg); err != nil {
|
|
t.Fatalf("apply TLS option: %v", err)
|
|
}
|
|
if cfg.mode != runModeHTTPS {
|
|
t.Fatalf("expected HTTPS mode, got %v", cfg.mode)
|
|
}
|
|
if cfg.graceful {
|
|
t.Fatal("expected TLS option to remain independent from graceful shutdown")
|
|
}
|
|
if cfg.tlsConfig != tlsConfig {
|
|
t.Fatal("expected TLS config to be preserved in run config")
|
|
}
|
|
}
|
|
|
|
func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
|
|
engine := New()
|
|
if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil {
|
|
t.Fatal("expected redirect server builder to reject https address without port")
|
|
}
|
|
}
|
|
|
|
func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
|
|
cfg := defaultRunConfig()
|
|
ctx := t.Context()
|
|
if err := WithShutdownContext(ctx).apply(&cfg); err != nil {
|
|
t.Fatalf("apply shutdown context option: %v", err)
|
|
}
|
|
if err := validateRunConfig(cfg); err == nil {
|
|
t.Fatal("expected shutdown context without graceful shutdown to fail validation")
|
|
}
|
|
}
|
|
|
|
func TestValidateRunConfigDoesNotMutateMode(t *testing.T) {
|
|
cfg := defaultRunConfig()
|
|
cfg.httpRedirectAddr = ":80"
|
|
if err := validateRunConfig(cfg); err != nil {
|
|
t.Fatalf("validate run config: %v", err)
|
|
}
|
|
if cfg.mode != runModeHTTP {
|
|
t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode)
|
|
}
|
|
}
|
|
|
|
func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) {
|
|
cfg := defaultRunConfig()
|
|
cfg.mode = runModeHTTPSRedirect
|
|
cfg.tlsConfig = &tls.Config{}
|
|
cfg.useHeaderHost = false
|
|
cfg.useHeaderHostSet = true
|
|
if err := validateRunConfig(cfg); err == nil {
|
|
t.Fatal("expected configured host mode without redirect host to fail validation")
|
|
}
|
|
}
|
|
|
|
func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) {
|
|
cfg := defaultRunConfig()
|
|
cfg.mode = runModeHTTPSRedirect
|
|
cfg.tlsConfig = &tls.Config{}
|
|
cfg.useHeaderHost = true
|
|
cfg.useHeaderHostSet = true
|
|
cfg.redirectHost = "configured.example"
|
|
if err := validateRunConfig(cfg); err == nil {
|
|
t.Fatal("expected redirect host to be rejected when header host mode is enabled")
|
|
}
|
|
}
|
|
|
|
func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) {
|
|
engine := New()
|
|
server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP})
|
|
if server.BaseContext == nil {
|
|
t.Fatal("expected graceful main server to set BaseContext")
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen for base context check: %v", err)
|
|
}
|
|
defer listener.Close()
|
|
if got := server.BaseContext(listener); got != engine.shutdownCtx {
|
|
t.Fatal("expected graceful main server to use engine shutdown context")
|
|
}
|
|
}
|
|
|
|
func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) {
|
|
engine := New()
|
|
serverConfigured := false
|
|
tlsConfigured := false
|
|
engine.SetServerConfigurator(func(s *http.Server) {
|
|
serverConfigured = true
|
|
s.ReadTimeout = time.Second
|
|
})
|
|
engine.SetTLSServerConfigurator(func(s *http.Server) {
|
|
tlsConfigured = true
|
|
s.IdleTimeout = time.Second
|
|
})
|
|
|
|
server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
|
if !tlsConfigured {
|
|
t.Fatal("expected TLS configurator to run for HTTPS main server")
|
|
}
|
|
if serverConfigured {
|
|
t.Fatal("expected generic server configurator to be skipped when TLS configurator is set")
|
|
}
|
|
if server.IdleTimeout != time.Second {
|
|
t.Fatal("expected TLS configurator changes to be applied to HTTPS main server")
|
|
}
|
|
}
|
|
|
|
func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) {
|
|
engine := New()
|
|
configured := false
|
|
engine.SetServerConfigurator(func(s *http.Server) {
|
|
configured = true
|
|
s.ReadTimeout = time.Second
|
|
})
|
|
|
|
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
|
if err != nil {
|
|
t.Fatalf("build redirect server: %v", err)
|
|
}
|
|
if !configured {
|
|
t.Fatal("expected redirect server to use generic server configurator")
|
|
}
|
|
if server.ReadTimeout != time.Second {
|
|
t.Fatal("expected redirect server configurator changes to be applied")
|
|
}
|
|
}
|
|
|
|
func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) {
|
|
engine := New()
|
|
httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}})
|
|
if !httpsServer.Protocols.HTTP2() {
|
|
t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings")
|
|
}
|
|
|
|
httpServer := buildMainServer(engine, defaultRunConfig())
|
|
if httpServer.Protocols.HTTP2() {
|
|
t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled")
|
|
}
|
|
}
|
|
|
|
func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) {
|
|
engine := New()
|
|
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
|
if err != nil {
|
|
t.Fatalf("build redirect server: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
|
req.Host = "example.com:80"
|
|
rr := httptest.NewRecorder()
|
|
server.Handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusMovedPermanently {
|
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
|
}
|
|
if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" {
|
|
t.Fatalf("unexpected redirect location: %q", location)
|
|
}
|
|
}
|
|
|
|
func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) {
|
|
engine := New()
|
|
server, err := buildRedirectServer(engine, runConfig{
|
|
addr: ":443",
|
|
httpRedirectAddr: ":80",
|
|
useHeaderHost: true,
|
|
useHeaderHostSet: true,
|
|
redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("build redirect server: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
|
req.Host = "example.com:80"
|
|
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
|
req.Header.Set("X-First-Host", "first.example")
|
|
rr := httptest.NewRecorder()
|
|
server.Handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusMovedPermanently {
|
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
|
}
|
|
if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" {
|
|
t.Fatalf("unexpected redirect location: %q", location)
|
|
}
|
|
}
|
|
|
|
func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) {
|
|
engine := New()
|
|
server, err := buildRedirectServer(engine, runConfig{
|
|
addr: ":443",
|
|
httpRedirectAddr: ":80",
|
|
useHeaderHost: true,
|
|
useHeaderHostSet: true,
|
|
redirectHostHeaders: []string{"X-Forwarded-Host"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("build redirect server: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
|
req.Host = "example.com:80"
|
|
rr := httptest.NewRecorder()
|
|
server.Handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUpgradeRequired {
|
|
t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) {
|
|
engine := New()
|
|
server, err := buildRedirectServer(engine, runConfig{
|
|
addr: ":443",
|
|
httpRedirectAddr: ":80",
|
|
useHeaderHost: false,
|
|
useHeaderHostSet: true,
|
|
redirectHost: "configured.example",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("build redirect server: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil)
|
|
req.Host = "example.com:80"
|
|
req.Header.Set("X-Forwarded-Host", "forwarded.example")
|
|
rr := httptest.NewRecorder()
|
|
server.Handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusMovedPermanently {
|
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
|
}
|
|
if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" {
|
|
t.Fatalf("unexpected redirect location: %q", location)
|
|
}
|
|
}
|
|
|
|
func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) {
|
|
engine := New()
|
|
server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"})
|
|
if err != nil {
|
|
t.Fatalf("build redirect server: %v", err)
|
|
}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil)
|
|
req.Host = "[::1]:80"
|
|
rr := httptest.NewRecorder()
|
|
server.Handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusMovedPermanently {
|
|
t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code)
|
|
}
|
|
if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" {
|
|
t.Fatalf("unexpected IPv6 redirect location: %q", location)
|
|
}
|
|
}
|
|
|
|
func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) {
|
|
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen on occupied addr: %v", err)
|
|
}
|
|
occupiedAddr := occupied.Addr().String()
|
|
defer occupied.Close()
|
|
|
|
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen for redirect addr: %v", err)
|
|
}
|
|
redirectAddr := redirectListener.Addr().String()
|
|
if err := redirectListener.Close(); err != nil {
|
|
t.Fatalf("close redirect addr probe: %v", err)
|
|
}
|
|
|
|
engine := New()
|
|
redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr})
|
|
if err != nil {
|
|
t.Fatalf("build redirect server: %v", err)
|
|
}
|
|
mainServer := &http.Server{Addr: occupiedAddr, Handler: engine}
|
|
|
|
err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background())
|
|
if err == nil {
|
|
t.Fatal("expected gracefulServe to fail when one server cannot bind")
|
|
}
|
|
if !strings.Contains(err.Error(), occupiedAddr) {
|
|
t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err)
|
|
}
|
|
|
|
conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond)
|
|
if dialErr == nil {
|
|
conn.Close()
|
|
t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr)
|
|
}
|
|
if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") {
|
|
t.Fatalf("unexpected dial result after shutdown, got %v", dialErr)
|
|
}
|
|
}
|
|
|
|
func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) {
|
|
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen on occupied addr: %v", err)
|
|
}
|
|
occupiedAddr := occupied.Addr().String()
|
|
defer occupied.Close()
|
|
|
|
redirectListener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen for redirect addr: %v", err)
|
|
}
|
|
redirectAddr := redirectListener.Addr().String()
|
|
if err := redirectListener.Close(); err != nil {
|
|
t.Fatalf("close redirect addr probe: %v", err)
|
|
}
|
|
|
|
engine := New()
|
|
err = engine.Run(
|
|
WithAddr(occupiedAddr),
|
|
WithTLS(&tls.Config{}),
|
|
WithHTTPRedirect(redirectAddr),
|
|
)
|
|
if err == nil {
|
|
t.Fatal("expected non-graceful TLS redirect startup to return bind error")
|
|
}
|
|
if !strings.Contains(err.Error(), occupiedAddr) {
|
|
t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err)
|
|
}
|
|
}
|