mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-02-03 08:51:11 +08:00
Merge pull request #36 from infinite-iroha/dev
add maxBytesReader & ctxMerge
This commit is contained in:
commit
5d2ab04b6b
4 changed files with 276 additions and 8 deletions
57
context.go
57
context.go
|
|
@ -58,6 +58,9 @@ type Context struct {
|
|||
engine *Engine
|
||||
|
||||
sameSite http.SameSite
|
||||
|
||||
// 请求体Body大小限制
|
||||
MaxRequestBodySize int64
|
||||
}
|
||||
|
||||
// --- Context 相关方法实现 ---
|
||||
|
|
@ -83,6 +86,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
|||
c.formCache = nil // 清空表单数据缓存
|
||||
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
||||
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
||||
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
||||
// c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员
|
||||
}
|
||||
|
||||
|
|
@ -208,6 +212,11 @@ func (c *Context) MustGet(key string) interface{} {
|
|||
panic("Key \"" + key + "\" does not exist in context.")
|
||||
}
|
||||
|
||||
// SetMaxRequestBodySize
|
||||
func (c *Context) SetMaxRequestBodySize(size int64) {
|
||||
c.MaxRequestBodySize = size
|
||||
}
|
||||
|
||||
// Query 从 URL 查询参数中获取值
|
||||
// 懒加载解析查询参数,并进行缓存
|
||||
func (c *Context) Query(key string) string {
|
||||
|
|
@ -434,8 +443,28 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
|
|||
if c.Request.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer c.Request.Body.Close() // 确保请求体被关闭
|
||||
data, err := copyb.ReadAll(c.Request.Body)
|
||||
|
||||
var limitBytesReader io.ReadCloser
|
||||
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
limitBytesReader = c.Request.Body
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
data, err := copyb.ReadAll(limitBytesReader)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
|
|
@ -448,8 +477,28 @@ func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
|||
if c.Request.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer c.Request.Body.Close() // 确保请求体被关闭
|
||||
data, err := copyb.ReadAll(c.Request.Body)
|
||||
|
||||
var limitBytesReader io.ReadCloser
|
||||
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
limitBytesReader = c.Request.Body
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
data, err := copyb.ReadAll(limitBytesReader)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
|
|
|
|||
|
|
@ -74,6 +74,9 @@ type Engine struct {
|
|||
// 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器
|
||||
// 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置)
|
||||
TLSServerConfigurator func(*http.Server)
|
||||
|
||||
// GlobalMaxRequestBodySize 全局请求体Body大小限制
|
||||
GlobalMaxRequestBodySize int64
|
||||
}
|
||||
|
||||
type ErrorHandle struct {
|
||||
|
|
@ -175,6 +178,7 @@ func New() *Engine {
|
|||
noRoutes: make(HandlersChain, 0),
|
||||
ServerConfigurator: nil,
|
||||
TLSServerConfigurator: nil,
|
||||
GlobalMaxRequestBodySize: -1,
|
||||
}
|
||||
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
||||
engine.SetDefaultProtocols()
|
||||
|
|
@ -294,6 +298,11 @@ func (engine *Engine) SetProtocols(config *ProtocolsConfig) {
|
|||
engine.useDefaultProtocols = false
|
||||
}
|
||||
|
||||
// 配置全局Req Body大小限制
|
||||
func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) {
|
||||
engine.GlobalMaxRequestBodySize = size
|
||||
}
|
||||
|
||||
// 配置Req IP来源 Headers
|
||||
func (engine *Engine) SetRemoteIPHeaders(headers []string) {
|
||||
engine.RemoteIPHeaders = headers
|
||||
|
|
|
|||
92
maxreader.go
Normal file
92
maxreader.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ErrBodyTooLarge 是当读取的字节数超过 MaxBytesReader 设置的限制时返回的错误.
|
||||
// 将其定义为可导出的变量, 方便调用方使用 errors.Is 进行判断.
|
||||
var ErrBodyTooLarge = fmt.Errorf("body too large")
|
||||
|
||||
// maxBytesReader 是一个实现了 io.ReadCloser 接口的结构体.
|
||||
// 它包装了另一个 io.ReadCloser, 并限制了从其中读取的最大字节数.
|
||||
type maxBytesReader struct {
|
||||
// r 是底层的 io.ReadCloser.
|
||||
r io.ReadCloser
|
||||
// n 是允许读取的最大字节数.
|
||||
n int64
|
||||
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
|
||||
read atomic.Int64
|
||||
}
|
||||
|
||||
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
|
||||
// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误.
|
||||
//
|
||||
// 如果 r 为 nil, 会 panic.
|
||||
// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r.
|
||||
func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
||||
if r == nil {
|
||||
panic("NewMaxBytesReader called with a nil reader")
|
||||
}
|
||||
// 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser.
|
||||
if n < 0 {
|
||||
return r
|
||||
}
|
||||
return &maxBytesReader{
|
||||
r: r,
|
||||
n: n,
|
||||
}
|
||||
}
|
||||
|
||||
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
|
||||
func (mbr *maxBytesReader) Read(p []byte) (int, error) {
|
||||
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
|
||||
readSoFar := mbr.read.Load()
|
||||
|
||||
// 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误.
|
||||
if readSoFar >= mbr.n {
|
||||
return 0, ErrBodyTooLarge
|
||||
}
|
||||
|
||||
// 计算当前还可以读取多少字节.
|
||||
remaining := mbr.n - readSoFar
|
||||
|
||||
// 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度.
|
||||
// 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数.
|
||||
if int64(len(p)) > remaining {
|
||||
p = p[:remaining]
|
||||
}
|
||||
|
||||
// 从底层 Reader 读取数据.
|
||||
n, err := mbr.r.Read(p)
|
||||
|
||||
// 如果实际读取到了数据, 更新原子计数器.
|
||||
if n > 0 {
|
||||
readSoFar = mbr.read.Add(int64(n))
|
||||
}
|
||||
|
||||
// 如果底层 Read 返回错误 (例如 io.EOF).
|
||||
if err != nil {
|
||||
// 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束.
|
||||
// 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了.
|
||||
return n, err
|
||||
}
|
||||
|
||||
// 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误.
|
||||
// 这是处理"跨越"限制情况的关键.
|
||||
if readSoFar > mbr.n {
|
||||
// 返回实际读取的字节数 n, 并附上超限错误.
|
||||
// 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭.
|
||||
return n, ErrBodyTooLarge
|
||||
}
|
||||
|
||||
// 一切正常, 返回读取的字节数和 nil 错误.
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Close 方法关闭底层的 ReadCloser, 保证资源释放.
|
||||
func (mbr *maxBytesReader) Close() error {
|
||||
return mbr.r.Close()
|
||||
}
|
||||
118
mergectx.go
Normal file
118
mergectx.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
|
||||
type mergedContext struct {
|
||||
// 嵌入一个基础 context, 它持有最早的 deadline 和取消信号.
|
||||
context.Context
|
||||
// 保存了所有的父 context, 用于 Value() 方法的查找.
|
||||
parents []context.Context
|
||||
// 用于手动取消此 mergedContext 的函数.
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// MergeCtx 创建并返回一个新的 context.Context.
|
||||
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
|
||||
// 自动被取消 (逻辑或关系).
|
||||
//
|
||||
// 新的 context 会继承:
|
||||
// - Deadline: 所有父 context 中最早的截止时间.
|
||||
// - Value: 按传入顺序从第一个能找到值的父 context 中获取值.
|
||||
func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.CancelFunc) {
|
||||
if len(parents) == 0 {
|
||||
return context.WithCancel(context.Background())
|
||||
}
|
||||
if len(parents) == 1 {
|
||||
return context.WithCancel(parents[0])
|
||||
}
|
||||
|
||||
var earliestDeadline time.Time
|
||||
for _, p := range parents {
|
||||
if deadline, ok := p.Deadline(); ok {
|
||||
if earliestDeadline.IsZero() || deadline.Before(earliestDeadline) {
|
||||
earliestDeadline = deadline
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var baseCtx context.Context
|
||||
var baseCancel context.CancelFunc
|
||||
if !earliestDeadline.IsZero() {
|
||||
baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline)
|
||||
} else {
|
||||
baseCtx, baseCancel = context.WithCancel(context.Background())
|
||||
}
|
||||
|
||||
mc := &mergedContext{
|
||||
Context: baseCtx,
|
||||
parents: parents,
|
||||
cancel: baseCancel,
|
||||
}
|
||||
|
||||
// 启动一个监控 goroutine.
|
||||
go func() {
|
||||
defer mc.cancel()
|
||||
|
||||
// orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭.
|
||||
// 同时监听 baseCtx.Done() 以便支持手动取消.
|
||||
select {
|
||||
case <-orDone(mc.parents...):
|
||||
case <-mc.Context.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
return mc, mc.cancel
|
||||
}
|
||||
|
||||
// Value 返回当前Ctx Value
|
||||
func (mc *mergedContext) Value(key any) any {
|
||||
return mc.Context.Value(key)
|
||||
}
|
||||
|
||||
// Deadline 实现了 context.Context 的 Deadline 方法.
|
||||
func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return mc.Context.Deadline()
|
||||
}
|
||||
|
||||
// Done 实现了 context.Context 的 Done 方法.
|
||||
func (mc *mergedContext) Done() <-chan struct{} {
|
||||
return mc.Context.Done()
|
||||
}
|
||||
|
||||
// Err 实现了 context.Context 的 Err 方法.
|
||||
func (mc *mergedContext) Err() error {
|
||||
return mc.Context.Err()
|
||||
}
|
||||
|
||||
// orDone 是一个辅助函数, 返回一个 channel.
|
||||
// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭.
|
||||
// 这是一个非阻塞的、不会泄漏 goroutine 的实现.
|
||||
func orDone(contexts ...context.Context) <-chan struct{} {
|
||||
done := make(chan struct{})
|
||||
|
||||
var once sync.Once
|
||||
closeDone := func() {
|
||||
once.Do(func() {
|
||||
close(done)
|
||||
})
|
||||
}
|
||||
|
||||
// 为每个父 context 启动一个 goroutine.
|
||||
for _, ctx := range contexts {
|
||||
go func(c context.Context) {
|
||||
select {
|
||||
case <-c.Done():
|
||||
closeDone()
|
||||
case <-done:
|
||||
// orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出.
|
||||
}
|
||||
}(ctx)
|
||||
}
|
||||
|
||||
return done
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue