perf: modernize io paths and reduce proxy allocations

This commit is contained in:
wjqserver 2026-04-11 01:43:34 +08:00
parent 02861b5537
commit 54f7de0c60
11 changed files with 312 additions and 29 deletions

View file

@ -128,6 +128,19 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) {
} }
} }
func (c *Context) writeResponseBody(data []byte, contextMsg string) {
if len(data) == 0 {
return
}
if _, err := c.Writer.Write(data); err != nil {
wrapped := fmt.Errorf("%s: %w", contextMsg, err)
c.AddError(wrapped)
if c != nil && c.engine != nil && c.engine.LogReco != nil {
c.engine.LogReco.Errorf("%s: %v", contextMsg, err)
}
}
}
// Next 在处理链中执行下一个处理函数 // Next 在处理链中执行下一个处理函数
// 这是中间件模式的核心,允许请求依次经过多个处理函数 // 这是中间件模式的核心,允许请求依次经过多个处理函数
func (c *Context) Next() { func (c *Context) Next() {
@ -344,20 +357,20 @@ func (c *Context) Param(key string) string {
func (c *Context) Raw(code int, contentType string, data []byte) { func (c *Context) Raw(code int, contentType string, data []byte) {
c.Writer.Header().Set("Content-Type", contentType) c.Writer.Header().Set("Content-Type", contentType)
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write(data) c.writeResponseBody(data, "failed to write raw response")
} }
// String 向响应写入格式化的字符串 // String 向响应写入格式化的字符串
func (c *Context) String(code int, format string, values ...any) { func (c *Context) String(code int, format string, values ...any) {
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write(fmt.Appendf(nil, format, values...)) c.writeResponseBody(fmt.Appendf(nil, format, values...), "failed to write string response")
} }
// Text 向响应写入无需格式化的string // Text 向响应写入无需格式化的string
func (c *Context) Text(code int, text string) { func (c *Context) Text(code int, text string) {
c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8") c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write([]byte(text)) c.writeResponseBody([]byte(text), "failed to write text response")
} }
// FileText // FileText
@ -495,7 +508,7 @@ func (c *Context) JSONBuf(code int, obj any) {
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write(buf.Bytes()) c.writeResponseBody(buf.Bytes(), "failed to write buffered JSON response")
} }
// GOB 向响应写入GOB数据 // GOB 向响应写入GOB数据
@ -524,7 +537,7 @@ func (c *Context) GOBBuf(code int, obj any) {
} }
c.Writer.Header().Set("Content-Type", "application/octet-stream") c.Writer.Header().Set("Content-Type", "application/octet-stream")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write(buf.Bytes()) c.writeResponseBody(buf.Bytes(), "failed to write buffered GOB response")
} }
// WANF向响应写入WANF数据 // WANF向响应写入WANF数据
@ -553,7 +566,7 @@ func (c *Context) WANFBuf(code int, obj any) {
} }
c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8") c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write(buf.Bytes()) c.writeResponseBody(buf.Bytes(), "failed to write buffered WANF response")
} }
// HTML 渲染 HTML 模板 // HTML 渲染 HTML 模板
@ -577,7 +590,7 @@ func (c *Context) HTML(code int, name string, obj any) {
// 可以扩展支持其他渲染器接口 // 可以扩展支持其他渲染器接口
} }
// 默认简单输出,用于未配置 HTMLRender 的情况 // 默认简单输出,用于未配置 HTMLRender 的情况
c.Writer.Write(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj)) c.writeResponseBody(fmt.Appendf(nil, "<!-- HTML rendered for %s -->\n<pre>%v</pre>", name, obj), "failed to write HTML response")
} }
// HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体. // HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体.
@ -602,7 +615,7 @@ func (c *Context) HTMLBuf(code int, name string, obj any) {
// 渲染成功,写入响应 // 渲染成功,写入响应
c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
c.Writer.Write(buf.Bytes()) c.writeResponseBody(buf.Bytes(), "failed to write buffered HTML response")
return return
} }
@ -938,7 +951,7 @@ func (c *Context) GetReqBodyFull() ([]byte, error) {
} }
}() }()
data, err := iox.ReadAll(body) data, err := io.ReadAll(body)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) c.AddError(fmt.Errorf("failed to read request body: %w", err))
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)
@ -959,7 +972,7 @@ func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) {
} }
}() }()
data, err := iox.ReadAll(body) data, err := io.ReadAll(body)
if err != nil { if err != nil {
c.AddError(fmt.Errorf("failed to read request body: %w", err)) c.AddError(fmt.Errorf("failed to read request body: %w", err))
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)

View file

@ -154,7 +154,7 @@ func writeDefaultErrorJSON(c *Context, code int, body []byte) {
} }
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.Writer.WriteHeader(code) c.Writer.WriteHeader(code)
_, _ = c.Writer.Write(body) c.writeResponseBody(body, "failed to write default error response")
c.Writer.Flush() c.Writer.Flush()
c.Abort() c.Abort()
} }

View file

@ -1,11 +1,66 @@
package touka package touka
import ( import (
"bufio"
"encoding/json" "encoding/json"
"errors"
"html/template"
"net"
"net/http" "net/http"
"testing" "testing"
) )
type failingResponseWriter struct {
header http.Header
status int
err error
}
func (w *failingResponseWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *failingResponseWriter) WriteHeader(statusCode int) {
if w.status == 0 {
w.status = statusCode
}
}
func (w *failingResponseWriter) Write(p []byte) (int, error) {
if w.status == 0 {
w.status = http.StatusOK
}
if w.err != nil {
return 0, w.err
}
return len(p), nil
}
func (w *failingResponseWriter) Flush() {}
func (w *failingResponseWriter) Status() int {
return w.status
}
func (w *failingResponseWriter) Size() int {
return 0
}
func (w *failingResponseWriter) Written() bool {
return w.status != 0
}
func (w *failingResponseWriter) IsHijacked() bool {
return false
}
func (w *failingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, http.ErrNotSupported
}
func TestHandleRequestRedirectFixedPath(t *testing.T) { func TestHandleRequestRedirectFixedPath(t *testing.T) {
engine := New() engine := New()
engine.GET("/api/v1/users/:id/settings", func(c *Context) { engine.GET("/api/v1/users/:id/settings", func(c *Context) {
@ -185,3 +240,67 @@ func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) {
t.Fatalf("expected custom error body, got %q", rr.Body.String()) t.Fatalf("expected custom error body, got %q", rr.Body.String())
} }
} }
func TestResponseHelpersCaptureWriteErrors(t *testing.T) {
testCases := []struct {
name string
run func(*Context)
}{
{name: "Raw", run: func(c *Context) { c.Raw(http.StatusOK, "application/octet-stream", []byte("payload")) }},
{name: "String", run: func(c *Context) { c.String(http.StatusOK, "value=%d", 1) }},
{name: "Text", run: func(c *Context) { c.Text(http.StatusOK, "payload") }},
{name: "JSONBuf", run: func(c *Context) { c.JSONBuf(http.StatusOK, map[string]string{"a": "b"}) }},
{name: "GOBBuf", run: func(c *Context) { c.GOBBuf(http.StatusOK, struct{ A string }{A: "b"}) }},
{name: "WANFBuf", run: func(c *Context) { c.WANFBuf(http.StatusOK, map[string]string{"a": "b"}) }},
{name: "HTMLFallback", run: func(c *Context) { c.HTML(http.StatusOK, "page", map[string]string{"a": "b"}) }},
{name: "HTMLBuf", run: func(c *Context) {
c.engine.HTMLRender = template.Must(template.New("page").Parse(`{{.a}}`))
c.HTMLBuf(http.StatusOK, "page", map[string]string{"a": "b"})
}},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
writerErr := errors.New("write failed")
w := &failingResponseWriter{err: writerErr}
c, _ := CreateTestContext(w)
tc.run(c)
if got := len(c.Errors); got != 1 {
t.Fatalf("expected exactly one captured error, got %d", got)
}
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
}
})
}
}
func TestDefaultErrorFastPathCapturesWriteErrors(t *testing.T) {
writerErr := errors.New("write failed")
w := &failingResponseWriter{err: writerErr}
engine := New()
c, _ := CreateTestContext(w)
c.engine = engine
req, err := http.NewRequest(http.MethodGet, "/missing", nil)
if err != nil {
t.Fatalf("failed to build request: %v", err)
}
c.reset(w, req)
defaultErrorHandle(c, http.StatusNotFound, errNotFound)
if len(c.Errors) == 0 {
t.Fatal("expected write error to be captured")
}
if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) {
t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1])
}
if c.Writer.Status() != http.StatusNotFound {
t.Fatalf("expected status %d, got %d", http.StatusNotFound, c.Writer.Status())
}
if !c.IsAborted() {
t.Fatal("expected fast path to abort context")
}
}

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/infinite-iroha/touka
go 1.26 go 1.26
require ( require (
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3
github.com/WJQSERVER-STUDIO/httpc v0.9.0 github.com/WJQSERVER-STUDIO/httpc v0.9.0
github.com/WJQSERVER/wanf v0.0.8 github.com/WJQSERVER/wanf v0.0.8
github.com/fenthope/reco v0.0.5 github.com/fenthope/reco v0.0.5

2
go.sum
View file

@ -1,5 +1,7 @@
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM=
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok=
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 h1:Hc1O6D50U3URkdSzfQ/SgeUU750wUBCYhefdvAbE2Ck=
github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3/go.mod h1:nFQzepAwwdj5Hp5U+X19l4FVvsaOSBTW41BzfI/CkMA=
github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4= github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4=
github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg= github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg=
github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8= github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8=

150
iox_benchmark_test.go Normal file
View file

@ -0,0 +1,150 @@
package touka
import (
"bytes"
"io"
"testing"
"github.com/WJQSERVER-STUDIO/go-utils/iox"
)
type benchmarkResetReader struct {
data []byte
off int
}
func (r *benchmarkResetReader) Read(p []byte) (int, error) {
if r.off >= len(r.data) {
return 0, io.EOF
}
n := copy(p, r.data[r.off:])
r.off += n
return n, nil
}
func (r *benchmarkResetReader) Reset() {
r.off = 0
}
type benchmarkDiscardWriter struct{}
func (benchmarkDiscardWriter) Write(p []byte) (int, error) {
return len(p), nil
}
var benchmarkIOXResult int64
var benchmarkIOXBytes []byte
func BenchmarkIOXCopyComparison(b *testing.B) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
b.Run("io.Copy", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := io.Copy(w, r)
if err != nil {
b.Fatalf("io.Copy failed: %v", err)
}
benchmarkIOXResult = n
}
})
b.Run("iox.Copy", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := iox.Copy(w, r)
if err != nil {
b.Fatalf("iox.Copy failed: %v", err)
}
benchmarkIOXResult = n
}
})
}
func BenchmarkIOXCopyBufferComparison(b *testing.B) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
b.Run("io.CopyBuffer", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
buf := make([]byte, 32*1024)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := io.CopyBuffer(w, r, buf)
if err != nil {
b.Fatalf("io.CopyBuffer failed: %v", err)
}
benchmarkIOXResult = n
}
})
b.Run("iox.CopyBuffer", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
w := benchmarkDiscardWriter{}
buf := make([]byte, 32*1024)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
n, err := iox.CopyBuffer(w, r, buf)
if err != nil {
b.Fatalf("iox.CopyBuffer failed: %v", err)
}
benchmarkIOXResult = n
}
})
}
func BenchmarkIOXReadAllComparison(b *testing.B) {
payload := bytes.Repeat([]byte("0123456789abcdef"), 4096)
b.Run("io.ReadAll", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
data, err := io.ReadAll(r)
if err != nil {
b.Fatalf("io.ReadAll failed: %v", err)
}
benchmarkIOXBytes = data
}
})
b.Run("iox.ReadAll", func(b *testing.B) {
r := &benchmarkResetReader{data: payload}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Reset()
data, err := io.ReadAll(r)
if err != nil {
b.Fatalf("iox.ReadAll failed: %v", err)
}
benchmarkIOXBytes = data
}
})
}

View file

@ -1041,7 +1041,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r
go copyer.copyFromBackend(errc) go copyer.copyFromBackend(errc)
var firstErr error var firstErr error
for i := 0; i < 2; i++ { for range 2 {
err := <-errc err := <-errc
if reverseProxyIsBenignTunnelError(err) { if reverseProxyIsBenignTunnelError(err) {
continue continue
@ -1123,7 +1123,7 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt
}() }()
var firstErr error var firstErr error
for i := 0; i < 2; i++ { for range 2 {
err := <-errc err := <-errc
if reverseProxyIsBenignTunnelError(err) { if reverseProxyIsBenignTunnelError(err) {
continue continue
@ -1587,8 +1587,8 @@ func reverseProxyViaProtocol(major, minor int, raw string) string {
if major > 0 { if major > 0 {
return strconv.Itoa(major) + "." + strconv.Itoa(minor) return strconv.Itoa(major) + "." + strconv.Itoa(minor)
} }
if strings.HasPrefix(raw, "HTTP/") { if after, ok := strings.CutPrefix(raw, "HTTP/"); ok {
return strings.TrimPrefix(raw, "HTTP/") return after
} }
return raw return raw
} }
@ -1702,7 +1702,7 @@ var reverseProxyHopHeaders = []string{
func removeHopByHopHeaders(header http.Header) { func removeHopByHopHeaders(header http.Header) {
for _, connectionValue := range header["Connection"] { for _, connectionValue := range header["Connection"] {
for _, token := range strings.Split(connectionValue, ",") { for token := range strings.SplitSeq(connectionValue, ",") {
trimmed := textproto.TrimString(token) trimmed := textproto.TrimString(token)
if trimmed != "" { if trimmed != "" {
header.Del(trimmed) header.Del(trimmed)
@ -1726,7 +1726,7 @@ func headerValuesContainToken(values []string, token string) bool {
return false return false
} }
for _, value := range values { for _, value := range values {
for _, part := range strings.Split(value, ",") { for part := range strings.SplitSeq(value, ",") {
if strings.EqualFold(textproto.TrimString(part), token) { if strings.EqualFold(textproto.TrimString(part), token) {
return true return true
} }

View file

@ -235,7 +235,7 @@ func (r *recordingReader) Read(p []byte) (int, error) {
if n == 0 { if n == 0 {
return 0, errors.New("reader received zero-length buffer") return 0, errors.New("reader received zero-length buffer")
} }
for i := 0; i < n; i++ { for i := range n {
p[i] = 'x' p[i] = 'x'
} }
r.left -= n r.left -= n

View file

@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"net/textproto" "net/textproto"
"net/url" "net/url"
"slices"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -404,10 +405,5 @@ func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, statu
if status <= 0 { if status <= 0 {
return false return false
} }
for _, unhealthyStatus := range config.UnhealthyStatus { return slices.Contains(config.UnhealthyStatus, status)
if status == unhealthyStatus {
return true
}
}
return false
} }

View file

@ -1866,7 +1866,9 @@ func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T)
if message != "echo:ping\n" { if message != "echo:ping\n" {
t.Fatalf("unexpected tunneled response body: %q", message) t.Fatalf("unexpected tunneled response body: %q", message)
} }
_ = pw.Close() if err := pw.Close(); err != nil {
t.Fatalf("close tunneled request body: %v", err)
}
select { select {
case err := <-errCh: case err := <-errCh:
@ -2215,7 +2217,9 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
cancel() cancel()
_ = pw.CloseWithError(context.Canceled) if err := pw.CloseWithError(context.Canceled); err != nil {
t.Fatalf("close request body with cancellation: %v", err)
}
select { select {
case <-writeErrCh: case <-writeErrCh:
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):

View file

@ -182,8 +182,7 @@ func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) {
func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) {
cfg := defaultRunConfig() cfg := defaultRunConfig()
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
if err := WithShutdownContext(ctx).apply(&cfg); err != nil { if err := WithShutdownContext(ctx).apply(&cfg); err != nil {
t.Fatalf("apply shutdown context option: %v", err) t.Fatalf("apply shutdown context option: %v", err)
} }