From 5d9bb3187d86bcfe150e632c48a28f58d96de30b Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 17:20:30 +0800 Subject: [PATCH] perf: optimize wildcard header deletion; test: assert invalid regex returns 500 - refactor Delete logic to iterate headers once, reducing ToLower calls from O(patterns * headers) to O(headers) - rewrite invalid regex test to verify runtime 500 response --- reverseproxy.go | 71 +++++++++++++++++++--------- reverseproxy_headers_replace_test.go | 34 +++++++++++-- 2 files changed, 78 insertions(+), 27 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 6ccca43..30fa370 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -136,39 +136,64 @@ func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) { } } + var deleteAll bool + var exactDeletes []string + var suffixPatterns, prefixPatterns, containsPatterns []string + for _, fieldName := range ops.Delete { fieldName = strings.ToLower(repl.Replace(fieldName)) if fieldName == "*" { - for k := range hdr { - hdr.Del(k) - } - continue + deleteAll = true + break } - switch { case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): - pattern := fieldName[1:len(fieldName)-1] - for k := range hdr { - if strings.Contains(strings.ToLower(k), pattern) { - hdr.Del(k) - } - } + containsPatterns = append(containsPatterns, fieldName[1:len(fieldName)-1]) case strings.HasPrefix(fieldName, "*"): - suffix := fieldName[1:] - for k := range hdr { - if strings.HasSuffix(strings.ToLower(k), suffix) { - hdr.Del(k) - } - } + suffixPatterns = append(suffixPatterns, fieldName[1:]) case strings.HasSuffix(fieldName, "*"): - prefix := fieldName[:len(fieldName)-1] - for k := range hdr { - if strings.HasPrefix(strings.ToLower(k), prefix) { - hdr.Del(k) + prefixPatterns = append(prefixPatterns, fieldName[:len(fieldName)-1]) + default: + exactDeletes = append(exactDeletes, fieldName) + } + } + + if deleteAll { + for k := range hdr { + hdr.Del(k) + } + } else if len(exactDeletes) > 0 || len(suffixPatterns) > 0 || len(prefixPatterns) > 0 || len(containsPatterns) > 0 { + toDelete := make([]string, 0, len(exactDeletes)) + for k := range hdr { + kl := strings.ToLower(k) + for _, d := range exactDeletes { + if kl == d { + toDelete = append(toDelete, k) + goto skip } } - default: - hdr.Del(fieldName) + for _, p := range containsPatterns { + if strings.Contains(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range suffixPatterns { + if strings.HasSuffix(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range prefixPatterns { + if strings.HasPrefix(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + skip: + } + for _, k := range toDelete { + hdr.Del(k) } } diff --git a/reverseproxy_headers_replace_test.go b/reverseproxy_headers_replace_test.go index 54e7889..1eb0d04 100644 --- a/reverseproxy_headers_replace_test.go +++ b/reverseproxy_headers_replace_test.go @@ -193,15 +193,41 @@ func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) { } func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) { - _ = New() - ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, "http://example.com"), + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, RequestHeaders: &HeaderOps{ Replace: map[string][]Replacement{ "X-Test": {{SearchRegexp: "[invalid"}}, }, }, - }) + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", resp.StatusCode) + } } func TestReplacementApply(t *testing.T) {