diff --git a/respw.go b/respw.go index ea8aa85..4821019 100644 --- a/respw.go +++ b/respw.go @@ -18,13 +18,15 @@ type ResponseWriter interface { Status() int // 返回写入的 HTTP 状态码,如果未写入则为 0 Size() int // 返回已写入响应体的字节数 Written() bool // 返回 WriteHeader 是否已被调用 + IsHijacked() bool } // responseWriterImpl 是 ResponseWriter 的具体实现。 type responseWriterImpl struct { http.ResponseWriter - size int - status int // 0 表示尚未写入状态码 + size int + status int // 0 表示尚未写入状态码 + hijacked bool } // NewResponseWriter 创建并返回一个 responseWriterImpl 实例。 @@ -33,11 +35,15 @@ func newResponseWriter(w http.ResponseWriter) ResponseWriter { ResponseWriter: w, status: 0, // 明确初始状态 size: 0, + hijacked: false, } return rw } func (rw *responseWriterImpl) WriteHeader(statusCode int) { + if rw.hijacked { + return + } if rw.status == 0 { // 确保只设置一次 rw.status = statusCode rw.ResponseWriter.WriteHeader(statusCode) @@ -45,6 +51,9 @@ func (rw *responseWriterImpl) WriteHeader(statusCode int) { } func (rw *responseWriterImpl) Write(b []byte) (int, error) { + if rw.hijacked { + return 0, errors.New("http: response already hijacked") + } if rw.status == 0 { // 如果 WriteHeader 没被显式调用,Go 的 http server 会默认为 200 // 我们在这里也将其标记为 200,因为即将写入数据。 @@ -79,7 +88,15 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) { // Flush 实现 http.Flusher 接口。 func (rw *responseWriterImpl) Flush() { + if rw.hijacked { + return + } if fl, ok := rw.ResponseWriter.(http.Flusher); ok { fl.Flush() } } + +// IsHijacked 方法返回连接是否已被劫持。 +func (rw *responseWriterImpl) IsHijacked() bool { + return rw.hijacked +}