This commit is contained in:
wjqserver 2025-04-24 17:50:18 +08:00
parent b955c915ff
commit 7a6544c6c9
12 changed files with 170 additions and 182 deletions

View file

@ -1,27 +1,22 @@
package proxy
import (
"bytes"
"fmt"
"ghproxy/auth"
"ghproxy/config"
"ghproxy/rate"
"html/template"
"io/fs"
"github.com/cloudwego/hertz/pkg/app"
)
func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo string, rawPath string) bool {
var errMsg string
// 白名单检查
if cfg.Whitelist.Enabled {
var whitelist bool
whitelist = auth.CheckWhitelist(user, repo)
if !whitelist {
errMsg = fmt.Sprintf("Whitelist Blocked repo: %s/%s", user, repo)
c.JSON(403, map[string]string{"error": errMsg})
ErrorPage(c, NewErrorWithStatusLookup(403, fmt.Sprintf("Whitelist Blocked repo: %s/%s", user, repo)))
logInfo("%s %s %s %s %s Whitelist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo)
return true
}
@ -32,8 +27,7 @@ func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo stri
var blacklist bool
blacklist = auth.CheckBlacklist(user, repo)
if blacklist {
errMsg = fmt.Sprintf("Blacklist Blocked repo: %s/%s", user, repo)
c.JSON(403, map[string]string{"error": errMsg})
ErrorPage(c, NewErrorWithStatusLookup(403, fmt.Sprintf("Blacklist Blocked repo: %s/%s", user, repo)))
logInfo("%s %s %s %s %s Blacklist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo)
return true
}
@ -48,7 +42,7 @@ func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPat
if matcher == "api" && !cfg.Auth.ForceAllowApi {
if cfg.Auth.Method != "header" || !cfg.Auth.Enabled {
c.JSON(403, map[string]string{"error": "Github API Req without AuthHeader is Not Allowed"})
ErrorPage(c, NewErrorWithStatusLookup(403, "Github API Req without AuthHeader is Not Allowed"))
logInfo("%s %s %s %s %s AuthHeader Unavailable", c.ClientIP(), c.Method(), rawPath)
return true
}
@ -59,7 +53,7 @@ func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPat
var authcheck bool
authcheck, err = auth.AuthHandler(c, cfg)
if !authcheck {
c.JSON(401, map[string]string{"error": "Unauthorized"})
ErrorPage(c, NewErrorWithStatusLookup(401, fmt.Sprintf("Unauthorized: %v", err)))
logInfo("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err)
return true
}
@ -81,105 +75,16 @@ func rateCheck(cfg *config.Config, c *app.RequestContext, limiter *rate.RateLimi
allowed = limiter.Allow()
default:
logWarning("Invalid RateLimit Method")
c.JSON(500, map[string]string{"error": "Invalid RateLimit Method"})
ErrorPage(c, NewErrorWithStatusLookup(500, "Invalid RateLimit Method"))
return true
}
if !allowed {
c.JSON(429, map[string]string{"error": "Too Many Requests"})
logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol())
ErrorPage(c, NewErrorWithStatusLookup(429, fmt.Sprintf("Too Many Requests; Rate Limit is %d per minute", cfg.RateLimit.RatePerMinute)))
logInfo("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol())
return true
}
}
return false
}
var errPagesFs fs.FS
func InitErrPagesFS(pages fs.FS) error {
var err error
errPagesFs, err = fs.Sub(pages, "pages/err")
if err != nil {
return err
}
return nil
}
type ErrorPageData struct {
StatusCode int
StatusDesc string
StatusText string
HelpInfo string
ErrorMessage string
}
func ErrPageUnwarper(errInfo *GHProxyErrors) ErrorPageData {
return ErrorPageData{
StatusCode: errInfo.StatusCode,
StatusDesc: errInfo.StatusDesc,
StatusText: errInfo.StatusText,
HelpInfo: errInfo.HelpInfo,
ErrorMessage: errInfo.ErrorMessage,
}
}
func ErrorPage(c *app.RequestContext, errInfo *GHProxyErrors) {
pageData, _ := htmlTemplateRender(errPagesFs, ErrPageUnwarper(errInfo))
/*
if err != nil {
c.JSON(errInfo.StatusCode, map[string]string{"error": errInfo.ErrorMessage})
logDebug("Error reading page.tmpl: %v", err)
return
}
*/
c.Data(errInfo.StatusCode, "text/html; charset=utf-8", pageData)
return
}
func NotFoundPage(c *app.RequestContext) {
/*
pageData, err := fs.ReadFile(errPagesFs, "404.html")
if err != nil {
c.JSON(404, map[string]string{"error": "Not Found"})
logDebug("Error reading 404.html: %v", err)
return
}
*/
pageData, err := htmlTemplateRender(errPagesFs, ErrorPageData{
StatusCode: 404,
StatusDesc: "Not Found",
StatusText: "The requested URL was not found on this server.",
ErrorMessage: "The requested URL was not found on this server.",
})
if err != nil {
c.JSON(404, map[string]string{"error": "Not Found"})
logDebug("Error reading 404.html: %v", err)
return
}
c.Data(404, "text/html; charset=utf-8", pageData)
return
}
func htmlTemplateRender(fsys fs.FS, data interface{}) ([]byte, error) {
tmplPath := "page.tmpl"
tmpl, err := template.ParseFS(fsys, tmplPath)
if err != nil {
return nil, fmt.Errorf("error parsing template: %w", err)
}
if tmpl == nil {
return nil, fmt.Errorf("template is nil")
}
// 创建一个 bytes.Buffer 用于存储渲染结果
var buf bytes.Buffer
err = tmpl.Execute(&buf, data)
if err != nil {
return nil, fmt.Errorf("error executing template: %w", err)
}
// 返回 buffer 的内容作为 []byte
return buf.Bytes(), nil
}