From 06a6d42de1b6dfa1bb51ba7482463da720f34e7f Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Sun, 19 Apr 2026 09:30:06 +0800 Subject: [PATCH] feat: add headers operations for reverse proxy - Add HeaderOps struct for Add/Set/Delete header operations - Add RespHeaderOps for response header manipulation with deferred support - Support wildcard patterns for header deletion (prefix-*, *suffix, *substring*) - Apply request headers before forwarding to upstream - Apply response headers before sending to client - Add comprehensive test coverage for header operations Usage example: engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ Target: target, RequestHeaders: &HeaderOps{ Add: map[string][]string{"X-Custom": {"value"}}, Delete: []string{"X-Sensitive-*"}, }, ResponseHeaders: &RespHeaderOps{ HeaderOps: &HeaderOps{ Set: map[string][]string{"X-Frame-Options": {"DENY"}}, }, }, })) --- reverseproxy.go | 193 ++++++++++++++++++++++++++++-- reverseproxy_headers_test.go | 220 +++++++++++++++++++++++++++++++++++ 2 files changed, 401 insertions(+), 12 deletions(-) create mode 100644 reverseproxy_headers_test.go diff --git a/reverseproxy.go b/reverseproxy.go index 4c2b3cd..eb5043c 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -48,34 +48,195 @@ type BufferPool interface { // ReverseProxyConfig configures the reverse proxy handler. type ReverseProxyConfig struct { - Target *url.URL + Target *url.URL Targets []string LoadBalancing ReverseProxyLoadBalancingConfig PassiveHealth ReverseProxyPassiveHealthConfig - Transport http.RoundTripper - FlushInterval time.Duration - BufferPool BufferPool + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool AllowH2CUpstream bool - ModifyRequest func(*http.Request) + ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error - ErrorHandler func(http.ResponseWriter, *http.Request, error) + ErrorHandler func(http.ResponseWriter, *http.Request, error) ForwardedHeaders ForwardedHeadersPolicy - ForwardedBy string - Via string - PreserveHost bool + ForwardedBy string + Via string + PreserveHost bool + + RequestHeaders *HeaderOps + ResponseHeaders *RespHeaderOps } var ( - errReverseProxyNilTarget = errors.New("reverse proxy target is nil") - errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") - errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") + errReverseProxyNilTarget = errors.New("reverse proxy target is nil") + errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") + errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams") ) +type HeaderOps struct { + Add map[string][]string + Set map[string][]string + Delete []string +} + +type RespHeaderOps struct { + *HeaderOps + Deferred bool +} + +func (ops *HeaderOps) applyToRequest(req *http.Request) { + if ops == nil { + return + } + replacer := newReverseProxyReplacer(req) + + for fieldName, vals := range ops.Add { + fieldName = replacer.Replace(fieldName) + for _, v := range vals { + req.Header.Add(fieldName, replacer.Replace(v)) + } + } + + for fieldName, vals := range ops.Set { + fieldName = replacer.Replace(fieldName) + req.Header.Del(fieldName) + for _, v := range vals { + req.Header.Add(fieldName, replacer.Replace(v)) + } + } + + for _, fieldName := range ops.Delete { + fieldName = strings.ToLower(replacer.Replace(fieldName)) + if fieldName == "*" { + for k := range req.Header { + req.Header.Del(k) + } + continue + } + + switch { + case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): + pattern := fieldName[1:len(fieldName)-1] + for k := range req.Header { + if strings.Contains(strings.ToLower(k), pattern) { + req.Header.Del(k) + } + } + case strings.HasPrefix(fieldName, "*"): + suffix := fieldName[1:] + for k := range req.Header { + if strings.HasSuffix(strings.ToLower(k), suffix) { + req.Header.Del(k) + } + } + case strings.HasSuffix(fieldName, "*"): + prefix := fieldName[:len(fieldName)-1] + for k := range req.Header { + if strings.HasPrefix(strings.ToLower(k), prefix) { + req.Header.Del(k) + } + } + default: + req.Header.Del(fieldName) + } + } +} + +func (ops *RespHeaderOps) applyToResponse(hdr http.Header) { + if ops == nil { + return + } + if ops.Deferred { + return + } + ops.applyTo(hdr, newReverseProxyReplacerFromHeader(hdr)) +} + +func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) { + if ops == nil { + return + } + if repl == nil { + repl = &reverseProxyReplacer{} + } + + for fieldName, vals := range ops.Add { + fieldName = repl.Replace(fieldName) + for _, v := range vals { + hdr.Add(fieldName, repl.Replace(v)) + } + } + + for fieldName, vals := range ops.Set { + fieldName = repl.Replace(fieldName) + hdr.Del(fieldName) + for _, v := range vals { + hdr.Add(fieldName, repl.Replace(v)) + } + } + + for _, fieldName := range ops.Delete { + fieldName = strings.ToLower(repl.Replace(fieldName)) + if fieldName == "*" { + for k := range hdr { + hdr.Del(k) + } + continue + } + + switch { + case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): + pattern := fieldName[1:len(fieldName)-1] + for k := range hdr { + if strings.Contains(strings.ToLower(k), pattern) { + hdr.Del(k) + } + } + case strings.HasPrefix(fieldName, "*"): + suffix := fieldName[1:] + for k := range hdr { + if strings.HasSuffix(strings.ToLower(k), suffix) { + hdr.Del(k) + } + } + case strings.HasSuffix(fieldName, "*"): + prefix := fieldName[:len(fieldName)-1] + for k := range hdr { + if strings.HasPrefix(strings.ToLower(k), prefix) { + hdr.Del(k) + } + } + default: + hdr.Del(fieldName) + } + } +} + +type reverseProxyReplacer struct { + req *http.Request +} + +func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer { + return &reverseProxyReplacer{req: req} +} + +func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer { + return &reverseProxyReplacer{} +} + +func (r *reverseProxyReplacer) Replace(s string) string { + if r == nil || s == "" { + return s + } + return s +} + type reverseProxyHandler struct { config ReverseProxyConfig upstreams []*reverseProxyUpstream @@ -573,6 +734,10 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte outreq.Header.Set("User-Agent", "") } + if p.config.RequestHeaders != nil { + p.config.RequestHeaders.applyToRequest(outreq) + } + if p.config.ModifyRequest != nil { p.config.ModifyRequest(outreq) } @@ -808,6 +973,10 @@ func appendXForwardedFor(header http.Header, clientIP string) { } func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool { + if p.config.ResponseHeaders != nil && !p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } + if p.config.ModifyResponse == nil { return true } diff --git a/reverseproxy_headers_test.go b/reverseproxy_headers_test.go new file mode 100644 index 0000000..4a4ae26 --- /dev/null +++ b/reverseproxy_headers_test.go @@ -0,0 +1,220 @@ +package touka + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestReverseProxyHeaderOpsAdd(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Custom-Header"); got != "test-value" { + t.Errorf("expected X-Custom-Header=test-value, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Add: map[string][]string{ + "X-Custom-Header": {"test-value"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsDelete(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Sensitive") != "" { + t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive")) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Delete: []string{"X-Sensitive"}, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Sensitive", "should-be-removed") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsSet(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got := r.Header.Get("X-Replace") + if got != "new-value" { + t.Errorf("expected X-Replace=new-value, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Set: map[string][]string{ + "X-Replace": {"new-value"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Replace", "old-value") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyResponseHeaderOps(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "backend-server") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Set: map[string][]string{ + "X-Custom": {"custom-value"}, + }, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Custom"); got != "custom-value" { + t.Errorf("expected X-Custom=custom-value, got %q", got) + } +} + +func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Powered-By", "Express") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Delete: []string{"X-Powered-By"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Powered-By"); got != "" { + t.Errorf("expected X-Powered-By to be deleted, got %q", got) + } +}