enhance: improve reverse proxy error handling and add tests

This commit is contained in:
wjqserver 2026-04-02 19:33:18 +08:00
parent 50c6a23614
commit 7abedc1ace
2 changed files with 144 additions and 10 deletions

View file

@ -2,7 +2,9 @@ package touka
import (
"bufio"
"bytes"
"context"
crand "crypto/rand"
"crypto/tls"
"errors"
"fmt"
@ -829,6 +831,67 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) {
}
}
func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *testing.T) {
t.Helper()
flushErr := errors.New("flush failed")
writer := &flushErrorResponseWriter{flushErr: flushErr}
conn := reverseProxyH2ReadWriteCloser{
ReadCloser: io.NopCloser(strings.NewReader("")),
ResponseWriter: writer,
}
n, err := conn.Write([]byte("ping"))
if n != len("ping") {
t.Fatalf("unexpected bytes written: %d", n)
}
if !errors.Is(err, flushErr) {
t.Fatalf("unexpected write error: %v", err)
}
if got := writer.body.String(); got != "ping" {
t.Fatalf("unexpected buffered body: %q", got)
}
}
func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *testing.T) {
t.Helper()
transportCalled := atomic.Bool{}
entropyErr := errors.New("entropy source unavailable")
originalReader := crand.Reader
crand.Reader = errorReader{err: entropyErr}
t.Cleanup(func() {
crand.Reader = originalReader
})
engine := New()
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
Target: mustParseURL(t, "http://example.com"),
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
transportCalled.Store(true)
return nil, errors.New("unexpected round trip")
}),
ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) {
w.WriteHeader(reverseProxyStatusCode(err))
_, _ = io.WriteString(w, err.Error())
},
}))
headers := make(http.Header)
headers.Set(":protocol", "websocket")
rr := PerformRequest(engine, http.MethodConnect, "/ws", nil, headers)
if transportCalled.Load() {
t.Fatal("transport should not be called when websocket key generation fails")
}
if rr.Code != http.StatusBadGateway {
t.Fatalf("unexpected status: %d", rr.Code)
}
if body := rr.Body.String(); !strings.Contains(body, "reverse proxy failed to generate websocket key") || !strings.Contains(body, entropyErr.Error()) {
t.Fatalf("unexpected error body: %q", body)
}
}
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
t.Helper()
@ -2137,6 +2200,73 @@ func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error)
return fn(req)
}
type flushErrorResponseWriter struct {
header http.Header
body bytes.Buffer
status int
written bool
flushErr error
}
func (w *flushErrorResponseWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *flushErrorResponseWriter) WriteHeader(statusCode int) {
if w.written {
return
}
w.status = statusCode
w.written = true
}
func (w *flushErrorResponseWriter) Write(p []byte) (int, error) {
if !w.written {
w.WriteHeader(http.StatusOK)
}
return w.body.Write(p)
}
func (w *flushErrorResponseWriter) Flush() {}
func (w *flushErrorResponseWriter) FlushError() error {
if !w.written {
w.WriteHeader(http.StatusOK)
}
return w.flushErr
}
func (w *flushErrorResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, http.ErrNotSupported
}
func (w *flushErrorResponseWriter) Status() int {
return w.status
}
func (w *flushErrorResponseWriter) Size() int {
return w.body.Len()
}
func (w *flushErrorResponseWriter) Written() bool {
return w.written
}
func (w *flushErrorResponseWriter) IsHijacked() bool {
return false
}
type errorReader struct {
err error
}
func (r errorReader) Read([]byte) (int, error) {
return 0, r.err
}
func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper()
u, err := url.Parse(raw)