diff --git a/context.go b/context.go index 0620c41..a7d77e2 100644 --- a/context.go +++ b/context.go @@ -330,14 +330,27 @@ func (c *Context) ShouldBindJSON(obj interface{}) error { if c.Request.Body == nil { return errors.New("request body is empty") } - /* - decoder := json.NewDecoder(c.Request.Body) - if err := decoder.Decode(obj); err != nil { - return fmt.Errorf("json binding error: %w", err) + // defer c.Request.Body.Close() // 通常由调用方或中间件确保关闭,但如果这里是唯一消耗点,可以考虑 + + var reader io.Reader = c.Request.Body + if c.engine != nil && c.engine.MaxRequestBodySize > 0 { + if c.Request.ContentLength != -1 && c.Request.ContentLength > c.engine.MaxRequestBodySize { + return fmt.Errorf("request body size (%d bytes) exceeds configured limit (%d bytes)", c.Request.ContentLength, c.engine.MaxRequestBodySize) } - */ - err := json.UnmarshalRead(c.Request.Body, obj) + // 注意:http.MaxBytesReader(nil, ...) 中的 nil ResponseWriter 参数意味着当超出限制时, + // MaxBytesReader 会直接返回错误,而不会尝试写入 HTTP 错误响应。这对于 API 来说是合适的。 + reader = http.MaxBytesReader(nil, c.Request.Body, c.engine.MaxRequestBodySize) + } + + err := json.UnmarshalRead(reader, obj) if err != nil { + // 检查错误类型是否为 http.MaxBytesError + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + return fmt.Errorf("request body size exceeds configured limit (%d bytes): %w", c.engine.MaxRequestBodySize, err) + } + // 检查是否是 json 相关的错误,可能需要更细致的错误处理 + // 例如,json.SyntaxError, json.UnmarshalTypeError 等 return fmt.Errorf("json binding error: %w", err) } return nil @@ -441,13 +454,31 @@ func (c *Context) GetReqBody() io.ReadCloser { // 注意:请求体只能读取一次 func (c *Context) GetReqBodyFull() ([]byte, error) { if c.Request.Body == nil { - return nil, nil + return nil, nil // 或者返回一个错误: errors.New("request body is nil") } defer c.Request.Body.Close() // 确保请求体被关闭 - data, err := io.ReadAll(c.Request.Body) + + var reader io.Reader = c.Request.Body + if c.engine != nil && c.engine.MaxRequestBodySize > 0 { + if c.Request.ContentLength != -1 && c.Request.ContentLength > c.engine.MaxRequestBodySize { + err := fmt.Errorf("request body size (%d bytes) exceeds configured limit (%d bytes)", c.Request.ContentLength, c.engine.MaxRequestBodySize) + c.AddError(err) + return nil, err + } + reader = http.MaxBytesReader(nil, c.Request.Body, c.engine.MaxRequestBodySize) + } + + data, err := io.ReadAll(reader) if err != nil { - c.AddError(fmt.Errorf("failed to read request body: %w", err)) - return nil, fmt.Errorf("failed to read request body: %w", err) + // 检查错误类型是否为 http.MaxBytesError,如果是,则表示超出了限制 + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + err = fmt.Errorf("request body size exceeds configured limit (%d bytes): %w", c.engine.MaxRequestBodySize, err) + } else { + err = fmt.Errorf("failed to read request body: %w", err) + } + c.AddError(err) + return nil, err } return data, nil } @@ -455,13 +486,30 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { // 类似 GetReqBodyFull, 返回 *bytes.Buffer func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { if c.Request.Body == nil { - return nil, nil + return nil, nil // 或者返回一个错误: errors.New("request body is nil") } defer c.Request.Body.Close() // 确保请求体被关闭 - data, err := io.ReadAll(c.Request.Body) + + var reader io.Reader = c.Request.Body + if c.engine != nil && c.engine.MaxRequestBodySize > 0 { + if c.Request.ContentLength != -1 && c.Request.ContentLength > c.engine.MaxRequestBodySize { + err := fmt.Errorf("request body size (%d bytes) exceeds configured limit (%d bytes)", c.Request.ContentLength, c.engine.MaxRequestBodySize) + c.AddError(err) + return nil, err + } + reader = http.MaxBytesReader(nil, c.Request.Body, c.engine.MaxRequestBodySize) + } + + data, err := io.ReadAll(reader) if err != nil { - c.AddError(fmt.Errorf("failed to read request body: %w", err)) - return nil, fmt.Errorf("failed to read request body: %w", err) + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + err = fmt.Errorf("request body size exceeds configured limit (%d bytes): %w", c.engine.MaxRequestBodySize, err) + } else { + err = fmt.Errorf("failed to read request body: %w", err) + } + c.AddError(err) + return nil, err } return bytes.NewBuffer(data), nil } diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..5250788 --- /dev/null +++ b/context_test.go @@ -0,0 +1,238 @@ +package touka + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +type TestJSON struct { + Name string `json:"name"` + Value int `json:"value"` +} + +func TestGetReqBodyFull_Limit(t *testing.T) { + smallLimit := int64(10) + largeBody := "this is a body larger than 10 bytes" + smallBody := "small" + + // Scenario 1: Request body larger than limit + t.Run("BodyLargerThanLimit", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody)) + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) + + _, err := c.GetReqBodyFull() + assert.Error(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit)) + }) + + // Scenario 2: ContentLength header larger than limit + t.Run("ContentLengthLargerThanLimit", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody)) // Actual body is small + req.ContentLength = smallLimit + 1 + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) + + _, err := c.GetReqBodyFull() + assert.Error(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("request body size (%d bytes) exceeds configured limit (%d bytes)", smallLimit+1, smallLimit)) + }) + + // Scenario 3: Request body smaller than limit + t.Run("BodySmallerThanLimit", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody)) + req.ContentLength = int64(len(smallBody)) + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) + + bodyBytes, err := c.GetReqBodyFull() + assert.NoError(t, err) + assert.Equal(t, smallBody, string(bodyBytes)) + }) + + // Scenario 4: Request body slightly larger than limit, but no ContentLength + // http.MaxBytesReader will still catch this + t.Run("BodySlightlyLargerNoContentLength", func(t *testing.T) { + slightlyLargeBody := "elevenbytes" // 11 bytes + req, _ := http.NewRequest("POST", "/", strings.NewReader(slightlyLargeBody)) + // No ContentLength header + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) // Limit is 10 + + _, err := c.GetReqBodyFull() + assert.Error(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit)) + }) +} + +func TestShouldBindJSON_Limit(t *testing.T) { + smallLimit := int64(20) + validJSON := `{"name":"test","value":1}` // approx 25 bytes, check exact + largeJSON := `{"name":"this is a very long name","value":12345}` + smallValidJSON := `{"name":"s","v":1}` // small enough + + // Scenario 1: JSON body larger than limit + t.Run("JSONLargerThanLimit", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(largeJSON)) + req.Header.Set("Content-Type", "application/json") + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) + + var data TestJSON + err := c.ShouldBindJSON(&data) + assert.Error(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit)) + }) + + // Scenario 2: ContentLength header larger than limit for JSON + t.Run("ContentLengthLargerThanLimitJSON", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(smallValidJSON)) // Actual body is small + req.Header.Set("Content-Type", "application/json") + req.ContentLength = smallLimit + 1 + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) + + var data TestJSON + err := c.ShouldBindJSON(&data) + assert.Error(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("request body size (%d bytes) exceeds configured limit (%d bytes)", smallLimit+1, smallLimit)) + }) + + // Scenario 3: JSON body smaller than limit + t.Run("JSONSmallerThanLimit", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(validJSON)) + req.Header.Set("Content-Type", "application/json") + // Set a limit that is larger than the validJSON + engineLimit := int64(len(validJSON) + 5) + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(engineLimit) + + + var data TestJSON + err := c.ShouldBindJSON(&data) + assert.NoError(t, err) + assert.Equal(t, "test", data.Name) + assert.Equal(t, 1, data.Value) + }) + + // Scenario 4: JSON body (no content length) slightly larger than limit + t.Run("JSONSlightlyLargerNoContentLength", func(t *testing.T) { + // This JSON is `{"name":"abcde","value":1}` which is 24 bytes. Limit is 20. + slightlyLargeJSON := `{"name":"abcde","value":1}` + req, _ := http.NewRequest("POST", "/", strings.NewReader(slightlyLargeJSON)) + req.Header.Set("Content-Type", "application/json") + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) // Limit is 20 + + var data TestJSON + err := c.ShouldBindJSON(&data) + assert.Error(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit)) + }) +} + +func TestMaxRequestBodySize_Disabled(t *testing.T) { + largeBody := strings.Repeat("a", 20*1024*1024) // 20MB body + largeJSON := `{"name":"` + strings.Repeat("b", 5*1024*1024) + `","value":1}` // Large JSON + + // Scenario 1: GetReqBodyFull with MaxRequestBodySize = 0 + t.Run("GetReqBodyFull_DisabledZero", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody)) + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(0) // Disable limit + + bodyBytes, err := c.GetReqBodyFull() + assert.NoError(t, err) + assert.Equal(t, largeBody, string(bodyBytes)) + }) + + // Scenario 2: GetReqBodyFull with MaxRequestBodySize = -1 + t.Run("GetReqBodyFull_DisabledNegative", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody)) + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(-1) // Disable limit + + bodyBytes, err := c.GetReqBodyFull() + assert.NoError(t, err) + assert.Equal(t, largeBody, string(bodyBytes)) + }) + + // Scenario 3: ShouldBindJSON with MaxRequestBodySize = 0 + t.Run("ShouldBindJSON_DisabledZero", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(largeJSON)) + req.Header.Set("Content-Type", "application/json") + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(0) // Disable limit + + var data TestJSON + err := c.ShouldBindJSON(&data) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(data.Name, "bbb")) // Just check prefix of large name + assert.Equal(t, 1, data.Value) + }) + + // Scenario 4: ShouldBindJSON with MaxRequestBodySize = -1 + t.Run("ShouldBindJSON_DisabledNegative", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(largeJSON)) + req.Header.Set("Content-Type", "application/json") + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(-1) // Disable limit + + var data TestJSON + err := c.ShouldBindJSON(&data) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(data.Name, "bbb")) + assert.Equal(t, 1, data.Value) + }) +} + +// TestGetReqBodyBuffer_Limit (Optional, as logic is very similar to GetReqBodyFull) +// You can add tests for GetReqBodyBuffer if you want explicit coverage, +// but its core limiting logic is identical to GetReqBodyFull. +func TestGetReqBodyBuffer_Limit(t *testing.T) { + smallLimit := int64(10) + largeBody := "this is a body larger than 10 bytes" + smallBody := "small" + + // Scenario 1: Request body larger than limit + t.Run("BufferBodyLargerThanLimit", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(largeBody)) + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) + + _, err := c.GetReqBodyBuffer() + assert.Error(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("request body size exceeds configured limit (%d bytes)", smallLimit)) + }) + + // Scenario 2: ContentLength header larger than limit + t.Run("BufferContentLengthLargerThanLimit", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody)) + req.ContentLength = smallLimit + 1 + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) + + _, err := c.GetReqBodyBuffer() + assert.Error(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("request body size (%d bytes) exceeds configured limit (%d bytes)", smallLimit+1, smallLimit)) + }) + + // Scenario 3: Request body smaller than limit + t.Run("BufferBodySmallerThanLimit", func(t *testing.T) { + req, _ := http.NewRequest("POST", "/", strings.NewReader(smallBody)) + req.ContentLength = int64(len(smallBody)) + c, engine := CreateTestContextWithRequest(httptest.NewRecorder(), req) + engine.SetMaxRequestBodySize(smallLimit) + + buffer, err := c.GetReqBodyBuffer() + assert.NoError(t, err) + assert.Equal(t, smallBody, buffer.String()) + }) +} diff --git a/engine.go b/engine.go index ff47fc3..4922664 100644 --- a/engine.go +++ b/engine.go @@ -74,6 +74,8 @@ type Engine struct { // 如果设置了此回调,它将优先于 ServerConfigurator 被用于 HTTPS 服务器 // 如果未设置,HTTPS 服务器将回退使用 ServerConfigurator (如果已设置) TLSServerConfigurator func(*http.Server) + + MaxRequestBodySize int64 // 限制读取Body的最大字节数 } type ErrorHandle struct { @@ -87,12 +89,39 @@ type ErrorHandler func(c *Context, code int, err error) func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 select { case <-c.Request.Context().Done(): - + // 客户端断开连接,无需进一步处理 return default: + // 检查响应是否已经写入 if c.Writer.Written() { return } + + // 收集错误信息用于日志记录 + primaryErrStr := "none" + if err != nil { + primaryErrStr = err.Error() + } + + var collectedErrors []string + for _, e := range c.GetErrors() { + collectedErrors = append(collectedErrors, e.Error()) + } + collectedErrorsStr := strings.Join(collectedErrors, "; ") + if collectedErrorsStr == "" { + collectedErrorsStr = "none" + } + + // 记录错误日志 + logMessage := fmt.Sprintf("[Touka ErrorHandler] Request: [%s] %s | Primary Error: %s | Collected Errors: %s", + c.Request.Method, c.Request.URL.Path, primaryErrStr, collectedErrorsStr) + + if c.engine != nil && c.engine.LogReco != nil { + c.engine.LogReco.Error(logMessage) + } else { + log.Println(logMessage) // Fallback to standard logger + } + // 输出json 状态码与状态码对应描述 var errMsg string if err != nil { @@ -160,6 +189,7 @@ func New() *Engine { noRoutes: make(HandlersChain, 0), ServerConfigurator: nil, TLSServerConfigurator: nil, + MaxRequestBodySize: 10 * 1024 * 1024, // 默认 10MB } //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() @@ -189,6 +219,11 @@ func Default() *Engine { // === 外部操作方法 === +// SetMaxRequestBodySize 设置读取Body的最大字节数 +func (engine *Engine) SetMaxRequestBodySize(size int64) { + engine.MaxRequestBodySize = size +} + // SetServerConfigurator 设置一个函数,该函数将在任何 HTTP 或 HTTPS 服务器 // (通过 RunShutdown, RunTLS, RunTLSRedir) 启动前被调用, // 允许用户对底层的 *http.Server 实例进行自定义配置