From cb86cb935aacf4958158eef6668a1a052b202a22 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 8 Jul 2025 13:26:18 +0800 Subject: [PATCH 1/2] add maxBytesReader & ctxMerge --- context.go | 45 +++++++++++++++++-- engine.go | 17 +++++-- maxreader.go | 92 ++++++++++++++++++++++++++++++++++++++ mergectx.go | 124 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 270 insertions(+), 8 deletions(-) create mode 100644 maxreader.go create mode 100644 mergectx.go diff --git a/context.go b/context.go index 23fa92f..5a8ce2d 100644 --- a/context.go +++ b/context.go @@ -58,6 +58,9 @@ type Context struct { engine *Engine sameSite http.SameSite + + // 请求体Body大小限制 + MaxRequestBodySize int64 } // --- Context 相关方法实现 --- @@ -83,6 +86,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { c.formCache = nil // 清空表单数据缓存 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 + c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize // c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员 } @@ -208,6 +212,11 @@ func (c *Context) MustGet(key string) interface{} { panic("Key \"" + key + "\" does not exist in context.") } +// SetMaxRequestBodySize +func (c *Context) SetMaxRequestBodySize(size int64) { + c.MaxRequestBodySize = size +} + // Query 从 URL 查询参数中获取值 // 懒加载解析查询参数,并进行缓存 func (c *Context) Query(key string) string { @@ -434,8 +443,22 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { if c.Request.Body == nil { return nil, nil } - defer c.Request.Body.Close() // 确保请求体被关闭 - data, err := copyb.ReadAll(c.Request.Body) + + var limitBytesReader io.ReadCloser + + if c.MaxRequestBodySize > 0 { + limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) + defer func() { + err := limitBytesReader.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() + } else { + limitBytesReader = c.Request.Body + } + + data, err := copyb.ReadAll(limitBytesReader) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -448,8 +471,22 @@ func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { if c.Request.Body == nil { return nil, nil } - defer c.Request.Body.Close() // 确保请求体被关闭 - data, err := copyb.ReadAll(c.Request.Body) + + var limitBytesReader io.ReadCloser + + if c.MaxRequestBodySize > 0 { + limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) + defer func() { + err := limitBytesReader.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() + } else { + limitBytesReader = c.Request.Body + } + + data, err := copyb.ReadAll(limitBytesReader) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) diff --git a/engine.go b/engine.go index 3d39cf0..ef16741 100644 --- a/engine.go +++ b/engine.go @@ -74,6 +74,9 @@ type Engine struct { // 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器 // 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置) TLSServerConfigurator func(*http.Server) + + // GlobalMaxRequestBodySize 全局请求体Body大小限制 + GlobalMaxRequestBodySize int64 } type ErrorHandle struct { @@ -171,10 +174,11 @@ func New() *Engine { unMatchFS: UnMatchFS{ ServeUnmatchedAsFS: false, }, - noRoute: nil, - noRoutes: make(HandlersChain, 0), - ServerConfigurator: nil, - TLSServerConfigurator: nil, + noRoute: nil, + noRoutes: make(HandlersChain, 0), + ServerConfigurator: nil, + TLSServerConfigurator: nil, + GlobalMaxRequestBodySize: -1, } //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() @@ -294,6 +298,11 @@ func (engine *Engine) SetProtocols(config *ProtocolsConfig) { engine.useDefaultProtocols = false } +// 配置全局Req Body大小限制 +func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) { + engine.GlobalMaxRequestBodySize = size +} + // 配置Req IP来源 Headers func (engine *Engine) SetRemoteIPHeaders(headers []string) { engine.RemoteIPHeaders = headers diff --git a/maxreader.go b/maxreader.go new file mode 100644 index 0000000..96ff025 --- /dev/null +++ b/maxreader.go @@ -0,0 +1,92 @@ +package touka + +import ( + "fmt" + "io" + "sync/atomic" +) + +// ErrBodyTooLarge 是当读取的字节数超过 MaxBytesReader 设置的限制时返回的错误. +// 将其定义为可导出的变量, 方便调用方使用 errors.Is 进行判断. +var ErrBodyTooLarge = fmt.Errorf("body too large") + +// maxBytesReader 是一个实现了 io.ReadCloser 接口的结构体. +// 它包装了另一个 io.ReadCloser, 并限制了从其中读取的最大字节数. +type maxBytesReader struct { + // r 是底层的 io.ReadCloser. + r io.ReadCloser + // n 是允许读取的最大字节数. + n int64 + // read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数. + read atomic.Int64 +} + +// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据, +// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误. +// +// 如果 r 为 nil, 会 panic. +// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r. +func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { + if r == nil { + panic("NewMaxBytesReader called with a nil reader") + } + // 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser. + if n < 0 { + return r + } + return &maxBytesReader{ + r: r, + n: n, + } +} + +// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制. +func (mbr *maxBytesReader) Read(p []byte) (int, error) { + // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. + readSoFar := mbr.read.Load() + + // 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误. + if readSoFar >= mbr.n { + return 0, ErrBodyTooLarge + } + + // 计算当前还可以读取多少字节. + remaining := mbr.n - readSoFar + + // 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度. + // 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数. + if int64(len(p)) > remaining { + p = p[:remaining] + } + + // 从底层 Reader 读取数据. + n, err := mbr.r.Read(p) + + // 如果实际读取到了数据, 更新原子计数器. + if n > 0 { + readSoFar = mbr.read.Add(int64(n)) + } + + // 如果底层 Read 返回错误 (例如 io.EOF). + if err != nil { + // 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束. + // 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了. + return n, err + } + + // 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误. + // 这是处理"跨越"限制情况的关键. + if readSoFar > mbr.n { + // 返回实际读取的字节数 n, 并附上超限错误. + // 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭. + return n, ErrBodyTooLarge + } + + // 一切正常, 返回读取的字节数和 nil 错误. + return n, nil +} + +// Close 方法关闭底层的 ReadCloser, 保证资源释放. +func (mbr *maxBytesReader) Close() error { + return mbr.r.Close() +} diff --git a/mergectx.go b/mergectx.go new file mode 100644 index 0000000..e6c223d --- /dev/null +++ b/mergectx.go @@ -0,0 +1,124 @@ +package touka + +import ( + "context" + "sync" + "time" +) + +// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. +type mergedContext struct { + // 嵌入一个基础 context, 它持有最早的 deadline 和取消信号. + context.Context + // 保存了所有的父 context, 用于 Value() 方法的查找. + parents []context.Context + // 用于手动取消此 mergedContext 的函数. + cancel context.CancelFunc +} + +// MergeCtx 创建并返回一个新的 context.Context. +// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时, +// 自动被取消 (逻辑或关系). +// +// 新的 context 会继承: +// - Deadline: 所有父 context 中最早的截止时间. +// - Value: 按传入顺序从第一个能找到值的父 context 中获取值. +func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.CancelFunc) { + if len(parents) == 0 { + return context.WithCancel(context.Background()) + } + if len(parents) == 1 { + return context.WithCancel(parents[0]) + } + + var earliestDeadline time.Time + for _, p := range parents { + if deadline, ok := p.Deadline(); ok { + if earliestDeadline.IsZero() || deadline.Before(earliestDeadline) { + earliestDeadline = deadline + } + } + } + + var baseCtx context.Context + var baseCancel context.CancelFunc + if !earliestDeadline.IsZero() { + baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline) + } else { + baseCtx, baseCancel = context.WithCancel(context.Background()) + } + + mc := &mergedContext{ + Context: baseCtx, + parents: parents, + cancel: baseCancel, + } + + // 启动一个监控 goroutine. + go func() { + defer mc.cancel() + + // orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭. + // 同时监听 baseCtx.Done() 以便支持手动取消. + select { + case <-orDone(mc.parents...): + case <-mc.Context.Done(): + } + }() + + return mc, mc.cancel +} + +// Value 实现了 context.Context 的 Value 方法. +// 它会按顺序遍历所有父 context, 并返回第一个找到的非 nil 值. +func (mc *mergedContext) Value(key any) any { + for _, p := range mc.parents { + if v := p.Value(key); v != nil { + return v + } + } + return nil +} + +// Deadline 实现了 context.Context 的 Deadline 方法. +func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) { + return mc.Context.Deadline() +} + +// Done 实现了 context.Context 的 Done 方法. +func (mc *mergedContext) Done() <-chan struct{} { + return mc.Context.Done() +} + +// Err 实现了 context.Context 的 Err 方法. +func (mc *mergedContext) Err() error { + return mc.Context.Err() +} + +// orDone 是一个辅助函数, 返回一个 channel. +// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭. +// 这是一个非阻塞的、不会泄漏 goroutine 的实现. +func orDone(contexts ...context.Context) <-chan struct{} { + done := make(chan struct{}) + + var once sync.Once + closeDone := func() { + once.Do(func() { + close(done) + }) + } + + // 为每个父 context 启动一个 goroutine. + for _, ctx := range contexts { + go func(c context.Context) { + select { + case <-c.Done(): + closeDone() + case <-done: + // orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出. + } + }(ctx) + } + + return done +} From 49508b49c145a43b3b1b9bc808e7e2bea46dd1c6 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 9 Jul 2025 00:17:52 +0800 Subject: [PATCH 2/2] fix limitMaxSizeReader non use body close & fix mergeCtx Value --- context.go | 12 ++++++++++++ mergectx.go | 10 ++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/context.go b/context.go index 5a8ce2d..6479a69 100644 --- a/context.go +++ b/context.go @@ -456,6 +456,12 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { }() } else { limitBytesReader = c.Request.Body + defer func() { + err := limitBytesReader.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() } data, err := copyb.ReadAll(limitBytesReader) @@ -484,6 +490,12 @@ func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { }() } else { limitBytesReader = c.Request.Body + defer func() { + err := limitBytesReader.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() } data, err := copyb.ReadAll(limitBytesReader) diff --git a/mergectx.go b/mergectx.go index e6c223d..4c91601 100644 --- a/mergectx.go +++ b/mergectx.go @@ -69,15 +69,9 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C return mc, mc.cancel } -// Value 实现了 context.Context 的 Value 方法. -// 它会按顺序遍历所有父 context, 并返回第一个找到的非 nil 值. +// Value 返回当前Ctx Value func (mc *mergedContext) Value(key any) any { - for _, p := range mc.parents { - if v := p.Value(key); v != nil { - return v - } - } - return nil + return mc.Context.Value(key) } // Deadline 实现了 context.Context 的 Deadline 方法.