feat: Add request body size limit and enhance error logging

This commit introduces two main improvements to the Touka web framework:

1.  **Configurable Request Body Size Limit:**
    - Added `MaxRequestBodySize int64` to `touka.Engine` (default 10MB).
    - You can customize this via `engine.SetMaxRequestBodySize()`.
    - The context methods `GetReqBodyFull()`, `GetReqBodyBuffer()`, and `ShouldBindJSON()` now adhere to this limit. They check `Content-Length` upfront and use `http.MaxBytesReader` to prevent reading excessively large request bodies into memory, enhancing protection against potential DoS attacks or high memory usage.
    - Added comprehensive unit tests in `context_test.go` for this feature, covering scenarios where the limit is active, disabled, and exceeded.

2.  **Enhanced Error Logging in Default Handler:**
    - The `defaultErrorHandle` in `engine.go` now logs not only the primary error passed to it but also any additional errors collected in `Context.Errors` (via `c.AddError()`).
    - This provides more comprehensive diagnostic information in the logs without altering the JSON error response structure sent to the client, ensuring backward compatibility.

These changes aim to improve the framework's robustness, memory safety, and debuggability.
This commit is contained in:
google-labs-jules[bot] 2025-06-20 06:35:01 +00:00
parent 543b3165ca
commit 82099e26ee
3 changed files with 336 additions and 15 deletions

View file

@ -330,14 +330,27 @@ func (c *Context) ShouldBindJSON(obj interface{}) error {
if c.Request.Body == nil { if c.Request.Body == nil {
return errors.New("request body is empty") return errors.New("request body is empty")
} }
/* // defer c.Request.Body.Close() // 通常由调用方或中间件确保关闭,但如果这里是唯一消耗点,可以考虑
decoder := json.NewDecoder(c.Request.Body)
if err := decoder.Decode(obj); err != nil { var reader io.Reader = c.Request.Body
return fmt.Errorf("json binding error: %w", err) if c.engine != nil && c.engine.MaxRequestBodySize > 0 {
if c.Request.ContentLength != -1 && c.Request.ContentLength > c.engine.MaxRequestBodySize {
return fmt.Errorf("request body size (%d bytes) exceeds configured limit (%d bytes)", c.Request.ContentLength, c.engine.MaxRequestBodySize)
} }
*/ // 注意http.MaxBytesReader(nil, ...) 中的 nil ResponseWriter 参数意味着当超出限制时,
err := json.UnmarshalRead(c.Request.Body, obj) // MaxBytesReader 会直接返回错误,而不会尝试写入 HTTP 错误响应。这对于 API 来说是合适的。
reader = http.MaxBytesReader(nil, c.Request.Body, c.engine.MaxRequestBodySize)
}
err := json.UnmarshalRead(reader, obj)
if err != nil { if err != nil {
// 检查错误类型是否为 http.MaxBytesError
var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
return fmt.Errorf("request body size exceeds configured limit (%d bytes): %w", c.engine.MaxRequestBodySize, err)
}
// 检查是否是 json 相关的错误,可能需要更细致的错误处理
// 例如json.SyntaxError, json.UnmarshalTypeError 等
return fmt.Errorf("json binding error: %w", err) return fmt.Errorf("json binding error: %w", err)
} }
return nil return nil
@ -441,13 +454,31 @@ func (c *Context) GetReqBody() io.ReadCloser {
// 注意:请求体只能读取一次 // 注意:请求体只能读取一次
func (c *Context) GetReqBodyFull() ([]byte, error) { func (c *Context) GetReqBodyFull() ([]byte, error) {
if c.Request.Body == nil { if c.Request.Body == nil {
return nil, nil return nil, nil // 或者返回一个错误: errors.New("request body is nil")
} }
defer c.Request.Body.Close() // 确保请求体被关闭 defer c.Request.Body.Close() // 确保请求体被关闭
data, err := io.ReadAll(c.Request.Body)
var reader io.Reader = c.Request.Body
if c.engine != nil && c.engine.MaxRequestBodySize > 0 {
if c.Request.ContentLength != -1 && c.Request.ContentLength > c.engine.MaxRequestBodySize {
err := fmt.Errorf("request body size (%d bytes) exceeds configured limit (%d bytes)", c.Request.ContentLength, c.engine.MaxRequestBodySize)
c.AddError(err)
return nil, err
}
reader = http.MaxBytesReader(nil, c.Request.Body, c.engine.MaxRequestBodySize)
}
data, err := io.ReadAll(reader)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) // 检查错误类型是否为 http.MaxBytesError如果是则表示超出了限制
return nil, fmt.Errorf("failed to read request body: %w", err) var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
err = fmt.Errorf("request body size exceeds configured limit (%d bytes): %w", c.engine.MaxRequestBodySize, err)
} else {
err = fmt.Errorf("failed to read request body: %w", err)
}
c.AddError(err)
return nil, err
} }
return data, nil return data, nil
} }
@ -455,13 +486,30 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
// 类似 GetReqBodyFull, 返回 *bytes.Buffer // 类似 GetReqBodyFull, 返回 *bytes.Buffer
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
if c.Request.Body == nil { if c.Request.Body == nil {
return nil, nil return nil, nil // 或者返回一个错误: errors.New("request body is nil")
} }
defer c.Request.Body.Close() // 确保请求体被关闭 defer c.Request.Body.Close() // 确保请求体被关闭
data, err := io.ReadAll(c.Request.Body)
var reader io.Reader = c.Request.Body
if c.engine != nil && c.engine.MaxRequestBodySize > 0 {
if c.Request.ContentLength != -1 && c.Request.ContentLength > c.engine.MaxRequestBodySize {
err := fmt.Errorf("request body size (%d bytes) exceeds configured limit (%d bytes)", c.Request.ContentLength, c.engine.MaxRequestBodySize)
c.AddError(err)
return nil, err
}
reader = http.MaxBytesReader(nil, c.Request.Body, c.engine.MaxRequestBodySize)
}
data, err := io.ReadAll(reader)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) var maxBytesErr *http.MaxBytesError
return nil, fmt.Errorf("failed to read request body: %w", err) if errors.As(err, &maxBytesErr) {
err = fmt.Errorf("request body size exceeds configured limit (%d bytes): %w", c.engine.MaxRequestBodySize, err)
} else {
err = fmt.Errorf("failed to read request body: %w", err)
}
c.AddError(err)
return nil, err
} }
return bytes.NewBuffer(data), nil return bytes.NewBuffer(data), nil
} }

238
context_test.go Normal file
View file

@ -0,0 +1,238 @@
package touka
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
type TestJSON struct {
Name string `json:"name"`
Value int `json:"value"`
}
func TestGetReqBodyFull_Limit(t *testing.T) {
smallLimit := int64(10)
largeBody := "this is a body larger than 10 bytes"
smallBody := "small"
// Scenario 1: Request body larger than limit
t.Run("BodyLargerThanLimit", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody))
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit)
_, err := c.GetReqBodyFull()
assert.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
})
// Scenario 2: ContentLength header larger than limit
t.Run("ContentLengthLargerThanLimit", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody)) // Actual body is small
req.ContentLength = smallLimit + 1
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit)
_, err := c.GetReqBodyFull()
assert.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("request body size (%d bytes) exceeds configured limit (%d bytes)", smallLimit+1, smallLimit))
})
// Scenario 3: Request body smaller than limit
t.Run("BodySmallerThanLimit", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody))
req.ContentLength = int64(len(smallBody))
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit)
bodyBytes, err := c.GetReqBodyFull()
assert.NoError(t, err)
assert.Equal(t, smallBody, string(bodyBytes))
})
// Scenario 4: Request body slightly larger than limit, but no ContentLength
// http.MaxBytesReader will still catch this
t.Run("BodySlightlyLargerNoContentLength", func(t *testing.T) {
slightlyLargeBody := "elevenbytes" // 11 bytes
req, _ := http.NewRequest("POST", "/", strings.NewReader(slightlyLargeBody))
// No ContentLength header
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit) // Limit is 10
_, err := c.GetReqBodyFull()
assert.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
})
}
func TestShouldBindJSON_Limit(t *testing.T) {
smallLimit := int64(20)
validJSON := `{"name":"test","value":1}` // approx 25 bytes, check exact
largeJSON := `{"name":"this is a very long name","value":12345}`
smallValidJSON := `{"name":"s","v":1}` // small enough
// Scenario 1: JSON body larger than limit
t.Run("JSONLargerThanLimit", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeJSON))
req.Header.Set("Content-Type", "application/json")
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit)
var data TestJSON
err := c.ShouldBindJSON(&data)
assert.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
})
// Scenario 2: ContentLength header larger than limit for JSON
t.Run("ContentLengthLargerThanLimitJSON", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallValidJSON)) // Actual body is small
req.Header.Set("Content-Type", "application/json")
req.ContentLength = smallLimit + 1
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit)
var data TestJSON
err := c.ShouldBindJSON(&data)
assert.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("request body size (%d bytes) exceeds configured limit (%d bytes)", smallLimit+1, smallLimit))
})
// Scenario 3: JSON body smaller than limit
t.Run("JSONSmallerThanLimit", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(validJSON))
req.Header.Set("Content-Type", "application/json")
// Set a limit that is larger than the validJSON
engineLimit := int64(len(validJSON) + 5)
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(engineLimit)
var data TestJSON
err := c.ShouldBindJSON(&data)
assert.NoError(t, err)
assert.Equal(t, "test", data.Name)
assert.Equal(t, 1, data.Value)
})
// Scenario 4: JSON body (no content length) slightly larger than limit
t.Run("JSONSlightlyLargerNoContentLength", func(t *testing.T) {
// This JSON is `{"name":"abcde","value":1}` which is 24 bytes. Limit is 20.
slightlyLargeJSON := `{"name":"abcde","value":1}`
req, _ := http.NewRequest("POST", "/", strings.NewReader(slightlyLargeJSON))
req.Header.Set("Content-Type", "application/json")
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit) // Limit is 20
var data TestJSON
err := c.ShouldBindJSON(&data)
assert.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
})
}
func TestMaxRequestBodySize_Disabled(t *testing.T) {
largeBody := strings.Repeat("a", 20*1024*1024) // 20MB body
largeJSON := `{"name":"` + strings.Repeat("b", 5*1024*1024) + `","value":1}` // Large JSON
// Scenario 1: GetReqBodyFull with MaxRequestBodySize = 0
t.Run("GetReqBodyFull_DisabledZero", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody))
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(0) // Disable limit
bodyBytes, err := c.GetReqBodyFull()
assert.NoError(t, err)
assert.Equal(t, largeBody, string(bodyBytes))
})
// Scenario 2: GetReqBodyFull with MaxRequestBodySize = -1
t.Run("GetReqBodyFull_DisabledNegative", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody))
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(-1) // Disable limit
bodyBytes, err := c.GetReqBodyFull()
assert.NoError(t, err)
assert.Equal(t, largeBody, string(bodyBytes))
})
// Scenario 3: ShouldBindJSON with MaxRequestBodySize = 0
t.Run("ShouldBindJSON_DisabledZero", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeJSON))
req.Header.Set("Content-Type", "application/json")
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(0) // Disable limit
var data TestJSON
err := c.ShouldBindJSON(&data)
assert.NoError(t, err)
assert.True(t, strings.HasPrefix(data.Name, "bbb")) // Just check prefix of large name
assert.Equal(t, 1, data.Value)
})
// Scenario 4: ShouldBindJSON with MaxRequestBodySize = -1
t.Run("ShouldBindJSON_DisabledNegative", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeJSON))
req.Header.Set("Content-Type", "application/json")
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(-1) // Disable limit
var data TestJSON
err := c.ShouldBindJSON(&data)
assert.NoError(t, err)
assert.True(t, strings.HasPrefix(data.Name, "bbb"))
assert.Equal(t, 1, data.Value)
})
}
// TestGetReqBodyBuffer_Limit (Optional, as logic is very similar to GetReqBodyFull)
// You can add tests for GetReqBodyBuffer if you want explicit coverage,
// but its core limiting logic is identical to GetReqBodyFull.
func TestGetReqBodyBuffer_Limit(t *testing.T) {
smallLimit := int64(10)
largeBody := "this is a body larger than 10 bytes"
smallBody := "small"
// Scenario 1: Request body larger than limit
t.Run("BufferBodyLargerThanLimit", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody))
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit)
_, err := c.GetReqBodyBuffer()
assert.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
})
// Scenario 2: ContentLength header larger than limit
t.Run("BufferContentLengthLargerThanLimit", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody))
req.ContentLength = smallLimit + 1
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit)
_, err := c.GetReqBodyBuffer()
assert.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("request body size (%d bytes) exceeds configured limit (%d bytes)", smallLimit+1, smallLimit))
})
// Scenario 3: Request body smaller than limit
t.Run("BufferBodySmallerThanLimit", func(t *testing.T) {
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody))
req.ContentLength = int64(len(smallBody))
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
engine.SetMaxRequestBodySize(smallLimit)
buffer, err := c.GetReqBodyBuffer()
assert.NoError(t, err)
assert.Equal(t, smallBody, buffer.String())
})
}

View file

@ -74,6 +74,8 @@ type Engine struct {
// 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器 // 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器
// 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置) // 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置)
TLSServerConfigurator func(*http.Server) TLSServerConfigurator func(*http.Server)
MaxRequestBodySize int64 // 限制读取Body的最大字节数
} }
type ErrorHandle struct { type ErrorHandle struct {
@ -87,12 +89,39 @@ type ErrorHandler func(c *Context, code int, err error)
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
select { select {
case <-c.Request.Context().Done(): case <-c.Request.Context().Done():
// 客户端断开连接,无需进一步处理
return return
default: default:
// 检查响应是否已经写入
if c.Writer.Written() { if c.Writer.Written() {
return return
} }
// 收集错误信息用于日志记录
primaryErrStr := "none"
if err != nil {
primaryErrStr = err.Error()
}
var collectedErrors []string
for _, e := range c.GetErrors() {
collectedErrors = append(collectedErrors, e.Error())
}
collectedErrorsStr := strings.Join(collectedErrors, "; ")
if collectedErrorsStr == "" {
collectedErrorsStr = "none"
}
// 记录错误日志
logMessage := fmt.Sprintf("[Touka ErrorHandler] Request: [%s] %s | Primary Error: %s | Collected Errors: %s",
c.Request.Method, c.Request.URL.Path, primaryErrStr, collectedErrorsStr)
if c.engine != nil && c.engine.LogReco != nil {
c.engine.LogReco.Error(logMessage)
} else {
log.Println(logMessage) // Fallback to standard logger
}
// 输出json 状态码与状态码对应描述 // 输出json 状态码与状态码对应描述
var errMsg string var errMsg string
if err != nil { if err != nil {
@ -160,6 +189,7 @@ func New() *Engine {
noRoutes: make(HandlersChain, 0), noRoutes: make(HandlersChain, 0),
ServerConfigurator: nil, ServerConfigurator: nil,
TLSServerConfigurator: nil, TLSServerConfigurator: nil,
MaxRequestBodySize: 10 * 1024 * 1024, // 默认 10MB
} }
//engine.SetProtocols(GetDefaultProtocolsConfig()) //engine.SetProtocols(GetDefaultProtocolsConfig())
engine.SetDefaultProtocols() engine.SetDefaultProtocols()
@ -189,6 +219,11 @@ func Default() *Engine {
// === 外部操作方法 === // === 外部操作方法 ===
// SetMaxRequestBodySize 设置读取Body的最大字节数
func (engine *Engine) SetMaxRequestBodySize(size int64) {
engine.MaxRequestBodySize = size
}
// SetServerConfigurator 设置一个函数,该函数将在任何 HTTP 或 HTTPS 服务器 // SetServerConfigurator 设置一个函数,该函数将在任何 HTTP 或 HTTPS 服务器
// (通过 RunShutdown, RunTLS, RunTLSRedir) 启动前被调用, // (通过 RunShutdown, RunTLS, RunTLSRedir) 启动前被调用,
// 允许用户对底层的 *http.Server 实例进行自定义配置 // 允许用户对底层的 *http.Server 实例进行自定义配置