This commit is contained in:
WJQSERVER 2024-10-06 00:22:42 +08:00
parent b9a7f30705
commit 57f67278a3
5 changed files with 23 additions and 40 deletions

View file

@ -33,26 +33,3 @@ func AuthHandler(c *gin.Context, cfg *config.Config) bool {
return isValid return isValid
} }
func IsBlacklisted(username, repo string, blacklist map[string][]string, enabled bool) bool {
if !enabled {
return false
}
// 检查 blacklist 是否为 nil
if blacklist == nil {
// 可以选择记录日志或返回 false
logw("Warning: Blacklist map is nil")
return false
}
if repos, ok := blacklist[username]; ok {
for _, blacklistedRepo := range repos {
if blacklistedRepo == repo {
return true
}
}
}
return false
}

10
auth/blacklist.go Normal file
View file

@ -0,0 +1,10 @@
package auth
func CheckBlacklist(fullrepo string) bool {
if fullrepo == "test/test1" {
logw("%s in blacklist", fullrepo)
return true
}
logw("%s not in blacklist", fullrepo)
return false
}

View file

@ -33,7 +33,7 @@ type Config struct {
} `yaml:"blacklist"` } `yaml:"blacklist"`
} }
type Blacklist struct { type BlacklistMap struct {
Blist map[string][]string `yaml:"blacklist"` Blist map[string][]string `yaml:"blacklist"`
} }
@ -47,8 +47,8 @@ func LoadConfig(filePath string) (*Config, error) {
} }
// LoadBlacklistConfig 从 YAML 配置文件加载黑名单配置 // LoadBlacklistConfig 从 YAML 配置文件加载黑名单配置
func LoadBlacklistConfig(filePath string) (*Blacklist, error) { func LoadBlacklistConfig(filePath string) (*BlacklistMap, error) {
var blacklist Blacklist var blacklist BlacklistMap
if err := loadYAML(filePath, &blacklist); err != nil { if err := loadYAML(filePath, &blacklist); err != nil {
return nil, err return nil, err
} }

View file

@ -16,7 +16,7 @@ import (
var ( var (
cfg *config.Config cfg *config.Config
blacklist *config.Blacklist blacklist *config.BlacklistMap
logw = logger.Logw logw = logger.Logw
router *gin.Engine router *gin.Engine
configfile = "/data/ghproxy/config/config.yaml" configfile = "/data/ghproxy/config/config.yaml"
@ -44,7 +44,7 @@ func loadConfig() {
func loadBlacklistConfig() { func loadBlacklistConfig() {
// 初始化黑名单配置 // 初始化黑名单配置
blacklist, err := config.LoadBlacklistConfig("/data/ghproxy/config/blacklist.yaml") blacklist, err := config.LoadBlacklistConfig(cfg.Blacklist.BlacklistFile)
if err != nil { if err != nil {
log.Fatalf("Failed to load blacklist: %v", err) log.Fatalf("Failed to load blacklist: %v", err)
} }
@ -87,7 +87,7 @@ func init() {
// 未匹配路由处理 // 未匹配路由处理
router.NoRoute(func(c *gin.Context) { router.NoRoute(func(c *gin.Context) {
proxy.NoRouteHandler(cfg, blacklist)(c) proxy.NoRouteHandler(cfg, config.BlacklistMap{})(c)
}) })
} }

View file

@ -30,7 +30,7 @@ var exps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?gist\.github\.com/([^/]+)/.+?/.+`), regexp.MustCompile(`^(?:https?://)?gist\.github\.com/([^/]+)/.+?/.+`),
} }
func NoRouteHandler(cfg *config.Config, blacklist *config.Blacklist) gin.HandlerFunc { func NoRouteHandler(cfg *config.Config, bmap config.BlacklistMap) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
re := regexp.MustCompile(`^(http:|https:)?/?/?(.*)`) re := regexp.MustCompile(`^(http:|https:)?/?/?(.*)`)
@ -57,17 +57,13 @@ func NoRouteHandler(cfg *config.Config, blacklist *config.Blacklist) gin.Handler
username := pathParts[2] username := pathParts[2]
repo := pathParts[3] repo := pathParts[3]
logw("Blacklist Check > Username: %s, Repo: %s", username, repo) logw("Blacklist Check > Username: %s, Repo: %s", username, repo)
fullrepo := fmt.Sprintf("%s/%s", username, repo)
if blacklist.Blist == nil { // 黑名单检查
logw("Warning: Blacklist map is nil") blacklistpass := auth.CheckBlacklist(fullrepo)
// 根据需要初始化或处理 if !blacklistpass {
blacklist.Blist = make(map[string][]string) c.AbortWithStatusJSON(404, gin.H{"error": "Not found"})
} logw("Blacklisted repo: %s", fullrepo)
// 检查仓库是否在黑名单中
if auth.IsBlacklisted(username, repo, blacklist.Blist, cfg.Blacklist.Enabled) {
c.String(http.StatusForbidden, "Access denied: repository is blacklisted.")
logw("Blacklisted repository: %s/%s", username, repo)
return return
} }