diff --git a/CHANGELOG.md b/CHANGELOG.md index 8361a09..9f96690 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,50 @@ # 更新日志 +3.1.0 - 2025-04-24 +--- +- CHANGE: 对标准url使用`HertZ`路由匹配器, 而不是自制匹配器, 以提升效率 +- CHANGE: 使用`bodystream`进行req方向的body复制, 而不是使用额外的`buffer reader` +- CHANGE: 使用`HertZ`的`requestContext`传递matcher参数, 而不是`25w30a`中的ctx +- CHANGE: 改进`rate`模块, 避免并发竞争问题 +- CHANGE: 将大部分状态码返回改为新的`html/tmpl`方式处理 +- CHANGE: 修改部分log等级 +- FIX: 修正默认配置的填充错误 +- CHANGE: 使用go `html/tmpl`处理状态码页面, 同时实现错误信息显示 +- CHANGE: 改进handle, 复用共同部分 +- CHANGE: 细化url匹配的返回码处理 +- CHANGE: 增加404界面 + +25w30e - 2025-04-24 +--- +- PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用; +- CHANGE: 改进`rate`模块, 避免并发竞争问题 +- CHANGE: 将大部分状态码返回改为新的`html/tmpl`方式处理 +- CHANGE: 修改部分log等级 +- FIX: 修正默认配置的填充错误 + +25w30d - 2025-04-22 +--- +- PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用; +- CHANGE: 使用go `html/tmpl`处理状态码页面, 同时实现错误信息显示 + +25w30c - 2025-04-21 +--- +- PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用; +- CHANGE: 改进handle, 复用共同部分 +- CHANGE: 细化url匹配的返回码处理 +- CHANGE: 增加404界面 + +25w30b - 2025-04-21 +--- +- PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用; +- CHANGE: 使用`bodystream`进行req方向的body复制, 而不是使用额外的`buffer reader` +- CHANGE: 使用`HertZ`的`requestContext`传递matcher参数, 而不是`25w30a`中的标准ctx + +25w30a - 2025-04-19 +--- +- PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用; +- CHANGE: 对标准url使用`HertZ`路由匹配器, 而不是自制匹配器 + 3.0.3 - 2025-04-19 --- - CHANGE: 增加移除部分header的处置, 避免向服务端/客户端透露过多信息 diff --git a/DEV-VERSION b/DEV-VERSION index c307029..1fecf83 100644 --- a/DEV-VERSION +++ b/DEV-VERSION @@ -1 +1 @@ -25w29b \ No newline at end of file +25w30e \ No newline at end of file diff --git a/VERSION b/VERSION index 282895a..a0cd9f0 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -3.0.3 \ No newline at end of file +3.1.0 \ No newline at end of file diff --git a/auth/auth.go b/auth/auth.go index 1a7f1a2..7ebff30 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,7 +1,6 @@ package auth import ( - "context" "fmt" "ghproxy/config" @@ -36,7 +35,7 @@ func Init(cfg *config.Config) { logDebug("Auth Init") } -func AuthHandler(ctx context.Context, c *app.RequestContext, cfg *config.Config) (isValid bool, err error) { +func AuthHandler(c *app.RequestContext, cfg *config.Config) (isValid bool, err error) { if cfg.Auth.Method == "parameters" { isValid, err = AuthParametersHandler(c, cfg) return isValid, err @@ -47,7 +46,7 @@ func AuthHandler(ctx context.Context, c *app.RequestContext, cfg *config.Config) logError("Auth method not set") return true, nil } else { - logError("Auth method not supported") - return false, fmt.Errorf(fmt.Sprintf("Auth method %s not supported", cfg.Auth.Method)) + logError("Auth method not supported %s", cfg.Auth.Method) + return false, fmt.Errorf("%s", fmt.Sprintf("Auth method %s not supported", cfg.Auth.Method)) } } diff --git a/config/config.go b/config/config.go index e737a63..984dc9e 100644 --- a/config/config.go +++ b/config/config.go @@ -228,11 +228,11 @@ func DefaultConfig() *Config { }, Blacklist: BlacklistConfig{ Enabled: false, - BlacklistFile: "/data/ghproxy/config/blacklist.txt", + BlacklistFile: "/data/ghproxy/config/blacklist.json", }, Whitelist: WhitelistConfig{ Enabled: false, - WhitelistFile: "/data/ghproxy/config/whitelist.txt", + WhitelistFile: "/data/ghproxy/config/whitelist.json", }, RateLimit: RateLimitConfig{ Enabled: false, diff --git a/go.mod b/go.mod index 2c6514f..1820f16 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect - github.com/nyaruka/phonenumbers v1.6.0 // indirect + github.com/nyaruka/phonenumbers v1.6.1 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect diff --git a/go.sum b/go.sum index b036c8b..b6da4cc 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgSh github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/nyaruka/phonenumbers v1.6.0 h1:r9ax45fFg+YLUs2X4bNXm5RAxWl00hYjFgNlv32vtHk= -github.com/nyaruka/phonenumbers v1.6.0/go.mod h1:7gjs+Lchqm49adhAKB5cdcng5ZXgt6x7Jgvi0ZorUtU= +github.com/nyaruka/phonenumbers v1.6.1 h1:XAJcTdYow16VrVKfglznMpJZz8KMJoMjx/91sX+K940= +github.com/nyaruka/phonenumbers v1.6.1/go.mod h1:7gjs+Lchqm49adhAKB5cdcng5ZXgt6x7Jgvi0ZorUtU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/satomitouka/touka-httpc v0.4.0 h1:cnOONdyJHJImMY8L64bvYF+7Ow/5CPf2Yr3RQRRMZOU= diff --git a/main.go b/main.go index 50df92e..5314f31 100644 --- a/main.go +++ b/main.go @@ -210,6 +210,12 @@ func loadEmbeddedPages(cfg *config.Config) (fs.FS, fs.FS, error) { return nil, nil, fmt.Errorf("failed to load embedded pages: %w", err) } + // 初始化errPagesFs + errPagesInitErr := proxy.InitErrPagesFS(pagesFS) + if errPagesInitErr != nil { + logWarning("errPagesInitErr: %s", errPagesInitErr) + } + var assets fs.FS assets, err = fs.Sub(pagesFS, "pages/assets") return pages, assets, nil @@ -404,56 +410,54 @@ func main() { os.Exit(1) } - // 添加Recovery中间件 - r.Use(recovery.Recovery()) - // 添加log中间件 - r.Use(loggin.Middleware()) - + r.Use(recovery.Recovery()) // Recovery中间件 + r.Use(loggin.Middleware()) // log中间件 setupApi(cfg, r, version) - setupPages(cfg, r) - /* - // 1. GitHub Releases/Archive - Use distinct path segments for type - r.GET("/github.com/:username/:repo/releases/*filepath", func(ctx context.Context, c *app.RequestContext) { // Distinct path for releases - proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) - }) + r.GET("/github.com/:username/:repo/releases/*filepath", func(ctx context.Context, c *app.RequestContext) { + c.Set("matcher", "release") + proxy.RoutingHandler(cfg, limiter, iplimiter)(ctx, c) + }) - r.GET("/github.com/:username/:repo/archive/*filepath", func(ctx context.Context, c *app.RequestContext) { // Distinct path for archive - proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) - }) + r.GET("/github.com/:username/:repo/archive/*filepath", func(ctx context.Context, c *app.RequestContext) { + c.Set("matcher", "release") + proxy.RoutingHandler(cfg, limiter, iplimiter)(ctx, c) + }) - // 2. GitHub Blob/Raw - Use distinct path segments for type - r.GET("/github.com/:username/:repo/blob/*filepath", func(ctx context.Context, c *app.RequestContext) { // Distinct path for blob - proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) - }) + r.GET("/github.com/:username/:repo/blob/*filepath", func(ctx context.Context, c *app.RequestContext) { + c.Set("matcher", "blob") + proxy.RoutingHandler(cfg, limiter, iplimiter)(ctx, c) + }) - r.GET("/github.com/:username/:repo/raw/*filepath", func(ctx context.Context, c *app.RequestContext) { // Distinct path for raw - proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) - }) + r.GET("/github.com/:username/:repo/raw/*filepath", func(ctx context.Context, c *app.RequestContext) { + c.Set("matcher", "raw") + proxy.RoutingHandler(cfg, limiter, iplimiter)(ctx, c) + }) - r.GET("/github.com/:username/:repo/info/*filepath", func(ctx context.Context, c *app.RequestContext) { // Distinct path for info - proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) - }) - r.GET("/github.com/:username/:repo/git-upload-pack", func(ctx context.Context, c *app.RequestContext) { - proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) - }) + r.GET("/github.com/:username/:repo/info/*filepath", func(ctx context.Context, c *app.RequestContext) { + c.Set("matcher", "gitclone") + proxy.RoutingHandler(cfg, limiter, iplimiter)(ctx, c) + }) + r.GET("/github.com/:username/:repo/git-upload-pack", func(ctx context.Context, c *app.RequestContext) { + c.Set("matcher", "gitclone") + proxy.RoutingHandler(cfg, limiter, iplimiter)(ctx, c) + }) - // 4. Raw GitHubusercontent - Keep as is (assuming it's distinct enough) - r.GET("/raw.githubusercontent.com/:username/:repo/*filepath", func(ctx context.Context, c *app.RequestContext) { - proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) - }) + r.GET("/raw.githubusercontent.com/:username/:repo/*filepath", func(ctx context.Context, c *app.RequestContext) { + c.Set("matcher", "raw") + proxy.RoutingHandler(cfg, limiter, iplimiter)(ctx, c) + }) - // 5. Gist GitHubusercontent - Keep as is (assuming it's distinct enough) - r.GET("/gist.githubusercontent.com/:username/*filepath", func(ctx context.Context, c *app.RequestContext) { - proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) - }) + r.GET("/gist.githubusercontent.com/:username/*filepath", func(ctx context.Context, c *app.RequestContext) { + c.Set("matcher", "gist") + proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) + }) - // 6. GitHub API Repos - Keep as is (assuming it's distinct enough) - r.GET("/api.github.com/repos/:username/:repo/*filepath", func(ctx context.Context, c *app.RequestContext) { - proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) - }) - */ + r.GET("/api.github.com/repos/:username/:repo/*filepath", func(ctx context.Context, c *app.RequestContext) { + c.Set("matcher", "api") + proxy.RoutingHandler(cfg, limiter, iplimiter)(ctx, c) + }) r.NoRoute(func(ctx context.Context, c *app.RequestContext) { proxy.NoRouteHandler(cfg, limiter, iplimiter)(ctx, c) diff --git a/proxy/authpass.go b/proxy/authpass.go index e506a9e..16887c7 100644 --- a/proxy/authpass.go +++ b/proxy/authpass.go @@ -18,8 +18,7 @@ func AuthPassThrough(c *app.RequestContext, cfg *config.Config, req *http.Reques req.Header.Set("Authorization", "token "+token) } else { logWarning("%s %s %s %s %s Auth-Error: Conflict Auth Method", c.ClientIP(), c.Method(), string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol()) - // 500 Internal Server Error - c.JSON(http.StatusInternalServerError, map[string]string{"error": "Conflict Auth Method"}) + ErrorPage(c, NewErrorWithStatusLookup(500, "Conflict Auth Method")) return } case "header": @@ -28,8 +27,7 @@ func AuthPassThrough(c *app.RequestContext, cfg *config.Config, req *http.Reques } default: logWarning("%s %s %s %s %s Invalid Auth Method / Auth Method is not be set", c.ClientIP(), c.Method(), string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol()) - // 500 Internal Server Error - c.JSON(http.StatusInternalServerError, map[string]string{"error": "Invalid Auth Method / Auth Method is not be set"}) + ErrorPage(c, NewErrorWithStatusLookup(500, "Invalid Auth Method / Auth Method is not be set")) return } } diff --git a/proxy/chunkreq.go b/proxy/chunkreq.go index b02ec4c..f68b07f 100644 --- a/proxy/chunkreq.go +++ b/proxy/chunkreq.go @@ -1,7 +1,6 @@ package proxy import ( - "bytes" "context" "fmt" "ghproxy/config" @@ -38,17 +37,15 @@ var ( func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, cfg *config.Config, matcher string) { var ( - method []byte - bodyReader *bytes.Buffer - req *http.Request - resp *http.Response - err error + method []byte + req *http.Request + resp *http.Response + err error ) method = c.Request.Method() - bodyReader = bytes.NewBuffer(c.Request.Body()) - req, err = client.NewRequest(string(method), u, bodyReader) + req, err = client.NewRequest(string(method), u, c.Request.BodyStream()) if err != nil { HandleError(c, fmt.Sprintf("Failed to create request: %v", err)) return @@ -66,8 +63,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c // 错误处理(404) if resp.StatusCode == 404 { - //c.String(http.StatusNotFound, "File Not Found") - c.Status(http.StatusNotFound) + ErrorPage(c, NewErrorWithStatusLookup(404, "Page Not Found (From Github)")) return } @@ -92,25 +88,12 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c if err != nil { logError("Failed to close response body: %v", err) } - c.Redirect(http.StatusMovedPermanently, []byte(finalURL)) + c.Redirect(301, []byte(finalURL)) logWarning("%s %s %s %s %s Final-URL: %s Size-Limit-Exceeded: %d", c.ClientIP(), c.Method(), c.Path(), c.UserAgent(), c.Request.Header.GetProtocol(), finalURL, bodySize) return } } - /* - for header := range headersToRemove { - resp.Header.Del(header) - } - - for key := range resp.Header { - var values []string = resp.Header.Values(key) - for _, value := range values { - c.Header(key, value) - } - } - */ - // 复制响应头,排除需要移除的 header for key, values := range resp.Header { if _, shouldRemove := respHeadersToRemove[key]; !shouldRemove { @@ -140,16 +123,16 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c compress = "gzip" } - logInfo("Is Shell: %s %s %s %s %s", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol()) + logDebug("Use Shell Editor: %s %s %s %s %s", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol()) c.Header("Content-Length", "") var reader io.Reader reader, _, err = processLinks(resp.Body, compress, string(c.Request.Host()), cfg) c.SetBodyStream(reader, -1) - if err != nil { logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err) + ErrorPage(c, NewErrorWithStatusLookup(500, fmt.Sprintf("Failed to copy response body: %v", err))) return } } else { diff --git a/proxy/error.go b/proxy/error.go index d8b27c5..b5ff39d 100644 --- a/proxy/error.go +++ b/proxy/error.go @@ -1,7 +1,10 @@ package proxy import ( - "net/http" + "bytes" + "fmt" + "html/template" + "io/fs" "github.com/WJQSERVER-STUDIO/go-utils/logger" "github.com/cloudwego/hertz/pkg/app" @@ -18,6 +21,147 @@ var ( ) func HandleError(c *app.RequestContext, message string) { - c.JSON(http.StatusInternalServerError, map[string]string{"error": message}) + ErrorPage(c, NewErrorWithStatusLookup(500, message)) logError(message) } + +type GHProxyErrors struct { + StatusCode int + StatusDesc string + StatusText string + HelpInfo string + ErrorMessage string +} + +var ( + ErrInvalidURL = &GHProxyErrors{ + StatusCode: 400, + StatusDesc: "Bad Request", + StatusText: "无效请求", + HelpInfo: "请求的URL格式不正确,请检查后重试。", + } + ErrAuthHeaderUnavailable = &GHProxyErrors{ + StatusCode: 401, + StatusDesc: "Unauthorized", + StatusText: "认证失败", + HelpInfo: "缺少或无效的鉴权信息。", + } + ErrForbidden = &GHProxyErrors{ + StatusCode: 403, + StatusDesc: "Forbidden", + StatusText: "权限不足", + HelpInfo: "您没有权限访问此资源。", + } + ErrNotFound = &GHProxyErrors{ + StatusCode: 404, + StatusDesc: "Not Found", + StatusText: "页面未找到", + HelpInfo: "抱歉,您访问的页面不存在。", + } + ErrTooManyRequests = &GHProxyErrors{ + StatusCode: 429, + StatusDesc: "Too Many Requests", + StatusText: "请求过于频繁", + HelpInfo: "您的请求过于频繁,请稍后再试。", + } + ErrInternalServerError = &GHProxyErrors{ + StatusCode: 500, + StatusDesc: "Internal Server Error", + StatusText: "服务器内部错误", + HelpInfo: "服务器处理您的请求时发生错误,请稍后重试或联系管理员。", + } +) + +var statusErrorMap map[int]*GHProxyErrors + +func init() { + statusErrorMap = map[int]*GHProxyErrors{ + ErrInvalidURL.StatusCode: ErrInvalidURL, + ErrAuthHeaderUnavailable.StatusCode: ErrAuthHeaderUnavailable, + ErrForbidden.StatusCode: ErrForbidden, + ErrNotFound.StatusCode: ErrNotFound, + ErrTooManyRequests.StatusCode: ErrTooManyRequests, + ErrInternalServerError.StatusCode: ErrInternalServerError, + } +} + +func NewErrorWithStatusLookup(statusCode int, errMsg string) *GHProxyErrors { + baseErr, found := statusErrorMap[statusCode] + + if found { + return &GHProxyErrors{ + StatusCode: baseErr.StatusCode, + StatusDesc: baseErr.StatusDesc, + StatusText: baseErr.StatusText, + HelpInfo: baseErr.HelpInfo, + ErrorMessage: errMsg, + } + } else { + return &GHProxyErrors{ + StatusCode: statusCode, + ErrorMessage: errMsg, + } + } +} + +var errPagesFs fs.FS + +func InitErrPagesFS(pages fs.FS) error { + var err error + errPagesFs, err = fs.Sub(pages, "pages/err") + if err != nil { + return err + } + return nil +} + +type ErrorPageData struct { + StatusCode int + StatusDesc string + StatusText string + HelpInfo string + ErrorMessage string +} + +func ErrPageUnwarper(errInfo *GHProxyErrors) ErrorPageData { + return ErrorPageData{ + StatusCode: errInfo.StatusCode, + StatusDesc: errInfo.StatusDesc, + StatusText: errInfo.StatusText, + HelpInfo: errInfo.HelpInfo, + ErrorMessage: errInfo.ErrorMessage, + } +} + +func ErrorPage(c *app.RequestContext, errInfo *GHProxyErrors) { + pageData, err := htmlTemplateRender(errPagesFs, ErrPageUnwarper(errInfo)) + if err != nil { + c.JSON(errInfo.StatusCode, map[string]string{"error": errInfo.ErrorMessage}) + logDebug("Error reading page.tmpl: %v", err) + return + } + c.Data(errInfo.StatusCode, "text/html; charset=utf-8", pageData) + return +} + +func htmlTemplateRender(fsys fs.FS, data interface{}) ([]byte, error) { + tmplPath := "page.tmpl" + tmpl, err := template.ParseFS(fsys, tmplPath) + if err != nil { + return nil, fmt.Errorf("error parsing template: %w", err) + } + if tmpl == nil { + return nil, fmt.Errorf("template is nil") + } + + // 创建一个 bytes.Buffer 用于存储渲染结果 + var buf bytes.Buffer + + err = tmpl.Execute(&buf, data) + if err != nil { + return nil, fmt.Errorf("error executing template: %w", err) + } + + // 返回 buffer 的内容作为 []byte + return buf.Bytes(), nil +} diff --git a/proxy/gitreq.go b/proxy/gitreq.go index 55ede9a..b8ba66d 100644 --- a/proxy/gitreq.go +++ b/proxy/gitreq.go @@ -1,7 +1,6 @@ package proxy import ( - "bytes" "context" "fmt" "ghproxy/config" @@ -28,16 +27,10 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co var ( resp *http.Response - //err error ) - body := c.Request.Body() - - bodyReader := bytes.NewBuffer(body) - // 创建请求 - if cfg.GitClone.Mode == "cache" { - req, err := gitclient.NewRequest(method, u, bodyReader) + req, err := gitclient.NewRequest(method, u, c.Request.BodyStream()) if err != nil { HandleError(c, fmt.Sprintf("Failed to create request: %v", err)) return @@ -52,7 +45,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co return } } else { - req, err := client.NewRequest(method, u, bodyReader) + req, err := client.NewRequest(method, u, c.Request.BodyStream()) if err != nil { HandleError(c, fmt.Sprintf("Failed to create request: %v", err)) return diff --git a/proxy/handler.go b/proxy/handler.go index 9d214d3..bdd7ecb 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -2,12 +2,9 @@ package proxy import ( "context" - "errors" "fmt" - "ghproxy/auth" "ghproxy/config" "ghproxy/rate" - "net/http" "regexp" "strings" @@ -19,43 +16,24 @@ var re = regexp.MustCompile(`^(http:|https:)?/?/?(.*)`) // 匹配http://或https func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) app.HandlerFunc { return func(ctx context.Context, c *app.RequestContext) { - // 限制访问频率 - if cfg.RateLimit.Enabled { - - var allowed bool - - switch cfg.RateLimit.RateMethod { - case "ip": - allowed = iplimiter.Allow(c.ClientIP()) - case "total": - allowed = limiter.Allow() - default: - logWarning("Invalid RateLimit Method") - return - } - - if !allowed { - c.JSON(http.StatusTooManyRequests, map[string]string{"error": "Too Many Requests"}) - logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) - return - } + var shoudBreak bool + shoudBreak = rateCheck(cfg, c, limiter, iplimiter) + if shoudBreak { + return } var ( rawPath string matches []string - errMsg string ) rawPath = strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/ matches = re.FindStringSubmatch(rawPath) // 匹配路径 - logInfo("URL: %v", matches) // 匹配路径错误处理 if len(matches) < 3 { - errMsg = fmt.Sprintf("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) - logWarning(errMsg) - c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath) + logWarning("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), c.Path(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + ErrorPage(c, NewErrorWithStatusLookup(400, fmt.Sprintf("Invalid URL Format: %s", c.Path()))) return } @@ -66,74 +44,34 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra user string repo string matcher string - err error ) - user, repo, matcher, err = Matcher(rawPath, cfg) - if err != nil { - if errors.Is(err, ErrInvalidURL) { - c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath) - logWarning(err.Error()) - return - } - if errors.Is(err, ErrAuthHeaderUnavailable) { - c.String(http.StatusForbidden, "AuthHeader Unavailable") - logWarning(err.Error()) - return - } + var matcherErr *GHProxyErrors + user, repo, matcher, matcherErr = Matcher(rawPath, cfg) + if matcherErr != nil { + ErrorPage(c, matcherErr) + return } - logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) - // dump log 记录详细信息 c.ClientIP(), c.Method(), rawPath,c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), full Header - logDump("%s %s %s %s %s %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), c.Request.Header.Header()) - var repouser string - repouser = fmt.Sprintf("%s/%s", user, repo) + logDump("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) + logDump("%s", c.Request.Header.Header()) - // 白名单检查 - if cfg.Whitelist.Enabled { - var whitelist bool - whitelist = auth.CheckWhitelist(user, repo) - if !whitelist { - errMsg = fmt.Sprintf("Whitelist Blocked repo: %s", repouser) - c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Whitelist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) - return - } + shoudBreak = listCheck(cfg, c, user, repo, rawPath) + if shoudBreak { + return } - // 黑名单检查 - if cfg.Blacklist.Enabled { - var blacklist bool - blacklist = auth.CheckBlacklist(user, repo) - if blacklist { - errMsg = fmt.Sprintf("Blacklist Blocked repo: %s", repouser) - c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) - return - } + shoudBreak = authCheck(c, cfg, matcher, rawPath) + if shoudBreak { + return } - // 若匹配api.github.com/repos/用户名/仓库名/路径, 则检查是否开启HeaderAuth - // 处理blob/raw路径 if matcher == "blob" { rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) } - // 鉴权 - if cfg.Auth.Enabled { - var authcheck bool - authcheck, err = auth.AuthHandler(ctx, c, cfg) - if !authcheck { - //c.AbortWithStatusJSON(401, gin.H{"error": "Unauthorized"}) - c.AbortWithStatusJSON(401, map[string]string{"error": "Unauthorized"}) - logWarning("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) - return - } - } - - // IP METHOD URL USERAGENT PROTO MATCHES - logDebug("%s %s %s %s %s Matched: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), matcher) + logDebug("Matched: %v", matcher) switch matcher { case "releases", "blob", "raw", "gist", "api": @@ -141,8 +79,8 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra case "clone": GitReq(ctx, c, rawPath, cfg, "git") default: - c.String(http.StatusForbidden, "Invalid input.") - fmt.Println("Invalid input.") + ErrorPage(c, NewErrorWithStatusLookup(500, "Matched But Not Matched")) + logError("Matched But Not Matched Path: %s rawPath: %s matcher: %s", c.Path(), rawPath, matcher) return } } diff --git a/proxy/match.go b/proxy/match.go index 1cae46c..9b55339 100644 --- a/proxy/match.go +++ b/proxy/match.go @@ -11,36 +11,7 @@ import ( "strings" ) -// 定义错误类型, error承载描述, 便于处理 -type MatcherErrors struct { - Code int - Msg string - Err error -} - -var ( - ErrInvalidURL = &MatcherErrors{ - Code: 403, - Msg: "Invalid URL Format", - } - ErrAuthHeaderUnavailable = &MatcherErrors{ - Code: 403, - Msg: "AuthHeader Unavailable", - } -) - -func (e *MatcherErrors) Error() string { - if e.Err != nil { - return fmt.Sprintf("Code: %d, Msg: %s, Err: %s", e.Code, e.Msg, e.Err.Error()) - } - return fmt.Sprintf("Code: %d, Msg: %s", e.Code, e.Msg) -} - -func (e *MatcherErrors) Unwrap() error { - return e.Err -} - -func Matcher(rawPath string, cfg *config.Config) (string, string, string, error) { +func Matcher(rawPath string, cfg *config.Config) (string, string, string, *GHProxyErrors) { var ( user string repo string @@ -56,7 +27,8 @@ func Matcher(rawPath string, cfg *config.Config) (string, string, string, error) // 取出user和repo和最后部分 parts := strings.Split(remainingPath, "/") if len(parts) <= 2 { - return "", "", "", ErrInvalidURL + errMsg := "Not enough parts in path after matching 'https://github.com*'" + return "", "", "", NewErrorWithStatusLookup(400, errMsg) } user = parts[0] repo = parts[1] @@ -65,12 +37,15 @@ func Matcher(rawPath string, cfg *config.Config) (string, string, string, error) switch parts[2] { case "releases", "archive": matcher = "releases" - case "blob", "raw": + case "blob": matcher = "blob" + case "raw": + matcher = "raw" case "info", "git-upload-pack": matcher = "clone" default: - return "", "", "", ErrInvalidURL + errMsg := "Url Matched 'https://github.com*', but didn't match the next matcher" + return "", "", "", NewErrorWithStatusLookup(400, errMsg) } } return user, repo, matcher, nil @@ -80,7 +55,8 @@ func Matcher(rawPath string, cfg *config.Config) (string, string, string, error) remainingPath := strings.TrimPrefix(rawPath, "https://") parts := strings.Split(remainingPath, "/") if len(parts) <= 3 { - return "", "", "", ErrInvalidURL + errMsg := "URL after matched 'https://raw*' should have at least 4 parts (user/repo/branch/file)." + return "", "", "", NewErrorWithStatusLookup(400, errMsg) } user = parts[1] repo = parts[2] @@ -93,7 +69,8 @@ func Matcher(rawPath string, cfg *config.Config) (string, string, string, error) remainingPath := strings.TrimPrefix(rawPath, "https://") parts := strings.Split(remainingPath, "/") if len(parts) <= 3 { - return "", "", "", ErrInvalidURL + errMsg := "URL after matched 'https://gist*' should have at least 4 parts (user/gist_id)." + return "", "", "", NewErrorWithStatusLookup(400, errMsg) } user = parts[1] repo = "" @@ -115,12 +92,16 @@ func Matcher(rawPath string, cfg *config.Config) (string, string, string, error) } if !cfg.Auth.ForceAllowApi { if cfg.Auth.Method != "header" || !cfg.Auth.Enabled { - return "", "", "", ErrAuthHeaderUnavailable + //return "", "", "", ErrAuthHeaderUnavailable + errMsg := "AuthHeader Unavailable, Need to open header auth to enable api proxy" + return "", "", "", NewErrorWithStatusLookup(403, errMsg) } } return user, repo, matcher, nil } - return "", "", "", ErrInvalidURL + //return "", "", "", ErrNotFound + errMsg := "Didn't match any matcher" + return "", "", "", NewErrorWithStatusLookup(404, errMsg) } func EditorMatcher(rawPath string, cfg *config.Config) (bool, string, error) { @@ -158,17 +139,11 @@ func EditorMatcher(rawPath string, cfg *config.Config) (bool, string, error) { return true, matcher, nil } } - return false, "", ErrInvalidURL + return false, "", nil } // 匹配文件扩展名是sh的rawPath func MatcherShell(rawPath string) bool { - /* - if strings.HasSuffix(rawPath, ".sh") { - return true - } - return false - */ return strings.HasSuffix(rawPath, ".sh") } diff --git a/proxy/reqheader.go b/proxy/reqheader.go index c338706..99bef3a 100644 --- a/proxy/reqheader.go +++ b/proxy/reqheader.go @@ -6,15 +6,6 @@ import ( "github.com/cloudwego/hertz/pkg/app" ) -/* -// 设置请求头 -func setRequestHeaders(c *app.RequestContext, req *http.Request) { - c.Request.Header.VisitAll(func(key, value []byte) { - req.Header.Set(string(key), string(value)) - }) -} -*/ - func setRequestHeaders(c *app.RequestContext, req *http.Request) { c.Request.Header.VisitAll(func(key, value []byte) { headerKey := string(key) @@ -22,16 +13,5 @@ func setRequestHeaders(c *app.RequestContext, req *http.Request) { if _, shouldRemove := reqHeadersToRemove[headerKey]; !shouldRemove { req.Header.Set(headerKey, headerValue) } - }) } - -/* -// removeWSHeader removes the "Upgrade" and "Connection" headers from the given -// Request, which are added by the client when it wants to upgrade the -// connection to a WebSocket connection. -func removeWSHeader(req *http.Request) { - req.Header.Del("Upgrade") - req.Header.Del("Connection") -} -*/ diff --git a/proxy/routing.go b/proxy/routing.go new file mode 100644 index 0000000..a3135ec --- /dev/null +++ b/proxy/routing.go @@ -0,0 +1,72 @@ +package proxy + +import ( + "context" + "ghproxy/config" + "ghproxy/rate" + "strings" + + "github.com/cloudwego/hertz/pkg/app" +) + +func RoutingHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) app.HandlerFunc { + return func(ctx context.Context, c *app.RequestContext) { + + var shoudBreak bool + + shoudBreak = rateCheck(cfg, c, limiter, iplimiter) + if shoudBreak { + return + } + + var ( + rawPath string + ) + + rawPath = strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/ + + var ( + user string + repo string + matcher string + ) + + user = c.Param("user") + repo = c.Param("repo") + matcher = c.GetString("matcher") + + logDump("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) + logDump("%s", c.Request.Header.Header()) + + shoudBreak = listCheck(cfg, c, user, repo, rawPath) + if shoudBreak { + return + } + + shoudBreak = authCheck(c, cfg, matcher, rawPath) + if shoudBreak { + return + } + + // 处理blob/raw路径 + if matcher == "blob" { + rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) + } + + // 为rawpath加入https:// 头 + rawPath = "https://" + rawPath + + logDebug("Matched: %v", matcher) + + switch matcher { + case "releases", "blob", "raw", "gist", "api": + ChunkedProxyRequest(ctx, c, rawPath, cfg, matcher) + case "clone": + GitReq(ctx, c, rawPath, cfg, "git") + default: + ErrorPage(c, NewErrorWithStatusLookup(500, "Matched But Not Matched")) + logError("Matched But Not Matched Path: %s rawPath: %s matcher: %s", c.Path(), rawPath, matcher) + return + } + } +} diff --git a/proxy/utils.go b/proxy/utils.go new file mode 100644 index 0000000..21e7886 --- /dev/null +++ b/proxy/utils.go @@ -0,0 +1,90 @@ +package proxy + +import ( + "fmt" + "ghproxy/auth" + "ghproxy/config" + "ghproxy/rate" + + "github.com/cloudwego/hertz/pkg/app" +) + +func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo string, rawPath string) bool { + + // 白名单检查 + if cfg.Whitelist.Enabled { + var whitelist bool + whitelist = auth.CheckWhitelist(user, repo) + if !whitelist { + ErrorPage(c, NewErrorWithStatusLookup(403, fmt.Sprintf("Whitelist Blocked repo: %s/%s", user, repo))) + logInfo("%s %s %s %s %s Whitelist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) + return true + } + } + + // 黑名单检查 + if cfg.Blacklist.Enabled { + var blacklist bool + blacklist = auth.CheckBlacklist(user, repo) + if blacklist { + ErrorPage(c, NewErrorWithStatusLookup(403, fmt.Sprintf("Blacklist Blocked repo: %s/%s", user, repo))) + logInfo("%s %s %s %s %s Blacklist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo) + return true + } + } + + return false +} + +// 鉴权 +func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPath string) bool { + var err error + + if matcher == "api" && !cfg.Auth.ForceAllowApi { + if cfg.Auth.Method != "header" || !cfg.Auth.Enabled { + ErrorPage(c, NewErrorWithStatusLookup(403, "Github API Req without AuthHeader is Not Allowed")) + logInfo("%s %s %s AuthHeader Unavailable", c.ClientIP(), c.Method(), rawPath) + return true + } + } + + // 鉴权 + if cfg.Auth.Enabled { + var authcheck bool + authcheck, err = auth.AuthHandler(c, cfg) + if !authcheck { + ErrorPage(c, NewErrorWithStatusLookup(401, fmt.Sprintf("Unauthorized: %v", err))) + logInfo("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) + return true + } + } + + return false +} + +func rateCheck(cfg *config.Config, c *app.RequestContext, limiter *rate.RateLimiter, iplimiter *rate.IPRateLimiter) bool { + // 限制访问频率 + if cfg.RateLimit.Enabled { + + var allowed bool + + switch cfg.RateLimit.RateMethod { + case "ip": + allowed = iplimiter.Allow(c.ClientIP()) + case "total": + allowed = limiter.Allow() + default: + logWarning("Invalid RateLimit Method") + ErrorPage(c, NewErrorWithStatusLookup(500, "Invalid RateLimit Method")) + return true + } + + if !allowed { + ErrorPage(c, NewErrorWithStatusLookup(429, fmt.Sprintf("Too Many Requests; Rate Limit is %d per minute", cfg.RateLimit.RatePerMinute))) + logInfo("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + return true + } + } + + return false +} diff --git a/rate/rate.go b/rate/rate.go index a1237b2..8305801 100644 --- a/rate/rate.go +++ b/rate/rate.go @@ -1,13 +1,14 @@ package rate import ( + "sync" "time" "github.com/WJQSERVER-STUDIO/go-utils/logger" "golang.org/x/time/rate" ) -// 日志输出 +// 日志模块 var ( logw = logger.Logw logDump = logger.LogDump @@ -17,49 +18,90 @@ var ( logError = logger.LogError ) -// 总体限流器 +// RateLimiter 总体限流器 type RateLimiter struct { limiter *rate.Limiter } -// 基于IP的限流器 -type IPRateLimiter struct { - limiters map[string]*RateLimiter - limit int - burst int - duration time.Duration -} - +// New 创建一个总体限流器 func New(limit int, burst int, duration time.Duration) *RateLimiter { + if limit <= 0 { + limit = 1 + logWarning("rate limit per minute must be positive, setting to 1") + } + if burst <= 0 { + burst = 1 + logWarning("rate limit burst must be positive, setting to 1") + } + + rateLimit := rate.Limit(float64(limit) / duration.Seconds()) + return &RateLimiter{ - limiter: rate.NewLimiter(rate.Limit(float64(limit)/duration.Seconds()), burst), + limiter: rate.NewLimiter(rateLimit, burst), } } +// Allow 检查是否允许请求通过 func (rl *RateLimiter) Allow() bool { return rl.limiter.Allow() } -func NewIPRateLimiter(limit int, burst int, duration time.Duration) *IPRateLimiter { +// IPRateLimiter 基于IP的限流器 +type IPRateLimiter struct { + limiters map[string]*RateLimiter // 用户级限流器 map + mu sync.RWMutex // 保护 limiters map + limit int // 每 duration 时间段内允许的请求数 + burst int // 突发请求数 + duration time.Duration // 限流周期 +} + +// NewIPRateLimiter 创建一个基于IP的限流器 +func NewIPRateLimiter(ipLimit int, ipBurst int, duration time.Duration) *IPRateLimiter { + if ipLimit <= 0 { + ipLimit = 1 + logWarning("IP rate limit per minute must be positive, setting to 1") + } + if ipBurst <= 0 { + ipBurst = 1 + logWarning("IP rate limit burst must be positive, setting to 1") + } + + logInfo("IP Rate Limiter initialized with limit: %d, burst: %d, duration: %v", ipLimit, ipBurst, duration) + return &IPRateLimiter{ limiters: make(map[string]*RateLimiter), - limit: limit, - burst: burst, + limit: ipLimit, + burst: ipBurst, duration: duration, } } +// Allow 检查给定IP的请求是否允许通过 func (rl *IPRateLimiter) Allow(ip string) bool { if ip == "" { - logWarning("empty ip") + logWarning("empty ip for rate limiting") return false } - limiter, ok := rl.limiters[ip] - if !ok { - // 创建新的 RateLimiter 并存储 - limiter = New(rl.limit, rl.burst, rl.duration) - rl.limiters[ip] = limiter + // 使用读锁快速查找 + rl.mu.RLock() + limiter, found := rl.limiters[ip] + rl.mu.RUnlock() + + if found { + return limiter.Allow() } + + // 未找到,获取写锁来创建和添加 + rl.mu.Lock() + // 双重检查 + limiter, found = rl.limiters[ip] + if !found { + newL := New(rl.limit, rl.burst, rl.duration) + rl.limiters[ip] = newL + limiter = newL + } + rl.mu.Unlock() + return limiter.Allow() }