add IsHijacked for respw && add recover for flush

This commit is contained in:
wjqserver 2025-06-06 21:44:45 +08:00
parent 740dce54a2
commit 4249f0192e

View file

@ -3,13 +3,15 @@ package touka
import (
"bufio"
"errors"
"log"
"net"
"net/http"
"runtime/debug"
)
// --- ResponseWriter 包装 ---
// ResponseWriter 接口扩展了 http.ResponseWriter 以提供对响应状态和大小的访问
// ResponseWriter 接口扩展了 http.ResponseWriter 以提供对响应状态和大小的访问
type ResponseWriter interface {
http.ResponseWriter
http.Hijacker // 支持 WebSocket 等
@ -21,7 +23,7 @@ type ResponseWriter interface {
IsHijacked() bool
}
// responseWriterImpl 是 ResponseWriter 的具体实现
// responseWriterImpl 是 ResponseWriter 的具体实现
type responseWriterImpl struct {
http.ResponseWriter
size int
@ -29,7 +31,7 @@ type responseWriterImpl struct {
hijacked bool
}
// NewResponseWriter 创建并返回一个 responseWriterImpl 实例
// NewResponseWriter 创建并返回一个 responseWriterImpl 实例
func newResponseWriter(w http.ResponseWriter) ResponseWriter {
rw := &responseWriterImpl{
ResponseWriter: w,
@ -40,6 +42,13 @@ func newResponseWriter(w http.ResponseWriter) ResponseWriter {
return rw
}
func (rw *responseWriterImpl) reset(w http.ResponseWriter) {
rw.ResponseWriter = w
rw.status = 0
rw.size = 0
rw.hijacked = false
}
func (rw *responseWriterImpl) WriteHeader(statusCode int) {
if rw.hijacked {
return
@ -56,7 +65,7 @@ func (rw *responseWriterImpl) Write(b []byte) (int, error) {
}
if rw.status == 0 {
// 如果 WriteHeader 没被显式调用Go 的 http server 会默认为 200
// 我们在这里也将其标记为 200因为即将写入数据
// 我们在这里也将其标记为 200因为即将写入数据
rw.status = http.StatusOK
// ResponseWriter.Write 会在第一次写入时自动调用 WriteHeader(http.StatusOK)
// 所以不需要在这里显式调用 rw.ResponseWriter.WriteHeader(http.StatusOK)
@ -78,16 +87,42 @@ func (rw *responseWriterImpl) Written() bool {
return rw.status != 0
}
// Hijack 实现 http.Hijacker 接口
// Hijack 实现 http.Hijacker 接口
func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := rw.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
// 检查是否已劫持
if rw.hijacked {
return nil, nil, errors.New("http: connection already hijacked")
}
return nil, nil, errors.New("http.Hijacker interface not supported")
// 尝试从底层 ResponseWriter 获取 Hijacker 接口
hj, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("http.Hijacker interface not supported")
}
// 调用底层的 Hijack 方法
conn, brw, err := hj.Hijack()
if err != nil {
// 如果劫持失败,返回错误
return nil, nil, err
}
// 如果劫持成功,更新内部状态
rw.hijacked = true
return conn, brw, nil
}
// Flush 实现 http.Flusher 接口。
// Flush 实现 http.Flusher 接口
func (rw *responseWriterImpl) Flush() {
defer func() {
if r := recover(); r != nil {
// 记录捕获到的 panic 信息,这表明底层连接可能已经关闭或失效
// 使用 log.Printf 记录,并包含堆栈信息,便于调试
log.Printf("Recovered from panic during responseWriterImpl.Flush for request: %v\nStack: %s", r, debug.Stack())
// 捕获后,不继续传播 panic允许请求的 goroutine 优雅退出
}
}()
if rw.hijacked {
return
}
@ -96,7 +131,7 @@ func (rw *responseWriterImpl) Flush() {
}
}
// IsHijacked 方法返回连接是否已被劫持
// IsHijacked 方法返回连接是否已被劫持
func (rw *responseWriterImpl) IsHijacked() bool {
return rw.hijacked
}