This commit is contained in:
WJQSERVER 2024-10-12 03:50:34 +08:00
parent e3d56ae9b7
commit 824656f9d0
10 changed files with 191 additions and 75 deletions

View file

@ -1 +1 @@
24w15b
24w15c

View file

@ -7,7 +7,24 @@ import (
"github.com/gin-gonic/gin"
)
var logw = logger.Logw
// 日志模块
var (
logw = logger.Logw
logInfo = logger.LogInfo
LogWarning = logger.LogWarning
logError = logger.LogError
)
// Auth Init
func Init(cfg *config.Config) {
if cfg.Blacklist.Enabled {
LoadBlacklist(cfg)
}
if cfg.Whitelist.Enabled {
LoadWhitelist(cfg)
}
logInfo("Auth Init")
}
func AuthHandler(c *gin.Context, cfg *config.Config) bool {
// 如果身份验证未启用,直接返回 true
@ -17,19 +34,19 @@ func AuthHandler(c *gin.Context, cfg *config.Config) bool {
// 获取 auth_token 参数
authToken := c.Query("auth_token")
logw("auth_token received: %s", authToken)
logInfo("auth_token received: %s", authToken)
// 验证 token
if authToken == "" {
logw("auth FAILED: no auth_token provided")
LogWarning("auth FAILED: no auth_token provided")
return false
}
isValid := authToken == cfg.Auth.AuthToken
if !isValid {
logw("auth FAILED: invalid auth_token: %s", authToken)
LogWarning("auth FAILED: invalid auth_token: %s", authToken)
}
logw("auth SUCCESS: %t", isValid)
logInfo("auth SUCCESS: %t", isValid)
return isValid
}

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"ghproxy/config"
"os"
"strings"
)
type BlacklistConfig struct {
@ -22,22 +23,32 @@ func LoadBlacklist(cfg *config.Config) {
data, err := os.ReadFile(blacklistfile)
if err != nil {
logw("Failed to read blacklist file: %v", err)
logError("Failed to read blacklist file: %v", err)
}
err = json.Unmarshal(data, blacklist)
if err != nil {
logw("Failed to unmarshal blacklist JSON: %v", err)
logError("Failed to unmarshal blacklist JSON: %v", err)
}
}
// fullrepo: "owner/repo" or "owner/*"
func CheckBlacklist(fullrepo string) bool {
return forRangeCheckBlacklist(blacklist.Blacklist, fullrepo)
}
func sliceRepoName_Blacklist(fullrepo string) (string, string) {
s := strings.Split(fullrepo, "/")
if len(s) != 2 {
return "", ""
}
return s[0], s[1]
}
func forRangeCheckBlacklist(blist []string, fullrepo string) bool {
repoUser, _ := sliceRepoName_Blacklist(fullrepo)
for _, blocked := range blist {
if blocked == fullrepo {
if blocked == fullrepo || (strings.HasSuffix(blocked, "/*") && strings.HasPrefix(repoUser, blocked[:len(blocked)-2])) {
return true
}
}

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"ghproxy/config"
"os"
"strings"
)
type WhitelistConfig struct {
@ -21,12 +22,12 @@ func LoadWhitelist(cfg *config.Config) {
data, err := os.ReadFile(whitelistfile)
if err != nil {
logw("Failed to read whitelist file: %v", err)
logError("Failed to read whitelist file: %v", err)
}
err = json.Unmarshal(data, whitelist)
if err != nil {
logw("Failed to unmarshal whitelist JSON: %v", err)
logError("Failed to unmarshal whitelist JSON: %v", err)
}
}
@ -34,9 +35,18 @@ func CheckWhitelist(fullrepo string) bool {
return forRangeCheckWhitelist(whitelist.Whitelist, fullrepo)
}
func forRangeCheckWhitelist(blist []string, fullrepo string) bool {
for _, blocked := range blist {
if blocked == fullrepo {
func sliceRepoName_Whitelist(fullrepo string) (string, string) {
s := strings.Split(fullrepo, "/")
if len(s) != 2 {
return "", ""
}
return s[0], s[1]
}
func forRangeCheckWhitelist(wlist []string, fullrepo string) bool {
repoUser, _ := sliceRepoName_Whitelist(fullrepo)
for _, blocked := range wlist {
if blocked == fullrepo || (strings.HasSuffix(blocked, "/*") && strings.HasPrefix(repoUser, blocked[:len(blocked)-2])) {
return true
}
}

View file

@ -1,8 +1,7 @@
{
"blacklist": [
"black/list",
"test/test1",
"example/repo2"
]
}
"blacklist": [
"black/list",
"test/test1",
"example/repo2"
]
}

View file

@ -56,7 +56,7 @@ func loadYAML(filePath string, out interface{}) error {
type Config struct {
Server ServerConfig
Log LoggerConfig
Log LogConfig
CORS CORSConfig
Auth AuthConfig
Blacklist BlacklistConfig
@ -69,7 +69,7 @@ type ServerConfig struct {
SizeLimit int `toml:"sizelimit"`
}
type LoggerConfig struct {
type LogConfig struct {
LogFilePath string `toml:"logfilepath"`
MaxLogSize int `toml:"maxlogsize"`
}

View file

@ -1,8 +1,7 @@
{
"whitelist": [
"white/list",
"white/test1",
"example/white"
]
}
"whitelist": [
"white/list",
"white/test1",
"example/white"
]
}

View file

@ -19,14 +19,16 @@ var (
logChannel = make(chan string, 100)
quitChannel = make(chan struct{})
logFileMutex sync.Mutex // 保护 logFile 的互斥锁
logFilePath = "/data/ghproxy/log/ghproxy.log"
)
// Init 初始化日志记录器,接受日志文件路径作为参数
func Init(logFilePath string, maxLogsize int) error {
func Init(logFilePath_input string, maxLogsize int) error {
logFileMutex.Lock()
defer logFileMutex.Unlock()
var err error
logFilePath = logFilePath_input
logFile, err = os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
return err
@ -62,6 +64,26 @@ func Logw(format string, args ...interface{}) {
Log(message)
}
// 日志等级INFO
func LogInfo(format string, args ...interface{}) {
message := fmt.Sprintf(format, args...)
output := fmt.Sprintf("[INFO] %s", message)
Log(output)
}
// 日志等级WARNING
func LogWarning(format string, args ...interface{}) {
message := fmt.Sprintf(format, args...)
output := fmt.Sprintf("[WARNING] %s", message)
Log(output)
}
// 日志等级ERROR
func LogError(format string, args ...interface{}) {
message := fmt.Sprintf(format, args...)
Log(message)
}
// Close 关闭日志文件
func Close() {
logFileMutex.Lock()

15
main.go
View file

@ -17,11 +17,18 @@ import (
var (
cfg *config.Config
logw = logger.Logw
router *gin.Engine
configfile = "/data/ghproxy/config/config.toml"
)
// 日志模块
var (
logw = logger.Logw
logInfo = logger.LogInfo
LogWarning = logger.LogWarning
logError = logger.LogError
)
func ReadFlag() {
cfgfile := flag.String("cfg", configfile, "config file path")
configfile = *cfgfile
@ -44,8 +51,8 @@ func setupLogger(cfg *config.Config) {
if err != nil {
log.Fatalf("Failed to initialize logger: %v", err)
}
logw("Logger initialized")
logw("Init Completed")
logInfo("Logger initialized")
logInfo("Init Completed")
}
func Loadlist(cfg *config.Config) {
@ -87,7 +94,7 @@ func main() {
// 启动服务器
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port))
if err != nil {
log.Fatalf("Error starting server: %v\n", err)
logError("Error starting server: %v\n", err)
}
fmt.Println("Program finished")

View file

@ -17,7 +17,13 @@ import (
"github.com/imroc/req/v3"
)
var logw = logger.Logw
// 日志模块
var (
logw = logger.Logw
logInfo = logger.LogInfo
LogWarning = logger.LogWarning
logError = logger.LogError
)
var exps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`),
@ -34,7 +40,7 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
matches := re.FindStringSubmatch(rawPath)
if len(matches) < 3 {
logw("Invalid URL: %s", rawPath)
LogWarning("Invalid URL: %s", rawPath)
c.String(http.StatusForbidden, "Invalid URL.")
return
}
@ -45,14 +51,14 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
pathmatches := regexp.MustCompile(`^([^/]+)/([^/]+)/([^/]+)/.*`)
pathParts := pathmatches.FindStringSubmatch(matches[2])
if len(pathParts) < 4 {
logw("Invalid path: %s", rawPath)
LogWarning("Invalid path: %s", rawPath)
c.String(http.StatusForbidden, "Invalid path; expected username/repo.")
return
}
username := pathParts[2]
repo := pathParts[3]
logw("Blacklist Check > Username: %s, Repo: %s", username, repo)
LogWarning("Blacklist Check > Username: %s, Repo: %s", username, repo)
fullrepo := fmt.Sprintf("%s/%s", username, repo)
// 白名单检查
@ -61,7 +67,7 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
if !whitelistpass {
errMsg := fmt.Sprintf("Whitelist Blocked repo: %s", fullrepo)
c.JSON(http.StatusForbidden, gin.H{"error": errMsg})
logw(errMsg)
LogWarning(errMsg)
return
}
}
@ -72,7 +78,7 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
if blacklistpass {
errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", fullrepo)
c.JSON(http.StatusForbidden, gin.H{"error": errMsg})
logw(errMsg)
LogWarning(errMsg)
return
}
}
@ -89,18 +95,18 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
if !auth.AuthHandler(c, cfg) {
c.AbortWithStatusJSON(401, gin.H{"error": "Unauthorized"})
logw("Unauthorized request: %s", rawPath)
LogWarning("Unauthorized request: %s", rawPath)
return
}
logw("Matches: %v", matches)
logInfo("Matches: %v", matches)
switch {
case exps[0].MatchString(rawPath), exps[1].MatchString(rawPath), exps[3].MatchString(rawPath), exps[4].MatchString(rawPath):
logw("%s Matched - USE proxy-chrome", rawPath)
logInfo("%s Matched - USE proxy-chrome", rawPath)
ProxyRequest(c, rawPath, cfg, "chrome")
case exps[2].MatchString(rawPath):
logw("%s Matched - USE proxy-git", rawPath)
logInfo("%s Matched - USE proxy-git", rawPath)
ProxyRequest(c, rawPath, cfg, "git")
default:
c.String(http.StatusForbidden, "Invalid input.")
@ -111,33 +117,18 @@ func NoRouteHandler(cfg *config.Config) gin.HandlerFunc {
func ProxyRequest(c *gin.Context, u string, cfg *config.Config, mode string) {
method := c.Request.Method
logw("%s %s", method, u)
logInfo("%s %s", method, u)
client := req.C()
client := createHTTPClient(mode)
switch mode {
case "chrome":
client.SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36").
SetTLSFingerprintChrome().
ImpersonateChrome()
case "git":
client.SetUserAgent("git/2.33.1")
}
body, err := io.ReadAll(c.Request.Body)
body, err := readRequestBody(c)
if err != nil {
HandleError(c, fmt.Sprintf("Failed to read request body: %v", err))
HandleError(c, err.Error())
return
}
defer c.Request.Body.Close()
req := client.R().SetBody(body)
for key, values := range c.Request.Header {
for _, value := range values {
req.SetHeader(key, value)
}
}
setRequestHeaders(c, req)
resp, err := SendRequest(req, method, u)
if err != nil {
@ -147,17 +138,56 @@ func ProxyRequest(c *gin.Context, u string, cfg *config.Config, mode string) {
defer resp.Body.Close()
if err := HandleResponseSize(resp, cfg, c); err != nil {
logw("Error handling response size: %v", err)
LogWarning("Error handling response size: %v", err)
return
}
CopyResponseHeaders(resp, c, cfg)
c.Status(resp.StatusCode)
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
logw("Failed to copy response body: %v", err)
if err := copyResponseBody(c, resp.Body); err != nil {
logError("Failed to copy response body: %v", err)
}
}
// createHTTPClient 创建并配置 HTTP 客户端
func createHTTPClient(mode string) *req.Client {
client := req.C()
switch mode {
case "chrome":
client.SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36").
SetTLSFingerprintChrome().
ImpersonateChrome()
case "git":
client.SetUserAgent("git/2.33.1")
}
return client
}
// readRequestBody 读取请求体
func readRequestBody(c *gin.Context) ([]byte, error) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}
defer c.Request.Body.Close()
return body, nil
}
// setRequestHeaders 设置请求头
func setRequestHeaders(c *gin.Context, req *req.Request) {
for key, values := range c.Request.Header {
for _, value := range values {
req.SetHeader(key, value)
}
}
}
// copyResponseBody 复制响应体到客户端
func copyResponseBody(c *gin.Context, respBody io.Reader) error {
_, err := io.Copy(c.Writer, respBody)
return err
}
func SendRequest(req *req.Request, method, url string) (*req.Response, error) {
switch method {
case "GET":
@ -169,7 +199,7 @@ func SendRequest(req *req.Request, method, url string) (*req.Response, error) {
case "DELETE":
return req.Delete(url)
default:
logw("Unsupported method: %s", method)
logInfo("Unsupported method: %s", method)
return nil, fmt.Errorf("unsupported method: %s", method)
}
}
@ -181,14 +211,25 @@ func HandleResponseSize(resp *req.Response, cfg *config.Config, c *gin.Context)
if err == nil && size > cfg.Server.SizeLimit {
finalURL := resp.Request.URL.String()
c.Redirect(http.StatusMovedPermanently, finalURL)
logw("Redirecting to %s due to size limit (%d bytes)", finalURL, size)
return fmt.Errorf("response size exceeds limit")
LogWarning("Size limit exceeded: %s, Size: %d", finalURL, size)
return fmt.Errorf("size limit exceeded: %d", size)
}
}
return nil
}
func CopyResponseHeaders(resp *req.Response, c *gin.Context, cfg *config.Config) {
removeHeaders(resp)
copyHeaders(resp, c)
setCORSHeaders(c, cfg)
setDefaultHeaders(c)
}
// removeHeaders 移除指定的响应头
func removeHeaders(resp *req.Response) {
headersToRemove := map[string]struct{}{
"Content-Security-Policy": {},
"Referrer-Policy": {},
@ -198,35 +239,45 @@ func CopyResponseHeaders(resp *req.Response, c *gin.Context, cfg *config.Config)
for header := range headersToRemove {
resp.Header.Del(header)
}
}
// copyHeaders 复制响应头到 Gin 上下文
func copyHeaders(resp *req.Response, c *gin.Context) {
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
}
c.Header("Access-Control-Allow-Origin", "")
// setCORSHeaders 设置 CORS 相关的响应头
func setCORSHeaders(c *gin.Context, cfg *config.Config) {
if cfg.CORS.Enabled {
c.Header("Access-Control-Allow-Origin", "*")
} else {
c.Header("Access-Control-Allow-Origin", "")
}
}
// setDefaultHeaders 设置默认的响应头
func setDefaultHeaders(c *gin.Context) {
c.Header("Age", "10")
c.Header("Cache-Control", "max-age=300")
}
func HandleError(c *gin.Context, message string) {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", message))
logw(message)
LogWarning(message)
}
func CheckURL(u string) []string {
for _, exp := range exps {
if matches := exp.FindStringSubmatch(u); matches != nil {
logw("URL matched: %s, Matches: %v", u, matches[1:])
logInfo("URL matched: %s, Matches: %v", u, matches[1:])
return matches[1:]
}
}
errMsg := fmt.Sprintf("Invalid URL: %s", u)
logw(errMsg)
LogWarning(errMsg)
return nil
}