From 4249f0192ef19329b4b64cb7ec836ec47cf3fd53 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Fri, 6 Jun 2025 21:44:45 +0800 Subject: [PATCH] add IsHijacked for respw && add recover for flush --- respw.go | 55 +++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/respw.go b/respw.go index 4821019..e94bfa7 100644 --- a/respw.go +++ b/respw.go @@ -3,13 +3,15 @@ package touka import ( "bufio" "errors" + "log" "net" "net/http" + "runtime/debug" ) // --- ResponseWriter 包装 --- -// ResponseWriter 接口扩展了 http.ResponseWriter 以提供对响应状态和大小的访问。 +// ResponseWriter 接口扩展了 http.ResponseWriter 以提供对响应状态和大小的访问 type ResponseWriter interface { http.ResponseWriter http.Hijacker // 支持 WebSocket 等 @@ -21,7 +23,7 @@ type ResponseWriter interface { IsHijacked() bool } -// responseWriterImpl 是 ResponseWriter 的具体实现。 +// responseWriterImpl 是 ResponseWriter 的具体实现 type responseWriterImpl struct { http.ResponseWriter size int @@ -29,7 +31,7 @@ type responseWriterImpl struct { hijacked bool } -// NewResponseWriter 创建并返回一个 responseWriterImpl 实例。 +// NewResponseWriter 创建并返回一个 responseWriterImpl 实例 func newResponseWriter(w http.ResponseWriter) ResponseWriter { rw := &responseWriterImpl{ ResponseWriter: w, @@ -40,6 +42,13 @@ func newResponseWriter(w http.ResponseWriter) ResponseWriter { return rw } +func (rw *responseWriterImpl) reset(w http.ResponseWriter) { + rw.ResponseWriter = w + rw.status = 0 + rw.size = 0 + rw.hijacked = false +} + func (rw *responseWriterImpl) WriteHeader(statusCode int) { if rw.hijacked { return @@ -56,7 +65,7 @@ func (rw *responseWriterImpl) Write(b []byte) (int, error) { } if rw.status == 0 { // 如果 WriteHeader 没被显式调用,Go 的 http server 会默认为 200 - // 我们在这里也将其标记为 200,因为即将写入数据。 + // 我们在这里也将其标记为 200,因为即将写入数据 rw.status = http.StatusOK // ResponseWriter.Write 会在第一次写入时自动调用 WriteHeader(http.StatusOK) // 所以不需要在这里显式调用 rw.ResponseWriter.WriteHeader(http.StatusOK) @@ -78,16 +87,42 @@ func (rw *responseWriterImpl) Written() bool { return rw.status != 0 } -// Hijack 实现 http.Hijacker 接口。 +// Hijack 实现 http.Hijacker 接口 func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if hj, ok := rw.ResponseWriter.(http.Hijacker); ok { - return hj.Hijack() + // 检查是否已劫持 + if rw.hijacked { + return nil, nil, errors.New("http: connection already hijacked") } - return nil, nil, errors.New("http.Hijacker interface not supported") + + // 尝试从底层 ResponseWriter 获取 Hijacker 接口 + hj, ok := rw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("http.Hijacker interface not supported") + } + + // 调用底层的 Hijack 方法 + conn, brw, err := hj.Hijack() + if err != nil { + // 如果劫持失败,返回错误 + return nil, nil, err + } + + // 如果劫持成功,更新内部状态 + rw.hijacked = true + + return conn, brw, nil } -// Flush 实现 http.Flusher 接口。 +// Flush 实现 http.Flusher 接口 func (rw *responseWriterImpl) Flush() { + defer func() { + if r := recover(); r != nil { + // 记录捕获到的 panic 信息,这表明底层连接可能已经关闭或失效 + // 使用 log.Printf 记录,并包含堆栈信息,便于调试 + log.Printf("Recovered from panic during responseWriterImpl.Flush for request: %v\nStack: %s", r, debug.Stack()) + // 捕获后,不继续传播 panic,允许请求的 goroutine 优雅退出 + } + }() if rw.hijacked { return } @@ -96,7 +131,7 @@ func (rw *responseWriterImpl) Flush() { } } -// IsHijacked 方法返回连接是否已被劫持。 +// IsHijacked 方法返回连接是否已被劫持 func (rw *responseWriterImpl) IsHijacked() bool { return rw.hijacked }