add IsHijacked for respw && add recover for flush

This commit is contained in:
wjqserver 2025-06-06 21:44:45 +08:00
parent 740dce54a2
commit 4249f0192e

View file

@ -3,13 +3,15 @@ package touka
import ( import (
"bufio" "bufio"
"errors" "errors"
"log"
"net" "net"
"net/http" "net/http"
"runtime/debug"
) )
// --- ResponseWriter 包装 --- // --- ResponseWriter 包装 ---
// ResponseWriter 接口扩展了 http.ResponseWriter 以提供对响应状态和大小的访问 // ResponseWriter 接口扩展了 http.ResponseWriter 以提供对响应状态和大小的访问
type ResponseWriter interface { type ResponseWriter interface {
http.ResponseWriter http.ResponseWriter
http.Hijacker // 支持 WebSocket 等 http.Hijacker // 支持 WebSocket 等
@ -21,7 +23,7 @@ type ResponseWriter interface {
IsHijacked() bool IsHijacked() bool
} }
// responseWriterImpl 是 ResponseWriter 的具体实现 // responseWriterImpl 是 ResponseWriter 的具体实现
type responseWriterImpl struct { type responseWriterImpl struct {
http.ResponseWriter http.ResponseWriter
size int size int
@ -29,7 +31,7 @@ type responseWriterImpl struct {
hijacked bool hijacked bool
} }
// NewResponseWriter 创建并返回一个 responseWriterImpl 实例 // NewResponseWriter 创建并返回一个 responseWriterImpl 实例
func newResponseWriter(w http.ResponseWriter) ResponseWriter { func newResponseWriter(w http.ResponseWriter) ResponseWriter {
rw := &responseWriterImpl{ rw := &responseWriterImpl{
ResponseWriter: w, ResponseWriter: w,
@ -40,6 +42,13 @@ func newResponseWriter(w http.ResponseWriter) ResponseWriter {
return rw return rw
} }
func (rw *responseWriterImpl) reset(w http.ResponseWriter) {
rw.ResponseWriter = w
rw.status = 0
rw.size = 0
rw.hijacked = false
}
func (rw *responseWriterImpl) WriteHeader(statusCode int) { func (rw *responseWriterImpl) WriteHeader(statusCode int) {
if rw.hijacked { if rw.hijacked {
return return
@ -56,7 +65,7 @@ func (rw *responseWriterImpl) Write(b []byte) (int, error) {
} }
if rw.status == 0 { if rw.status == 0 {
// 如果 WriteHeader 没被显式调用Go 的 http server 会默认为 200 // 如果 WriteHeader 没被显式调用Go 的 http server 会默认为 200
// 我们在这里也将其标记为 200因为即将写入数据 // 我们在这里也将其标记为 200因为即将写入数据
rw.status = http.StatusOK rw.status = http.StatusOK
// ResponseWriter.Write 会在第一次写入时自动调用 WriteHeader(http.StatusOK) // ResponseWriter.Write 会在第一次写入时自动调用 WriteHeader(http.StatusOK)
// 所以不需要在这里显式调用 rw.ResponseWriter.WriteHeader(http.StatusOK) // 所以不需要在这里显式调用 rw.ResponseWriter.WriteHeader(http.StatusOK)
@ -78,16 +87,42 @@ func (rw *responseWriterImpl) Written() bool {
return rw.status != 0 return rw.status != 0
} }
// Hijack 实现 http.Hijacker 接口 // Hijack 实现 http.Hijacker 接口
func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := rw.ResponseWriter.(http.Hijacker); ok { // 检查是否已劫持
return hj.Hijack() if rw.hijacked {
return nil, nil, errors.New("http: connection already hijacked")
} }
// 尝试从底层 ResponseWriter 获取 Hijacker 接口
hj, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("http.Hijacker interface not supported") return nil, nil, errors.New("http.Hijacker interface not supported")
}
// 调用底层的 Hijack 方法
conn, brw, err := hj.Hijack()
if err != nil {
// 如果劫持失败,返回错误
return nil, nil, err
}
// 如果劫持成功,更新内部状态
rw.hijacked = true
return conn, brw, nil
} }
// Flush 实现 http.Flusher 接口。 // Flush 实现 http.Flusher 接口
func (rw *responseWriterImpl) Flush() { func (rw *responseWriterImpl) Flush() {
defer func() {
if r := recover(); r != nil {
// 记录捕获到的 panic 信息,这表明底层连接可能已经关闭或失效
// 使用 log.Printf 记录,并包含堆栈信息,便于调试
log.Printf("Recovered from panic during responseWriterImpl.Flush for request: %v\nStack: %s", r, debug.Stack())
// 捕获后,不继续传播 panic允许请求的 goroutine 优雅退出
}
}()
if rw.hijacked { if rw.hijacked {
return return
} }
@ -96,7 +131,7 @@ func (rw *responseWriterImpl) Flush() {
} }
} }
// IsHijacked 方法返回连接是否已被劫持 // IsHijacked 方法返回连接是否已被劫持
func (rw *responseWriterImpl) IsHijacked() bool { func (rw *responseWriterImpl) IsHijacked() bool {
return rw.hijacked return rw.hijacked
} }