ghproxy/rate/rate.go
2025-04-24 17:50:18 +08:00

107 lines
2.4 KiB
Go

package rate
import (
"sync"
"time"
"github.com/WJQSERVER-STUDIO/go-utils/logger"
"golang.org/x/time/rate"
)
// 日志模块
var (
logw = logger.Logw
logDump = logger.LogDump
logDebug = logger.LogDebug
logInfo = logger.LogInfo
logWarning = logger.LogWarning
logError = logger.LogError
)
// RateLimiter 总体限流器
type RateLimiter struct {
limiter *rate.Limiter
}
// New 创建一个总体限流器
func New(limit int, burst int, duration time.Duration) *RateLimiter {
if limit <= 0 {
limit = 1
logWarning("rate limit per minute must be positive, setting to 1")
}
if burst <= 0 {
burst = 1
logWarning("rate limit burst must be positive, setting to 1")
}
rateLimit := rate.Limit(float64(limit) / duration.Seconds())
return &RateLimiter{
limiter: rate.NewLimiter(rateLimit, burst),
}
}
// Allow 检查是否允许请求通过
func (rl *RateLimiter) Allow() bool {
return rl.limiter.Allow()
}
// IPRateLimiter 基于IP的限流器
type IPRateLimiter struct {
limiters map[string]*RateLimiter // 用户级限流器 map
mu sync.RWMutex // 保护 limiters map
limit int // 每 duration 时间段内允许的请求数
burst int // 突发请求数
duration time.Duration // 限流周期
}
// NewIPRateLimiter 创建一个基于IP的限流器
func NewIPRateLimiter(ipLimit int, ipBurst int, duration time.Duration) *IPRateLimiter {
if ipLimit <= 0 {
ipLimit = 1
logWarning("IP rate limit per minute must be positive, setting to 1")
}
if ipBurst <= 0 {
ipBurst = 1
logWarning("IP rate limit burst must be positive, setting to 1")
}
logInfo("IP Rate Limiter initialized with limit: %d, burst: %d, duration: %v", ipLimit, ipBurst, duration)
return &IPRateLimiter{
limiters: make(map[string]*RateLimiter),
limit: ipLimit,
burst: ipBurst,
duration: duration,
}
}
// Allow 检查给定IP的请求是否允许通过
func (rl *IPRateLimiter) Allow(ip string) bool {
if ip == "" {
logWarning("empty ip for rate limiting")
return false
}
// 使用读锁快速查找
rl.mu.RLock()
limiter, found := rl.limiters[ip]
rl.mu.RUnlock()
if found {
return limiter.Allow()
}
// 未找到,获取写锁来创建和添加
rl.mu.Lock()
// 双重检查
limiter, found = rl.limiters[ip]
if !found {
newL := New(rl.limit, rl.burst, rl.duration)
rl.limiters[ip] = newL
limiter = newL
}
rl.mu.Unlock()
return limiter.Allow()
}