mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-02-03 00:41:10 +08:00
add maxBytesReader & ctxMerge
This commit is contained in:
parent
17bab2dcfd
commit
cb86cb935a
4 changed files with 270 additions and 8 deletions
45
context.go
45
context.go
|
|
@ -58,6 +58,9 @@ type Context struct {
|
|||
engine *Engine
|
||||
|
||||
sameSite http.SameSite
|
||||
|
||||
// 请求体Body大小限制
|
||||
MaxRequestBodySize int64
|
||||
}
|
||||
|
||||
// --- Context 相关方法实现 ---
|
||||
|
|
@ -83,6 +86,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
|||
c.formCache = nil // 清空表单数据缓存
|
||||
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
||||
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
||||
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
||||
// c.HTTPClient 和 c.engine 保持不变,它们引用 Engine 实例的成员
|
||||
}
|
||||
|
||||
|
|
@ -208,6 +212,11 @@ func (c *Context) MustGet(key string) interface{} {
|
|||
panic("Key \"" + key + "\" does not exist in context.")
|
||||
}
|
||||
|
||||
// SetMaxRequestBodySize
|
||||
func (c *Context) SetMaxRequestBodySize(size int64) {
|
||||
c.MaxRequestBodySize = size
|
||||
}
|
||||
|
||||
// Query 从 URL 查询参数中获取值
|
||||
// 懒加载解析查询参数,并进行缓存
|
||||
func (c *Context) Query(key string) string {
|
||||
|
|
@ -434,8 +443,22 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
|
|||
if c.Request.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer c.Request.Body.Close() // 确保请求体被关闭
|
||||
data, err := copyb.ReadAll(c.Request.Body)
|
||||
|
||||
var limitBytesReader io.ReadCloser
|
||||
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
limitBytesReader = c.Request.Body
|
||||
}
|
||||
|
||||
data, err := copyb.ReadAll(limitBytesReader)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
|
|
@ -448,8 +471,22 @@ func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
|||
if c.Request.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer c.Request.Body.Close() // 确保请求体被关闭
|
||||
data, err := copyb.ReadAll(c.Request.Body)
|
||||
|
||||
var limitBytesReader io.ReadCloser
|
||||
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
limitBytesReader = c.Request.Body
|
||||
}
|
||||
|
||||
data, err := copyb.ReadAll(limitBytesReader)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue