optimize blacklist

This commit is contained in:
WJQSERVER 2025-02-14 07:19:17 +08:00
parent 97ae0044e7
commit 40c9ca5f38
5 changed files with 71 additions and 40 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@ demo
demo.toml demo.toml
*.log *.log
*.bak *.bak
list.json

View file

@ -18,7 +18,7 @@ var (
func Init(cfg *config.Config) { func Init(cfg *config.Config) {
if cfg.Blacklist.Enabled { if cfg.Blacklist.Enabled {
LoadBlacklist(cfg) InitBlacklist(cfg)
} }
if cfg.Whitelist.Enabled { if cfg.Whitelist.Enabled {
LoadWhitelist(cfg) LoadWhitelist(cfg)

View file

@ -2,59 +2,89 @@ package auth
import ( import (
"encoding/json" "encoding/json"
"fmt"
"ghproxy/config" "ghproxy/config"
"os" "os"
"strings" "strings"
"sync"
) )
type BlacklistConfig struct { type Blacklist struct {
Blacklist []string `json:"blacklist"` userSet map[string]struct{} // 用户级黑名单
repoSet map[string]map[string]struct{} // 仓库级黑名单
initOnce sync.Once // 确保初始化只执行一次
initialized bool // 初始化状态标识
} }
var ( var (
cfg *config.Config instance *Blacklist
blacklistfile = "/data/ghproxy/config/blacklist.json" initErr error
blacklist *BlacklistConfig
) )
func LoadBlacklist(cfg *config.Config) { // InitBlacklist 初始化黑名单(线程安全,仅执行一次)
blacklistfile = cfg.Blacklist.BlacklistFile func InitBlacklist(cfg *config.Config) error {
blacklist = &BlacklistConfig{} instance = &Blacklist{
userSet: make(map[string]struct{}),
repoSet: make(map[string]map[string]struct{}),
}
data, err := os.ReadFile(blacklistfile) data, err := os.ReadFile(cfg.Blacklist.BlacklistFile)
if err != nil { if err != nil {
logError("Failed to read blacklist file: %v", err) return fmt.Errorf("failed to read blacklist: %w", err)
} }
err = json.Unmarshal(data, blacklist) var list struct {
if err != nil { Entries []string `json:"blacklist"`
logError("Failed to unmarshal blacklist JSON: %v", err) }
if err := json.Unmarshal(data, &list); err != nil {
return fmt.Errorf("invalid blacklist format: %w", err)
}
for _, entry := range list.Entries {
user, repo := splitUserRepo(entry)
switch {
case repo == "" || repo == "*":
instance.userSet[user] = struct{}{}
default:
if _, exists := instance.repoSet[user]; !exists {
instance.repoSet[user] = make(map[string]struct{})
}
instance.repoSet[user][repo] = struct{}{}
} }
} }
func CheckBlacklist(repouser string, user string, repo string) bool { instance.initialized = true
return forRangeCheckBlacklist(blacklist.Blacklist, repouser, user) return nil
} }
func sliceRepoName_Blacklist(fullrepo string) (string, string) { // CheckBlacklist 检查用户和仓库是否在黑名单中(无锁设计)
s := strings.Split(fullrepo, "/") func CheckBlacklist(username, repo string) bool {
if len(s) != 2 { if instance == nil || !instance.initialized {
return "", ""
}
return s[0], s[1]
}
func forRangeCheckBlacklist(blist []string, fullrepo string, user string) bool {
for _, blocked := range blist {
users, _ := sliceRepoName_Blacklist(blocked)
if user == users {
if strings.HasSuffix(blocked, "/*") {
return true
}
if fullrepo == blocked {
return true
}
}
}
return false return false
} }
// 先检查用户级黑名单
if _, exists := instance.userSet[username]; exists {
return true
}
// 再检查仓库级黑名单
if repos, userExists := instance.repoSet[username]; userExists {
// 允许仓库名为空时的全用户仓库匹配
if repo == "" {
return true
}
_, repoExists := repos[repo]
return repoExists
}
return false
}
// splitUserRepo 优化分割逻辑(仅初始化时使用)
func splitUserRepo(fullRepo string) (user, repo string) {
if idx := strings.Index(fullRepo, "/"); idx > 0 {
return fullRepo[:idx], fullRepo[idx+1:]
}
return fullRepo, ""
}

View file

@ -1,7 +1,7 @@
{ {
"blacklist": [ "blacklist": [
"black/list", "eviluser",
"test/test1", "spamuser/bad-repo",
"example/*" "malwareuser/*"
] ]
} }

View file

@ -73,7 +73,7 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra
// 黑名单检查 // 黑名单检查
if cfg.Blacklist.Enabled { if cfg.Blacklist.Enabled {
blacklist := auth.CheckBlacklist(repouser, username, repo) blacklist := auth.CheckBlacklist(username, repo)
if blacklist { if blacklist {
logErrMsg := fmt.Sprintf("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.Get("User-Agent"), c.Request.Proto, repouser) logErrMsg := fmt.Sprintf("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.Get("User-Agent"), c.Request.Proto, repouser)
errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", repouser) errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", repouser)