diff --git a/context.go b/context.go index 2e4d2bb..9c4ba7e 100644 --- a/context.go +++ b/context.go @@ -44,6 +44,8 @@ type Context struct { handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler) index int8 // 当前执行到处理链的哪个位置 + requestBodyPrepared bool + mu sync.RWMutex Keys map[string]any // 用于在中间件之间传递数据 @@ -102,6 +104,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize + c.requestBodyPrepared = false if cap(c.SkippedNodes) > 0 { c.SkippedNodes = c.SkippedNodes[:0] @@ -237,6 +240,18 @@ func (c *Context) SetMaxRequestBodySize(size int64) { c.MaxRequestBodySize = size } +func (c *Context) prepareRequestBody() io.ReadCloser { + if c.Request == nil || c.Request.Body == nil { + return nil + } + if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 { + return c.Request.Body + } + c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) + c.requestBodyPrepared = true + return c.Request.Body +} + // Query 从 URL 查询参数中获取值 // 懒加载解析查询参数,并进行缓存 func (c *Context) Query(key string) string { @@ -258,7 +273,39 @@ func (c *Context) DefaultQuery(key, defaultValue string) string { // 懒加载解析表单数据,并进行缓存 func (c *Context) PostForm(key string) string { if c.formCache == nil { - c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded + if c.MaxRequestBodySize > 0 { + c.prepareRequestBody() + } + contentType := c.Request.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + + switch mediaType { + case "multipart/form-data": + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + case "application/x-www-form-urlencoded": + if err := c.Request.ParseForm(); err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + default: + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { + if !errors.Is(err, http.ErrNotMultipart) { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + } + } c.formCache = c.Request.PostForm } return c.formCache.Get(key) @@ -338,8 +385,11 @@ func (c *Context) FileText(code int, filePath string) { } c.SetHeader("Content-Type", "text/plain; charset=utf-8") - - c.SetBodyStream(file, int(fileInfo.Size())) + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) + c.Writer.WriteHeader(code) + if _, err := iox.Copy(c.Writer, file); err != nil { + c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err)) + } } /* @@ -557,10 +607,16 @@ func (c *Context) Redirect(code int, location string) { // ShouldBindJSON 尝试将请求体绑定到 JSON 对象 func (c *Context) ShouldBindJSON(obj any) error { - if c.Request.Body == nil { + var body io.ReadCloser + if c.MaxRequestBodySize > 0 { + body = c.prepareRequestBody() + } else { + body = c.Request.Body + } + if body == nil { return errors.New("request body is empty") } - err := json.UnmarshalRead(c.Request.Body, obj) + err := json.UnmarshalRead(body, obj) if err != nil { return fmt.Errorf("json binding error: %w", err) } @@ -569,10 +625,16 @@ func (c *Context) ShouldBindJSON(obj any) error { // ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象 func (c *Context) ShouldBindWANF(obj any) error { - if c.Request.Body == nil { + var body io.ReadCloser + if c.MaxRequestBodySize > 0 { + body = c.prepareRequestBody() + } else { + body = c.Request.Body + } + if body == nil { return errors.New("request body is empty") } - decoder, err := wanf.NewStreamDecoder(c.Request.Body) + decoder, err := wanf.NewStreamDecoder(body) if err != nil { return fmt.Errorf("failed to create WANF decoder: %w", err) } @@ -585,10 +647,16 @@ func (c *Context) ShouldBindWANF(obj any) error { // ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象 func (c *Context) ShouldBindGOB(obj any) error { - if c.Request.Body == nil { + var body io.ReadCloser + if c.MaxRequestBodySize > 0 { + body = c.prepareRequestBody() + } else { + body = c.Request.Body + } + if body == nil { return errors.New("request body is empty") } - decoder := gob.NewDecoder(c.Request.Body) + decoder := gob.NewDecoder(body) if err := decoder.Decode(obj); err != nil { return fmt.Errorf("GOB binding error: %w", err) } @@ -705,6 +773,10 @@ func setFieldValue(field reflect.Value, values []string) error { // ShouldBindForm 尝试将表单数据绑定到结构体 // 支持 application/x-www-form-urlencoded 和 multipart/form-data func (c *Context) ShouldBindForm(obj any) error { + if c.MaxRequestBodySize > 0 { + c.prepareRequestBody() + } + contentType := c.Request.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(contentType) if err != nil { @@ -713,7 +785,7 @@ func (c *Context) ShouldBindForm(obj any) error { switch mediaType { case "multipart/form-data": - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { return fmt.Errorf("parse multipart form error: %w", err) } case "application/x-www-form-urlencoded": @@ -727,6 +799,7 @@ func (c *Context) ShouldBindForm(obj any) error { if err := bindForm(c.Request.Form, obj); err != nil { return fmt.Errorf("form binding error: %w", err) } + c.formCache = c.Request.PostForm return nil } @@ -827,37 +900,30 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) { // GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体 // 注意:请求体只能读取一次 func (c *Context) GetReqBody() io.ReadCloser { + if c.MaxRequestBodySize > 0 { + return c.prepareRequestBody() + } + if c.Request == nil || c.Request.Body == nil { + return nil + } return c.Request.Body } // GetReqBodyFull 读取并返回请求体的所有内容 // 注意:请求体只能读取一次 func (c *Context) GetReqBodyFull() ([]byte, error) { - if c.Request.Body == nil { + body := c.GetReqBody() + if body == nil { return nil, nil } + defer func() { + err := body.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() - var limitBytesReader io.ReadCloser - - if c.MaxRequestBodySize > 0 { - limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } else { - limitBytesReader = c.Request.Body - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } - - data, err := iox.ReadAll(limitBytesReader) + data, err := iox.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -867,31 +933,18 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { // 类似 GetReqBodyFull, 返回 *bytes.Buffer func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { - if c.Request.Body == nil { + body := c.GetReqBody() + if body == nil { return nil, nil } + defer func() { + err := body.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() - var limitBytesReader io.ReadCloser - - if c.MaxRequestBodySize > 0 { - limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } else { - limitBytesReader = c.Request.Body - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } - - data, err := iox.ReadAll(limitBytesReader) + data, err := iox.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) diff --git a/context_bodylimit_test.go b/context_bodylimit_test.go new file mode 100644 index 0000000..1e7696a --- /dev/null +++ b/context_bodylimit_test.go @@ -0,0 +1,174 @@ +package touka + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +type zeroNilThenEOFReader struct { + readCalls int +} + +func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) { + r.readCalls++ + if r.readCalls == 1 { + return 0, nil + } + return 0, io.EOF +} + +func (r *zeroNilThenEOFReader) Close() error { + return nil +} + +func TestFileTextUsesProvidedStatusCode(t *testing.T) { + t.Helper() + + dir := t.TempDir() + filePath := filepath.Join(dir, "hello.txt") + if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil { + t.Fatalf("write temp file: %v", err) + } + + rr := httptest.NewRecorder() + c, _ := CreateTestContext(rr) + + c.FileText(http.StatusCreated, filePath) + + if rr.Code != http.StatusCreated { + t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code) + } + if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" { + t.Fatalf("unexpected content type: %q", got) + } + if body := rr.Body.String(); body != "hello touka" { + t.Fatalf("unexpected body: %q", body) + } +} + +func TestMaxBytesReaderAllowsExactLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4) + defer reader.Close() + + data, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("expected exact limit read to succeed, got %v", err) + } + if string(data) != "abcd" { + t.Fatalf("unexpected data: %q", string(data)) + } +} + +func TestMaxBytesReaderRejectsOverLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4) + defer reader.Close() + + _, err := io.ReadAll(reader) + if !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge, got %v", err) + } +} + +func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1) + defer reader.Close() + + buf := make([]byte, 1) + n, err := reader.Read(buf) + if n != 0 || err != nil { + t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err) + } + + n, err = reader.Read(buf) + if n != 0 || !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err) + } +} + +func TestMaxBytesReaderTreatsZeroLimitAsUnlimited(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abc")), 0) + defer reader.Close() + + data, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("expected zero limit to leave body unlimited, got %v", err) + } + if string(data) != "abc" { + t.Fatalf("unexpected data: %q", string(data)) + } +} + +func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) { + t.Helper() + + body := strings.NewReader(`{"name":"abcdef"}`) + req := httptest.NewRequest(http.MethodPost, "/json", body) + req.Header.Set("Content-Type", "application/json") + + c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) + c.SetMaxRequestBodySize(8) + + var payload struct { + Name string `json:"name"` + } + + err := c.ShouldBindJSON(&payload) + if !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge, got %v", err) + } +} + +func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) { + t.Helper() + + body := strings.NewReader("name=abcdef") + req := httptest.NewRequest(http.MethodPost, "/form", body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) + c.SetMaxRequestBodySize(4) + + var payload struct { + Name string `form:"name"` + } + + err := c.ShouldBindForm(&payload) + if !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge, got %v", err) + } +} + +func TestPostFormHonorsMaxRequestBodySize(t *testing.T) { + t.Helper() + + body := strings.NewReader("name=abcdef") + req := httptest.NewRequest(http.MethodPost, "/form", body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) + c.SetMaxRequestBodySize(4) + + if got := c.PostForm("name"); got != "" { + t.Fatalf("expected empty value on over-limit form body, got %q", got) + } + if len(c.Errors) == 0 { + t.Fatal("expected parse error to be recorded") + } + if !errors.Is(c.Errors[0], ErrBodyTooLarge) { + t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0]) + } +} diff --git a/maxreader.go b/maxreader.go index c6201e6..4d3fb2c 100644 --- a/maxreader.go +++ b/maxreader.go @@ -23,19 +23,21 @@ type maxBytesReader struct { n int64 // read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数. read atomic.Int64 + // emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读. + emptyAtLimit atomic.Bool } // NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据, // 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误. // // 如果 r 为 nil, 会 panic. -// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r. +// 如果 n 小于等于 0, 则读取不受限制, 直接返回原始的 r. func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { if r == nil { panic("NewMaxBytesReader called with a nil reader") } - // 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser. - if n < 0 { + // 如果限制为非正数, 意味着不限制, 直接返回原始的 ReadCloser. + if n <= 0 { return r } return &maxBytesReader{ @@ -46,48 +48,53 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { // Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制. func (mbr *maxBytesReader) Read(p []byte) (int, error) { - // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. - readSoFar := mbr.read.Load() - - // 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误. - if readSoFar >= mbr.n { - return 0, ErrBodyTooLarge + if len(p) == 0 { + return 0, nil } - // 计算当前还可以读取多少字节. + // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. + readSoFar := mbr.read.Load() remaining := mbr.n - readSoFar + if remaining < 0 { + return 0, ErrBodyTooLarge + } + if remaining == 0 { + var probe [1]byte + n, err := mbr.r.Read(probe[:]) + if n > 0 { + mbr.read.Add(int64(n)) + return 0, ErrBodyTooLarge + } + if err != nil { + return 0, err + } + if mbr.emptyAtLimit.Swap(true) { + return 0, ErrBodyTooLarge + } + return 0, nil + } + mbr.emptyAtLimit.Store(false) - // 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度. - // 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数. - if int64(len(p)) > remaining { - p = p[:remaining] + // 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。 + if int64(len(p))-1 > remaining { + p = p[:remaining+1] } // 从底层 Reader 读取数据. n, err := mbr.r.Read(p) - // 如果实际读取到了数据, 更新原子计数器. - if n > 0 { - readSoFar = mbr.read.Add(int64(n)) - } - - // 如果底层 Read 返回错误 (例如 io.EOF). - if err != nil { - // 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束. - // 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了. + if int64(n) <= remaining { + if n > 0 { + mbr.read.Add(int64(n)) + } return n, err } - // 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误. - // 这是处理"跨越"限制情况的关键. - if readSoFar > mbr.n { - // 返回实际读取的字节数 n, 并附上超限错误. - // 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭. - return n, ErrBodyTooLarge + // 读取结果跨过了限制,只向上层暴露允许的部分。 + if remaining > 0 { + mbr.read.Add(remaining) } - - // 一切正常, 返回读取的字节数和 nil 错误. - return n, nil + return int(remaining), ErrBodyTooLarge } // Close 方法关闭底层的 ReadCloser, 保证资源释放.