From b955c915ff2badf400bdbda51ceaceb925fa482c Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 24 Apr 2025 01:09:53 +0800 Subject: [PATCH] fix callback issue --- proxy/handler.go | 37 ++++++++++++++----------------------- proxy/routing.go | 18 +++++++++++++++--- proxy/utils.go | 32 +++++++++++++++++++------------- 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/proxy/handler.go b/proxy/handler.go index dae537b..e7aa236 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -16,7 +16,11 @@ var re = regexp.MustCompile(`^(http:|https:)?/?/?(.*)`) // 匹配http://或https func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) app.HandlerFunc { return func(ctx context.Context, c *app.RequestContext) { - rateCheck(cfg, c, limiter, iplimiter) + var shoudBreak bool + shoudBreak = rateCheck(cfg, c, limiter, iplimiter) + if shoudBreak { + return + } var ( rawPath string @@ -30,7 +34,6 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra // 匹配路径错误处理 if len(matches) < 3 { logWarning("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) - //c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath) c.JSON(http.StatusForbidden, map[string]string{"error": "Invalid URL Format"}) return } @@ -42,7 +45,6 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra user string repo string matcher string - //err error ) var matcherErr *GHProxyErrors @@ -50,31 +52,20 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra if matcherErr != nil { ErrorPage(c, matcherErr) return - /* - if errors.Is(err, ErrInvalidURL) { - c.JSON(ErrInvalidURL.Code, map[string]string{"error": "Invalid URL Format, Path: " + rawPath}) - logWarning(err.Error()) - return - } - if errors.Is(err, ErrAuthHeaderUnavailable) { - c.JSON(ErrAuthHeaderUnavailable.Code, map[string]string{"error": "AuthHeader Unavailable"}) - logWarning(err.Error()) - return - } - if errors.Is(err, ErrNotFound) { - //c.JSON(ErrNotFound.Code, map[string]string{"error": "Not Found"}) - NotFoundPage(c) - logWarning(err.Error()) - return - } - */ } logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) logDump("%s", c.Request.Header.Header()) - listCheck(cfg, c, user, repo, rawPath) - authCheck(c, cfg, matcher, rawPath) + shoudBreak = listCheck(cfg, c, user, repo, rawPath) + if shoudBreak { + return + } + + shoudBreak = authCheck(c, cfg, matcher, rawPath) + if shoudBreak { + return + } // 处理blob/raw路径 if matcher == "blob" { diff --git a/proxy/routing.go b/proxy/routing.go index c07824c..140c5c5 100644 --- a/proxy/routing.go +++ b/proxy/routing.go @@ -13,7 +13,12 @@ import ( func RoutingHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) app.HandlerFunc { return func(ctx context.Context, c *app.RequestContext) { - rateCheck(cfg, c, limiter, iplimiter) + var shoudBreak bool + + shoudBreak = rateCheck(cfg, c, limiter, iplimiter) + if shoudBreak { + return + } var ( rawPath string @@ -34,8 +39,15 @@ func RoutingHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) logDump("%s", c.Request.Header.Header()) - listCheck(cfg, c, user, repo, rawPath) - authCheck(c, cfg, matcher, rawPath) + shoudBreak = listCheck(cfg, c, user, repo, rawPath) + if shoudBreak { + return + } + + shoudBreak = authCheck(c, cfg, matcher, rawPath) + if shoudBreak { + return + } // 处理blob/raw路径 if matcher == "blob" { diff --git a/proxy/utils.go b/proxy/utils.go index 25027b6..cb6dedf 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -12,7 +12,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" ) -func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo string, rawPath string) { +func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo string, rawPath string) bool { var errMsg string // 白名单检查 @@ -22,8 +22,8 @@ func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo stri if !whitelist { errMsg = fmt.Sprintf("Whitelist Blocked repo: %s/%s", user, repo) c.JSON(403, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Whitelist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) - return + logInfo("%s %s %s %s %s Whitelist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) + return true } } @@ -34,21 +34,23 @@ func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo stri if blacklist { errMsg = fmt.Sprintf("Blacklist Blocked repo: %s/%s", user, repo) c.JSON(403, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Blacklist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) - return + logInfo("%s %s %s %s %s Blacklist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) + return true } } + + return false } // 鉴权 -func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPath string) { +func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPath string) bool { var err error if matcher == "api" && !cfg.Auth.ForceAllowApi { if cfg.Auth.Method != "header" || !cfg.Auth.Enabled { c.JSON(403, map[string]string{"error": "Github API Req without AuthHeader is Not Allowed"}) - logWarning("%s %s %s %s %s AuthHeader Unavailable", c.ClientIP(), c.Method(), rawPath) - return + logInfo("%s %s %s %s %s AuthHeader Unavailable", c.ClientIP(), c.Method(), rawPath) + return true } } @@ -58,13 +60,15 @@ func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPat authcheck, err = auth.AuthHandler(c, cfg) if !authcheck { c.JSON(401, map[string]string{"error": "Unauthorized"}) - logWarning("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) - return + logInfo("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) + return true } } + + return false } -func rateCheck(cfg *config.Config, c *app.RequestContext, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) { +func rateCheck(cfg *config.Config, c *app.RequestContext, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) bool { // 限制访问频率 if cfg.RateLimit.Enabled { @@ -78,15 +82,17 @@ func rateCheck(cfg *config.Config, c *app.RequestContext, limiter *rate.RateLimi default: logWarning("Invalid RateLimit Method") c.JSON(500, map[string]string{"error": "Invalid RateLimit Method"}) - return + return true } if !allowed { c.JSON(429, map[string]string{"error": "Too Many Requests"}) logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) - return + return true } } + + return false } var errPagesFs fs.FS