From bb822599b9a14112a67ff64853cb23ed3e0ab5f4 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 11 Jun 2025 11:11:54 +0800 Subject: [PATCH 1/3] update methods --- context.go | 127 ++++++++++++++++++++++++++++++++++++++++++++++++++++- engine.go | 105 +++++++++++++++++++++++++++++++++++++++---- go.mod | 2 +- go.sum | 4 +- logreco.go | 6 +++ 5 files changed, 232 insertions(+), 12 deletions(-) diff --git a/context.go b/context.go index a15ae3c..dd0f1a3 100644 --- a/context.go +++ b/context.go @@ -1,6 +1,7 @@ package touka import ( + "bytes" "context" "encoding/gob" "errors" @@ -14,6 +15,7 @@ import ( "net/url" "strings" "sync" + "time" "github.com/fenthope/reco" "github.com/go-json-experiment/json" @@ -128,6 +130,72 @@ func (c *Context) Get(key string) (value interface{}, exists bool) { return } +// GetString 从 Context 中获取一个字符串值 +// 这是一个线程安全的操作 +func (c *Context) GetString(key string) (value string, exists bool) { + if val, exists := c.Get(key); exists { + if str, ok := val.(string); ok { + return str, true + } + } + return "", false +} + +// GetInt 从 Context 中获取一个 int 值 +// 这是一个线程安全的操作 +func (c *Context) GetInt(key string) (value int, exists bool) { + if val, exists := c.Get(key); exists { + if i, ok := val.(int); ok { + return i, true + } + } + return 0, false +} + +// GetBool 从 Context 中获取一个 bool 值 +// 这是一个线程安全的操作 +func (c *Context) GetBool(key string) (value bool, exists bool) { + if val, exists := c.Get(key); exists { + if b, ok := val.(bool); ok { + return b, true + } + } + return false, false +} + +// GetFloat64 从 Context 中获取一个 float64 值 +// 这是一个线程安全的操作 +func (c *Context) GetFloat64(key string) (value float64, exists bool) { + if val, exists := c.Get(key); exists { + if f, ok := val.(float64); ok { + return f, true + } + } + return 0.0, false +} + +// GetTime 从 Context 中获取一个 time.Time 值 +// 这是一个线程安全的操作 +func (c *Context) GetTime(key string) (value time.Time, exists bool) { + if val, exists := c.Get(key); exists { + if t, ok := val.(time.Time); ok { + return t, true + } + } + return time.Time{}, false +} + +// GetDuration 从 Context 中获取一个 time.Duration 值 +// 这是一个线程安全的操作 +func (c *Context) GetDuration(key string) (value time.Duration, exists bool) { + if val, exists := c.Get(key); exists { + if d, ok := val.(time.Duration); ok { + return d, true + } + } + return 0, false +} + // MustGet 从 Context 中获取一个值,如果不存在则 panic // 适用于确定值一定存在的场景 func (c *Context) MustGet(key string) interface{} { @@ -369,7 +437,6 @@ func (c *Context) GetReqBody() io.ReadCloser { return c.Request.Body } -// GetReqBodyFull // GetReqBodyFull 读取并返回请求体的所有内容 // 注意:请求体只能读取一次 func (c *Context) GetReqBodyFull() ([]byte, error) { @@ -385,6 +452,20 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { return data, nil } +// 类似 GetReqBodyFull, 返回 *bytes.Buffer +func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { + if c.Request.Body == nil { + return nil, nil + } + defer c.Request.Body.Close() // 确保请求体被关闭 + data, err := io.ReadAll(c.Request.Body) + if err != nil { + c.AddError(fmt.Errorf("failed to read request body: %w", err)) + return nil, fmt.Errorf("failed to read request body: %w", err) + } + return bytes.NewBuffer(data), nil +} + // RequestIP 返回客户端的 IP 地址 // 它会根据 Engine 的配置 (ForwardByClientIP) 尝试从 X-Forwarded-For 或 X-Real-IP 等头部获取, // 否则回退到 Request.RemoteAddr @@ -512,6 +593,50 @@ func (c *Context) GetLogger() *reco.Logger { return c.engine.LogReco } +// GetReqQueryString +// GetReqQueryString 返回请求的原始查询字符串 +func (c *Context) GetReqQueryString() string { + return c.Request.URL.RawQuery +} + +// SetBodyStream 设置响应体为一个 io.Reader,并指定内容长度 +// 如果 contentSize 为 -1,则表示内容长度未知,将使用 Transfer-Encoding: chunked +func (c *Context) SetBodyStream(reader io.Reader, contentSize int) { + // 确保在写入数据前设置状态码 + if !c.Writer.Written() { + c.Writer.WriteHeader(http.StatusOK) // 默认 200 OK + } + + // 如果指定了内容长度且大于等于 0,则设置 Content-Length 头部 + if contentSize >= 0 { + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", contentSize)) + } else { + // 如果内容长度未知,移除 Content-Length 头部,通常会使用 Transfer-Encoding: chunked + c.Writer.Header().Del("Content-Length") + } + + // 将 reader 的内容直接复制到 ResponseWriter + // ResponseWriter 实现了 io.Writer 接口 + _, err := copyb.Copy(c.Writer, reader) + if err != nil { + c.AddError(fmt.Errorf("failed to write stream: %w", err)) + // 注意:这里可能无法设置错误状态码,因为头部可能已经发送 + // 可以在调用 SetBodyStream 之前检查错误,或者在中间件中处理 Context.Errors + } +} + +// GetRequestURI 返回请求的原始 URI +func (c *Context) GetRequestURI() string { + return c.Request.RequestURI +} + +// GetRequestURIPath 返回请求的原始 URI 的路径部分 +func (c *Context) GetRequestURIPath() string { + return c.Request.URL.Path +} + +// == cookie === + // SetSameSite 设置响应的 SameSite cookie 属性 func (c *Context) SetSameSite(samesite http.SameSite) { c.sameSite = samesite diff --git a/engine.go b/engine.go index da1d9ad..6ba8fee 100644 --- a/engine.go +++ b/engine.go @@ -3,6 +3,7 @@ package touka import ( "context" "errors" + "fmt" "log" "reflect" "runtime" @@ -147,7 +148,7 @@ func New() *Engine { } //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() - engine.SetLogger(defaultLogRecoConfig) + engine.SetLoggerCfg(defaultLogRecoConfig) // 初始化 Context Pool,为每个新 Context 实例提供一个构造函数 engine.pool.New = func() interface{} { return &Context{ @@ -173,8 +174,13 @@ func Default() *Engine { // === 外部操作方法 === -// 配置日志Logger -func (engine *Engine) SetLogger(logcfg reco.Config) { +// SetLogger传入实例 +func (engine *Engine) SetLogger(logger *reco.Logger) { + engine.LogReco = logger +} + +// 配置日志LoggerCfg +func (engine *Engine) SetLoggerCfg(logcfg reco.Config) { engine.LogReco = NewLogger(logcfg) } @@ -659,8 +665,9 @@ func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IR // == 其他操作方式 === -// Static FileServer 传入一个文件夹路径, 使用FileServer进行处理 -func (engine *Engine) Static(relativePath, rootPath string) { +// StaticDir 传入一个文件夹路径, 使用FileServer进行处理 +// r.StaticDir("/test/*filepath", "/var/www/test") +func (engine *Engine) StaticDir(relativePath, rootPath string) { // 清理路径 relativePath = path.Clean(relativePath) rootPath = path.Clean(rootPath) @@ -723,8 +730,8 @@ func (engine *Engine) Static(relativePath, rootPath string) { }) } -// Group的Static -func (group *RouterGroup) Static(relativePath, rootPath string) { +// Group的StaticDir方式 +func (group *RouterGroup) StaticDir(relativePath, rootPath string) { // 清理路径 relativePath = path.Clean(relativePath) rootPath = path.Clean(rootPath) @@ -896,5 +903,87 @@ func (group *RouterGroup) StaticFile(relativePath, filePath string) { group.GET(relativePath, FileHandle) group.HEAD(relativePath, FileHandle) group.OPTIONS(relativePath, FileHandle) - +} + +// 维护一个Methods列表 +var ( + MethodGet = "GET" + MethodHead = "HEAD" + MethodPost = "POST" + MethodPut = "PUT" + MethodPatch = "PATCH" + MethodDelete = "DELETE" + MethodConnect = "CONNECT" + MethodOptions = "OPTIONS" + MethodTrace = "TRACE" + MethodAny = "ANY" +) + +var MethodsSet = map[string]struct{}{ + MethodGet: {}, + MethodHead: {}, + MethodPost: {}, + MethodPut: {}, + MethodPatch: {}, + MethodDelete: {}, + MethodConnect: {}, + MethodOptions: {}, + MethodTrace: {}, + MethodAny: {}, +} + +// HandleFunc 注册一个或多个 HTTP 方法的路由 +// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) +// relativePath 是相对于当前组或 Engine 的路径 +// handlers 是处理函数链 +func (engine *Engine) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) { + for _, method := range methods { + if _, ok := MethodsSet[method]; !ok { + panic("invalid method: " + method) + } + engine.Handle(method, relativePath, handlers...) + } +} + +// HandleFunc 注册一个或多个 HTTP 方法的路由 +// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}) +// relativePath 是相对于当前组或 Engine 的路径 +// handlers 是处理函数链 +func (group *RouterGroup) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) { + for _, method := range methods { + if _, ok := MethodsSet[method]; !ok { + panic("invalid method: " + method) + } + group.Handle(method, relativePath, handlers...) + } +} + +// FileServer方式, 返回一个HandleFunc, 统一化处理 +func FileServer(fs http.FileSystem) HandlerFunc { + return func(c *Context) { + // 检查是否是 GET 或 HEAD 方法 + if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead { + // 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件 + if c.engine.HandleMethodNotAllowed { + c.Next() + } else { + // 否则,返回 405 Method Not Allowed + c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method)) + } + return + } + + // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + + // 调用 http.FileServer 处理请求 + http.FileServer(fs).ServeHTTP(ecw, c.Request) + + // 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler + ecw.processAfterFileServer() + + // 中止处理链,因为 FileServer 已经处理了响应 + c.Abort() + } } diff --git a/go.mod b/go.mod index 87fa154..167bd3e 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.24.4 require ( github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 github.com/WJQSERVER-STUDIO/httpc v0.7.0 - github.com/fenthope/reco v0.0.1 + github.com/fenthope/reco v0.0.2 github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8 ) diff --git a/go.sum b/go.sum index 3ea14ae..fa1a547 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 h1:JLtFd00AdFg/TP+dtvIzLkdHwKU github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4/go.mod h1:FZ6XE+4TKy4MOfX1xWKe6Rwsg0ucYFCdNh1KLvyKTfc= github.com/WJQSERVER-STUDIO/httpc v0.7.0 h1:iHhqlxppJBjlmvsIjvLZKRbWXqSdbeSGGofjHGmqGJc= github.com/WJQSERVER-STUDIO/httpc v0.7.0/go.mod h1:M7KNUZjjhCkzzcg9lBPs9YfkImI+7vqjAyjdA19+joE= -github.com/fenthope/reco v0.0.1 h1:GYcuXCEKYoctD0dFkiBC+t0RMTOyOiujBCin8bbLR3Y= -github.com/fenthope/reco v0.0.1/go.mod h1:mDkGLHte5udWTIcjQTxrABRcf56SSdxBOCLgrRDwI/Y= +github.com/fenthope/reco v0.0.2 h1:9+RdpZlQYuwQh00XGn8uFu7vcHeldlgufGnszXkyFwg= +github.com/fenthope/reco v0.0.2/go.mod h1:mDkGLHte5udWTIcjQTxrABRcf56SSdxBOCLgrRDwI/Y= github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8 h1:o8UqXPI6SVwQt04RGsqKp3qqmbOfTNMqDrWsc4O47kk= github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= diff --git a/logreco.go b/logreco.go index 7182c51..2fc1a7a 100644 --- a/logreco.go +++ b/logreco.go @@ -34,3 +34,9 @@ func CloseLogger(logger *reco.Logger) { return } } + +func (engine *Engine) CloseLogger() { + if engine.LogReco != nil { + CloseLogger(engine.LogReco) + } +} From 9a2aeef0d0b84febb30bd19b9ffc81a0e18fbf1c Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 11 Jun 2025 11:23:15 +0800 Subject: [PATCH 2/3] remove ANY form MethodsSet to avoid conflict --- engine.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/engine.go b/engine.go index 6ba8fee..eeafe6e 100644 --- a/engine.go +++ b/engine.go @@ -916,7 +916,6 @@ var ( MethodConnect = "CONNECT" MethodOptions = "OPTIONS" MethodTrace = "TRACE" - MethodAny = "ANY" ) var MethodsSet = map[string]struct{}{ @@ -929,7 +928,6 @@ var MethodsSet = map[string]struct{}{ MethodConnect: {}, MethodOptions: {}, MethodTrace: {}, - MethodAny: {}, } // HandleFunc 注册一个或多个 HTTP 方法的路由 From 896182417f7c3144e646308cfb7000a09dfbe1cb Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 11 Jun 2025 11:24:54 +0800 Subject: [PATCH 3/3] fix header writer after status issue --- context.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/context.go b/context.go index dd0f1a3..d4b9a5b 100644 --- a/context.go +++ b/context.go @@ -602,11 +602,6 @@ func (c *Context) GetReqQueryString() string { // SetBodyStream 设置响应体为一个 io.Reader,并指定内容长度 // 如果 contentSize 为 -1,则表示内容长度未知,将使用 Transfer-Encoding: chunked func (c *Context) SetBodyStream(reader io.Reader, contentSize int) { - // 确保在写入数据前设置状态码 - if !c.Writer.Written() { - c.Writer.WriteHeader(http.StatusOK) // 默认 200 OK - } - // 如果指定了内容长度且大于等于 0,则设置 Content-Length 头部 if contentSize >= 0 { c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", contentSize)) @@ -615,6 +610,11 @@ func (c *Context) SetBodyStream(reader io.Reader, contentSize int) { c.Writer.Header().Del("Content-Length") } + // 确保在写入数据前设置状态码 + if !c.Writer.Written() { + c.Writer.WriteHeader(http.StatusOK) // 默认 200 OK + } + // 将 reader 的内容直接复制到 ResponseWriter // ResponseWriter 实现了 io.Writer 接口 _, err := copyb.Copy(c.Writer, reader)