ghproxy/rate/rate.go

63 lines
1.2 KiB
Go

package rate
import (
"time"
"github.com/WJQSERVER-STUDIO/go-utils/logger"
"golang.org/x/time/rate"
)
// 日志输出
var (
logw = logger.Logw
logInfo = logger.LogInfo
logWarning = logger.LogWarning
logError = logger.LogError
)
// 总体限流器
type RateLimiter struct {
limiter *rate.Limiter
}
// 基于IP的限流器
type IPRateLimiter struct {
limiters map[string]*RateLimiter
limit int
burst int
duration time.Duration
}
func New(limit int, burst int, duration time.Duration) *RateLimiter {
return &RateLimiter{
limiter: rate.NewLimiter(rate.Limit(float64(limit)/duration.Seconds()), burst),
}
}
func (rl *RateLimiter) Allow() bool {
return rl.limiter.Allow()
}
func NewIPRateLimiter(limit int, burst int, duration time.Duration) *IPRateLimiter {
return &IPRateLimiter{
limiters: make(map[string]*RateLimiter),
limit: limit,
burst: burst,
duration: duration,
}
}
func (rl *IPRateLimiter) Allow(ip string) bool {
if ip == "" {
logWarning("empty ip")
return false
}
limiter, ok := rl.limiters[ip]
if !ok {
// 创建新的 RateLimiter 并存储
limiter = New(rl.limit, rl.burst, rl.duration)
rl.limiters[ip] = limiter
}
return limiter.Allow()
}