Merge pull request #88 from infinite-iroha/feat/replacer-dynamic-vars

feat: 实现动态请求变量替换
This commit is contained in:
WJQSERVER 2026-04-21 18:14:38 +08:00 committed by GitHub
commit 8fdb16ae1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 140 additions and 2 deletions

View file

@ -264,11 +264,26 @@ func (ops *HeaderOps) Provision() error {
} }
type reverseProxyReplacer struct { type reverseProxyReplacer struct {
req *http.Request method, host, path, query, scheme, uri, proto string
} }
func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer { func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer {
return &reverseProxyReplacer{req: req} if req == nil || req.URL == nil {
return &reverseProxyReplacer{}
}
uri := req.RequestURI
if uri == "" {
uri = req.URL.RequestURI()
}
return &reverseProxyReplacer{
method: req.Method,
host: req.Host,
path: req.URL.EscapedPath(),
query: req.URL.RawQuery,
scheme: reverseProxyRequestScheme(req),
uri: uri,
proto: req.Proto,
}
} }
func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer { func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer {
@ -279,6 +294,27 @@ func (r *reverseProxyReplacer) Replace(s string) string {
if r == nil || s == "" { if r == nil || s == "" {
return s return s
} }
if r.method != "" {
s = strings.ReplaceAll(s, "{method}", r.method)
}
if r.host != "" {
s = strings.ReplaceAll(s, "{host}", r.host)
}
if r.path != "" {
s = strings.ReplaceAll(s, "{path}", r.path)
}
if r.query != "" {
s = strings.ReplaceAll(s, "{query}", r.query)
}
if r.scheme != "" {
s = strings.ReplaceAll(s, "{scheme}", r.scheme)
}
if r.uri != "" {
s = strings.ReplaceAll(s, "{uri}", r.uri)
}
if r.proto != "" {
s = strings.ReplaceAll(s, "{proto}", r.proto)
}
return s return s
} }

View file

@ -426,3 +426,105 @@ func BenchmarkReplacementApplyRegexp(b *testing.B) {
_ = r.apply(s) _ = r.apply(s)
} }
} }
func TestReverseProxyReplacerDynamicVars(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "http://example.com/api/v1/users?sort=name&limit=10", nil)
req.Host = "example.com"
repl := newReverseProxyReplacer(req)
tests := []struct {
name string
input string
want string
}{
{"method", "{method}", "GET"},
{"host", "{host}", "example.com"},
{"path", "{path}", "/api/v1/users"},
{"query", "{query}", "sort=name&limit=10"},
{"scheme", "{scheme}", "http"},
{"proto", "{proto}", "HTTP/1.1"},
{"combined", "X-{method}-{path}", "X-GET-/api/v1/users"},
{"no vars", "static-value", "static-value"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := repl.Replace(tt.input); got != tt.want {
t.Errorf("Replace(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestReverseProxyReplacerNilRequest(t *testing.T) {
repl := newReverseProxyReplacer(nil)
if got := repl.Replace("{method}"); got != "{method}" {
t.Errorf("expected unchanged string with nil request, got %q", got)
}
}
func TestReverseProxyReplacerNilReplacer(t *testing.T) {
var repl *reverseProxyReplacer
if got := repl.Replace("{method}"); got != "{method}" {
t.Errorf("expected unchanged string with nil replacer, got %q", got)
}
}
func TestReverseProxyReplacerFromHeader(t *testing.T) {
hdr := make(http.Header)
repl := newReverseProxyReplacerFromHeader(hdr)
if got := repl.Replace("{method}"); got != "{method}" {
t.Errorf("expected unchanged string from header replacer, got %q", got)
}
}
func TestReverseProxyHeaderOpsWithDynamicVars(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("X-Forwarded-Path"); got != "/dynamic/path" {
t.Errorf("expected X-Forwarded-Path=/dynamic/path, got %q", got)
}
if got := r.Header.Get("X-Forwarded-Method"); got != "GET" {
t.Errorf("expected X-Forwarded-Method=GET, got %q", got)
}
if got := r.Header.Get("X-Forwarded-Host"); got != "client.example" {
t.Errorf("expected X-Forwarded-Host=client.example, got %q", got)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()
target, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parse target: %v", err)
}
engine := New()
engine.GET("/dynamic/path", ReverseProxy(ReverseProxyConfig{
Target: target,
RequestHeaders: &HeaderOps{
Add: map[string][]string{
"X-Forwarded-Path": {"{path}"},
"X-Forwarded-Method": {"{method}"},
"X-Forwarded-Host": {"{host}"},
},
},
}))
proxy := httptest.NewServer(engine)
defer proxy.Close()
req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/dynamic/path", nil)
req.Host = "client.example"
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}