perf(proxy): reduce nest rewrite allocations

- Dispatch shell link rewriting between streaming and buffered paths based on response size

- Reuse buffers and reduce URL construction allocations in proxy handlers

- Add nest benchmarks and align extractParts compatibility expectations with the current contract
This commit is contained in:
wjqserver 2026-04-12 00:02:54 +08:00
parent 4c555ed50c
commit e2719aa761
6 changed files with 248 additions and 88 deletions

View file

@ -2,14 +2,27 @@ package proxy
import (
"bufio"
"bytes"
"fmt"
"ghproxy/config"
"io"
"strings"
"sync"
"github.com/infinite-iroha/touka"
)
var (
prefixGithub = []byte("https://github.com")
prefixRawUser = []byte("https://raw.githubusercontent.com")
prefixRaw = []byte("https://raw.github.com")
prefixGistUser = []byte("https://gist.githubusercontent.com")
prefixGist = []byte("https://gist.github.com")
prefixAPIBytes = []byte("https://api.github.com")
prefixHTTP = []byte("http://")
prefixHTTPS = []byte("https://")
)
func EditorMatcher(rawPath string, cfg *config.Config) (bool, error) {
// 匹配 "https://github.com"开头的链接
if strings.HasPrefix(rawPath, "https://github.com") {
@ -40,6 +53,28 @@ func EditorMatcher(rawPath string, cfg *config.Config) (bool, error) {
return false, nil
}
func EditorMatcherBytes(rawPath []byte, cfg *config.Config) bool {
if bytes.HasPrefix(rawPath, prefixGithub) {
return true
}
if bytes.HasPrefix(rawPath, prefixRawUser) {
return true
}
if bytes.HasPrefix(rawPath, prefixRaw) {
return true
}
if bytes.HasPrefix(rawPath, prefixGistUser) {
return true
}
if bytes.HasPrefix(rawPath, prefixGist) {
return true
}
if cfg.Shell.RewriteAPI && bytes.HasPrefix(rawPath, prefixAPIBytes) {
return true
}
return false
}
// 匹配文件扩展名是sh的rawPath
func MatcherShell(rawPath string) bool {
return strings.HasSuffix(rawPath, ".sh")
@ -64,87 +99,140 @@ func modifyURL(url string, host string, cfg *config.Config) string {
return url
}
// processLinks 处理链接,返回包含处理后数据的 io.Reader
func processLinks(input io.ReadCloser, host string, cfg *config.Config, c *touka.Context) (readerOut io.Reader, written int64, err error) {
pipeReader, pipeWriter := io.Pipe() // 创建 io.Pipe
func modifyURLBytes(url []byte, host []byte, cfg *config.Config) []byte {
if !EditorMatcherBytes(url, cfg) {
return url
}
var trimmed []byte
if bytes.HasPrefix(url, prefixHTTPS) {
trimmed = url[len(prefixHTTPS):]
} else if bytes.HasPrefix(url, prefixHTTP) {
trimmed = url[len(prefixHTTP):]
} else {
trimmed = url
}
newURL := make([]byte, len(prefixHTTPS)+len(host)+1+len(trimmed))
written := 0
written += copy(newURL[written:], prefixHTTPS)
written += copy(newURL[written:], host)
written += copy(newURL[written:], []byte("/"))
copy(newURL[written:], trimmed)
return newURL
}
var bufferPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
func processLinksStreamingInternal(input io.ReadCloser, host string, cfg *config.Config, c *touka.Context) (readerOut io.Reader, written int64, err error) {
pipeReader, pipeWriter := io.Pipe()
readerOut = pipeReader
go func() { // 在 Goroutine 中执行写入操作
go func() {
defer func() {
if pipeWriter != nil { // 确保 pipeWriter 关闭,即使发生错误
if err != nil {
if closeErr := pipeWriter.CloseWithError(err); closeErr != nil { // 如果有错误,传递错误给 reader
c.Errorf("pipeWriter close with error failed: %v, original error: %v", closeErr, err)
}
} else {
if closeErr := pipeWriter.Close(); closeErr != nil { // 没有错误,正常关闭
c.Errorf("pipeWriter close failed: %v", closeErr)
if err == nil { // 如果之前没有错误,记录关闭错误
err = closeErr
}
}
}
if err != nil {
_ = pipeWriter.CloseWithError(err)
return
}
_ = pipeWriter.Close()
}()
defer func() {
if closeErr := input.Close(); closeErr != nil && c != nil {
c.Errorf("input close failed: %v", closeErr)
}
}()
bufReader := bufio.NewReader(input)
bufWriter := bufio.NewWriterSize(pipeWriter, 4096)
defer func() {
if err := input.Close(); err != nil {
c.Errorf("input close failed: %v", err)
}
}()
var bufReader *bufio.Reader
bufReader = bufio.NewReader(input)
var bufWriter *bufio.Writer
bufWriter = bufio.NewWriterSize(pipeWriter, 4096) // 使用 pipeWriter
//确保writer关闭
defer func() {
if flushErr := bufWriter.Flush(); flushErr != nil {
c.Errorf("writer flush failed %v", flushErr)
// 如果已经存在错误,则保留。否则,记录此错误。
if err == nil {
err = flushErr
}
if flushErr := bufWriter.Flush(); flushErr != nil && err == nil {
err = fmt.Errorf("flush writer failed: %w", flushErr)
}
}()
// 使用正则表达式匹配 http 和 https 链接
for {
line, readErr := bufReader.ReadString('\n')
if readErr != nil {
if readErr == io.EOF {
break // 文件结束
if readErr != nil && readErr != io.EOF {
err = fmt.Errorf("read error: %w", readErr)
return
}
if len(line) > 0 {
modifiedLine := urlPattern.ReplaceAllStringFunc(line, func(originalURL string) string {
return modifyURL(originalURL, host, cfg)
})
n, writeErr := bufWriter.WriteString(modifiedLine)
written += int64(n)
if writeErr != nil {
err = fmt.Errorf("write error: %w", writeErr)
return
}
err = fmt.Errorf("读取行错误: %v", readErr) // 传递错误
return // Goroutine 中使用 return 返回错误
}
// 替换所有匹配的 URL
modifiedLine := urlPattern.ReplaceAllStringFunc(line, func(originalURL string) string {
return modifyURL(originalURL, host, cfg) // 假设 modifyURL 函数已定义
})
n, writeErr := bufWriter.WriteString(modifiedLine)
written += int64(n) // 更新写入的字节数
if writeErr != nil {
err = fmt.Errorf("写入文件错误: %v", writeErr) // 传递错误
return // Goroutine 中使用 return 返回错误
if readErr == io.EOF {
break
}
}
// 在返回之前,再刷新一次 (虽然 defer 中已经有 flush但这里再加一次确保及时刷新)
if flushErr := bufWriter.Flush(); flushErr != nil {
if err == nil { // 避免覆盖之前的错误
err = flushErr
}
return // Goroutine 中使用 return 返回错误
}
}()
return readerOut, written, nil // 返回 reader 和 writtenerror 由 Goroutine 通过 pipeWriter.CloseWithError 传递
return readerOut, written, nil
}
func processLinks(input io.ReadCloser, host string, cfg *config.Config, c *touka.Context, bodySize int) (readerOut io.Reader, written int64, err error) {
const sizeThreshold = 256 * 1024
if bodySize == -1 || bodySize > sizeThreshold {
return processLinksStreamingInternal(input, host, cfg, c)
}
return processLinksBufferedInternal(input, host, cfg, c)
}
func processLinksBufferedInternal(input io.ReadCloser, host string, cfg *config.Config, c *touka.Context) (readerOut io.Reader, written int64, err error) {
pipeReader, pipeWriter := io.Pipe()
readerOut = pipeReader
hostBytes := []byte(host)
go func() {
defer func() {
if closeErr := input.Close(); closeErr != nil && c != nil {
c.Errorf("input close failed: %v", closeErr)
}
}()
defer func() {
if err != nil {
_ = pipeWriter.CloseWithError(err)
return
}
_ = pipeWriter.Close()
}()
buf := bufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer bufferPool.Put(buf)
if _, err = buf.ReadFrom(input); err != nil {
err = fmt.Errorf("reading input failed: %w", err)
return
}
modifiedBytes := urlPattern.ReplaceAllFunc(buf.Bytes(), func(originalURL []byte) []byte {
return modifyURLBytes(originalURL, hostBytes, cfg)
})
var n int
n, err = pipeWriter.Write(modifiedBytes)
written = int64(n)
if err != nil {
err = fmt.Errorf("writing to pipe failed: %w", err)
}
}()
return readerOut, written, nil
}