diff --git a/engine.go b/engine.go index f236624..c2eae91 100644 --- a/engine.go +++ b/engine.go @@ -319,11 +319,16 @@ func GetDefaultProtocolsConfig() *ProtocolsConfig { // 设置默认Protocols func (engine *Engine) SetDefaultProtocols() { engine.useDefaultProtocols = true - engine.SetProtocols(GetDefaultProtocolsConfig()) + engine.setProtocols(GetDefaultProtocolsConfig()) } // 设置Protocol func (engine *Engine) SetProtocols(config *ProtocolsConfig) { + engine.setProtocols(config) + engine.useDefaultProtocols = false +} + +func (engine *Engine) setProtocols(config *ProtocolsConfig) { engine.Protocols = *config engine.serverProtocols = &http.Protocols{} // 初始化指针 func() { @@ -333,7 +338,13 @@ func (engine *Engine) SetProtocols(config *ProtocolsConfig) { p.SetUnencryptedHTTP2(config.Http2_Cleartext) *engine.serverProtocols = p // 将值赋给指针指向的结构体 }() - engine.useDefaultProtocols = false +} + +// applyDefaultServerConfig 应用框架的默认配置到 http.Server +func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { + if engine.serverProtocols != nil { + srv.Protocols = engine.serverProtocols + } } // 配置全局Req Body大小限制 diff --git a/protocols_test.go b/protocols_test.go new file mode 100644 index 0000000..73f16e9 --- /dev/null +++ b/protocols_test.go @@ -0,0 +1,111 @@ +package touka + +import ( + "crypto/tls" + "net/http" + "testing" +) + +func TestApplyDefaultServerConfig(t *testing.T) { + engine := New() + + // 1. 测试默认协议 + srv1 := &http.Server{} + engine.applyDefaultServerConfig(srv1) + + if srv1.Protocols == nil { + t.Fatal("srv1.Protocols should not be nil after applyDefaultServerConfig") + } + + // 默认配置是 Http1: true, Http2: false, Http2_Cleartext: false + if !srv1.Protocols.HTTP1() { + t.Error("Expected HTTP/1 to be enabled by default") + } + if srv1.Protocols.HTTP2() { + t.Error("Expected HTTP/2 to be disabled by default") + } + + // 2. 测试自定义协议 + engine.SetProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, + Http2_Cleartext: true, + }) + + srv2 := &http.Server{} + engine.applyDefaultServerConfig(srv2) + + if srv2.Protocols == nil { + t.Fatal("srv2.Protocols should not be nil after applyDefaultServerConfig") + } + + if !srv2.Protocols.HTTP1() { + t.Error("Expected HTTP/1 to be enabled after SetProtocols") + } + if !srv2.Protocols.HTTP2() { + t.Error("Expected HTTP/2 to be enabled after SetProtocols") + } + if !srv2.Protocols.UnencryptedHTTP2() { + t.Error("Expected Unencrypted HTTP/2 to be enabled after SetProtocols") + } + + // 3. 再次更改协议并验证 + engine.SetProtocols(&ProtocolsConfig{ + Http1: false, + Http2: true, + Http2_Cleartext: false, + }) + + srv3 := &http.Server{} + engine.applyDefaultServerConfig(srv3) + + if srv3.Protocols == nil { + t.Fatal("srv3.Protocols should not be nil") + } + if srv3.Protocols.HTTP1() { + t.Error("Expected HTTP/1 to be disabled") + } + if !srv3.Protocols.HTTP2() { + t.Error("Expected HTTP/2 to be enabled") + } +} + +func TestRunTLSProtocolInheritance(t *testing.T) { + engine := New() + + // 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2 + if engine.useDefaultProtocols { + engine.setProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, + }) + } + + srv := &http.Server{TLSConfig: &tls.Config{}} + engine.applyDefaultServerConfig(srv) + + if !srv.Protocols.HTTP2() { + t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config") + } + + // 模拟用户设置了自定义协议后调用 RunTLS + engine = New() + engine.SetProtocols(&ProtocolsConfig{ + Http1: true, + Http2: false, // 用户明确不想要 HTTP/2 + }) + + if engine.useDefaultProtocols { + engine.setProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, + }) + } + + srv2 := &http.Server{TLSConfig: &tls.Config{}} + engine.applyDefaultServerConfig(srv2) + + if srv2.Protocols.HTTP2() { + t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously") + } +} diff --git a/serve.go b/serve.go index 6a4cf2a..f3ddc5f 100644 --- a/serve.go +++ b/serve.go @@ -211,7 +211,7 @@ func (engine *Engine) Run(addr ...string) error { srv := &http.Server{Addr: address, Handler: engine} // 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性 - //engine.applyDefaultServerConfig(srv) + engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } @@ -231,7 +231,7 @@ func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error srv.RegisterOnShutdown(engine.shutdownCancel) // 应用框架的默认配置和用户提供的自定义配置 - //engine.applyDefaultServerConfig(srv) + engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } @@ -252,7 +252,7 @@ func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, t srv.RegisterOnShutdown(engine.shutdownCancel) // 应用框架的默认配置和用户提供的自定义配置 - //engine.applyDefaultServerConfig(srv) + engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } @@ -268,7 +268,7 @@ func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...tim // 配置 HTTP/2 支持 (如果使用默认配置) if engine.useDefaultProtocols { - engine.SetProtocols(&ProtocolsConfig{ + engine.setProtocols(&ProtocolsConfig{ Http1: true, Http2: true, // 默认在 TLS 上启用 HTTP/2 }) @@ -286,7 +286,7 @@ func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...tim // 应用框架的默认配置和用户提供的自定义配置 // 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator - //engine.applyDefaultServerConfig(srv) + engine.applyDefaultServerConfig(srv) if engine.TLSServerConfigurator != nil { engine.TLSServerConfigurator(srv) } else if engine.ServerConfigurator != nil { @@ -310,7 +310,7 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con // --- HTTPS 服务器 --- if engine.useDefaultProtocols { - engine.SetProtocols(&ProtocolsConfig{Http1: true, Http2: true}) + engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true}) } httpsSrv := &http.Server{ Addr: httpsAddr, @@ -321,7 +321,7 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con }, } httpsSrv.RegisterOnShutdown(engine.shutdownCancel) - //engine.applyDefaultServerConfig(httpsSrv) + engine.applyDefaultServerConfig(httpsSrv) if engine.TLSServerConfigurator != nil { engine.TLSServerConfigurator(httpsSrv) } else if engine.ServerConfigurator != nil { @@ -355,7 +355,7 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con Addr: httpAddr, Handler: redirectHandler, } - //engine.applyDefaultServerConfig(httpSrv) + engine.applyDefaultServerConfig(httpSrv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(httpSrv) }