This commit is contained in:
WJQSERVER 2024-10-12 03:50:34 +08:00
parent e3d56ae9b7
commit 824656f9d0
10 changed files with 191 additions and 75 deletions

View file

@ -1 +1 @@
24w15b 24w15c

View file

@ -7,7 +7,24 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var logw = logger.Logw // 日志模块
var (
logw = logger.Logw
logInfo = logger.LogInfo
LogWarning = logger.LogWarning
logError = logger.LogError
)
// Auth Init
func Init(cfg *config.Config) {
if cfg.Blacklist.Enabled {
LoadBlacklist(cfg)
}
if cfg.Whitelist.Enabled {
LoadWhitelist(cfg)
}
logInfo("Auth Init")
}
func AuthHandler(c *gin.Context, cfg *config.Config) bool { func AuthHandler(c *gin.Context, cfg *config.Config) bool {
// 如果身份验证未启用,直接返回 true // 如果身份验证未启用,直接返回 true
@ -17,19 +34,19 @@ func AuthHandler(c *gin.Context, cfg *config.Config) bool {
// 获取 auth_token 参数 // 获取 auth_token 参数
authToken := c.Query("auth_token") authToken := c.Query("auth_token")
logw("auth_token received: %s", authToken) logInfo("auth_token received: %s", authToken)
// 验证 token // 验证 token
if authToken == "" { if authToken == "" {
logw("auth FAILED: no auth_token provided") LogWarning("auth FAILED: no auth_token provided")
return false return false
} }
isValid := authToken == cfg.Auth.AuthToken isValid := authToken == cfg.Auth.AuthToken
if !isValid { if !isValid {
logw("auth FAILED: invalid auth_token: %s", authToken) LogWarning("auth FAILED: invalid auth_token: %s", authToken)
} }
logw("auth SUCCESS: %t", isValid) logInfo("auth SUCCESS: %t", isValid)
return isValid return isValid
} }

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"ghproxy/config" "ghproxy/config"
"os" "os"
"strings"
) )
type BlacklistConfig struct { type BlacklistConfig struct {
@ -22,22 +23,32 @@ func LoadBlacklist(cfg *config.Config) {
data, err := os.ReadFile(blacklistfile) data, err := os.ReadFile(blacklistfile)
if err != nil { if err != nil {
logw("Failed to read blacklist file: %v", err) logError("Failed to read blacklist file: %v", err)
} }
err = json.Unmarshal(data, blacklist) err = json.Unmarshal(data, blacklist)
if err != nil { if err != nil {
logw("Failed to unmarshal blacklist JSON: %v", err) logError("Failed to unmarshal blacklist JSON: %v", err)
} }
} }
// fullrepo: "owner/repo" or "owner/*"
func CheckBlacklist(fullrepo string) bool { func CheckBlacklist(fullrepo string) bool {
return forRangeCheckBlacklist(blacklist.Blacklist, fullrepo) return forRangeCheckBlacklist(blacklist.Blacklist, fullrepo)
} }
func sliceRepoName_Blacklist(fullrepo string) (string, string) {
s := strings.Split(fullrepo, "/")
if len(s) != 2 {
return "", ""
}
return s[0], s[1]
}
func forRangeCheckBlacklist(blist []string, fullrepo string) bool { func forRangeCheckBlacklist(blist []string, fullrepo string) bool {
repoUser, _ := sliceRepoName_Blacklist(fullrepo)
for _, blocked := range blist { for _, blocked := range blist {
if blocked == fullrepo { if blocked == fullrepo || (strings.HasSuffix(blocked, "/*") && strings.HasPrefix(repoUser, blocked[:len(blocked)-2])) {
return true return true
} }
} }

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"ghproxy/config" "ghproxy/config"
"os" "os"
"strings"
) )
type WhitelistConfig struct { type WhitelistConfig struct {
@ -21,12 +22,12 @@ func LoadWhitelist(cfg *config.Config) {
data, err := os.ReadFile(whitelistfile) data, err := os.ReadFile(whitelistfile)
if err != nil { if err != nil {
logw("Failed to read whitelist file: %v", err) logError("Failed to read whitelist file: %v", err)
} }
err = json.Unmarshal(data, whitelist) err = json.Unmarshal(data, whitelist)
if err != nil { if err != nil {
logw("Failed to unmarshal whitelist JSON: %v", err) logError("Failed to unmarshal whitelist JSON: %v", err)
} }
} }
@ -34,9 +35,18 @@ func CheckWhitelist(fullrepo string) bool {
return forRangeCheckWhitelist(whitelist.Whitelist, fullrepo) return forRangeCheckWhitelist(whitelist.Whitelist, fullrepo)
} }
func forRangeCheckWhitelist(blist []string, fullrepo string) bool { func sliceRepoName_Whitelist(fullrepo string) (string, string) {
for _, blocked := range blist { s := strings.Split(fullrepo, "/")
if blocked == fullrepo { if len(s) != 2 {
return "", ""
}
return s[0], s[1]
}
func forRangeCheckWhitelist(wlist []string, fullrepo string) bool {
repoUser, _ := sliceRepoName_Whitelist(fullrepo)
for _, blocked := range wlist {
if blocked == fullrepo || (strings.HasSuffix(blocked, "/*") && strings.HasPrefix(repoUser, blocked[:len(blocked)-2])) {
return true return true
} }
} }

View file

@ -5,4 +5,3 @@
"example/repo2" "example/repo2"
] ]
} }

View file

@ -56,7 +56,7 @@ func loadYAML(filePath string, out interface{}) error {
type Config struct { type Config struct {
Server ServerConfig Server ServerConfig
Log LoggerConfig Log LogConfig
CORS CORSConfig CORS CORSConfig
Auth AuthConfig Auth AuthConfig
Blacklist BlacklistConfig Blacklist BlacklistConfig
@ -69,7 +69,7 @@ type ServerConfig struct {
SizeLimit int `toml:"sizelimit"` SizeLimit int `toml:"sizelimit"`
} }
type LoggerConfig struct { type LogConfig struct {
LogFilePath string `toml:"logfilepath"` LogFilePath string `toml:"logfilepath"`
MaxLogSize int `toml:"maxlogsize"` MaxLogSize int `toml:"maxlogsize"`
} }

View file

@ -5,4 +5,3 @@
"example/white" "example/white"
] ]
} }

View file

@ -19,14 +19,16 @@ var (
logChannel = make(chan string, 100) logChannel = make(chan string, 100)
quitChannel = make(chan struct{}) quitChannel = make(chan struct{})
logFileMutex sync.Mutex // 保护 logFile 的互斥锁 logFileMutex sync.Mutex // 保护 logFile 的互斥锁
logFilePath = "/data/ghproxy/log/ghproxy.log"
) )
// Init 初始化日志记录器,接受日志文件路径作为参数 // Init 初始化日志记录器,接受日志文件路径作为参数
func Init(logFilePath string, maxLogsize int) error { func Init(logFilePath_input string, maxLogsize int) error {
logFileMutex.Lock() logFileMutex.Lock()
defer logFileMutex.Unlock() defer logFileMutex.Unlock()
var err error var err error
logFilePath = logFilePath_input
logFile, err = os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) logFile, err = os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
if err != nil { if err != nil {
return err return err
@ -62,6 +64,26 @@ func Logw(format string, args ...interface{}) {
Log(message) Log(message)
} }
// 日志等级INFO
func LogInfo(format string, args ...interface{}) {
message := fmt.Sprintf(format, args...)
output := fmt.Sprintf("[INFO] %s", message)
Log(output)
}
// 日志等级WARNING
func LogWarning(format string, args ...interface{}) {
message := fmt.Sprintf(format, args...)
output := fmt.Sprintf("[WARNING] %s", message)
Log(output)
}
// 日志等级ERROR
func LogError(format string, args ...interface{}) {
message := fmt.Sprintf(format, args...)
Log(message)
}
// Close 关闭日志文件 // Close 关闭日志文件
func Close() { func Close() {
logFileMutex.Lock() logFileMutex.Lock()

15
main.go
View file

@ -17,11 +17,18 @@ import (
var ( var (
cfg *config.Config cfg *config.Config
logw = logger.Logw
router *gin.Engine router *gin.Engine
configfile = "/data/ghproxy/config/config.toml" configfile = "/data/ghproxy/config/config.toml"
) )
// 日志模块
var (
logw = logger.Logw
logInfo = logger.LogInfo
LogWarning = logger.LogWarning
logError = logger.LogError
)
func ReadFlag() { func ReadFlag() {
cfgfile := flag.String("cfg", configfile, "config file path") cfgfile := flag.String("cfg", configfile, "config file path")
configfile = *cfgfile configfile = *cfgfile
@ -44,8 +51,8 @@ func setupLogger(cfg *config.Config) {
if err != nil { if err != nil {
log.Fatalf("Failed to initialize logger: %v", err) log.Fatalf("Failed to initialize logger: %v", err)
} }
logw("Logger initialized") logInfo("Logger initialized")
logw("Init Completed") logInfo("Init Completed")
} }
func Loadlist(cfg *config.Config) { func Loadlist(cfg *config.Config) {
@ -87,7 +94,7 @@ func main() {
// 启动服务器 // 启动服务器
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)) err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port))
if err != nil { if err != nil {
log.Fatalf("Error starting server: %v\n", err) logError("Error starting server: %v\n", err)
} }
fmt.Println("Program finished") fmt.Println("Program finished")

View file

@ -17,7 +17,13 @@ import (
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
) )
var logw = logger.Logw // 日志模块
var (
logw = logger.Logw
logInfo = logger.LogInfo
LogWarning = logger.LogWarning
logError = logger.LogError
)
var exps = []*regexp.Regexp{ var exps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`),
@ -34,7 +40,7 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
matches := re.FindStringSubmatch(rawPath) matches := re.FindStringSubmatch(rawPath)
if len(matches) < 3 { if len(matches) < 3 {
logw("Invalid URL: %s", rawPath) LogWarning("Invalid URL: %s", rawPath)
c.String(http.StatusForbidden, "Invalid URL.") c.String(http.StatusForbidden, "Invalid URL.")
return return
} }
@ -45,14 +51,14 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
pathmatches := regexp.MustCompile(`^([^/]+)/([^/]+)/([^/]+)/.*`) pathmatches := regexp.MustCompile(`^([^/]+)/([^/]+)/([^/]+)/.*`)
pathParts := pathmatches.FindStringSubmatch(matches[2]) pathParts := pathmatches.FindStringSubmatch(matches[2])
if len(pathParts) < 4 { if len(pathParts) < 4 {
logw("Invalid path: %s", rawPath) LogWarning("Invalid path: %s", rawPath)
c.String(http.StatusForbidden, "Invalid path; expected username/repo.") c.String(http.StatusForbidden, "Invalid path; expected username/repo.")
return return
} }
username := pathParts[2] username := pathParts[2]
repo := pathParts[3] repo := pathParts[3]
logw("Blacklist Check > Username: %s, Repo: %s", username, repo) LogWarning("Blacklist Check > Username: %s, Repo: %s", username, repo)
fullrepo := fmt.Sprintf("%s/%s", username, repo) fullrepo := fmt.Sprintf("%s/%s", username, repo)
// 白名单检查 // 白名单检查
@ -61,7 +67,7 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
if !whitelistpass { if !whitelistpass {
errMsg := fmt.Sprintf("Whitelist Blocked repo: %s", fullrepo) errMsg := fmt.Sprintf("Whitelist Blocked repo: %s", fullrepo)
c.JSON(http.StatusForbidden, gin.H{"error": errMsg}) c.JSON(http.StatusForbidden, gin.H{"error": errMsg})
logw(errMsg) LogWarning(errMsg)
return return
} }
} }
@ -72,7 +78,7 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
if blacklistpass { if blacklistpass {
errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", fullrepo) errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", fullrepo)
c.JSON(http.StatusForbidden, gin.H{"error": errMsg}) c.JSON(http.StatusForbidden, gin.H{"error": errMsg})
logw(errMsg) LogWarning(errMsg)
return return
} }
} }
@ -89,18 +95,18 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
if !auth.AuthHandler(c, cfg) { if !auth.AuthHandler(c, cfg) {
c.AbortWithStatusJSON(401, gin.H{"error": "Unauthorized"}) c.AbortWithStatusJSON(401, gin.H{"error": "Unauthorized"})
logw("Unauthorized request: %s", rawPath) LogWarning("Unauthorized request: %s", rawPath)
return return
} }
logw("Matches: %v", matches) logInfo("Matches: %v", matches)
switch { switch {
case exps[0].MatchString(rawPath), exps[1].MatchString(rawPath), exps[3].MatchString(rawPath), exps[4].MatchString(rawPath): case exps[0].MatchString(rawPath), exps[1].MatchString(rawPath), exps[3].MatchString(rawPath), exps[4].MatchString(rawPath):
logw("%s Matched - USE proxy-chrome", rawPath) logInfo("%s Matched - USE proxy-chrome", rawPath)
ProxyRequest(c, rawPath, cfg, "chrome") ProxyRequest(c, rawPath, cfg, "chrome")
case exps[2].MatchString(rawPath): case exps[2].MatchString(rawPath):
logw("%s Matched - USE proxy-git", rawPath) logInfo("%s Matched - USE proxy-git", rawPath)
ProxyRequest(c, rawPath, cfg, "git") ProxyRequest(c, rawPath, cfg, "git")
default: default:
c.String(http.StatusForbidden, "Invalid input.") c.String(http.StatusForbidden, "Invalid input.")
@ -111,33 +117,18 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
func ProxyRequest(c *gin.Context, u string, cfg *config.Config, mode string) { func ProxyRequest(c *gin.Context, u string, cfg *config.Config, mode string) {
method := c.Request.Method method := c.Request.Method
logw("%s %s", method, u) logInfo("%s %s", method, u)
client := req.C() client := createHTTPClient(mode)
switch mode { body, err := readRequestBody(c)
case "chrome":
client.SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36").
SetTLSFingerprintChrome().
ImpersonateChrome()
case "git":
client.SetUserAgent("git/2.33.1")
}
body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
HandleError(c, fmt.Sprintf("Failed to read request body: %v", err)) HandleError(c, err.Error())
return return
} }
defer c.Request.Body.Close()
req := client.R().SetBody(body) req := client.R().SetBody(body)
setRequestHeaders(c, req)
for key, values := range c.Request.Header {
for _, value := range values {
req.SetHeader(key, value)
}
}
resp, err := SendRequest(req, method, u) resp, err := SendRequest(req, method, u)
if err != nil { if err != nil {
@ -147,17 +138,56 @@ func ProxyRequest(c *gin.Context, u string, cfg *config.Config, mode string) {
defer resp.Body.Close() defer resp.Body.Close()
if err := HandleResponseSize(resp, cfg, c); err != nil { if err := HandleResponseSize(resp, cfg, c); err != nil {
logw("Error handling response size: %v", err) LogWarning("Error handling response size: %v", err)
return return
} }
CopyResponseHeaders(resp, c, cfg) CopyResponseHeaders(resp, c, cfg)
c.Status(resp.StatusCode) c.Status(resp.StatusCode)
if _, err := io.Copy(c.Writer, resp.Body); err != nil { if err := copyResponseBody(c, resp.Body); err != nil {
logw("Failed to copy response body: %v", err) logError("Failed to copy response body: %v", err)
} }
} }
// createHTTPClient 创建并配置 HTTP 客户端
func createHTTPClient(mode string) *req.Client {
client := req.C()
switch mode {
case "chrome":
client.SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36").
SetTLSFingerprintChrome().
ImpersonateChrome()
case "git":
client.SetUserAgent("git/2.33.1")
}
return client
}
// readRequestBody 读取请求体
func readRequestBody(c *gin.Context) ([]byte, error) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}
defer c.Request.Body.Close()
return body, nil
}
// setRequestHeaders 设置请求头
func setRequestHeaders(c *gin.Context, req *req.Request) {
for key, values := range c.Request.Header {
for _, value := range values {
req.SetHeader(key, value)
}
}
}
// copyResponseBody 复制响应体到客户端
func copyResponseBody(c *gin.Context, respBody io.Reader) error {
_, err := io.Copy(c.Writer, respBody)
return err
}
func SendRequest(req *req.Request, method, url string) (*req.Response, error) { func SendRequest(req *req.Request, method, url string) (*req.Response, error) {
switch method { switch method {
case "GET": case "GET":
@ -169,7 +199,7 @@ func SendRequest(req *req.Request, method, url string) (*req.Response, error) {
case "DELETE": case "DELETE":
return req.Delete(url) return req.Delete(url)
default: default:
logw("Unsupported method: %s", method) logInfo("Unsupported method: %s", method)
return nil, fmt.Errorf("unsupported method: %s", method) return nil, fmt.Errorf("unsupported method: %s", method)
} }
} }
@ -181,14 +211,25 @@ func HandleResponseSize(resp *req.Response, cfg *config.Config, c *gin.Context)
if err == nil && size > cfg.Server.SizeLimit { if err == nil && size > cfg.Server.SizeLimit {
finalURL := resp.Request.URL.String() finalURL := resp.Request.URL.String()
c.Redirect(http.StatusMovedPermanently, finalURL) c.Redirect(http.StatusMovedPermanently, finalURL)
logw("Redirecting to %s due to size limit (%d bytes)", finalURL, size) LogWarning("Size limit exceeded: %s, Size: %d", finalURL, size)
return fmt.Errorf("response size exceeds limit") return fmt.Errorf("size limit exceeded: %d", size)
} }
} }
return nil return nil
} }
func CopyResponseHeaders(resp *req.Response, c *gin.Context, cfg *config.Config) { func CopyResponseHeaders(resp *req.Response, c *gin.Context, cfg *config.Config) {
removeHeaders(resp)
copyHeaders(resp, c)
setCORSHeaders(c, cfg)
setDefaultHeaders(c)
}
// removeHeaders 移除指定的响应头
func removeHeaders(resp *req.Response) {
headersToRemove := map[string]struct{}{ headersToRemove := map[string]struct{}{
"Content-Security-Policy": {}, "Content-Security-Policy": {},
"Referrer-Policy": {}, "Referrer-Policy": {},
@ -198,35 +239,45 @@ func CopyResponseHeaders(resp *req.Response, c *gin.Context, cfg *config.Config)
for header := range headersToRemove { for header := range headersToRemove {
resp.Header.Del(header) resp.Header.Del(header)
} }
}
// copyHeaders 复制响应头到 Gin 上下文
func copyHeaders(resp *req.Response, c *gin.Context) {
for key, values := range resp.Header { for key, values := range resp.Header {
for _, value := range values { for _, value := range values {
c.Header(key, value) c.Header(key, value)
} }
} }
c.Header("Access-Control-Allow-Origin", "")
if cfg.CORS.Enabled {
c.Header("Access-Control-Allow-Origin", "*")
} }
// setCORSHeaders 设置 CORS 相关的响应头
func setCORSHeaders(c *gin.Context, cfg *config.Config) {
if cfg.CORS.Enabled {
c.Header("Access-Control-Allow-Origin", "*")
} else {
c.Header("Access-Control-Allow-Origin", "")
}
}
// setDefaultHeaders 设置默认的响应头
func setDefaultHeaders(c *gin.Context) {
c.Header("Age", "10") c.Header("Age", "10")
c.Header("Cache-Control", "max-age=300") c.Header("Cache-Control", "max-age=300")
} }
func HandleError(c *gin.Context, message string) { func HandleError(c *gin.Context, message string) {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", message)) c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", message))
logw(message) LogWarning(message)
} }
func CheckURL(u string) []string { func CheckURL(u string) []string {
for _, exp := range exps { for _, exp := range exps {
if matches := exp.FindStringSubmatch(u); matches != nil { if matches := exp.FindStringSubmatch(u); matches != nil {
logw("URL matched: %s, Matches: %v", u, matches[1:]) logInfo("URL matched: %s, Matches: %v", u, matches[1:])
return matches[1:] return matches[1:]
} }
} }
errMsg := fmt.Sprintf("Invalid URL: %s", u) errMsg := fmt.Sprintf("Invalid URL: %s", u)
logw(errMsg) LogWarning(errMsg)
return nil return nil
} }