diff --git a/ecw.go b/ecw.go index 8f1417a..e43db3a 100644 --- a/ecw.go +++ b/ecw.go @@ -1,6 +1,9 @@ package touka import ( + "bufio" + "errors" + "net" "net/http" "sync" ) @@ -44,9 +47,9 @@ func (ecw *errorCapturingResponseWriter) reset(w http.ResponseWriter, r *http.Re // AcquireErrorCapturingResponseWriter 从对象池获取一个 errorCapturingResponseWriter 实例 // 必须在处理完成后调用 ReleaseErrorCapturingResponseWriter -func AcquireErrorCapturingResponseWriter(c *Context, eh ErrorHandler) *errorCapturingResponseWriter { +func AcquireErrorCapturingResponseWriter(c *Context) *errorCapturingResponseWriter { ecw := errorResponseWriterPool.Get().(*errorCapturingResponseWriter) - ecw.reset(c.Writer, c.Request, c, eh) // 传入 Touka Context 的 Writer + ecw.reset(c.Writer, c.Request, c, c.engine.errorHandle.handler) // 传入 Touka Context 的 Writer return ecw } @@ -74,9 +77,9 @@ func (ecw *errorCapturingResponseWriter) WriteHeader(statusCode int) { if ecw.responseStarted { return // 响应已开始, 忽略后续的 WriteHeader 调用 } - ecw.statusCode = statusCode // 总是记录 FileServer 意图的状态码 + ecw.statusCode = statusCode - if statusCode >= http.StatusBadRequest { + if ecw.Status() >= 400 { ecw.capturedErrorSignal = true // 是一个错误状态码 (>=400), 激活错误信号 // 不会将这个 WriteHeader 传递给原始的 w, 等待 processAfterFileServer 处理 @@ -108,7 +111,7 @@ func (ecw *errorCapturingResponseWriter) Write(data []byte) (int, error) { for k, v := range ecw.headerSnapshot { ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值 } - ecw.w.WriteHeader(ecw.statusCode) // 发送实际的状态码 (可能是 200 或之前设置的 2xx) + ecw.w.WriteHeader(ecw.Status()) // 发送实际的状态码 (可能是 200 或之前设置的 2xx) ecw.responseStarted = true } return ecw.w.Write(data) // 写入数据到原始 ResponseWriter @@ -133,7 +136,7 @@ func (ecw *errorCapturingResponseWriter) processAfterFileServer() { ecw.ctx.Next() } else { // 调用用户自定义的 ErrorHandlerFunc, 由它负责完整的错误响应 - ecw.errorHandlerFunc(ecw.ctx, ecw.statusCode) + ecw.errorHandlerFunc(ecw.ctx, ecw.Status()) ecw.ctx.Abort() } } @@ -141,3 +144,57 @@ func (ecw *errorCapturingResponseWriter) processAfterFileServer() { // 如果 ecw.capturedErrorSignal && ecw.responseStarted, 表示在捕获错误信号之前, // 成功路径的响应已经开始, 此时无法再进行错误处理覆盖 } + +// Status 返回当前记录的状态码 +func (ecw *errorCapturingResponseWriter) Status() int { + if ecw.statusCode == 0 && !ecw.responseStarted { + // 如果还没有显式设置状态码, 并且响应尚未开始, + // 则尝试从底层 ResponseWriter 获取状态码 (如果它实现了 Statuser) + if tw, ok := ecw.w.(ResponseWriter); ok { + return tw.Status() + } + // 否则, 默认返回 200 OK (Go HTTP server 的默认行为) + return http.StatusOK + } + return ecw.statusCode +} + +// Size 返回已写入响应体的字节数 +func (ecw *errorCapturingResponseWriter) Size() int { + // ecw 在捕获错误信号时会丢弃 FileServer 写入的数据, 所以 Size 应返回 0 + if ecw.capturedErrorSignal { + return 0 + } + // 否则, 尝试从底层 ResponseWriter 获取已写入的字节数 + if tw, ok := ecw.w.(ResponseWriter); ok { + return tw.Size() + } + // 对于其他类型的 ResponseWriter, 无法可靠获取, 只能返回 0 + return 0 +} + +// Written方式 +func (ecw *errorCapturingResponseWriter) Written() bool { + // 如果响应已经通过这个包装器开始写入 (WriteHeader 或 Write 成功调用) + // 或者如果原始 ResponseWriter 已经标记为 Written (例如, 如果它是 touka.ResponseWriterImpl) + // 则认为响应已开始 + if ecw.responseStarted { + return true + } + // 检查原始 ResponseWriter 是否已经写入 + if tw, ok := ecw.w.(ResponseWriter); ok { + return tw.Written() + } + // 对于其他类型的 ResponseWriter, 无法可靠判断是否已写入, 只能依赖 responseStarted 标记 + return false +} + +// Hijack 实现 http.Hijacker 接口 +// 它将 Hijack 调用委托给底层的 ResponseWriter +func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := ecw.w.(http.Hijacker) + if !ok { + return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface") + } + return hijacker.Hijack() +} diff --git a/engine.go b/engine.go index c3e4743..b97e90d 100644 --- a/engine.go +++ b/engine.go @@ -2,6 +2,7 @@ package touka import ( "context" + "log" "reflect" "runtime" "strings" @@ -52,7 +53,8 @@ type Engine struct { noRoute HandlerFunc - unMatchFS UnMatchFS // 未匹配下的处理 + unMatchFS UnMatchFS // 未匹配下的处理 + unMatchFileServer http.Handler // 处理handle serverProtocols *http.Protocols //服务协议 Protocols ProtocolsConfig //协议版本配置 @@ -73,6 +75,9 @@ func defaultErrorHandle(c *Context, code int) { // 检查客户端是否已断 return default: + if c.Writer.Written() { + return + } // 输出json 状态码与状态码对应描述 c.JSON(code, H{ "code": code, @@ -84,6 +89,22 @@ func defaultErrorHandle(c *Context, code int) { // 检查客户端是否已断 } } +// 默认errorhandle包装 避免竞争意外问题, 保证稳定性 +func defaultErrorWarp(handler ErrorHandler) ErrorHandler { + return func(c *Context, code int) { + select { + case <-c.Request.Context().Done(): + return + default: + if c.Writer.Written() { + log.Printf("errpage: response already started for status %d, skipping error page rendering", code) + return + } + } + handler(c, code) + } +} + type UnMatchFS struct { FSForUnmatched http.FileSystem ServeUnmatchedAsFS bool @@ -146,7 +167,7 @@ func Default() *Engine { // 设置自定义错误处理 func (engine *Engine) SetErrorHandler(handler ErrorHandler) { engine.errorHandle.useDefault = false - engine.errorHandle.handler = handler + engine.errorHandle.handler = defaultErrorWarp(handler) } // 获取一个默认错误处理handle @@ -159,8 +180,10 @@ func (engine *Engine) SetUnMatchFS(fs http.FileSystem) { if fs != nil { engine.unMatchFS.FSForUnmatched = fs engine.unMatchFS.ServeUnmatchedAsFS = true + engine.unMatchFileServer = http.FileServer(fs) } else { engine.unMatchFS.ServeUnmatchedAsFS = false + engine.unMatchFileServer = nil } } @@ -442,14 +465,20 @@ func (engine *Engine) handleRequest(c *Context) { func unMatchFSHandle() HandlerFunc { return func(c *Context) { engine := c.engine + // 确保 engine.unMatchFileServer 存在 + if !engine.unMatchFS.ServeUnmatchedAsFS || engine.unMatchFileServer == nil { + c.Next() // 如果未配置或 FileSystem 为 nil,则继续处理链 + return + } if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { // 使用 http.FileServer 处理未匹配的请求 - fileServer := http.FileServer(engine.unMatchFS.FSForUnmatched) + //fileServer := http.FileServer(engine.unMatchFS.FSForUnmatched) //ecw := newErrorCapturingResponseWriter(c, c.engine.errorHandle.handler) - ecw := AcquireErrorCapturingResponseWriter(c, c.engine.errorHandle.handler) + ecw := AcquireErrorCapturingResponseWriter(c) defer ReleaseErrorCapturingResponseWriter(ecw) - fileServer.ServeHTTP(ecw, c.Request) + c.engine.unMatchFileServer.ServeHTTP(ecw, c.Request) ecw.processAfterFileServer() + c.Abort() return } else { if engine.noRoute == nil { @@ -726,7 +755,7 @@ func (engine *Engine) Static(relativePath, rootPath string) { // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 // 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理 - ecw := AcquireErrorCapturingResponseWriter(c, c.engine.errorHandle.handler) + ecw := AcquireErrorCapturingResponseWriter(c) defer ReleaseErrorCapturingResponseWriter(ecw) // @@ -790,7 +819,7 @@ func (group *RouterGroup) Static(relativePath, rootPath string) { // 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码 // 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理 - ecw := AcquireErrorCapturingResponseWriter(c, group.engine.errorHandle.handler) + ecw := AcquireErrorCapturingResponseWriter(c) defer ReleaseErrorCapturingResponseWriter(ecw) //