From 7abedc1acea918edb522d00a5e1c3c7e7d45870d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:33:18 +0800 Subject: [PATCH] enhance: improve reverse proxy error handling and add tests --- reverseproxy.go | 24 ++++---- reverseproxy_test.go | 130 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 10 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index e674e7e..2d6dfea 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -101,10 +101,10 @@ type reverseProxyH2ReadWriteCloser struct { func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { n, err := rwc.ResponseWriter.Write(p) if err != nil { - return 0, err + return n, err } if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { - return 0, err + return n, err } return n, nil } @@ -479,7 +479,10 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { outreq := c.Request.Clone(ctx) - bridgeCtx, bridged := reverseProxyPrepareExtendedConnectBridge(outreq) + bridgeCtx, bridged, err := reverseProxyPrepareExtendedConnectBridge(outreq) + if err != nil { + return nil, nil, nil, err + } if bridged { outreq = outreq.WithContext(bridgeCtx) } @@ -1370,13 +1373,13 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } -func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool) { +func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { protocol := reverseProxyExtendedConnectProtocol(req) if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { if req == nil { - return context.Background(), false + return context.Background(), false, nil } - return req.Context(), false + return req.Context(), false, nil } bridge := &reverseProxyExtendedConnectBridge{body: req.Body} @@ -1389,10 +1392,11 @@ func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Contex req.Header.Set("Connection", "Upgrade") req.Header.Set("Sec-WebSocket-Version", "13") key, err := reverseProxyGenerateWebSocketKey() - if err == nil { - req.Header.Set("Sec-WebSocket-Key", key) + if err != nil { + return nil, false, fmt.Errorf("reverse proxy failed to generate websocket key: %w", err) } - return ctx, true + req.Header.Set("Sec-WebSocket-Key", key) + return ctx, true, nil } func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge { @@ -1405,7 +1409,7 @@ func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseP func reverseProxyGenerateWebSocketKey() (string, error) { key := make([]byte, 16) - if _, err := rand.Read(key); err != nil { + if _, err := io.ReadFull(rand.Reader, key); err != nil { return "", err } return base64.StdEncoding.EncodeToString(key), nil diff --git a/reverseproxy_test.go b/reverseproxy_test.go index b05f426..8a250b2 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -2,7 +2,9 @@ package touka import ( "bufio" + "bytes" "context" + crand "crypto/rand" "crypto/tls" "errors" "fmt" @@ -829,6 +831,67 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) { } } +func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *testing.T) { + t.Helper() + + flushErr := errors.New("flush failed") + writer := &flushErrorResponseWriter{flushErr: flushErr} + conn := reverseProxyH2ReadWriteCloser{ + ReadCloser: io.NopCloser(strings.NewReader("")), + ResponseWriter: writer, + } + + n, err := conn.Write([]byte("ping")) + if n != len("ping") { + t.Fatalf("unexpected bytes written: %d", n) + } + if !errors.Is(err, flushErr) { + t.Fatalf("unexpected write error: %v", err) + } + if got := writer.body.String(); got != "ping" { + t.Fatalf("unexpected buffered body: %q", got) + } +} + +func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *testing.T) { + t.Helper() + + transportCalled := atomic.Bool{} + entropyErr := errors.New("entropy source unavailable") + originalReader := crand.Reader + crand.Reader = errorReader{err: entropyErr} + t.Cleanup(func() { + crand.Reader = originalReader + }) + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + transportCalled.Store(true) + return nil, errors.New("unexpected round trip") + }), + ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) { + w.WriteHeader(reverseProxyStatusCode(err)) + _, _ = io.WriteString(w, err.Error()) + }, + })) + + headers := make(http.Header) + headers.Set(":protocol", "websocket") + rr := PerformRequest(engine, http.MethodConnect, "/ws", nil, headers) + + if transportCalled.Load() { + t.Fatal("transport should not be called when websocket key generation fails") + } + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "reverse proxy failed to generate websocket key") || !strings.Contains(body, entropyErr.Error()) { + t.Fatalf("unexpected error body: %q", body) + } +} + func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { t.Helper() @@ -2137,6 +2200,73 @@ func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) return fn(req) } +type flushErrorResponseWriter struct { + header http.Header + body bytes.Buffer + status int + written bool + flushErr error +} + +func (w *flushErrorResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *flushErrorResponseWriter) WriteHeader(statusCode int) { + if w.written { + return + } + w.status = statusCode + w.written = true +} + +func (w *flushErrorResponseWriter) Write(p []byte) (int, error) { + if !w.written { + w.WriteHeader(http.StatusOK) + } + return w.body.Write(p) +} + +func (w *flushErrorResponseWriter) Flush() {} + +func (w *flushErrorResponseWriter) FlushError() error { + if !w.written { + w.WriteHeader(http.StatusOK) + } + return w.flushErr +} + +func (w *flushErrorResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} + +func (w *flushErrorResponseWriter) Status() int { + return w.status +} + +func (w *flushErrorResponseWriter) Size() int { + return w.body.Len() +} + +func (w *flushErrorResponseWriter) Written() bool { + return w.written +} + +func (w *flushErrorResponseWriter) IsHijacked() bool { + return false +} + +type errorReader struct { + err error +} + +func (r errorReader) Read([]byte) (int, error) { + return 0, r.err +} + func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw)