package proxy import ( "bytes" "fmt" "ghproxy/auth" "ghproxy/config" "ghproxy/rate" "html/template" "io/fs" "github.com/cloudwego/hertz/pkg/app" ) func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo string, rawPath string) bool { var errMsg string // 白名单检查 if cfg.Whitelist.Enabled { var whitelist bool whitelist = auth.CheckWhitelist(user, repo) if !whitelist { errMsg = fmt.Sprintf("Whitelist Blocked repo: %s/%s", user, repo) c.JSON(403, map[string]string{"error": errMsg}) logInfo("%s %s %s %s %s Whitelist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) return true } } // 黑名单检查 if cfg.Blacklist.Enabled { var blacklist bool blacklist = auth.CheckBlacklist(user, repo) if blacklist { errMsg = fmt.Sprintf("Blacklist Blocked repo: %s/%s", user, repo) c.JSON(403, map[string]string{"error": errMsg}) logInfo("%s %s %s %s %s Blacklist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) return true } } return false } // 鉴权 func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPath string) bool { var err error if matcher == "api" && !cfg.Auth.ForceAllowApi { if cfg.Auth.Method != "header" || !cfg.Auth.Enabled { c.JSON(403, map[string]string{"error": "Github API Req without AuthHeader is Not Allowed"}) logInfo("%s %s %s %s %s AuthHeader Unavailable", c.ClientIP(), c.Method(), rawPath) return true } } // 鉴权 if cfg.Auth.Enabled { var authcheck bool authcheck, err = auth.AuthHandler(c, cfg) if !authcheck { c.JSON(401, map[string]string{"error": "Unauthorized"}) logInfo("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) return true } } return false } func rateCheck(cfg *config.Config, c *app.RequestContext, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) bool { // 限制访问频率 if cfg.RateLimit.Enabled { var allowed bool switch cfg.RateLimit.RateMethod { case "ip": allowed = iplimiter.Allow(c.ClientIP()) case "total": allowed = limiter.Allow() default: logWarning("Invalid RateLimit Method") c.JSON(500, map[string]string{"error": "Invalid RateLimit Method"}) return true } if !allowed { c.JSON(429, map[string]string{"error": "Too Many Requests"}) logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) return true } } return false } var errPagesFs fs.FS func InitErrPagesFS(pages fs.FS) error { var err error errPagesFs, err = fs.Sub(pages, "pages/err") if err != nil { return err } return nil } type ErrorPageData struct { StatusCode int StatusDesc string StatusText string HelpInfo string ErrorMessage string } func ErrPageUnwarper(errInfo *GHProxyErrors) ErrorPageData { return ErrorPageData{ StatusCode: errInfo.StatusCode, StatusDesc: errInfo.StatusDesc, StatusText: errInfo.StatusText, HelpInfo: errInfo.HelpInfo, ErrorMessage: errInfo.ErrorMessage, } } func ErrorPage(c *app.RequestContext, errInfo *GHProxyErrors) { pageData, _ := htmlTemplateRender(errPagesFs, ErrPageUnwarper(errInfo)) /* if err != nil { c.JSON(errInfo.StatusCode, map[string]string{"error": errInfo.ErrorMessage}) logDebug("Error reading page.tmpl: %v", err) return } */ c.Data(errInfo.StatusCode, "text/html; charset=utf-8", pageData) return } func NotFoundPage(c *app.RequestContext) { /* pageData, err := fs.ReadFile(errPagesFs, "404.html") if err != nil { c.JSON(404, map[string]string{"error": "Not Found"}) logDebug("Error reading 404.html: %v", err) return } */ pageData, err := htmlTemplateRender(errPagesFs, ErrorPageData{ StatusCode: 404, StatusDesc: "Not Found", StatusText: "The requested URL was not found on this server.", ErrorMessage: "The requested URL was not found on this server.", }) if err != nil { c.JSON(404, map[string]string{"error": "Not Found"}) logDebug("Error reading 404.html: %v", err) return } c.Data(404, "text/html; charset=utf-8", pageData) return } func htmlTemplateRender(fsys fs.FS, data interface{}) ([]byte, error) { tmplPath := "page.tmpl" tmpl, err := template.ParseFS(fsys, tmplPath) if err != nil { return nil, fmt.Errorf("error parsing template: %w", err) } if tmpl == nil { return nil, fmt.Errorf("template is nil") } // 创建一个 bytes.Buffer 用于存储渲染结果 var buf bytes.Buffer err = tmpl.Execute(&buf, data) if err != nil { return nil, fmt.Errorf("error executing template: %w", err) } // 返回 buffer 的内容作为 []byte return buf.Bytes(), nil }