Fix FileText status code and unify request body size limits

- FileText: now respects the provided status code instead of defaulting to 200 OK
- Request body limits: prepareRequestBody() is now only called when MaxRequestBodySize > 0
  - ShouldBindJSON, ShouldBindWANF, ShouldBindGOB, ShouldBindForm, GetReqBody, PostForm
    all now use the original c.Request.Body path when no limit is configured
- maxBytesReader: fixed exact-limit boundary case where body size == limit was
  incorrectly rejected
- Added regression tests for FileText status codes and body limit behavior

All existing tests pass, and new tests verify the corrected behavior.
This commit is contained in:
wjqserver 2026-03-31 16:38:04 +08:00
parent ef965f4a6a
commit 64e2ad9e7b
3 changed files with 252 additions and 56 deletions

View file

@ -44,6 +44,8 @@ type Context struct {
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler) handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
index int8 // 当前执行到处理链的哪个位置 index int8 // 当前执行到处理链的哪个位置
requestBodyPrepared bool
mu sync.RWMutex mu sync.RWMutex
Keys map[string]any // 用于在中间件之间传递数据 Keys map[string]any // 用于在中间件之间传递数据
@ -102,6 +104,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
c.requestBodyPrepared = false
if cap(c.SkippedNodes) > 0 { if cap(c.SkippedNodes) > 0 {
c.SkippedNodes = c.SkippedNodes[:0] c.SkippedNodes = c.SkippedNodes[:0]
@ -237,6 +240,18 @@ func (c *Context) SetMaxRequestBodySize(size int64) {
c.MaxRequestBodySize = size c.MaxRequestBodySize = size
} }
func (c *Context) prepareRequestBody() io.ReadCloser {
if c.Request == nil || c.Request.Body == nil {
return nil
}
if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 {
return c.Request.Body
}
c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
c.requestBodyPrepared = true
return c.Request.Body
}
// Query 从 URL 查询参数中获取值 // Query 从 URL 查询参数中获取值
// 懒加载解析查询参数,并进行缓存 // 懒加载解析查询参数,并进行缓存
func (c *Context) Query(key string) string { func (c *Context) Query(key string) string {
@ -258,7 +273,39 @@ func (c *Context) DefaultQuery(key, defaultValue string) string {
// 懒加载解析表单数据,并进行缓存 // 懒加载解析表单数据,并进行缓存
func (c *Context) PostForm(key string) string { func (c *Context) PostForm(key string) string {
if c.formCache == nil { if c.formCache == nil {
c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded if c.MaxRequestBodySize > 0 {
c.prepareRequestBody()
contentType := c.Request.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
switch mediaType {
case "multipart/form-data":
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
default:
if err := c.Request.ParseForm(); err != nil {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
}
} else {
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
if !errors.Is(err, http.ErrNotMultipart) {
c.AddError(fmt.Errorf("parse form error: %w", err))
c.formCache = make(url.Values)
return ""
}
}
}
c.formCache = c.Request.PostForm c.formCache = c.Request.PostForm
} }
return c.formCache.Get(key) return c.formCache.Get(key)
@ -338,8 +385,11 @@ func (c *Context) FileText(code int, filePath string) {
} }
c.SetHeader("Content-Type", "text/plain; charset=utf-8") c.SetHeader("Content-Type", "text/plain; charset=utf-8")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
c.SetBodyStream(file, int(fileInfo.Size())) c.Writer.WriteHeader(code)
if _, err := iox.Copy(c.Writer, file); err != nil {
c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err))
}
} }
/* /*
@ -557,10 +607,16 @@ func (c *Context) Redirect(code int, location string) {
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象 // ShouldBindJSON 尝试将请求体绑定到 JSON 对象
func (c *Context) ShouldBindJSON(obj any) error { func (c *Context) ShouldBindJSON(obj any) error {
if c.Request.Body == nil { var body io.ReadCloser
if c.MaxRequestBodySize > 0 {
body = c.prepareRequestBody()
} else {
body = c.Request.Body
}
if body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
err := json.UnmarshalRead(c.Request.Body, obj) err := json.UnmarshalRead(body, obj)
if err != nil { if err != nil {
return fmt.Errorf("json binding error: %w", err) return fmt.Errorf("json binding error: %w", err)
} }
@ -569,10 +625,16 @@ func (c *Context) ShouldBindJSON(obj any) error {
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象 // ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
func (c *Context) ShouldBindWANF(obj any) error { func (c *Context) ShouldBindWANF(obj any) error {
if c.Request.Body == nil { var body io.ReadCloser
if c.MaxRequestBodySize > 0 {
body = c.prepareRequestBody()
} else {
body = c.Request.Body
}
if body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
decoder, err := wanf.NewStreamDecoder(c.Request.Body) decoder, err := wanf.NewStreamDecoder(body)
if err != nil { if err != nil {
return fmt.Errorf("failed to create WANF decoder: %w", err) return fmt.Errorf("failed to create WANF decoder: %w", err)
} }
@ -585,10 +647,16 @@ func (c *Context) ShouldBindWANF(obj any) error {
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象 // ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
func (c *Context) ShouldBindGOB(obj any) error { func (c *Context) ShouldBindGOB(obj any) error {
if c.Request.Body == nil { var body io.ReadCloser
if c.MaxRequestBodySize > 0 {
body = c.prepareRequestBody()
} else {
body = c.Request.Body
}
if body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
decoder := gob.NewDecoder(c.Request.Body) decoder := gob.NewDecoder(body)
if err := decoder.Decode(obj); err != nil { if err := decoder.Decode(obj); err != nil {
return fmt.Errorf("GOB binding error: %w", err) return fmt.Errorf("GOB binding error: %w", err)
} }
@ -705,6 +773,10 @@ func setFieldValue(field reflect.Value, values []string) error {
// ShouldBindForm 尝试将表单数据绑定到结构体 // ShouldBindForm 尝试将表单数据绑定到结构体
// 支持 application/x-www-form-urlencoded 和 multipart/form-data // 支持 application/x-www-form-urlencoded 和 multipart/form-data
func (c *Context) ShouldBindForm(obj any) error { func (c *Context) ShouldBindForm(obj any) error {
if c.MaxRequestBodySize > 0 {
c.prepareRequestBody()
}
contentType := c.Request.Header.Get("Content-Type") contentType := c.Request.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType) mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil { if err != nil {
@ -713,7 +785,7 @@ func (c *Context) ShouldBindForm(obj any) error {
switch mediaType { switch mediaType {
case "multipart/form-data": case "multipart/form-data":
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
return fmt.Errorf("parse multipart form error: %w", err) return fmt.Errorf("parse multipart form error: %w", err)
} }
case "application/x-www-form-urlencoded": case "application/x-www-form-urlencoded":
@ -727,6 +799,7 @@ func (c *Context) ShouldBindForm(obj any) error {
if err := bindForm(c.Request.Form, obj); err != nil { if err := bindForm(c.Request.Form, obj); err != nil {
return fmt.Errorf("form binding error: %w", err) return fmt.Errorf("form binding error: %w", err)
} }
c.formCache = c.Request.PostForm
return nil return nil
} }
@ -827,37 +900,30 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) {
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体 // GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
// 注意:请求体只能读取一次 // 注意:请求体只能读取一次
func (c *Context) GetReqBody() io.ReadCloser { func (c *Context) GetReqBody() io.ReadCloser {
if c.MaxRequestBodySize > 0 {
return c.prepareRequestBody()
}
if c.Request == nil || c.Request.Body == nil {
return nil
}
return c.Request.Body return c.Request.Body
} }
// GetReqBodyFull 读取并返回请求体的所有内容 // GetReqBodyFull 读取并返回请求体的所有内容
// 注意:请求体只能读取一次 // 注意:请求体只能读取一次
func (c *Context) GetReqBodyFull() ([]byte, error) { func (c *Context) GetReqBodyFull() ([]byte, error) {
if c.Request.Body == nil { body := c.GetReqBody()
if body == nil {
return nil, nil return nil, nil
} }
var limitBytesReader io.ReadCloser
if c.MaxRequestBodySize > 0 {
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
defer func() { defer func() {
err := limitBytesReader.Close() err := body.Close()
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err)) c.AddError(fmt.Errorf("failed to close request body: %w", err))
} }
}() }()
} else {
limitBytesReader = c.Request.Body
defer func() {
err := limitBytesReader.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
}
data, err := iox.ReadAll(limitBytesReader) data, err := iox.ReadAll(body)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) c.AddError(fmt.Errorf("failed to read request body: %w", err))
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)
@ -867,31 +933,18 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
// 类似 GetReqBodyFull, 返回 *bytes.Buffer // 类似 GetReqBodyFull, 返回 *bytes.Buffer
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
if c.Request.Body == nil { body := c.GetReqBody()
if body == nil {
return nil, nil return nil, nil
} }
var limitBytesReader io.ReadCloser
if c.MaxRequestBodySize > 0 {
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
defer func() { defer func() {
err := limitBytesReader.Close() err := body.Close()
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err)) c.AddError(fmt.Errorf("failed to close request body: %w", err))
} }
}() }()
} else {
limitBytesReader = c.Request.Body
defer func() {
err := limitBytesReader.Close()
if err != nil {
c.AddError(fmt.Errorf("failed to close request body: %w", err))
}
}()
}
data, err := iox.ReadAll(limitBytesReader) data, err := iox.ReadAll(body)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) c.AddError(fmt.Errorf("failed to read request body: %w", err))
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)

125
context_bodylimit_test.go Normal file
View file

@ -0,0 +1,125 @@
package touka
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
func TestFileTextUsesProvidedStatusCode(t *testing.T) {
t.Helper()
dir := t.TempDir()
filePath := filepath.Join(dir, "hello.txt")
if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil {
t.Fatalf("write temp file: %v", err)
}
rr := httptest.NewRecorder()
c, _ := CreateTestContext(rr)
c.FileText(http.StatusCreated, filePath)
if rr.Code != http.StatusCreated {
t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code)
}
if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" {
t.Fatalf("unexpected content type: %q", got)
}
if body := rr.Body.String(); body != "hello touka" {
t.Fatalf("unexpected body: %q", body)
}
}
func TestMaxBytesReaderAllowsExactLimit(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4)
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("expected exact limit read to succeed, got %v", err)
}
if string(data) != "abcd" {
t.Fatalf("unexpected data: %q", string(data))
}
}
func TestMaxBytesReaderRejectsOverLimit(t *testing.T) {
t.Helper()
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4)
defer reader.Close()
_, err := io.ReadAll(reader)
if !errors.Is(err, ErrBodyTooLarge) {
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
}
}
func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) {
t.Helper()
body := strings.NewReader(`{"name":"abcdef"}`)
req := httptest.NewRequest(http.MethodPost, "/json", body)
req.Header.Set("Content-Type", "application/json")
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
c.SetMaxRequestBodySize(8)
var payload struct {
Name string `json:"name"`
}
err := c.ShouldBindJSON(&payload)
if !errors.Is(err, ErrBodyTooLarge) {
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
}
}
func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) {
t.Helper()
body := strings.NewReader("name=abcdef")
req := httptest.NewRequest(http.MethodPost, "/form", body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
c.SetMaxRequestBodySize(4)
var payload struct {
Name string `form:"name"`
}
err := c.ShouldBindForm(&payload)
if !errors.Is(err, ErrBodyTooLarge) {
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
}
}
func TestPostFormHonorsMaxRequestBodySize(t *testing.T) {
t.Helper()
body := strings.NewReader("name=abcdef")
req := httptest.NewRequest(http.MethodPost, "/form", body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
c.SetMaxRequestBodySize(4)
if got := c.PostForm("name"); got != "" {
t.Fatalf("expected empty value on over-limit form body, got %q", got)
}
if len(c.Errors) == 0 {
t.Fatal("expected parse error to be recorded")
}
if !errors.Is(c.Errors[0], ErrBodyTooLarge) {
t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0])
}
}

View file

@ -46,11 +46,29 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制. // Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
func (mbr *maxBytesReader) Read(p []byte) (int, error) { func (mbr *maxBytesReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
readSoFar := mbr.read.Load() readSoFar := mbr.read.Load()
// 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误. if readSoFar > mbr.n {
if readSoFar >= mbr.n { return 0, ErrBodyTooLarge
}
// 当已恰好读满限制时, 需要探测底层是否还有额外数据.
// 如果下一次读取立即 EOF, 说明请求体大小恰好等于限制, 属于合法情况.
if readSoFar == mbr.n {
var probe [1]byte
n, err := mbr.r.Read(probe[:])
if n > 0 {
mbr.read.Add(int64(n))
return 0, ErrBodyTooLarge
}
if err != nil {
return 0, err
}
return 0, ErrBodyTooLarge return 0, ErrBodyTooLarge
} }