diff --git a/reverseproxy.go b/reverseproxy.go index 5ec3693..1cf0078 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -20,6 +20,7 @@ import ( "net/netip" "net/textproto" "net/url" + "regexp" "strconv" "strings" "sync" @@ -48,34 +49,275 @@ type BufferPool interface { // ReverseProxyConfig configures the reverse proxy handler. type ReverseProxyConfig struct { - Target *url.URL + Target *url.URL Targets []string LoadBalancing ReverseProxyLoadBalancingConfig PassiveHealth ReverseProxyPassiveHealthConfig - Transport http.RoundTripper - FlushInterval time.Duration - BufferPool BufferPool + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool AllowH2CUpstream bool - ModifyRequest func(*http.Request) + ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error - ErrorHandler func(http.ResponseWriter, *http.Request, error) + ErrorHandler func(http.ResponseWriter, *http.Request, error) ForwardedHeaders ForwardedHeadersPolicy - ForwardedBy string - Via string - PreserveHost bool + ForwardedBy string + Via string + PreserveHost bool + + RequestHeaders *HeaderOps + ResponseHeaders *RespHeaderOps } var ( - errReverseProxyNilTarget = errors.New("reverse proxy target is nil") - errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") - errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") + errReverseProxyNilTarget = errors.New("reverse proxy target is nil") + errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") + errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams") ) +type HeaderOps struct { + Add map[string][]string + Set map[string][]string + Delete []string + Replace map[string][]Replacement +} + +type Replacement struct { + Search string + Replace string + SearchRegexp string + re *regexp.Regexp +} + +type RespHeaderOps struct { + *HeaderOps + Deferred bool +} + +func (ops *HeaderOps) applyToRequest(req *http.Request) { + if ops == nil { + return + } + ops.applyTo(req.Header, newReverseProxyReplacer(req)) +} + +func (ops *RespHeaderOps) applyToResponse(hdr http.Header) { + if ops == nil { + return + } + ops.applyTo(hdr, newReverseProxyReplacerFromHeader(hdr)) +} + +func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) { + if ops == nil { + return + } + if repl == nil { + repl = &reverseProxyReplacer{} + } + + for fieldName, vals := range ops.Add { + fieldName = repl.Replace(fieldName) + for _, v := range vals { + hdr.Add(fieldName, repl.Replace(v)) + } + } + + for fieldName, vals := range ops.Set { + fieldName = repl.Replace(fieldName) + hdr.Del(fieldName) + for _, v := range vals { + hdr.Add(fieldName, repl.Replace(v)) + } + } + + var deleteAll bool + var exactDeletes []string + var suffixPatterns, prefixPatterns, containsPatterns []string + + for _, fieldName := range ops.Delete { + fieldName = strings.ToLower(repl.Replace(fieldName)) + if fieldName == "*" { + deleteAll = true + break + } + switch { + case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): + containsPatterns = append(containsPatterns, fieldName[1:len(fieldName)-1]) + case strings.HasPrefix(fieldName, "*"): + suffixPatterns = append(suffixPatterns, fieldName[1:]) + case strings.HasSuffix(fieldName, "*"): + prefixPatterns = append(prefixPatterns, fieldName[:len(fieldName)-1]) + default: + exactDeletes = append(exactDeletes, fieldName) + } + } + + if deleteAll { + for k := range hdr { + hdr.Del(k) + } + } else if len(exactDeletes) > 0 || len(suffixPatterns) > 0 || len(prefixPatterns) > 0 || len(containsPatterns) > 0 { + toDelete := make([]string, 0, len(exactDeletes)) + for k := range hdr { + kl := strings.ToLower(k) + for _, d := range exactDeletes { + if kl == d { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range containsPatterns { + if strings.Contains(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range suffixPatterns { + if strings.HasSuffix(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range prefixPatterns { + if strings.HasPrefix(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + skip: + } + for _, k := range toDelete { + hdr.Del(k) + } + } + + ops.applyReplace(hdr, repl) +} + +func (ops *HeaderOps) applyReplace(hdr http.Header, repl *reverseProxyReplacer) { + if ops == nil || len(ops.Replace) == 0 { + return + } + for fieldName, replacements := range ops.Replace { + fieldName = http.CanonicalHeaderKey(repl.Replace(fieldName)) + if fieldName == "*" { + for fn, vals := range hdr { + for i := range vals { + for _, r := range replacements { + hdr[fn][i] = r.apply(vals[i]) + } + } + } + continue + } + vals, ok := hdr[fieldName] + if !ok { + continue + } + for i := range vals { + for _, r := range replacements { + hdr[fieldName][i] = r.apply(vals[i]) + } + } + } +} + +func (r *Replacement) apply(s string) string { + if r == nil || s == "" { + return s + } + if r.SearchRegexp != "" && r.re != nil { + return r.re.ReplaceAllString(s, r.Replace) + } + if r.Search != "" { + return strings.ReplaceAll(s, r.Search, r.Replace) + } + return s +} + +func (ops *HeaderOps) Provision() error { + if ops == nil { + return nil + } + for fieldName, replacements := range ops.Replace { + for i, r := range replacements { + if r.SearchRegexp == "" { + continue + } + if r.Search != "" { + return fmt.Errorf("replacement %d for header field %q: cannot specify both Search and SearchRegexp", i, fieldName) + } + re, err := regexp.Compile(r.SearchRegexp) + if err != nil { + return fmt.Errorf("replacement %d for header field %q: %v", i, fieldName, err) + } + replacements[i].re = re + } + } + return nil +} + +type reverseProxyReplacer struct { + method, host, path, query, scheme, uri, proto string +} + +func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer { + if req == nil || req.URL == nil { + return &reverseProxyReplacer{} + } + uri := req.RequestURI + if uri == "" { + uri = req.URL.RequestURI() + } + return &reverseProxyReplacer{ + method: req.Method, + host: req.Host, + path: req.URL.EscapedPath(), + query: req.URL.RawQuery, + scheme: reverseProxyRequestScheme(req), + uri: uri, + proto: req.Proto, + } +} + +func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer { + return &reverseProxyReplacer{} +} + +func (r *reverseProxyReplacer) Replace(s string) string { + if r == nil || s == "" { + return s + } + if r.method != "" { + s = strings.ReplaceAll(s, "{method}", r.method) + } + if r.host != "" { + s = strings.ReplaceAll(s, "{host}", r.host) + } + if r.path != "" { + s = strings.ReplaceAll(s, "{path}", r.path) + } + if r.query != "" { + s = strings.ReplaceAll(s, "{query}", r.query) + } + if r.scheme != "" { + s = strings.ReplaceAll(s, "{scheme}", r.scheme) + } + if r.uri != "" { + s = strings.ReplaceAll(s, "{uri}", r.uri) + } + if r.proto != "" { + s = strings.ReplaceAll(s, "{proto}", r.proto) + } + return s +} + type reverseProxyHandler struct { config ReverseProxyConfig upstreams []*reverseProxyUpstream @@ -256,6 +498,19 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { receivedBy: reverseProxyReceivedBy(config.Via), } + if config.RequestHeaders != nil { + if err := config.RequestHeaders.Provision(); err != nil { + proxy.configError = err + return proxy + } + } + if config.ResponseHeaders != nil && config.ResponseHeaders.HeaderOps != nil { + if err := config.ResponseHeaders.HeaderOps.Provision(); err != nil { + proxy.configError = err + return proxy + } + } + upstreams, err := buildReverseProxyUpstreams(config) if err != nil { proxy.configError = err @@ -573,6 +828,10 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte outreq.Header.Set("User-Agent", "") } + if p.config.RequestHeaders != nil { + p.config.RequestHeaders.applyToRequest(outreq) + } + if p.config.ModifyRequest != nil { p.config.ModifyRequest(outreq) } @@ -808,7 +1067,14 @@ func appendXForwardedFor(header http.Header, clientIP string) { } func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool { + if p.config.ResponseHeaders != nil && !p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } + if p.config.ModifyResponse == nil { + if p.config.ResponseHeaders != nil && p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } return true } if err := p.config.ModifyResponse(res); err != nil { @@ -816,6 +1082,9 @@ func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req p.handleError(c, err) return false } + if p.config.ResponseHeaders != nil && p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } return true } diff --git a/reverseproxy_headers_replace_test.go b/reverseproxy_headers_replace_test.go new file mode 100644 index 0000000..0c0d599 --- /dev/null +++ b/reverseproxy_headers_replace_test.go @@ -0,0 +1,530 @@ +package touka + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "testing" +) + +func TestReverseProxyHeaderOpsReplaceSubstring(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Server"); got != "Caddy" { + t.Errorf("expected X-Server=Caddy, got %q", got) + } + if got := r.Header.Get("X-Location"); got != "/api/v2/resource" { + t.Errorf("expected X-Location=/api/v2/resource, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Server": {{Search: "NGINX", Replace: "Caddy"}}, + "X-Location": {{Search: "v1", Replace: "v2"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Server", "NGINX") + req.Header.Set("X-Location", "/api/v1/resource") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsReplaceRegexp(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Route"); got != "/proxy-upstream" { + t.Errorf("expected X-Route=/proxy-upstream, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Route": {{SearchRegexp: `^/([^/]+)/(.+)$`, Replace: "/proxy-$2"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Route", "/original/upstream") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsReplaceWildcard(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Host-A"); got != "new.example.com" { + t.Errorf("expected X-Host-A=new.example.com, got %q", got) + } + if got := r.Header.Get("X-Host-B"); got != "new.example.com" { + t.Errorf("expected X-Host-B=new.example.com, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "*": {{Search: "old.example.com", Replace: "new.example.com"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Host-A", "old.example.com") + req.Header.Set("X-Host-B", "old.example.com") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "backend-internal:8080") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Backend": {{Search: "backend-internal:8080", Replace: "public.example.com"}}, + }, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Backend"); got != "public.example.com" { + t.Errorf("expected X-Backend=public.example.com, got %q", got) + } +} + +func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Test": {{SearchRegexp: "[invalid"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", resp.StatusCode) + } +} + +func TestReplacementApply(t *testing.T) { + tests := []struct { + name string + r *Replacement + s string + want string + }{ + {name: "nil replacement", r: nil, s: "hello", want: "hello"}, + {name: "empty string", r: &Replacement{Search: "x", Replace: "y"}, s: "", want: ""}, + {name: "substring match", r: &Replacement{Search: "world", Replace: "go"}, s: "hello world", want: "hello go"}, + {name: "substring no match", r: &Replacement{Search: "foo", Replace: "bar"}, s: "hello world", want: "hello world"}, + {name: "substring multiple", r: &Replacement{Search: "a", Replace: "b"}, s: "aaa", want: "bbb"}, + {name: "regexp match", r: &Replacement{SearchRegexp: `\d+`, Replace: "N", re: regexp.MustCompile(`\d+`)}, s: "abc123def", want: "abcNdef"}, + {name: "regexp no match", r: &Replacement{SearchRegexp: `z+`, Replace: "Z", re: regexp.MustCompile(`z+`)}, s: "abc", want: "abc"}, + {name: "empty search and regexp", r: &Replacement{}, s: "unchanged", want: "unchanged"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.apply(tt.s); got != tt.want { + t.Errorf("Replacement.apply() = %q, want %q", got, tt.want) + } + }) + } +} + +func BenchmarkHeaderOpsAdd(b *testing.B) { + ops := &HeaderOps{ + Add: map[string][]string{ + "X-Custom-1": {"value-1"}, + "X-Custom-2": {"value-2"}, + "X-Custom-3": {"value-3"}, + }, + } + hdr := make(http.Header) + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr = make(http.Header) + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsSet(b *testing.B) { + ops := &HeaderOps{ + Set: map[string][]string{ + "X-Frame-Options": {"DENY"}, + "X-Content-Type-Options": {"nosniff"}, + "X-XSS-Protection": {"1; mode=block"}, + }, + } + hdr := make(http.Header) + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr = make(http.Header) + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsDeleteSingle(b *testing.B) { + ops := &HeaderOps{ + Delete: []string{"X-Powered-By"}, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Powered-By", "Express") + hdr.Set("X-Keep", "value") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsDeleteWildcard(b *testing.B) { + ops := &HeaderOps{ + Delete: []string{"X-Debug-*"}, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Debug-1", "v1") + hdr.Set("X-Debug-2", "v2") + hdr.Set("X-Keep", "value") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsReplaceSubstring(b *testing.B) { + ops := &HeaderOps{ + Replace: map[string][]Replacement{ + "Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("Location", "http://internal:8080/api/v1/users") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsReplaceRegexp(b *testing.B) { + re := regexp.MustCompile(`^http://([^/]+)(/.*)$`) + ops := &HeaderOps{ + Replace: map[string][]Replacement{ + "Location": {{SearchRegexp: `^http://([^/]+)(/.*)$`, Replace: "https://public.example.com$2", re: re}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("Location", "http://internal:8080/api/v1/users") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsReplaceWildcard(b *testing.B) { + ops := &HeaderOps{ + Replace: map[string][]Replacement{ + "*": {{Search: "internal.example.com", Replace: "public.example.com"}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Host", "internal.example.com") + hdr.Set("X-Origin", "internal.example.com") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsMixed(b *testing.B) { + ops := &HeaderOps{ + Add: map[string][]string{ + "X-Request-ID": {"req-123"}, + }, + Set: map[string][]string{ + "X-Frame-Options": {"DENY"}, + }, + Delete: []string{"X-Powered-By"}, + Replace: map[string][]Replacement{ + "Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Powered-By", "Express") + hdr.Set("Location", "http://internal:8080/api") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkReplacementApplySubstring(b *testing.B) { + r := &Replacement{Search: "old.example.com", Replace: "new.example.com"} + s := "https://old.example.com/api/v1/resource" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = r.apply(s) + } +} + +func BenchmarkReplacementApplyRegexp(b *testing.B) { + r := &Replacement{SearchRegexp: `^https?://[^/]+`, Replace: "https://new.example.com", re: regexp.MustCompile(`^https?://[^/]+`)} + s := "https://old.example.com/api/v1/resource" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = r.apply(s) + } +} + +func TestReverseProxyReplacerDynamicVars(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com/api/v1/users?sort=name&limit=10", nil) + req.Host = "example.com" + repl := newReverseProxyReplacer(req) + + tests := []struct { + name string + input string + want string + }{ + {"method", "{method}", "GET"}, + {"host", "{host}", "example.com"}, + {"path", "{path}", "/api/v1/users"}, + {"query", "{query}", "sort=name&limit=10"}, + {"scheme", "{scheme}", "http"}, + {"proto", "{proto}", "HTTP/1.1"}, + {"combined", "X-{method}-{path}", "X-GET-/api/v1/users"}, + {"no vars", "static-value", "static-value"}, + {"empty", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := repl.Replace(tt.input); got != tt.want { + t.Errorf("Replace(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestReverseProxyReplacerNilRequest(t *testing.T) { + repl := newReverseProxyReplacer(nil) + if got := repl.Replace("{method}"); got != "{method}" { + t.Errorf("expected unchanged string with nil request, got %q", got) + } +} + +func TestReverseProxyReplacerNilReplacer(t *testing.T) { + var repl *reverseProxyReplacer + if got := repl.Replace("{method}"); got != "{method}" { + t.Errorf("expected unchanged string with nil replacer, got %q", got) + } +} + +func TestReverseProxyReplacerFromHeader(t *testing.T) { + hdr := make(http.Header) + repl := newReverseProxyReplacerFromHeader(hdr) + if got := repl.Replace("{method}"); got != "{method}" { + t.Errorf("expected unchanged string from header replacer, got %q", got) + } +} + +func TestReverseProxyHeaderOpsWithDynamicVars(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Forwarded-Path"); got != "/dynamic/path" { + t.Errorf("expected X-Forwarded-Path=/dynamic/path, got %q", got) + } + if got := r.Header.Get("X-Forwarded-Method"); got != "GET" { + t.Errorf("expected X-Forwarded-Method=GET, got %q", got) + } + if got := r.Header.Get("X-Forwarded-Host"); got != "client.example" { + t.Errorf("expected X-Forwarded-Host=client.example, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/dynamic/path", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Add: map[string][]string{ + "X-Forwarded-Path": {"{path}"}, + "X-Forwarded-Method": {"{method}"}, + "X-Forwarded-Host": {"{host}"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/dynamic/path", nil) + req.Host = "client.example" + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} diff --git a/reverseproxy_headers_test.go b/reverseproxy_headers_test.go new file mode 100644 index 0000000..4a4ae26 --- /dev/null +++ b/reverseproxy_headers_test.go @@ -0,0 +1,220 @@ +package touka + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestReverseProxyHeaderOpsAdd(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Custom-Header"); got != "test-value" { + t.Errorf("expected X-Custom-Header=test-value, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Add: map[string][]string{ + "X-Custom-Header": {"test-value"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsDelete(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Sensitive") != "" { + t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive")) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Delete: []string{"X-Sensitive"}, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Sensitive", "should-be-removed") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsSet(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got := r.Header.Get("X-Replace") + if got != "new-value" { + t.Errorf("expected X-Replace=new-value, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Set: map[string][]string{ + "X-Replace": {"new-value"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Replace", "old-value") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyResponseHeaderOps(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "backend-server") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Set: map[string][]string{ + "X-Custom": {"custom-value"}, + }, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Custom"); got != "custom-value" { + t.Errorf("expected X-Custom=custom-value, got %q", got) + } +} + +func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Powered-By", "Express") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Delete: []string{"X-Powered-By"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Powered-By"); got != "" { + t.Errorf("expected X-Powered-By to be deleted, got %q", got) + } +}