mirror of
https://github.com/WJQSERVER-STUDIO/ghproxy.git
synced 2026-02-03 08:11:11 +08:00
Fix: Optimize header forwarding by excluding headers in a single pass
This commit is contained in:
parent
26a42b6510
commit
7e5b12dff8
2 changed files with 73 additions and 30 deletions
|
|
@ -5,20 +5,39 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"ghproxy/config"
|
"ghproxy/config"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/cloudwego/hertz/pkg/app"
|
"github.com/cloudwego/hertz/pkg/app"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
headersToRemove = map[string]struct{}{
|
||||||
|
"Content-Security-Policy": {},
|
||||||
|
"Referrer-Policy": {},
|
||||||
|
"Strict-Transport-Security": {},
|
||||||
|
"X-Github-Request-Id": {},
|
||||||
|
"X-Timer": {},
|
||||||
|
"X-Served-By": {},
|
||||||
|
"X-Fastly-Request-Id": {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, cfg *config.Config, matcher string) {
|
func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, cfg *config.Config, matcher string) {
|
||||||
method := c.Request.Method
|
|
||||||
|
|
||||||
body := c.Request.Body()
|
var (
|
||||||
|
method []byte
|
||||||
|
bodyReader *bytes.Buffer
|
||||||
|
req *http.Request
|
||||||
|
resp *http.Response
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
bodyReader := bytes.NewBuffer(body)
|
method = c.Request.Method()
|
||||||
|
bodyReader = bytes.NewBuffer(c.Request.Body())
|
||||||
|
|
||||||
req, err := client.NewRequest(string(method()), u, bodyReader)
|
req, err = client.NewRequest(string(method), u, bodyReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
HandleError(c, fmt.Sprintf("Failed to create request: %v", err))
|
HandleError(c, fmt.Sprintf("Failed to create request: %v", err))
|
||||||
return
|
return
|
||||||
|
|
@ -27,7 +46,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
||||||
removeWSHeader(req) // 删除Conection Upgrade头, 避免与HTTP/2冲突(检查是否存在Upgrade头)
|
removeWSHeader(req) // 删除Conection Upgrade头, 避免与HTTP/2冲突(检查是否存在Upgrade头)
|
||||||
AuthPassThrough(c, cfg, req)
|
AuthPassThrough(c, cfg, req)
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err = client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
HandleError(c, fmt.Sprintf("Failed to send request: %v", err))
|
HandleError(c, fmt.Sprintf("Failed to send request: %v", err))
|
||||||
return
|
return
|
||||||
|
|
@ -55,8 +74,9 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
||||||
bodySize = -1
|
bodySize = -1
|
||||||
}
|
}
|
||||||
if err == nil && bodySize > sizelimit {
|
if err == nil && bodySize > sizelimit {
|
||||||
finalURL := resp.Request.URL.String()
|
var finalURL string
|
||||||
err := resp.Body.Close()
|
finalURL = resp.Request.URL.String()
|
||||||
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError("Failed to close response body: %v", err)
|
logError("Failed to close response body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -66,20 +86,26 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, values := range resp.Header {
|
/*
|
||||||
|
for header := range headersToRemove {
|
||||||
|
resp.Header.Del(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range resp.Header {
|
||||||
|
var values []string = resp.Header.Values(key)
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
headersToRemove := map[string]struct{}{
|
// 复制响应头,排除需要移除的 header
|
||||||
"Content-Security-Policy": {},
|
for key, values := range resp.Header {
|
||||||
"Referrer-Policy": {},
|
if _, shouldRemove := headersToRemove[key]; !shouldRemove {
|
||||||
"Strict-Transport-Security": {},
|
for _, value := range values {
|
||||||
|
c.Header(key, value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for header := range headersToRemove {
|
|
||||||
resp.Header.Del(header)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch cfg.Server.Cors {
|
switch cfg.Server.Cors {
|
||||||
|
|
@ -105,7 +131,9 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c
|
||||||
logInfo("Is Shell: %s %s %s %s %s", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol())
|
logInfo("Is Shell: %s %s %s %s %s", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol())
|
||||||
c.Header("Content-Length", "")
|
c.Header("Content-Length", "")
|
||||||
|
|
||||||
reader, _, err := processLinks(resp.Body, compress, string(c.Request.Host()), cfg)
|
var reader io.Reader
|
||||||
|
|
||||||
|
reader, _, err = processLinks(resp.Body, compress, string(c.Request.Host()), cfg)
|
||||||
c.SetBodyStream(reader, -1)
|
c.SetBodyStream(reader, -1)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -41,13 +41,19 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rawPath := strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/
|
var (
|
||||||
matches := re.FindStringSubmatch(rawPath) // 匹配路径
|
rawPath string
|
||||||
|
matches []string
|
||||||
|
errMsg string
|
||||||
|
)
|
||||||
|
|
||||||
|
rawPath = strings.TrimPrefix(string(c.Request.RequestURI()), "/") // 去掉前缀/
|
||||||
|
matches = re.FindStringSubmatch(rawPath) // 匹配路径
|
||||||
logInfo("URL: %v", matches)
|
logInfo("URL: %v", matches)
|
||||||
|
|
||||||
// 匹配路径错误处理
|
// 匹配路径错误处理
|
||||||
if len(matches) < 3 {
|
if len(matches) < 3 {
|
||||||
errMsg := fmt.Sprintf("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol())
|
errMsg = fmt.Sprintf("%s %s %s %s %s Invalid URL", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol())
|
||||||
logWarning(errMsg)
|
logWarning(errMsg)
|
||||||
c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath)
|
c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath)
|
||||||
return
|
return
|
||||||
|
|
@ -56,7 +62,14 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra
|
||||||
// 制作url
|
// 制作url
|
||||||
rawPath = "https://" + matches[2]
|
rawPath = "https://" + matches[2]
|
||||||
|
|
||||||
user, repo, matcher, err := Matcher(rawPath, cfg)
|
var (
|
||||||
|
user string
|
||||||
|
repo string
|
||||||
|
matcher string
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
user, repo, matcher, err = Matcher(rawPath, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrInvalidURL) {
|
if errors.Is(err, ErrInvalidURL) {
|
||||||
c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath)
|
c.String(http.StatusForbidden, "Invalid URL Format. Path: %s", rawPath)
|
||||||
|
|
@ -69,18 +82,19 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
username := user
|
|
||||||
|
|
||||||
logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), username, repo)
|
logInfo("%s %s %s %s %s Matched-Username: %s, Matched-Repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), user, repo)
|
||||||
// dump log 记录详细信息 c.ClientIP(), c.Method(), rawPath,c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), full Header
|
// dump log 记录详细信息 c.ClientIP(), c.Method(), rawPath,c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), full Header
|
||||||
logDump("%s %s %s %s %s %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), c.Request.Header.Header())
|
logDump("%s %s %s %s %s %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), c.Request.Header.Header())
|
||||||
repouser := fmt.Sprintf("%s/%s", username, repo)
|
var repouser string
|
||||||
|
repouser = fmt.Sprintf("%s/%s", user, repo)
|
||||||
|
|
||||||
// 白名单检查
|
// 白名单检查
|
||||||
if cfg.Whitelist.Enabled {
|
if cfg.Whitelist.Enabled {
|
||||||
whitelist := auth.CheckWhitelist(username, repo)
|
var whitelist bool
|
||||||
|
whitelist = auth.CheckWhitelist(user, repo)
|
||||||
if !whitelist {
|
if !whitelist {
|
||||||
errMsg := fmt.Sprintf("Whitelist Blocked repo: %s", repouser)
|
errMsg = fmt.Sprintf("Whitelist Blocked repo: %s", repouser)
|
||||||
c.JSON(http.StatusForbidden, map[string]string{"error": errMsg})
|
c.JSON(http.StatusForbidden, map[string]string{"error": errMsg})
|
||||||
logWarning("%s %s %s %s %s Whitelist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser)
|
logWarning("%s %s %s %s %s Whitelist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser)
|
||||||
return
|
return
|
||||||
|
|
@ -89,9 +103,10 @@ func NoRouteHandler(cfg *config.Config, limiter *rate.RateLimiter, iplimiter *ra
|
||||||
|
|
||||||
// 黑名单检查
|
// 黑名单检查
|
||||||
if cfg.Blacklist.Enabled {
|
if cfg.Blacklist.Enabled {
|
||||||
blacklist := auth.CheckBlacklist(username, repo)
|
var blacklist bool
|
||||||
|
blacklist = auth.CheckBlacklist(user, repo)
|
||||||
if blacklist {
|
if blacklist {
|
||||||
errMsg := fmt.Sprintf("Blacklist Blocked repo: %s", repouser)
|
errMsg = fmt.Sprintf("Blacklist Blocked repo: %s", repouser)
|
||||||
c.JSON(http.StatusForbidden, map[string]string{"error": errMsg})
|
c.JSON(http.StatusForbidden, map[string]string{"error": errMsg})
|
||||||
logWarning("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser)
|
logWarning("%s %s %s %s %s Blacklist Blocked repo: %s", c.ClientIP(), c.Method(), rawPath, c.Request.Header.UserAgent(), c.Request.Header.GetProtocol(), repouser)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue