fix hijack

This commit is contained in:
wjqserver 2025-06-06 01:25:35 +08:00
parent 643fcd77ef
commit 1618f89ba5

View file

@ -18,6 +18,7 @@ type ResponseWriter interface {
Status() int // 返回写入的 HTTP 状态码,如果未写入则为 0 Status() int // 返回写入的 HTTP 状态码,如果未写入则为 0
Size() int // 返回已写入响应体的字节数 Size() int // 返回已写入响应体的字节数
Written() bool // 返回 WriteHeader 是否已被调用 Written() bool // 返回 WriteHeader 是否已被调用
IsHijacked() bool
} }
// responseWriterImpl 是 ResponseWriter 的具体实现。 // responseWriterImpl 是 ResponseWriter 的具体实现。
@ -25,6 +26,7 @@ type responseWriterImpl struct {
http.ResponseWriter http.ResponseWriter
size int size int
status int // 0 表示尚未写入状态码 status int // 0 表示尚未写入状态码
hijacked bool
} }
// NewResponseWriter 创建并返回一个 responseWriterImpl 实例。 // NewResponseWriter 创建并返回一个 responseWriterImpl 实例。
@ -33,11 +35,15 @@ func newResponseWriter(w http.ResponseWriter) ResponseWriter {
ResponseWriter: w, ResponseWriter: w,
status: 0, // 明确初始状态 status: 0, // 明确初始状态
size: 0, size: 0,
hijacked: false,
} }
return rw return rw
} }
func (rw *responseWriterImpl) WriteHeader(statusCode int) { func (rw *responseWriterImpl) WriteHeader(statusCode int) {
if rw.hijacked {
return
}
if rw.status == 0 { // 确保只设置一次 if rw.status == 0 { // 确保只设置一次
rw.status = statusCode rw.status = statusCode
rw.ResponseWriter.WriteHeader(statusCode) rw.ResponseWriter.WriteHeader(statusCode)
@ -45,6 +51,9 @@ func (rw *responseWriterImpl) WriteHeader(statusCode int) {
} }
func (rw *responseWriterImpl) Write(b []byte) (int, error) { func (rw *responseWriterImpl) Write(b []byte) (int, error) {
if rw.hijacked {
return 0, errors.New("http: response already hijacked")
}
if rw.status == 0 { if rw.status == 0 {
// 如果 WriteHeader 没被显式调用Go 的 http server 会默认为 200 // 如果 WriteHeader 没被显式调用Go 的 http server 会默认为 200
// 我们在这里也将其标记为 200因为即将写入数据。 // 我们在这里也将其标记为 200因为即将写入数据。
@ -79,7 +88,15 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Flush 实现 http.Flusher 接口。 // Flush 实现 http.Flusher 接口。
func (rw *responseWriterImpl) Flush() { func (rw *responseWriterImpl) Flush() {
if rw.hijacked {
return
}
if fl, ok := rw.ResponseWriter.(http.Flusher); ok { if fl, ok := rw.ResponseWriter.(http.Flusher); ok {
fl.Flush() fl.Flush()
} }
} }
// IsHijacked 方法返回连接是否已被劫持。
func (rw *responseWriterImpl) IsHijacked() bool {
return rw.hijacked
}