From 7e5b12dff8a35f029719030c190a4dd043e89865 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 16 Apr 2025 15:50:04 +0800 Subject: [PATCH] Fix: Optimize header forwarding by excluding headers in a single pass --- proxy/chunkreq.go | 66 +++++++++++++++++++++++++++++++++-------------- proxy/handler.go | 37 ++++++++++++++++++-------- 2 files changed, 73 insertions(+), 30 deletions(-) diff --git a/proxy/chunkreq.go b/proxy/chunkreq.go index 47f0e59..7187e1f 100644 --- a/proxy/chunkreq.go +++ b/proxy/chunkreq.go @@ -5,20 +5,39 @@ import ( "context" "fmt" "ghproxy/config" + "io" "net/http" "strconv" "github.com/cloudwego/hertz/pkg/app" ) +var ( + headersToRemove = map[string]struct{}{ + "Content-Security-Policy": {}, + "Referrer-Policy": {}, + "Strict-Transport-Security": {}, + "X-Github-Request-Id": {}, + "X-Timer": {}, + "X-Served-By": {}, + "X-Fastly-Request-Id": {}, + } +) + func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, cfg *config.Config, matcher string) { - method := c.Request.Method - body := c.Request.Body() + var ( + method []byte + bodyReader *bytes.Buffer + req *http.Request + resp *http.Response + err error + ) - bodyReader := bytes.NewBuffer(body) + method = c.Request.Method() + bodyReader = bytes.NewBuffer(c.Request.Body()) - req, err := client.NewRequest(string(method()), u, bodyReader) + req, err = client.NewRequest(string(method), u, bodyReader) if err != nil { HandleError(c, fmt.Sprintf("Failed to create request: %v", err)) return @@ -27,7 +46,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c removeWSHeader(req) // 删除Conection Upgrade头, 避免与HTTP/2冲突(检查是否存在Upgrade头) AuthPassThrough(c, cfg, req) - resp, err := client.Do(req) + resp, err = client.Do(req) if err != nil { HandleError(c, fmt.Sprintf("Failed to send request: %v", err)) return @@ -55,8 +74,9 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c bodySize = -1 } if err == nil && bodySize > sizelimit { - finalURL := resp.Request.URL.String() - err := resp.Body.Close() + var finalURL string + finalURL = resp.Request.URL.String() + err = resp.Body.Close() if err != nil { logError("Failed to close response body: %v", err) } @@ -66,20 +86,26 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c } } - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) + /* + for header := range headersToRemove { + resp.Header.Del(header) } - } - headersToRemove := map[string]struct{}{ - "Content-Security-Policy": {}, - "Referrer-Policy": {}, - "Strict-Transport-Security": {}, - } + for key := range resp.Header { + var values []string = resp.Header.Values(key) + for _, value := range values { + c.Header(key, value) + } + } + */ - for header := range headersToRemove { - resp.Header.Del(header) + // 复制响应头,排除需要移除的 header + for key, values := range resp.Header { + if _, shouldRemove := headersToRemove[key]; !shouldRemove { + for _, value := range values { + c.Header(key, value) + } + } } switch cfg.Server.Cors { @@ -105,7 +131,9 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c logInfo("Is Shell: %s %s %s %s %s", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol()) c.Header("Content-Length", "") - reader, _, err := processLinks(resp.Body, compress, string(c.Request.Host()), cfg) + var reader io.Reader + + reader, _, err = processLinks(resp.Body, compress, string(c.Request.Host()), cfg) c.SetBodyStream(reader, -1) if err != nil { diff --git a/proxy/handler.go b/proxy/handler.go index f8b998c..9d214d3 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -41,13 +41,19 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra } } - rawPath := strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/ - matches := re.FindStringSubmatch(rawPath) // 匹配路径 + var ( + rawPath string + matches []string + errMsg string + ) + + rawPath = strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/ + matches = re.FindStringSubmatch(rawPath) // 匹配路径 logInfo("URL: %v", matches) // 匹配路径错误处理 if len(matches) < 3 { - errMsg := fmt.Sprintf("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + errMsg = fmt.Sprintf("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) logWarning(errMsg) c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath) return @@ -56,7 +62,14 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra // 制作url rawPath = "https://" + matches[2] - user, repo, matcher, err := Matcher(rawPath, cfg) + var ( + user string + repo string + matcher string + err error + ) + + user, repo, matcher, err = Matcher(rawPath, cfg) if err != nil { if errors.Is(err, ErrInvalidURL) { c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath) @@ -69,18 +82,19 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra return } } - username := user - logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), username, repo) + logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) // dump log 记录详细信息 c.ClientIP(), c.Method(), rawPath,c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), full Header logDump("%s %s %s %s %s %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), c.Request.Header.Header()) - repouser := fmt.Sprintf("%s/%s", username, repo) + var repouser string + repouser = fmt.Sprintf("%s/%s", user, repo) // 白名单检查 if cfg.Whitelist.Enabled { - whitelist := auth.CheckWhitelist(username, repo) + var whitelist bool + whitelist = auth.CheckWhitelist(user, repo) if !whitelist { - errMsg := fmt.Sprintf("Whitelist Blocked repo: %s", repouser) + errMsg = fmt.Sprintf("Whitelist Blocked repo: %s", repouser) c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) logWarning("%s %s %s %s %s Whitelist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) return @@ -89,9 +103,10 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra // 黑名单检查 if cfg.Blacklist.Enabled { - blacklist := auth.CheckBlacklist(username, repo) + var blacklist bool + blacklist = auth.CheckBlacklist(user, repo) if blacklist { - errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", repouser) + errMsg = fmt.Sprintf("Blacklist Blocked repo: %s", repouser) c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) logWarning("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) return