add bandwidth limiter

This commit is contained in:
wjqserver 2025-05-14 01:33:54 +08:00
parent 3f8d16511e
commit 71bc2aaed7
10 changed files with 157 additions and 22 deletions

View file

@ -129,11 +129,35 @@ type WhitelistConfig struct {
WhitelistFile string `toml:"whitelistFile"` WhitelistFile string `toml:"whitelistFile"`
} }
/*
[rateLimit]
enabled = false
rateMethod = "total" # "total" or "ip"
ratePerMinute = 100
burst = 10
[rateLimit.bandwidthLimit]
enabled = false
totalLimit = "100mbps"
totalBurst = "100mbps"
singleLimit = "10mbps"
singleBurst = "10mbps"
*/
type RateLimitConfig struct { type RateLimitConfig struct {
Enabled bool `toml:"enabled"` Enabled bool `toml:"enabled"`
RateMethod string `toml:"rateMethod"` RateMethod string `toml:"rateMethod"`
RatePerMinute int `toml:"ratePerMinute"` RatePerMinute int `toml:"ratePerMinute"`
Burst int `toml:"burst"` Burst int `toml:"burst"`
BandwidthLimit BandwidthLimitConfig
}
type BandwidthLimitConfig struct {
Enabled bool `toml:"enabled"`
TotalLimit string `toml:"totalLimit"`
TotalBurst string `toml:"totalBurst"`
SingleLimit string `toml:"singleLimit"`
SingleBurst string `toml:"singleBurst"`
} }
/* /*
@ -252,6 +276,13 @@ func DefaultConfig() *Config {
RateMethod: "total", RateMethod: "total",
RatePerMinute: 100, RatePerMinute: 100,
Burst: 10, Burst: 10,
BandwidthLimit: BandwidthLimitConfig{
Enabled: false,
TotalLimit: "100mbps",
TotalBurst: "100mbps",
SingleLimit: "10mbps",
SingleBurst: "10mbps",
},
}, },
Outbound: OutboundConfig{ Outbound: OutboundConfig{
Enabled: false, Enabled: false,

View file

@ -57,6 +57,13 @@ rateMethod = "total" # "ip" or "total"
ratePerMinute = 180 ratePerMinute = 180
burst = 5 burst = 5
[rateLimit.bandwidthLimit]
enabled = false
totalLimit = "100mbps"
totalBurst = "100mbps"
singleLimit = "10mbps"
singleBurst = "10mbps"
[outbound] [outbound]
enabled = false enabled = false
url = "socks5://127.0.0.1:1080" # "http://127.0.0.1:7890" url = "socks5://127.0.0.1:1080" # "http://127.0.0.1:7890"

7
go.mod
View file

@ -4,8 +4,7 @@ go 1.24.3
require ( require (
github.com/BurntSushi/toml v1.5.0 github.com/BurntSushi/toml v1.5.0
//github.com/WJQSERVER-STUDIO/httpc v0.5.0 github.com/WJQSERVER-STUDIO/httpc v0.5.1
github.com/WJQSERVER-STUDIO/httpc v0.5.1-0.20250513102952-d961182b2489
github.com/WJQSERVER-STUDIO/logger v1.6.0 github.com/WJQSERVER-STUDIO/logger v1.6.0
github.com/cloudwego/hertz v0.9.7 github.com/cloudwego/hertz v0.9.7
github.com/hertz-contrib/http2 v0.1.8 github.com/hertz-contrib/http2 v0.1.8
@ -13,6 +12,8 @@ require (
golang.org/x/time v0.11.0 golang.org/x/time v0.11.0
) )
require github.com/WJQSERVER-STUDIO/go-utils/limitreader v0.0.2
require ( require (
github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 // indirect github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 // indirect
github.com/WJQSERVER-STUDIO/go-utils/log v0.0.2 // indirect github.com/WJQSERVER-STUDIO/go-utils/log v0.0.2 // indirect
@ -38,4 +39,4 @@ require (
google.golang.org/protobuf v1.36.6 // indirect google.golang.org/protobuf v1.36.6 // indirect
) )
//replace github.com/WJQSERVER-STUDIO/httpc v0.5.0 => /data/github/WJQSERVER-STUDIO/httpc //replace github.com/WJQSERVER-STUDIO/httpc v0.5.1 => /data/github/WJQSERVER-STUDIO/httpc

6
go.sum
View file

@ -2,10 +2,12 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 h1:JLtFd00AdFg/TP+dtvIzLkdHwKUGPOAijN1sMtEYoFg= github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 h1:JLtFd00AdFg/TP+dtvIzLkdHwKUGPOAijN1sMtEYoFg=
github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4/go.mod h1:FZ6XE+4TKy4MOfX1xWKe6Rwsg0ucYFCdNh1KLvyKTfc= github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4/go.mod h1:FZ6XE+4TKy4MOfX1xWKe6Rwsg0ucYFCdNh1KLvyKTfc=
github.com/WJQSERVER-STUDIO/go-utils/limitreader v0.0.2 h1:8bBkKk6E2Zr+I5szL7gyc5f0DK8N9agIJCpM1Cqw2NE=
github.com/WJQSERVER-STUDIO/go-utils/limitreader v0.0.2/go.mod h1:yPX8xuZH+py7eLJwOYj3VVI/4/Yuy5+x8Mhq8qezcPg=
github.com/WJQSERVER-STUDIO/go-utils/log v0.0.2 h1:9CSf+V0ZQPl2ijC/g6v/ObemmhpKcikKVIodsaLExTA= github.com/WJQSERVER-STUDIO/go-utils/log v0.0.2 h1:9CSf+V0ZQPl2ijC/g6v/ObemmhpKcikKVIodsaLExTA=
github.com/WJQSERVER-STUDIO/go-utils/log v0.0.2/go.mod h1:j9Q+xnwpOfve7/uJnZ2izRQw6NNoXjvJHz7vUQAaLZE= github.com/WJQSERVER-STUDIO/go-utils/log v0.0.2/go.mod h1:j9Q+xnwpOfve7/uJnZ2izRQw6NNoXjvJHz7vUQAaLZE=
github.com/WJQSERVER-STUDIO/httpc v0.5.1-0.20250513102952-d961182b2489 h1:BScWEkOFYMDaSB4SNhBa6XeBoBjg1IHxmGE3NSNW6zw= github.com/WJQSERVER-STUDIO/httpc v0.5.1 h1:+TKCPYBuj7PAHuiduGCGAqsHAa4QtsUfoVwRN777q64=
github.com/WJQSERVER-STUDIO/httpc v0.5.1-0.20250513102952-d961182b2489/go.mod h1:M7KNUZjjhCkzzcg9lBPs9YfkImI+7vqjAyjdA19+joE= github.com/WJQSERVER-STUDIO/httpc v0.5.1/go.mod h1:M7KNUZjjhCkzzcg9lBPs9YfkImI+7vqjAyjdA19+joE=
github.com/WJQSERVER-STUDIO/logger v1.6.0 h1:xK2xV7hlkMXaWzvj4+cNoNWA+JfnJaHX6VU+RrPnr7Q= github.com/WJQSERVER-STUDIO/logger v1.6.0 h1:xK2xV7hlkMXaWzvj4+cNoNWA+JfnJaHX6VU+RrPnr7Q=
github.com/WJQSERVER-STUDIO/logger v1.6.0/go.mod h1:TICMsR7geROHBg6rxwkqUNGydo34XVsX93yeoxyfuyY= github.com/WJQSERVER-STUDIO/logger v1.6.0/go.mod h1:TICMsR7geROHBg6rxwkqUNGydo34XVsX93yeoxyfuyY=
github.com/bytedance/gopkg v0.1.1/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/gopkg v0.1.1/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=

View file

@ -181,7 +181,11 @@ func setupRateLimit(cfg *config.Config) {
} }
func InitReq(cfg *config.Config) { func InitReq(cfg *config.Config) {
proxy.InitReq(cfg) err := proxy.InitReq(cfg)
if err != nil {
fmt.Printf("Failed to initialize request: %v\n", err)
os.Exit(1)
}
} }
// loadEmbeddedPages 加载嵌入式页面资源 // loadEmbeddedPages 加载嵌入式页面资源

64
proxy/bandwidth.go Normal file
View file

@ -0,0 +1,64 @@
package proxy
import (
"errors"
"ghproxy/config"
"github.com/WJQSERVER-STUDIO/go-utils/limitreader"
"golang.org/x/time/rate"
)
var (
bandwidthLimit rate.Limit
bandwidthBurst rate.Limit
)
func UnDefiendRateStringErrHandle(err error) error {
if errors.Is(err, &limitreader.UnDefiendRateStringErr{}) {
logWarning("UnDefiendRateStringErr: %s", err)
return nil
}
return err
}
func SetGlobalRateLimit(cfg *config.Config) error {
if cfg.RateLimit.BandwidthLimit.Enabled {
var err error
var totalLimit rate.Limit
var totalBurst rate.Limit
totalLimit, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.TotalLimit)
if UnDefiendRateStringErrHandle(err) != nil {
logError("Failed to parse total bandwidth limit: %v", err)
return err
}
totalBurst, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.TotalBurst)
if UnDefiendRateStringErrHandle(err) != nil {
logError("Failed to parse total bandwidth burst: %v", err)
return err
}
limitreader.SetGlobalRateLimit(totalLimit, int(totalBurst))
err = SetBandwidthLimit(cfg)
if UnDefiendRateStringErrHandle(err) != nil {
logError("Failed to set bandwidth limit: %v", err)
return err
}
} else {
limitreader.SetGlobalRateLimit(rate.Inf, 0)
}
return nil
}
func SetBandwidthLimit(cfg *config.Config) error {
var err error
bandwidthLimit, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.SingleLimit)
if UnDefiendRateStringErrHandle(err) != nil {
logError("Failed to parse bandwidth limit: %v", err)
return err
}
bandwidthBurst, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.SingleBurst)
if UnDefiendRateStringErrHandle(err) != nil {
logError("Failed to parse bandwidth burst: %v", err)
return err
}
return nil
}

View file

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/WJQSERVER-STUDIO/go-utils/limitreader"
"github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app"
) )
@ -94,6 +95,12 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
c.Status(resp.StatusCode) c.Status(resp.StatusCode)
bodyReader := resp.Body
if cfg.RateLimit.BandwidthLimit.Enabled {
bodyReader = limitreader.NewRateLimitedReader(bodyReader, bandwidthLimit, int(bandwidthBurst), ctx)
}
if MatcherShell(u) && matchString(matcher, matchedMatchers) && cfg.Shell.Editor { if MatcherShell(u) && matchString(matcher, matchedMatchers) && cfg.Shell.Editor {
// 判断body是不是gzip // 判断body是不是gzip
var compress string var compress string
@ -106,7 +113,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
var reader io.Reader var reader io.Reader
reader, _, err = processLinks(resp.Body, compress, string(c.Request.Host()), cfg) reader, _, err = processLinks(bodyReader, compress, string(c.Request.Host()), cfg)
c.SetBodyStream(reader, -1) c.SetBodyStream(reader, -1)
if err != nil { if err != nil {
logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), c.Request.Method(), u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err) logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), c.Request.Method(), u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err)
@ -114,11 +121,12 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
return return
} }
} else { } else {
if contentLength != "" { if contentLength != "" {
c.SetBodyStream(resp.Body, bodySize) c.SetBodyStream(bodyReader, bodySize)
return return
} }
c.SetBodyStream(resp.Body, -1) c.SetBodyStream(bodyReader, -1)
} }
} }

View file

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/WJQSERVER-STUDIO/go-utils/limitreader"
"github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app"
) )
@ -107,10 +108,16 @@ func GhcrRequest(ctx context.Context, c *app.RequestContext, u string, cfg *conf
c.Status(resp.StatusCode) c.Status(resp.StatusCode)
bodyReader := resp.Body
if cfg.RateLimit.BandwidthLimit.Enabled {
bodyReader = limitreader.NewRateLimitedReader(bodyReader, bandwidthLimit, int(bandwidthBurst), ctx)
}
if contentLength != "" { if contentLength != "" {
c.SetBodyStream(resp.Body, bodySize) c.SetBodyStream(bodyReader, bodySize)
return return
} }
c.SetBodyStream(resp.Body, -1) c.SetBodyStream(bodyReader, -1)
} }

View file

@ -8,15 +8,16 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/WJQSERVER-STUDIO/go-utils/limitreader"
"github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app"
) )
func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Config, mode string) { func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Config, mode string) {
method := string(c.Request.Method()) method := string(c.Request.Method())
bodyReader := bytes.NewBuffer(c.Request.Body()) reqBodyReader := bytes.NewBuffer(c.Request.Body())
//bodyReader := c.Request.BodyStream() //bodyReader := c.Request.BodyStream() // 不可替换为此实现
if cfg.GitClone.Mode == "cache" { if cfg.GitClone.Mode == "cache" {
userPath, repoPath, remainingPath, queryParams, err := extractParts(u) userPath, repoPath, remainingPath, queryParams, err := extractParts(u)
@ -35,7 +36,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
if cfg.GitClone.Mode == "cache" { if cfg.GitClone.Mode == "cache" {
rb := gitclient.NewRequestBuilder(method, u) rb := gitclient.NewRequestBuilder(method, u)
rb.NoDefaultHeaders() rb.NoDefaultHeaders()
rb.SetBody(bodyReader) rb.SetBody(reqBodyReader)
rb.WithContext(ctx) rb.WithContext(ctx)
req, err := rb.Build() req, err := rb.Build()
@ -55,7 +56,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
} else { } else {
rb := client.NewRequestBuilder(string(c.Request.Method()), u) rb := client.NewRequestBuilder(string(c.Request.Method()), u)
rb.NoDefaultHeaders() rb.NoDefaultHeaders()
rb.SetBody(bodyReader) rb.SetBody(reqBodyReader)
rb.WithContext(ctx) rb.WithContext(ctx)
req, err := rb.Build() req, err := rb.Build()
@ -91,7 +92,6 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
for key, values := range resp.Header { for key, values := range resp.Header {
for _, value := range values { for _, value := range values {
//c.Header(key, value)
c.Response.Header.Add(key, value) c.Response.Header.Add(key, value)
} }
} }
@ -124,5 +124,11 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
c.Response.Header.Set("Expires", "0") c.Response.Header.Set("Expires", "0")
} }
c.SetBodyStream(resp.Body, -1) bodyReader := resp.Body
if cfg.RateLimit.BandwidthLimit.Enabled {
bodyReader = limitreader.NewRateLimitedReader(bodyReader, bandwidthLimit, int(bandwidthBurst), ctx)
}
c.SetBodyStream(bodyReader, -1)
} }

View file

@ -18,11 +18,16 @@ var (
gitclient *httpc.Client gitclient *httpc.Client
) )
func InitReq(cfg *config.Config) { func InitReq(cfg *config.Config) error {
initHTTPClient(cfg) initHTTPClient(cfg)
if cfg.GitClone.Mode == "cache" { if cfg.GitClone.Mode == "cache" {
initGitHTTPClient(cfg) initGitHTTPClient(cfg)
} }
err := SetGlobalRateLimit(cfg)
if err != nil {
return err
}
return nil
} }
func initHTTPClient(cfg *config.Config) { func initHTTPClient(cfg *config.Config) {