mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-02-03 17:01:11 +08:00
Compare commits
6 commits
780e640253
...
a6171241ce
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6171241ce | ||
|
|
336b8ad958 | ||
|
|
989eb34c4c | ||
|
|
5d2ab04b6b | ||
|
|
49508b49c1 | ||
|
|
cb86cb935a |
6 changed files with 693 additions and 475 deletions
57
context.go
57
context.go
|
|
@ -58,6 +58,9 @@ type Context struct {
|
||||||
engine *Engine
|
engine *Engine
|
||||||
|
|
||||||
sameSite http.SameSite
|
sameSite http.SameSite
|
||||||
|
|
||||||
|
// 请求体Body大小限制
|
||||||
|
MaxRequestBodySize int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Context 相关方法实现 ---
|
// --- Context 相关方法实现 ---
|
||||||
|
|
@ -83,6 +86,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
||||||
c.formCache = nil // 清空表单数据缓存
|
c.formCache = nil // 清空表单数据缓存
|
||||||
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
||||||
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
||||||
|
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
||||||
// c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员
|
// c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -208,6 +212,11 @@ func (c *Context) MustGet(key string) interface{} {
|
||||||
panic("Key \"" + key + "\" does not exist in context.")
|
panic("Key \"" + key + "\" does not exist in context.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMaxRequestBodySize
|
||||||
|
func (c *Context) SetMaxRequestBodySize(size int64) {
|
||||||
|
c.MaxRequestBodySize = size
|
||||||
|
}
|
||||||
|
|
||||||
// Query 从 URL 查询参数中获取值
|
// Query 从 URL 查询参数中获取值
|
||||||
// 懒加载解析查询参数,并进行缓存
|
// 懒加载解析查询参数,并进行缓存
|
||||||
func (c *Context) Query(key string) string {
|
func (c *Context) Query(key string) string {
|
||||||
|
|
@ -434,8 +443,28 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||||||
if c.Request.Body == nil {
|
if c.Request.Body == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
defer c.Request.Body.Close() // 确保请求体被关闭
|
|
||||||
data, err := copyb.ReadAll(c.Request.Body)
|
var limitBytesReader io.ReadCloser
|
||||||
|
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||||
|
defer func() {
|
||||||
|
err := limitBytesReader.Close()
|
||||||
|
if err != nil {
|
||||||
|
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
limitBytesReader = c.Request.Body
|
||||||
|
defer func() {
|
||||||
|
err := limitBytesReader.Close()
|
||||||
|
if err != nil {
|
||||||
|
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := copyb.ReadAll(limitBytesReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||||
|
|
@ -448,8 +477,28 @@ func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
||||||
if c.Request.Body == nil {
|
if c.Request.Body == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
defer c.Request.Body.Close() // 确保请求体被关闭
|
|
||||||
data, err := copyb.ReadAll(c.Request.Body)
|
var limitBytesReader io.ReadCloser
|
||||||
|
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||||
|
defer func() {
|
||||||
|
err := limitBytesReader.Close()
|
||||||
|
if err != nil {
|
||||||
|
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
limitBytesReader = c.Request.Body
|
||||||
|
defer func() {
|
||||||
|
err := limitBytesReader.Close()
|
||||||
|
if err != nil {
|
||||||
|
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := copyb.ReadAll(limitBytesReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||||
|
|
|
||||||
598
engine.go
598
engine.go
|
|
@ -3,14 +3,12 @@ package touka
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
|
||||||
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
|
@ -59,8 +57,8 @@ type Engine struct {
|
||||||
noRoute HandlerFunc // NoRoute 处理器
|
noRoute HandlerFunc // NoRoute 处理器
|
||||||
noRoutes HandlersChain // NoRoutes 处理器链 (如果 noRoute 未设置,则使用此链)
|
noRoutes HandlersChain // NoRoutes 处理器链 (如果 noRoute 未设置,则使用此链)
|
||||||
|
|
||||||
unMatchFS UnMatchFS // 未匹配下的处理
|
unMatchFS UnMatchFS // 未匹配下的处理
|
||||||
unMatchFileServer http.Handler // 处理handle
|
UnMatchFSRoutes HandlersChain // UnMatch 处理器链, 用于扩展自由度, 在此局部链上, unMatchFS相关处理会在最后
|
||||||
|
|
||||||
serverProtocols *http.Protocols //服务协议
|
serverProtocols *http.Protocols //服务协议
|
||||||
Protocols ProtocolsConfig //协议版本配置
|
Protocols ProtocolsConfig //协议版本配置
|
||||||
|
|
@ -74,6 +72,35 @@ type Engine struct {
|
||||||
// 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器
|
// 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器
|
||||||
// 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置)
|
// 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置)
|
||||||
TLSServerConfigurator func(*http.Server)
|
TLSServerConfigurator func(*http.Server)
|
||||||
|
|
||||||
|
// GlobalMaxRequestBodySize 全局请求体Body大小限制
|
||||||
|
GlobalMaxRequestBodySize int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
||||||
|
// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"})
|
||||||
|
// relativePath 是相对于当前组或 Engine 的路径
|
||||||
|
// handlers 是处理函数链
|
||||||
|
func (engine *Engine) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) {
|
||||||
|
for _, method := range methods {
|
||||||
|
if _, ok := MethodsSet[method]; !ok {
|
||||||
|
panic("invalid method: " + method)
|
||||||
|
}
|
||||||
|
engine.Handle(method, relativePath, handlers...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
||||||
|
// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"})
|
||||||
|
// relativePath 是相对于当前组或 Engine 的路径
|
||||||
|
// handlers 是处理函数链
|
||||||
|
func (group *RouterGroup) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) {
|
||||||
|
for _, method := range methods {
|
||||||
|
if _, ok := MethodsSet[method]; !ok {
|
||||||
|
panic("invalid method: " + method)
|
||||||
|
}
|
||||||
|
group.Handle(method, relativePath, handlers...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ErrorHandle struct {
|
type ErrorHandle struct {
|
||||||
|
|
@ -171,10 +198,11 @@ func New() *Engine {
|
||||||
unMatchFS: UnMatchFS{
|
unMatchFS: UnMatchFS{
|
||||||
ServeUnmatchedAsFS: false,
|
ServeUnmatchedAsFS: false,
|
||||||
},
|
},
|
||||||
noRoute: nil,
|
noRoute: nil,
|
||||||
noRoutes: make(HandlersChain, 0),
|
noRoutes: make(HandlersChain, 0),
|
||||||
ServerConfigurator: nil,
|
ServerConfigurator: nil,
|
||||||
TLSServerConfigurator: nil,
|
TLSServerConfigurator: nil,
|
||||||
|
GlobalMaxRequestBodySize: -1,
|
||||||
}
|
}
|
||||||
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
||||||
engine.SetDefaultProtocols()
|
engine.SetDefaultProtocols()
|
||||||
|
|
@ -253,15 +281,22 @@ func (engine *Engine) GetDefaultErrHandler() ErrorHandler {
|
||||||
return defaultErrorHandle
|
return defaultErrorHandle
|
||||||
}
|
}
|
||||||
|
|
||||||
// 传入并配置unMatchFS
|
func (engine *Engine) SetUnMatchFS(fs http.FileSystem, handlers ...HandlerFunc) {
|
||||||
func (engine *Engine) SetUnMatchFS(fs http.FileSystem) {
|
engine.SetUnMatchFSChain(fs, handlers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerFunc) {
|
||||||
if fs != nil {
|
if fs != nil {
|
||||||
engine.unMatchFS.FSForUnmatched = fs
|
engine.unMatchFS.FSForUnmatched = fs
|
||||||
engine.unMatchFS.ServeUnmatchedAsFS = true
|
engine.unMatchFS.ServeUnmatchedAsFS = true
|
||||||
engine.unMatchFileServer = http.FileServer(fs)
|
unMatchFileServer := GetStaticFSHandleFunc(http.FileServer(fs))
|
||||||
|
combinedChain := make(HandlersChain, len(handlers)+1)
|
||||||
|
copy(combinedChain, handlers)
|
||||||
|
combinedChain[len(handlers)] = unMatchFileServer
|
||||||
|
engine.UnMatchFSRoutes = combinedChain
|
||||||
} else {
|
} else {
|
||||||
engine.unMatchFS.ServeUnmatchedAsFS = false
|
engine.unMatchFS.ServeUnmatchedAsFS = false
|
||||||
engine.unMatchFileServer = nil
|
engine.UnMatchFSRoutes = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -294,6 +329,11 @@ func (engine *Engine) SetProtocols(config *ProtocolsConfig) {
|
||||||
engine.useDefaultProtocols = false
|
engine.useDefaultProtocols = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 配置全局Req Body大小限制
|
||||||
|
func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) {
|
||||||
|
engine.GlobalMaxRequestBodySize = size
|
||||||
|
}
|
||||||
|
|
||||||
// 配置Req IP来源 Headers
|
// 配置Req IP来源 Headers
|
||||||
func (engine *Engine) SetRemoteIPHeaders(headers []string) {
|
func (engine *Engine) SetRemoteIPHeaders(headers []string) {
|
||||||
engine.RemoteIPHeaders = headers
|
engine.RemoteIPHeaders = headers
|
||||||
|
|
@ -378,135 +418,6 @@ func getHandlerName(h HandlerFunc) string {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP 实现了 http.Handler 接口,是 Engine 处理所有 HTTP 请求的入口
|
|
||||||
// 每个传入的 HTTP 请求都会调用此方法
|
|
||||||
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
||||||
// 从 Context Pool 中获取一个 Context 对象进行复用
|
|
||||||
c := engine.pool.Get().(*Context)
|
|
||||||
c.reset(w, req) // 重置 Context 对象的状态以适应当前请求
|
|
||||||
|
|
||||||
// 执行请求处理
|
|
||||||
engine.handleRequest(c)
|
|
||||||
|
|
||||||
// 将 Context 对象放回 Context Pool,以供下次复用
|
|
||||||
engine.pool.Put(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleRequest 负责根据请求查找路由并执行相应的处理函数链
|
|
||||||
// 这是路由查找和执行的核心逻辑
|
|
||||||
func (engine *Engine) handleRequest(c *Context) {
|
|
||||||
httpMethod := c.Request.Method
|
|
||||||
requestPath := c.Request.URL.Path
|
|
||||||
|
|
||||||
// 查找对应的路由树的根节点
|
|
||||||
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
|
|
||||||
if rootNode != nil {
|
|
||||||
// 查找匹配的节点和处理函数
|
|
||||||
// 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量
|
|
||||||
// skippedNodes 内部使用,因此无需从外部传入已分配的 slice
|
|
||||||
var skippedNodes []skippedNode // 用于回溯的跳过节点
|
|
||||||
// 直接在 rootNode 上调用 getValue 方法
|
|
||||||
value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码
|
|
||||||
|
|
||||||
if value.handlers != nil {
|
|
||||||
//c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数
|
|
||||||
c.handlers = value.handlers
|
|
||||||
c.Next() // 执行处理函数链
|
|
||||||
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
|
|
||||||
if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向
|
|
||||||
if value.tsr && engine.RedirectTrailingSlash {
|
|
||||||
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
|
|
||||||
redirectPath := requestPath
|
|
||||||
if len(requestPath) > 0 && requestPath[len(requestPath)-1] == '/' {
|
|
||||||
redirectPath = requestPath[:len(requestPath)-1]
|
|
||||||
} else {
|
|
||||||
redirectPath = requestPath + "/"
|
|
||||||
}
|
|
||||||
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 尝试不区分大小写的查找
|
|
||||||
// 直接在 rootNode 上调用 findCaseInsensitivePath 方法
|
|
||||||
ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash)
|
|
||||||
if found && engine.RedirectFixedPath {
|
|
||||||
c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建处理链
|
|
||||||
// 组合全局中间件和路由处理函数
|
|
||||||
handlers := engine.globalHandlers
|
|
||||||
|
|
||||||
// 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由
|
|
||||||
// 则在全局中间件之后添加 MethodNotAllowed 处理器
|
|
||||||
if engine.HandleMethodNotAllowed {
|
|
||||||
handlers = append(handlers, MethodNotAllowed())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed
|
|
||||||
// 则在处理链的最后添加 UnMatchFS 处理器
|
|
||||||
if engine.unMatchFS.ServeUnmatchedAsFS {
|
|
||||||
handlers = append(handlers, unMatchFSHandle())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS
|
|
||||||
// 则在处理链的最后添加 NoRoute 处理器
|
|
||||||
if engine.noRoute != nil {
|
|
||||||
handlers = append(handlers, engine.noRoute)
|
|
||||||
} else if len(engine.noRoutes) > 0 {
|
|
||||||
handlers = append(handlers, engine.noRoutes...)
|
|
||||||
}
|
|
||||||
|
|
||||||
handlers = append(handlers, NotFound())
|
|
||||||
|
|
||||||
c.handlers = handlers
|
|
||||||
c.Next() // 执行处理函数链
|
|
||||||
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnMatchFS HandleFunc
|
|
||||||
func unMatchFSHandle() HandlerFunc {
|
|
||||||
return func(c *Context) {
|
|
||||||
engine := c.engine
|
|
||||||
// 确保 engine.unMatchFileServer 存在
|
|
||||||
if !engine.unMatchFS.ServeUnmatchedAsFS || engine.unMatchFileServer == nil {
|
|
||||||
c.Next() // 如果未配置或 FileSystem 为 nil,则继续处理链
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
|
||||||
// 使用 http.FileServer 处理未匹配的请求
|
|
||||||
ecw := AcquireErrorCapturingResponseWriter(c)
|
|
||||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
|
||||||
c.engine.unMatchFileServer.ServeHTTP(ecw, c.Request)
|
|
||||||
ecw.processAfterFileServer()
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
if engine.noRoute == nil {
|
|
||||||
// 若为OPTIONS
|
|
||||||
if c.Request.Method == http.MethodOptions {
|
|
||||||
//返回allow get
|
|
||||||
c.Writer.Header().Set("Allow", "GET")
|
|
||||||
c.Status(http.StatusOK)
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 405中间件
|
// 405中间件
|
||||||
func MethodNotAllowed() HandlerFunc {
|
func MethodNotAllowed() HandlerFunc {
|
||||||
return func(c *Context) {
|
return func(c *Context) {
|
||||||
|
|
@ -721,353 +632,98 @@ func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// == 其他操作方式 ===
|
// ServeHTTP 实现了 http.Handler 接口,是 Engine 处理所有 HTTP 请求的入口
|
||||||
|
// 每个传入的 HTTP 请求都会调用此方法
|
||||||
|
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
// 从 Context Pool 中获取一个 Context 对象进行复用
|
||||||
|
c := engine.pool.Get().(*Context)
|
||||||
|
c.reset(w, req) // 重置 Context 对象的状态以适应当前请求
|
||||||
|
|
||||||
// StaticDir 传入一个文件夹路径, 使用FileServer进行处理
|
// 执行请求处理
|
||||||
// r.StaticDir("/test/*filepath", "/var/www/test")
|
engine.handleRequest(c)
|
||||||
func (engine *Engine) StaticDir(relativePath, rootPath string) {
|
|
||||||
// 清理路径
|
|
||||||
relativePath = path.Clean(relativePath)
|
|
||||||
rootPath = path.Clean(rootPath)
|
|
||||||
|
|
||||||
// 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径
|
// 将 Context 对象放回 Context Pool,以供下次复用
|
||||||
if !strings.HasSuffix(relativePath, "/") {
|
engine.pool.Put(c)
|
||||||
relativePath += "/"
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// 创建一个文件系统处理器
|
// handleRequest 负责根据请求查找路由并执行相应的处理函数链
|
||||||
fileServer := http.FileServer(http.Dir(rootPath))
|
// 这是路由查找和执行的核心逻辑
|
||||||
|
func (engine *Engine) handleRequest(c *Context) {
|
||||||
|
httpMethod := c.Request.Method
|
||||||
|
requestPath := c.Request.URL.Path
|
||||||
|
|
||||||
// 注册一个捕获所有路径的路由,使用自定义处理器
|
// 查找对应的路由树的根节点
|
||||||
// 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD
|
rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型
|
||||||
// 我们可以通过在处理函数内部检查方法来限制
|
if rootNode != nil {
|
||||||
engine.ANY(relativePath+"*filepath", func(c *Context) {
|
// 查找匹配的节点和处理函数
|
||||||
// 检查是否是 GET 或 HEAD 方法
|
// 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量
|
||||||
if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead {
|
// skippedNodes 内部使用,因此无需从外部传入已分配的 slice
|
||||||
// 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件
|
var skippedNodes []skippedNode // 用于回溯的跳过节点
|
||||||
if engine.HandleMethodNotAllowed {
|
// 直接在 rootNode 上调用 getValue 方法
|
||||||
c.Next()
|
value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码
|
||||||
} else {
|
|
||||||
// 否则,返回 405 Method Not Allowed
|
if value.handlers != nil {
|
||||||
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
//c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数
|
||||||
}
|
c.handlers = value.handlers
|
||||||
|
c.Next() // 执行处理函数链
|
||||||
|
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
requestPath := c.Request.URL.Path
|
// 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复)
|
||||||
|
if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向
|
||||||
// 获取捕获到的文件路径参数
|
if value.tsr && engine.RedirectTrailingSlash {
|
||||||
filepath := c.Param("filepath")
|
// 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/
|
||||||
|
redirectPath := requestPath
|
||||||
// 构造文件服务器需要处理的请求路径
|
if len(requestPath) > 0 && requestPath[len(requestPath)-1] == '/' {
|
||||||
// FileServer 会将请求路径与 http.Dir 的根路径结合
|
redirectPath = requestPath[:len(requestPath)-1]
|
||||||
// 我们需要移除相对路径前缀,只保留文件路径部分
|
} else {
|
||||||
// 例如,如果 relativePath 是 "/static/",请求是 "/static/js/app.js"
|
redirectPath = requestPath + "/"
|
||||||
// FileServer 需要的路径是 "/js/app.js"
|
}
|
||||||
// 这里的 filepath 参数已经包含了 "/" 前缀,例如 "/js/app.js"
|
c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向
|
||||||
// 所以直接使用 filepath 即可
|
return
|
||||||
c.Request.URL.Path = filepath
|
|
||||||
|
|
||||||
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
|
||||||
// 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理
|
|
||||||
ecw := AcquireErrorCapturingResponseWriter(c)
|
|
||||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
|
||||||
|
|
||||||
//
|
|
||||||
// 调用 FileServer 处理请求
|
|
||||||
fileServer.ServeHTTP(ecw, c.Request)
|
|
||||||
|
|
||||||
// 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler
|
|
||||||
ecw.processAfterFileServer()
|
|
||||||
|
|
||||||
// 恢复原始请求路径,以便后续中间件或日志记录使用
|
|
||||||
c.Request.URL.Path = requestPath
|
|
||||||
|
|
||||||
// 中止处理链,因为 FileServer 已经处理了响应
|
|
||||||
c.Abort()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Group的StaticDir方式
|
|
||||||
func (group *RouterGroup) StaticDir(relativePath, rootPath string) {
|
|
||||||
// 清理路径
|
|
||||||
relativePath = path.Clean(relativePath)
|
|
||||||
rootPath = path.Clean(rootPath)
|
|
||||||
|
|
||||||
// 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径
|
|
||||||
if !strings.HasSuffix(relativePath, "/") {
|
|
||||||
relativePath += "/"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建一个文件系统处理器
|
|
||||||
fileServer := http.FileServer(http.Dir(rootPath))
|
|
||||||
|
|
||||||
// 注册一个捕获所有路径的路由,使用自定义处理器
|
|
||||||
// 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD
|
|
||||||
// 我们可以通过在处理函数内部检查方法来限制
|
|
||||||
group.ANY(relativePath+"*filepath", func(c *Context) {
|
|
||||||
// 检查是否是 GET 或 HEAD 方法
|
|
||||||
if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead {
|
|
||||||
// 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件
|
|
||||||
if group.engine.HandleMethodNotAllowed {
|
|
||||||
c.Next()
|
|
||||||
} else {
|
|
||||||
// 否则,返回 405 Method Not Allowed
|
|
||||||
group.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
|
||||||
}
|
}
|
||||||
return
|
// 尝试不区分大小写的查找
|
||||||
}
|
// 直接在 rootNode 上调用 findCaseInsensitivePath 方法
|
||||||
|
ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash)
|
||||||
requestPath := c.Request.URL.Path
|
if found && engine.RedirectFixedPath {
|
||||||
|
c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径
|
||||||
// 获取捕获到的文件路径参数
|
return
|
||||||
filepath := c.Param("filepath")
|
|
||||||
|
|
||||||
// 构造文件服务器需要处理的请求路径
|
|
||||||
// FileServer 会将请求路径与 http.Dir 的根路径结合
|
|
||||||
// 我们需要移除相对路径前缀,只保留文件路径部分
|
|
||||||
// 例如,如果 relativePath 是 "/static/",请求是 "/static/js/app.js"
|
|
||||||
// FileServer 需要的路径是 "/js/app.js"
|
|
||||||
// 这里的 filepath 参数已经包含了 "/" 前缀,例如 "/js/app.js"
|
|
||||||
// 所以直接使用 filepath 即可
|
|
||||||
c.Request.URL.Path = filepath
|
|
||||||
|
|
||||||
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
|
||||||
// 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理
|
|
||||||
ecw := AcquireErrorCapturingResponseWriter(c)
|
|
||||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
|
||||||
|
|
||||||
//
|
|
||||||
// 调用 FileServer 处理请求
|
|
||||||
fileServer.ServeHTTP(ecw, c.Request)
|
|
||||||
|
|
||||||
// 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler
|
|
||||||
ecw.processAfterFileServer()
|
|
||||||
|
|
||||||
// 恢复原始请求路径,以便后续中间件或日志记录使用
|
|
||||||
c.Request.URL.Path = requestPath
|
|
||||||
|
|
||||||
// 中止处理链,因为 FileServer 已经处理了响应
|
|
||||||
c.Abort()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Static File 传入一个文件路径, 使用FileServer进行处理
|
|
||||||
func (engine *Engine) StaticFile(relativePath, filePath string) {
|
|
||||||
// 清理路径
|
|
||||||
relativePath = path.Clean(relativePath)
|
|
||||||
filePath = path.Clean(filePath)
|
|
||||||
|
|
||||||
// 创建一个文件系统处理器,指向包含目标文件的目录
|
|
||||||
// http.Dir 需要一个目录路径
|
|
||||||
dir := path.Dir(filePath)
|
|
||||||
fileName := path.Base(filePath)
|
|
||||||
fileServer := http.FileServer(http.Dir(dir))
|
|
||||||
|
|
||||||
FileHandle := func(c *Context) {
|
|
||||||
// 检查是否是 GET 或 HEAD 方法
|
|
||||||
if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead {
|
|
||||||
// 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件
|
|
||||||
if engine.HandleMethodNotAllowed {
|
|
||||||
c.Next()
|
|
||||||
} else {
|
|
||||||
// 否则,返回 405 Method Not Allowed
|
|
||||||
engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
requestPath := c.Request.URL.Path
|
|
||||||
|
|
||||||
// 构造文件服务器需要处理的请求路径
|
|
||||||
// FileServer 会将请求路径与 http.Dir 的根路径结合
|
|
||||||
// 我们需要将请求路径设置为文件名,以便 FileServer 找到正确的文件
|
|
||||||
c.Request.URL.Path = "/" + fileName // FileServer 期望路径以 / 开头
|
|
||||||
|
|
||||||
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
|
||||||
ecw := AcquireErrorCapturingResponseWriter(c)
|
|
||||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
|
||||||
|
|
||||||
// 调用 FileServer 处理请求
|
|
||||||
fileServer.ServeHTTP(ecw, c.Request)
|
|
||||||
|
|
||||||
// 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler
|
|
||||||
ecw.processAfterFileServer()
|
|
||||||
|
|
||||||
// 恢复原始请求路径
|
|
||||||
c.Request.URL.Path = requestPath
|
|
||||||
|
|
||||||
// 中止处理链,因为 FileServer 已经处理了响应
|
|
||||||
c.Abort()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册一个精确匹配的路由
|
// 构建处理链
|
||||||
engine.GET(relativePath, FileHandle)
|
// 组合全局中间件和路由处理函数
|
||||||
engine.HEAD(relativePath, FileHandle)
|
handlers := engine.globalHandlers
|
||||||
engine.OPTIONS(relativePath, FileHandle)
|
|
||||||
|
|
||||||
}
|
// 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由
|
||||||
|
// 则在全局中间件之后添加 MethodNotAllowed 处理器
|
||||||
// Group的StaticFile
|
if engine.HandleMethodNotAllowed {
|
||||||
func (group *RouterGroup) StaticFile(relativePath, filePath string) {
|
handlers = append(handlers, MethodNotAllowed())
|
||||||
// 清理路径
|
|
||||||
relativePath = path.Clean(relativePath)
|
|
||||||
filePath = path.Clean(filePath)
|
|
||||||
|
|
||||||
// 创建一个文件系统处理器,指向包含目标文件的目录
|
|
||||||
// http.Dir 需要一个目录路径
|
|
||||||
dir := path.Dir(filePath)
|
|
||||||
fileName := path.Base(filePath)
|
|
||||||
fileServer := http.FileServer(http.Dir(dir))
|
|
||||||
|
|
||||||
FileHandle := func(c *Context) {
|
|
||||||
// 检查是否是 GET 或 HEAD 方法
|
|
||||||
if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead {
|
|
||||||
// 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件
|
|
||||||
if group.engine.HandleMethodNotAllowed {
|
|
||||||
c.Next()
|
|
||||||
} else {
|
|
||||||
// 否则,返回 405 Method Not Allowed
|
|
||||||
group.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed"))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
requestPath := c.Request.URL.Path
|
|
||||||
|
|
||||||
// 构造文件服务器需要处理的请求路径
|
|
||||||
// FileServer 会将请求路径与 http.Dir 的根路径结合
|
|
||||||
// 我们需要将请求路径设置为文件名,以便 FileServer 找到正确的文件
|
|
||||||
c.Request.URL.Path = "/" + fileName // FileServer 期望路径以 / 开头
|
|
||||||
|
|
||||||
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
|
||||||
ecw := AcquireErrorCapturingResponseWriter(c)
|
|
||||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
|
||||||
|
|
||||||
// 调用 FileServer 处理请求
|
|
||||||
fileServer.ServeHTTP(ecw, c.Request)
|
|
||||||
|
|
||||||
// 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler
|
|
||||||
ecw.processAfterFileServer()
|
|
||||||
|
|
||||||
// 恢复原始请求路径
|
|
||||||
c.Request.URL.Path = requestPath
|
|
||||||
|
|
||||||
// 中止处理链,因为 FileServer 已经处理了响应
|
|
||||||
c.Abort()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册一个精确匹配的路由
|
// 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed
|
||||||
group.GET(relativePath, FileHandle)
|
// 则在处理链的最后添加 UnMatchFS 处理器
|
||||||
group.HEAD(relativePath, FileHandle)
|
if engine.unMatchFS.ServeUnmatchedAsFS {
|
||||||
group.OPTIONS(relativePath, FileHandle)
|
/*
|
||||||
}
|
var unMatchFSHandle = c.engine.unMatchFileServer
|
||||||
|
handlers = append(handlers, unMatchFSHandle)
|
||||||
// StaticFS
|
*/
|
||||||
func (engine *Engine) StaticFS(relativePath string, fs http.FileSystem) {
|
handlers = append(handlers, engine.UnMatchFSRoutes...)
|
||||||
// 清理路径
|
|
||||||
relativePath = path.Clean(relativePath)
|
|
||||||
|
|
||||||
// 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径
|
|
||||||
if !strings.HasSuffix(relativePath, "/") {
|
|
||||||
relativePath += "/"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册一个捕获所有路径的路由,使用 FileServer 处理器
|
// 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS
|
||||||
engine.ANY(relativePath+"*filepath", FileServer(fs))
|
// 则在处理链的最后添加 NoRoute 处理器
|
||||||
}
|
if engine.noRoute != nil {
|
||||||
|
handlers = append(handlers, engine.noRoute)
|
||||||
// Group的StaticFS
|
} else if len(engine.noRoutes) > 0 {
|
||||||
func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) {
|
handlers = append(handlers, engine.noRoutes...)
|
||||||
// 清理路径
|
|
||||||
relativePath = path.Clean(relativePath)
|
|
||||||
|
|
||||||
// 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径
|
|
||||||
if !strings.HasSuffix(relativePath, "/") {
|
|
||||||
relativePath += "/"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册一个捕获所有路径的路由,使用 FileServer 处理器
|
handlers = append(handlers, NotFound())
|
||||||
group.ANY(relativePath+"*filepath", FileServer(fs))
|
|
||||||
}
|
c.handlers = handlers
|
||||||
|
c.Next() // 执行处理函数链
|
||||||
// 维护一个Methods列表
|
//c.Writer.Flush() // 确保所有缓冲的响应数据被发送
|
||||||
var (
|
|
||||||
MethodGet = "GET"
|
|
||||||
MethodHead = "HEAD"
|
|
||||||
MethodPost = "POST"
|
|
||||||
MethodPut = "PUT"
|
|
||||||
MethodPatch = "PATCH"
|
|
||||||
MethodDelete = "DELETE"
|
|
||||||
MethodConnect = "CONNECT"
|
|
||||||
MethodOptions = "OPTIONS"
|
|
||||||
MethodTrace = "TRACE"
|
|
||||||
)
|
|
||||||
|
|
||||||
var MethodsSet = map[string]struct{}{
|
|
||||||
MethodGet: {},
|
|
||||||
MethodHead: {},
|
|
||||||
MethodPost: {},
|
|
||||||
MethodPut: {},
|
|
||||||
MethodPatch: {},
|
|
||||||
MethodDelete: {},
|
|
||||||
MethodConnect: {},
|
|
||||||
MethodOptions: {},
|
|
||||||
MethodTrace: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
|
||||||
// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"})
|
|
||||||
// relativePath 是相对于当前组或 Engine 的路径
|
|
||||||
// handlers 是处理函数链
|
|
||||||
func (engine *Engine) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) {
|
|
||||||
for _, method := range methods {
|
|
||||||
if _, ok := MethodsSet[method]; !ok {
|
|
||||||
panic("invalid method: " + method)
|
|
||||||
}
|
|
||||||
engine.Handle(method, relativePath, handlers...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleFunc 注册一个或多个 HTTP 方法的路由
|
|
||||||
// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"})
|
|
||||||
// relativePath 是相对于当前组或 Engine 的路径
|
|
||||||
// handlers 是处理函数链
|
|
||||||
func (group *RouterGroup) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) {
|
|
||||||
for _, method := range methods {
|
|
||||||
if _, ok := MethodsSet[method]; !ok {
|
|
||||||
panic("invalid method: " + method)
|
|
||||||
}
|
|
||||||
group.Handle(method, relativePath, handlers...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileServer方式, 返回一个HandleFunc, 统一化处理
|
|
||||||
func FileServer(fs http.FileSystem) HandlerFunc {
|
|
||||||
return func(c *Context) {
|
|
||||||
// 检查是否是 GET 或 HEAD 方法
|
|
||||||
if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead {
|
|
||||||
// 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件
|
|
||||||
if c.engine.HandleMethodNotAllowed {
|
|
||||||
c.Next()
|
|
||||||
} else {
|
|
||||||
// 否则,返回 405 Method Not Allowed
|
|
||||||
c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
|
||||||
ecw := AcquireErrorCapturingResponseWriter(c)
|
|
||||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
|
||||||
|
|
||||||
// 调用 http.FileServer 处理请求
|
|
||||||
http.FileServer(fs).ServeHTTP(ecw, c.Request)
|
|
||||||
|
|
||||||
// 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler
|
|
||||||
ecw.processAfterFileServer()
|
|
||||||
|
|
||||||
// 中止处理链,因为 FileServer 已经处理了响应
|
|
||||||
c.Abort()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
278
fileserver.go
Normal file
278
fileserver.go
Normal file
|
|
@ -0,0 +1,278 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// === FileServer相关 ===
|
||||||
|
|
||||||
|
var allowedFileServerMethods = map[string]struct{}{
|
||||||
|
http.MethodGet: {},
|
||||||
|
http.MethodHead: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileServer方式, 返回一个HandleFunc, 统一化处理
|
||||||
|
func FileServer(fs http.FileSystem) HandlerFunc {
|
||||||
|
if fs == nil {
|
||||||
|
return func(c *Context) {
|
||||||
|
c.ErrorUseHandle(500, errors.New("Input FileSystem is nil"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fileServerInstance := http.FileServer(fs)
|
||||||
|
return func(c *Context) {
|
||||||
|
FileServerHandleServe(c, fileServerInstance)
|
||||||
|
|
||||||
|
// 中止处理链,因为 FileServer 已经处理了响应
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FileServerHandleServe(c *Context, fsHandle http.Handler) {
|
||||||
|
if fsHandle == nil {
|
||||||
|
ErrInputFSisNil := errors.New("Input FileSystem Handle is nil")
|
||||||
|
c.AddError(ErrInputFSisNil)
|
||||||
|
// 500
|
||||||
|
c.ErrorUseHandle(http.StatusInternalServerError, ErrInputFSisNil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是 GET 或 HEAD 方法
|
||||||
|
if _, ok := allowedFileServerMethods[c.Request.Method]; !ok {
|
||||||
|
// 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件
|
||||||
|
if c.engine.HandleMethodNotAllowed {
|
||||||
|
c.Next()
|
||||||
|
} else {
|
||||||
|
if c.engine.noRoute == nil {
|
||||||
|
if c.Request.Method == http.MethodOptions {
|
||||||
|
//返回allow get
|
||||||
|
c.Writer.Header().Set("Allow", "GET, HEAD")
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// 否则,返回 405 Method Not Allowed
|
||||||
|
c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
||||||
|
ecw := AcquireErrorCapturingResponseWriter(c)
|
||||||
|
defer ReleaseErrorCapturingResponseWriter(ecw)
|
||||||
|
|
||||||
|
// 调用 http.FileServer 处理请求
|
||||||
|
fsHandle.ServeHTTP(ecw, c.Request)
|
||||||
|
|
||||||
|
// 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler
|
||||||
|
ecw.processAfterFileServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaticDir 传入一个文件夹路径, 使用FileServer进行处理
|
||||||
|
// r.StaticDir("/test/*filepath", "/var/www/test")
|
||||||
|
func (engine *Engine) StaticDir(relativePath, rootPath string) {
|
||||||
|
// 清理路径
|
||||||
|
relativePath = path.Clean(relativePath)
|
||||||
|
rootPath = path.Clean(rootPath)
|
||||||
|
|
||||||
|
// 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径
|
||||||
|
if !strings.HasSuffix(relativePath, "/") {
|
||||||
|
relativePath += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建一个文件系统处理器
|
||||||
|
fileServer := http.FileServer(http.Dir(rootPath))
|
||||||
|
|
||||||
|
// 注册一个捕获所有路径的路由,使用自定义处理器
|
||||||
|
// 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD
|
||||||
|
// 我们可以通过在处理函数内部检查方法来限制
|
||||||
|
engine.ANY(relativePath+"*filepath", GetStaticDirHandleFunc(fileServer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group的StaticDir方式
|
||||||
|
func (group *RouterGroup) StaticDir(relativePath, rootPath string) {
|
||||||
|
// 清理路径
|
||||||
|
relativePath = path.Clean(relativePath)
|
||||||
|
rootPath = path.Clean(rootPath)
|
||||||
|
|
||||||
|
// 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径
|
||||||
|
if !strings.HasSuffix(relativePath, "/") {
|
||||||
|
relativePath += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建一个文件系统处理器
|
||||||
|
fileServer := http.FileServer(http.Dir(rootPath))
|
||||||
|
|
||||||
|
// 注册一个捕获所有路径的路由,使用自定义处理器
|
||||||
|
// 注意:这里使用 ANY 方法,但 FileServer 通常只处理 GET 和 HEAD
|
||||||
|
// 我们可以通过在处理函数内部检查方法来限制
|
||||||
|
group.ANY(relativePath+"*filepath", GetStaticDirHandleFunc(fileServer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaticDirHandleFunc
|
||||||
|
func (engine *Engine) GetStaticDirHandle(rootPath string) HandlerFunc {
|
||||||
|
// 清理路径
|
||||||
|
rootPath = path.Clean(rootPath)
|
||||||
|
|
||||||
|
// 创建一个文件系统处理器
|
||||||
|
fileServer := http.FileServer(http.Dir(rootPath))
|
||||||
|
|
||||||
|
return GetStaticDirHandleFunc(fileServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaticDirHandleFunc
|
||||||
|
func (group *RouterGroup) GetStaticDirHandle(rootPath string) HandlerFunc { // 清理路径
|
||||||
|
return group.engine.GetStaticDirHandle(rootPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaticDirHandle
|
||||||
|
func GetStaticDirHandleFunc(fsHandle http.Handler) HandlerFunc {
|
||||||
|
return func(c *Context) {
|
||||||
|
requestPath := c.Request.URL.Path
|
||||||
|
|
||||||
|
// 获取捕获到的文件路径参数
|
||||||
|
filepath := c.Param("filepath")
|
||||||
|
|
||||||
|
// 构造文件服务器需要处理的请求路径
|
||||||
|
c.Request.URL.Path = filepath
|
||||||
|
|
||||||
|
FileServerHandleServe(c, fsHandle)
|
||||||
|
|
||||||
|
// 恢复原始请求路径,以便后续中间件或日志记录使用
|
||||||
|
c.Request.URL.Path = requestPath
|
||||||
|
|
||||||
|
// 中止处理链,因为 FileServer 已经处理了响应
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Static File 传入一个文件路径, 使用FileServer进行处理
|
||||||
|
func (engine *Engine) StaticFile(relativePath, filePath string) {
|
||||||
|
// 清理路径
|
||||||
|
relativePath = path.Clean(relativePath)
|
||||||
|
filePath = path.Clean(filePath)
|
||||||
|
|
||||||
|
FileHandle := engine.GetStaticFileHandle(filePath)
|
||||||
|
|
||||||
|
// 注册一个精确匹配的路由
|
||||||
|
engine.GET(relativePath, FileHandle)
|
||||||
|
engine.HEAD(relativePath, FileHandle)
|
||||||
|
engine.OPTIONS(relativePath, FileHandle)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group的StaticFile
|
||||||
|
func (group *RouterGroup) StaticFile(relativePath, filePath string) {
|
||||||
|
// 清理路径
|
||||||
|
relativePath = path.Clean(relativePath)
|
||||||
|
filePath = path.Clean(filePath)
|
||||||
|
|
||||||
|
FileHandle := group.GetStaticFileHandle(filePath)
|
||||||
|
|
||||||
|
// 注册一个精确匹配的路由
|
||||||
|
group.GET(relativePath, FileHandle)
|
||||||
|
group.HEAD(relativePath, FileHandle)
|
||||||
|
group.OPTIONS(relativePath, FileHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaticFileHandleFunc
|
||||||
|
func (engine *Engine) GetStaticFileHandle(filePath string) HandlerFunc {
|
||||||
|
// 清理路径
|
||||||
|
filePath = path.Clean(filePath)
|
||||||
|
|
||||||
|
// 创建一个文件系统处理器,指向包含目标文件的目录
|
||||||
|
dir := path.Dir(filePath)
|
||||||
|
fileName := path.Base(filePath)
|
||||||
|
fileServer := http.FileServer(http.Dir(dir))
|
||||||
|
|
||||||
|
return GetStaticFileHandleFunc(fileServer, fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaticFileHandleFunc
|
||||||
|
func (group *RouterGroup) GetStaticFileHandle(filePath string) HandlerFunc {
|
||||||
|
// 清理路径
|
||||||
|
filePath = path.Clean(filePath)
|
||||||
|
|
||||||
|
// 创建一个文件系统处理器,指向包含目标文件的目录
|
||||||
|
dir := path.Dir(filePath)
|
||||||
|
fileName := path.Base(filePath)
|
||||||
|
fileServer := http.FileServer(http.Dir(dir))
|
||||||
|
|
||||||
|
return GetStaticFileHandleFunc(fileServer, fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaticFileHandleFunc
|
||||||
|
func GetStaticFileHandleFunc(fsHandle http.Handler, fileName string) HandlerFunc {
|
||||||
|
return func(c *Context) {
|
||||||
|
requestPath := c.Request.URL.Path
|
||||||
|
|
||||||
|
// 构造文件服务器需要处理的请求路径
|
||||||
|
c.Request.URL.Path = "/" + fileName
|
||||||
|
|
||||||
|
FileServerHandleServe(c, fsHandle)
|
||||||
|
|
||||||
|
// 恢复原始请求路径
|
||||||
|
c.Request.URL.Path = requestPath
|
||||||
|
|
||||||
|
// 中止处理链,因为 FileServer 已经处理了响应
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaticFS
|
||||||
|
func (engine *Engine) StaticFS(relativePath string, fs http.FileSystem) {
|
||||||
|
// 清理路径
|
||||||
|
relativePath = path.Clean(relativePath)
|
||||||
|
|
||||||
|
// 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径
|
||||||
|
if !strings.HasSuffix(relativePath, "/") {
|
||||||
|
relativePath += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
fileServer := http.FileServer(fs)
|
||||||
|
engine.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group的StaticFS
|
||||||
|
func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) {
|
||||||
|
// 清理路径
|
||||||
|
relativePath = path.Clean(relativePath)
|
||||||
|
|
||||||
|
// 确保相对路径以 '/' 结尾,以便 FileServer 正确处理子路径
|
||||||
|
if !strings.HasSuffix(relativePath, "/") {
|
||||||
|
relativePath += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
fileServer := http.FileServer(fs)
|
||||||
|
group.ANY(relativePath+"*filepath", GetStaticFSHandleFunc(fileServer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaticFSHandleFunc
|
||||||
|
func GetStaticFSHandleFunc(fsHandle http.Handler) HandlerFunc {
|
||||||
|
return func(c *Context) {
|
||||||
|
|
||||||
|
FileServerHandleServe(c, fsHandle)
|
||||||
|
|
||||||
|
// 中止处理链,因为 FileServer 已经处理了响应
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaticFSHandleFunc
|
||||||
|
func (engine *Engine) GetStaticFSHandle(fs http.FileSystem) HandlerFunc {
|
||||||
|
fileServer := http.FileServer(fs)
|
||||||
|
return GetStaticFSHandleFunc(fileServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStaticFSHandleFunc
|
||||||
|
func (group *RouterGroup) GetStaticFSHandle(fs http.FileSystem) HandlerFunc {
|
||||||
|
fileServer := http.FileServer(fs)
|
||||||
|
return GetStaticFSHandleFunc(fileServer)
|
||||||
|
}
|
||||||
92
maxreader.go
Normal file
92
maxreader.go
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrBodyTooLarge 是当读取的字节数超过 MaxBytesReader 设置的限制时返回的错误.
|
||||||
|
// 将其定义为可导出的变量, 方便调用方使用 errors.Is 进行判断.
|
||||||
|
var ErrBodyTooLarge = fmt.Errorf("body too large")
|
||||||
|
|
||||||
|
// maxBytesReader 是一个实现了 io.ReadCloser 接口的结构体.
|
||||||
|
// 它包装了另一个 io.ReadCloser, 并限制了从其中读取的最大字节数.
|
||||||
|
type maxBytesReader struct {
|
||||||
|
// r 是底层的 io.ReadCloser.
|
||||||
|
r io.ReadCloser
|
||||||
|
// n 是允许读取的最大字节数.
|
||||||
|
n int64
|
||||||
|
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
|
||||||
|
read atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
|
||||||
|
// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误.
|
||||||
|
//
|
||||||
|
// 如果 r 为 nil, 会 panic.
|
||||||
|
// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r.
|
||||||
|
func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
||||||
|
if r == nil {
|
||||||
|
panic("NewMaxBytesReader called with a nil reader")
|
||||||
|
}
|
||||||
|
// 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser.
|
||||||
|
if n < 0 {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
return &maxBytesReader{
|
||||||
|
r: r,
|
||||||
|
n: n,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
|
||||||
|
func (mbr *maxBytesReader) Read(p []byte) (int, error) {
|
||||||
|
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
|
||||||
|
readSoFar := mbr.read.Load()
|
||||||
|
|
||||||
|
// 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误.
|
||||||
|
if readSoFar >= mbr.n {
|
||||||
|
return 0, ErrBodyTooLarge
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算当前还可以读取多少字节.
|
||||||
|
remaining := mbr.n - readSoFar
|
||||||
|
|
||||||
|
// 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度.
|
||||||
|
// 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数.
|
||||||
|
if int64(len(p)) > remaining {
|
||||||
|
p = p[:remaining]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从底层 Reader 读取数据.
|
||||||
|
n, err := mbr.r.Read(p)
|
||||||
|
|
||||||
|
// 如果实际读取到了数据, 更新原子计数器.
|
||||||
|
if n > 0 {
|
||||||
|
readSoFar = mbr.read.Add(int64(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果底层 Read 返回错误 (例如 io.EOF).
|
||||||
|
if err != nil {
|
||||||
|
// 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束.
|
||||||
|
// 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了.
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误.
|
||||||
|
// 这是处理"跨越"限制情况的关键.
|
||||||
|
if readSoFar > mbr.n {
|
||||||
|
// 返回实际读取的字节数 n, 并附上超限错误.
|
||||||
|
// 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭.
|
||||||
|
return n, ErrBodyTooLarge
|
||||||
|
}
|
||||||
|
|
||||||
|
// 一切正常, 返回读取的字节数和 nil 错误.
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 方法关闭底层的 ReadCloser, 保证资源释放.
|
||||||
|
func (mbr *maxBytesReader) Close() error {
|
||||||
|
return mbr.r.Close()
|
||||||
|
}
|
||||||
118
mergectx.go
Normal file
118
mergectx.go
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
|
||||||
|
type mergedContext struct {
|
||||||
|
// 嵌入一个基础 context, 它持有最早的 deadline 和取消信号.
|
||||||
|
context.Context
|
||||||
|
// 保存了所有的父 context, 用于 Value() 方法的查找.
|
||||||
|
parents []context.Context
|
||||||
|
// 用于手动取消此 mergedContext 的函数.
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeCtx 创建并返回一个新的 context.Context.
|
||||||
|
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
|
||||||
|
// 自动被取消 (逻辑或关系).
|
||||||
|
//
|
||||||
|
// 新的 context 会继承:
|
||||||
|
// - Deadline: 所有父 context 中最早的截止时间.
|
||||||
|
// - Value: 按传入顺序从第一个能找到值的父 context 中获取值.
|
||||||
|
func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.CancelFunc) {
|
||||||
|
if len(parents) == 0 {
|
||||||
|
return context.WithCancel(context.Background())
|
||||||
|
}
|
||||||
|
if len(parents) == 1 {
|
||||||
|
return context.WithCancel(parents[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
var earliestDeadline time.Time
|
||||||
|
for _, p := range parents {
|
||||||
|
if deadline, ok := p.Deadline(); ok {
|
||||||
|
if earliestDeadline.IsZero() || deadline.Before(earliestDeadline) {
|
||||||
|
earliestDeadline = deadline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var baseCtx context.Context
|
||||||
|
var baseCancel context.CancelFunc
|
||||||
|
if !earliestDeadline.IsZero() {
|
||||||
|
baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline)
|
||||||
|
} else {
|
||||||
|
baseCtx, baseCancel = context.WithCancel(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
mc := &mergedContext{
|
||||||
|
Context: baseCtx,
|
||||||
|
parents: parents,
|
||||||
|
cancel: baseCancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动一个监控 goroutine.
|
||||||
|
go func() {
|
||||||
|
defer mc.cancel()
|
||||||
|
|
||||||
|
// orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭.
|
||||||
|
// 同时监听 baseCtx.Done() 以便支持手动取消.
|
||||||
|
select {
|
||||||
|
case <-orDone(mc.parents...):
|
||||||
|
case <-mc.Context.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return mc, mc.cancel
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value 返回当前Ctx Value
|
||||||
|
func (mc *mergedContext) Value(key any) any {
|
||||||
|
return mc.Context.Value(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deadline 实现了 context.Context 的 Deadline 方法.
|
||||||
|
func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) {
|
||||||
|
return mc.Context.Deadline()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done 实现了 context.Context 的 Done 方法.
|
||||||
|
func (mc *mergedContext) Done() <-chan struct{} {
|
||||||
|
return mc.Context.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err 实现了 context.Context 的 Err 方法.
|
||||||
|
func (mc *mergedContext) Err() error {
|
||||||
|
return mc.Context.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// orDone 是一个辅助函数, 返回一个 channel.
|
||||||
|
// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭.
|
||||||
|
// 这是一个非阻塞的、不会泄漏 goroutine 的实现.
|
||||||
|
func orDone(contexts ...context.Context) <-chan struct{} {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
var once sync.Once
|
||||||
|
closeDone := func() {
|
||||||
|
once.Do(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 为每个父 context 启动一个 goroutine.
|
||||||
|
for _, ctx := range contexts {
|
||||||
|
go func(c context.Context) {
|
||||||
|
select {
|
||||||
|
case <-c.Done():
|
||||||
|
closeDone()
|
||||||
|
case <-done:
|
||||||
|
// orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出.
|
||||||
|
}
|
||||||
|
}(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return done
|
||||||
|
}
|
||||||
25
touka.go
25
touka.go
|
|
@ -42,3 +42,28 @@ type RouteInfo struct {
|
||||||
Handler string // 处理函数名称
|
Handler string // 处理函数名称
|
||||||
Group string // 路由分组
|
Group string // 路由分组
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 维护一个Methods列表
|
||||||
|
var (
|
||||||
|
MethodGet = "GET"
|
||||||
|
MethodHead = "HEAD"
|
||||||
|
MethodPost = "POST"
|
||||||
|
MethodPut = "PUT"
|
||||||
|
MethodPatch = "PATCH"
|
||||||
|
MethodDelete = "DELETE"
|
||||||
|
MethodConnect = "CONNECT"
|
||||||
|
MethodOptions = "OPTIONS"
|
||||||
|
MethodTrace = "TRACE"
|
||||||
|
)
|
||||||
|
|
||||||
|
var MethodsSet = map[string]struct{}{
|
||||||
|
MethodGet: {},
|
||||||
|
MethodHead: {},
|
||||||
|
MethodPost: {},
|
||||||
|
MethodPut: {},
|
||||||
|
MethodPatch: {},
|
||||||
|
MethodDelete: {},
|
||||||
|
MethodConnect: {},
|
||||||
|
MethodOptions: {},
|
||||||
|
MethodTrace: {},
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue