feat: add headers operations for reverse proxy

- Add HeaderOps struct for Add/Set/Delete header operations
- Add RespHeaderOps for response header manipulation with deferred support
- Support wildcard patterns for header deletion (prefix-*, *suffix, *substring*)
- Apply request headers before forwarding to upstream
- Apply response headers before sending to client
- Add comprehensive test coverage for header operations

Usage example:
  engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
    Target: target,
    RequestHeaders: &HeaderOps{
      Add: map[string][]string{"X-Custom": {"value"}},
      Delete: []string{"X-Sensitive-*"},
    },
    ResponseHeaders: &RespHeaderOps{
      HeaderOps: &HeaderOps{
        Set: map[string][]string{"X-Frame-Options": {"DENY"}},
      },
    },
  }))
This commit is contained in:
wjqserver 2026-04-19 09:30:06 +08:00
parent 3b5f2c81af
commit 06a6d42de1
2 changed files with 401 additions and 12 deletions

View file

@ -67,6 +67,9 @@ type ReverseProxyConfig struct {
ForwardedBy string ForwardedBy string
Via string Via string
PreserveHost bool PreserveHost bool
RequestHeaders *HeaderOps
ResponseHeaders *RespHeaderOps
} }
var ( var (
@ -76,6 +79,164 @@ var (
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams") errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
) )
type HeaderOps struct {
Add map[string][]string
Set map[string][]string
Delete []string
}
type RespHeaderOps struct {
*HeaderOps
Deferred bool
}
func (ops *HeaderOps) applyToRequest(req *http.Request) {
if ops == nil {
return
}
replacer := newReverseProxyReplacer(req)
for fieldName, vals := range ops.Add {
fieldName = replacer.Replace(fieldName)
for _, v := range vals {
req.Header.Add(fieldName, replacer.Replace(v))
}
}
for fieldName, vals := range ops.Set {
fieldName = replacer.Replace(fieldName)
req.Header.Del(fieldName)
for _, v := range vals {
req.Header.Add(fieldName, replacer.Replace(v))
}
}
for _, fieldName := range ops.Delete {
fieldName = strings.ToLower(replacer.Replace(fieldName))
if fieldName == "*" {
for k := range req.Header {
req.Header.Del(k)
}
continue
}
switch {
case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"):
pattern := fieldName[1:len(fieldName)-1]
for k := range req.Header {
if strings.Contains(strings.ToLower(k), pattern) {
req.Header.Del(k)
}
}
case strings.HasPrefix(fieldName, "*"):
suffix := fieldName[1:]
for k := range req.Header {
if strings.HasSuffix(strings.ToLower(k), suffix) {
req.Header.Del(k)
}
}
case strings.HasSuffix(fieldName, "*"):
prefix := fieldName[:len(fieldName)-1]
for k := range req.Header {
if strings.HasPrefix(strings.ToLower(k), prefix) {
req.Header.Del(k)
}
}
default:
req.Header.Del(fieldName)
}
}
}
func (ops *RespHeaderOps) applyToResponse(hdr http.Header) {
if ops == nil {
return
}
if ops.Deferred {
return
}
ops.applyTo(hdr, newReverseProxyReplacerFromHeader(hdr))
}
func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) {
if ops == nil {
return
}
if repl == nil {
repl = &reverseProxyReplacer{}
}
for fieldName, vals := range ops.Add {
fieldName = repl.Replace(fieldName)
for _, v := range vals {
hdr.Add(fieldName, repl.Replace(v))
}
}
for fieldName, vals := range ops.Set {
fieldName = repl.Replace(fieldName)
hdr.Del(fieldName)
for _, v := range vals {
hdr.Add(fieldName, repl.Replace(v))
}
}
for _, fieldName := range ops.Delete {
fieldName = strings.ToLower(repl.Replace(fieldName))
if fieldName == "*" {
for k := range hdr {
hdr.Del(k)
}
continue
}
switch {
case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"):
pattern := fieldName[1:len(fieldName)-1]
for k := range hdr {
if strings.Contains(strings.ToLower(k), pattern) {
hdr.Del(k)
}
}
case strings.HasPrefix(fieldName, "*"):
suffix := fieldName[1:]
for k := range hdr {
if strings.HasSuffix(strings.ToLower(k), suffix) {
hdr.Del(k)
}
}
case strings.HasSuffix(fieldName, "*"):
prefix := fieldName[:len(fieldName)-1]
for k := range hdr {
if strings.HasPrefix(strings.ToLower(k), prefix) {
hdr.Del(k)
}
}
default:
hdr.Del(fieldName)
}
}
}
type reverseProxyReplacer struct {
req *http.Request
}
func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer {
return &reverseProxyReplacer{req: req}
}
func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer {
return &reverseProxyReplacer{}
}
func (r *reverseProxyReplacer) Replace(s string) string {
if r == nil || s == "" {
return s
}
return s
}
type reverseProxyHandler struct { type reverseProxyHandler struct {
config ReverseProxyConfig config ReverseProxyConfig
upstreams []*reverseProxyUpstream upstreams []*reverseProxyUpstream
@ -573,6 +734,10 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
outreq.Header.Set("User-Agent", "") outreq.Header.Set("User-Agent", "")
} }
if p.config.RequestHeaders != nil {
p.config.RequestHeaders.applyToRequest(outreq)
}
if p.config.ModifyRequest != nil { if p.config.ModifyRequest != nil {
p.config.ModifyRequest(outreq) p.config.ModifyRequest(outreq)
} }
@ -808,6 +973,10 @@ func appendXForwardedFor(header http.Header, clientIP string) {
} }
func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool { func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool {
if p.config.ResponseHeaders != nil && !p.config.ResponseHeaders.Deferred {
p.config.ResponseHeaders.applyToResponse(res.Header)
}
if p.config.ModifyResponse == nil { if p.config.ModifyResponse == nil {
return true return true
} }

View file

@ -0,0 +1,220 @@
package touka
import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestReverseProxyHeaderOpsAdd(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Custom-Header"); got != "test-value" {
t.Errorf("expected X-Custom-Header=test-value, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Add: map[string][]string{
"X-Custom-Header": {"test-value"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsDelete(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Sensitive") != "" {
t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive"))
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Delete: []string{"X-Sensitive"},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Sensitive", "should-be-removed")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyHeaderOpsSet(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("X-Replace")
if got != "new-value" {
t.Errorf("expected X-Replace=new-value, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Set: map[string][]string{
"X-Replace": {"new-value"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
req.Header.Set("X-Replace", "old-value")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
func TestReverseProxyResponseHeaderOps(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "backend-server")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
ResponseHeaders: &RespHeaderOps{
HeaderOps: &HeaderOps{
Set: map[string][]string{
"X-Custom": {"custom-value"},
},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if got := resp.Header.Get("X-Custom"); got != "custom-value" {
t.Errorf("expected X-Custom=custom-value, got %q", got)
}
}
func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) {
t.Helper()
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Powered-By", "Express")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
Target: target,
ResponseHeaders: &RespHeaderOps{
HeaderOps: &HeaderOps{
Delete: []string{"X-Powered-By"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
resp, err := http.Get(proxy.URL + "/test")
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if got := resp.Header.Get("X-Powered-By"); got != "" {
t.Errorf("expected X-Powered-By to be deleted, got %q", got)
}
}