mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Merge pull request #76 from infinite-iroha/break/v1-fix-filetext-bodylimit
Break/v1 fix filetext bodylimit
This commit is contained in:
commit
c019f24e99
3 changed files with 320 additions and 86 deletions
141
context.go
141
context.go
|
|
@ -44,6 +44,8 @@ type Context struct {
|
||||||
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
|
handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler)
|
||||||
index int8 // 当前执行到处理链的哪个位置
|
index int8 // 当前执行到处理链的哪个位置
|
||||||
|
|
||||||
|
requestBodyPrepared bool
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
Keys map[string]any // 用于在中间件之间传递数据
|
Keys map[string]any // 用于在中间件之间传递数据
|
||||||
|
|
||||||
|
|
@ -102,6 +104,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
|
||||||
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值
|
||||||
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式
|
||||||
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize
|
||||||
|
c.requestBodyPrepared = false
|
||||||
|
|
||||||
if cap(c.SkippedNodes) > 0 {
|
if cap(c.SkippedNodes) > 0 {
|
||||||
c.SkippedNodes = c.SkippedNodes[:0]
|
c.SkippedNodes = c.SkippedNodes[:0]
|
||||||
|
|
@ -237,6 +240,18 @@ func (c *Context) SetMaxRequestBodySize(size int64) {
|
||||||
c.MaxRequestBodySize = size
|
c.MaxRequestBodySize = size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Context) prepareRequestBody() io.ReadCloser {
|
||||||
|
if c.Request == nil || c.Request.Body == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 {
|
||||||
|
return c.Request.Body
|
||||||
|
}
|
||||||
|
c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
||||||
|
c.requestBodyPrepared = true
|
||||||
|
return c.Request.Body
|
||||||
|
}
|
||||||
|
|
||||||
// Query 从 URL 查询参数中获取值
|
// Query 从 URL 查询参数中获取值
|
||||||
// 懒加载解析查询参数,并进行缓存
|
// 懒加载解析查询参数,并进行缓存
|
||||||
func (c *Context) Query(key string) string {
|
func (c *Context) Query(key string) string {
|
||||||
|
|
@ -258,7 +273,39 @@ func (c *Context) DefaultQuery(key, defaultValue string) string {
|
||||||
// 懒加载解析表单数据,并进行缓存
|
// 懒加载解析表单数据,并进行缓存
|
||||||
func (c *Context) PostForm(key string) string {
|
func (c *Context) PostForm(key string) string {
|
||||||
if c.formCache == nil {
|
if c.formCache == nil {
|
||||||
c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
c.prepareRequestBody()
|
||||||
|
}
|
||||||
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
|
if err != nil {
|
||||||
|
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||||
|
c.formCache = make(url.Values)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch mediaType {
|
||||||
|
case "multipart/form-data":
|
||||||
|
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||||
|
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||||
|
c.formCache = make(url.Values)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case "application/x-www-form-urlencoded":
|
||||||
|
if err := c.Request.ParseForm(); err != nil {
|
||||||
|
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||||
|
c.formCache = make(url.Values)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||||
|
if !errors.Is(err, http.ErrNotMultipart) {
|
||||||
|
c.AddError(fmt.Errorf("parse form error: %w", err))
|
||||||
|
c.formCache = make(url.Values)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
c.formCache = c.Request.PostForm
|
c.formCache = c.Request.PostForm
|
||||||
}
|
}
|
||||||
return c.formCache.Get(key)
|
return c.formCache.Get(key)
|
||||||
|
|
@ -338,8 +385,11 @@ func (c *Context) FileText(code int, filePath string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
c.SetHeader("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
|
||||||
c.SetBodyStream(file, int(fileInfo.Size()))
|
c.Writer.WriteHeader(code)
|
||||||
|
if _, err := iox.Copy(c.Writer, file); err != nil {
|
||||||
|
c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
@ -557,10 +607,16 @@ func (c *Context) Redirect(code int, location string) {
|
||||||
|
|
||||||
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象
|
// ShouldBindJSON 尝试将请求体绑定到 JSON 对象
|
||||||
func (c *Context) ShouldBindJSON(obj any) error {
|
func (c *Context) ShouldBindJSON(obj any) error {
|
||||||
if c.Request.Body == nil {
|
var body io.ReadCloser
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
body = c.prepareRequestBody()
|
||||||
|
} else {
|
||||||
|
body = c.Request.Body
|
||||||
|
}
|
||||||
|
if body == nil {
|
||||||
return errors.New("request body is empty")
|
return errors.New("request body is empty")
|
||||||
}
|
}
|
||||||
err := json.UnmarshalRead(c.Request.Body, obj)
|
err := json.UnmarshalRead(body, obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("json binding error: %w", err)
|
return fmt.Errorf("json binding error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -569,10 +625,16 @@ func (c *Context) ShouldBindJSON(obj any) error {
|
||||||
|
|
||||||
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
|
// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象
|
||||||
func (c *Context) ShouldBindWANF(obj any) error {
|
func (c *Context) ShouldBindWANF(obj any) error {
|
||||||
if c.Request.Body == nil {
|
var body io.ReadCloser
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
body = c.prepareRequestBody()
|
||||||
|
} else {
|
||||||
|
body = c.Request.Body
|
||||||
|
}
|
||||||
|
if body == nil {
|
||||||
return errors.New("request body is empty")
|
return errors.New("request body is empty")
|
||||||
}
|
}
|
||||||
decoder, err := wanf.NewStreamDecoder(c.Request.Body)
|
decoder, err := wanf.NewStreamDecoder(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create WANF decoder: %w", err)
|
return fmt.Errorf("failed to create WANF decoder: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -585,10 +647,16 @@ func (c *Context) ShouldBindWANF(obj any) error {
|
||||||
|
|
||||||
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
|
// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象
|
||||||
func (c *Context) ShouldBindGOB(obj any) error {
|
func (c *Context) ShouldBindGOB(obj any) error {
|
||||||
if c.Request.Body == nil {
|
var body io.ReadCloser
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
body = c.prepareRequestBody()
|
||||||
|
} else {
|
||||||
|
body = c.Request.Body
|
||||||
|
}
|
||||||
|
if body == nil {
|
||||||
return errors.New("request body is empty")
|
return errors.New("request body is empty")
|
||||||
}
|
}
|
||||||
decoder := gob.NewDecoder(c.Request.Body)
|
decoder := gob.NewDecoder(body)
|
||||||
if err := decoder.Decode(obj); err != nil {
|
if err := decoder.Decode(obj); err != nil {
|
||||||
return fmt.Errorf("GOB binding error: %w", err)
|
return fmt.Errorf("GOB binding error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -705,6 +773,10 @@ func setFieldValue(field reflect.Value, values []string) error {
|
||||||
// ShouldBindForm 尝试将表单数据绑定到结构体
|
// ShouldBindForm 尝试将表单数据绑定到结构体
|
||||||
// 支持 application/x-www-form-urlencoded 和 multipart/form-data
|
// 支持 application/x-www-form-urlencoded 和 multipart/form-data
|
||||||
func (c *Context) ShouldBindForm(obj any) error {
|
func (c *Context) ShouldBindForm(obj any) error {
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
c.prepareRequestBody()
|
||||||
|
}
|
||||||
|
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -713,7 +785,7 @@ func (c *Context) ShouldBindForm(obj any) error {
|
||||||
|
|
||||||
switch mediaType {
|
switch mediaType {
|
||||||
case "multipart/form-data":
|
case "multipart/form-data":
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
if err := c.Request.ParseMultipartForm(defaultMemory); err != nil {
|
||||||
return fmt.Errorf("parse multipart form error: %w", err)
|
return fmt.Errorf("parse multipart form error: %w", err)
|
||||||
}
|
}
|
||||||
case "application/x-www-form-urlencoded":
|
case "application/x-www-form-urlencoded":
|
||||||
|
|
@ -727,6 +799,7 @@ func (c *Context) ShouldBindForm(obj any) error {
|
||||||
if err := bindForm(c.Request.Form, obj); err != nil {
|
if err := bindForm(c.Request.Form, obj); err != nil {
|
||||||
return fmt.Errorf("form binding error: %w", err)
|
return fmt.Errorf("form binding error: %w", err)
|
||||||
}
|
}
|
||||||
|
c.formCache = c.Request.PostForm
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -827,37 +900,30 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) {
|
||||||
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
|
// GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体
|
||||||
// 注意:请求体只能读取一次
|
// 注意:请求体只能读取一次
|
||||||
func (c *Context) GetReqBody() io.ReadCloser {
|
func (c *Context) GetReqBody() io.ReadCloser {
|
||||||
|
if c.MaxRequestBodySize > 0 {
|
||||||
|
return c.prepareRequestBody()
|
||||||
|
}
|
||||||
|
if c.Request == nil || c.Request.Body == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return c.Request.Body
|
return c.Request.Body
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetReqBodyFull 读取并返回请求体的所有内容
|
// GetReqBodyFull 读取并返回请求体的所有内容
|
||||||
// 注意:请求体只能读取一次
|
// 注意:请求体只能读取一次
|
||||||
func (c *Context) GetReqBodyFull() ([]byte, error) {
|
func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||||||
if c.Request.Body == nil {
|
body := c.GetReqBody()
|
||||||
|
if body == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var limitBytesReader io.ReadCloser
|
|
||||||
|
|
||||||
if c.MaxRequestBodySize > 0 {
|
|
||||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := limitBytesReader.Close()
|
err := body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
} else {
|
|
||||||
limitBytesReader = c.Request.Body
|
|
||||||
defer func() {
|
|
||||||
err := limitBytesReader.Close()
|
|
||||||
if err != nil {
|
|
||||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := iox.ReadAll(limitBytesReader)
|
data, err := iox.ReadAll(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||||
|
|
@ -867,31 +933,18 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||||||
|
|
||||||
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
|
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
|
||||||
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
||||||
if c.Request.Body == nil {
|
body := c.GetReqBody()
|
||||||
|
if body == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var limitBytesReader io.ReadCloser
|
|
||||||
|
|
||||||
if c.MaxRequestBodySize > 0 {
|
|
||||||
limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := limitBytesReader.Close()
|
err := body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
} else {
|
|
||||||
limitBytesReader = c.Request.Body
|
|
||||||
defer func() {
|
|
||||||
err := limitBytesReader.Close()
|
|
||||||
if err != nil {
|
|
||||||
c.AddError(fmt.Errorf("failed to close request body: %w", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := iox.ReadAll(limitBytesReader)
|
data, err := iox.ReadAll(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||||
|
|
|
||||||
174
context_bodylimit_test.go
Normal file
174
context_bodylimit_test.go
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type zeroNilThenEOFReader struct {
|
||||||
|
readCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) {
|
||||||
|
r.readCalls++
|
||||||
|
if r.readCalls == 1 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *zeroNilThenEOFReader) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTextUsesProvidedStatusCode(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
filePath := filepath.Join(dir, "hello.txt")
|
||||||
|
if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
c, _ := CreateTestContext(rr)
|
||||||
|
|
||||||
|
c.FileText(http.StatusCreated, filePath)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code)
|
||||||
|
}
|
||||||
|
if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" {
|
||||||
|
t.Fatalf("unexpected content type: %q", got)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); body != "hello touka" {
|
||||||
|
t.Fatalf("unexpected body: %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxBytesReaderAllowsExactLimit(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected exact limit read to succeed, got %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "abcd" {
|
||||||
|
t.Fatalf("unexpected data: %q", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxBytesReaderRejectsOverLimit(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
_, err := io.ReadAll(reader)
|
||||||
|
if !errors.Is(err, ErrBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
n, err := reader.Read(buf)
|
||||||
|
if n != 0 || err != nil {
|
||||||
|
t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = reader.Read(buf)
|
||||||
|
if n != 0 || !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaxBytesReaderTreatsZeroLimitAsUnlimited(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abc")), 0)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected zero limit to leave body unlimited, got %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "abc" {
|
||||||
|
t.Fatalf("unexpected data: %q", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body := strings.NewReader(`{"name":"abcdef"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/json", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||||
|
c.SetMaxRequestBodySize(8)
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.ShouldBindJSON(&payload)
|
||||||
|
if !errors.Is(err, ErrBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body := strings.NewReader("name=abcdef")
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/form", body)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||||
|
c.SetMaxRequestBodySize(4)
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Name string `form:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.ShouldBindForm(&payload)
|
||||||
|
if !errors.Is(err, ErrBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrBodyTooLarge, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostFormHonorsMaxRequestBodySize(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body := strings.NewReader("name=abcdef")
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/form", body)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||||
|
c.SetMaxRequestBodySize(4)
|
||||||
|
|
||||||
|
if got := c.PostForm("name"); got != "" {
|
||||||
|
t.Fatalf("expected empty value on over-limit form body, got %q", got)
|
||||||
|
}
|
||||||
|
if len(c.Errors) == 0 {
|
||||||
|
t.Fatal("expected parse error to be recorded")
|
||||||
|
}
|
||||||
|
if !errors.Is(c.Errors[0], ErrBodyTooLarge) {
|
||||||
|
t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
67
maxreader.go
67
maxreader.go
|
|
@ -23,19 +23,21 @@ type maxBytesReader struct {
|
||||||
n int64
|
n int64
|
||||||
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
|
// read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数.
|
||||||
read atomic.Int64
|
read atomic.Int64
|
||||||
|
// emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读.
|
||||||
|
emptyAtLimit atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
|
// NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据,
|
||||||
// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误.
|
// 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误.
|
||||||
//
|
//
|
||||||
// 如果 r 为 nil, 会 panic.
|
// 如果 r 为 nil, 会 panic.
|
||||||
// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r.
|
// 如果 n 小于等于 0, 则读取不受限制, 直接返回原始的 r.
|
||||||
func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
panic("NewMaxBytesReader called with a nil reader")
|
panic("NewMaxBytesReader called with a nil reader")
|
||||||
}
|
}
|
||||||
// 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser.
|
// 如果限制为非正数, 意味着不限制, 直接返回原始的 ReadCloser.
|
||||||
if n < 0 {
|
if n <= 0 {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
return &maxBytesReader{
|
return &maxBytesReader{
|
||||||
|
|
@ -46,48 +48,53 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
|
||||||
|
|
||||||
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
|
// Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制.
|
||||||
func (mbr *maxBytesReader) Read(p []byte) (int, error) {
|
func (mbr *maxBytesReader) Read(p []byte) (int, error) {
|
||||||
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
|
if len(p) == 0 {
|
||||||
readSoFar := mbr.read.Load()
|
return 0, nil
|
||||||
|
|
||||||
// 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误.
|
|
||||||
if readSoFar >= mbr.n {
|
|
||||||
return 0, ErrBodyTooLarge
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算当前还可以读取多少字节.
|
// 在函数开始时只加载一次原子变量, 减少后续的原子操作开销.
|
||||||
|
readSoFar := mbr.read.Load()
|
||||||
remaining := mbr.n - readSoFar
|
remaining := mbr.n - readSoFar
|
||||||
|
if remaining < 0 {
|
||||||
|
return 0, ErrBodyTooLarge
|
||||||
|
}
|
||||||
|
if remaining == 0 {
|
||||||
|
var probe [1]byte
|
||||||
|
n, err := mbr.r.Read(probe[:])
|
||||||
|
if n > 0 {
|
||||||
|
mbr.read.Add(int64(n))
|
||||||
|
return 0, ErrBodyTooLarge
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if mbr.emptyAtLimit.Swap(true) {
|
||||||
|
return 0, ErrBodyTooLarge
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
mbr.emptyAtLimit.Store(false)
|
||||||
|
|
||||||
// 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度.
|
// 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。
|
||||||
// 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数.
|
if int64(len(p))-1 > remaining {
|
||||||
if int64(len(p)) > remaining {
|
p = p[:remaining+1]
|
||||||
p = p[:remaining]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从底层 Reader 读取数据.
|
// 从底层 Reader 读取数据.
|
||||||
n, err := mbr.r.Read(p)
|
n, err := mbr.r.Read(p)
|
||||||
|
|
||||||
// 如果实际读取到了数据, 更新原子计数器.
|
if int64(n) <= remaining {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
readSoFar = mbr.read.Add(int64(n))
|
mbr.read.Add(int64(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果底层 Read 返回错误 (例如 io.EOF).
|
|
||||||
if err != nil {
|
|
||||||
// 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束.
|
|
||||||
// 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了.
|
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误.
|
// 读取结果跨过了限制,只向上层暴露允许的部分。
|
||||||
// 这是处理"跨越"限制情况的关键.
|
if remaining > 0 {
|
||||||
if readSoFar > mbr.n {
|
mbr.read.Add(remaining)
|
||||||
// 返回实际读取的字节数 n, 并附上超限错误.
|
|
||||||
// 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭.
|
|
||||||
return n, ErrBodyTooLarge
|
|
||||||
}
|
}
|
||||||
|
return int(remaining), ErrBodyTooLarge
|
||||||
// 一切正常, 返回读取的字节数和 nil 错误.
|
|
||||||
return n, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close 方法关闭底层的 ReadCloser, 保证资源释放.
|
// Close 方法关闭底层的 ReadCloser, 保证资源释放.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue