mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
perf: fast-path default 404 and 405 responses
This commit is contained in:
parent
017bb13295
commit
7c37d4c38c
2 changed files with 83 additions and 0 deletions
37
engine.go
37
engine.go
|
|
@ -19,6 +19,7 @@ import (
|
||||||
|
|
||||||
"github.com/WJQSERVER-STUDIO/httpc"
|
"github.com/WJQSERVER-STUDIO/httpc"
|
||||||
"github.com/fenthope/reco"
|
"github.com/fenthope/reco"
|
||||||
|
"github.com/go-json-experiment/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Last 返回链中的最后一个处理函数
|
// Last 返回链中的最后一个处理函数
|
||||||
|
|
@ -132,6 +133,32 @@ type defaultErrorResponse struct {
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error())
|
||||||
|
var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error())
|
||||||
|
|
||||||
|
func mustMarshalDefaultErrorBody(code int, errMsg string) []byte {
|
||||||
|
body, err := json.Marshal(defaultErrorResponse{
|
||||||
|
Code: code,
|
||||||
|
Message: http.StatusText(code),
|
||||||
|
Error: errMsg,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeDefaultErrorJSON(c *Context, code int, body []byte) {
|
||||||
|
if c == nil || c.Writer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
c.Writer.WriteHeader(code)
|
||||||
|
_, _ = c.Writer.Write(body)
|
||||||
|
c.Writer.Flush()
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
var methodNotAllowedHandler HandlerFunc = func(c *Context) {
|
var methodNotAllowedHandler HandlerFunc = func(c *Context) {
|
||||||
httpMethod := c.Request.Method
|
httpMethod := c.Request.Method
|
||||||
requestPath := routeLookupPath(c.Request)
|
requestPath := routeLookupPath(c.Request)
|
||||||
|
|
@ -191,6 +218,16 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是
|
||||||
if c.Writer.Written() {
|
if c.Writer.Written() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if len(c.Errors) == 0 {
|
||||||
|
switch {
|
||||||
|
case code == http.StatusNotFound && errors.Is(err, errNotFound):
|
||||||
|
writeDefaultErrorJSON(c, code, defaultNotFoundBody)
|
||||||
|
return
|
||||||
|
case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed):
|
||||||
|
writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
// 输出json 状态码与状态码对应描述
|
// 输出json 状态码与状态码对应描述
|
||||||
var errMsg string
|
var errMsg string
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -139,3 +139,49 @@ func TestDefaultErrorHandleJSONShape(t *testing.T) {
|
||||||
t.Fatalf("unexpected error payload: %+v", body)
|
t.Fatalf("unexpected error payload: %+v", body)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultMethodNotAllowedJSONShape(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
|
||||||
|
if rr.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err)
|
||||||
|
}
|
||||||
|
if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" {
|
||||||
|
t.Fatalf("unexpected error payload: %+v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) {
|
||||||
|
engine := New()
|
||||||
|
engine.SetErrorHandler(func(c *Context, code int, err error) {
|
||||||
|
c.Writer.Header().Set("X-Custom-Error", "1")
|
||||||
|
c.String(code, "custom:%v", err)
|
||||||
|
})
|
||||||
|
engine.GET("/users", func(c *Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil)
|
||||||
|
if rr.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code)
|
||||||
|
}
|
||||||
|
if got := rr.Header().Get("X-Custom-Error"); got != "1" {
|
||||||
|
t.Fatalf("expected custom error header, got %q", got)
|
||||||
|
}
|
||||||
|
if rr.Body.String() != "custom:method not allowed" {
|
||||||
|
t.Fatalf("expected custom error body, got %q", rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue