diff --git a/config/config.go b/config/config.go index 14c9301..c5b8aa0 100644 --- a/config/config.go +++ b/config/config.go @@ -129,11 +129,35 @@ type WhitelistConfig struct { WhitelistFile string `toml:"whitelistFile"` } +/* +[rateLimit] +enabled = false +rateMethod = "total" # "total" or "ip" +ratePerMinute = 100 +burst = 10 + + [rateLimit.bandwidthLimit] + enabled = false + totalLimit = "100mbps" + totalBurst = "100mbps" + singleLimit = "10mbps" + singleBurst = "10mbps" +*/ + type RateLimitConfig struct { - Enabled bool `toml:"enabled"` - RateMethod string `toml:"rateMethod"` - RatePerMinute int `toml:"ratePerMinute"` - Burst int `toml:"burst"` + Enabled bool `toml:"enabled"` + RateMethod string `toml:"rateMethod"` + RatePerMinute int `toml:"ratePerMinute"` + Burst int `toml:"burst"` + BandwidthLimit BandwidthLimitConfig +} + +type BandwidthLimitConfig struct { + Enabled bool `toml:"enabled"` + TotalLimit string `toml:"totalLimit"` + TotalBurst string `toml:"totalBurst"` + SingleLimit string `toml:"singleLimit"` + SingleBurst string `toml:"singleBurst"` } /* @@ -252,6 +276,13 @@ func DefaultConfig() *Config { RateMethod: "total", RatePerMinute: 100, Burst: 10, + BandwidthLimit: BandwidthLimitConfig{ + Enabled: false, + TotalLimit: "100mbps", + TotalBurst: "100mbps", + SingleLimit: "10mbps", + SingleBurst: "10mbps", + }, }, Outbound: OutboundConfig{ Enabled: false, diff --git a/config/config.toml b/config/config.toml index b43ebf3..ca7b80a 100644 --- a/config/config.toml +++ b/config/config.toml @@ -57,6 +57,13 @@ rateMethod = "total" # "ip" or "total" ratePerMinute = 180 burst = 5 +[rateLimit.bandwidthLimit] + enabled = false + totalLimit = "100mbps" + totalBurst = "100mbps" + singleLimit = "10mbps" + singleBurst = "10mbps" + [outbound] enabled = false url = "socks5://127.0.0.1:1080" # "http://127.0.0.1:7890" diff --git a/go.mod b/go.mod index 24d73f2..d20e90a 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,7 @@ go 1.24.3 require ( github.com/BurntSushi/toml v1.5.0 - //github.com/WJQSERVER-STUDIO/httpc v0.5.0 - github.com/WJQSERVER-STUDIO/httpc v0.5.1-0.20250513102952-d961182b2489 + github.com/WJQSERVER-STUDIO/httpc v0.5.1 github.com/WJQSERVER-STUDIO/logger v1.6.0 github.com/cloudwego/hertz v0.9.7 github.com/hertz-contrib/http2 v0.1.8 @@ -13,6 +12,8 @@ require ( golang.org/x/time v0.11.0 ) +require github.com/WJQSERVER-STUDIO/go-utils/limitreader v0.0.2 + require ( github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 // indirect github.com/WJQSERVER-STUDIO/go-utils/log v0.0.2 // indirect @@ -38,4 +39,4 @@ require ( google.golang.org/protobuf v1.36.6 // indirect ) -//replace github.com/WJQSERVER-STUDIO/httpc v0.5.0 => /data/github/WJQSERVER-STUDIO/httpc +//replace github.com/WJQSERVER-STUDIO/httpc v0.5.1 => /data/github/WJQSERVER-STUDIO/httpc diff --git a/go.sum b/go.sum index 8676974..d906a16 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,12 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 h1:JLtFd00AdFg/TP+dtvIzLkdHwKUGPOAijN1sMtEYoFg= github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4/go.mod h1:FZ6XE+4TKy4MOfX1xWKe6Rwsg0ucYFCdNh1KLvyKTfc= +github.com/WJQSERVER-STUDIO/go-utils/limitreader v0.0.2 h1:8bBkKk6E2Zr+I5szL7gyc5f0DK8N9agIJCpM1Cqw2NE= +github.com/WJQSERVER-STUDIO/go-utils/limitreader v0.0.2/go.mod h1:yPX8xuZH+py7eLJwOYj3VVI/4/Yuy5+x8Mhq8qezcPg= github.com/WJQSERVER-STUDIO/go-utils/log v0.0.2 h1:9CSf+V0ZQPl2ijC/g6v/ObemmhpKcikKVIodsaLExTA= github.com/WJQSERVER-STUDIO/go-utils/log v0.0.2/go.mod h1:j9Q+xnwpOfve7/uJnZ2izRQw6NNoXjvJHz7vUQAaLZE= -github.com/WJQSERVER-STUDIO/httpc v0.5.1-0.20250513102952-d961182b2489 h1:BScWEkOFYMDaSB4SNhBa6XeBoBjg1IHxmGE3NSNW6zw= -github.com/WJQSERVER-STUDIO/httpc v0.5.1-0.20250513102952-d961182b2489/go.mod h1:M7KNUZjjhCkzzcg9lBPs9YfkImI+7vqjAyjdA19+joE= +github.com/WJQSERVER-STUDIO/httpc v0.5.1 h1:+TKCPYBuj7PAHuiduGCGAqsHAa4QtsUfoVwRN777q64= +github.com/WJQSERVER-STUDIO/httpc v0.5.1/go.mod h1:M7KNUZjjhCkzzcg9lBPs9YfkImI+7vqjAyjdA19+joE= github.com/WJQSERVER-STUDIO/logger v1.6.0 h1:xK2xV7hlkMXaWzvj4+cNoNWA+JfnJaHX6VU+RrPnr7Q= github.com/WJQSERVER-STUDIO/logger v1.6.0/go.mod h1:TICMsR7geROHBg6rxwkqUNGydo34XVsX93yeoxyfuyY= github.com/bytedance/gopkg v0.1.1/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= diff --git a/main.go b/main.go index 8d61589..d0237d8 100644 --- a/main.go +++ b/main.go @@ -181,7 +181,11 @@ func setupRateLimit(cfg *config.Config) { } func InitReq(cfg *config.Config) { - proxy.InitReq(cfg) + err := proxy.InitReq(cfg) + if err != nil { + fmt.Printf("Failed to initialize request: %v\n", err) + os.Exit(1) + } } // loadEmbeddedPages 加载嵌入式页面资源 diff --git a/proxy/bandwidth.go b/proxy/bandwidth.go new file mode 100644 index 0000000..a7591c2 --- /dev/null +++ b/proxy/bandwidth.go @@ -0,0 +1,64 @@ +package proxy + +import ( + "errors" + "ghproxy/config" + + "github.com/WJQSERVER-STUDIO/go-utils/limitreader" + "golang.org/x/time/rate" +) + +var ( + bandwidthLimit rate.Limit + bandwidthBurst rate.Limit +) + +func UnDefiendRateStringErrHandle(err error) error { + if errors.Is(err, &limitreader.UnDefiendRateStringErr{}) { + logWarning("UnDefiendRateStringErr: %s", err) + return nil + } + return err +} + +func SetGlobalRateLimit(cfg *config.Config) error { + if cfg.RateLimit.BandwidthLimit.Enabled { + var err error + var totalLimit rate.Limit + var totalBurst rate.Limit + totalLimit, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.TotalLimit) + if UnDefiendRateStringErrHandle(err) != nil { + logError("Failed to parse total bandwidth limit: %v", err) + return err + } + totalBurst, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.TotalBurst) + if UnDefiendRateStringErrHandle(err) != nil { + logError("Failed to parse total bandwidth burst: %v", err) + return err + } + limitreader.SetGlobalRateLimit(totalLimit, int(totalBurst)) + err = SetBandwidthLimit(cfg) + if UnDefiendRateStringErrHandle(err) != nil { + logError("Failed to set bandwidth limit: %v", err) + return err + } + } else { + limitreader.SetGlobalRateLimit(rate.Inf, 0) + } + return nil +} + +func SetBandwidthLimit(cfg *config.Config) error { + var err error + bandwidthLimit, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.SingleLimit) + if UnDefiendRateStringErrHandle(err) != nil { + logError("Failed to parse bandwidth limit: %v", err) + return err + } + bandwidthBurst, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.SingleBurst) + if UnDefiendRateStringErrHandle(err) != nil { + logError("Failed to parse bandwidth burst: %v", err) + return err + } + return nil +} diff --git a/proxy/chunkreq.go b/proxy/chunkreq.go index 35b615e..56e55a3 100644 --- a/proxy/chunkreq.go +++ b/proxy/chunkreq.go @@ -8,6 +8,7 @@ import ( "net/http" "strconv" + "github.com/WJQSERVER-STUDIO/go-utils/limitreader" "github.com/cloudwego/hertz/pkg/app" ) @@ -94,6 +95,12 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c c.Status(resp.StatusCode) + bodyReader := resp.Body + + if cfg.RateLimit.BandwidthLimit.Enabled { + bodyReader = limitreader.NewRateLimitedReader(bodyReader, bandwidthLimit, int(bandwidthBurst), ctx) + } + if MatcherShell(u) && matchString(matcher, matchedMatchers) && cfg.Shell.Editor { // 判断body是不是gzip var compress string @@ -106,7 +113,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c var reader io.Reader - reader, _, err = processLinks(resp.Body, compress, string(c.Request.Host()), cfg) + reader, _, err = processLinks(bodyReader, compress, string(c.Request.Host()), cfg) c.SetBodyStream(reader, -1) if err != nil { logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), c.Request.Method(), u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err) @@ -114,11 +121,12 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c return } } else { + if contentLength != "" { - c.SetBodyStream(resp.Body, bodySize) + c.SetBodyStream(bodyReader, bodySize) return } - c.SetBodyStream(resp.Body, -1) + c.SetBodyStream(bodyReader, -1) } } diff --git a/proxy/docker.go b/proxy/docker.go index 7fbb039..f12a74e 100644 --- a/proxy/docker.go +++ b/proxy/docker.go @@ -7,6 +7,7 @@ import ( "net/http" "strconv" + "github.com/WJQSERVER-STUDIO/go-utils/limitreader" "github.com/cloudwego/hertz/pkg/app" ) @@ -107,10 +108,16 @@ func GhcrRequest(ctx context.Context, c *app.RequestContext, u string, cfg *conf c.Status(resp.StatusCode) + bodyReader := resp.Body + + if cfg.RateLimit.BandwidthLimit.Enabled { + bodyReader = limitreader.NewRateLimitedReader(bodyReader, bandwidthLimit, int(bandwidthBurst), ctx) + } + if contentLength != "" { - c.SetBodyStream(resp.Body, bodySize) + c.SetBodyStream(bodyReader, bodySize) return } - c.SetBodyStream(resp.Body, -1) + c.SetBodyStream(bodyReader, -1) } diff --git a/proxy/gitreq.go b/proxy/gitreq.go index 4398bd7..1afd856 100644 --- a/proxy/gitreq.go +++ b/proxy/gitreq.go @@ -8,15 +8,16 @@ import ( "net/http" "strconv" + "github.com/WJQSERVER-STUDIO/go-utils/limitreader" "github.com/cloudwego/hertz/pkg/app" ) func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Config, mode string) { method := string(c.Request.Method()) - bodyReader := bytes.NewBuffer(c.Request.Body()) + reqBodyReader := bytes.NewBuffer(c.Request.Body()) - //bodyReader := c.Request.BodyStream() + //bodyReader := c.Request.BodyStream() // 不可替换为此实现 if cfg.GitClone.Mode == "cache" { userPath, repoPath, remainingPath, queryParams, err := extractParts(u) @@ -35,7 +36,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co if cfg.GitClone.Mode == "cache" { rb := gitclient.NewRequestBuilder(method, u) rb.NoDefaultHeaders() - rb.SetBody(bodyReader) + rb.SetBody(reqBodyReader) rb.WithContext(ctx) req, err := rb.Build() @@ -55,7 +56,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co } else { rb := client.NewRequestBuilder(string(c.Request.Method()), u) rb.NoDefaultHeaders() - rb.SetBody(bodyReader) + rb.SetBody(reqBodyReader) rb.WithContext(ctx) req, err := rb.Build() @@ -91,7 +92,6 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co for key, values := range resp.Header { for _, value := range values { - //c.Header(key, value) c.Response.Header.Add(key, value) } } @@ -124,5 +124,11 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co c.Response.Header.Set("Expires", "0") } - c.SetBodyStream(resp.Body, -1) + bodyReader := resp.Body + + if cfg.RateLimit.BandwidthLimit.Enabled { + bodyReader = limitreader.NewRateLimitedReader(bodyReader, bandwidthLimit, int(bandwidthBurst), ctx) + } + + c.SetBodyStream(bodyReader, -1) } diff --git a/proxy/httpc.go b/proxy/httpc.go index 120d8a7..83de29b 100644 --- a/proxy/httpc.go +++ b/proxy/httpc.go @@ -18,11 +18,16 @@ var ( gitclient *httpc.Client ) -func InitReq(cfg *config.Config) { +func InitReq(cfg *config.Config) error { initHTTPClient(cfg) if cfg.GitClone.Mode == "cache" { initGitHTTPClient(cfg) } + err := SetGlobalRateLimit(cfg) + if err != nil { + return err + } + return nil } func initHTTPClient(cfg *config.Config) {