diff --git a/context.go b/context.go index 6479a69..23fa92f 100644 --- a/context.go +++ b/context.go @@ -58,9 +58,6 @@ type Context struct { engine *Engine sameSite http.SameSite - - // 请求体Body大小限制 - MaxRequestBodySize int64 } // --- Context 相关方法实现 --- @@ -86,7 +83,6 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { c.formCache = nil // 清空表单数据缓存 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 - c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize // c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员 } @@ -212,11 +208,6 @@ func (c *Context) MustGet(key string) interface{} { panic("Key \"" + key + "\" does not exist in context.") } -// SetMaxRequestBodySize -func (c *Context) SetMaxRequestBodySize(size int64) { - c.MaxRequestBodySize = size -} - // Query 从 URL 查询参数中获取值 // 懒加载解析查询参数,并进行缓存 func (c *Context) Query(key string) string { @@ -443,28 +434,8 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { if c.Request.Body == nil { return nil, nil } - - var limitBytesReader io.ReadCloser - - if c.MaxRequestBodySize > 0 { - limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } else { - limitBytesReader = c.Request.Body - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } - - data, err := copyb.ReadAll(limitBytesReader) + defer c.Request.Body.Close() // 确保请求体被关闭 + data, err := copyb.ReadAll(c.Request.Body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -477,28 +448,8 @@ func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { if c.Request.Body == nil { return nil, nil } - - var limitBytesReader io.ReadCloser - - if c.MaxRequestBodySize > 0 { - limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } else { - limitBytesReader = c.Request.Body - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } - - data, err := copyb.ReadAll(limitBytesReader) + defer c.Request.Body.Close() // 确保请求体被关闭 + data, err := copyb.ReadAll(c.Request.Body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) diff --git a/engine.go b/engine.go index a2e7b2a..3d39cf0 100644 --- a/engine.go +++ b/engine.go @@ -3,12 +3,14 @@ package touka import ( "context" "errors" + "fmt" "log" "reflect" "runtime" "strings" "net/http" + "path" "sync" @@ -57,8 +59,8 @@ type Engine struct { noRoute HandlerFunc // NoRoute 处理器 noRoutes HandlersChain // NoRoutes 处理器链 (如果 noRoute 未设置,则使用此链) - unMatchFS UnMatchFS // 未匹配下的处理 - UnMatchFSRoutes HandlersChain // UnMatch 处理器链, 用于扩展自由度, 在此局部链上, unMatchFS相关处理会在最后 + unMatchFS UnMatchFS // 未匹配下的处理 + unMatchFileServer http.Handler // 处理handle serverProtocols *http.Protocols //服务协议 Protocols ProtocolsConfig //协议版本配置 @@ -72,35 +74,6 @@ type Engine struct { // 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器 // 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置) TLSServerConfigurator func(*http.Server) - - // GlobalMaxRequestBodySize 全局请求体Body大小限制 - GlobalMaxRequestBodySize int64 -} - -// HandleFunc 注册一个或多个 HTTP 方法的路由 -// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) -// relativePath 是相对于当前组或 Engine 的路径 -// handlers 是处理函数链 -func (engine *Engine) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) { - for _, method := range methods { - if _, ok := MethodsSet[method]; !ok { - panic("invalid method: " + method) - } - engine.Handle(method, relativePath, handlers...) - } -} - -// HandleFunc 注册一个或多个 HTTP 方法的路由 -// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) -// relativePath 是相对于当前组或 Engine 的路径 -// handlers 是处理函数链 -func (group *RouterGroup) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) { - for _, method := range methods { - if _, ok := MethodsSet[method]; !ok { - panic("invalid method: " + method) - } - group.Handle(method, relativePath, handlers...) - } } type ErrorHandle struct { @@ -198,11 +171,10 @@ func New() *Engine { unMatchFS: UnMatchFS{ ServeUnmatchedAsFS: false, }, - noRoute: nil, - noRoutes: make(HandlersChain, 0), - ServerConfigurator: nil, - TLSServerConfigurator: nil, - GlobalMaxRequestBodySize: -1, + noRoute: nil, + noRoutes: make(HandlersChain, 0), + ServerConfigurator: nil, + TLSServerConfigurator: nil, } //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() @@ -281,22 +253,15 @@ func (engine *Engine) GetDefaultErrHandler() ErrorHandler { return defaultErrorHandle } -func (engine *Engine) SetUnMatchFS(fs http.FileSystem, handlers ...HandlerFunc) { - engine.SetUnMatchFSChain(fs, handlers...) -} - -func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerFunc) { +// 传入并配置unMatchFS +func (engine *Engine) SetUnMatchFS(fs http.FileSystem) { if fs != nil { engine.unMatchFS.FSForUnmatched = fs engine.unMatchFS.ServeUnmatchedAsFS = true - unMatchFileServer := GetStaticFSHandleFunc(http.FileServer(fs)) - combinedChain := make(HandlersChain, len(handlers)+1) - copy(combinedChain, handlers) - combinedChain[len(handlers)] = unMatchFileServer - engine.UnMatchFSRoutes = combinedChain + engine.unMatchFileServer = http.FileServer(fs) } else { engine.unMatchFS.ServeUnmatchedAsFS = false - engine.UnMatchFSRoutes = nil + engine.unMatchFileServer = nil } } @@ -329,11 +294,6 @@ func (engine *Engine) SetProtocols(config *ProtocolsConfig) { engine.useDefaultProtocols = false } -// 配置全局Req Body大小限制 -func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) { - engine.GlobalMaxRequestBodySize = size -} - // 配置Req IP来源 Headers func (engine *Engine) SetRemoteIPHeaders(headers []string) { engine.RemoteIPHeaders = headers @@ -418,6 +378,135 @@ func getHandlerName(h HandlerFunc) string { } +// ServeHTTP 实现了 http.Handler 接口,是 Engine 处理所有 HTTP 请求的入口 +// 每个传入的 HTTP 请求都会调用此方法 +func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // 从 Context Pool 中获取一个 Context 对象进行复用 + c := engine.pool.Get().(*Context) + c.reset(w, req) // 重置 Context 对象的状态以适应当前请求 + + // 执行请求处理 + engine.handleRequest(c) + + // 将 Context 对象放回 Context Pool,以供下次复用 + engine.pool.Put(c) +} + +// handleRequest 负责根据请求查找路由并执行相应的处理函数链 +// 这是路由查找和执行的核心逻辑 +func (engine *Engine) handleRequest(c *Context) { + httpMethod := c.Request.Method + requestPath := c.Request.URL.Path + + // 查找对应的路由树的根节点 + rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 + if rootNode != nil { + // 查找匹配的节点和处理函数 + // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 + // skippedNodes 内部使用,因此无需从外部传入已分配的 slice + var skippedNodes []skippedNode // 用于回溯的跳过节点 + // 直接在 rootNode 上调用 getValue 方法 + value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码 + + if value.handlers != nil { + //c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数 + c.handlers = value.handlers + c.Next() // 执行处理函数链 + //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 + return + } + + // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) + if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 + if value.tsr && engine.RedirectTrailingSlash { + // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ + redirectPath := requestPath + if len(requestPath) > 0 && requestPath[len(requestPath)-1] == '/' { + redirectPath = requestPath[:len(requestPath)-1] + } else { + redirectPath = requestPath + "/" + } + c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 + return + } + // 尝试不区分大小写的查找 + // 直接在 rootNode 上调用 findCaseInsensitivePath 方法 + ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) + if found && engine.RedirectFixedPath { + c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 + return + } + } + } + + // 构建处理链 + // 组合全局中间件和路由处理函数 + handlers := engine.globalHandlers + + // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 + // 则在全局中间件之后添加 MethodNotAllowed 处理器 + if engine.HandleMethodNotAllowed { + handlers = append(handlers, MethodNotAllowed()) + } + + // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed + // 则在处理链的最后添加 UnMatchFS 处理器 + if engine.unMatchFS.ServeUnmatchedAsFS { + handlers = append(handlers, unMatchFSHandle()) + } + + // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS + // 则在处理链的最后添加 NoRoute 处理器 + if engine.noRoute != nil { + handlers = append(handlers, engine.noRoute) + } else if len(engine.noRoutes) > 0 { + handlers = append(handlers, engine.noRoutes...) + } + + handlers = append(handlers, NotFound()) + + c.handlers = handlers + c.Next() // 执行处理函数链 + //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 +} + +// UnMatchFS HandleFunc +func unMatchFSHandle() HandlerFunc { + return func(c *Context) { + engine := c.engine + // 确保 engine.unMatchFileServer 存在 + if !engine.unMatchFS.ServeUnmatchedAsFS || engine.unMatchFileServer == nil { + c.Next() // 如果未配置或 FileSystem 为 nil,则继续处理链 + return + } + if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { + // 使用 http.FileServer 处理未匹配的请求 + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + c.engine.unMatchFileServer.ServeHTTP(ecw, c.Request) + ecw.processAfterFileServer() + c.Abort() + return + } else { + if engine.noRoute == nil { + // 若为OPTIONS + if c.Request.Method == http.MethodOptions { + //返回allow get + c.Writer.Header().Set("Allow", "GET") + c.Status(http.StatusOK) + c.Abort() + return + } else { + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) + return + } + } else { + c.Next() + } + } + } +} + // 405中间件 func MethodNotAllowed() HandlerFunc { return func(c *Context) { @@ -632,98 +721,353 @@ func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IR } } -// ServeHTTP 实现了 http.Handler 接口,是 Engine 处理所有 HTTP 请求的入口 -// 每个传入的 HTTP 请求都会调用此方法 -func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // 从 Context Pool 中获取一个 Context 对象进行复用 - c := engine.pool.Get().(*Context) - c.reset(w, req) // 重置 Context 对象的状态以适应当前请求 +// == 其他操作方式 === - // 执行请求处理 - engine.handleRequest(c) +// StaticDir 传入一个文件夹路径, 使用FileServer进行处理 +// r.StaticDir("/test/*filepath", "/var/www/test") +func (engine *Engine) StaticDir(relativePath, rootPath string) { + // 清理路径 + relativePath = path.Clean(relativePath) + rootPath = path.Clean(rootPath) - // 将 Context 对象放回 Context Pool,以供下次复用 - engine.pool.Put(c) -} + // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 + if !strings.HasSuffix(relativePath, "/") { + relativePath += "/" + } -// handleRequest 负责根据请求查找路由并执行相应的处理函数链 -// 这是路由查找和执行的核心逻辑 -func (engine *Engine) handleRequest(c *Context) { - httpMethod := c.Request.Method - requestPath := c.Request.URL.Path + // 创建一个文件系统处理器 + fileServer := http.FileServer(http.Dir(rootPath)) - // 查找对应的路由树的根节点 - rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 - if rootNode != nil { - // 查找匹配的节点和处理函数 - // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 - // skippedNodes 内部使用,因此无需从外部传入已分配的 slice - var skippedNodes []skippedNode // 用于回溯的跳过节点 - // 直接在 rootNode 上调用 getValue 方法 - value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码 - - if value.handlers != nil { - //c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数 - c.handlers = value.handlers - c.Next() // 执行处理函数链 - //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 + // 注册一个捕获所有路径的路由,使用自定义处理器 + // 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD + // 我们可以通过在处理函数内部检查方法来限制 + engine.ANY(relativePath+"*filepath", func(c *Context) { + // 检查是否是 GET 或 HEAD 方法 + if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + if engine.HandleMethodNotAllowed { + c.Next() + } else { + // 否则,返回 405 Method Not Allowed + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) + } return } - // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) - if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 - if value.tsr && engine.RedirectTrailingSlash { - // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ - redirectPath := requestPath - if len(requestPath) > 0 && requestPath[len(requestPath)-1] == '/' { - redirectPath = requestPath[:len(requestPath)-1] - } else { - redirectPath = requestPath + "/" - } - c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 - return - } - // 尝试不区分大小写的查找 - // 直接在 rootNode 上调用 findCaseInsensitivePath 方法 - ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) - if found && engine.RedirectFixedPath { - c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 - return - } - } - } + requestPath := c.Request.URL.Path - // 构建处理链 - // 组合全局中间件和路由处理函数 - handlers := engine.globalHandlers + // 获取捕获到的文件路径参数 + filepath := c.Param("filepath") - // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 - // 则在全局中间件之后添加 MethodNotAllowed 处理器 - if engine.HandleMethodNotAllowed { - handlers = append(handlers, MethodNotAllowed()) - } + // 构造文件服务器需要处理的请求路径 + // FileServer 会将请求路径与 http.Dir 的根路径结合 + // 我们需要移除相对路径前缀,只保留文件路径部分 + // 例如,如果 relativePath 是 "/static/",请求是 "/static/js/app.js" + // FileServer 需要的路径是 "/js/app.js" + // 这里的 filepath 参数已经包含了 "/" 前缀,例如 "/js/app.js" + // 所以直接使用 filepath 即可 + c.Request.URL.Path = filepath - // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed - // 则在处理链的最后添加 UnMatchFS 处理器 - if engine.unMatchFS.ServeUnmatchedAsFS { - /* - var unMatchFSHandle = c.engine.unMatchFileServer - handlers = append(handlers, unMatchFSHandle) - */ - handlers = append(handlers, engine.UnMatchFSRoutes...) - } + // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 + // 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理 + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) - // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS - // 则在处理链的最后添加 NoRoute 处理器 - if engine.noRoute != nil { - handlers = append(handlers, engine.noRoute) - } else if len(engine.noRoutes) > 0 { - handlers = append(handlers, engine.noRoutes...) - } + // + // 调用 FileServer 处理请求 + fileServer.ServeHTTP(ecw, c.Request) - handlers = append(handlers, NotFound()) + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + ecw.processAfterFileServer() - c.handlers = handlers - c.Next() // 执行处理函数链 - //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 + // 恢复原始请求路径,以便后续中间件或日志记录使用 + c.Request.URL.Path = requestPath + + // 中止处理链,因为 FileServer 已经处理了响应 + c.Abort() + }) +} + +// Group的StaticDir方式 +func (group *RouterGroup) StaticDir(relativePath, rootPath string) { + // 清理路径 + relativePath = path.Clean(relativePath) + rootPath = path.Clean(rootPath) + + // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 + if !strings.HasSuffix(relativePath, "/") { + relativePath += "/" + } + + // 创建一个文件系统处理器 + fileServer := http.FileServer(http.Dir(rootPath)) + + // 注册一个捕获所有路径的路由,使用自定义处理器 + // 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD + // 我们可以通过在处理函数内部检查方法来限制 + group.ANY(relativePath+"*filepath", func(c *Context) { + // 检查是否是 GET 或 HEAD 方法 + if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + if group.engine.HandleMethodNotAllowed { + c.Next() + } else { + // 否则,返回 405 Method Not Allowed + group.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) + } + return + } + + requestPath := c.Request.URL.Path + + // 获取捕获到的文件路径参数 + filepath := c.Param("filepath") + + // 构造文件服务器需要处理的请求路径 + // FileServer 会将请求路径与 http.Dir 的根路径结合 + // 我们需要移除相对路径前缀,只保留文件路径部分 + // 例如,如果 relativePath 是 "/static/",请求是 "/static/js/app.js" + // FileServer 需要的路径是 "/js/app.js" + // 这里的 filepath 参数已经包含了 "/" 前缀,例如 "/js/app.js" + // 所以直接使用 filepath 即可 + c.Request.URL.Path = filepath + + // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 + // 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理 + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + + // + // 调用 FileServer 处理请求 + fileServer.ServeHTTP(ecw, c.Request) + + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + ecw.processAfterFileServer() + + // 恢复原始请求路径,以便后续中间件或日志记录使用 + c.Request.URL.Path = requestPath + + // 中止处理链,因为 FileServer 已经处理了响应 + c.Abort() + }) +} + +// Static File 传入一个文件路径, 使用FileServer进行处理 +func (engine *Engine) StaticFile(relativePath, filePath string) { + // 清理路径 + relativePath = path.Clean(relativePath) + filePath = path.Clean(filePath) + + // 创建一个文件系统处理器,指向包含目标文件的目录 + // http.Dir 需要一个目录路径 + dir := path.Dir(filePath) + fileName := path.Base(filePath) + fileServer := http.FileServer(http.Dir(dir)) + + FileHandle := func(c *Context) { + // 检查是否是 GET 或 HEAD 方法 + if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + if engine.HandleMethodNotAllowed { + c.Next() + } else { + // 否则,返回 405 Method Not Allowed + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) + } + return + } + + requestPath := c.Request.URL.Path + + // 构造文件服务器需要处理的请求路径 + // FileServer 会将请求路径与 http.Dir 的根路径结合 + // 我们需要将请求路径设置为文件名,以便 FileServer 找到正确的文件 + c.Request.URL.Path = "/" + fileName // FileServer 期望路径以 / 开头 + + // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + + // 调用 FileServer 处理请求 + fileServer.ServeHTTP(ecw, c.Request) + + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + ecw.processAfterFileServer() + + // 恢复原始请求路径 + c.Request.URL.Path = requestPath + + // 中止处理链,因为 FileServer 已经处理了响应 + c.Abort() + } + + // 注册一个精确匹配的路由 + engine.GET(relativePath, FileHandle) + engine.HEAD(relativePath, FileHandle) + engine.OPTIONS(relativePath, FileHandle) + +} + +// Group的StaticFile +func (group *RouterGroup) StaticFile(relativePath, filePath string) { + // 清理路径 + relativePath = path.Clean(relativePath) + filePath = path.Clean(filePath) + + // 创建一个文件系统处理器,指向包含目标文件的目录 + // http.Dir 需要一个目录路径 + dir := path.Dir(filePath) + fileName := path.Base(filePath) + fileServer := http.FileServer(http.Dir(dir)) + + FileHandle := func(c *Context) { + // 检查是否是 GET 或 HEAD 方法 + if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + if group.engine.HandleMethodNotAllowed { + c.Next() + } else { + // 否则,返回 405 Method Not Allowed + group.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) + } + return + } + + requestPath := c.Request.URL.Path + + // 构造文件服务器需要处理的请求路径 + // FileServer 会将请求路径与 http.Dir 的根路径结合 + // 我们需要将请求路径设置为文件名,以便 FileServer 找到正确的文件 + c.Request.URL.Path = "/" + fileName // FileServer 期望路径以 / 开头 + + // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + + // 调用 FileServer 处理请求 + fileServer.ServeHTTP(ecw, c.Request) + + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + ecw.processAfterFileServer() + + // 恢复原始请求路径 + c.Request.URL.Path = requestPath + + // 中止处理链,因为 FileServer 已经处理了响应 + c.Abort() + } + + // 注册一个精确匹配的路由 + group.GET(relativePath, FileHandle) + group.HEAD(relativePath, FileHandle) + group.OPTIONS(relativePath, FileHandle) +} + +// StaticFS +func (engine *Engine) StaticFS(relativePath string, fs http.FileSystem) { + // 清理路径 + relativePath = path.Clean(relativePath) + + // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 + if !strings.HasSuffix(relativePath, "/") { + relativePath += "/" + } + + // 注册一个捕获所有路径的路由,使用 FileServer 处理器 + engine.ANY(relativePath+"*filepath", FileServer(fs)) +} + +// Group的StaticFS +func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) { + // 清理路径 + relativePath = path.Clean(relativePath) + + // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 + if !strings.HasSuffix(relativePath, "/") { + relativePath += "/" + } + + // 注册一个捕获所有路径的路由,使用 FileServer 处理器 + group.ANY(relativePath+"*filepath", FileServer(fs)) +} + +// 维护一个Methods列表 +var ( + MethodGet = "GET" + MethodHead = "HEAD" + MethodPost = "POST" + MethodPut = "PUT" + MethodPatch = "PATCH" + MethodDelete = "DELETE" + MethodConnect = "CONNECT" + MethodOptions = "OPTIONS" + MethodTrace = "TRACE" +) + +var MethodsSet = map[string]struct{}{ + MethodGet: {}, + MethodHead: {}, + MethodPost: {}, + MethodPut: {}, + MethodPatch: {}, + MethodDelete: {}, + MethodConnect: {}, + MethodOptions: {}, + MethodTrace: {}, +} + +// HandleFunc 注册一个或多个 HTTP 方法的路由 +// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) +// relativePath 是相对于当前组或 Engine 的路径 +// handlers 是处理函数链 +func (engine *Engine) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) { + for _, method := range methods { + if _, ok := MethodsSet[method]; !ok { + panic("invalid method: " + method) + } + engine.Handle(method, relativePath, handlers...) + } +} + +// HandleFunc 注册一个或多个 HTTP 方法的路由 +// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) +// relativePath 是相对于当前组或 Engine 的路径 +// handlers 是处理函数链 +func (group *RouterGroup) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) { + for _, method := range methods { + if _, ok := MethodsSet[method]; !ok { + panic("invalid method: " + method) + } + group.Handle(method, relativePath, handlers...) + } +} + +// FileServer方式, 返回一个HandleFunc, 统一化处理 +func FileServer(fs http.FileSystem) HandlerFunc { + return func(c *Context) { + // 检查是否是 GET 或 HEAD 方法 + if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + if c.engine.HandleMethodNotAllowed { + c.Next() + } else { + // 否则,返回 405 Method Not Allowed + c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method)) + } + return + } + + // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + + // 调用 http.FileServer 处理请求 + http.FileServer(fs).ServeHTTP(ecw, c.Request) + + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + ecw.processAfterFileServer() + + // 中止处理链,因为 FileServer 已经处理了响应 + c.Abort() + } } diff --git a/fileserver.go b/fileserver.go deleted file mode 100644 index 653818e..0000000 --- a/fileserver.go +++ /dev/null @@ -1,278 +0,0 @@ -package touka - -import ( - "errors" - "fmt" - "net/http" - "path" - "strings" -) - -// === FileServer相关 === - -var allowedFileServerMethods = map[string]struct{}{ - http.MethodGet: {}, - http.MethodHead: {}, -} - -// FileServer方式, 返回一个HandleFunc, 统一化处理 -func FileServer(fs http.FileSystem) HandlerFunc { - if fs == nil { - return func(c *Context) { - c.ErrorUseHandle(500, errors.New("Input FileSystem is nil")) - } - } - fileServerInstance := http.FileServer(fs) - return func(c *Context) { - FileServerHandleServe(c, fileServerInstance) - - // 中止处理链,因为 FileServer 已经处理了响应 - c.Abort() - } -} - -func FileServerHandleServe(c *Context, fsHandle http.Handler) { - if fsHandle == nil { - ErrInputFSisNil := errors.New("Input FileSystem Handle is nil") - c.AddError(ErrInputFSisNil) - // 500 - c.ErrorUseHandle(http.StatusInternalServerError, ErrInputFSisNil) - return - } - - // 检查是否是 GET 或 HEAD 方法 - if _, ok := allowedFileServerMethods[c.Request.Method]; !ok { - // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 - if c.engine.HandleMethodNotAllowed { - c.Next() - } else { - if c.engine.noRoute == nil { - if c.Request.Method == http.MethodOptions { - //返回allow get - c.Writer.Header().Set("Allow", "GET, HEAD") - c.Status(http.StatusOK) - c.Abort() - return - } else { - // 否则,返回 405 Method Not Allowed - c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method)) - } - } else { - c.Next() - } - } - return - } - - // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 - ecw := AcquireErrorCapturingResponseWriter(c) - defer ReleaseErrorCapturingResponseWriter(ecw) - - // 调用 http.FileServer 处理请求 - fsHandle.ServeHTTP(ecw, c.Request) - - // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler - ecw.processAfterFileServer() -} - -// StaticDir 传入一个文件夹路径, 使用FileServer进行处理 -// r.StaticDir("/test/*filepath", "/var/www/test") -func (engine *Engine) StaticDir(relativePath, rootPath string) { - // 清理路径 - relativePath = path.Clean(relativePath) - rootPath = path.Clean(rootPath) - - // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 - if !strings.HasSuffix(relativePath, "/") { - relativePath += "/" - } - - // 创建一个文件系统处理器 - fileServer := http.FileServer(http.Dir(rootPath)) - - // 注册一个捕获所有路径的路由,使用自定义处理器 - // 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD - // 我们可以通过在处理函数内部检查方法来限制 - engine.ANY(relativePath+"*filepath", GetStaticDirHandleFunc(fileServer)) -} - -// Group的StaticDir方式 -func (group *RouterGroup) StaticDir(relativePath, rootPath string) { - // 清理路径 - relativePath = path.Clean(relativePath) - rootPath = path.Clean(rootPath) - - // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 - if !strings.HasSuffix(relativePath, "/") { - relativePath += "/" - } - - // 创建一个文件系统处理器 - fileServer := http.FileServer(http.Dir(rootPath)) - - // 注册一个捕获所有路径的路由,使用自定义处理器 - // 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD - // 我们可以通过在处理函数内部检查方法来限制 - group.ANY(relativePath+"*filepath", GetStaticDirHandleFunc(fileServer)) -} - -// GetStaticDirHandleFunc -func (engine *Engine) GetStaticDirHandle(rootPath string) HandlerFunc { - // 清理路径 - rootPath = path.Clean(rootPath) - - // 创建一个文件系统处理器 - fileServer := http.FileServer(http.Dir(rootPath)) - - return GetStaticDirHandleFunc(fileServer) -} - -// GetStaticDirHandleFunc -func (group *RouterGroup) GetStaticDirHandle(rootPath string) HandlerFunc { // 清理路径 - return group.engine.GetStaticDirHandle(rootPath) -} - -// GetStaticDirHandle -func GetStaticDirHandleFunc(fsHandle http.Handler) HandlerFunc { - return func(c *Context) { - requestPath := c.Request.URL.Path - - // 获取捕获到的文件路径参数 - filepath := c.Param("filepath") - - // 构造文件服务器需要处理的请求路径 - c.Request.URL.Path = filepath - - FileServerHandleServe(c, fsHandle) - - // 恢复原始请求路径,以便后续中间件或日志记录使用 - c.Request.URL.Path = requestPath - - // 中止处理链,因为 FileServer 已经处理了响应 - c.Abort() - } -} - -// Static File 传入一个文件路径, 使用FileServer进行处理 -func (engine *Engine) StaticFile(relativePath, filePath string) { - // 清理路径 - relativePath = path.Clean(relativePath) - filePath = path.Clean(filePath) - - FileHandle := engine.GetStaticFileHandle(filePath) - - // 注册一个精确匹配的路由 - engine.GET(relativePath, FileHandle) - engine.HEAD(relativePath, FileHandle) - engine.OPTIONS(relativePath, FileHandle) - -} - -// Group的StaticFile -func (group *RouterGroup) StaticFile(relativePath, filePath string) { - // 清理路径 - relativePath = path.Clean(relativePath) - filePath = path.Clean(filePath) - - FileHandle := group.GetStaticFileHandle(filePath) - - // 注册一个精确匹配的路由 - group.GET(relativePath, FileHandle) - group.HEAD(relativePath, FileHandle) - group.OPTIONS(relativePath, FileHandle) -} - -// GetStaticFileHandleFunc -func (engine *Engine) GetStaticFileHandle(filePath string) HandlerFunc { - // 清理路径 - filePath = path.Clean(filePath) - - // 创建一个文件系统处理器,指向包含目标文件的目录 - dir := path.Dir(filePath) - fileName := path.Base(filePath) - fileServer := http.FileServer(http.Dir(dir)) - - return GetStaticFileHandleFunc(fileServer, fileName) -} - -// GetStaticFileHandleFunc -func (group *RouterGroup) GetStaticFileHandle(filePath string) HandlerFunc { - // 清理路径 - filePath = path.Clean(filePath) - - // 创建一个文件系统处理器,指向包含目标文件的目录 - dir := path.Dir(filePath) - fileName := path.Base(filePath) - fileServer := http.FileServer(http.Dir(dir)) - - return GetStaticFileHandleFunc(fileServer, fileName) -} - -// GetStaticFileHandleFunc -func GetStaticFileHandleFunc(fsHandle http.Handler, fileName string) HandlerFunc { - return func(c *Context) { - requestPath := c.Request.URL.Path - - // 构造文件服务器需要处理的请求路径 - c.Request.URL.Path = "/" + fileName - - FileServerHandleServe(c, fsHandle) - - // 恢复原始请求路径 - c.Request.URL.Path = requestPath - - // 中止处理链,因为 FileServer 已经处理了响应 - c.Abort() - } -} - -// StaticFS -func (engine *Engine) StaticFS(relativePath string, fs http.FileSystem) { - // 清理路径 - relativePath = path.Clean(relativePath) - - // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 - if !strings.HasSuffix(relativePath, "/") { - relativePath += "/" - } - - fileServer := http.FileServer(fs) - engine.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer)) -} - -// Group的StaticFS -func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) { - // 清理路径 - relativePath = path.Clean(relativePath) - - // 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径 - if !strings.HasSuffix(relativePath, "/") { - relativePath += "/" - } - - fileServer := http.FileServer(fs) - group.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer)) -} - -// GetStaticFSHandleFunc -func GetStaticFSHandleFunc(fsHandle http.Handler) HandlerFunc { - return func(c *Context) { - - FileServerHandleServe(c, fsHandle) - - // 中止处理链,因为 FileServer 已经处理了响应 - c.Abort() - } -} - -// GetStaticFSHandleFunc -func (engine *Engine) GetStaticFSHandle(fs http.FileSystem) HandlerFunc { - fileServer := http.FileServer(fs) - return GetStaticFSHandleFunc(fileServer) -} - -// GetStaticFSHandleFunc -func (group *RouterGroup) GetStaticFSHandle(fs http.FileSystem) HandlerFunc { - fileServer := http.FileServer(fs) - return GetStaticFSHandleFunc(fileServer) -} diff --git a/maxreader.go b/maxreader.go deleted file mode 100644 index 96ff025..0000000 --- a/maxreader.go +++ /dev/null @@ -1,92 +0,0 @@ -package touka - -import ( - "fmt" - "io" - "sync/atomic" -) - -// ErrBodyTooLarge 是当读取的字节数超过 MaxBytesReader 设置的限制时返回的错误. -// 将其定义为可导出的变量, 方便调用方使用 errors.Is 进行判断. -var ErrBodyTooLarge = fmt.Errorf("body too large") - -// maxBytesReader 是一个实现了 io.ReadCloser 接口的结构体. -// 它包装了另一个 io.ReadCloser, 并限制了从其中读取的最大字节数. -type maxBytesReader struct { - // r 是底层的 io.ReadCloser. - r io.ReadCloser - // n 是允许读取的最大字节数. - n int64 - // read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数. - read atomic.Int64 -} - -// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据, -// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误. -// -// 如果 r 为 nil, 会 panic. -// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r. -func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { - if r == nil { - panic("NewMaxBytesReader called with a nil reader") - } - // 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser. - if n < 0 { - return r - } - return &maxBytesReader{ - r: r, - n: n, - } -} - -// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制. -func (mbr *maxBytesReader) Read(p []byte) (int, error) { - // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. - readSoFar := mbr.read.Load() - - // 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误. - if readSoFar >= mbr.n { - return 0, ErrBodyTooLarge - } - - // 计算当前还可以读取多少字节. - remaining := mbr.n - readSoFar - - // 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度. - // 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数. - if int64(len(p)) > remaining { - p = p[:remaining] - } - - // 从底层 Reader 读取数据. - n, err := mbr.r.Read(p) - - // 如果实际读取到了数据, 更新原子计数器. - if n > 0 { - readSoFar = mbr.read.Add(int64(n)) - } - - // 如果底层 Read 返回错误 (例如 io.EOF). - if err != nil { - // 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束. - // 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了. - return n, err - } - - // 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误. - // 这是处理"跨越"限制情况的关键. - if readSoFar > mbr.n { - // 返回实际读取的字节数 n, 并附上超限错误. - // 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭. - return n, ErrBodyTooLarge - } - - // 一切正常, 返回读取的字节数和 nil 错误. - return n, nil -} - -// Close 方法关闭底层的 ReadCloser, 保证资源释放. -func (mbr *maxBytesReader) Close() error { - return mbr.r.Close() -} diff --git a/mergectx.go b/mergectx.go deleted file mode 100644 index 4c91601..0000000 --- a/mergectx.go +++ /dev/null @@ -1,118 +0,0 @@ -package touka - -import ( - "context" - "sync" - "time" -) - -// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. -type mergedContext struct { - // 嵌入一个基础 context, 它持有最早的 deadline 和取消信号. - context.Context - // 保存了所有的父 context, 用于 Value() 方法的查找. - parents []context.Context - // 用于手动取消此 mergedContext 的函数. - cancel context.CancelFunc -} - -// MergeCtx 创建并返回一个新的 context.Context. -// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时, -// 自动被取消 (逻辑或关系). -// -// 新的 context 会继承: -// - Deadline: 所有父 context 中最早的截止时间. -// - Value: 按传入顺序从第一个能找到值的父 context 中获取值. -func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.CancelFunc) { - if len(parents) == 0 { - return context.WithCancel(context.Background()) - } - if len(parents) == 1 { - return context.WithCancel(parents[0]) - } - - var earliestDeadline time.Time - for _, p := range parents { - if deadline, ok := p.Deadline(); ok { - if earliestDeadline.IsZero() || deadline.Before(earliestDeadline) { - earliestDeadline = deadline - } - } - } - - var baseCtx context.Context - var baseCancel context.CancelFunc - if !earliestDeadline.IsZero() { - baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline) - } else { - baseCtx, baseCancel = context.WithCancel(context.Background()) - } - - mc := &mergedContext{ - Context: baseCtx, - parents: parents, - cancel: baseCancel, - } - - // 启动一个监控 goroutine. - go func() { - defer mc.cancel() - - // orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭. - // 同时监听 baseCtx.Done() 以便支持手动取消. - select { - case <-orDone(mc.parents...): - case <-mc.Context.Done(): - } - }() - - return mc, mc.cancel -} - -// Value 返回当前Ctx Value -func (mc *mergedContext) Value(key any) any { - return mc.Context.Value(key) -} - -// Deadline 实现了 context.Context 的 Deadline 方法. -func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) { - return mc.Context.Deadline() -} - -// Done 实现了 context.Context 的 Done 方法. -func (mc *mergedContext) Done() <-chan struct{} { - return mc.Context.Done() -} - -// Err 实现了 context.Context 的 Err 方法. -func (mc *mergedContext) Err() error { - return mc.Context.Err() -} - -// orDone 是一个辅助函数, 返回一个 channel. -// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭. -// 这是一个非阻塞的、不会泄漏 goroutine 的实现. -func orDone(contexts ...context.Context) <-chan struct{} { - done := make(chan struct{}) - - var once sync.Once - closeDone := func() { - once.Do(func() { - close(done) - }) - } - - // 为每个父 context 启动一个 goroutine. - for _, ctx := range contexts { - go func(c context.Context) { - select { - case <-c.Done(): - closeDone() - case <-done: - // orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出. - } - }(ctx) - } - - return done -} diff --git a/touka.go b/touka.go index a9f9a1c..ba8400d 100644 --- a/touka.go +++ b/touka.go @@ -42,28 +42,3 @@ type RouteInfo struct { Handler string // 处理函数名称 Group string // 路由分组 } - -// 维护一个Methods列表 -var ( - MethodGet = "GET" - MethodHead = "HEAD" - MethodPost = "POST" - MethodPut = "PUT" - MethodPatch = "PATCH" - MethodDelete = "DELETE" - MethodConnect = "CONNECT" - MethodOptions = "OPTIONS" - MethodTrace = "TRACE" -) - -var MethodsSet = map[string]struct{}{ - MethodGet: {}, - MethodHead: {}, - MethodPost: {}, - MethodPut: {}, - MethodPatch: {}, - MethodDelete: {}, - MethodConnect: {}, - MethodOptions: {}, - MethodTrace: {}, -}