mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-02-03 08:51:11 +08:00
fix fileserver status chain
This commit is contained in:
parent
8ae88a77f0
commit
52b900db92
2 changed files with 99 additions and 13 deletions
69
ecw.go
69
ecw.go
|
|
@ -1,6 +1,9 @@
|
||||||
package touka
|
package touka
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
@ -44,9 +47,9 @@ func (ecw *errorCapturingResponseWriter) reset(w http.ResponseWriter, r *http.Re
|
||||||
|
|
||||||
// AcquireErrorCapturingResponseWriter 从对象池获取一个 errorCapturingResponseWriter 实例
|
// AcquireErrorCapturingResponseWriter 从对象池获取一个 errorCapturingResponseWriter 实例
|
||||||
// 必须在处理完成后调用 ReleaseErrorCapturingResponseWriter
|
// 必须在处理完成后调用 ReleaseErrorCapturingResponseWriter
|
||||||
func AcquireErrorCapturingResponseWriter(c *Context, eh ErrorHandler) *errorCapturingResponseWriter {
|
func AcquireErrorCapturingResponseWriter(c *Context) *errorCapturingResponseWriter {
|
||||||
ecw := errorResponseWriterPool.Get().(*errorCapturingResponseWriter)
|
ecw := errorResponseWriterPool.Get().(*errorCapturingResponseWriter)
|
||||||
ecw.reset(c.Writer, c.Request, c, eh) // 传入 Touka Context 的 Writer
|
ecw.reset(c.Writer, c.Request, c, c.engine.errorHandle.handler) // 传入 Touka Context 的 Writer
|
||||||
return ecw
|
return ecw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -74,9 +77,9 @@ func (ecw *errorCapturingResponseWriter) WriteHeader(statusCode int) {
|
||||||
if ecw.responseStarted {
|
if ecw.responseStarted {
|
||||||
return // 响应已开始, 忽略后续的 WriteHeader 调用
|
return // 响应已开始, 忽略后续的 WriteHeader 调用
|
||||||
}
|
}
|
||||||
ecw.statusCode = statusCode // 总是记录 FileServer 意图的状态码
|
ecw.statusCode = statusCode
|
||||||
|
|
||||||
if statusCode >= http.StatusBadRequest {
|
if ecw.Status() >= 400 {
|
||||||
ecw.capturedErrorSignal = true
|
ecw.capturedErrorSignal = true
|
||||||
// 是一个错误状态码 (>=400), 激活错误信号
|
// 是一个错误状态码 (>=400), 激活错误信号
|
||||||
// 不会将这个 WriteHeader 传递给原始的 w, 等待 processAfterFileServer 处理
|
// 不会将这个 WriteHeader 传递给原始的 w, 等待 processAfterFileServer 处理
|
||||||
|
|
@ -108,7 +111,7 @@ func (ecw *errorCapturingResponseWriter) Write(data []byte) (int, error) {
|
||||||
for k, v := range ecw.headerSnapshot {
|
for k, v := range ecw.headerSnapshot {
|
||||||
ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值
|
ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值
|
||||||
}
|
}
|
||||||
ecw.w.WriteHeader(ecw.statusCode) // 发送实际的状态码 (可能是 200 或之前设置的 2xx)
|
ecw.w.WriteHeader(ecw.Status()) // 发送实际的状态码 (可能是 200 或之前设置的 2xx)
|
||||||
ecw.responseStarted = true
|
ecw.responseStarted = true
|
||||||
}
|
}
|
||||||
return ecw.w.Write(data) // 写入数据到原始 ResponseWriter
|
return ecw.w.Write(data) // 写入数据到原始 ResponseWriter
|
||||||
|
|
@ -133,7 +136,7 @@ func (ecw *errorCapturingResponseWriter) processAfterFileServer() {
|
||||||
ecw.ctx.Next()
|
ecw.ctx.Next()
|
||||||
} else {
|
} else {
|
||||||
// 调用用户自定义的 ErrorHandlerFunc, 由它负责完整的错误响应
|
// 调用用户自定义的 ErrorHandlerFunc, 由它负责完整的错误响应
|
||||||
ecw.errorHandlerFunc(ecw.ctx, ecw.statusCode)
|
ecw.errorHandlerFunc(ecw.ctx, ecw.Status())
|
||||||
ecw.ctx.Abort()
|
ecw.ctx.Abort()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -141,3 +144,57 @@ func (ecw *errorCapturingResponseWriter) processAfterFileServer() {
|
||||||
// 如果 ecw.capturedErrorSignal && ecw.responseStarted, 表示在捕获错误信号之前,
|
// 如果 ecw.capturedErrorSignal && ecw.responseStarted, 表示在捕获错误信号之前,
|
||||||
// 成功路径的响应已经开始, 此时无法再进行错误处理覆盖
|
// 成功路径的响应已经开始, 此时无法再进行错误处理覆盖
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Status 返回当前记录的状态码
|
||||||
|
func (ecw *errorCapturingResponseWriter) Status() int {
|
||||||
|
if ecw.statusCode == 0 && !ecw.responseStarted {
|
||||||
|
// 如果还没有显式设置状态码, 并且响应尚未开始,
|
||||||
|
// 则尝试从底层 ResponseWriter 获取状态码 (如果它实现了 Statuser)
|
||||||
|
if tw, ok := ecw.w.(ResponseWriter); ok {
|
||||||
|
return tw.Status()
|
||||||
|
}
|
||||||
|
// 否则, 默认返回 200 OK (Go HTTP server 的默认行为)
|
||||||
|
return http.StatusOK
|
||||||
|
}
|
||||||
|
return ecw.statusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size 返回已写入响应体的字节数
|
||||||
|
func (ecw *errorCapturingResponseWriter) Size() int {
|
||||||
|
// ecw 在捕获错误信号时会丢弃 FileServer 写入的数据, 所以 Size 应返回 0
|
||||||
|
if ecw.capturedErrorSignal {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
// 否则, 尝试从底层 ResponseWriter 获取已写入的字节数
|
||||||
|
if tw, ok := ecw.w.(ResponseWriter); ok {
|
||||||
|
return tw.Size()
|
||||||
|
}
|
||||||
|
// 对于其他类型的 ResponseWriter, 无法可靠获取, 只能返回 0
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Written方式
|
||||||
|
func (ecw *errorCapturingResponseWriter) Written() bool {
|
||||||
|
// 如果响应已经通过这个包装器开始写入 (WriteHeader 或 Write 成功调用)
|
||||||
|
// 或者如果原始 ResponseWriter 已经标记为 Written (例如, 如果它是 touka.ResponseWriterImpl)
|
||||||
|
// 则认为响应已开始
|
||||||
|
if ecw.responseStarted {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// 检查原始 ResponseWriter 是否已经写入
|
||||||
|
if tw, ok := ecw.w.(ResponseWriter); ok {
|
||||||
|
return tw.Written()
|
||||||
|
}
|
||||||
|
// 对于其他类型的 ResponseWriter, 无法可靠判断是否已写入, 只能依赖 responseStarted 标记
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack 实现 http.Hijacker 接口
|
||||||
|
// 它将 Hijack 调用委托给底层的 ResponseWriter
|
||||||
|
func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
hijacker, ok := ecw.w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface")
|
||||||
|
}
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
|
|
||||||
41
engine.go
41
engine.go
|
|
@ -2,6 +2,7 @@ package touka
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -53,6 +54,7 @@ type Engine struct {
|
||||||
noRoute HandlerFunc
|
noRoute HandlerFunc
|
||||||
|
|
||||||
unMatchFS UnMatchFS // 未匹配下的处理
|
unMatchFS UnMatchFS // 未匹配下的处理
|
||||||
|
unMatchFileServer http.Handler // 处理handle
|
||||||
|
|
||||||
serverProtocols *http.Protocols //服务协议
|
serverProtocols *http.Protocols //服务协议
|
||||||
Protocols ProtocolsConfig //协议版本配置
|
Protocols ProtocolsConfig //协议版本配置
|
||||||
|
|
@ -73,6 +75,9 @@ func defaultErrorHandle(c *Context, code int) { // 检查客户端是否已断
|
||||||
|
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
|
if c.Writer.Written() {
|
||||||
|
return
|
||||||
|
}
|
||||||
// 输出json 状态码与状态码对应描述
|
// 输出json 状态码与状态码对应描述
|
||||||
c.JSON(code, H{
|
c.JSON(code, H{
|
||||||
"code": code,
|
"code": code,
|
||||||
|
|
@ -84,6 +89,22 @@ func defaultErrorHandle(c *Context, code int) { // 检查客户端是否已断
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 默认errorhandle包装 避免竞争意外问题, 保证稳定性
|
||||||
|
func defaultErrorWarp(handler ErrorHandler) ErrorHandler {
|
||||||
|
return func(c *Context, code int) {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if c.Writer.Written() {
|
||||||
|
log.Printf("errpage: response already started for status %d, skipping error page rendering", code)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handler(c, code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type UnMatchFS struct {
|
type UnMatchFS struct {
|
||||||
FSForUnmatched http.FileSystem
|
FSForUnmatched http.FileSystem
|
||||||
ServeUnmatchedAsFS bool
|
ServeUnmatchedAsFS bool
|
||||||
|
|
@ -146,7 +167,7 @@ func Default() *Engine {
|
||||||
// 设置自定义错误处理
|
// 设置自定义错误处理
|
||||||
func (engine *Engine) SetErrorHandler(handler ErrorHandler) {
|
func (engine *Engine) SetErrorHandler(handler ErrorHandler) {
|
||||||
engine.errorHandle.useDefault = false
|
engine.errorHandle.useDefault = false
|
||||||
engine.errorHandle.handler = handler
|
engine.errorHandle.handler = defaultErrorWarp(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取一个默认错误处理handle
|
// 获取一个默认错误处理handle
|
||||||
|
|
@ -159,8 +180,10 @@ func (engine *Engine) SetUnMatchFS(fs http.FileSystem) {
|
||||||
if fs != nil {
|
if fs != nil {
|
||||||
engine.unMatchFS.FSForUnmatched = fs
|
engine.unMatchFS.FSForUnmatched = fs
|
||||||
engine.unMatchFS.ServeUnmatchedAsFS = true
|
engine.unMatchFS.ServeUnmatchedAsFS = true
|
||||||
|
engine.unMatchFileServer = http.FileServer(fs)
|
||||||
} else {
|
} else {
|
||||||
engine.unMatchFS.ServeUnmatchedAsFS = false
|
engine.unMatchFS.ServeUnmatchedAsFS = false
|
||||||
|
engine.unMatchFileServer = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -442,14 +465,20 @@ func (engine *Engine) handleRequest(c *Context) {
|
||||||
func unMatchFSHandle() HandlerFunc {
|
func unMatchFSHandle() HandlerFunc {
|
||||||
return func(c *Context) {
|
return func(c *Context) {
|
||||||
engine := c.engine
|
engine := c.engine
|
||||||
|
// 确保 engine.unMatchFileServer 存在
|
||||||
|
if !engine.unMatchFS.ServeUnmatchedAsFS || engine.unMatchFileServer == nil {
|
||||||
|
c.Next() // 如果未配置或 FileSystem 为 nil,则继续处理链
|
||||||
|
return
|
||||||
|
}
|
||||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead {
|
||||||
// 使用 http.FileServer 处理未匹配的请求
|
// 使用 http.FileServer 处理未匹配的请求
|
||||||
fileServer := http.FileServer(engine.unMatchFS.FSForUnmatched)
|
//fileServer := http.FileServer(engine.unMatchFS.FSForUnmatched)
|
||||||
//ecw := newErrorCapturingResponseWriter(c, c.engine.errorHandle.handler)
|
//ecw := newErrorCapturingResponseWriter(c, c.engine.errorHandle.handler)
|
||||||
ecw := AcquireErrorCapturingResponseWriter(c, c.engine.errorHandle.handler)
|
ecw := AcquireErrorCapturingResponseWriter(c)
|
||||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
defer ReleaseErrorCapturingResponseWriter(ecw)
|
||||||
fileServer.ServeHTTP(ecw, c.Request)
|
c.engine.unMatchFileServer.ServeHTTP(ecw, c.Request)
|
||||||
ecw.processAfterFileServer()
|
ecw.processAfterFileServer()
|
||||||
|
c.Abort()
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
if engine.noRoute == nil {
|
if engine.noRoute == nil {
|
||||||
|
|
@ -726,7 +755,7 @@ func (engine *Engine) Static(relativePath, rootPath string) {
|
||||||
|
|
||||||
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
||||||
// 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理
|
// 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理
|
||||||
ecw := AcquireErrorCapturingResponseWriter(c, c.engine.errorHandle.handler)
|
ecw := AcquireErrorCapturingResponseWriter(c)
|
||||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
defer ReleaseErrorCapturingResponseWriter(ecw)
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
@ -790,7 +819,7 @@ func (group *RouterGroup) Static(relativePath, rootPath string) {
|
||||||
|
|
||||||
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
// 使用自定义的 ResponseWriter 包装器来捕获 FileServer 可能返回的错误状态码
|
||||||
// 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理
|
// 这样我们可以在 FileServer 返回 404 或 403 时,使用 Engine 的 ErrorHandler 进行统一处理
|
||||||
ecw := AcquireErrorCapturingResponseWriter(c, group.engine.errorHandle.handler)
|
ecw := AcquireErrorCapturingResponseWriter(c)
|
||||||
defer ReleaseErrorCapturingResponseWriter(ecw)
|
defer ReleaseErrorCapturingResponseWriter(ecw)
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue