mirror of
https://github.com/WJQSERVER-STUDIO/ghproxy.git
synced 2026-02-03 08:11:11 +08:00
25w30e
This commit is contained in:
parent
b955c915ff
commit
7a6544c6c9
12 changed files with 170 additions and 182 deletions
|
|
@ -1,5 +1,13 @@
|
||||||
# 更新日志
|
# 更新日志
|
||||||
|
|
||||||
|
25w30e - 2025-04-24
|
||||||
|
---
|
||||||
|
- PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用;
|
||||||
|
- CHANGE: 改进`rate`模块, 避免并发竞争问题
|
||||||
|
- CHANGE: 将大部分状态码返回改为新的`html/tmpl`方式处理
|
||||||
|
- CHANGE: 修改部分log等级
|
||||||
|
- FIX: 修正默认配置的填充错误
|
||||||
|
|
||||||
25w30d - 2025-04-22
|
25w30d - 2025-04-22
|
||||||
---
|
---
|
||||||
- PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用;
|
- PRE-RELEASE: 此版本是v3.1.0预发布版本,请勿在生产环境中使用;
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
25w30d
|
25w30e
|
||||||
|
|
@ -46,7 +46,7 @@ func AuthHandler(c *app.RequestContext, cfg *config.Config) (isValid bool, err e
|
||||||
logError("Auth method not set")
|
logError("Auth method not set")
|
||||||
return true, nil
|
return true, nil
|
||||||
} else {
|
} else {
|
||||||
logError("Auth method not supported")
|
logError("Auth method not supported %s", cfg.Auth.Method)
|
||||||
return false, fmt.Errorf(fmt.Sprintf("Auth method %s not supported", cfg.Auth.Method))
|
return false, fmt.Errorf("%s", fmt.Sprintf("Auth method %s not supported", cfg.Auth.Method))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -228,11 +228,11 @@ func DefaultConfig() *Config {
|
||||||
},
|
},
|
||||||
Blacklist: BlacklistConfig{
|
Blacklist: BlacklistConfig{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
BlacklistFile: "/data/ghproxy/config/blacklist.txt",
|
BlacklistFile: "/data/ghproxy/config/blacklist.json",
|
||||||
},
|
},
|
||||||
Whitelist: WhitelistConfig{
|
Whitelist: WhitelistConfig{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
WhitelistFile: "/data/ghproxy/config/whitelist.txt",
|
WhitelistFile: "/data/ghproxy/config/whitelist.json",
|
||||||
},
|
},
|
||||||
RateLimit: RateLimitConfig{
|
RateLimit: RateLimitConfig{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,7 @@ func AuthPassThrough(c *app.RequestContext, cfg *config.Config, req *http.Reques
|
||||||
req.Header.Set("Authorization", "token "+token)
|
req.Header.Set("Authorization", "token "+token)
|
||||||
} else {
|
} else {
|
||||||
logWarning("%s %s %s %s %s Auth-Error: Conflict Auth Method", c.ClientIP(), c.Method(), string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol())
|
logWarning("%s %s %s %s %s Auth-Error: Conflict Auth Method", c.ClientIP(), c.Method(), string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol())
|
||||||
// 500 Internal Server Error
|
ErrorPage(c, NewErrorWithStatusLookup(500, "Conflict Auth Method"))
|
||||||
c.JSON(http.StatusInternalServerError, map[string]string{"error": "Conflict Auth Method"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "header":
|
case "header":
|
||||||
|
|
@ -28,8 +27,7 @@ func AuthPassThrough(c *app.RequestContext, cfg *config.Config, req *http.Reques
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
logWarning("%s %s %s %s %s Invalid Auth Method / Auth Method is not be set", c.ClientIP(), c.Method(), string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol())
|
logWarning("%s %s %s %s %s Invalid Auth Method / Auth Method is not be set", c.ClientIP(), c.Method(), string(c.Path()), c.UserAgent(), c.Request.Header.GetProtocol())
|
||||||
// 500 Internal Server Error
|
ErrorPage(c, NewErrorWithStatusLookup(500, "Invalid Auth Method / Auth Method is not be set"))
|
||||||
c.JSON(http.StatusInternalServerError, map[string]string{"error": "Invalid Auth Method / Auth Method is not be set"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -63,8 +63,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
||||||
|
|
||||||
// 错误处理(404)
|
// 错误处理(404)
|
||||||
if resp.StatusCode == 404 {
|
if resp.StatusCode == 404 {
|
||||||
//c.String(http.StatusNotFound, "File Not Found")
|
ErrorPage(c, NewErrorWithStatusLookup(404, "Page Not Found (From Github)"))
|
||||||
c.Status(http.StatusNotFound)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -89,25 +88,12 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError("Failed to close response body: %v", err)
|
logError("Failed to close response body: %v", err)
|
||||||
}
|
}
|
||||||
c.Redirect(http.StatusMovedPermanently, []byte(finalURL))
|
c.Redirect(301, []byte(finalURL))
|
||||||
logWarning("%s %s %s %s %s Final-URL: %s Size-Limit-Exceeded: %d", c.ClientIP(), c.Method(), c.Path(), c.UserAgent(), c.Request.Header.GetProtocol(), finalURL, bodySize)
|
logWarning("%s %s %s %s %s Final-URL: %s Size-Limit-Exceeded: %d", c.ClientIP(), c.Method(), c.Path(), c.UserAgent(), c.Request.Header.GetProtocol(), finalURL, bodySize)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
for header := range headersToRemove {
|
|
||||||
resp.Header.Del(header)
|
|
||||||
}
|
|
||||||
|
|
||||||
for key := range resp.Header {
|
|
||||||
var values []string = resp.Header.Values(key)
|
|
||||||
for _, value := range values {
|
|
||||||
c.Header(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
// 复制响应头,排除需要移除的 header
|
// 复制响应头,排除需要移除的 header
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
if _, shouldRemove := respHeadersToRemove[key]; !shouldRemove {
|
if _, shouldRemove := respHeadersToRemove[key]; !shouldRemove {
|
||||||
|
|
@ -137,16 +123,16 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
||||||
compress = "gzip"
|
compress = "gzip"
|
||||||
}
|
}
|
||||||
|
|
||||||
logInfo("Is Shell: %s %s %s %s %s", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol())
|
logDebug("Use Shell Editor: %s %s %s %s %s", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol())
|
||||||
c.Header("Content-Length", "")
|
c.Header("Content-Length", "")
|
||||||
|
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
|
|
||||||
reader, _, err = processLinks(resp.Body, compress, string(c.Request.Host()), cfg)
|
reader, _, err = processLinks(resp.Body, compress, string(c.Request.Host()), cfg)
|
||||||
c.SetBodyStream(reader, -1)
|
c.SetBodyStream(reader, -1)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err)
|
logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err)
|
||||||
|
ErrorPage(c, NewErrorWithStatusLookup(500, fmt.Sprintf("Failed to copy response body: %v", err)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"io/fs"
|
||||||
|
|
||||||
"github.com/WJQSERVER-STUDIO/go-utils/logger"
|
"github.com/WJQSERVER-STUDIO/go-utils/logger"
|
||||||
"github.com/cloudwego/hertz/pkg/app"
|
"github.com/cloudwego/hertz/pkg/app"
|
||||||
|
|
@ -18,7 +21,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func HandleError(c *app.RequestContext, message string) {
|
func HandleError(c *app.RequestContext, message string) {
|
||||||
c.JSON(http.StatusInternalServerError, map[string]string{"error": message})
|
ErrorPage(c, NewErrorWithStatusLookup(500, message))
|
||||||
logError(message)
|
logError(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -55,6 +58,12 @@ var (
|
||||||
StatusText: "页面未找到",
|
StatusText: "页面未找到",
|
||||||
HelpInfo: "抱歉,您访问的页面不存在。",
|
HelpInfo: "抱歉,您访问的页面不存在。",
|
||||||
}
|
}
|
||||||
|
ErrTooManyRequests = &GHProxyErrors{
|
||||||
|
StatusCode: 429,
|
||||||
|
StatusDesc: "Too Many Requests",
|
||||||
|
StatusText: "请求过于频繁",
|
||||||
|
HelpInfo: "您的请求过于频繁,请稍后再试。",
|
||||||
|
}
|
||||||
ErrInternalServerError = &GHProxyErrors{
|
ErrInternalServerError = &GHProxyErrors{
|
||||||
StatusCode: 500,
|
StatusCode: 500,
|
||||||
StatusDesc: "Internal Server Error",
|
StatusDesc: "Internal Server Error",
|
||||||
|
|
@ -71,6 +80,7 @@ func init() {
|
||||||
ErrAuthHeaderUnavailable.StatusCode: ErrAuthHeaderUnavailable,
|
ErrAuthHeaderUnavailable.StatusCode: ErrAuthHeaderUnavailable,
|
||||||
ErrForbidden.StatusCode: ErrForbidden,
|
ErrForbidden.StatusCode: ErrForbidden,
|
||||||
ErrNotFound.StatusCode: ErrNotFound,
|
ErrNotFound.StatusCode: ErrNotFound,
|
||||||
|
ErrTooManyRequests.StatusCode: ErrTooManyRequests,
|
||||||
ErrInternalServerError.StatusCode: ErrInternalServerError,
|
ErrInternalServerError.StatusCode: ErrInternalServerError,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -93,3 +103,65 @@ func NewErrorWithStatusLookup(statusCode int, errMsg string) *GHProxyErrors {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errPagesFs fs.FS
|
||||||
|
|
||||||
|
func InitErrPagesFS(pages fs.FS) error {
|
||||||
|
var err error
|
||||||
|
errPagesFs, err = fs.Sub(pages, "pages/err")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorPageData struct {
|
||||||
|
StatusCode int
|
||||||
|
StatusDesc string
|
||||||
|
StatusText string
|
||||||
|
HelpInfo string
|
||||||
|
ErrorMessage string
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrPageUnwarper(errInfo *GHProxyErrors) ErrorPageData {
|
||||||
|
return ErrorPageData{
|
||||||
|
StatusCode: errInfo.StatusCode,
|
||||||
|
StatusDesc: errInfo.StatusDesc,
|
||||||
|
StatusText: errInfo.StatusText,
|
||||||
|
HelpInfo: errInfo.HelpInfo,
|
||||||
|
ErrorMessage: errInfo.ErrorMessage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorPage(c *app.RequestContext, errInfo *GHProxyErrors) {
|
||||||
|
pageData, err := htmlTemplateRender(errPagesFs, ErrPageUnwarper(errInfo))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(errInfo.StatusCode, map[string]string{"error": errInfo.ErrorMessage})
|
||||||
|
logDebug("Error reading page.tmpl: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Data(errInfo.StatusCode, "text/html; charset=utf-8", pageData)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func htmlTemplateRender(fsys fs.FS, data interface{}) ([]byte, error) {
|
||||||
|
tmplPath := "page.tmpl"
|
||||||
|
tmpl, err := template.ParseFS(fsys, tmplPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing template: %w", err)
|
||||||
|
}
|
||||||
|
if tmpl == nil {
|
||||||
|
return nil, fmt.Errorf("template is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建一个 bytes.Buffer 用于存储渲染结果
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
err = tmpl.Execute(&buf, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error executing template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回 buffer 的内容作为 []byte
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"ghproxy/config"
|
"ghproxy/config"
|
||||||
"ghproxy/rate"
|
"ghproxy/rate"
|
||||||
"net/http"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
|
@ -29,12 +29,11 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra
|
||||||
|
|
||||||
rawPath = strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/
|
rawPath = strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/
|
||||||
matches = re.FindStringSubmatch(rawPath) // 匹配路径
|
matches = re.FindStringSubmatch(rawPath) // 匹配路径
|
||||||
logDebug("URL: %v", matches)
|
|
||||||
|
|
||||||
// 匹配路径错误处理
|
// 匹配路径错误处理
|
||||||
if len(matches) < 3 {
|
if len(matches) < 3 {
|
||||||
logWarning("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol())
|
logWarning("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), c.Path(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol())
|
||||||
c.JSON(http.StatusForbidden, map[string]string{"error": "Invalid URL Format"})
|
ErrorPage(c, NewErrorWithStatusLookup(400, fmt.Sprintf("Invalid URL Format: %s", c.Path())))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -80,8 +79,8 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra
|
||||||
case "clone":
|
case "clone":
|
||||||
GitReq(ctx, c, rawPath, cfg, "git")
|
GitReq(ctx, c, rawPath, cfg, "git")
|
||||||
default:
|
default:
|
||||||
c.JSON(http.StatusForbidden, map[string]string{"error": "Invalid input."})
|
ErrorPage(c, NewErrorWithStatusLookup(500, "Matched But Not Matched"))
|
||||||
logError("Invalid input")
|
logError("Matched But Not Matched Path: %s rawPath: %s matcher: %s", c.Path(), rawPath, matcher)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,6 @@ import (
|
||||||
"github.com/cloudwego/hertz/pkg/app"
|
"github.com/cloudwego/hertz/pkg/app"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
|
||||||
// 设置请求头
|
|
||||||
func setRequestHeaders(c *app.RequestContext, req *http.Request) {
|
|
||||||
c.Request.Header.VisitAll(func(key, value []byte) {
|
|
||||||
req.Header.Set(string(key), string(value))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
func setRequestHeaders(c *app.RequestContext, req *http.Request) {
|
func setRequestHeaders(c *app.RequestContext, req *http.Request) {
|
||||||
c.Request.Header.VisitAll(func(key, value []byte) {
|
c.Request.Header.VisitAll(func(key, value []byte) {
|
||||||
headerKey := string(key)
|
headerKey := string(key)
|
||||||
|
|
@ -22,16 +13,5 @@ func setRequestHeaders(c *app.RequestContext, req *http.Request) {
|
||||||
if _, shouldRemove := reqHeadersToRemove[headerKey]; !shouldRemove {
|
if _, shouldRemove := reqHeadersToRemove[headerKey]; !shouldRemove {
|
||||||
req.Header.Set(headerKey, headerValue)
|
req.Header.Set(headerKey, headerValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
// removeWSHeader removes the "Upgrade" and "Connection" headers from the given
|
|
||||||
// Request, which are added by the client when it wants to upgrade the
|
|
||||||
// connection to a WebSocket connection.
|
|
||||||
func removeWSHeader(req *http.Request) {
|
|
||||||
req.Header.Del("Upgrade")
|
|
||||||
req.Header.Del("Connection")
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"ghproxy/config"
|
"ghproxy/config"
|
||||||
"ghproxy/rate"
|
"ghproxy/rate"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/cloudwego/hertz/pkg/app"
|
"github.com/cloudwego/hertz/pkg/app"
|
||||||
|
|
@ -57,8 +56,7 @@ func RoutingHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra
|
||||||
// 为rawpath加入https:// 头
|
// 为rawpath加入https:// 头
|
||||||
rawPath = "https://" + rawPath
|
rawPath = "https://" + rawPath
|
||||||
|
|
||||||
// IP METHOD URL USERAGENT PROTO MATCHES
|
logDebug("Matched: %v", matcher)
|
||||||
logDebug("%s %s %s %s %s Matched: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), matcher)
|
|
||||||
|
|
||||||
switch matcher {
|
switch matcher {
|
||||||
case "releases", "blob", "raw", "gist", "api":
|
case "releases", "blob", "raw", "gist", "api":
|
||||||
|
|
@ -66,8 +64,8 @@ func RoutingHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra
|
||||||
case "clone":
|
case "clone":
|
||||||
GitReq(ctx, c, rawPath, cfg, "git")
|
GitReq(ctx, c, rawPath, cfg, "git")
|
||||||
default:
|
default:
|
||||||
c.JSON(http.StatusForbidden, map[string]string{"error": "Invalid input."})
|
ErrorPage(c, NewErrorWithStatusLookup(500, "Matched But Not Matched"))
|
||||||
logError("Invalid input")
|
logError("Matched But Not Matched Path: %s rawPath: %s matcher: %s", c.Path(), rawPath, matcher)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
109
proxy/utils.go
109
proxy/utils.go
|
|
@ -1,27 +1,22 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"ghproxy/auth"
|
"ghproxy/auth"
|
||||||
"ghproxy/config"
|
"ghproxy/config"
|
||||||
"ghproxy/rate"
|
"ghproxy/rate"
|
||||||
"html/template"
|
|
||||||
"io/fs"
|
|
||||||
|
|
||||||
"github.com/cloudwego/hertz/pkg/app"
|
"github.com/cloudwego/hertz/pkg/app"
|
||||||
)
|
)
|
||||||
|
|
||||||
func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo string, rawPath string) bool {
|
func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo string, rawPath string) bool {
|
||||||
var errMsg string
|
|
||||||
|
|
||||||
// 白名单检查
|
// 白名单检查
|
||||||
if cfg.Whitelist.Enabled {
|
if cfg.Whitelist.Enabled {
|
||||||
var whitelist bool
|
var whitelist bool
|
||||||
whitelist = auth.CheckWhitelist(user, repo)
|
whitelist = auth.CheckWhitelist(user, repo)
|
||||||
if !whitelist {
|
if !whitelist {
|
||||||
errMsg = fmt.Sprintf("Whitelist Blocked repo: %s/%s", user, repo)
|
ErrorPage(c, NewErrorWithStatusLookup(403, fmt.Sprintf("Whitelist Blocked repo: %s/%s", user, repo)))
|
||||||
c.JSON(403, map[string]string{"error": errMsg})
|
|
||||||
logInfo("%s %s %s %s %s Whitelist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo)
|
logInfo("%s %s %s %s %s Whitelist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
@ -32,8 +27,7 @@ func listCheck(cfg *config.Config, c *app.RequestContext, user string, repo stri
|
||||||
var blacklist bool
|
var blacklist bool
|
||||||
blacklist = auth.CheckBlacklist(user, repo)
|
blacklist = auth.CheckBlacklist(user, repo)
|
||||||
if blacklist {
|
if blacklist {
|
||||||
errMsg = fmt.Sprintf("Blacklist Blocked repo: %s/%s", user, repo)
|
ErrorPage(c, NewErrorWithStatusLookup(403, fmt.Sprintf("Blacklist Blocked repo: %s/%s", user, repo)))
|
||||||
c.JSON(403, map[string]string{"error": errMsg})
|
|
||||||
logInfo("%s %s %s %s %s Blacklist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo)
|
logInfo("%s %s %s %s %s Blacklist Blocked repo: %s/%s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
@ -48,7 +42,7 @@ func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPat
|
||||||
|
|
||||||
if matcher == "api" && !cfg.Auth.ForceAllowApi {
|
if matcher == "api" && !cfg.Auth.ForceAllowApi {
|
||||||
if cfg.Auth.Method != "header" || !cfg.Auth.Enabled {
|
if cfg.Auth.Method != "header" || !cfg.Auth.Enabled {
|
||||||
c.JSON(403, map[string]string{"error": "Github API Req without AuthHeader is Not Allowed"})
|
ErrorPage(c, NewErrorWithStatusLookup(403, "Github API Req without AuthHeader is Not Allowed"))
|
||||||
logInfo("%s %s %s %s %s AuthHeader Unavailable", c.ClientIP(), c.Method(), rawPath)
|
logInfo("%s %s %s %s %s AuthHeader Unavailable", c.ClientIP(), c.Method(), rawPath)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
@ -59,7 +53,7 @@ func authCheck(c *app.RequestContext, cfg *config.Config, matcher string, rawPat
|
||||||
var authcheck bool
|
var authcheck bool
|
||||||
authcheck, err = auth.AuthHandler(c, cfg)
|
authcheck, err = auth.AuthHandler(c, cfg)
|
||||||
if !authcheck {
|
if !authcheck {
|
||||||
c.JSON(401, map[string]string{"error": "Unauthorized"})
|
ErrorPage(c, NewErrorWithStatusLookup(401, fmt.Sprintf("Unauthorized: %v", err)))
|
||||||
logInfo("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err)
|
logInfo("%s %s %s %s %s Auth-Error: %v", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
@ -81,105 +75,16 @@ func rateCheck(cfg *config.Config, c *app.RequestContext, limiter *rate.RateLimi
|
||||||
allowed = limiter.Allow()
|
allowed = limiter.Allow()
|
||||||
default:
|
default:
|
||||||
logWarning("Invalid RateLimit Method")
|
logWarning("Invalid RateLimit Method")
|
||||||
c.JSON(500, map[string]string{"error": "Invalid RateLimit Method"})
|
ErrorPage(c, NewErrorWithStatusLookup(500, "Invalid RateLimit Method"))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !allowed {
|
if !allowed {
|
||||||
c.JSON(429, map[string]string{"error": "Too Many Requests"})
|
ErrorPage(c, NewErrorWithStatusLookup(429, fmt.Sprintf("Too Many Requests; Rate Limit is %d per minute", cfg.RateLimit.RatePerMinute)))
|
||||||
logWarning("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol())
|
logInfo("%s %s %s %s %s 429-TooManyRequests", c.ClientIP(), c.Method(), c.Request.RequestURI(), c.Request.Header.UserAgent(), c.Request.Header.GetProtocol())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
var errPagesFs fs.FS
|
|
||||||
|
|
||||||
func InitErrPagesFS(pages fs.FS) error {
|
|
||||||
var err error
|
|
||||||
errPagesFs, err = fs.Sub(pages, "pages/err")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ErrorPageData struct {
|
|
||||||
StatusCode int
|
|
||||||
StatusDesc string
|
|
||||||
StatusText string
|
|
||||||
HelpInfo string
|
|
||||||
ErrorMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
func ErrPageUnwarper(errInfo *GHProxyErrors) ErrorPageData {
|
|
||||||
return ErrorPageData{
|
|
||||||
StatusCode: errInfo.StatusCode,
|
|
||||||
StatusDesc: errInfo.StatusDesc,
|
|
||||||
StatusText: errInfo.StatusText,
|
|
||||||
HelpInfo: errInfo.HelpInfo,
|
|
||||||
ErrorMessage: errInfo.ErrorMessage,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ErrorPage(c *app.RequestContext, errInfo *GHProxyErrors) {
|
|
||||||
pageData, _ := htmlTemplateRender(errPagesFs, ErrPageUnwarper(errInfo))
|
|
||||||
/*
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(errInfo.StatusCode, map[string]string{"error": errInfo.ErrorMessage})
|
|
||||||
logDebug("Error reading page.tmpl: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
c.Data(errInfo.StatusCode, "text/html; charset=utf-8", pageData)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func NotFoundPage(c *app.RequestContext) {
|
|
||||||
/*
|
|
||||||
pageData, err := fs.ReadFile(errPagesFs, "404.html")
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(404, map[string]string{"error": "Not Found"})
|
|
||||||
logDebug("Error reading 404.html: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
pageData, err := htmlTemplateRender(errPagesFs, ErrorPageData{
|
|
||||||
StatusCode: 404,
|
|
||||||
StatusDesc: "Not Found",
|
|
||||||
StatusText: "The requested URL was not found on this server.",
|
|
||||||
ErrorMessage: "The requested URL was not found on this server.",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(404, map[string]string{"error": "Not Found"})
|
|
||||||
logDebug("Error reading 404.html: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Data(404, "text/html; charset=utf-8", pageData)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func htmlTemplateRender(fsys fs.FS, data interface{}) ([]byte, error) {
|
|
||||||
tmplPath := "page.tmpl"
|
|
||||||
tmpl, err := template.ParseFS(fsys, tmplPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing template: %w", err)
|
|
||||||
}
|
|
||||||
if tmpl == nil {
|
|
||||||
return nil, fmt.Errorf("template is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建一个 bytes.Buffer 用于存储渲染结果
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
err = tmpl.Execute(&buf, data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error executing template: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 返回 buffer 的内容作为 []byte
|
|
||||||
return buf.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
82
rate/rate.go
82
rate/rate.go
|
|
@ -1,13 +1,14 @@
|
||||||
package rate
|
package rate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/WJQSERVER-STUDIO/go-utils/logger"
|
"github.com/WJQSERVER-STUDIO/go-utils/logger"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 日志输出
|
// 日志模块
|
||||||
var (
|
var (
|
||||||
logw = logger.Logw
|
logw = logger.Logw
|
||||||
logDump = logger.LogDump
|
logDump = logger.LogDump
|
||||||
|
|
@ -17,49 +18,90 @@ var (
|
||||||
logError = logger.LogError
|
logError = logger.LogError
|
||||||
)
|
)
|
||||||
|
|
||||||
// 总体限流器
|
// RateLimiter 总体限流器
|
||||||
type RateLimiter struct {
|
type RateLimiter struct {
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
}
|
}
|
||||||
|
|
||||||
// 基于IP的限流器
|
// New 创建一个总体限流器
|
||||||
type IPRateLimiter struct {
|
|
||||||
limiters map[string]*RateLimiter
|
|
||||||
limit int
|
|
||||||
burst int
|
|
||||||
duration time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(limit int, burst int, duration time.Duration) *RateLimiter {
|
func New(limit int, burst int, duration time.Duration) *RateLimiter {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 1
|
||||||
|
logWarning("rate limit per minute must be positive, setting to 1")
|
||||||
|
}
|
||||||
|
if burst <= 0 {
|
||||||
|
burst = 1
|
||||||
|
logWarning("rate limit burst must be positive, setting to 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
rateLimit := rate.Limit(float64(limit) / duration.Seconds())
|
||||||
|
|
||||||
return &RateLimiter{
|
return &RateLimiter{
|
||||||
limiter: rate.NewLimiter(rate.Limit(float64(limit)/duration.Seconds()), burst),
|
limiter: rate.NewLimiter(rateLimit, burst),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Allow 检查是否允许请求通过
|
||||||
func (rl *RateLimiter) Allow() bool {
|
func (rl *RateLimiter) Allow() bool {
|
||||||
return rl.limiter.Allow()
|
return rl.limiter.Allow()
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIPRateLimiter(limit int, burst int, duration time.Duration) *IPRateLimiter {
|
// IPRateLimiter 基于IP的限流器
|
||||||
|
type IPRateLimiter struct {
|
||||||
|
limiters map[string]*RateLimiter // 用户级限流器 map
|
||||||
|
mu sync.RWMutex // 保护 limiters map
|
||||||
|
limit int // 每 duration 时间段内允许的请求数
|
||||||
|
burst int // 突发请求数
|
||||||
|
duration time.Duration // 限流周期
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIPRateLimiter 创建一个基于IP的限流器
|
||||||
|
func NewIPRateLimiter(ipLimit int, ipBurst int, duration time.Duration) *IPRateLimiter {
|
||||||
|
if ipLimit <= 0 {
|
||||||
|
ipLimit = 1
|
||||||
|
logWarning("IP rate limit per minute must be positive, setting to 1")
|
||||||
|
}
|
||||||
|
if ipBurst <= 0 {
|
||||||
|
ipBurst = 1
|
||||||
|
logWarning("IP rate limit burst must be positive, setting to 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
logInfo("IP Rate Limiter initialized with limit: %d, burst: %d, duration: %v", ipLimit, ipBurst, duration)
|
||||||
|
|
||||||
return &IPRateLimiter{
|
return &IPRateLimiter{
|
||||||
limiters: make(map[string]*RateLimiter),
|
limiters: make(map[string]*RateLimiter),
|
||||||
limit: limit,
|
limit: ipLimit,
|
||||||
burst: burst,
|
burst: ipBurst,
|
||||||
duration: duration,
|
duration: duration,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Allow 检查给定IP的请求是否允许通过
|
||||||
func (rl *IPRateLimiter) Allow(ip string) bool {
|
func (rl *IPRateLimiter) Allow(ip string) bool {
|
||||||
if ip == "" {
|
if ip == "" {
|
||||||
logWarning("empty ip")
|
logWarning("empty ip for rate limiting")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
limiter, ok := rl.limiters[ip]
|
// 使用读锁快速查找
|
||||||
if !ok {
|
rl.mu.RLock()
|
||||||
// 创建新的 RateLimiter 并存储
|
limiter, found := rl.limiters[ip]
|
||||||
limiter = New(rl.limit, rl.burst, rl.duration)
|
rl.mu.RUnlock()
|
||||||
rl.limiters[ip] = limiter
|
|
||||||
|
if found {
|
||||||
|
return limiter.Allow()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 未找到,获取写锁来创建和添加
|
||||||
|
rl.mu.Lock()
|
||||||
|
// 双重检查
|
||||||
|
limiter, found = rl.limiters[ip]
|
||||||
|
if !found {
|
||||||
|
newL := New(rl.limit, rl.burst, rl.duration)
|
||||||
|
rl.limiters[ip] = newL
|
||||||
|
limiter = newL
|
||||||
|
}
|
||||||
|
rl.mu.Unlock()
|
||||||
|
|
||||||
return limiter.Allow()
|
return limiter.Allow()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue