From 55158c0cb10fbcad0eaf2f350710d43e5d4c4ea7 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 25 Mar 2025 23:35:40 +0800 Subject: [PATCH] update --- proxy/chunkreq.go | 33 ++--- proxy/gitreq.go | 39 +++--- proxy/match.go | 309 +++++++++++++++------------------------------ proxy/reqheader.go | 1 + 4 files changed, 141 insertions(+), 241 deletions(-) diff --git a/proxy/chunkreq.go b/proxy/chunkreq.go index fb78762..130e22e 100644 --- a/proxy/chunkreq.go +++ b/proxy/chunkreq.go @@ -9,9 +9,8 @@ import ( "net/http" "strconv" - "github.com/WJQSERVER-STUDIO/go-utils/copyb" "github.com/cloudwego/hertz/pkg/app" - hresp "github.com/cloudwego/hertz/pkg/protocol/http1/resp" + //hresp "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, cfg *config.Config, matcher string) { @@ -71,7 +70,7 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c HandleError(c, fmt.Sprintf("Failed to send request: %v", err)) return } - defer resp.Body.Close() + //defer resp.Body.Close() // 错误处理(404) if resp.StatusCode == 404 { @@ -117,8 +116,8 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c c.Header("Access-Control-Allow-Origin", cfg.Server.Cors) } - c.Status(resp.StatusCode) - c.Response.HijackWriter(hresp.NewChunkedBodyWriter(&c.Response, c.GetWriter())) + //c.Status(resp.StatusCode) + //c.Response.HijackWriter(hresp.NewChunkedBodyWriter(&c.Response, c.GetWriter())) if MatcherShell(u) && matchString(matcher, matchedMatchers) && cfg.Shell.Editor { // 判断body是不是gzip @@ -131,23 +130,27 @@ func ChunkedProxyRequest(ctx context.Context, c *app.RequestContext, u string, c c.Header("Content-Length", "") //err := ProcessLinksAndWriteChunked(resp.Body, compress, string(c.Request.Host()), cfg, c) - _, err := processLinks(resp.Body, c.Response.BodyWriter(), compress, string(c.Request.Host()), cfg) + reader, _, err := processLinks(resp.Body, compress, string(c.Request.Host()), cfg) + c.SetBodyStream(reader, -1) if err != nil { logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err) return - } else { - c.Flush() // 确保刷入 } } else { //err = hwriter.Writer(resp.Body, c) //writer := c.Response.BodyWriter() - _, err := copyb.Copy(c.Response.BodyWriter(), resp.Body) - if err != nil { - logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err) - return - } else { - c.Flush() // 确保刷入 - } + + /* + _, err := copyb.Copy(c.Response.BodyWriter(), resp.Body) + if err != nil { + logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err) + return + } else { + c.Flush() // 确保刷入 + } + */ + c.SetBodyStream(resp.Body, -1) } + } diff --git a/proxy/gitreq.go b/proxy/gitreq.go index 0405a0f..8472303 100644 --- a/proxy/gitreq.go +++ b/proxy/gitreq.go @@ -5,11 +5,9 @@ import ( "context" "fmt" "ghproxy/config" - "io" "net/http" "strconv" - "github.com/WJQSERVER-STUDIO/go-utils/copyb" "github.com/cloudwego/hertz/pkg/app" ) @@ -30,7 +28,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co var ( resp *http.Response - err error + //err error ) body := c.Request.Body() @@ -46,7 +44,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co } setRequestHeaders(c, req) removeWSHeader(req) - reWriteEncodeHeader(req) + //reWriteEncodeHeader(req) AuthPassThrough(c, cfg, req) resp, err = gitclient.Do(req) @@ -62,7 +60,7 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co } setRequestHeaders(c, req) removeWSHeader(req) - reWriteEncodeHeader(req) + //reWriteEncodeHeader(req) AuthPassThrough(c, cfg, req) resp, err = client.Do(req) @@ -71,12 +69,14 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co return } } - //defer resp.Body.Close() - defer func(Body io.ReadCloser) { - if err := Body.Close(); err != nil { - logError("Failed to close response body: %v", err) - } - }(resp.Body) + /* + //defer resp.Body.Close() + defer func(Body io.ReadCloser) { + if err := Body.Close(); err != nil { + logError("Failed to close response body: %v", err) + } + }(resp.Body) + */ contentLength := resp.Header.Get("Content-Length") if contentLength != "" { @@ -118,15 +118,18 @@ func GitReq(ctx context.Context, c *app.RequestContext, u string, cfg *config.Co } c.Status(resp.StatusCode) + c.SetBodyStream(resp.Body, -1) //err = hwriter.Writer(resp.Body, c) - _, err = copyb.Copy(c.Response.BodyWriter(), resp.Body) + /* + _, err = copyb.Copy(c.Response.BodyWriter(), resp.Body) - if err != nil { - logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err) - return - } else { + if err != nil { + logError("%s %s %s %s %s Failed to copy response body: %v", c.ClientIP(), method, u, c.Request.Header.Get("User-Agent"), c.Request.Header.GetProtocol(), err) + return + } else { - c.Flush() // 确保刷入 - } + c.Flush() // 确保刷入 + } + */ } diff --git a/proxy/match.go b/proxy/match.go index 5f1f40f..67d1412 100644 --- a/proxy/match.go +++ b/proxy/match.go @@ -6,15 +6,9 @@ import ( "fmt" "ghproxy/config" "io" - "net/http" "net/url" "regexp" "strings" - "sync" - - "github.com/cloudwego/hertz/pkg/app" - "github.com/cloudwego/hertz/pkg/protocol/http1/resp" - "github.com/valyala/bytebufferpool" ) // 定义错误类型, error承载描述, 便于处理 @@ -185,9 +179,9 @@ func modifyURL(url string, host string, cfg *config.Config) string { return url } if matched { - - u := strings.TrimPrefix(url, "https://") - u = strings.TrimPrefix(url, "http://") + var u = url + u = strings.TrimPrefix(u, "https://") + u = strings.TrimPrefix(u, "http://") logDump("Modified URL: %s", "https://"+host+"/"+u) return "https://" + host + "/" + u } @@ -212,136 +206,6 @@ func matchString(target string, stringsToMatch []string) bool { return exists } -func ProcessLinksAndWriteChunked(input io.Reader, compress string, host string, cfg *config.Config, c *app.RequestContext) error { - pr, pw := io.Pipe() // 创建一个管道,用于进程间通信 - var wg sync.WaitGroup - wg.Add(2) - - var processErr error // 用于存储处理过程中发生的错误 - - go func() { - defer wg.Done() // 协程结束时通知 WaitGroup - defer pw.Close() // 协程结束时关闭管道的写端 - - var reader *bufio.Reader - if compress == "gzip" { // 如果需要解压 - gzipReader, err := gzip.NewReader(input) // 创建 gzip 解压器 - if err != nil { - c.String(http.StatusInternalServerError, fmt.Sprintf("gzip 解压错误: %v", err)) // 设置 HTTP 状态码和错误信息 - processErr = fmt.Errorf("gzip decompression error: %w", err) // gzip decompression error - return - } - defer gzipReader.Close() // 延迟关闭 gzip 解压器 - reader = bufio.NewReader(gzipReader) // 使用 bufio 读取解压后的数据 - } else { - reader = bufio.NewReader(input) // 直接使用 bufio 读取原始数据 - } - - var writer io.Writer = pw // 默认写入管道 - var gzipWriter *gzip.Writer - - if compress == "gzip" { // 如果需要压缩 - gzipWriter = gzip.NewWriter(writer) // 创建 gzip 压缩器 - writer = gzipWriter // 将 writer 设置为 gzip 压缩器 - defer func() { // 延迟关闭 gzip 压缩器 - if err := gzipWriter.Close(); err != nil { - logError("gzipWriter close failed: %v", err) - processErr = fmt.Errorf("gzipwriter close failed: %w", err) // gzipwriter close failed - } - }() - } - - urlPattern := regexp.MustCompile(`https?://[^\s'"]+`) // 编译正则表达式,用于匹配 URL - scanner := bufio.NewScanner(reader) // 创建 scanner 用于逐行扫描 - for scanner.Scan() { // 循环读取每一行 - line := scanner.Text() // 获取当前行 - modifiedLine := urlPattern.ReplaceAllStringFunc(line, func(originalURL string) string { // 替换 URL - return modifyURL(originalURL, host, cfg) // 调用 modifyURL 函数修改 URL - }) - modifiedLineWithNewline := modifiedLine + "\n" // 添加换行符 - - _, err := writer.Write([]byte(modifiedLineWithNewline)) // 将修改后的行写入管道/gzip - if err != nil { - logError("写入 pipe 错误: %v", err) // 记录错误 - processErr = fmt.Errorf("write to pipe error: %w", err) // write to pipe error - return - } - } - - if err := scanner.Err(); err != nil { - logError("读取输入错误: %v", err) // 记录错误 - c.String(http.StatusInternalServerError, fmt.Sprintf("读取输入错误: %v", err)) // 设置 HTTP 状态码和错误信息 - processErr = fmt.Errorf("read input error: %w", err) // read input error - return - } - }() - - go func() { - defer wg.Done() // 协程结束时通知 WaitGroup - - c.Response.HijackWriter(resp.NewChunkedBodyWriter(&c.Response, c.GetWriter())) // 劫持 writer,启用分块编码 - - bufWrapper := bytebufferpool.Get() // 从对象池获取 bytebuffer - buf := bufWrapper.B - size := 32768 // 32KB, 设置缓冲区大小 - buf = buf[:cap(buf)] - if len(buf) < size { - buf = append(buf, make([]byte, size-len(buf))...) - } - buf = buf[:size] // 将缓冲区限制为 'size' - defer bytebufferpool.Put(bufWrapper) // 延迟将 bytebuffer 放回对象池 - - for { // 循环读取和写入数据 - n, err := pr.Read(buf) // 从管道读取数据 - if err != nil { - if err == io.EOF { // 如果读取到文件末尾 - if n > 0 { // 确保写入所有剩余数据 - _, err := c.Write(buf[:n]) // 写入最后的数据块 - if err != nil { - processErr = fmt.Errorf("failed to write final chunk: %w", err) // failed to write final chunk - break - } - } - c.Flush() // 刷新缓冲区 - break // 读取到文件末尾, 退出循环 - } - logError("hwriter.Writer read error: %v", err) // 记录错误 - if processErr == nil { - processErr = fmt.Errorf("failed to read from pipe: %w", err) // failed to read from pipe - // 不要在这里设置 http status code. 如果 read 失败, process 协程可能还没有完成, 它可能正在尝试设置 status code. 两个地方都设置会导致 race condition. - } - break // 读取错误,退出循环 - } - - if n > 0 { // 只有在实际读取到数据时才写入 - _, err = c.Write(buf[:n]) // 将数据写入响应 - if err != nil { - // 处理写入错误 (考虑记录日志并可能中止) - logError("hwriter.Writer write error: %v", err) - if processErr == nil { // 仅当 processErr 尚未设置时才设置. - processErr = fmt.Errorf("failed to write chunk: %w", err) // failed to write chunk - } - break // 写入错误, 退出循环 - } - - // 在大多数情况下,考虑移除 Flush. 仅在 *真正* 需要时保留它。 - if err := c.Flush(); err != nil { - // 更强大的 Flush() 错误处理 - c.AbortWithStatus(http.StatusInternalServerError) // 中止响应 - logError("hwriter.Writer flush error: %v", err) - if processErr == nil { - processErr = fmt.Errorf("failed to flush chunk: %w", err) // failed to flush chunk - } - break // 刷新错误, 退出循环 - } - } - } - }() - - wg.Wait() // 等待两个协程结束 - return processErr // 返回错误 -} - // extractParts 从给定的 URL 中提取所需的部分 func extractParts(rawURL string) (string, string, string, url.Values, error) { // 解析 URL @@ -374,82 +238,111 @@ func extractParts(rawURL string) (string, string, string, url.Values, error) { return repoOwner, repoName, remainingPath, queryParams, nil } -// processLinks 处理链接并将结果写入输出流 -func processLinks(input io.Reader, output io.Writer, compress string, host string, cfg *config.Config) (written int64, err error) { - var reader *bufio.Reader +// processLinks 处理链接,返回包含处理后数据的 io.Reader +func processLinks(input io.Reader, compress string, host string, cfg *config.Config) (readerOut io.Reader, written int64, err error) { + pipeReader, pipeWriter := io.Pipe() // 创建 io.Pipe + readerOut = pipeReader - if compress == "gzip" { - // 解压gzip - gzipReader, err := gzip.NewReader(input) - if err != nil { - return 0, fmt.Errorf("gzip解压错误: %v", err) - } - defer gzipReader.Close() - reader = bufio.NewReader(gzipReader) - } else { - reader = bufio.NewReader(input) - } - - var writer *bufio.Writer - var gzipWriter *gzip.Writer - - // 根据是否gzip确定 writer 的创建 - if compress == "gzip" { - gzipWriter = gzip.NewWriter(output) - writer = bufio.NewWriterSize(gzipWriter, 4096) //设置缓冲区大小 - } else { - writer = bufio.NewWriterSize(output, 4096) - } - - //确保writer关闭 - defer func() { - var closeErr error // 局部变量,用于保存defer中可能发生的错误 - - if gzipWriter != nil { - if closeErr = gzipWriter.Close(); closeErr != nil { - logError("gzipWriter close failed %v", closeErr) - // 如果已经存在错误,则保留。否则,记录此错误。 - if err == nil { - err = closeErr + go func() { // 在 Goroutine 中执行写入操作 + defer func() { + if pipeWriter != nil { // 确保 pipeWriter 关闭,即使发生错误 + if err != nil { + if closeErr := pipeWriter.CloseWithError(err); closeErr != nil { // 如果有错误,传递错误给 reader + logError("pipeWriter close with error failed: %v, original error: %v", closeErr, err) + } + } else { + if closeErr := pipeWriter.Close(); closeErr != nil { // 没有错误,正常关闭 + logError("pipeWriter close failed: %v", closeErr) + if err == nil { // 如果之前没有错误,记录关闭错误 + err = closeErr + } + } } } + }() + + var bufReader *bufio.Reader + + if compress == "gzip" { + // 解压gzip + gzipReader, gzipErr := gzip.NewReader(input) + if gzipErr != nil { + err = fmt.Errorf("gzip解压错误: %v", gzipErr) + return // Goroutine 中使用 return 返回错误 + } + defer gzipReader.Close() + bufReader = bufio.NewReader(gzipReader) + } else { + bufReader = bufio.NewReader(input) } - if flushErr := writer.Flush(); flushErr != nil { - logError("writer flush failed %v", flushErr) - // 如果已经存在错误,则保留。否则,记录此错误。 - if err == nil { + + var bufWriter *bufio.Writer + var gzipWriter *gzip.Writer + + // 根据是否gzip确定 writer 的创建 + if compress == "gzip" { + gzipWriter = gzip.NewWriter(pipeWriter) // 使用 pipeWriter + bufWriter = bufio.NewWriterSize(gzipWriter, 4096) //设置缓冲区大小 + } else { + bufWriter = bufio.NewWriterSize(pipeWriter, 4096) // 使用 pipeWriter + } + + //确保writer关闭 + defer func() { + var closeErr error // 局部变量,用于保存defer中可能发生的错误 + + if gzipWriter != nil { + if closeErr = gzipWriter.Close(); closeErr != nil { + logError("gzipWriter close failed %v", closeErr) + // 如果已经存在错误,则保留。否则,记录此错误。 + if err == nil { + err = closeErr + } + } + } + if flushErr := bufWriter.Flush(); flushErr != nil { + logError("writer flush failed %v", flushErr) + // 如果已经存在错误,则保留。否则,记录此错误。 + if err == nil { + err = flushErr + } + } + }() + + // 使用正则表达式匹配 http 和 https 链接 + urlPattern := regexp.MustCompile(`https?://[^\s'"]+`) + for { + line, readErr := bufReader.ReadString('\n') + if readErr != nil { + if readErr == io.EOF { + break // 文件结束 + } + err = fmt.Errorf("读取行错误: %v", readErr) // 传递错误 + return // Goroutine 中使用 return 返回错误 + } + + // 替换所有匹配的 URL + modifiedLine := urlPattern.ReplaceAllStringFunc(line, func(originalURL string) string { + logDump("originalURL: %s", originalURL) + return modifyURL(originalURL, host, cfg) // 假设 modifyURL 函数已定义 + }) + + n, writeErr := bufWriter.WriteString(modifiedLine) + written += int64(n) // 更新写入的字节数 + if writeErr != nil { + err = fmt.Errorf("写入文件错误: %v", writeErr) // 传递错误 + return // Goroutine 中使用 return 返回错误 + } + } + + // 在返回之前,再刷新一次 (虽然 defer 中已经有 flush,但这里再加一次确保及时刷新) + if flushErr := bufWriter.Flush(); flushErr != nil { + if err == nil { // 避免覆盖之前的错误 err = flushErr } + return // Goroutine 中使用 return 返回错误 } }() - // 使用正则表达式匹配 http 和 https 链接 - urlPattern := regexp.MustCompile(`https?://[^\s'"]+`) - for { - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - break // 文件结束 - } - return written, fmt.Errorf("读取行错误: %v", err) // 传递错误 - } - - // 替换所有匹配的 URL - modifiedLine := urlPattern.ReplaceAllStringFunc(line, func(originalURL string) string { - return modifyURL(originalURL, host, cfg) - }) - - n, werr := writer.WriteString(modifiedLine) - written += int64(n) // 更新写入的字节数 - if werr != nil { - return written, fmt.Errorf("写入文件错误: %v", werr) // 传递错误 - } - } - - // 在返回之前,再刷新一次 - if fErr := writer.Flush(); fErr != nil { - return written, fErr - } - - return written, nil + return readerOut, written, nil // 返回 reader 和 written,error 由 Goroutine 通过 pipeWriter.CloseWithError 传递 } diff --git a/proxy/reqheader.go b/proxy/reqheader.go index 4926186..86e33d4 100644 --- a/proxy/reqheader.go +++ b/proxy/reqheader.go @@ -20,6 +20,7 @@ func removeWSHeader(req *http.Request) { } func reWriteEncodeHeader(req *http.Request) { + if isGzipAccepted(req.Header) { req.Header.Set("Content-Encoding", "gzip") req.Header.Set("Accept-Encoding", "gzip")