mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-15 16:47:38 +08:00
enhance: improve reverse proxy error handling and add tests
This commit is contained in:
parent
50c6a23614
commit
7abedc1ace
2 changed files with 144 additions and 10 deletions
|
|
@ -101,10 +101,10 @@ type reverseProxyH2ReadWriteCloser struct {
|
||||||
func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) {
|
func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) {
|
||||||
n, err := rwc.ResponseWriter.Write(p)
|
n, err := rwc.ResponseWriter.Write(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return n, err
|
||||||
}
|
}
|
||||||
if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||||
return 0, err
|
return n, err
|
||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
@ -479,7 +479,10 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte
|
||||||
|
|
||||||
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
|
func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) {
|
||||||
outreq := c.Request.Clone(ctx)
|
outreq := c.Request.Clone(ctx)
|
||||||
bridgeCtx, bridged := reverseProxyPrepareExtendedConnectBridge(outreq)
|
bridgeCtx, bridged, err := reverseProxyPrepareExtendedConnectBridge(outreq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
if bridged {
|
if bridged {
|
||||||
outreq = outreq.WithContext(bridgeCtx)
|
outreq = outreq.WithContext(bridgeCtx)
|
||||||
}
|
}
|
||||||
|
|
@ -1370,13 +1373,13 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool {
|
||||||
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
return policy == ForwardedBoth || policy == ForwardedRFC7239Only
|
||||||
}
|
}
|
||||||
|
|
||||||
func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool) {
|
func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) {
|
||||||
protocol := reverseProxyExtendedConnectProtocol(req)
|
protocol := reverseProxyExtendedConnectProtocol(req)
|
||||||
if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
|
if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") {
|
||||||
if req == nil {
|
if req == nil {
|
||||||
return context.Background(), false
|
return context.Background(), false, nil
|
||||||
}
|
}
|
||||||
return req.Context(), false
|
return req.Context(), false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge := &reverseProxyExtendedConnectBridge{body: req.Body}
|
bridge := &reverseProxyExtendedConnectBridge{body: req.Body}
|
||||||
|
|
@ -1389,10 +1392,11 @@ func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Contex
|
||||||
req.Header.Set("Connection", "Upgrade")
|
req.Header.Set("Connection", "Upgrade")
|
||||||
req.Header.Set("Sec-WebSocket-Version", "13")
|
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||||
key, err := reverseProxyGenerateWebSocketKey()
|
key, err := reverseProxyGenerateWebSocketKey()
|
||||||
if err == nil {
|
if err != nil {
|
||||||
req.Header.Set("Sec-WebSocket-Key", key)
|
return nil, false, fmt.Errorf("reverse proxy failed to generate websocket key: %w", err)
|
||||||
}
|
}
|
||||||
return ctx, true
|
req.Header.Set("Sec-WebSocket-Key", key)
|
||||||
|
return ctx, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge {
|
func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge {
|
||||||
|
|
@ -1405,7 +1409,7 @@ func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseP
|
||||||
|
|
||||||
func reverseProxyGenerateWebSocketKey() (string, error) {
|
func reverseProxyGenerateWebSocketKey() (string, error) {
|
||||||
key := make([]byte, 16)
|
key := make([]byte, 16)
|
||||||
if _, err := rand.Read(key); err != nil {
|
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return base64.StdEncoding.EncodeToString(key), nil
|
return base64.StdEncoding.EncodeToString(key), nil
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ package touka
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
crand "crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -829,6 +831,67 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
flushErr := errors.New("flush failed")
|
||||||
|
writer := &flushErrorResponseWriter{flushErr: flushErr}
|
||||||
|
conn := reverseProxyH2ReadWriteCloser{
|
||||||
|
ReadCloser: io.NopCloser(strings.NewReader("")),
|
||||||
|
ResponseWriter: writer,
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := conn.Write([]byte("ping"))
|
||||||
|
if n != len("ping") {
|
||||||
|
t.Fatalf("unexpected bytes written: %d", n)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, flushErr) {
|
||||||
|
t.Fatalf("unexpected write error: %v", err)
|
||||||
|
}
|
||||||
|
if got := writer.body.String(); got != "ping" {
|
||||||
|
t.Fatalf("unexpected buffered body: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
transportCalled := atomic.Bool{}
|
||||||
|
entropyErr := errors.New("entropy source unavailable")
|
||||||
|
originalReader := crand.Reader
|
||||||
|
crand.Reader = errorReader{err: entropyErr}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
crand.Reader = originalReader
|
||||||
|
})
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: mustParseURL(t, "http://example.com"),
|
||||||
|
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
transportCalled.Store(true)
|
||||||
|
return nil, errors.New("unexpected round trip")
|
||||||
|
}),
|
||||||
|
ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) {
|
||||||
|
w.WriteHeader(reverseProxyStatusCode(err))
|
||||||
|
_, _ = io.WriteString(w, err.Error())
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set(":protocol", "websocket")
|
||||||
|
rr := PerformRequest(engine, http.MethodConnect, "/ws", nil, headers)
|
||||||
|
|
||||||
|
if transportCalled.Load() {
|
||||||
|
t.Fatal("transport should not be called when websocket key generation fails")
|
||||||
|
}
|
||||||
|
if rr.Code != http.StatusBadGateway {
|
||||||
|
t.Fatalf("unexpected status: %d", rr.Code)
|
||||||
|
}
|
||||||
|
if body := rr.Body.String(); !strings.Contains(body, "reverse proxy failed to generate websocket key") || !strings.Contains(body, entropyErr.Error()) {
|
||||||
|
t.Fatalf("unexpected error body: %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
|
func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
@ -2137,6 +2200,73 @@ func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error)
|
||||||
return fn(req)
|
return fn(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type flushErrorResponseWriter struct {
|
||||||
|
header http.Header
|
||||||
|
body bytes.Buffer
|
||||||
|
status int
|
||||||
|
written bool
|
||||||
|
flushErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) Header() http.Header {
|
||||||
|
if w.header == nil {
|
||||||
|
w.header = make(http.Header)
|
||||||
|
}
|
||||||
|
return w.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
if w.written {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.status = statusCode
|
||||||
|
w.written = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) Write(p []byte) (int, error) {
|
||||||
|
if !w.written {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
return w.body.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) Flush() {}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) FlushError() error {
|
||||||
|
if !w.written {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
return w.flushErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return nil, nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) Status() int {
|
||||||
|
return w.status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) Size() int {
|
||||||
|
return w.body.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) Written() bool {
|
||||||
|
return w.written
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flushErrorResponseWriter) IsHijacked() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorReader struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r errorReader) Read([]byte) (int, error) {
|
||||||
|
return 0, r.err
|
||||||
|
}
|
||||||
|
|
||||||
func mustParseURL(t *testing.T, raw string) *url.URL {
|
func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
u, err := url.Parse(raw)
|
u, err := url.Parse(raw)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue