Compare commits

...

4 commits

Author SHA1 Message Date
WJQSERVER
fa8f044b81
Merge pull request #19 from infinite-iroha/dev
update methods
2025-06-11 11:43:34 +08:00
wjqserver
896182417f fix header writer after status issue 2025-06-11 11:24:54 +08:00
wjqserver
9a2aeef0d0 remove ANY form MethodsSet to avoid conflict 2025-06-11 11:23:15 +08:00
wjqserver
bb822599b9 update methods 2025-06-11 11:11:54 +08:00
5 changed files with 230 additions and 12 deletions

View file

@ -1,6 +1,7 @@
package touka
import (
"bytes"
"context"
"encoding/gob"
"errors"
@ -14,6 +15,7 @@ import (
"net/url"
"strings"
"sync"
"time"
"github.com/fenthope/reco"
"github.com/go-json-experiment/json"
@ -128,6 +130,72 @@ func (c *Context) Get(key string) (value interface{}, exists bool) {
return
}
// GetString 从 Context 中获取一个字符串值
// 这是一个线程安全的操作
func (c *Context) GetString(key string) (value string, exists bool) {
if val, exists := c.Get(key); exists {
if str, ok := val.(string); ok {
return str, true
}
}
return "", false
}
// GetInt 从 Context 中获取一个 int 值
// 这是一个线程安全的操作
func (c *Context) GetInt(key string) (value int, exists bool) {
if val, exists := c.Get(key); exists {
if i, ok := val.(int); ok {
return i, true
}
}
return 0, false
}
// GetBool 从 Context 中获取一个 bool 值
// 这是一个线程安全的操作
func (c *Context) GetBool(key string) (value bool, exists bool) {
if val, exists := c.Get(key); exists {
if b, ok := val.(bool); ok {
return b, true
}
}
return false, false
}
// GetFloat64 从 Context 中获取一个 float64 值
// 这是一个线程安全的操作
func (c *Context) GetFloat64(key string) (value float64, exists bool) {
if val, exists := c.Get(key); exists {
if f, ok := val.(float64); ok {
return f, true
}
}
return 0.0, false
}
// GetTime 从 Context 中获取一个 time.Time 值
// 这是一个线程安全的操作
func (c *Context) GetTime(key string) (value time.Time, exists bool) {
if val, exists := c.Get(key); exists {
if t, ok := val.(time.Time); ok {
return t, true
}
}
return time.Time{}, false
}
// GetDuration 从 Context 中获取一个 time.Duration 值
// 这是一个线程安全的操作
func (c *Context) GetDuration(key string) (value time.Duration, exists bool) {
if val, exists := c.Get(key); exists {
if d, ok := val.(time.Duration); ok {
return d, true
}
}
return 0, false
}
// MustGet 从 Context 中获取一个值,如果不存在则 panic
// 适用于确定值一定存在的场景
func (c *Context) MustGet(key string) interface{} {
@ -369,7 +437,6 @@ func (c *Context) GetReqBody() io.ReadCloser {
return c.Request.Body
}
// GetReqBodyFull
// GetReqBodyFull 读取并返回请求体的所有内容
// 注意:请求体只能读取一次
func (c *Context) GetReqBodyFull() ([]byte, error) {
@ -385,6 +452,20 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
return data, nil
}
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
if c.Request.Body == nil {
return nil, nil
}
defer c.Request.Body.Close() // 确保请求体被关闭
data, err := io.ReadAll(c.Request.Body)
if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err))
return nil, fmt.Errorf("failed to read request body: %w", err)
}
return bytes.NewBuffer(data), nil
}
// RequestIP 返回客户端的 IP 地址
// 它会根据 Engine 的配置 (ForwardByClientIP) 尝试从 X-Forwarded-For 或 X-Real-IP 等头部获取,
// 否则回退到 Request.RemoteAddr
@ -512,6 +593,50 @@ func (c *Context) GetLogger() *reco.Logger {
return c.engine.LogReco
}
// GetReqQueryString
// GetReqQueryString 返回请求的原始查询字符串
func (c *Context) GetReqQueryString() string {
return c.Request.URL.RawQuery
}
// SetBodyStream 设置响应体为一个 io.Reader并指定内容长度
// 如果 contentSize 为 -1则表示内容长度未知将使用 Transfer-Encoding: chunked
func (c *Context) SetBodyStream(reader io.Reader, contentSize int) {
// 如果指定了内容长度且大于等于 0则设置 Content-Length 头部
if contentSize >= 0 {
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", contentSize))
} else {
// 如果内容长度未知,移除 Content-Length 头部,通常会使用 Transfer-Encoding: chunked
c.Writer.Header().Del("Content-Length")
}
// 确保在写入数据前设置状态码
if !c.Writer.Written() {
c.Writer.WriteHeader(http.StatusOK) // 默认 200 OK
}
// 将 reader 的内容直接复制到 ResponseWriter
// ResponseWriter 实现了 io.Writer 接口
_, err := copyb.Copy(c.Writer, reader)
if err != nil {
c.AddError(fmt.Errorf("failed to write stream: %w", err))
// 注意:这里可能无法设置错误状态码,因为头部可能已经发送
// 可以在调用 SetBodyStream 之前检查错误,或者在中间件中处理 Context.Errors
}
}
// GetRequestURI 返回请求的原始 URI
func (c *Context) GetRequestURI() string {
return c.Request.RequestURI
}
// GetRequestURIPath 返回请求的原始 URI 的路径部分
func (c *Context) GetRequestURIPath() string {
return c.Request.URL.Path
}
// == cookie ===
// SetSameSite 设置响应的 SameSite cookie 属性
func (c *Context) SetSameSite(samesite http.SameSite) {
c.sameSite = samesite

103
engine.go
View file

@ -3,6 +3,7 @@ package touka
import (
"context"
"errors"
"fmt"
"log"
"reflect"
"runtime"
@ -147,7 +148,7 @@ func New() *Engine {
}
//engine.SetProtocols(GetDefaultProtocolsConfig())
engine.SetDefaultProtocols()
engine.SetLogger(defaultLogRecoConfig)
engine.SetLoggerCfg(defaultLogRecoConfig)
// 初始化 Context Pool为每个新 Context 实例提供一个构造函数
engine.pool.New = func() interface{} {
return &Context{
@ -173,8 +174,13 @@ func Default() *Engine {
// === 外部操作方法 ===
// 配置日志Logger
func (engine *Engine) SetLogger(logcfg reco.Config) {
// SetLogger传入实例
func (engine *Engine) SetLogger(logger *reco.Logger) {
engine.LogReco = logger
}
// 配置日志LoggerCfg
func (engine *Engine) SetLoggerCfg(logcfg reco.Config) {
engine.LogReco = NewLogger(logcfg)
}
@ -659,8 +665,9 @@ func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IR
// == 其他操作方式 ===
// Static FileServer 传入一个文件夹路径, 使用FileServer进行处理
func (engine *Engine) Static(relativePath, rootPath string) {
// StaticDir 传入一个文件夹路径, 使用FileServer进行处理
// r.StaticDir("/test/*filepath", "/var/www/test")
func (engine *Engine) StaticDir(relativePath, rootPath string) {
// 清理路径
relativePath = path.Clean(relativePath)
rootPath = path.Clean(rootPath)
@ -723,8 +730,8 @@ func (engine *Engine) Static(relativePath, rootPath string) {
})
}
// Group的Static
func (group *RouterGroup) Static(relativePath, rootPath string) {
// Group的StaticDir方式
func (group *RouterGroup) StaticDir(relativePath, rootPath string) {
// 清理路径
relativePath = path.Clean(relativePath)
rootPath = path.Clean(rootPath)
@ -896,5 +903,85 @@ func (group *RouterGroup) StaticFile(relativePath, filePath string) {
group.GET(relativePath, FileHandle)
group.HEAD(relativePath, FileHandle)
group.OPTIONS(relativePath, FileHandle)
}
// 维护一个Methods列表
var (
MethodGet = "GET"
MethodHead = "HEAD"
MethodPost = "POST"
MethodPut = "PUT"
MethodPatch = "PATCH"
MethodDelete = "DELETE"
MethodConnect = "CONNECT"
MethodOptions = "OPTIONS"
MethodTrace = "TRACE"
)
var MethodsSet = map[string]struct{}{
MethodGet: {},
MethodHead: {},
MethodPost: {},
MethodPut: {},
MethodPatch: {},
MethodDelete: {},
MethodConnect: {},
MethodOptions: {},
MethodTrace: {},
}
// HandleFunc 注册一个或多个 HTTP 方法的路由
// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}
// relativePath 是相对于当前组或 Engine 的路径
// handlers 是处理函数链
func (engine *Engine) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) {
for _, method := range methods {
if _, ok := MethodsSet[method]; !ok {
panic("invalid method: " + method)
}
engine.Handle(method, relativePath, handlers...)
}
}
// HandleFunc 注册一个或多个 HTTP 方法的路由
// methods 参数是一个字符串切片,包含要注册的 HTTP 方法(例如 []string{"GET", "POST"}
// relativePath 是相对于当前组或 Engine 的路径
// handlers 是处理函数链
func (group *RouterGroup) HandleFunc(methods []string, relativePath string, handlers ...HandlerFunc) {
for _, method := range methods {
if _, ok := MethodsSet[method]; !ok {
panic("invalid method: " + method)
}
group.Handle(method, relativePath, handlers...)
}
}
// FileServer方式, 返回一个HandleFunc, 统一化处理
func FileServer(fs http.FileSystem) HandlerFunc {
return func(c *Context) {
// 检查是否是 GET 或 HEAD 方法
if c.Request.Method != http.MethodGet && c.Request.Method != http.MethodHead {
// 如果不是,且启用了 MethodNotAllowed 处理,则继续到 MethodNotAllowed 中间件
if c.engine.HandleMethodNotAllowed {
c.Next()
} else {
// 否则,返回 405 Method Not Allowed
c.engine.errorHandle.handler(c, http.StatusMethodNotAllowed, fmt.Errorf("Method %s is Not Allowed on FileServer", c.Request.Method))
}
return
}
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
ecw := AcquireErrorCapturingResponseWriter(c)
defer ReleaseErrorCapturingResponseWriter(ecw)
// 调用 http.FileServer 处理请求
http.FileServer(fs).ServeHTTP(ecw, c.Request)
// 在 FileServer 处理完成后,检查是否捕获到错误状态码,并调用 ErrorHandler
ecw.processAfterFileServer()
// 中止处理链,因为 FileServer 已经处理了响应
c.Abort()
}
}

2
go.mod
View file

@ -5,7 +5,7 @@ go 1.24.4
require (
github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4
github.com/WJQSERVER-STUDIO/httpc v0.7.0
github.com/fenthope/reco v0.0.1
github.com/fenthope/reco v0.0.2
github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8
)

4
go.sum
View file

@ -2,8 +2,8 @@ github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 h1:JLtFd00AdFg/TP+dtvIzLkdHwKU
github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4/go.mod h1:FZ6XE+4TKy4MOfX1xWKe6Rwsg0ucYFCdNh1KLvyKTfc=
github.com/WJQSERVER-STUDIO/httpc v0.7.0 h1:iHhqlxppJBjlmvsIjvLZKRbWXqSdbeSGGofjHGmqGJc=
github.com/WJQSERVER-STUDIO/httpc v0.7.0/go.mod h1:M7KNUZjjhCkzzcg9lBPs9YfkImI+7vqjAyjdA19+joE=
github.com/fenthope/reco v0.0.1 h1:GYcuXCEKYoctD0dFkiBC+t0RMTOyOiujBCin8bbLR3Y=
github.com/fenthope/reco v0.0.1/go.mod h1:mDkGLHte5udWTIcjQTxrABRcf56SSdxBOCLgrRDwI/Y=
github.com/fenthope/reco v0.0.2 h1:9+RdpZlQYuwQh00XGn8uFu7vcHeldlgufGnszXkyFwg=
github.com/fenthope/reco v0.0.2/go.mod h1:mDkGLHte5udWTIcjQTxrABRcf56SSdxBOCLgrRDwI/Y=
github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8 h1:o8UqXPI6SVwQt04RGsqKp3qqmbOfTNMqDrWsc4O47kk=
github.com/go-json-experiment/json v0.0.0-20250517221953-25912455fbc8/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=

View file

@ -34,3 +34,9 @@ func CloseLogger(logger *reco.Logger) {
return
}
}
func (engine *Engine) CloseLogger() {
if engine.LogReco != nil {
CloseLogger(engine.LogReco)
}
}