mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Fix FileText status code and unify request body size limits
- FileText: now respects the provided status code instead of defaulting to 200 OK
- Request body limits: prepareRequestBody() is now only called when MaxRequestBodySize > 0
- ShouldBindJSON, ShouldBindWANF, ShouldBindGOB, ShouldBindForm, GetReqBody, PostForm
all now use the original c.Request.Body path when no limit is configured
- maxBytesReader: fixed exact-limit boundary case where body size == limit was
incorrectly rejected
- Added regression tests for FileText status codes and body limit behavior
All existing tests pass, and new tests verify the corrected behavior.
This commit is contained in:
parent
ef965f4a6a
commit
64e2ad9e7b
3 changed files with 252 additions and 56 deletions
161
context.go
161
context.go
|
|
@ -44,6 +44,8 @@ type Context struct {
|
|||
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
|
||||
index int8 // 当前执行到处理链的哪个位置
|
||||
|
||||
requestBodyPrepared bool
|
||||
|
||||
mu sync.RWMutex
|
||||
Keys map[string]any // 用于在中间件之间传递数据
|
||||
|
||||
|
|
@ -102,6 +104,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
|||
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
||||
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
||||
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
||||
c.requestBodyPrepared = false
|
||||
|
||||
if cap(c.SkippedNodes) > 0 {
|
||||
c.SkippedNodes = c.SkippedNodes[:0]
|
||||
|
|
@ -237,6 +240,18 @@ func (c *Context) SetMaxRequestBodySize(size int64) {
|
|||
c.MaxRequestBodySize = size
|
||||
}
|
||||
|
||||
func (c *Context) prepareRequestBody() io.ReadCloser {
|
||||
if c.Request == nil || c.Request.Body == nil {
|
||||
return nil
|
||||
}
|
||||
if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 {
|
||||
return c.Request.Body
|
||||
}
|
||||
c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
c.requestBodyPrepared = true
|
||||
return c.Request.Body
|
||||
}
|
||||
|
||||
// Query 从 URL 查询参数中获取值
|
||||
// 懒加载解析查询参数,并进行缓存
|
||||
func (c *Context) Query(key string) string {
|
||||
|
|
@ -258,7 +273,39 @@ func (c *Context) DefaultQuery(key, defaultValue string) string {
|
|||
// 懒加载解析表单数据,并进行缓存
|
||||
func (c *Context) PostForm(key string) string {
|
||||
if c.formCache == nil {
|
||||
c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
c.prepareRequestBody()
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||
c.formCache = make(url.Values)
|
||||
return ""
|
||||
}
|
||||
|
||||
switch mediaType {
|
||||
case "multipart/form-data":
|
||||
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||
c.formCache = make(url.Values)
|
||||
return ""
|
||||
}
|
||||
default:
|
||||
if err := c.Request.ParseForm(); err != nil {
|
||||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||
c.formCache = make(url.Values)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||
if !errors.Is(err, http.ErrNotMultipart) {
|
||||
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||
c.formCache = make(url.Values)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
c.formCache = c.Request.PostForm
|
||||
}
|
||||
return c.formCache.Get(key)
|
||||
|
|
@ -338,8 +385,11 @@ func (c *Context) FileText(code int, filePath string) {
|
|||
}
|
||||
|
||||
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
||||
|
||||
c.SetBodyStream(file, int(fileInfo.Size()))
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
|
||||
c.Writer.WriteHeader(code)
|
||||
if _, err := iox.Copy(c.Writer, file); err != nil {
|
||||
c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err))
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -557,10 +607,16 @@ func (c *Context) Redirect(code int, location string) {
|
|||
|
||||
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象
|
||||
func (c *Context) ShouldBindJSON(obj any) error {
|
||||
if c.Request.Body == nil {
|
||||
var body io.ReadCloser
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
body = c.prepareRequestBody()
|
||||
} else {
|
||||
body = c.Request.Body
|
||||
}
|
||||
if body == nil {
|
||||
return errors.New("request body is empty")
|
||||
}
|
||||
err := json.UnmarshalRead(c.Request.Body, obj)
|
||||
err := json.UnmarshalRead(body, obj)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json binding error: %w", err)
|
||||
}
|
||||
|
|
@ -569,10 +625,16 @@ func (c *Context) ShouldBindJSON(obj any) error {
|
|||
|
||||
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
|
||||
func (c *Context) ShouldBindWANF(obj any) error {
|
||||
if c.Request.Body == nil {
|
||||
var body io.ReadCloser
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
body = c.prepareRequestBody()
|
||||
} else {
|
||||
body = c.Request.Body
|
||||
}
|
||||
if body == nil {
|
||||
return errors.New("request body is empty")
|
||||
}
|
||||
decoder, err := wanf.NewStreamDecoder(c.Request.Body)
|
||||
decoder, err := wanf.NewStreamDecoder(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create WANF decoder: %w", err)
|
||||
}
|
||||
|
|
@ -585,10 +647,16 @@ func (c *Context) ShouldBindWANF(obj any) error {
|
|||
|
||||
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
|
||||
func (c *Context) ShouldBindGOB(obj any) error {
|
||||
if c.Request.Body == nil {
|
||||
var body io.ReadCloser
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
body = c.prepareRequestBody()
|
||||
} else {
|
||||
body = c.Request.Body
|
||||
}
|
||||
if body == nil {
|
||||
return errors.New("request body is empty")
|
||||
}
|
||||
decoder := gob.NewDecoder(c.Request.Body)
|
||||
decoder := gob.NewDecoder(body)
|
||||
if err := decoder.Decode(obj); err != nil {
|
||||
return fmt.Errorf("GOB binding error: %w", err)
|
||||
}
|
||||
|
|
@ -705,6 +773,10 @@ func setFieldValue(field reflect.Value, values []string) error {
|
|||
// ShouldBindForm 尝试将表单数据绑定到结构体
|
||||
// 支持 application/x-www-form-urlencoded 和 multipart/form-data
|
||||
func (c *Context) ShouldBindForm(obj any) error {
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
c.prepareRequestBody()
|
||||
}
|
||||
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
|
|
@ -713,7 +785,7 @@ func (c *Context) ShouldBindForm(obj any) error {
|
|||
|
||||
switch mediaType {
|
||||
case "multipart/form-data":
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
||||
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||
return fmt.Errorf("parse multipart form error: %w", err)
|
||||
}
|
||||
case "application/x-www-form-urlencoded":
|
||||
|
|
@ -727,6 +799,7 @@ func (c *Context) ShouldBindForm(obj any) error {
|
|||
if err := bindForm(c.Request.Form, obj); err != nil {
|
||||
return fmt.Errorf("form binding error: %w", err)
|
||||
}
|
||||
c.formCache = c.Request.PostForm
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -827,37 +900,30 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) {
|
|||
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
|
||||
// 注意:请求体只能读取一次
|
||||
func (c *Context) GetReqBody() io.ReadCloser {
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
return c.prepareRequestBody()
|
||||
}
|
||||
if c.Request == nil || c.Request.Body == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Request.Body
|
||||
}
|
||||
|
||||
// GetReqBodyFull 读取并返回请求体的所有内容
|
||||
// 注意:请求体只能读取一次
|
||||
func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||||
if c.Request.Body == nil {
|
||||
body := c.GetReqBody()
|
||||
if body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer func() {
|
||||
err := body.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
var limitBytesReader io.ReadCloser
|
||||
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
limitBytesReader = c.Request.Body
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
data, err := iox.ReadAll(limitBytesReader)
|
||||
data, err := iox.ReadAll(body)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
|
|
@ -867,31 +933,18 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
|
|||
|
||||
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
|
||||
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
||||
if c.Request.Body == nil {
|
||||
body := c.GetReqBody()
|
||||
if body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer func() {
|
||||
err := body.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
var limitBytesReader io.ReadCloser
|
||||
|
||||
if c.MaxRequestBodySize > 0 {
|
||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
limitBytesReader = c.Request.Body
|
||||
defer func() {
|
||||
err := limitBytesReader.Close()
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
data, err := iox.ReadAll(limitBytesReader)
|
||||
data, err := iox.ReadAll(body)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue