diff --git a/reverseproxy.go b/reverseproxy.go index 2d6dfea..afdbd9c 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -96,20 +96,21 @@ type reverseProxyExtendedConnectBridge struct { type reverseProxyH2ReadWriteCloser struct { io.ReadCloser ResponseWriter + controller *http.ResponseController } -func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { +func (rwc *reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { n, err := rwc.ResponseWriter.Write(p) if err != nil { return n, err } - if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + if err := rwc.controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { return n, err } return n, nil } -func (rwc reverseProxyH2ReadWriteCloser) Close() error { +func (rwc *reverseProxyH2ReadWriteCloser) Close() error { if rwc.ReadCloser == nil { return nil } @@ -1012,7 +1013,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } - conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer} + conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} backConnClosed := make(chan struct{}) go func() { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 8a250b2..85a64d4 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -836,9 +836,10 @@ func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *te flushErr := errors.New("flush failed") writer := &flushErrorResponseWriter{flushErr: flushErr} - conn := reverseProxyH2ReadWriteCloser{ + conn := &reverseProxyH2ReadWriteCloser{ ReadCloser: io.NopCloser(strings.NewReader("")), ResponseWriter: writer, + controller: http.NewResponseController(reverseProxyBaseResponseWriter(writer)), } n, err := conn.Write([]byte("ping"))