From d79aeaaacd9f6faf5c43351a04126c7701228085 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Mon, 21 Apr 2025 18:52:45 +0800 Subject: [PATCH] 25w30c --- CHANGELOG.md | 7 ++ DEV-VERSION | 2 +- auth/auth.go | 3 +- main.go | 6 ++ proxy/chunkreq.go | 8 +- proxy/gitreq.go | 6 -- proxy/handler.go | 213 +++++----------------------------------------- proxy/match.go | 12 ++- proxy/routing.go | 62 ++++++++++++++ proxy/utils.go | 110 ++++++++++++++++++++++++ 10 files changed, 214 insertions(+), 215 deletions(-) create mode 100644 proxy/routing.go create mode 100644 proxy/utils.go diff --git a/CHANGELOG.md b/CHANGELOG.md index c60bba0..e59acad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # 更新日志 +25w30c - 2025-04-21 +--- +- PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用; +- CHANGE: 改进handle, 复用共同部分 +- CHANGE: 细化url匹配的返回码处理 +- CHANGE: 增加404界面 + 25w30b - 2025-04-21 --- - PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用; diff --git a/DEV-VERSION b/DEV-VERSION index 087f399..a71ee32 100644 --- a/DEV-VERSION +++ b/DEV-VERSION @@ -1 +1 @@ -25w30b \ No newline at end of file +25w30c \ No newline at end of file diff --git a/auth/auth.go b/auth/auth.go index 1a7f1a2..eacfbf6 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,7 +1,6 @@ package auth import ( - "context" "fmt" "ghproxy/config" @@ -36,7 +35,7 @@ func Init(cfg *config.Config) { logDebug("Auth Init") } -func AuthHandler(ctx context.Context, c *app.RequestContext, cfg *config.Config) (isValid bool, err error) { +func AuthHandler(c *app.RequestContext, cfg *config.Config) (isValid bool, err error) { if cfg.Auth.Method == "parameters" { isValid, err = AuthParametersHandler(c, cfg) return isValid, err diff --git a/main.go b/main.go index 2fd7bfa..5314f31 100644 --- a/main.go +++ b/main.go @@ -210,6 +210,12 @@ func loadEmbeddedPages(cfg *config.Config) (fs.FS, fs.FS, error) { return nil, nil, fmt.Errorf("failed to load embedded pages: %w", err) } + // 初始化errPagesFs + errPagesInitErr := proxy.InitErrPagesFS(pagesFS) + if errPagesInitErr != nil { + logWarning("errPagesInitErr: %s", errPagesInitErr) + } + var assets fs.FS assets, err = fs.Sub(pagesFS, "pages/assets") return pages, assets, nil diff --git a/proxy/chunkreq.go b/proxy/chunkreq.go index 18cc0cb..1571b27 100644 --- a/proxy/chunkreq.go +++ b/proxy/chunkreq.go @@ -38,14 +38,12 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c var ( method []byte - //bodyReader *bytes.Buffer - req *http.Request - resp *http.Response - err error + req *http.Request + resp *http.Response + err error ) method = c.Request.Method() - //bodyReader = bytes.NewBuffer(c.Request.Body()) req, err = client.NewRequest(string(method), u, c.Request.BodyStream()) if err != nil { diff --git a/proxy/gitreq.go b/proxy/gitreq.go index ad53f94..b8ba66d 100644 --- a/proxy/gitreq.go +++ b/proxy/gitreq.go @@ -27,14 +27,8 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co var ( resp *http.Response - //err error ) - //body := c.Request.Body() - - //bodyReader := bytes.NewBuffer(body) - // 创建请求 - if cfg.GitClone.Mode == "cache" { req, err := gitclient.NewRequest(method, u, c.Request.BodyStream()) if err != nil { diff --git a/proxy/handler.go b/proxy/handler.go index 4b9b977..444685a 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -3,8 +3,6 @@ package proxy import ( "context" "errors" - "fmt" - "ghproxy/auth" "ghproxy/config" "ghproxy/rate" "net/http" @@ -19,43 +17,22 @@ var re = regexp.MustCompile(`^(http:|https:)?/?/?(.*)`) // 匹配http://或https func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) app.HandlerFunc { return func(ctx context.Context, c *app.RequestContext) { - // 限制访问频率 - if cfg.RateLimit.Enabled { - - var allowed bool - - switch cfg.RateLimit.RateMethod { - case "ip": - allowed = iplimiter.Allow(c.ClientIP()) - case "total": - allowed = limiter.Allow() - default: - logWarning("Invalid RateLimit Method") - return - } - - if !allowed { - c.JSON(http.StatusTooManyRequests, map[string]string{"error": "Too Many Requests"}) - logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) - return - } - } + rateCheck(cfg, c, limiter, iplimiter) var ( rawPath string matches []string - errMsg string ) rawPath = strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/ matches = re.FindStringSubmatch(rawPath) // 匹配路径 - logInfo("URL: %v", matches) + logDebug("URL: %v", matches) // 匹配路径错误处理 if len(matches) < 3 { - errMsg = fmt.Sprintf("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) - logWarning(errMsg) - c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath) + logWarning("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + //c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath) + c.JSON(http.StatusForbidden, map[string]string{"error": "Invalid URL Format"}) return } @@ -72,68 +49,35 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra user, repo, matcher, err = Matcher(rawPath, cfg) if err != nil { if errors.Is(err, ErrInvalidURL) { - c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath) + c.JSON(ErrInvalidURL.Code, map[string]string{"error": "Invalid URL Format, Path: " + rawPath}) logWarning(err.Error()) return } if errors.Is(err, ErrAuthHeaderUnavailable) { - c.String(http.StatusForbidden, "AuthHeader Unavailable") + c.JSON(ErrAuthHeaderUnavailable.Code, map[string]string{"error": "AuthHeader Unavailable"}) + logWarning(err.Error()) + return + } + if errors.Is(err, ErrNotFound) { + //c.JSON(ErrNotFound.Code, map[string]string{"error": "Not Found"}) + NotFoundPage(c) logWarning(err.Error()) return } } logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) - // dump log 记录详细信息 c.ClientIP(), c.Method(), rawPath,c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), full Header - logDump("%s %s %s %s %s %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), c.Request.Header.Header()) - var repouser string - repouser = fmt.Sprintf("%s/%s", user, repo) + logDump("%s", c.Request.Header.Header()) - // 白名单检查 - if cfg.Whitelist.Enabled { - var whitelist bool - whitelist = auth.CheckWhitelist(user, repo) - if !whitelist { - errMsg = fmt.Sprintf("Whitelist Blocked repo: %s", repouser) - c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Whitelist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) - return - } - } - - // 黑名单检查 - if cfg.Blacklist.Enabled { - var blacklist bool - blacklist = auth.CheckBlacklist(user, repo) - if blacklist { - errMsg = fmt.Sprintf("Blacklist Blocked repo: %s", repouser) - c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) - return - } - } - - // 若匹配api.github.com/repos/用户名/仓库名/路径, 则检查是否开启HeaderAuth + listCheck(cfg, c, user, repo, rawPath) + authCheck(c, cfg, matcher, rawPath) // 处理blob/raw路径 if matcher == "blob" { rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) } - // 鉴权 - if cfg.Auth.Enabled { - var authcheck bool - authcheck, err = auth.AuthHandler(ctx, c, cfg) - if !authcheck { - //c.AbortWithStatusJSON(401, gin.H{"error": "Unauthorized"}) - c.AbortWithStatusJSON(401, map[string]string{"error": "Unauthorized"}) - logWarning("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) - return - } - } - - // IP METHOD URL USERAGENT PROTO MATCHES - logDebug("%s %s %s %s %s Matched: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), matcher) + logDebug("Matched: %v", matcher) switch matcher { case "releases", "blob", "raw", "gist", "api": @@ -141,127 +85,8 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra case "clone": GitReq(ctx, c, rawPath, cfg, "git") default: - c.String(http.StatusForbidden, "Invalid input.") - fmt.Println("Invalid input.") - return - } - } -} - -func RoutingHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) app.HandlerFunc { - return func(ctx context.Context, c *app.RequestContext) { - // 输出所有传入参数 - logDebug("Context Params(Matcher): %v", c.GetString("matcher")) - - // 限制访问频率 - if cfg.RateLimit.Enabled { - - var allowed bool - - switch cfg.RateLimit.RateMethod { - case "ip": - allowed = iplimiter.Allow(c.ClientIP()) - case "total": - allowed = limiter.Allow() - default: - logWarning("Invalid RateLimit Method") - return - } - - if !allowed { - c.JSON(http.StatusTooManyRequests, map[string]string{"error": "Too Many Requests"}) - logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) - return - } - } - - var ( - rawPath string - errMsg string - ) - - rawPath = strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/ - - var ( - user string - repo string - matcher string - err error - ) - - user = c.Param("user") - repo = c.Param("repo") - matcher = c.GetString("matcher") - - logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) - // dump log 记录详细信息 c.ClientIP(), c.Method(), rawPath,c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), full Header - logDump("%s %s %s %s %s %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), c.Request.Header.Header()) - - // 白名单检查 - if cfg.Whitelist.Enabled { - var whitelist bool - whitelist = auth.CheckWhitelist(user, repo) - if !whitelist { - errMsg = fmt.Sprintf("Whitelist Blocked repo: %s/%s", user, repo) - c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Whitelist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) - return - } - } - - // 黑名单检查 - if cfg.Blacklist.Enabled { - var blacklist bool - blacklist = auth.CheckBlacklist(user, repo) - if blacklist { - errMsg = fmt.Sprintf("Blacklist Blocked repo: %s/%s", user, repo) - c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Blacklist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) - return - } - } - - if matcher == "api" && !cfg.Auth.ForceAllowApi { - if cfg.Auth.Method != "header" || !cfg.Auth.Enabled { - c.JSON(http.StatusForbidden, map[string]string{"error": "Github API Req without AuthHeader is Not Allowed"}) - logWarning("%s %s %s %s %s AuthHeader Unavailable", c.ClientIP(), c.Method(), rawPath) - return - } - } - - // 鉴权 - if cfg.Auth.Enabled { - var authcheck bool - authcheck, err = auth.AuthHandler(ctx, c, cfg) - if !authcheck { - //c.AbortWithStatusJSON(401, gin.H{"error": "Unauthorized"}) - c.AbortWithStatusJSON(401, map[string]string{"error": "Unauthorized"}) - logWarning("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) - return - } - } - - // 若匹配api.github.com/repos/用户名/仓库名/路径, 则检查是否开启HeaderAuth - - // 处理blob/raw路径 - if matcher == "blob" { - rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) - } - - // 为rawpath加入https:// 头 - rawPath = "https://" + rawPath - - // IP METHOD URL USERAGENT PROTO MATCHES - logDebug("%s %s %s %s %s Matched: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), matcher) - - switch matcher { - case "releases", "blob", "raw", "gist", "api": - ChunkedProxyRequest(ctx, c, rawPath, cfg, matcher) - case "clone": - GitReq(ctx, c, rawPath, cfg, "git") - default: - c.String(http.StatusForbidden, "Invalid input.") - fmt.Println("Invalid input.") + c.JSON(http.StatusForbidden, map[string]string{"error": "Invalid input."}) + logError("Invalid input") return } } diff --git a/proxy/match.go b/proxy/match.go index fa42494..5f0f6e2 100644 --- a/proxy/match.go +++ b/proxy/match.go @@ -27,6 +27,10 @@ var ( Code: 403, Msg: "AuthHeader Unavailable", } + ErrNotFound = &MatcherErrors{ + Code: 404, + Msg: "Not Found", + } ) func (e *MatcherErrors) Error() string { @@ -122,7 +126,7 @@ func Matcher(rawPath string, cfg *config.Config) (string, string, string, error) } return user, repo, matcher, nil } - return "", "", "", ErrInvalidURL + return "", "", "", ErrNotFound } func EditorMatcher(rawPath string, cfg *config.Config) (bool, string, error) { @@ -165,12 +169,6 @@ func EditorMatcher(rawPath string, cfg *config.Config) (bool, string, error) { // 匹配文件扩展名是sh的rawPath func MatcherShell(rawPath string) bool { - /* - if strings.HasSuffix(rawPath, ".sh") { - return true - } - return false - */ return strings.HasSuffix(rawPath, ".sh") } diff --git a/proxy/routing.go b/proxy/routing.go new file mode 100644 index 0000000..c07824c --- /dev/null +++ b/proxy/routing.go @@ -0,0 +1,62 @@ +package proxy + +import ( + "context" + "ghproxy/config" + "ghproxy/rate" + "net/http" + "strings" + + "github.com/cloudwego/hertz/pkg/app" +) + +func RoutingHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) app.HandlerFunc { + return func(ctx context.Context, c *app.RequestContext) { + + rateCheck(cfg, c, limiter, iplimiter) + + var ( + rawPath string + ) + + rawPath = strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/ + + var ( + user string + repo string + matcher string + ) + + user = c.Param("user") + repo = c.Param("repo") + matcher = c.GetString("matcher") + + logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) + logDump("%s", c.Request.Header.Header()) + + listCheck(cfg, c, user, repo, rawPath) + authCheck(c, cfg, matcher, rawPath) + + // 处理blob/raw路径 + if matcher == "blob" { + rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) + } + + // 为rawpath加入https:// 头 + rawPath = "https://" + rawPath + + // IP METHOD URL USERAGENT PROTO MATCHES + logDebug("%s %s %s %s %s Matched: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), matcher) + + switch matcher { + case "releases", "blob", "raw", "gist", "api": + ChunkedProxyRequest(ctx, c, rawPath, cfg, matcher) + case "clone": + GitReq(ctx, c, rawPath, cfg, "git") + default: + c.JSON(http.StatusForbidden, map[string]string{"error": "Invalid input."}) + logError("Invalid input") + return + } + } +} diff --git a/proxy/utils.go b/proxy/utils.go new file mode 100644 index 0000000..615daa2 --- /dev/null +++ b/proxy/utils.go @@ -0,0 +1,110 @@ +package proxy + +import ( + "fmt" + "ghproxy/auth" + "ghproxy/config" + "ghproxy/rate" + "io/fs" + + "github.com/cloudwego/hertz/pkg/app" +) + +func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo string, rawPath string) { + var errMsg string + + // 白名单检查 + if cfg.Whitelist.Enabled { + var whitelist bool + whitelist = auth.CheckWhitelist(user, repo) + if !whitelist { + errMsg = fmt.Sprintf("Whitelist Blocked repo: %s/%s", user, repo) + c.JSON(403, map[string]string{"error": errMsg}) + logWarning("%s %s %s %s %s Whitelist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) + return + } + } + + // 黑名单检查 + if cfg.Blacklist.Enabled { + var blacklist bool + blacklist = auth.CheckBlacklist(user, repo) + if blacklist { + errMsg = fmt.Sprintf("Blacklist Blocked repo: %s/%s", user, repo) + c.JSON(403, map[string]string{"error": errMsg}) + logWarning("%s %s %s %s %s Blacklist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) + return + } + } +} + +// 鉴权 +func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPath string) { + var err error + + if matcher == "api" && !cfg.Auth.ForceAllowApi { + if cfg.Auth.Method != "header" || !cfg.Auth.Enabled { + c.JSON(403, map[string]string{"error": "Github API Req without AuthHeader is Not Allowed"}) + logWarning("%s %s %s %s %s AuthHeader Unavailable", c.ClientIP(), c.Method(), rawPath) + return + } + } + + // 鉴权 + if cfg.Auth.Enabled { + var authcheck bool + authcheck, err = auth.AuthHandler(c, cfg) + if !authcheck { + c.JSON(401, map[string]string{"error": "Unauthorized"}) + logWarning("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) + return + } + } +} + +func rateCheck(cfg *config.Config, c *app.RequestContext, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) { + // 限制访问频率 + if cfg.RateLimit.Enabled { + + var allowed bool + + switch cfg.RateLimit.RateMethod { + case "ip": + allowed = iplimiter.Allow(c.ClientIP()) + case "total": + allowed = limiter.Allow() + default: + logWarning("Invalid RateLimit Method") + c.JSON(500, map[string]string{"error": "Invalid RateLimit Method"}) + return + } + + if !allowed { + c.JSON(429, map[string]string{"error": "Too Many Requests"}) + logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + return + } + } +} + +var errPagesFs fs.FS + +func InitErrPagesFS(pages fs.FS) error { + var err error + errPagesFs, err = fs.Sub(pages, "pages/err") + if err != nil { + return err + } + return nil +} + +func NotFoundPage(c *app.RequestContext) { + pageData, err := fs.ReadFile(errPagesFs, "404.html") + if err != nil { + c.JSON(404, map[string]string{"error": "Not Found"}) + logDebug("Error reading 404.html: %v", err) + return + } + c.Data(404, "text/html; charset=utf-8", pageData) + return +}