diff --git a/engine.go b/engine.go index b0723e7..536d6e1 100644 --- a/engine.go +++ b/engine.go @@ -19,6 +19,7 @@ import ( "github.com/WJQSERVER-STUDIO/httpc" "github.com/fenthope/reco" + "github.com/go-json-experiment/json" ) // Last 返回链中的最后一个处理函数 @@ -132,6 +133,32 @@ type defaultErrorResponse struct { Error string `json:"error"` } +var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error()) +var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error()) + +func mustMarshalDefaultErrorBody(code int, errMsg string) []byte { + body, err := json.Marshal(defaultErrorResponse{ + Code: code, + Message: http.StatusText(code), + Error: errMsg, + }) + if err != nil { + panic(err) + } + return body +} + +func writeDefaultErrorJSON(c *Context, code int, body []byte) { + if c == nil || c.Writer == nil { + return + } + c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") + c.Writer.WriteHeader(code) + _, _ = c.Writer.Write(body) + c.Writer.Flush() + c.Abort() +} + var methodNotAllowedHandler HandlerFunc = func(c *Context) { httpMethod := c.Request.Method requestPath := routeLookupPath(c.Request) @@ -191,6 +218,16 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是 if c.Writer.Written() { return } + if len(c.Errors) == 0 { + switch { + case code == http.StatusNotFound && errors.Is(err, errNotFound): + writeDefaultErrorJSON(c, code, defaultNotFoundBody) + return + case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed): + writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody) + return + } + } // 输出json 状态码与状态码对应描述 var errMsg string if err != nil { diff --git a/engine_test.go b/engine_test.go index 571f4b7..f6906b3 100644 --- a/engine_test.go +++ b/engine_test.go @@ -139,3 +139,49 @@ func TestDefaultErrorHandleJSONShape(t *testing.T) { t.Fatalf("unexpected error payload: %+v", body) } } + +func TestDefaultMethodNotAllowedJSONShape(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) + } + if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" { + t.Fatalf("unexpected error payload: %+v", body) + } +} + +func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) { + engine := New() + engine.SetErrorHandler(func(c *Context, code int, err error) { + c.Writer.Header().Set("X-Custom-Error", "1") + c.String(code, "custom:%v", err) + }) + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if got := rr.Header().Get("X-Custom-Error"); got != "1" { + t.Fatalf("expected custom error header, got %q", got) + } + if rr.Body.String() != "custom:method not allowed" { + t.Fatalf("expected custom error body, got %q", rr.Body.String()) + } +}