This commit is contained in:
wjqserver 2025-03-20 14:47:02 +08:00
parent 27cc30ab8b
commit c931017f03
4 changed files with 97 additions and 23 deletions

View file

@ -1,5 +1,11 @@
# 更新日志 # 更新日志
25w21a - 2025-03-20
---
- PRE-RELEASE: 此版本是v3.0.1的预发布版本,请勿在生产环境中使用;
- CHANGE: 改进cli
- CHANGE: 完善`gitreq`部分
3.0.0 - 2025-03-19 3.0.0 - 2025-03-19
--- ---
- RELEASE: Next Gen; 下一个起点; v3会与v2.4.0及以上版本保证兼容关系, 可平顺升级; - RELEASE: Next Gen; 下一个起点; v3会与v2.4.0及以上版本保证兼容关系, 可平顺升级;

View file

@ -1 +1 @@
25w20b 25w21a

108
main.go
View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"net/http" "net/http"
"os"
"time" "time"
"ghproxy/api" "ghproxy/api"
@ -27,14 +28,16 @@ import (
) )
var ( var (
cfg *config.Config cfg *config.Config
r *server.Hertz r *server.Hertz
configfile = "/data/ghproxy/config/config.toml" configfile = "/data/ghproxy/config/config.toml"
cfgfile string cfgfile string
version string version string
runMode string runMode string
limiter *rate.RateLimiter limiter *rate.RateLimiter
iplimiter *rate.IPRateLimiter iplimiter *rate.IPRateLimiter
showVersion bool // 新增的版本号标志
showHelp bool // 新增的帮助标志
) )
var ( var (
@ -61,6 +64,38 @@ var (
func readFlag() { func readFlag() {
flag.StringVar(&cfgfile, "cfg", configfile, "config file path") flag.StringVar(&cfgfile, "cfg", configfile, "config file path")
flag.BoolVar(&showVersion, "v", false, "show version and exit") // 添加-v标志
flag.BoolVar(&showHelp, "h", false, "show help message and exit") // 添加-h标志
// 捕获未定义的 flag
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
flag.PrintDefaults()
fmt.Fprintln(os.Stderr, "\nInvalid flags:")
// 检查未定义的flags
invalidFlags := []string{}
for _, arg := range os.Args[1:] {
if arg[0] == '-' && arg != "-h" && arg != "-v" { // 检查是否是flag, 排除 -h 和 -v
defined := false
flag.VisitAll(func(f *flag.Flag) {
if "-"+f.Name == arg {
defined = true
}
})
if !defined {
invalidFlags = append(invalidFlags, arg)
}
}
}
for _, flag := range invalidFlags {
fmt.Fprintf(os.Stderr, " %s\n", flag)
}
if len(invalidFlags) > 0 {
os.Exit(2) // 使用非零状态码退出,表示有错误
}
}
} }
func loadConfig() { func loadConfig() {
@ -68,8 +103,11 @@ func loadConfig() {
cfg, err = config.LoadConfig(cfgfile) cfg, err = config.LoadConfig(cfgfile)
if err != nil { if err != nil {
fmt.Printf("Failed to load config: %v\n", err) fmt.Printf("Failed to load config: %v\n", err)
// 如果配置文件加载失败,也显示帮助信息并退出
flag.Usage()
os.Exit(1)
} }
if cfg.Server.Debug { if cfg != nil && cfg.Server.Debug { // 确保 cfg 不为 nil
fmt.Println("Config File Path: ", cfgfile) fmt.Println("Config File Path: ", cfgfile)
fmt.Printf("Loaded config: %v\n", cfg) fmt.Printf("Loaded config: %v\n", cfg)
} }
@ -80,10 +118,12 @@ func setupLogger(cfg *config.Config) {
err = logger.Init(cfg.Log.LogFilePath, cfg.Log.MaxLogSize) err = logger.Init(cfg.Log.LogFilePath, cfg.Log.MaxLogSize)
if err != nil { if err != nil {
fmt.Printf("Failed to initialize logger: %v\n", err) fmt.Printf("Failed to initialize logger: %v\n", err)
os.Exit(1)
} }
err = logger.SetLogLevel(cfg.Log.Level) err = logger.SetLogLevel(cfg.Log.Level)
if err != nil { if err != nil {
fmt.Printf("Logger Level Error: %v\n", err) fmt.Printf("Logger Level Error: %v\n", err)
os.Exit(1)
} }
fmt.Printf("Log Level: %s\n", cfg.Log.Level) fmt.Printf("Log Level: %s\n", cfg.Log.Level)
logDebug("Config File Path: ", cfgfile) logDebug("Config File Path: ", cfgfile)
@ -260,27 +300,51 @@ func setupPages(cfg *config.Config, r *server.Hertz) {
func init() { func init() {
readFlag() readFlag()
flag.Parse() flag.Parse()
// 如果设置了 -h则显示帮助信息并退出
if showHelp {
flag.Usage()
os.Exit(0)
}
// 如果设置了 -v则显示版本号并退出
if showVersion {
fmt.Printf("GHProxy Version: %s \n", version)
os.Exit(0)
}
loadConfig() loadConfig()
setupLogger(cfg) if cfg != nil { // 在setupLogger前添加空值检查
InitReq(cfg) setupLogger(cfg)
loadlist(cfg) InitReq(cfg)
setupRateLimit(cfg) loadlist(cfg)
setupRateLimit(cfg)
if cfg.Server.Debug { if cfg.Server.Debug {
runMode = "dev" runMode = "dev"
} else { } else {
runMode = "release" runMode = "release"
}
if cfg.Server.Debug {
version = "Dev" // 如果是Debug模式版本设置为"Dev"
}
} }
if cfg.Server.Debug {
version = "Dev"
}
} }
func main() { func main() {
// 如果 showVersion 为 true则在 init 阶段已退出,这里直接返回
if showVersion || showHelp {
return
}
logDebug("Run Mode: %s", runMode) logDebug("Run Mode: %s", runMode)
// 确保在程序配置加载且非版本显示模式下执行
if cfg == nil {
fmt.Println("Config not loaded, exiting.")
return // 如果配置未加载,则不继续执行
}
addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
r := server.New( r := server.New(

View file

@ -45,6 +45,8 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
return return
} }
setRequestHeaders(c, req) setRequestHeaders(c, req)
removeWSHeader(req)
reWriteEncodeHeader(req)
AuthPassThrough(c, cfg, req) AuthPassThrough(c, cfg, req)
resp, err = gitclient.Do(req) resp, err = gitclient.Do(req)
@ -59,6 +61,8 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co
return return
} }
setRequestHeaders(c, req) setRequestHeaders(c, req)
removeWSHeader(req)
reWriteEncodeHeader(req)
AuthPassThrough(c, cfg, req) AuthPassThrough(c, cfg, req)
resp, err = client.Do(req) resp, err = client.Do(req)