diff --git a/CHANGELOG.md b/CHANGELOG.md index bc8603b..9b7cf74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,40 @@ # 更新日志 +3.0.1 - 2025-03-21 +--- +- RELEASE: Next Step; 下一步; 完善v3的同时, 修正已知问题; v3会与v2.4.0及以上版本保证兼容关系, 可平顺升级; +- CHANGE: 改进cli +- CHANGE: 重写`ProcessLinksAndWriteChunked`(脚本嵌套加速处理器), 修正已知问题的同时提高性能与效率 +- CHANGE: 完善`gitreq`部分 +- FIX: 修正日志输出格式问题 +- FIX: 使用更新的`hwriter`以修正相关问题 + +25w21e - 2025-03-21 +--- +- PRE-RELEASE: 此版本是v3.0.1的预发布版本,请勿在生产环境中使用; +- CHANGE: 重写`ProcessLinksAndWriteChunked`(脚本嵌套加速处理器), 修正已知问题的同时提高性能与效率 + +25w21d - 2025-03-21 +--- +- PRE-RELEASE: 此版本是v3.0.1的预发布版本,请勿在生产环境中使用; +- FIX: 使用更新的`hwriter`以修正相关问题 + +25w21c - 2025-03-20 +--- +- PRE-RELEASE: 此版本是v3.0.1的预发布版本,请勿在生产环境中使用; +- TEST: 测试新的`hwriter` + +25w21b - 2025-03-20 +--- +- PRE-RELEASE: 此版本是v3.0.1的预发布版本,请勿在生产环境中使用; +- FIX: 修正日志输出格式问题 + +25w21a - 2025-03-20 +--- +- PRE-RELEASE: 此版本是v3.0.1的预发布版本,请勿在生产环境中使用; +- CHANGE: 改进cli +- CHANGE: 完善`gitreq`部分 + 3.0.0 - 2025-03-19 --- - RELEASE: Next Gen; 下一个起点; v3会与v2.4.0及以上版本保证兼容关系, 可平顺升级; diff --git a/DEV-VERSION b/DEV-VERSION index b7d6d4a..34cebe9 100644 --- a/DEV-VERSION +++ b/DEV-VERSION @@ -1 +1 @@ -25w20b \ No newline at end of file +25w21e \ No newline at end of file diff --git a/SECURITY.MD b/SECURITY.MD index 5ac050f..eca438a 100644 --- a/SECURITY.MD +++ b/SECURITY.MD @@ -6,8 +6,9 @@ | 版本 | 是否支持 | | --- | --- | -| v2.x.x | :white_check_mark: 当前最新版本序列, 受支持 | -| v1.x.x | :x: 这些版本已结束生命周期,不再受支持 | +| v3.x.x | :white_check_mark: 当前最新版本序列 | +| v2.x.x | :x: 这些版本已结束生命周期,不受支持 | +| v1.x.x | :x: 这些版本已结束生命周期,不受支持 | | 25w*a/b/c... | :warning: 此为PRE-RELEASE版本,用于开发与测试,可能存在未知的问题 | | 24w*a/b/c... | :warning: 此为PRE-RELEASE版本,用于开发与测试,可能存在未知的问题 生命周期已完全结束 | | v0.x.x | :x: 这些版本不再受支持 | diff --git a/VERSION b/VERSION index 56fea8a..13d683c 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -3.0.0 \ No newline at end of file +3.0.1 \ No newline at end of file diff --git a/api/api.go b/api/api.go index 2a0068b..e463d9a 100644 --- a/api/api.go +++ b/api/api.go @@ -65,7 +65,7 @@ func InitHandleRouter(cfg *config.Config, r *server.Hertz, version string) { func SizeLimitHandler(cfg *config.Config, c *app.RequestContext, ctx context.Context) { sizeLimit := cfg.Server.SizeLimit - logInfo("%s %s %s %s %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logInfo("%s %s %s %s %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) c.Response.Header.Set("Content-Type", "application/json") c.JSON(200, (map[string]interface{}{ "MaxResponseBodySize": sizeLimit, @@ -73,7 +73,7 @@ func SizeLimitHandler(cfg *config.Config, c *app.RequestContext, ctx context.Con } func WhiteListStatusHandler(cfg *config.Config, c *app.RequestContext, ctx context.Context) { - logInfo("%s %s %s %s %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logInfo("%s %s %s %s %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) c.Response.Header.Set("Content-Type", "application/json") c.JSON(200, (map[string]interface{}{ "Whitelist": cfg.Whitelist.Enabled, @@ -81,7 +81,7 @@ func WhiteListStatusHandler(cfg *config.Config, c *app.RequestContext, ctx conte } func BlackListStatusHandler(cfg *config.Config, c *app.RequestContext, ctx context.Context) { - logInfo("%s %s %s %s %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logInfo("%s %s %s %s %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) c.Response.Header.Set("Content-Type", "application/json") c.JSON(200, (map[string]interface{}{ "Blacklist": cfg.Blacklist.Enabled, @@ -89,7 +89,7 @@ func BlackListStatusHandler(cfg *config.Config, c *app.RequestContext, ctx conte } func CorsStatusHandler(cfg *config.Config, c *app.RequestContext, ctx context.Context) { - logInfo("%s %s %s %s %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logInfo("%s %s %s %s %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) c.Response.Header.Set("Content-Type", "application/json") c.JSON(200, (map[string]interface{}{ "Cors": cfg.Server.Cors, @@ -97,7 +97,7 @@ func CorsStatusHandler(cfg *config.Config, c *app.RequestContext, ctx context.Co } func HealthcheckHandler(c *app.RequestContext, ctx context.Context) { - logInfo("%s %s %s %s %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logInfo("%s %s %s %s %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) c.Response.Header.Set("Content-Type", "application/json") c.JSON(200, (map[string]interface{}{ "Status": "OK", @@ -105,7 +105,7 @@ func HealthcheckHandler(c *app.RequestContext, ctx context.Context) { } func VersionHandler(c *app.RequestContext, ctx context.Context, version string) { - logInfo("%s %s %s %s %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logInfo("%s %s %s %s %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) c.Response.Header.Set("Content-Type", "application/json") c.JSON(200, (map[string]interface{}{ "Version": version, @@ -113,7 +113,7 @@ func VersionHandler(c *app.RequestContext, ctx context.Context, version string) } func RateLimitStatusHandler(cfg *config.Config, c *app.RequestContext, ctx context.Context) { - logInfo("%s %s %s %s %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logInfo("%s %s %s %s %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) c.Response.Header.Set("Content-Type", "application/json") c.JSON(200, (map[string]interface{}{ "RateLimit": cfg.RateLimit.Enabled, @@ -121,7 +121,7 @@ func RateLimitStatusHandler(cfg *config.Config, c *app.RequestContext, ctx conte } func RateLimitLimitHandler(cfg *config.Config, c *app.RequestContext, ctx context.Context) { - logInfo("%s %s %s %s %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logInfo("%s %s %s %s %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) c.Response.Header.Set("Content-Type", "application/json") c.JSON(200, (map[string]interface{}{ "RatePerMinute": cfg.RateLimit.RatePerMinute, @@ -129,7 +129,7 @@ func RateLimitLimitHandler(cfg *config.Config, c *app.RequestContext, ctx contex } func SmartGitStatusHandler(cfg *config.Config, c *app.RequestContext, ctx context.Context) { - logInfo("%s %s %s %s %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logInfo("%s %s %s %s %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) c.Response.Header.Set("Content-Type", "application/json") c.JSON(200, (map[string]interface{}{ "enabled": cfg.GitClone.Mode == "cache", diff --git a/auth/auth-header.go b/auth/auth-header.go index b63983d..18852ba 100644 --- a/auth/auth-header.go +++ b/auth/auth-header.go @@ -13,7 +13,7 @@ func AuthHeaderHandler(c *app.RequestContext, cfg *config.Config) (isValid bool, } // 获取"GH-Auth"的值 authToken := string(c.GetHeader("GH-Auth")) - logDebug("%s %s %s %s %s AUTH_TOKEN: %s", c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), authToken) + logDebug("%s %s %s %s %s AUTH_TOKEN: %s", c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), authToken) if authToken == "" { return false, fmt.Errorf("Auth token not found") } diff --git a/auth/auth-parameters.go b/auth/auth-parameters.go index 2e1139d..3635f92 100644 --- a/auth/auth-parameters.go +++ b/auth/auth-parameters.go @@ -13,7 +13,7 @@ func AuthParametersHandler(c *app.RequestContext, cfg *config.Config) (isValid b } authToken := c.Query("auth_token") - logDebug("%s %s %s %s %s AUTH_TOKEN: %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), authToken) + logDebug("%s %s %s %s %s AUTH_TOKEN: %s", c.ClientIP(), c.Method(), string(c.Path()), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), authToken) if authToken == "" { return false, fmt.Errorf("Auth token not found") diff --git a/go.mod b/go.mod index 588e63d..8527184 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.1 require ( github.com/BurntSushi/toml v1.5.0 - github.com/WJQSERVER-STUDIO/go-utils/hwriter v0.0.2 + github.com/WJQSERVER-STUDIO/go-utils/hwriter v0.0.3 github.com/WJQSERVER-STUDIO/go-utils/logger v1.5.0 github.com/cloudwego/hertz v0.9.6 github.com/hertz-contrib/http2 v0.1.8 diff --git a/go.sum b/go.sum index a4d7633..5174ab3 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4 h1:JLtFd00AdFg/TP+dtvIzLkdHwKUGPOAijN1sMtEYoFg= github.com/WJQSERVER-STUDIO/go-utils/copyb v0.0.4/go.mod h1:FZ6XE+4TKy4MOfX1xWKe6Rwsg0ucYFCdNh1KLvyKTfc= -github.com/WJQSERVER-STUDIO/go-utils/hwriter v0.0.2 h1:z9xSC3qkt8Qjjb+KRV0Az5klUBJ/gE3berBbjVSFVzY= -github.com/WJQSERVER-STUDIO/go-utils/hwriter v0.0.2/go.mod h1:U3dVP2MzKJfK6dPiobxmSdynibqCOn1mxQEVLylESWA= +github.com/WJQSERVER-STUDIO/go-utils/hwriter v0.0.3 h1:4kZH8GauRDR2R3ywgyob2Clyh3o1o/DPZCxknzi9HUU= +github.com/WJQSERVER-STUDIO/go-utils/hwriter v0.0.3/go.mod h1:U3dVP2MzKJfK6dPiobxmSdynibqCOn1mxQEVLylESWA= github.com/WJQSERVER-STUDIO/go-utils/log v0.0.1 h1:gJEQspQPB527Vp2FPcdOrynQEj3YYtrg1ixVSB/JvZM= github.com/WJQSERVER-STUDIO/go-utils/log v0.0.1/go.mod h1:j9Q+xnwpOfve7/uJnZ2izRQw6NNoXjvJHz7vUQAaLZE= github.com/WJQSERVER-STUDIO/go-utils/logger v1.5.0 h1:Uk4N7Sh4OPth3am3xVv17JlAm7tsna97ZLQRpQj7r5c= diff --git a/main.go b/main.go index 568832c..e1332a6 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "fmt" "io/fs" "net/http" + "os" "time" "ghproxy/api" @@ -27,14 +28,16 @@ import ( ) var ( - cfg *config.Config - r *server.Hertz - configfile = "/data/ghproxy/config/config.toml" - cfgfile string - version string - runMode string - limiter *rate.RateLimiter - iplimiter *rate.IPRateLimiter + cfg *config.Config + r *server.Hertz + configfile = "/data/ghproxy/config/config.toml" + cfgfile string + version string + runMode string + limiter *rate.RateLimiter + iplimiter *rate.IPRateLimiter + showVersion bool // 新增的版本号标志 + showHelp bool // 新增的帮助标志 ) var ( @@ -61,6 +64,38 @@ var ( func readFlag() { flag.StringVar(&cfgfile, "cfg", configfile, "config file path") + flag.BoolVar(&showVersion, "v", false, "show version and exit") // 添加-v标志 + flag.BoolVar(&showHelp, "h", false, "show help message and exit") // 添加-h标志 + + // 捕获未定义的 flag + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) + flag.PrintDefaults() + fmt.Fprintln(os.Stderr, "\nInvalid flags:") + + // 检查未定义的flags + invalidFlags := []string{} + for _, arg := range os.Args[1:] { + if arg[0] == '-' && arg != "-h" && arg != "-v" { // 检查是否是flag, 排除 -h 和 -v + defined := false + flag.VisitAll(func(f *flag.Flag) { + if "-"+f.Name == arg { + defined = true + } + }) + if !defined { + invalidFlags = append(invalidFlags, arg) + } + } + } + for _, flag := range invalidFlags { + fmt.Fprintf(os.Stderr, " %s\n", flag) + } + if len(invalidFlags) > 0 { + os.Exit(2) // 使用非零状态码退出,表示有错误 + } + + } } func loadConfig() { @@ -68,8 +103,11 @@ func loadConfig() { cfg, err = config.LoadConfig(cfgfile) if err != nil { fmt.Printf("Failed to load config: %v\n", err) + // 如果配置文件加载失败,也显示帮助信息并退出 + flag.Usage() + os.Exit(1) } - if cfg.Server.Debug { + if cfg != nil && cfg.Server.Debug { // 确保 cfg 不为 nil fmt.Println("Config File Path: ", cfgfile) fmt.Printf("Loaded config: %v\n", cfg) } @@ -80,10 +118,12 @@ func setupLogger(cfg *config.Config) { err = logger.Init(cfg.Log.LogFilePath, cfg.Log.MaxLogSize) if err != nil { fmt.Printf("Failed to initialize logger: %v\n", err) + os.Exit(1) } err = logger.SetLogLevel(cfg.Log.Level) if err != nil { fmt.Printf("Logger Level Error: %v\n", err) + os.Exit(1) } fmt.Printf("Log Level: %s\n", cfg.Log.Level) logDebug("Config File Path: ", cfgfile) @@ -260,27 +300,51 @@ func setupPages(cfg *config.Config, r *server.Hertz) { func init() { readFlag() flag.Parse() + + // 如果设置了 -h,则显示帮助信息并退出 + if showHelp { + flag.Usage() + os.Exit(0) + } + + // 如果设置了 -v,则显示版本号并退出 + if showVersion { + fmt.Printf("GHProxy Version: %s \n", version) + os.Exit(0) + } + loadConfig() - setupLogger(cfg) - InitReq(cfg) - loadlist(cfg) - setupRateLimit(cfg) + if cfg != nil { // 在setupLogger前添加空值检查 + setupLogger(cfg) + InitReq(cfg) + loadlist(cfg) + setupRateLimit(cfg) - if cfg.Server.Debug { - runMode = "dev" - } else { - runMode = "release" + if cfg.Server.Debug { + runMode = "dev" + } else { + runMode = "release" + } + + if cfg.Server.Debug { + version = "Dev" // 如果是Debug模式,版本设置为"Dev" + } } - - if cfg.Server.Debug { - version = "Dev" - } - } func main() { + // 如果 showVersion 为 true,则在 init 阶段已退出,这里直接返回 + if showVersion || showHelp { + return + } logDebug("Run Mode: %s", runMode) + // 确保在程序配置加载且非版本显示模式下执行 + if cfg == nil { + fmt.Println("Config not loaded, exiting.") + return // 如果配置未加载,则不继续执行 + } + addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) r := server.New( diff --git a/proxy/authpass.go b/proxy/authpass.go index 4bd3095..46d982c 100644 --- a/proxy/authpass.go +++ b/proxy/authpass.go @@ -11,13 +11,13 @@ func AuthPassThrough(c *app.RequestContext, cfg *config.Config, req *http.Reques if cfg.Auth.PassThrough { token := c.Query("token") if token != "" { - logDebug("%s %s %s %s %s Auth-PassThrough: token %s", c.ClientIP(), c.Request.Method, string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol(), token) + logDebug("%s %s %s %s %s Auth-PassThrough: token %s", c.ClientIP(), c.Method(), string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol(), token) switch cfg.Auth.AuthMethod { case "parameters": if !cfg.Auth.Enabled { req.Header.Set("Authorization", "token "+token) } else { - logWarning("%s %s %s %s %s Auth-Error: Conflict Auth Method", c.ClientIP(), c.Request.Method, string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol()) + logWarning("%s %s %s %s %s Auth-Error: Conflict Auth Method", c.ClientIP(), c.Method(), string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol()) // 500 Internal Server Error c.JSON(http.StatusInternalServerError, map[string]string{"error": "Conflict Auth Method"}) return @@ -27,7 +27,7 @@ func AuthPassThrough(c *app.RequestContext, cfg *config.Config, req *http.Reques req.Header.Set("Authorization", "token "+token) } default: - logWarning("%s %s %s %s %s Invalid Auth Method / Auth Method is not be set", c.ClientIP(), c.Request.Method, string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol()) + logWarning("%s %s %s %s %s Invalid Auth Method / Auth Method is not be set", c.ClientIP(), c.Method(), string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol()) // 500 Internal Server Error c.JSON(http.StatusInternalServerError, map[string]string{"error": "Invalid Auth Method / Auth Method is not be set"}) return diff --git a/proxy/chunkreq.go b/proxy/chunkreq.go index a00dda9..819d5c0 100644 --- a/proxy/chunkreq.go +++ b/proxy/chunkreq.go @@ -47,7 +47,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c if err == nil && size > sizelimit { finalURL := headResp.Request.URL.String() c.Redirect(http.StatusMovedPermanently, []byte(finalURL)) - logWarning("%s %s %s %s %s Final-URL: %s Size-Limit-Exceeded: %d", c.ClientIP(), c.Request.Method, c.Path(), c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), finalURL, size) + logWarning("%s %s %s %s %s Final-URL: %s Size-Limit-Exceeded: %d", c.ClientIP(), c.Method(), c.Path(), c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), finalURL, size) return } } @@ -85,7 +85,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c if err == nil && size > sizelimit { finalURL := resp.Request.URL.String() c.Redirect(http.StatusMovedPermanently, []byte(finalURL)) - logWarning("%s %s %s %s %s Final-URL: %s Size-Limit-Exceeded: %d", c.ClientIP(), c.Request.Method, c.Path(), c.UserAgent(), c.Request.Header.GetProtocol(), finalURL, size) + logWarning("%s %s %s %s %s Final-URL: %s Size-Limit-Exceeded: %d", c.ClientIP(), c.Method(), c.Path(), c.UserAgent(), c.Request.Header.GetProtocol(), finalURL, size) return } } diff --git a/proxy/gitreq.go b/proxy/gitreq.go index 2f781ed..f1330de 100644 --- a/proxy/gitreq.go +++ b/proxy/gitreq.go @@ -45,6 +45,8 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co return } setRequestHeaders(c, req) + removeWSHeader(req) + reWriteEncodeHeader(req) AuthPassThrough(c, cfg, req) resp, err = gitclient.Do(req) @@ -59,6 +61,8 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co return } setRequestHeaders(c, req) + removeWSHeader(req) + reWriteEncodeHeader(req) AuthPassThrough(c, cfg, req) resp, err = client.Do(req) @@ -81,7 +85,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co if err == nil && size > sizelimit { finalURL := []byte(resp.Request.URL.String()) c.Redirect(http.StatusMovedPermanently, finalURL) - logWarning("%s %s %s %s %s Final-URL: %s Size-Limit-Exceeded: %d", c.ClientIP(), c.Request.Method, c.Path(), c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), finalURL, size) + logWarning("%s %s %s %s %s Final-URL: %s Size-Limit-Exceeded: %d", c.ClientIP(), c.Method(), c.Path(), c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), finalURL, size) return } } diff --git a/proxy/handler.go b/proxy/handler.go index fbf679f..435e7d5 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -36,7 +36,7 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra if !allowed { c.JSON(http.StatusTooManyRequests, map[string]string{"error": "Too Many Requests"}) - logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Request.Method, c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) return } } @@ -47,7 +47,7 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra // 匹配路径错误处理 if len(matches) < 3 { - errMsg := fmt.Sprintf("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) + errMsg := fmt.Sprintf("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol()) logWarning(errMsg) c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath) return @@ -71,9 +71,9 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra } username := user - logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), username, repo) - // dump log 记录详细信息 c.ClientIP(), c.Request.Method, rawPath,c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), full Header - logDump("%s %s %s %s %s %s", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), c.Request.Header.Header()) + logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), username, repo) + // dump log 记录详细信息 c.ClientIP(), c.Method(), rawPath,c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), full Header + logDump("%s %s %s %s %s %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), c.Request.Header.Header()) repouser := fmt.Sprintf("%s/%s", username, repo) // 白名单检查 @@ -82,7 +82,7 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra if !whitelist { errMsg := fmt.Sprintf("Whitelist Blocked repo: %s", repouser) c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Whitelist Blocked repo: %s", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) + logWarning("%s %s %s %s %s Whitelist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) return } } @@ -93,7 +93,7 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra if blacklist { errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", repouser) c.JSON(http.StatusForbidden, map[string]string{"error": errMsg}) - logWarning("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) + logWarning("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser) return } } @@ -111,12 +111,12 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra if !authcheck { //c.AbortWithStatusJSON(401, gin.H{"error": "Unauthorized"}) c.AbortWithStatusJSON(401, map[string]string{"error": "Unauthorized"}) - logWarning("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) + logWarning("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err) return } // IP METHOD URL USERAGENT PROTO MATCHES - logDebug("%s %s %s %s %s Matches: %v", c.ClientIP(), c.Request.Method, rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), matches) + logDebug("%s %s %s %s %s Matches: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), matches) switch matcher { case "releases", "blob", "raw", "gist", "api": diff --git a/proxy/match.go b/proxy/match.go index 523b958..7a6e818 100644 --- a/proxy/match.go +++ b/proxy/match.go @@ -10,9 +10,10 @@ import ( "net/url" "regexp" "strings" + "sync" "github.com/cloudwego/hertz/pkg/app" - hresp "github.com/cloudwego/hertz/pkg/protocol/http1/resp" + "github.com/cloudwego/hertz/pkg/protocol/http1/resp" "github.com/valyala/bytebufferpool" ) @@ -211,87 +212,134 @@ func matchString(target string, stringsToMatch []string) bool { return exists } -// processLinksAndWriteChunked 处理链接并将结果以 chunked 方式写入响应 func ProcessLinksAndWriteChunked(input io.Reader, compress string, host string, cfg *config.Config, c *app.RequestContext) error { - var reader *bufio.Reader + pr, pw := io.Pipe() // 创建一个管道,用于进程间通信 + var wg sync.WaitGroup + wg.Add(2) - if compress == "gzip" { - // 解压 gzip - gzipReader, err := gzip.NewReader(input) - if err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("gzip 解压错误: %v", err)) - return fmt.Errorf("gzip 解压错误: %w", err) - } - defer gzipReader.Close() - reader = bufio.NewReader(gzipReader) - } else { - reader = bufio.NewReader(input) - } + var processErr error // 用于存储处理过程中发生的错误 - // 获取 chunked body writer - chunkedWriter := hresp.NewChunkedBodyWriter(&c.Response, c.GetWriter()) + go func() { + defer wg.Done() // 协程结束时通知 WaitGroup + defer pw.Close() // 协程结束时关闭管道的写端 - var writer io.Writer = chunkedWriter - var gzipWriter *gzip.Writer - - if compress == "gzip" { - gzipWriter = gzip.NewWriter(writer) - writer = gzipWriter - defer func() { - if err := gzipWriter.Close(); err != nil { - logError("gzipWriter close failed: %v", err) + var reader *bufio.Reader + if compress == "gzip" { // 如果需要解压 + gzipReader, err := gzip.NewReader(input) // 创建 gzip 解压器 + if err != nil { + c.String(http.StatusInternalServerError, fmt.Sprintf("gzip 解压错误: %v", err)) // 设置 HTTP 状态码和错误信息 + processErr = fmt.Errorf("gzip decompression error: %w", err) // gzip decompression error + return } - }() - } - - bufWrapper := bytebufferpool.Get() - buf := bufWrapper.B - size := 32768 // 32KB - buf = buf[:cap(buf)] - if len(buf) < size { - buf = append(buf, make([]byte, size-len(buf))...) - } - buf = buf[:size] // 将缓冲区限制为 'size' - defer bytebufferpool.Put(bufWrapper) - - urlPattern := regexp.MustCompile(`https?://[^\s'"]+`) - scanner := bufio.NewScanner(reader) - for scanner.Scan() { - line := scanner.Text() - modifiedLine := urlPattern.ReplaceAllStringFunc(line, func(originalURL string) string { - return modifyURL(originalURL, host, cfg) - }) - modifiedLineWithNewline := modifiedLine + "\n" - - _, err := writer.Write([]byte(modifiedLineWithNewline)) - if err != nil { - logError("写入 chunk 错误: %v", err) - return fmt.Errorf("写入 chunk 错误: %w", err) + defer gzipReader.Close() // 延迟关闭 gzip 解压器 + reader = bufio.NewReader(gzipReader) // 使用 bufio 读取解压后的数据 + } else { + reader = bufio.NewReader(input) // 直接使用 bufio 读取原始数据 } - if compress != "gzip" { - if fErr := chunkedWriter.Flush(); fErr != nil { - logError("chunkedWriter flush failed: %v", fErr) - return fmt.Errorf("chunkedWriter flush failed: %w", fErr) + var writer io.Writer = pw // 默认写入管道 + var gzipWriter *gzip.Writer + + if compress == "gzip" { // 如果需要压缩 + gzipWriter = gzip.NewWriter(writer) // 创建 gzip 压缩器 + writer = gzipWriter // 将 writer 设置为 gzip 压缩器 + defer func() { // 延迟关闭 gzip 压缩器 + if err := gzipWriter.Close(); err != nil { + logError("gzipWriter close failed: %v", err) + processErr = fmt.Errorf("gzipwriter close failed: %w", err) // gzipwriter close failed + } + }() + } + + urlPattern := regexp.MustCompile(`https?://[^\s'"]+`) // 编译正则表达式,用于匹配 URL + scanner := bufio.NewScanner(reader) // 创建 scanner 用于逐行扫描 + for scanner.Scan() { // 循环读取每一行 + line := scanner.Text() // 获取当前行 + modifiedLine := urlPattern.ReplaceAllStringFunc(line, func(originalURL string) string { // 替换 URL + return modifyURL(originalURL, host, cfg) // 调用 modifyURL 函数修改 URL + }) + modifiedLineWithNewline := modifiedLine + "\n" // 添加换行符 + + _, err := writer.Write([]byte(modifiedLineWithNewline)) // 将修改后的行写入管道/gzip + if err != nil { + logError("写入 pipe 错误: %v", err) // 记录错误 + processErr = fmt.Errorf("write to pipe error: %w", err) // write to pipe error + return } } - } - if err := scanner.Err(); err != nil { - logError("读取输入错误: %v", err) - c.String(http.StatusInternalServerError, fmt.Sprintf("读取输入错误: %v", err)) - return fmt.Errorf("读取输入错误: %w", err) - } - - // 对于 gzip,chunkedWriter 的关闭会触发最后的 chunk - if compress != "gzip" { - if fErr := chunkedWriter.Flush(); fErr != nil { - logError("final chunkedWriter flush failed: %v", fErr) - return fmt.Errorf("final chunkedWriter flush failed: %w", fErr) + if err := scanner.Err(); err != nil { + logError("读取输入错误: %v", err) // 记录错误 + c.String(http.StatusInternalServerError, fmt.Sprintf("读取输入错误: %v", err)) // 设置 HTTP 状态码和错误信息 + processErr = fmt.Errorf("read input error: %w", err) // read input error + return } - } + }() - return nil // 成功完成处理 + go func() { + defer wg.Done() // 协程结束时通知 WaitGroup + + c.Response.HijackWriter(resp.NewChunkedBodyWriter(&c.Response, c.GetWriter())) // 劫持 writer,启用分块编码 + + bufWrapper := bytebufferpool.Get() // 从对象池获取 bytebuffer + buf := bufWrapper.B + size := 32768 // 32KB, 设置缓冲区大小 + buf = buf[:cap(buf)] + if len(buf) < size { + buf = append(buf, make([]byte, size-len(buf))...) + } + buf = buf[:size] // 将缓冲区限制为 'size' + defer bytebufferpool.Put(bufWrapper) // 延迟将 bytebuffer 放回对象池 + + for { // 循环读取和写入数据 + n, err := pr.Read(buf) // 从管道读取数据 + if err != nil { + if err == io.EOF { // 如果读取到文件末尾 + if n > 0 { // 确保写入所有剩余数据 + _, err := c.Write(buf[:n]) // 写入最后的数据块 + if err != nil { + processErr = fmt.Errorf("failed to write final chunk: %w", err) // failed to write final chunk + break + } + } + c.Flush() // 刷新缓冲区 + break // 读取到文件末尾, 退出循环 + } + logError("hwriter.Writer read error: %v", err) // 记录错误 + if processErr == nil { + processErr = fmt.Errorf("failed to read from pipe: %w", err) // failed to read from pipe + // 不要在这里设置 http status code. 如果 read 失败, process 协程可能还没有完成, 它可能正在尝试设置 status code. 两个地方都设置会导致 race condition. + } + break // 读取错误,退出循环 + } + + if n > 0 { // 只有在实际读取到数据时才写入 + _, err = c.Write(buf[:n]) // 将数据写入响应 + if err != nil { + // 处理写入错误 (考虑记录日志并可能中止) + logError("hwriter.Writer write error: %v", err) + if processErr == nil { // 仅当 processErr 尚未设置时才设置. + processErr = fmt.Errorf("failed to write chunk: %w", err) // failed to write chunk + } + break // 写入错误, 退出循环 + } + + // 在大多数情况下,考虑移除 Flush. 仅在 *真正* 需要时保留它。 + if err := c.Flush(); err != nil { + // 更强大的 Flush() 错误处理 + c.AbortWithStatus(http.StatusInternalServerError) // 中止响应 + logError("hwriter.Writer flush error: %v", err) + if processErr == nil { + processErr = fmt.Errorf("failed to flush chunk: %w", err) // failed to flush chunk + } + break // 刷新错误, 退出循环 + } + } + } + }() + + wg.Wait() // 等待两个协程结束 + return processErr // 返回错误 } // extractParts 从给定的 URL 中提取所需的部分 diff --git a/proxy/reqheader.go b/proxy/reqheader.go index 86e33d4..4926186 100644 --- a/proxy/reqheader.go +++ b/proxy/reqheader.go @@ -20,7 +20,6 @@ func removeWSHeader(req *http.Request) { } func reWriteEncodeHeader(req *http.Request) { - if isGzipAccepted(req.Header) { req.Header.Set("Content-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")