diff --git a/engine.go b/engine.go index eeafe6e..6a7a668 100644 --- a/engine.go +++ b/engine.go @@ -18,8 +18,8 @@ import ( "github.com/fenthope/reco" ) -// Last 返回链中的最后一个处理函数。 -// 如果链为空,则返回 nil。 +// Last 返回链中的最后一个处理函数 +// 如果链为空,则返回 nil func (c HandlersChain) Last() HandlerFunc { if len(c) > 0 { return c[len(c)-1] @@ -27,37 +27,37 @@ func (c HandlersChain) Last() HandlerFunc { return nil } -// Engine 是 Touka 框架的核心,负责路由注册、中间件管理和请求分发。 -// 它实现了 http.Handler 接口,可以直接用于 http.ListenAndServe。 +// Engine 是 Touka 框架的核心,负责路由注册、中间件管理和请求分发 +// 它实现了 http.Handler 接口,可以直接用于 http.ListenAndServe type Engine struct { methodTrees methodTrees // 存储所有HTTP方法的路由树 - pool sync.Pool // Context Pool 用于复用 Context 对象,提高性能。 + pool sync.Pool // Context Pool 用于复用 Context 对象,提高性能 - globalHandlers HandlersChain // 全局中间件,应用于所有路由。 + globalHandlers HandlersChain // 全局中间件,应用于所有路由 - maxParams uint16 // 记录所有路由中最大的参数数量,用于优化 Params 切片的分配。 + maxParams uint16 // 记录所有路由中最大的参数数量,用于优化 Params 切片的分配 - // 可配置项,用于控制框架行为,参考 Gin + // 可配置项,用于控制框架行为,参考 Gin RedirectTrailingSlash bool // 是否自动重定向带尾部斜杠的路径到不带尾部斜杠的路径 (e.g. /foo/ -> /foo) RedirectFixedPath bool // 是否自动修复路径中的大小写错误 (e.g. /Foo -> /foo) HandleMethodNotAllowed bool // 是否启用 MethodNotAllowed 处理器 ForwardByClientIP bool // 是否信任 X-Forwarded-For 等头部获取客户端 IP - RemoteIPHeaders []string // 用于获取客户端 IP 的头部列表,例如 {"X-Forwarded-For", "X-Real-IP"} - // TrustedProxies []string // 可信代理 IP 列表,用于判断是否使用 X-Forwarded-For 等头部 (预留接口) + RemoteIPHeaders []string // 用于获取客户端 IP 的头部列表,例如 {"X-Forwarded-For", "X-Real-IP"} + // TrustedProxies []string // 可信代理 IP 列表,用于判断是否使用 X-Forwarded-For 等头部 (预留接口) - HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求。 + HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求 LogReco *reco.Logger - HTMLRender interface{} // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 + HTMLRender interface{} // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 routesInfo []RouteInfo // 存储所有注册的路由信息 errorHandle ErrorHandle // 错误处理 noRoute HandlerFunc // NoRoute 处理器 - noRoutes HandlersChain // NoRoutes 处理器链 (如果 noRoute 未设置,则使用此链) + noRoutes HandlersChain // NoRoutes 处理器链 (如果 noRoute 未设置,则使用此链) unMatchFS UnMatchFS // 未匹配下的处理 unMatchFileServer http.Handler // 处理handle @@ -65,6 +65,15 @@ type Engine struct { serverProtocols *http.Protocols //服务协议 Protocols ProtocolsConfig //协议版本配置 useDefaultProtocols bool //是否使用默认协议 + + // ServerConfigurator 允许在服务器启动前对其进行自定义配置 + // 例如,设置 ReadTimeout, WriteTimeout 等 + ServerConfigurator func(*http.Server) + + // TLSServerConfigurator 允许在 HTTPS 服务器启动前进行自定义配置 + // 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器 + // 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置) + TLSServerConfigurator func(*http.Server) } type ErrorHandle struct { @@ -124,7 +133,7 @@ type ProtocolsConfig struct { Http2_Cleartext bool // 是否启用 H2C } -// New 创建并返回一个 Engine 实例。 +// New 创建并返回一个 Engine 实例 func New() *Engine { engine := &Engine{ methodTrees: make(methodTrees, 0, 9), // 常见的HTTP方法有9个 (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS, CONNECT, TRACE) @@ -143,22 +152,24 @@ func New() *Engine { unMatchFS: UnMatchFS{ ServeUnmatchedAsFS: false, }, - noRoute: nil, - noRoutes: make(HandlersChain, 0), + noRoute: nil, + noRoutes: make(HandlersChain, 0), + ServerConfigurator: nil, + TLSServerConfigurator: nil, } //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() engine.SetLoggerCfg(defaultLogRecoConfig) - // 初始化 Context Pool,为每个新 Context 实例提供一个构造函数 + // 初始化 Context Pool,为每个新 Context 实例提供一个构造函数 engine.pool.New = func() interface{} { return &Context{ - Writer: newResponseWriter(nil), // 初始时可以传入nil,在ServeHTTP中会重新设置实际的 http.ResponseWriter + Writer: newResponseWriter(nil), // 初始时可以传入nil,在ServeHTTP中会重新设置实际的 http.ResponseWriter Params: make(Params, 0, engine.maxParams), // 预分配 Params 切片以减少内存分配 Keys: make(map[string]interface{}), Errors: make([]error, 0), - ctx: context.Background(), // 初始上下文,后续会被请求的 Context 覆盖 + ctx: context.Background(), // 初始上下文,后续会被请求的 Context 覆盖 HTTPClient: engine.HTTPClient, - engine: engine, // Context 持有 Engine 引用,方便访问 Engine 的配置 + engine: engine, // Context 持有 Engine 引用,方便访问 Engine 的配置 } } @@ -174,6 +185,19 @@ func Default() *Engine { // === 外部操作方法 === +// SetServerConfigurator 设置一个函数,该函数将在任何 HTTP 或 HTTPS 服务器 +// (通过 RunShutdown, RunTLS, RunTLSRedir) 启动前被调用, +// 允许用户对底层的 *http.Server 实例进行自定义配置 +func (engine *Engine) SetServerConfigurator(fn func(*http.Server)) { + engine.ServerConfigurator = fn +} + +// SetTLSServerConfigurator 设置一个函数,该函数将专门用于配置 HTTPS 服务器 +// 如果设置了此函数,它将覆盖通用的 ServerConfigurator +func (engine *Engine) SetTLSServerConfigurator(fn func(*http.Server)) { + engine.TLSServerConfigurator = fn +} + // SetLogger传入实例 func (engine *Engine) SetLogger(logger *reco.Logger) { engine.LogReco = logger @@ -241,27 +265,27 @@ func (engine *Engine) SetRemoteIPHeaders(headers []string) { engine.RemoteIPHeaders = headers } -// SetForwardByClientIP 设置是否信任 X-Forwarded-For 等头部获取客户端 IP。 +// SetForwardByClientIP 设置是否信任 X-Forwarded-For 等头部获取客户端 IP func (engine *Engine) SetForwardByClientIP(enable bool) { engine.ForwardByClientIP = enable } -// SetHTTPClient 设置 Engine 使用的 httpc.Client。 +// SetHTTPClient 设置 Engine 使用的 httpc.Client func (engine *Engine) SetHTTPClient(client *httpc.Client) { if client != nil { engine.HTTPClient = client } } -// registerMethodTree 内部方法,用于获取或注册对应 HTTP 方法的路由树根节点。 -// 如果该方法没有对应的树,则创建一个新的树。 +// registerMethodTree 内部方法,用于获取或注册对应 HTTP 方法的路由树根节点 +// 如果该方法没有对应的树,则创建一个新的树 func (engine *Engine) registerMethodTree(method string) *node { for _, tree := range engine.methodTrees { if tree.method == method { return tree.root } } - // 如果没有找到,则创建一个新的方法树并添加到列表中 + // 如果没有找到,则创建一个新的方法树并添加到列表中 root := &node{ nType: root, // 根节点类型 fullPath: "/", // 根路径 @@ -270,9 +294,9 @@ func (engine *Engine) registerMethodTree(method string) *node { return root } -// addRoute 将一个路由及处理函数链添加到路由树中。 -// 这是框架内部路由注册的核心逻辑。 -// groupPath 用于记录路由所属的分组路径。 +// addRoute 将一个路由及处理函数链添加到路由树中 +// 这是框架内部路由注册的核心逻辑 +// groupPath 用于记录路由所属的分组路径 func (engine *Engine) addRoute(method, absolutePath, groupPath string, handlers HandlersChain) { // relativePath 更名为 absolutePath if absolutePath == "" { panic("absolute path must not be empty") @@ -281,7 +305,7 @@ func (engine *Engine) addRoute(method, absolutePath, groupPath string, handlers panic("handlers must not be empty") } - // 检查并更新 maxParams,使用 absolutePath + // 检查并更新 maxParams,使用 absolutePath if n := countParams(absolutePath); n > engine.maxParams { engine.maxParams = n } @@ -302,10 +326,10 @@ func (engine *Engine) addRoute(method, absolutePath, groupPath string, handlers }) } -// getHandlerName 辅助函数,用于获取 HandlerFunc 的名称。 -// 注意:这只是一个简单的反射实现,对于匿名函数或闭包,可能返回不可读的名称。 +// getHandlerName 辅助函数,用于获取 HandlerFunc 的名称 +// 注意:这只是一个简单的反射实现,对于匿名函数或闭包,可能返回不可读的名称 func getHandlerName(h HandlerFunc) string { - //return reflect.TypeOf(h).Name() // 对于具名函数,返回函数名。对于匿名函数,可能返回空字符串或类似 func123 这样的名称。 + //return reflect.TypeOf(h).Name() // 对于具名函数,返回函数名对于匿名函数,可能返回空字符串或类似 func123 这样的名称 // 更精确的获取函数名需要 import "runtime" // pc := reflect.ValueOf(h).Pointer() // f := runtime.FuncForPC(pc) @@ -320,8 +344,8 @@ func getHandlerName(h HandlerFunc) string { } -// ServeHTTP 实现了 http.Handler 接口,是 Engine 处理所有 HTTP 请求的入口。 -// 每个传入的 HTTP 请求都会调用此方法。 +// ServeHTTP 实现了 http.Handler 接口,是 Engine 处理所有 HTTP 请求的入口 +// 每个传入的 HTTP 请求都会调用此方法 func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // 从 Context Pool 中获取一个 Context 对象进行复用 c := engine.pool.Get().(*Context) @@ -330,12 +354,12 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // 执行请求处理 engine.handleRequest(c) - // 将 Context 对象放回 Context Pool,以供下次复用 + // 将 Context 对象放回 Context Pool,以供下次复用 engine.pool.Put(c) } -// handleRequest 负责根据请求查找路由并执行相应的处理函数链。 -// 这是路由查找和执行的核心逻辑。 +// handleRequest 负责根据请求查找路由并执行相应的处理函数链 +// 这是路由查找和执行的核心逻辑 func (engine *Engine) handleRequest(c *Context) { httpMethod := c.Request.Method requestPath := c.Request.URL.Path @@ -344,8 +368,8 @@ func (engine *Engine) handleRequest(c *Context) { rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 if rootNode != nil { // 查找匹配的节点和处理函数 - // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 - // skippedNodes 内部使用,因此无需从外部传入已分配的 slice + // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 + // skippedNodes 内部使用,因此无需从外部传入已分配的 slice var skippedNodes []skippedNode // 用于回溯的跳过节点 // 直接在 rootNode 上调用 getValue 方法 value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码 @@ -358,7 +382,7 @@ func (engine *Engine) handleRequest(c *Context) { return } - // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) + // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 if value.tsr && engine.RedirectTrailingSlash { // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ @@ -385,19 +409,19 @@ func (engine *Engine) handleRequest(c *Context) { // 组合全局中间件和路由处理函数 handlers := engine.globalHandlers - // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 + // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 // 则在全局中间件之后添加 MethodNotAllowed 处理器 if engine.HandleMethodNotAllowed { handlers = append(handlers, MethodNotAllowed()) } - // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed + // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed // 则在处理链的最后添加 UnMatchFS 处理器 if engine.unMatchFS.ServeUnmatchedAsFS { handlers = append(handlers, unMatchFSHandle()) } - // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS + // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS // 则在处理链的最后添加 NoRoute 处理器 if engine.noRoute != nil { handlers = append(handlers, engine.noRoute) @@ -418,7 +442,7 @@ func unMatchFSHandle() HandlerFunc { engine := c.engine // 确保 engine.unMatchFileServer 存在 if !engine.unMatchFS.ServeUnmatchedAsFS || engine.unMatchFileServer == nil { - c.Next() // 如果未配置或 FileSystem 为 nil,则继续处理链 + c.Next() // 如果未配置或 FileSystem 为 nil,则继续处理链 return } if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { @@ -457,31 +481,31 @@ func MethodNotAllowed() HandlerFunc { engine := c.engine // 是否是OPTIONS方式 if httpMethod == http.MethodOptions { - // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 allowedMethods := []string{} for _, treeIter := range engine.methodTrees { var tempSkippedNodes []skippedNode - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) if value.handlers != nil { allowedMethods = append(allowedMethods, treeIter.method) } } if len(allowedMethods) > 0 { - // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) c.Status(http.StatusOK) return } } - // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 for _, treeIter := range engine.methodTrees { - if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 continue } - var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数 + var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数 if value.handlers != nil { // 使用定义的ErrorHandle处理 engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) @@ -512,8 +536,8 @@ func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { Engine.noRoutes = handlerFuncs } -// combineHandlers 组合多个处理函数链为一个。 -// 这是构建完整处理链(全局中间件 + 组中间件 + 路由处理函数)的关键。 +// combineHandlers 组合多个处理函数链为一个 +// 这是构建完整处理链(全局中间件 + 组中间件 + 路由处理函数)的关键 func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) HandlersChain { finalSize := len(h1) + len(h2) mergedHandlers := make(HandlersChain, finalSize) @@ -522,15 +546,15 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle return mergedHandlers } -// Use 将全局中间件添加到 Engine。 -// 这些中间件将应用于所有注册的路由。 +// Use 将全局中间件添加到 Engine +// 这些中间件将应用于所有注册的路由 func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { engine.globalHandlers = append(engine.globalHandlers, middleware...) return engine } -// Handle 注册通用 HTTP 方法的路由。 -// 这是所有具体 HTTP 方法注册的基础方法。 +// Handle 注册通用 HTTP 方法的路由 +// 这是所有具体 HTTP 方法注册的基础方法 func (engine *Engine) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) { absolutePath := path.Join("/", relativePath) // 修正:统一使用 path.Join 进行路径拼接 // 修正:将全局中间件与此路由的处理函数合并 @@ -538,42 +562,42 @@ func (engine *Engine) Handle(httpMethod, relativePath string, handlers ...Handle engine.addRoute(httpMethod, absolutePath, "/", fullHandlers) } -// GET 注册 GET 方法的路由。 +// GET 注册 GET 方法的路由 func (engine *Engine) GET(relativePath string, handlers ...HandlerFunc) { engine.Handle(http.MethodGet, relativePath, handlers...) } -// POST 注册 POST 方法的路由。 +// POST 注册 POST 方法的路由 func (engine *Engine) POST(relativePath string, handlers ...HandlerFunc) { engine.Handle(http.MethodPost, relativePath, handlers...) } -// PUT 注册 PUT 方法的路由。 +// PUT 注册 PUT 方法的路由 func (engine *Engine) PUT(relativePath string, handlers ...HandlerFunc) { engine.Handle(http.MethodPut, relativePath, handlers...) } -// DELETE 注册 DELETE 方法的路由。 +// DELETE 注册 DELETE 方法的路由 func (engine *Engine) DELETE(relativePath string, handlers ...HandlerFunc) { engine.Handle(http.MethodDelete, relativePath, handlers...) } -// PATCH 注册 PATCH 方法的路由。 +// PATCH 注册 PATCH 方法的路由 func (engine *Engine) PATCH(relativePath string, handlers ...HandlerFunc) { engine.Handle(http.MethodPatch, relativePath, handlers...) } -// HEAD 注册 HEAD 方法的路由。 +// HEAD 注册 HEAD 方法的路由 func (engine *Engine) HEAD(relativePath string, handlers ...HandlerFunc) { engine.Handle(http.MethodHead, relativePath, handlers...) } -// OPTIONS 注册 OPTIONS 方法的路由。 +// OPTIONS 注册 OPTIONS 方法的路由 func (engine *Engine) OPTIONS(relativePath string, handlers ...HandlerFunc) { engine.Handle(http.MethodOptions, relativePath, handlers...) } -// ANY 注册所有常见 HTTP 方法的路由。 +// ANY 注册所有常见 HTTP 方法的路由 func (engine *Engine) ANY(relativePath string, handlers ...HandlerFunc) { engine.Handle(http.MethodGet, relativePath, handlers...) engine.Handle(http.MethodPost, relativePath, handlers...) @@ -584,13 +608,13 @@ func (engine *Engine) ANY(relativePath string, handlers ...HandlerFunc) { engine.Handle(http.MethodOptions, relativePath, handlers...) } -// GetRouterInfo 返回所有已注册的路由信息。 +// GetRouterInfo 返回所有已注册的路由信息 func (engine *Engine) GetRouterInfo() []RouteInfo { return engine.routesInfo } -// Group 创建一个新的路由组。 -// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起。 +// Group 创建一个新的路由组 +// 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { return &RouterGroup{ Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 @@ -599,30 +623,30 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute } } -// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由。 -// 它也实现了 IRouter 接口,允许嵌套分组。 +// RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 +// 它也实现了 IRouter 接口,允许嵌套分组 type RouterGroup struct { - Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 + Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 basePath string // 组路径前缀 - engine *Engine // 指向 Engine 实例,用于注册路由到全局路由树 + engine *Engine // 指向 Engine 实例,用于注册路由到全局路由树 } -// Use 将中间件应用于当前路由组。 -// 这些中间件将应用于当前组及其子组的所有路由。 +// Use 将中间件应用于当前路由组 +// 这些中间件将应用于当前组及其子组的所有路由 func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { group.Handlers = append(group.Handlers, middleware...) return group } -// Handle 注册通用 HTTP 方法的路由到当前组。 -// 路径是相对于当前组的 basePath。 +// Handle 注册通用 HTTP 方法的路由到当前组 +// 路径是相对于当前组的 basePath func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) { absolutePath := path.Join(group.basePath, relativePath) fullHandlers := group.engine.combineHandlers(group.Handlers, handlers) group.engine.addRoute(httpMethod, absolutePath, group.basePath, fullHandlers) } -// GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS, ANY 方法与 Engine 类似,只是通过 Group 的 Handle 方法注册。 +// GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS, ANY 方法与 Engine 类似,只是通过 Group 的 Handle 方法注册 func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) { group.Handle(http.MethodGet, relativePath, handlers...) } @@ -654,7 +678,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) { group.Handle(http.MethodOptions, relativePath, handlers...) } -// Group 为当前组创建一个新的子组。 +// Group 为当前组创建一个新的子组 func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { return &RouterGroup{ Handlers: group.engine.combineHandlers(group.Handlers, handlers), @@ -672,7 +696,7 @@ func (engine *Engine) StaticDir(relativePath, rootPath string) { relativePath = path.Clean(relativePath) rootPath = path.Clean(rootPath) - // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 + // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 if !strings.HasSuffix(relativePath, "/") { relativePath += "/" } @@ -680,17 +704,17 @@ func (engine *Engine) StaticDir(relativePath, rootPath string) { // 创建一个文件系统处理器 fileServer := http.FileServer(http.Dir(rootPath)) - // 注册一个捕获所有路径的路由,使用自定义处理器 - // 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD + // 注册一个捕获所有路径的路由,使用自定义处理器 + // 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD // 我们可以通过在处理函数内部检查方法来限制 engine.ANY(relativePath+"*filepath", func(c *Context) { // 检查是否是 GET 或 HEAD 方法 if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { - // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 if engine.HandleMethodNotAllowed { c.Next() } else { - // 否则,返回 405 Method Not Allowed + // 否则,返回 405 Method Not Allowed engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) } return @@ -703,15 +727,15 @@ func (engine *Engine) StaticDir(relativePath, rootPath string) { // 构造文件服务器需要处理的请求路径 // FileServer 会将请求路径与 http.Dir 的根路径结合 - // 我们需要移除相对路径前缀,只保留文件路径部分 - // 例如,如果 relativePath 是 "/static/",请求是 "/static/js/app.js" + // 我们需要移除相对路径前缀,只保留文件路径部分 + // 例如,如果 relativePath 是 "/static/",请求是 "/static/js/app.js" // FileServer 需要的路径是 "/js/app.js" - // 这里的 filepath 参数已经包含了 "/" 前缀,例如 "/js/app.js" + // 这里的 filepath 参数已经包含了 "/" 前缀,例如 "/js/app.js" // 所以直接使用 filepath 即可 c.Request.URL.Path = filepath // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 - // 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理 + // 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理 ecw := AcquireErrorCapturingResponseWriter(c) defer ReleaseErrorCapturingResponseWriter(ecw) @@ -719,13 +743,13 @@ func (engine *Engine) StaticDir(relativePath, rootPath string) { // 调用 FileServer 处理请求 fileServer.ServeHTTP(ecw, c.Request) - // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler ecw.processAfterFileServer() - // 恢复原始请求路径,以便后续中间件或日志记录使用 + // 恢复原始请求路径,以便后续中间件或日志记录使用 c.Request.URL.Path = requestPath - // 中止处理链,因为 FileServer 已经处理了响应 + // 中止处理链,因为 FileServer 已经处理了响应 c.Abort() }) } @@ -736,7 +760,7 @@ func (group *RouterGroup) StaticDir(relativePath, rootPath string) { relativePath = path.Clean(relativePath) rootPath = path.Clean(rootPath) - // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 + // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 if !strings.HasSuffix(relativePath, "/") { relativePath += "/" } @@ -744,17 +768,17 @@ func (group *RouterGroup) StaticDir(relativePath, rootPath string) { // 创建一个文件系统处理器 fileServer := http.FileServer(http.Dir(rootPath)) - // 注册一个捕获所有路径的路由,使用自定义处理器 - // 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD + // 注册一个捕获所有路径的路由,使用自定义处理器 + // 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD // 我们可以通过在处理函数内部检查方法来限制 group.ANY(relativePath+"*filepath", func(c *Context) { // 检查是否是 GET 或 HEAD 方法 if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { - // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 if group.engine.HandleMethodNotAllowed { c.Next() } else { - // 否则,返回 405 Method Not Allowed + // 否则,返回 405 Method Not Allowed group.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) } return @@ -767,15 +791,15 @@ func (group *RouterGroup) StaticDir(relativePath, rootPath string) { // 构造文件服务器需要处理的请求路径 // FileServer 会将请求路径与 http.Dir 的根路径结合 - // 我们需要移除相对路径前缀,只保留文件路径部分 - // 例如,如果 relativePath 是 "/static/",请求是 "/static/js/app.js" + // 我们需要移除相对路径前缀,只保留文件路径部分 + // 例如,如果 relativePath 是 "/static/",请求是 "/static/js/app.js" // FileServer 需要的路径是 "/js/app.js" - // 这里的 filepath 参数已经包含了 "/" 前缀,例如 "/js/app.js" + // 这里的 filepath 参数已经包含了 "/" 前缀,例如 "/js/app.js" // 所以直接使用 filepath 即可 c.Request.URL.Path = filepath // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 - // 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理 + // 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理 ecw := AcquireErrorCapturingResponseWriter(c) defer ReleaseErrorCapturingResponseWriter(ecw) @@ -783,13 +807,13 @@ func (group *RouterGroup) StaticDir(relativePath, rootPath string) { // 调用 FileServer 处理请求 fileServer.ServeHTTP(ecw, c.Request) - // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler ecw.processAfterFileServer() - // 恢复原始请求路径,以便后续中间件或日志记录使用 + // 恢复原始请求路径,以便后续中间件或日志记录使用 c.Request.URL.Path = requestPath - // 中止处理链,因为 FileServer 已经处理了响应 + // 中止处理链,因为 FileServer 已经处理了响应 c.Abort() }) } @@ -800,7 +824,7 @@ func (engine *Engine) StaticFile(relativePath, filePath string) { relativePath = path.Clean(relativePath) filePath = path.Clean(filePath) - // 创建一个文件系统处理器,指向包含目标文件的目录 + // 创建一个文件系统处理器,指向包含目标文件的目录 // http.Dir 需要一个目录路径 dir := path.Dir(filePath) fileName := path.Base(filePath) @@ -809,11 +833,11 @@ func (engine *Engine) StaticFile(relativePath, filePath string) { FileHandle := func(c *Context) { // 检查是否是 GET 或 HEAD 方法 if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { - // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 if engine.HandleMethodNotAllowed { c.Next() } else { - // 否则,返回 405 Method Not Allowed + // 否则,返回 405 Method Not Allowed engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) } return @@ -823,7 +847,7 @@ func (engine *Engine) StaticFile(relativePath, filePath string) { // 构造文件服务器需要处理的请求路径 // FileServer 会将请求路径与 http.Dir 的根路径结合 - // 我们需要将请求路径设置为文件名,以便 FileServer 找到正确的文件 + // 我们需要将请求路径设置为文件名,以便 FileServer 找到正确的文件 c.Request.URL.Path = "/" + fileName // FileServer 期望路径以 / 开头 // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 @@ -833,13 +857,13 @@ func (engine *Engine) StaticFile(relativePath, filePath string) { // 调用 FileServer 处理请求 fileServer.ServeHTTP(ecw, c.Request) - // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler ecw.processAfterFileServer() // 恢复原始请求路径 c.Request.URL.Path = requestPath - // 中止处理链,因为 FileServer 已经处理了响应 + // 中止处理链,因为 FileServer 已经处理了响应 c.Abort() } @@ -856,7 +880,7 @@ func (group *RouterGroup) StaticFile(relativePath, filePath string) { relativePath = path.Clean(relativePath) filePath = path.Clean(filePath) - // 创建一个文件系统处理器,指向包含目标文件的目录 + // 创建一个文件系统处理器,指向包含目标文件的目录 // http.Dir 需要一个目录路径 dir := path.Dir(filePath) fileName := path.Base(filePath) @@ -865,11 +889,11 @@ func (group *RouterGroup) StaticFile(relativePath, filePath string) { FileHandle := func(c *Context) { // 检查是否是 GET 或 HEAD 方法 if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { - // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 if group.engine.HandleMethodNotAllowed { c.Next() } else { - // 否则,返回 405 Method Not Allowed + // 否则,返回 405 Method Not Allowed group.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) } return @@ -879,7 +903,7 @@ func (group *RouterGroup) StaticFile(relativePath, filePath string) { // 构造文件服务器需要处理的请求路径 // FileServer 会将请求路径与 http.Dir 的根路径结合 - // 我们需要将请求路径设置为文件名,以便 FileServer 找到正确的文件 + // 我们需要将请求路径设置为文件名,以便 FileServer 找到正确的文件 c.Request.URL.Path = "/" + fileName // FileServer 期望路径以 / 开头 // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 @@ -889,13 +913,13 @@ func (group *RouterGroup) StaticFile(relativePath, filePath string) { // 调用 FileServer 处理请求 fileServer.ServeHTTP(ecw, c.Request) - // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler ecw.processAfterFileServer() // 恢复原始请求路径 c.Request.URL.Path = requestPath - // 中止处理链,因为 FileServer 已经处理了响应 + // 中止处理链,因为 FileServer 已经处理了响应 c.Abort() } @@ -931,7 +955,7 @@ var MethodsSet = map[string]struct{}{ } // HandleFunc 注册一个或多个 HTTP 方法的路由 -// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) +// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) // relativePath 是相对于当前组或 Engine 的路径 // handlers 是处理函数链 func (engine *Engine) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) { @@ -944,7 +968,7 @@ func (engine *Engine) HandleFunc(methods []string, relativePath string, handlers } // HandleFunc 注册一个或多个 HTTP 方法的路由 -// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) +// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) // relativePath 是相对于当前组或 Engine 的路径 // handlers 是处理函数链 func (group *RouterGroup) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) { @@ -961,11 +985,11 @@ func FileServer(fs http.FileSystem) HandlerFunc { return func(c *Context) { // 检查是否是 GET 或 HEAD 方法 if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { - // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 if c.engine.HandleMethodNotAllowed { c.Next() } else { - // 否则,返回 405 Method Not Allowed + // 否则,返回 405 Method Not Allowed c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method)) } return @@ -978,10 +1002,10 @@ func FileServer(fs http.FileSystem) HandlerFunc { // 调用 http.FileServer 处理请求 http.FileServer(fs).ServeHTTP(ecw, c.Request) - // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler ecw.processAfterFileServer() - // 中止处理链,因为 FileServer 已经处理了响应 + // 中止处理链,因为 FileServer 已经处理了响应 c.Abort() } } diff --git a/serve.go b/serve.go index ab4aec3..885f8c3 100644 --- a/serve.go +++ b/serve.go @@ -17,222 +17,247 @@ import ( "github.com/fenthope/reco" ) -const defaultShutdownTimeout = 5 * time.Second // 定义默认的优雅关闭超时时间 +// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间 +const defaultShutdownTimeout = 5 * time.Second -// resolveAddress 辅助函数,处理传入的地址参数。 +// --- 内部辅助函数 --- + +// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080" func resolveAddress(addr []string) string { switch len(addr) { case 0: - return ":8080" // 默认端口 + return ":8080" case 1: return addr[0] default: - panic("too many parameters for Run method") // 参数过多则报错 + panic("too many parameters provided for server address") } } -// Run 启动 HTTP 服务器。 -// 接受一个可选的地址参数,如果未提供则默认为 ":8080"。 -func (engine *Engine) Run(addr ...string) (err error) { - address := resolveAddress(addr) // 解析服务器地址 - log.Printf("Touka server listening on %s\n", address) - err = http.ListenAndServe(address, engine) // 启动 HTTP 服务器 - return -} - -// getShutdownTimeout 解析可选的超时参数,如果未提供或无效,则返回默认超时。 +// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值 func getShutdownTimeout(timeouts []time.Duration) time.Duration { - var timeout time.Duration - if len(timeouts) > 0 { - timeout = timeouts[0] - if timeout <= 0 { - log.Printf("Warning: Provided shutdown timeout (%v) is non-positive. Using default timeout %v.\n", timeout, defaultShutdownTimeout) - timeout = defaultShutdownTimeout - } - } else { - timeout = defaultShutdownTimeout + if len(timeouts) > 0 && timeouts[0] > 0 { + return timeouts[0] } - return timeout + return defaultShutdownTimeout } -// handleGracefulShutdown 处理一个或多个 http.Server 实例的优雅关闭。 -// 它监听操作系统信号,并在指定超时时间内尝试关闭所有服务器。 +// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server, +// 并处理其启动失败的致命错误 +// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS") +func runServer(serverType string, srv *http.Server) { + go func() { + var err error + protocol := "http" + if srv.TLSConfig != nil { + protocol = "https" + } + + log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) + + if srv.TLSConfig != nil { + // 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置, + // ListenAndServeTLS 的前两个参数可以为空字符串 + err = srv.ListenAndServeTLS("", "") + } else { + err = srv.ListenAndServe() + } + + // 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed), + // 则认为是一个严重错误,并终止程序 + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("Touka %s server failed: %v", serverType, err) + } + }() +} + +// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器 +// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿 func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error { + // 创建一个 channel 来接收操作系统信号 quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - <-quit + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 + <-quit // 阻塞,直到接收到上述信号之一 log.Println("Shutting down Touka server(s)...") - go func() { - log.Println("Touka Logger Clossing...") - CloseLogger(logger) - }() + // 关闭日志记录器 + if logger != nil { + go func() { + log.Println("Closing Touka logger...") + CloseLogger(logger) + }() + } + // 创建一个带超时的上下文,用于 Shutdown ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() var wg sync.WaitGroup - var errs []error - var errsMutex sync.Mutex // 保护 errs 切片 + errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel + // 并发地关闭所有服务器 for _, srv := range servers { - srv := srv // capture loop variable wg.Add(1) - go func() { + go func(s *http.Server) { defer wg.Done() - if err := srv.Shutdown(ctx); err != nil { - errsMutex.Lock() - if err == context.DeadlineExceeded { - log.Printf("Server %s shutdown timed out after %v.\n", srv.Addr, timeout) - errs = append(errs, fmt.Errorf("server %s shutdown timed out", srv.Addr)) - } else { - log.Printf("Server %s forced to shutdown: %v\n", srv.Addr, err) - errs = append(errs, fmt.Errorf("server %s forced to shutdown: %w", srv.Addr, err)) - } - errsMutex.Unlock() + if err := s.Shutdown(ctx); err != nil { + // 将错误发送到 channel + errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) } - }() - } - wg.Wait() // 等待所有服务器的关闭 Goroutine 完成 - - if len(errs) > 0 { - return errors.Join(errs...) // 返回所有收集到的错误 + }(srv) } + wg.Wait() // 等待所有服务器的关闭 goroutine 完成 + close(errChan) // 关闭 channel,以便可以安全地遍历它 + + // 收集所有关闭过程中发生的错误 + var shutdownErrors []error + for err := range errChan { + shutdownErrors = append(shutdownErrors, err) + log.Printf("Shutdown error: %v", err) + } + + if len(shutdownErrors) > 0 { + return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 + } log.Println("Touka server(s) exited gracefully.") return nil } -// RunShutdown 启动 HTTP 服务器并支持优雅关闭。 -// 它监听操作系统信号 (SIGINT, SIGTERM),并在指定超时时间内优雅地关闭服务器。 -// addr: 服务器监听的地址,例如 ":8080"。 -// timeouts: 可选的超时时间,如果未提供,则默认为 5 秒。 -func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error { - timeout := getShutdownTimeout(timeouts) +// --- 公共 Run 方法 --- - srv := &http.Server{ - Addr: addr, - Handler: engine, // Engine 实现了 http.Handler 接口 +// Run 启动一个不支持优雅关闭的 HTTP 服务器 +// 这是一个阻塞调用,主要用于简单的场景或快速测试 +// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法 +func (engine *Engine) Run(addr ...string) error { + address := resolveAddress(addr) + srv := &http.Server{Addr: address, Handler: engine} + + // 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性 + //engine.applyDefaultServerConfig(srv) + if engine.ServerConfigurator != nil { + engine.ServerConfigurator(srv) } - - // 启动服务器在单独的 Goroutine 中运行,以便主 Goroutine 可以监听信号 - go func() { - log.Printf("Touka HTTP server listening on %s\n", addr) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Touka HTTP server listen error: %s\n", err) - } - }() - - return handleGracefulShutdown([]*http.Server{srv}, timeout, engine.LogReco) + log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address) + return srv.ListenAndServe() } -// RunWithTLS 启动 HTTPS 服务器并支持优雅关闭。 -// 用户需自行创建并传入 *tls.Config 实例,以提供完整的 TLS 配置自由度。 -// addr: 服务器监听的地址,例如 ":8443"。 -// tlsConfig: 包含 TLS 证书、密钥及其他配置的 tls.Config 实例。 -// timeouts: 可选的超时时间,如果未提供,则默认为 5 秒。 -func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunWithTLS") +// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 +func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error { + srv := &http.Server{ + Addr: addr, + Handler: engine, + } + + // 应用框架的默认配置和用户提供的自定义配置 + //engine.applyDefaultServerConfig(srv) + if engine.ServerConfigurator != nil { + engine.ServerConfigurator(srv) + } + + runServer("HTTP", srv) + return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) +} + +// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器 +func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil for RunTLS") + } + + // 配置 HTTP/2 支持 (如果使用默认配置) + if engine.useDefaultProtocols { + engine.SetProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, // 默认在 TLS 上启用 HTTP/2 + }) } - timeout := getShutdownTimeout(timeouts) srv := &http.Server{ Addr: addr, Handler: engine, - TLSConfig: tlsConfig, // 使用用户传入的 tls.Config + TLSConfig: tlsConfig, } - if engine.useDefaultProtocols { - //加入HTTP2支持 - engine.SetProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, // 默认启用 HTTP/2 - Http2_Cleartext: false, - }) + // 应用框架的默认配置和用户提供的自定义配置 + // 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator + //engine.applyDefaultServerConfig(srv) + if engine.TLSServerConfigurator != nil { + engine.TLSServerConfigurator(srv) + } else if engine.ServerConfigurator != nil { + engine.ServerConfigurator(srv) } - go func() { - log.Printf("Touka HTTPS server listening on %s\n", addr) - if err := srv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { - log.Fatalf("Touka HTTPS server listen error: %s\n", err) - } - }() - - return handleGracefulShutdown([]*http.Server{srv}, timeout, engine.LogReco) + runServer("HTTPS", srv) + return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) } -// RunWithTLSRedir 启动 HTTP 和 HTTPS 服务器,并将所有 HTTP 请求重定向到 HTTPS。 -// httpAddr: HTTP 服务器监听的地址,例如 ":80"。 -// httpsAddr: HTTPS 服务器监听的地址,例如 ":443"。 -// tlsConfig: 包含 TLS 证书、密钥及其他配置的 tls.Config 实例,用于 HTTPS 服务器。 -// timeouts: 可选的超时时间,如果未提供,则默认为 5 秒。 -func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunWithTLSRedir") - } - timeout := getShutdownTimeout(timeouts) +// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名 +func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + return engine.RunTLS(addr, tlsConfig, timeouts...) +} - // HTTPS Server +// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭 +func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil for RunTLSRedir") + } + + // --- HTTPS 服务器 --- + if engine.useDefaultProtocols { + engine.SetProtocols(&ProtocolsConfig{Http1: true, Http2: true}) + } httpsSrv := &http.Server{ Addr: httpsAddr, Handler: engine, - TLSConfig: tlsConfig, // 使用用户传入的 tls.Config + TLSConfig: tlsConfig, + } + //engine.applyDefaultServerConfig(httpsSrv) + if engine.TLSServerConfigurator != nil { + engine.TLSServerConfigurator(httpsSrv) + } else if engine.ServerConfigurator != nil { + engine.ServerConfigurator(httpsSrv) } - if engine.useDefaultProtocols { - //加入HTTP2支持 - engine.SetProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, // 默认启用 HTTP/2 - Http2_Cleartext: false, - }) - } + // --- HTTP 重定向服务器 --- + redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } - // HTTP Server for redirection + _, httpsPort, err := net.SplitHostPort(httpsAddr) + if err != nil { + // 如果 httpsAddr 没有端口,这是一个配置错误 + + log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr) + } + + targetURL := "https://" + host + // 只有在非标准 HTTPS 端口 (443) 时才附加端口号 + if httpsPort != "443" { + targetURL = "https://" + net.JoinHostPort(host, httpsPort) + } + targetURL += r.URL.RequestURI() + + http.Redirect(w, r, targetURL, http.StatusMovedPermanently) + }) httpSrv := &http.Server{ - Addr: httpAddr, - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 从 r.Host 提取 hostname,例如 "localhost:8080" -> "localhost" - hostOnly, _, err := net.SplitHostPort(r.Host) - if err != nil { // r.Host 可能没有端口,例如 "example.com" - hostOnly = r.Host - } - - // 从 httpsAddr 提取目标 HTTPS 端口,例如 ":443" -> "443" - _, targetHttpsPort, err := net.SplitHostPort(httpsAddr) - if err != nil { // httpsAddr 必须包含一个有效的端口 - log.Fatalf("Error: Invalid HTTPS address '%s' for redirection. Must specify a port (e.g., ':443').", httpsAddr) - } - - var redirectHost string - if targetHttpsPort == "443" { - redirectHost = hostOnly // 如果是默认 HTTPS 端口,则无需在 URL 中显式指定端口 - } else { - redirectHost = net.JoinHostPort(hostOnly, targetHttpsPort) // 否则,显式指定端口 - } - - // 构建目标 HTTPS URL - targetURL := "https://" + redirectHost + r.URL.RequestURI() - http.Redirect(w, r, targetURL, http.StatusMovedPermanently) // 301 Permanent Redirect - }), + Addr: httpAddr, + Handler: redirectHandler, + } + //engine.applyDefaultServerConfig(httpSrv) + if engine.ServerConfigurator != nil { + engine.ServerConfigurator(httpSrv) } - // Start HTTPS server - go func() { - log.Printf("Touka HTTPS server listening on %s\n", httpsAddr) - if err := httpsSrv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { // 同样,传入空字符串 - log.Fatalf("Touka HTTPS server listen error: %s\n", err) - } - }() - - // Start HTTP redirect server - go func() { - log.Printf("Touka HTTP redirect server listening on %s\n", httpAddr) - if err := httpSrv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Touka HTTP redirect server listen error: %s\n", err) - } - }() - - return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, timeout, engine.LogReco) + // --- 启动服务器和优雅关闭 --- + runServer("HTTPS", httpsSrv) + runServer("HTTP Redirect", httpSrv) + return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco) +} + +// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性 +func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...) }