mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
feat: add headers operations for reverse proxy
- Add HeaderOps struct for Add/Set/Delete header operations
- Add RespHeaderOps for response header manipulation with deferred support
- Support wildcard patterns for header deletion (prefix-*, *suffix, *substring*)
- Apply request headers before forwarding to upstream
- Apply response headers before sending to client
- Add comprehensive test coverage for header operations
Usage example:
engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Add: map[string][]string{"X-Custom": {"value"}},
Delete: []string{"X-Sensitive-*"},
},
ResponseHeaders: &RespHeaderOps{
HeaderOps: &HeaderOps{
Set: map[string][]string{"X-Frame-Options": {"DENY"}},
},
},
}))
This commit is contained in:
parent
3b5f2c81af
commit
06a6d42de1
2 changed files with 401 additions and 12 deletions
193
reverseproxy.go
193
reverseproxy.go
|
|
@ -48,34 +48,195 @@ type BufferPool interface {
|
||||||
|
|
||||||
// ReverseProxyConfig configures the reverse proxy handler.
|
// ReverseProxyConfig configures the reverse proxy handler.
|
||||||
type ReverseProxyConfig struct {
|
type ReverseProxyConfig struct {
|
||||||
Target *url.URL
|
Target *url.URL
|
||||||
Targets []string
|
Targets []string
|
||||||
|
|
||||||
LoadBalancing ReverseProxyLoadBalancingConfig
|
LoadBalancing ReverseProxyLoadBalancingConfig
|
||||||
PassiveHealth ReverseProxyPassiveHealthConfig
|
PassiveHealth ReverseProxyPassiveHealthConfig
|
||||||
|
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
FlushInterval time.Duration
|
FlushInterval time.Duration
|
||||||
BufferPool BufferPool
|
BufferPool BufferPool
|
||||||
AllowH2CUpstream bool
|
AllowH2CUpstream bool
|
||||||
|
|
||||||
ModifyRequest func(*http.Request)
|
ModifyRequest func(*http.Request)
|
||||||
ModifyResponse func(*http.Response) error
|
ModifyResponse func(*http.Response) error
|
||||||
ErrorHandler func(http.ResponseWriter, *http.Request, error)
|
ErrorHandler func(http.ResponseWriter, *http.Request, error)
|
||||||
|
|
||||||
ForwardedHeaders ForwardedHeadersPolicy
|
ForwardedHeaders ForwardedHeadersPolicy
|
||||||
ForwardedBy string
|
ForwardedBy string
|
||||||
Via string
|
Via string
|
||||||
PreserveHost bool
|
PreserveHost bool
|
||||||
|
|
||||||
|
RequestHeaders *HeaderOps
|
||||||
|
ResponseHeaders *RespHeaderOps
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
|
errReverseProxyNilTarget = errors.New("reverse proxy target is nil")
|
||||||
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
|
errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host")
|
||||||
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
|
errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete")
|
||||||
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
|
errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type HeaderOps struct {
|
||||||
|
Add map[string][]string
|
||||||
|
Set map[string][]string
|
||||||
|
Delete []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RespHeaderOps struct {
|
||||||
|
*HeaderOps
|
||||||
|
Deferred bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ops *HeaderOps) applyToRequest(req *http.Request) {
|
||||||
|
if ops == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
replacer := newReverseProxyReplacer(req)
|
||||||
|
|
||||||
|
for fieldName, vals := range ops.Add {
|
||||||
|
fieldName = replacer.Replace(fieldName)
|
||||||
|
for _, v := range vals {
|
||||||
|
req.Header.Add(fieldName, replacer.Replace(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, vals := range ops.Set {
|
||||||
|
fieldName = replacer.Replace(fieldName)
|
||||||
|
req.Header.Del(fieldName)
|
||||||
|
for _, v := range vals {
|
||||||
|
req.Header.Add(fieldName, replacer.Replace(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, fieldName := range ops.Delete {
|
||||||
|
fieldName = strings.ToLower(replacer.Replace(fieldName))
|
||||||
|
if fieldName == "*" {
|
||||||
|
for k := range req.Header {
|
||||||
|
req.Header.Del(k)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"):
|
||||||
|
pattern := fieldName[1:len(fieldName)-1]
|
||||||
|
for k := range req.Header {
|
||||||
|
if strings.Contains(strings.ToLower(k), pattern) {
|
||||||
|
req.Header.Del(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(fieldName, "*"):
|
||||||
|
suffix := fieldName[1:]
|
||||||
|
for k := range req.Header {
|
||||||
|
if strings.HasSuffix(strings.ToLower(k), suffix) {
|
||||||
|
req.Header.Del(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case strings.HasSuffix(fieldName, "*"):
|
||||||
|
prefix := fieldName[:len(fieldName)-1]
|
||||||
|
for k := range req.Header {
|
||||||
|
if strings.HasPrefix(strings.ToLower(k), prefix) {
|
||||||
|
req.Header.Del(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
req.Header.Del(fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ops *RespHeaderOps) applyToResponse(hdr http.Header) {
|
||||||
|
if ops == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ops.Deferred {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ops.applyTo(hdr, newReverseProxyReplacerFromHeader(hdr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) {
|
||||||
|
if ops == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if repl == nil {
|
||||||
|
repl = &reverseProxyReplacer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, vals := range ops.Add {
|
||||||
|
fieldName = repl.Replace(fieldName)
|
||||||
|
for _, v := range vals {
|
||||||
|
hdr.Add(fieldName, repl.Replace(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, vals := range ops.Set {
|
||||||
|
fieldName = repl.Replace(fieldName)
|
||||||
|
hdr.Del(fieldName)
|
||||||
|
for _, v := range vals {
|
||||||
|
hdr.Add(fieldName, repl.Replace(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, fieldName := range ops.Delete {
|
||||||
|
fieldName = strings.ToLower(repl.Replace(fieldName))
|
||||||
|
if fieldName == "*" {
|
||||||
|
for k := range hdr {
|
||||||
|
hdr.Del(k)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"):
|
||||||
|
pattern := fieldName[1:len(fieldName)-1]
|
||||||
|
for k := range hdr {
|
||||||
|
if strings.Contains(strings.ToLower(k), pattern) {
|
||||||
|
hdr.Del(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case strings.HasPrefix(fieldName, "*"):
|
||||||
|
suffix := fieldName[1:]
|
||||||
|
for k := range hdr {
|
||||||
|
if strings.HasSuffix(strings.ToLower(k), suffix) {
|
||||||
|
hdr.Del(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case strings.HasSuffix(fieldName, "*"):
|
||||||
|
prefix := fieldName[:len(fieldName)-1]
|
||||||
|
for k := range hdr {
|
||||||
|
if strings.HasPrefix(strings.ToLower(k), prefix) {
|
||||||
|
hdr.Del(k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
hdr.Del(fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type reverseProxyReplacer struct {
|
||||||
|
req *http.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer {
|
||||||
|
return &reverseProxyReplacer{req: req}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer {
|
||||||
|
return &reverseProxyReplacer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *reverseProxyReplacer) Replace(s string) string {
|
||||||
|
if r == nil || s == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
type reverseProxyHandler struct {
|
type reverseProxyHandler struct {
|
||||||
config ReverseProxyConfig
|
config ReverseProxyConfig
|
||||||
upstreams []*reverseProxyUpstream
|
upstreams []*reverseProxyUpstream
|
||||||
|
|
@ -573,6 +734,10 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte
|
||||||
outreq.Header.Set("User-Agent", "")
|
outreq.Header.Set("User-Agent", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.config.RequestHeaders != nil {
|
||||||
|
p.config.RequestHeaders.applyToRequest(outreq)
|
||||||
|
}
|
||||||
|
|
||||||
if p.config.ModifyRequest != nil {
|
if p.config.ModifyRequest != nil {
|
||||||
p.config.ModifyRequest(outreq)
|
p.config.ModifyRequest(outreq)
|
||||||
}
|
}
|
||||||
|
|
@ -808,6 +973,10 @@ func appendXForwardedFor(header http.Header, clientIP string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool {
|
func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool {
|
||||||
|
if p.config.ResponseHeaders != nil && !p.config.ResponseHeaders.Deferred {
|
||||||
|
p.config.ResponseHeaders.applyToResponse(res.Header)
|
||||||
|
}
|
||||||
|
|
||||||
if p.config.ModifyResponse == nil {
|
if p.config.ModifyResponse == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
||||||
220
reverseproxy_headers_test.go
Normal file
220
reverseproxy_headers_test.go
Normal file
|
|
@ -0,0 +1,220 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsAdd(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got := r.Header.Get("X-Custom-Header"); got != "test-value" {
|
||||||
|
t.Errorf("expected X-Custom-Header=test-value, got %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Add: map[string][]string{
|
||||||
|
"X-Custom-Header": {"test-value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(proxy.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsDelete(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("X-Sensitive") != "" {
|
||||||
|
t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive"))
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Delete: []string{"X-Sensitive"},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||||
|
req.Header.Set("X-Sensitive", "should-be-removed")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyHeaderOpsSet(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
got := r.Header.Get("X-Replace")
|
||||||
|
if got != "new-value" {
|
||||||
|
t.Errorf("expected X-Replace=new-value, got %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
RequestHeaders: &HeaderOps{
|
||||||
|
Set: map[string][]string{
|
||||||
|
"X-Replace": {"new-value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil)
|
||||||
|
req.Header.Set("X-Replace", "old-value")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyResponseHeaderOps(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Backend", "backend-server")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ResponseHeaders: &RespHeaderOps{
|
||||||
|
HeaderOps: &HeaderOps{
|
||||||
|
Set: map[string][]string{
|
||||||
|
"X-Custom": {"custom-value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(proxy.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if got := resp.Header.Get("X-Custom"); got != "custom-value" {
|
||||||
|
t.Errorf("expected X-Custom=custom-value, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Powered-By", "Express")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
target, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse target: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := New()
|
||||||
|
engine.GET("/test", ReverseProxy(ReverseProxyConfig{
|
||||||
|
Target: target,
|
||||||
|
ResponseHeaders: &RespHeaderOps{
|
||||||
|
HeaderOps: &HeaderOps{
|
||||||
|
Delete: []string{"X-Powered-By"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxy := httptest.NewServer(engine)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(proxy.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if got := resp.Header.Get("X-Powered-By"); got != "" {
|
||||||
|
t.Errorf("expected X-Powered-By to be deleted, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue