mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-02-03 08:51:11 +08:00
feat: Add request body size limit and enhance error logging
This commit introduces two main improvements to the Touka web framework:
1. **Configurable Request Body Size Limit:**
- Added `MaxRequestBodySize int64` to `touka.Engine` (default 10MB).
- You can customize this via `engine.SetMaxRequestBodySize()`.
- The context methods `GetReqBodyFull()`, `GetReqBodyBuffer()`, and `ShouldBindJSON()` now adhere to this limit. They check `Content-Length` upfront and use `http.MaxBytesReader` to prevent reading excessively large request bodies into memory, enhancing protection against potential DoS attacks or high memory usage.
- Added comprehensive unit tests in `context_test.go` for this feature, covering scenarios where the limit is active, disabled, and exceeded.
2. **Enhanced Error Logging in Default Handler:**
- The `defaultErrorHandle` in `engine.go` now logs not only the primary error passed to it but also any additional errors collected in `Context.Errors` (via `c.AddError()`).
- This provides more comprehensive diagnostic information in the logs without altering the JSON error response structure sent to the client, ensuring backward compatibility.
These changes aim to improve the framework's robustness, memory safety, and debuggability.
This commit is contained in:
parent
543b3165ca
commit
82099e26ee
3 changed files with 336 additions and 15 deletions
76
context.go
76
context.go
|
|
@ -330,14 +330,27 @@ func (c *Context) ShouldBindJSON(obj interface{}) error {
|
|||
if c.Request.Body == nil {
|
||||
return errors.New("request body is empty")
|
||||
}
|
||||
/*
|
||||
decoder := json.NewDecoder(c.Request.Body)
|
||||
if err := decoder.Decode(obj); err != nil {
|
||||
return fmt.Errorf("json binding error: %w", err)
|
||||
// defer c.Request.Body.Close() // 通常由调用方或中间件确保关闭,但如果这里是唯一消耗点,可以考虑
|
||||
|
||||
var reader io.Reader = c.Request.Body
|
||||
if c.engine != nil && c.engine.MaxRequestBodySize > 0 {
|
||||
if c.Request.ContentLength != -1 && c.Request.ContentLength > c.engine.MaxRequestBodySize {
|
||||
return fmt.Errorf("request body size (%d bytes) exceeds configured limit (%d bytes)", c.Request.ContentLength, c.engine.MaxRequestBodySize)
|
||||
}
|
||||
*/
|
||||
err := json.UnmarshalRead(c.Request.Body, obj)
|
||||
// 注意:http.MaxBytesReader(nil, ...) 中的 nil ResponseWriter 参数意味着当超出限制时,
|
||||
// MaxBytesReader 会直接返回错误,而不会尝试写入 HTTP 错误响应。这对于 API 来说是合适的。
|
||||
reader = http.MaxBytesReader(nil, c.Request.Body, c.engine.MaxRequestBodySize)
|
||||
}
|
||||
|
||||
err := json.UnmarshalRead(reader, obj)
|
||||
if err != nil {
|
||||
// 检查错误类型是否为 http.MaxBytesError
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
return fmt.Errorf("request body size exceeds configured limit (%d bytes): %w", c.engine.MaxRequestBodySize, err)
|
||||
}
|
||||
// 检查是否是 json 相关的错误,可能需要更细致的错误处理
|
||||
// 例如,json.SyntaxError, json.UnmarshalTypeError 等
|
||||
return fmt.Errorf("json binding error: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
|
@ -441,13 +454,31 @@ func (c *Context) GetReqBody() io.ReadCloser {
|
|||
// 注意:请求体只能读取一次
|
||||
func (c *Context) GetReqBodyFull() ([]byte, error) {
|
||||
if c.Request.Body == nil {
|
||||
return nil, nil
|
||||
return nil, nil // 或者返回一个错误: errors.New("request body is nil")
|
||||
}
|
||||
defer c.Request.Body.Close() // 确保请求体被关闭
|
||||
data, err := io.ReadAll(c.Request.Body)
|
||||
|
||||
var reader io.Reader = c.Request.Body
|
||||
if c.engine != nil && c.engine.MaxRequestBodySize > 0 {
|
||||
if c.Request.ContentLength != -1 && c.Request.ContentLength > c.engine.MaxRequestBodySize {
|
||||
err := fmt.Errorf("request body size (%d bytes) exceeds configured limit (%d bytes)", c.Request.ContentLength, c.engine.MaxRequestBodySize)
|
||||
c.AddError(err)
|
||||
return nil, err
|
||||
}
|
||||
reader = http.MaxBytesReader(nil, c.Request.Body, c.engine.MaxRequestBodySize)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
// 检查错误类型是否为 http.MaxBytesError,如果是,则表示超出了限制
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
err = fmt.Errorf("request body size exceeds configured limit (%d bytes): %w", c.engine.MaxRequestBodySize, err)
|
||||
} else {
|
||||
err = fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
c.AddError(err)
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
|
@ -455,13 +486,30 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
|
|||
// 类似 GetReqBodyFull, 返回 *bytes.Buffer
|
||||
func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
|
||||
if c.Request.Body == nil {
|
||||
return nil, nil
|
||||
return nil, nil // 或者返回一个错误: errors.New("request body is nil")
|
||||
}
|
||||
defer c.Request.Body.Close() // 确保请求体被关闭
|
||||
data, err := io.ReadAll(c.Request.Body)
|
||||
|
||||
var reader io.Reader = c.Request.Body
|
||||
if c.engine != nil && c.engine.MaxRequestBodySize > 0 {
|
||||
if c.Request.ContentLength != -1 && c.Request.ContentLength > c.engine.MaxRequestBodySize {
|
||||
err := fmt.Errorf("request body size (%d bytes) exceeds configured limit (%d bytes)", c.Request.ContentLength, c.engine.MaxRequestBodySize)
|
||||
c.AddError(err)
|
||||
return nil, err
|
||||
}
|
||||
reader = http.MaxBytesReader(nil, c.Request.Body, c.engine.MaxRequestBodySize)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
c.AddError(fmt.Errorf("failed to read request body: %w", err))
|
||||
return nil, fmt.Errorf("failed to read request body: %w", err)
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
err = fmt.Errorf("request body size exceeds configured limit (%d bytes): %w", c.engine.MaxRequestBodySize, err)
|
||||
} else {
|
||||
err = fmt.Errorf("failed to read request body: %w", err)
|
||||
}
|
||||
c.AddError(err)
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewBuffer(data), nil
|
||||
}
|
||||
|
|
|
|||
238
context_test.go
Normal file
238
context_test.go
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type TestJSON struct {
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
func TestGetReqBodyFull_Limit(t *testing.T) {
|
||||
smallLimit := int64(10)
|
||||
largeBody := "this is a body larger than 10 bytes"
|
||||
smallBody := "small"
|
||||
|
||||
// Scenario 1: Request body larger than limit
|
||||
t.Run("BodyLargerThanLimit", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody))
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit)
|
||||
|
||||
_, err := c.GetReqBodyFull()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
|
||||
})
|
||||
|
||||
// Scenario 2: ContentLength header larger than limit
|
||||
t.Run("ContentLengthLargerThanLimit", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody)) // Actual body is small
|
||||
req.ContentLength = smallLimit + 1
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit)
|
||||
|
||||
_, err := c.GetReqBodyFull()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("request body size (%d bytes) exceeds configured limit (%d bytes)", smallLimit+1, smallLimit))
|
||||
})
|
||||
|
||||
// Scenario 3: Request body smaller than limit
|
||||
t.Run("BodySmallerThanLimit", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody))
|
||||
req.ContentLength = int64(len(smallBody))
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit)
|
||||
|
||||
bodyBytes, err := c.GetReqBodyFull()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, smallBody, string(bodyBytes))
|
||||
})
|
||||
|
||||
// Scenario 4: Request body slightly larger than limit, but no ContentLength
|
||||
// http.MaxBytesReader will still catch this
|
||||
t.Run("BodySlightlyLargerNoContentLength", func(t *testing.T) {
|
||||
slightlyLargeBody := "elevenbytes" // 11 bytes
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(slightlyLargeBody))
|
||||
// No ContentLength header
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit) // Limit is 10
|
||||
|
||||
_, err := c.GetReqBodyFull()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldBindJSON_Limit(t *testing.T) {
|
||||
smallLimit := int64(20)
|
||||
validJSON := `{"name":"test","value":1}` // approx 25 bytes, check exact
|
||||
largeJSON := `{"name":"this is a very long name","value":12345}`
|
||||
smallValidJSON := `{"name":"s","v":1}` // small enough
|
||||
|
||||
// Scenario 1: JSON body larger than limit
|
||||
t.Run("JSONLargerThanLimit", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeJSON))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit)
|
||||
|
||||
var data TestJSON
|
||||
err := c.ShouldBindJSON(&data)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
|
||||
})
|
||||
|
||||
// Scenario 2: ContentLength header larger than limit for JSON
|
||||
t.Run("ContentLengthLargerThanLimitJSON", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallValidJSON)) // Actual body is small
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.ContentLength = smallLimit + 1
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit)
|
||||
|
||||
var data TestJSON
|
||||
err := c.ShouldBindJSON(&data)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("request body size (%d bytes) exceeds configured limit (%d bytes)", smallLimit+1, smallLimit))
|
||||
})
|
||||
|
||||
// Scenario 3: JSON body smaller than limit
|
||||
t.Run("JSONSmallerThanLimit", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(validJSON))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
// Set a limit that is larger than the validJSON
|
||||
engineLimit := int64(len(validJSON) + 5)
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(engineLimit)
|
||||
|
||||
|
||||
var data TestJSON
|
||||
err := c.ShouldBindJSON(&data)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", data.Name)
|
||||
assert.Equal(t, 1, data.Value)
|
||||
})
|
||||
|
||||
// Scenario 4: JSON body (no content length) slightly larger than limit
|
||||
t.Run("JSONSlightlyLargerNoContentLength", func(t *testing.T) {
|
||||
// This JSON is `{"name":"abcde","value":1}` which is 24 bytes. Limit is 20.
|
||||
slightlyLargeJSON := `{"name":"abcde","value":1}`
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(slightlyLargeJSON))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit) // Limit is 20
|
||||
|
||||
var data TestJSON
|
||||
err := c.ShouldBindJSON(&data)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaxRequestBodySize_Disabled(t *testing.T) {
|
||||
largeBody := strings.Repeat("a", 20*1024*1024) // 20MB body
|
||||
largeJSON := `{"name":"` + strings.Repeat("b", 5*1024*1024) + `","value":1}` // Large JSON
|
||||
|
||||
// Scenario 1: GetReqBodyFull with MaxRequestBodySize = 0
|
||||
t.Run("GetReqBodyFull_DisabledZero", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody))
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(0) // Disable limit
|
||||
|
||||
bodyBytes, err := c.GetReqBodyFull()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, largeBody, string(bodyBytes))
|
||||
})
|
||||
|
||||
// Scenario 2: GetReqBodyFull with MaxRequestBodySize = -1
|
||||
t.Run("GetReqBodyFull_DisabledNegative", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody))
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(-1) // Disable limit
|
||||
|
||||
bodyBytes, err := c.GetReqBodyFull()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, largeBody, string(bodyBytes))
|
||||
})
|
||||
|
||||
// Scenario 3: ShouldBindJSON with MaxRequestBodySize = 0
|
||||
t.Run("ShouldBindJSON_DisabledZero", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeJSON))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(0) // Disable limit
|
||||
|
||||
var data TestJSON
|
||||
err := c.ShouldBindJSON(&data)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, strings.HasPrefix(data.Name, "bbb")) // Just check prefix of large name
|
||||
assert.Equal(t, 1, data.Value)
|
||||
})
|
||||
|
||||
// Scenario 4: ShouldBindJSON with MaxRequestBodySize = -1
|
||||
t.Run("ShouldBindJSON_DisabledNegative", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeJSON))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(-1) // Disable limit
|
||||
|
||||
var data TestJSON
|
||||
err := c.ShouldBindJSON(&data)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, strings.HasPrefix(data.Name, "bbb"))
|
||||
assert.Equal(t, 1, data.Value)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGetReqBodyBuffer_Limit (Optional, as logic is very similar to GetReqBodyFull)
|
||||
// You can add tests for GetReqBodyBuffer if you want explicit coverage,
|
||||
// but its core limiting logic is identical to GetReqBodyFull.
|
||||
func TestGetReqBodyBuffer_Limit(t *testing.T) {
|
||||
smallLimit := int64(10)
|
||||
largeBody := "this is a body larger than 10 bytes"
|
||||
smallBody := "small"
|
||||
|
||||
// Scenario 1: Request body larger than limit
|
||||
t.Run("BufferBodyLargerThanLimit", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody))
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit)
|
||||
|
||||
_, err := c.GetReqBodyBuffer()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit))
|
||||
})
|
||||
|
||||
// Scenario 2: ContentLength header larger than limit
|
||||
t.Run("BufferContentLengthLargerThanLimit", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody))
|
||||
req.ContentLength = smallLimit + 1
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit)
|
||||
|
||||
_, err := c.GetReqBodyBuffer()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), fmt.Sprintf("request body size (%d bytes) exceeds configured limit (%d bytes)", smallLimit+1, smallLimit))
|
||||
})
|
||||
|
||||
// Scenario 3: Request body smaller than limit
|
||||
t.Run("BufferBodySmallerThanLimit", func(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody))
|
||||
req.ContentLength = int64(len(smallBody))
|
||||
c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req)
|
||||
engine.SetMaxRequestBodySize(smallLimit)
|
||||
|
||||
buffer, err := c.GetReqBodyBuffer()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, smallBody, buffer.String())
|
||||
})
|
||||
}
|
||||
37
engine.go
37
engine.go
|
|
@ -74,6 +74,8 @@ type Engine struct {
|
|||
// 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器
|
||||
// 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置)
|
||||
TLSServerConfigurator func(*http.Server)
|
||||
|
||||
MaxRequestBodySize int64 // 限制读取Body的最大字节数
|
||||
}
|
||||
|
||||
type ErrorHandle struct {
|
||||
|
|
@ -87,12 +89,39 @@ type ErrorHandler func(c *Context, code int, err error)
|
|||
func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
|
||||
// 客户端断开连接,无需进一步处理
|
||||
return
|
||||
default:
|
||||
// 检查响应是否已经写入
|
||||
if c.Writer.Written() {
|
||||
return
|
||||
}
|
||||
|
||||
// 收集错误信息用于日志记录
|
||||
primaryErrStr := "none"
|
||||
if err != nil {
|
||||
primaryErrStr = err.Error()
|
||||
}
|
||||
|
||||
var collectedErrors []string
|
||||
for _, e := range c.GetErrors() {
|
||||
collectedErrors = append(collectedErrors, e.Error())
|
||||
}
|
||||
collectedErrorsStr := strings.Join(collectedErrors, "; ")
|
||||
if collectedErrorsStr == "" {
|
||||
collectedErrorsStr = "none"
|
||||
}
|
||||
|
||||
// 记录错误日志
|
||||
logMessage := fmt.Sprintf("[Touka ErrorHandler] Request: [%s] %s | Primary Error: %s | Collected Errors: %s",
|
||||
c.Request.Method, c.Request.URL.Path, primaryErrStr, collectedErrorsStr)
|
||||
|
||||
if c.engine != nil && c.engine.LogReco != nil {
|
||||
c.engine.LogReco.Error(logMessage)
|
||||
} else {
|
||||
log.Println(logMessage) // Fallback to standard logger
|
||||
}
|
||||
|
||||
// 输出json 状态码与状态码对应描述
|
||||
var errMsg string
|
||||
if err != nil {
|
||||
|
|
@ -160,6 +189,7 @@ func New() *Engine {
|
|||
noRoutes: make(HandlersChain, 0),
|
||||
ServerConfigurator: nil,
|
||||
TLSServerConfigurator: nil,
|
||||
MaxRequestBodySize: 10 * 1024 * 1024, // 默认 10MB
|
||||
}
|
||||
//engine.SetProtocols(GetDefaultProtocolsConfig())
|
||||
engine.SetDefaultProtocols()
|
||||
|
|
@ -189,6 +219,11 @@ func Default() *Engine {
|
|||
|
||||
// === 外部操作方法 ===
|
||||
|
||||
// SetMaxRequestBodySize 设置读取Body的最大字节数
|
||||
func (engine *Engine) SetMaxRequestBodySize(size int64) {
|
||||
engine.MaxRequestBodySize = size
|
||||
}
|
||||
|
||||
// SetServerConfigurator 设置一个函数,该函数将在任何 HTTP 或 HTTPS 服务器
|
||||
// (通过 RunShutdown, RunTLS, RunTLSRedir) 启动前被调用,
|
||||
// 允许用户对底层的 *http.Server 实例进行自定义配置
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue