diff --git a/reverseproxy.go b/reverseproxy.go index e01f4d0..bb1784b 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -829,6 +829,19 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } + var closeOnce sync.Once + closeTunnel := func() { + closeOnce.Do(func() { + _ = c.Request.Body.Close() + _ = backWrite.Close() + _ = res.Body.Close() + }) + } + go func() { + <-req.Context().Done() + closeTunnel() + }() + errc := make(chan error, 2) go func() { _, err := io.Copy(backWrite, c.Request.Body) @@ -849,19 +862,24 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt errc <- closeErr }() - firstErr := <-errc - _ = c.Request.Body.Close() - _ = backWrite.Close() - _ = res.Body.Close() - secondErr := <-errc - - for _, err := range []error{firstErr, secondErr} { + var firstErr error + for i := 0; i < 2; i++ { + err := <-errc if reverseProxyIsBenignTunnelError(err) { continue } - return err + if firstErr == nil { + firstErr = err + closeTunnel() + } } - return nil + closeTunnel() + if reverseProxyIsBenignTunnelError(firstErr) { + return nil + } + + return firstErr + } func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { @@ -902,7 +920,7 @@ func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byt var written int64 for { nr, rerr := src.Read(buf) - if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) { + if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) { p.logf(nil, "reverse proxy read error during body copy: %v", rerr) } if nr > 0 { @@ -1371,7 +1389,19 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { } func reverseProxyIsBenignTunnelError(err error) bool { - return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) + return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err) +} + +func reverseProxyIsClosedBodyError(err error) bool { + if err == nil { + return false + } + switch err.Error() { + case "body closed by handler", "http2: response body closed", "response body closed": + return true + default: + return false + } } func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 345dd97..e56aa5e 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -967,6 +967,223 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("enable full duplex failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + reader := bufio.NewReader(r.Body) + line, err := reader.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(w, "ack:"+line); err != nil { + errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err) + return + } + _ = controller.Flush() + + if _, err := io.Copy(io.Discard, reader); err != nil { + errCh <- fmt.Errorf("wait for request half-close failed: %w", err) + return + } + if _, err := io.WriteString(w, "after-close\n"); err != nil { + errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err) + return + } + _ = controller.Flush() + })) + upstream.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { + t.Fatalf("configure upstream HTTP/2 server: %v", err) + } + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + reader := bufio.NewReader(resp.Body) + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read immediate tunneled response: %v", err) + } + if message != "ack:ping\n" { + t.Fatalf("unexpected immediate tunneled response: %q", message) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + + message, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("read post-close tunneled response: %v", err) + } + if message != "after-close\n" { + t.Fatalf("unexpected post-close tunneled response: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("enable full duplex failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + <-r.Context().Done() + })) + upstream.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { + t.Fatalf("configure upstream HTTP/2 server: %v", err) + } + upstream.StartTLS() + defer upstream.Close() + + proxyErrCh := make(chan error, 1) + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + select { + case proxyErrCh <- err: + default: + } + }, + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(ctx, http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + writeErrCh := make(chan error, 1) + go func() { + _, err := io.WriteString(pw, strings.Repeat("x", 1<<20)) + writeErrCh <- err + }() + time.Sleep(50 * time.Millisecond) + + cancel() + _ = pw.CloseWithError(context.Canceled) + select { + case <-writeErrCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for request body writer to unblock") + } + + select { + case err := <-proxyErrCh: + t.Fatalf("proxy error handler should not be called on cancellation, got: %v", err) + case <-time.After(200 * time.Millisecond): + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { t.Helper()