diff --git a/.gitignore b/.gitignore index 5b7c484..04ef99a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ demo demo.toml *.log -*.bak \ No newline at end of file +*.bak +list.json \ No newline at end of file diff --git a/auth/auth.go b/auth/auth.go index 0705b8d..fdc5833 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -18,7 +18,7 @@ var ( func Init(cfg *config.Config) { if cfg.Blacklist.Enabled { - LoadBlacklist(cfg) + InitBlacklist(cfg) } if cfg.Whitelist.Enabled { LoadWhitelist(cfg) diff --git a/auth/blacklist.go b/auth/blacklist.go index 3aac712..ba2091b 100644 --- a/auth/blacklist.go +++ b/auth/blacklist.go @@ -2,59 +2,89 @@ package auth import ( "encoding/json" + "fmt" "ghproxy/config" "os" "strings" + "sync" ) -type BlacklistConfig struct { - Blacklist []string `json:"blacklist"` +type Blacklist struct { + userSet map[string]struct{} // 用户级黑名单 + repoSet map[string]map[string]struct{} // 仓库级黑名单 + initOnce sync.Once // 确保初始化只执行一次 + initialized bool // 初始化状态标识 } var ( - cfg *config.Config - blacklistfile = "/data/ghproxy/config/blacklist.json" - blacklist *BlacklistConfig + instance *Blacklist + initErr error ) -func LoadBlacklist(cfg *config.Config) { - blacklistfile = cfg.Blacklist.BlacklistFile - blacklist = &BlacklistConfig{} +// InitBlacklist 初始化黑名单(线程安全,仅执行一次) +func InitBlacklist(cfg *config.Config) error { + instance = &Blacklist{ + userSet: make(map[string]struct{}), + repoSet: make(map[string]map[string]struct{}), + } - data, err := os.ReadFile(blacklistfile) + data, err := os.ReadFile(cfg.Blacklist.BlacklistFile) if err != nil { - logError("Failed to read blacklist file: %v", err) + return fmt.Errorf("failed to read blacklist: %w", err) } - err = json.Unmarshal(data, blacklist) - if err != nil { - logError("Failed to unmarshal blacklist JSON: %v", err) + var list struct { + Entries []string `json:"blacklist"` } -} - -func CheckBlacklist(repouser string, user string, repo string) bool { - return forRangeCheckBlacklist(blacklist.Blacklist, repouser, user) -} - -func sliceRepoName_Blacklist(fullrepo string) (string, string) { - s := strings.Split(fullrepo, "/") - if len(s) != 2 { - return "", "" + if err := json.Unmarshal(data, &list); err != nil { + return fmt.Errorf("invalid blacklist format: %w", err) } - return s[0], s[1] -} -func forRangeCheckBlacklist(blist []string, fullrepo string, user string) bool { - for _, blocked := range blist { - users, _ := sliceRepoName_Blacklist(blocked) - if user == users { - if strings.HasSuffix(blocked, "/*") { - return true - } - if fullrepo == blocked { - return true + for _, entry := range list.Entries { + user, repo := splitUserRepo(entry) + switch { + case repo == "" || repo == "*": + instance.userSet[user] = struct{}{} + default: + if _, exists := instance.repoSet[user]; !exists { + instance.repoSet[user] = make(map[string]struct{}) } + instance.repoSet[user][repo] = struct{}{} } } + + instance.initialized = true + return nil +} + +// CheckBlacklist 检查用户和仓库是否在黑名单中(无锁设计) +func CheckBlacklist(username, repo string) bool { + if instance == nil || !instance.initialized { + return false + } + + // 先检查用户级黑名单 + if _, exists := instance.userSet[username]; exists { + return true + } + + // 再检查仓库级黑名单 + if repos, userExists := instance.repoSet[username]; userExists { + // 允许仓库名为空时的全用户仓库匹配 + if repo == "" { + return true + } + _, repoExists := repos[repo] + return repoExists + } + return false } + +// splitUserRepo 优化分割逻辑(仅初始化时使用) +func splitUserRepo(fullRepo string) (user, repo string) { + if idx := strings.Index(fullRepo, "/"); idx > 0 { + return fullRepo[:idx], fullRepo[idx+1:] + } + return fullRepo, "" +} diff --git a/config/blacklist.json b/config/blacklist.json index 839d3f6..c1ed260 100644 --- a/config/blacklist.json +++ b/config/blacklist.json @@ -1,7 +1,7 @@ { "blacklist": [ - "black/list", - "test/test1", - "example/*" + "eviluser", + "spamuser/bad-repo", + "malwareuser/*" ] } \ No newline at end of file diff --git a/proxy/handler.go b/proxy/handler.go index a052125..480aa67 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -73,7 +73,7 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra // 黑名单检查 if cfg.Blacklist.Enabled { - blacklist := auth.CheckBlacklist(repouser, username, repo) + blacklist := auth.CheckBlacklist(username, repo) if blacklist { logErrMsg := fmt.Sprintf("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.Get("User-Agent"), c.Request.Proto, repouser) errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", repouser)