diff --git a/DEV-VERSION b/DEV-VERSION index ad29565..aaa059b 100644 --- a/DEV-VERSION +++ b/DEV-VERSION @@ -1 +1 @@ -24w15b \ No newline at end of file +24w15c \ No newline at end of file diff --git a/auth/auth.go b/auth/auth.go index 4a8a7b0..bdb74eb 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -7,7 +7,24 @@ import ( "github.com/gin-gonic/gin" ) -var logw = logger.Logw +// 日志模块 +var ( + logw = logger.Logw + logInfo = logger.LogInfo + LogWarning = logger.LogWarning + logError = logger.LogError +) + +// Auth Init +func Init(cfg *config.Config) { + if cfg.Blacklist.Enabled { + LoadBlacklist(cfg) + } + if cfg.Whitelist.Enabled { + LoadWhitelist(cfg) + } + logInfo("Auth Init") +} func AuthHandler(c *gin.Context, cfg *config.Config) bool { // 如果身份验证未启用,直接返回 true @@ -17,19 +34,19 @@ func AuthHandler(c *gin.Context, cfg *config.Config) bool { // 获取 auth_token 参数 authToken := c.Query("auth_token") - logw("auth_token received: %s", authToken) + logInfo("auth_token received: %s", authToken) // 验证 token if authToken == "" { - logw("auth FAILED: no auth_token provided") + LogWarning("auth FAILED: no auth_token provided") return false } isValid := authToken == cfg.Auth.AuthToken if !isValid { - logw("auth FAILED: invalid auth_token: %s", authToken) + LogWarning("auth FAILED: invalid auth_token: %s", authToken) } - logw("auth SUCCESS: %t", isValid) + logInfo("auth SUCCESS: %t", isValid) return isValid } diff --git a/auth/blacklist.go b/auth/blacklist.go index 830537f..4b8f17d 100644 --- a/auth/blacklist.go +++ b/auth/blacklist.go @@ -4,6 +4,7 @@ import ( "encoding/json" "ghproxy/config" "os" + "strings" ) type BlacklistConfig struct { @@ -22,22 +23,32 @@ func LoadBlacklist(cfg *config.Config) { data, err := os.ReadFile(blacklistfile) if err != nil { - logw("Failed to read blacklist file: %v", err) + logError("Failed to read blacklist file: %v", err) } err = json.Unmarshal(data, blacklist) if err != nil { - logw("Failed to unmarshal blacklist JSON: %v", err) + logError("Failed to unmarshal blacklist JSON: %v", err) } } +// fullrepo: "owner/repo" or "owner/*" func CheckBlacklist(fullrepo string) bool { return forRangeCheckBlacklist(blacklist.Blacklist, fullrepo) } +func sliceRepoName_Blacklist(fullrepo string) (string, string) { + s := strings.Split(fullrepo, "/") + if len(s) != 2 { + return "", "" + } + return s[0], s[1] +} + func forRangeCheckBlacklist(blist []string, fullrepo string) bool { + repoUser, _ := sliceRepoName_Blacklist(fullrepo) for _, blocked := range blist { - if blocked == fullrepo { + if blocked == fullrepo || (strings.HasSuffix(blocked, "/*") && strings.HasPrefix(repoUser, blocked[:len(blocked)-2])) { return true } } diff --git a/auth/whitelist.go b/auth/whitelist.go index 326b3aa..340be52 100644 --- a/auth/whitelist.go +++ b/auth/whitelist.go @@ -4,6 +4,7 @@ import ( "encoding/json" "ghproxy/config" "os" + "strings" ) type WhitelistConfig struct { @@ -21,12 +22,12 @@ func LoadWhitelist(cfg *config.Config) { data, err := os.ReadFile(whitelistfile) if err != nil { - logw("Failed to read whitelist file: %v", err) + logError("Failed to read whitelist file: %v", err) } err = json.Unmarshal(data, whitelist) if err != nil { - logw("Failed to unmarshal whitelist JSON: %v", err) + logError("Failed to unmarshal whitelist JSON: %v", err) } } @@ -34,9 +35,18 @@ func CheckWhitelist(fullrepo string) bool { return forRangeCheckWhitelist(whitelist.Whitelist, fullrepo) } -func forRangeCheckWhitelist(blist []string, fullrepo string) bool { - for _, blocked := range blist { - if blocked == fullrepo { +func sliceRepoName_Whitelist(fullrepo string) (string, string) { + s := strings.Split(fullrepo, "/") + if len(s) != 2 { + return "", "" + } + return s[0], s[1] +} + +func forRangeCheckWhitelist(wlist []string, fullrepo string) bool { + repoUser, _ := sliceRepoName_Whitelist(fullrepo) + for _, blocked := range wlist { + if blocked == fullrepo || (strings.HasSuffix(blocked, "/*") && strings.HasPrefix(repoUser, blocked[:len(blocked)-2])) { return true } } diff --git a/config/blacklist.json b/config/blacklist.json index 286d110..8062ee0 100644 --- a/config/blacklist.json +++ b/config/blacklist.json @@ -1,8 +1,7 @@ { - "blacklist": [ - "black/list", - "test/test1", - "example/repo2" - ] - } - \ No newline at end of file + "blacklist": [ + "black/list", + "test/test1", + "example/repo2" + ] +} \ No newline at end of file diff --git a/config/config.go b/config/config.go index 04928d5..afebe99 100644 --- a/config/config.go +++ b/config/config.go @@ -56,7 +56,7 @@ func loadYAML(filePath string, out interface{}) error { type Config struct { Server ServerConfig - Log LoggerConfig + Log LogConfig CORS CORSConfig Auth AuthConfig Blacklist BlacklistConfig @@ -69,7 +69,7 @@ type ServerConfig struct { SizeLimit int `toml:"sizelimit"` } -type LoggerConfig struct { +type LogConfig struct { LogFilePath string `toml:"logfilepath"` MaxLogSize int `toml:"maxlogsize"` } diff --git a/config/whitelist.json b/config/whitelist.json index d7ba053..e39661b 100644 --- a/config/whitelist.json +++ b/config/whitelist.json @@ -1,8 +1,7 @@ { - "whitelist": [ - "white/list", - "white/test1", - "example/white" - ] - } - \ No newline at end of file + "whitelist": [ + "white/list", + "white/test1", + "example/white" + ] +} \ No newline at end of file diff --git a/logger/logger.go b/logger/logger.go index b6e7057..09d262f 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -19,14 +19,16 @@ var ( logChannel = make(chan string, 100) quitChannel = make(chan struct{}) logFileMutex sync.Mutex // 保护 logFile 的互斥锁 + logFilePath = "/data/ghproxy/log/ghproxy.log" ) // Init 初始化日志记录器,接受日志文件路径作为参数 -func Init(logFilePath string, maxLogsize int) error { +func Init(logFilePath_input string, maxLogsize int) error { logFileMutex.Lock() defer logFileMutex.Unlock() var err error + logFilePath = logFilePath_input logFile, err = os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) if err != nil { return err @@ -62,6 +64,26 @@ func Logw(format string, args ...interface{}) { Log(message) } +// 日志等级INFO +func LogInfo(format string, args ...interface{}) { + message := fmt.Sprintf(format, args...) + output := fmt.Sprintf("[INFO] %s", message) + Log(output) +} + +// 日志等级WARNING +func LogWarning(format string, args ...interface{}) { + message := fmt.Sprintf(format, args...) + output := fmt.Sprintf("[WARNING] %s", message) + Log(output) +} + +// 日志等级ERROR +func LogError(format string, args ...interface{}) { + message := fmt.Sprintf(format, args...) + Log(message) +} + // Close 关闭日志文件 func Close() { logFileMutex.Lock() diff --git a/main.go b/main.go index ce99d12..1adc9e9 100644 --- a/main.go +++ b/main.go @@ -17,11 +17,18 @@ import ( var ( cfg *config.Config - logw = logger.Logw router *gin.Engine configfile = "/data/ghproxy/config/config.toml" ) +// 日志模块 +var ( + logw = logger.Logw + logInfo = logger.LogInfo + LogWarning = logger.LogWarning + logError = logger.LogError +) + func ReadFlag() { cfgfile := flag.String("cfg", configfile, "config file path") configfile = *cfgfile @@ -44,8 +51,8 @@ func setupLogger(cfg *config.Config) { if err != nil { log.Fatalf("Failed to initialize logger: %v", err) } - logw("Logger initialized") - logw("Init Completed") + logInfo("Logger initialized") + logInfo("Init Completed") } func Loadlist(cfg *config.Config) { @@ -87,7 +94,7 @@ func main() { // 启动服务器 err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)) if err != nil { - log.Fatalf("Error starting server: %v\n", err) + logError("Error starting server: %v\n", err) } fmt.Println("Program finished") diff --git a/proxy/proxy.go b/proxy/proxy.go index d830103..309ddb0 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -17,7 +17,13 @@ import ( "github.com/imroc/req/v3" ) -var logw = logger.Logw +// 日志模块 +var ( + logw = logger.Logw + logInfo = logger.LogInfo + LogWarning = logger.LogWarning + logError = logger.LogError +) var exps = []*regexp.Regexp{ regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`), @@ -34,7 +40,7 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc { matches := re.FindStringSubmatch(rawPath) if len(matches) < 3 { - logw("Invalid URL: %s", rawPath) + LogWarning("Invalid URL: %s", rawPath) c.String(http.StatusForbidden, "Invalid URL.") return } @@ -45,14 +51,14 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc { pathmatches := regexp.MustCompile(`^([^/]+)/([^/]+)/([^/]+)/.*`) pathParts := pathmatches.FindStringSubmatch(matches[2]) if len(pathParts) < 4 { - logw("Invalid path: %s", rawPath) + LogWarning("Invalid path: %s", rawPath) c.String(http.StatusForbidden, "Invalid path; expected username/repo.") return } username := pathParts[2] repo := pathParts[3] - logw("Blacklist Check > Username: %s, Repo: %s", username, repo) + LogWarning("Blacklist Check > Username: %s, Repo: %s", username, repo) fullrepo := fmt.Sprintf("%s/%s", username, repo) // 白名单检查 @@ -61,7 +67,7 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc { if !whitelistpass { errMsg := fmt.Sprintf("Whitelist Blocked repo: %s", fullrepo) c.JSON(http.StatusForbidden, gin.H{"error": errMsg}) - logw(errMsg) + LogWarning(errMsg) return } } @@ -72,7 +78,7 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc { if blacklistpass { errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", fullrepo) c.JSON(http.StatusForbidden, gin.H{"error": errMsg}) - logw(errMsg) + LogWarning(errMsg) return } } @@ -89,18 +95,18 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc { if !auth.AuthHandler(c, cfg) { c.AbortWithStatusJSON(401, gin.H{"error": "Unauthorized"}) - logw("Unauthorized request: %s", rawPath) + LogWarning("Unauthorized request: %s", rawPath) return } - logw("Matches: %v", matches) + logInfo("Matches: %v", matches) switch { case exps[0].MatchString(rawPath), exps[1].MatchString(rawPath), exps[3].MatchString(rawPath), exps[4].MatchString(rawPath): - logw("%s Matched - USE proxy-chrome", rawPath) + logInfo("%s Matched - USE proxy-chrome", rawPath) ProxyRequest(c, rawPath, cfg, "chrome") case exps[2].MatchString(rawPath): - logw("%s Matched - USE proxy-git", rawPath) + logInfo("%s Matched - USE proxy-git", rawPath) ProxyRequest(c, rawPath, cfg, "git") default: c.String(http.StatusForbidden, "Invalid input.") @@ -111,33 +117,18 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc { func ProxyRequest(c *gin.Context, u string, cfg *config.Config, mode string) { method := c.Request.Method - logw("%s %s", method, u) + logInfo("%s %s", method, u) - client := req.C() + client := createHTTPClient(mode) - switch mode { - case "chrome": - client.SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36"). - SetTLSFingerprintChrome(). - ImpersonateChrome() - case "git": - client.SetUserAgent("git/2.33.1") - } - - body, err := io.ReadAll(c.Request.Body) + body, err := readRequestBody(c) if err != nil { - HandleError(c, fmt.Sprintf("Failed to read request body: %v", err)) + HandleError(c, err.Error()) return } - defer c.Request.Body.Close() req := client.R().SetBody(body) - - for key, values := range c.Request.Header { - for _, value := range values { - req.SetHeader(key, value) - } - } + setRequestHeaders(c, req) resp, err := SendRequest(req, method, u) if err != nil { @@ -147,17 +138,56 @@ func ProxyRequest(c *gin.Context, u string, cfg *config.Config, mode string) { defer resp.Body.Close() if err := HandleResponseSize(resp, cfg, c); err != nil { - logw("Error handling response size: %v", err) + LogWarning("Error handling response size: %v", err) return } CopyResponseHeaders(resp, c, cfg) c.Status(resp.StatusCode) - if _, err := io.Copy(c.Writer, resp.Body); err != nil { - logw("Failed to copy response body: %v", err) + if err := copyResponseBody(c, resp.Body); err != nil { + logError("Failed to copy response body: %v", err) } } +// createHTTPClient 创建并配置 HTTP 客户端 +func createHTTPClient(mode string) *req.Client { + client := req.C() + switch mode { + case "chrome": + client.SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36"). + SetTLSFingerprintChrome(). + ImpersonateChrome() + case "git": + client.SetUserAgent("git/2.33.1") + } + return client +} + +// readRequestBody 读取请求体 +func readRequestBody(c *gin.Context) ([]byte, error) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %v", err) + } + defer c.Request.Body.Close() + return body, nil +} + +// setRequestHeaders 设置请求头 +func setRequestHeaders(c *gin.Context, req *req.Request) { + for key, values := range c.Request.Header { + for _, value := range values { + req.SetHeader(key, value) + } + } +} + +// copyResponseBody 复制响应体到客户端 +func copyResponseBody(c *gin.Context, respBody io.Reader) error { + _, err := io.Copy(c.Writer, respBody) + return err +} + func SendRequest(req *req.Request, method, url string) (*req.Response, error) { switch method { case "GET": @@ -169,7 +199,7 @@ func SendRequest(req *req.Request, method, url string) (*req.Response, error) { case "DELETE": return req.Delete(url) default: - logw("Unsupported method: %s", method) + logInfo("Unsupported method: %s", method) return nil, fmt.Errorf("unsupported method: %s", method) } } @@ -181,14 +211,25 @@ func HandleResponseSize(resp *req.Response, cfg *config.Config, c *gin.Context) if err == nil && size > cfg.Server.SizeLimit { finalURL := resp.Request.URL.String() c.Redirect(http.StatusMovedPermanently, finalURL) - logw("Redirecting to %s due to size limit (%d bytes)", finalURL, size) - return fmt.Errorf("response size exceeds limit") + LogWarning("Size limit exceeded: %s, Size: %d", finalURL, size) + return fmt.Errorf("size limit exceeded: %d", size) } } return nil } func CopyResponseHeaders(resp *req.Response, c *gin.Context, cfg *config.Config) { + removeHeaders(resp) + + copyHeaders(resp, c) + + setCORSHeaders(c, cfg) + + setDefaultHeaders(c) +} + +// removeHeaders 移除指定的响应头 +func removeHeaders(resp *req.Response) { headersToRemove := map[string]struct{}{ "Content-Security-Policy": {}, "Referrer-Policy": {}, @@ -198,35 +239,45 @@ func CopyResponseHeaders(resp *req.Response, c *gin.Context, cfg *config.Config) for header := range headersToRemove { resp.Header.Del(header) } +} +// copyHeaders 复制响应头到 Gin 上下文 +func copyHeaders(resp *req.Response, c *gin.Context) { for key, values := range resp.Header { for _, value := range values { c.Header(key, value) } } +} - c.Header("Access-Control-Allow-Origin", "") +// setCORSHeaders 设置 CORS 相关的响应头 +func setCORSHeaders(c *gin.Context, cfg *config.Config) { if cfg.CORS.Enabled { c.Header("Access-Control-Allow-Origin", "*") + } else { + c.Header("Access-Control-Allow-Origin", "") } +} +// setDefaultHeaders 设置默认的响应头 +func setDefaultHeaders(c *gin.Context) { c.Header("Age", "10") c.Header("Cache-Control", "max-age=300") } func HandleError(c *gin.Context, message string) { c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", message)) - logw(message) + LogWarning(message) } func CheckURL(u string) []string { for _, exp := range exps { if matches := exp.FindStringSubmatch(u); matches != nil { - logw("URL matched: %s, Matches: %v", u, matches[1:]) + logInfo("URL matched: %s, Matches: %v", u, matches[1:]) return matches[1:] } } errMsg := fmt.Sprintf("Invalid URL: %s", u) - logw(errMsg) + LogWarning(errMsg) return nil }