mirror of
https://github.com/WJQSERVER-STUDIO/ghproxy.git
synced 2026-02-03 08:11:11 +08:00
add bandwidth limiter
This commit is contained in:
parent
3f8d16511e
commit
71bc2aaed7
10 changed files with 157 additions and 22 deletions
64
proxy/bandwidth.go
Normal file
64
proxy/bandwidth.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"ghproxy/config"
|
||||
|
||||
"github.com/WJQSERVER-STUDIO/go-utils/limitreader"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var (
|
||||
bandwidthLimit rate.Limit
|
||||
bandwidthBurst rate.Limit
|
||||
)
|
||||
|
||||
func UnDefiendRateStringErrHandle(err error) error {
|
||||
if errors.Is(err, &limitreader.UnDefiendRateStringErr{}) {
|
||||
logWarning("UnDefiendRateStringErr: %s", err)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func SetGlobalRateLimit(cfg *config.Config) error {
|
||||
if cfg.RateLimit.BandwidthLimit.Enabled {
|
||||
var err error
|
||||
var totalLimit rate.Limit
|
||||
var totalBurst rate.Limit
|
||||
totalLimit, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.TotalLimit)
|
||||
if UnDefiendRateStringErrHandle(err) != nil {
|
||||
logError("Failed to parse total bandwidth limit: %v", err)
|
||||
return err
|
||||
}
|
||||
totalBurst, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.TotalBurst)
|
||||
if UnDefiendRateStringErrHandle(err) != nil {
|
||||
logError("Failed to parse total bandwidth burst: %v", err)
|
||||
return err
|
||||
}
|
||||
limitreader.SetGlobalRateLimit(totalLimit, int(totalBurst))
|
||||
err = SetBandwidthLimit(cfg)
|
||||
if UnDefiendRateStringErrHandle(err) != nil {
|
||||
logError("Failed to set bandwidth limit: %v", err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
limitreader.SetGlobalRateLimit(rate.Inf, 0)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetBandwidthLimit(cfg *config.Config) error {
|
||||
var err error
|
||||
bandwidthLimit, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.SingleLimit)
|
||||
if UnDefiendRateStringErrHandle(err) != nil {
|
||||
logError("Failed to parse bandwidth limit: %v", err)
|
||||
return err
|
||||
}
|
||||
bandwidthBurst, err = limitreader.ParseRate(cfg.RateLimit.BandwidthLimit.SingleBurst)
|
||||
if UnDefiendRateStringErrHandle(err) != nil {
|
||||
logError("Failed to parse bandwidth burst: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/WJQSERVER-STUDIO/go-utils/limitreader"
|
||||
"github.com/cloudwego/hertz/pkg/app"
|
||||
)
|
||||
|
||||
|
|
@ -94,6 +95,12 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
|||
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
bodyReader := resp.Body
|
||||
|
||||
if cfg.RateLimit.BandwidthLimit.Enabled {
|
||||
bodyReader = limitreader.NewRateLimitedReader(bodyReader, bandwidthLimit, int(bandwidthBurst), ctx)
|
||||
}
|
||||
|
||||
if MatcherShell(u) && matchString(matcher, matchedMatchers) && cfg.Shell.Editor {
|
||||
// 判断body是不是gzip
|
||||
var compress string
|
||||
|
|
@ -106,7 +113,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
|||
|
||||
var reader io.Reader
|
||||
|
||||
reader, _, err = processLinks(resp.Body, compress, string(c.Request.Host()), cfg)
|
||||
reader, _, err = processLinks(bodyReader, compress, string(c.Request.Host()), cfg)
|
||||
c.SetBodyStream(reader, -1)
|
||||
if err != nil {
|
||||
logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), c.Request.Method(), u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err)
|
||||
|
|
@ -114,11 +121,12 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
|||
return
|
||||
}
|
||||
} else {
|
||||
|
||||
if contentLength != "" {
|
||||
c.SetBodyStream(resp.Body, bodySize)
|
||||
c.SetBodyStream(bodyReader, bodySize)
|
||||
return
|
||||
}
|
||||
c.SetBodyStream(resp.Body, -1)
|
||||
c.SetBodyStream(bodyReader, -1)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/WJQSERVER-STUDIO/go-utils/limitreader"
|
||||
"github.com/cloudwego/hertz/pkg/app"
|
||||
)
|
||||
|
||||
|
|
@ -107,10 +108,16 @@ func GhcrRequest(ctx context.Context, c *app.RequestContext, u string, cfg *conf
|
|||
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
bodyReader := resp.Body
|
||||
|
||||
if cfg.RateLimit.BandwidthLimit.Enabled {
|
||||
bodyReader = limitreader.NewRateLimitedReader(bodyReader, bandwidthLimit, int(bandwidthBurst), ctx)
|
||||
}
|
||||
|
||||
if contentLength != "" {
|
||||
c.SetBodyStream(resp.Body, bodySize)
|
||||
c.SetBodyStream(bodyReader, bodySize)
|
||||
return
|
||||
}
|
||||
c.SetBodyStream(resp.Body, -1)
|
||||
c.SetBodyStream(bodyReader, -1)
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,15 +8,16 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/WJQSERVER-STUDIO/go-utils/limitreader"
|
||||
"github.com/cloudwego/hertz/pkg/app"
|
||||
)
|
||||
|
||||
func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Config, mode string) {
|
||||
method := string(c.Request.Method())
|
||||
|
||||
bodyReader := bytes.NewBuffer(c.Request.Body())
|
||||
reqBodyReader := bytes.NewBuffer(c.Request.Body())
|
||||
|
||||
//bodyReader := c.Request.BodyStream()
|
||||
//bodyReader := c.Request.BodyStream() // 不可替换为此实现
|
||||
|
||||
if cfg.GitClone.Mode == "cache" {
|
||||
userPath, repoPath, remainingPath, queryParams, err := extractParts(u)
|
||||
|
|
@ -35,7 +36,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
|
|||
if cfg.GitClone.Mode == "cache" {
|
||||
rb := gitclient.NewRequestBuilder(method, u)
|
||||
rb.NoDefaultHeaders()
|
||||
rb.SetBody(bodyReader)
|
||||
rb.SetBody(reqBodyReader)
|
||||
rb.WithContext(ctx)
|
||||
|
||||
req, err := rb.Build()
|
||||
|
|
@ -55,7 +56,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
|
|||
} else {
|
||||
rb := client.NewRequestBuilder(string(c.Request.Method()), u)
|
||||
rb.NoDefaultHeaders()
|
||||
rb.SetBody(bodyReader)
|
||||
rb.SetBody(reqBodyReader)
|
||||
rb.WithContext(ctx)
|
||||
|
||||
req, err := rb.Build()
|
||||
|
|
@ -91,7 +92,6 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
|
|||
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
//c.Header(key, value)
|
||||
c.Response.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
|
|
@ -124,5 +124,11 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
|
|||
c.Response.Header.Set("Expires", "0")
|
||||
}
|
||||
|
||||
c.SetBodyStream(resp.Body, -1)
|
||||
bodyReader := resp.Body
|
||||
|
||||
if cfg.RateLimit.BandwidthLimit.Enabled {
|
||||
bodyReader = limitreader.NewRateLimitedReader(bodyReader, bandwidthLimit, int(bandwidthBurst), ctx)
|
||||
}
|
||||
|
||||
c.SetBodyStream(bodyReader, -1)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,11 +18,16 @@ var (
|
|||
gitclient *httpc.Client
|
||||
)
|
||||
|
||||
func InitReq(cfg *config.Config) {
|
||||
func InitReq(cfg *config.Config) error {
|
||||
initHTTPClient(cfg)
|
||||
if cfg.GitClone.Mode == "cache" {
|
||||
initGitHTTPClient(cfg)
|
||||
}
|
||||
err := SetGlobalRateLimit(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func initHTTPClient(cfg *config.Config) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue