fix fileserver status chain

This commit is contained in:
wjqserver 2025-05-30 21:31:54 +08:00
parent 8ae88a77f0
commit 52b900db92
2 changed files with 99 additions and 13 deletions

69
ecw.go
View file

@ -1,6 +1,9 @@
package touka
import (
"bufio"
"errors"
"net"
"net/http"
"sync"
)
@ -44,9 +47,9 @@ func (ecw *errorCapturingResponseWriter) reset(w http.ResponseWriter, r *http.Re
// AcquireErrorCapturingResponseWriter 从对象池获取一个 errorCapturingResponseWriter 实例
// 必须在处理完成后调用 ReleaseErrorCapturingResponseWriter
func AcquireErrorCapturingResponseWriter(c *Context, eh ErrorHandler) *errorCapturingResponseWriter {
func AcquireErrorCapturingResponseWriter(c *Context) *errorCapturingResponseWriter {
ecw := errorResponseWriterPool.Get().(*errorCapturingResponseWriter)
ecw.reset(c.Writer, c.Request, c, eh) // 传入 Touka Context 的 Writer
ecw.reset(c.Writer, c.Request, c, c.engine.errorHandle.handler) // 传入 Touka Context 的 Writer
return ecw
}
@ -74,9 +77,9 @@ func (ecw *errorCapturingResponseWriter) WriteHeader(statusCode int) {
if ecw.responseStarted {
return // 响应已开始, 忽略后续的 WriteHeader 调用
}
ecw.statusCode = statusCode // 总是记录 FileServer 意图的状态码
ecw.statusCode = statusCode
if statusCode >= http.StatusBadRequest {
if ecw.Status() >= 400 {
ecw.capturedErrorSignal = true
// 是一个错误状态码 (>=400), 激活错误信号
// 不会将这个 WriteHeader 传递给原始的 w, 等待 processAfterFileServer 处理
@ -108,7 +111,7 @@ func (ecw *errorCapturingResponseWriter) Write(data []byte) (int, error) {
for k, v := range ecw.headerSnapshot {
ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值
}
ecw.w.WriteHeader(ecw.statusCode) // 发送实际的状态码 (可能是 200 或之前设置的 2xx)
ecw.w.WriteHeader(ecw.Status()) // 发送实际的状态码 (可能是 200 或之前设置的 2xx)
ecw.responseStarted = true
}
return ecw.w.Write(data) // 写入数据到原始 ResponseWriter
@ -133,7 +136,7 @@ func (ecw *errorCapturingResponseWriter) processAfterFileServer() {
ecw.ctx.Next()
} else {
// 调用用户自定义的 ErrorHandlerFunc, 由它负责完整的错误响应
ecw.errorHandlerFunc(ecw.ctx, ecw.statusCode)
ecw.errorHandlerFunc(ecw.ctx, ecw.Status())
ecw.ctx.Abort()
}
}
@ -141,3 +144,57 @@ func (ecw *errorCapturingResponseWriter) processAfterFileServer() {
// 如果 ecw.capturedErrorSignal && ecw.responseStarted, 表示在捕获错误信号之前,
// 成功路径的响应已经开始, 此时无法再进行错误处理覆盖
}
// Status 返回当前记录的状态码
func (ecw *errorCapturingResponseWriter) Status() int {
if ecw.statusCode == 0 && !ecw.responseStarted {
// 如果还没有显式设置状态码, 并且响应尚未开始,
// 则尝试从底层 ResponseWriter 获取状态码 (如果它实现了 Statuser)
if tw, ok := ecw.w.(ResponseWriter); ok {
return tw.Status()
}
// 否则, 默认返回 200 OK (Go HTTP server 的默认行为)
return http.StatusOK
}
return ecw.statusCode
}
// Size 返回已写入响应体的字节数
func (ecw *errorCapturingResponseWriter) Size() int {
// ecw 在捕获错误信号时会丢弃 FileServer 写入的数据, 所以 Size 应返回 0
if ecw.capturedErrorSignal {
return 0
}
// 否则, 尝试从底层 ResponseWriter 获取已写入的字节数
if tw, ok := ecw.w.(ResponseWriter); ok {
return tw.Size()
}
// 对于其他类型的 ResponseWriter, 无法可靠获取, 只能返回 0
return 0
}
// Written方式
func (ecw *errorCapturingResponseWriter) Written() bool {
// 如果响应已经通过这个包装器开始写入 (WriteHeader 或 Write 成功调用)
// 或者如果原始 ResponseWriter 已经标记为 Written (例如, 如果它是 touka.ResponseWriterImpl)
// 则认为响应已开始
if ecw.responseStarted {
return true
}
// 检查原始 ResponseWriter 是否已经写入
if tw, ok := ecw.w.(ResponseWriter); ok {
return tw.Written()
}
// 对于其他类型的 ResponseWriter, 无法可靠判断是否已写入, 只能依赖 responseStarted 标记
return false
}
// Hijack 实现 http.Hijacker 接口
// 它将 Hijack 调用委托给底层的 ResponseWriter
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := ecw.w.(http.Hijacker)
if !ok {
return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface")
}
return hijacker.Hijack()
}