From 64e2ad9e7b4f514f44622967d0a6fcc5c8468445 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 31 Mar 2026 16:38:04 +0800 Subject: [PATCH 01/55] Fix FileText status code and unify request body size limits - FileText: now respects the provided status code instead of defaulting to 200 OK - Request body limits: prepareRequestBody() is now only called when MaxRequestBodySize > 0 - ShouldBindJSON, ShouldBindWANF, ShouldBindGOB, ShouldBindForm, GetReqBody, PostForm all now use the original c.Request.Body path when no limit is configured - maxBytesReader: fixed exact-limit boundary case where body size == limit was incorrectly rejected - Added regression tests for FileText status codes and body limit behavior All existing tests pass, and new tests verify the corrected behavior. --- context.go | 161 +++++++++++++++++++++++++------------- context_bodylimit_test.go | 125 +++++++++++++++++++++++++++++ maxreader.go | 22 +++++- 3 files changed, 252 insertions(+), 56 deletions(-) create mode 100644 context_bodylimit_test.go diff --git a/context.go b/context.go index 2e4d2bb..855206e 100644 --- a/context.go +++ b/context.go @@ -44,6 +44,8 @@ type Context struct { handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler) index int8 // 当前执行到处理链的哪个位置 + requestBodyPrepared bool + mu sync.RWMutex Keys map[string]any // 用于在中间件之间传递数据 @@ -102,6 +104,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize + c.requestBodyPrepared = false if cap(c.SkippedNodes) > 0 { c.SkippedNodes = c.SkippedNodes[:0] @@ -237,6 +240,18 @@ func (c *Context) SetMaxRequestBodySize(size int64) { c.MaxRequestBodySize = size } +func (c *Context) prepareRequestBody() io.ReadCloser { + if c.Request == nil || c.Request.Body == nil { + return nil + } + if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 { + return c.Request.Body + } + c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) + c.requestBodyPrepared = true + return c.Request.Body +} + // Query 从 URL 查询参数中获取值 // 懒加载解析查询参数,并进行缓存 func (c *Context) Query(key string) string { @@ -258,7 +273,39 @@ func (c *Context) DefaultQuery(key, defaultValue string) string { // 懒加载解析表单数据,并进行缓存 func (c *Context) PostForm(key string) string { if c.formCache == nil { - c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded + if c.MaxRequestBodySize > 0 { + c.prepareRequestBody() + contentType := c.Request.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + + switch mediaType { + case "multipart/form-data": + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + default: + if err := c.Request.ParseForm(); err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + } + } else { + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { + if !errors.Is(err, http.ErrNotMultipart) { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + } + } c.formCache = c.Request.PostForm } return c.formCache.Get(key) @@ -338,8 +385,11 @@ func (c *Context) FileText(code int, filePath string) { } c.SetHeader("Content-Type", "text/plain; charset=utf-8") - - c.SetBodyStream(file, int(fileInfo.Size())) + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) + c.Writer.WriteHeader(code) + if _, err := iox.Copy(c.Writer, file); err != nil { + c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err)) + } } /* @@ -557,10 +607,16 @@ func (c *Context) Redirect(code int, location string) { // ShouldBindJSON 尝试将请求体绑定到 JSON 对象 func (c *Context) ShouldBindJSON(obj any) error { - if c.Request.Body == nil { + var body io.ReadCloser + if c.MaxRequestBodySize > 0 { + body = c.prepareRequestBody() + } else { + body = c.Request.Body + } + if body == nil { return errors.New("request body is empty") } - err := json.UnmarshalRead(c.Request.Body, obj) + err := json.UnmarshalRead(body, obj) if err != nil { return fmt.Errorf("json binding error: %w", err) } @@ -569,10 +625,16 @@ func (c *Context) ShouldBindJSON(obj any) error { // ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象 func (c *Context) ShouldBindWANF(obj any) error { - if c.Request.Body == nil { + var body io.ReadCloser + if c.MaxRequestBodySize > 0 { + body = c.prepareRequestBody() + } else { + body = c.Request.Body + } + if body == nil { return errors.New("request body is empty") } - decoder, err := wanf.NewStreamDecoder(c.Request.Body) + decoder, err := wanf.NewStreamDecoder(body) if err != nil { return fmt.Errorf("failed to create WANF decoder: %w", err) } @@ -585,10 +647,16 @@ func (c *Context) ShouldBindWANF(obj any) error { // ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象 func (c *Context) ShouldBindGOB(obj any) error { - if c.Request.Body == nil { + var body io.ReadCloser + if c.MaxRequestBodySize > 0 { + body = c.prepareRequestBody() + } else { + body = c.Request.Body + } + if body == nil { return errors.New("request body is empty") } - decoder := gob.NewDecoder(c.Request.Body) + decoder := gob.NewDecoder(body) if err := decoder.Decode(obj); err != nil { return fmt.Errorf("GOB binding error: %w", err) } @@ -705,6 +773,10 @@ func setFieldValue(field reflect.Value, values []string) error { // ShouldBindForm 尝试将表单数据绑定到结构体 // 支持 application/x-www-form-urlencoded 和 multipart/form-data func (c *Context) ShouldBindForm(obj any) error { + if c.MaxRequestBodySize > 0 { + c.prepareRequestBody() + } + contentType := c.Request.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(contentType) if err != nil { @@ -713,7 +785,7 @@ func (c *Context) ShouldBindForm(obj any) error { switch mediaType { case "multipart/form-data": - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { return fmt.Errorf("parse multipart form error: %w", err) } case "application/x-www-form-urlencoded": @@ -727,6 +799,7 @@ func (c *Context) ShouldBindForm(obj any) error { if err := bindForm(c.Request.Form, obj); err != nil { return fmt.Errorf("form binding error: %w", err) } + c.formCache = c.Request.PostForm return nil } @@ -827,37 +900,30 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) { // GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体 // 注意:请求体只能读取一次 func (c *Context) GetReqBody() io.ReadCloser { + if c.MaxRequestBodySize > 0 { + return c.prepareRequestBody() + } + if c.Request == nil || c.Request.Body == nil { + return nil + } return c.Request.Body } // GetReqBodyFull 读取并返回请求体的所有内容 // 注意:请求体只能读取一次 func (c *Context) GetReqBodyFull() ([]byte, error) { - if c.Request.Body == nil { + body := c.GetReqBody() + if body == nil { return nil, nil } + defer func() { + err := body.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() - var limitBytesReader io.ReadCloser - - if c.MaxRequestBodySize > 0 { - limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } else { - limitBytesReader = c.Request.Body - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } - - data, err := iox.ReadAll(limitBytesReader) + data, err := iox.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -867,31 +933,18 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { // 类似 GetReqBodyFull, 返回 *bytes.Buffer func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { - if c.Request.Body == nil { + body := c.GetReqBody() + if body == nil { return nil, nil } + defer func() { + err := body.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() - var limitBytesReader io.ReadCloser - - if c.MaxRequestBodySize > 0 { - limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } else { - limitBytesReader = c.Request.Body - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } - - data, err := iox.ReadAll(limitBytesReader) + data, err := iox.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) diff --git a/context_bodylimit_test.go b/context_bodylimit_test.go new file mode 100644 index 0000000..546f06e --- /dev/null +++ b/context_bodylimit_test.go @@ -0,0 +1,125 @@ +package touka + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestFileTextUsesProvidedStatusCode(t *testing.T) { + t.Helper() + + dir := t.TempDir() + filePath := filepath.Join(dir, "hello.txt") + if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil { + t.Fatalf("write temp file: %v", err) + } + + rr := httptest.NewRecorder() + c, _ := CreateTestContext(rr) + + c.FileText(http.StatusCreated, filePath) + + if rr.Code != http.StatusCreated { + t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code) + } + if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" { + t.Fatalf("unexpected content type: %q", got) + } + if body := rr.Body.String(); body != "hello touka" { + t.Fatalf("unexpected body: %q", body) + } +} + +func TestMaxBytesReaderAllowsExactLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4) + defer reader.Close() + + data, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("expected exact limit read to succeed, got %v", err) + } + if string(data) != "abcd" { + t.Fatalf("unexpected data: %q", string(data)) + } +} + +func TestMaxBytesReaderRejectsOverLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4) + defer reader.Close() + + _, err := io.ReadAll(reader) + if !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge, got %v", err) + } +} + +func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) { + t.Helper() + + body := strings.NewReader(`{"name":"abcdef"}`) + req := httptest.NewRequest(http.MethodPost, "/json", body) + req.Header.Set("Content-Type", "application/json") + + c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) + c.SetMaxRequestBodySize(8) + + var payload struct { + Name string `json:"name"` + } + + err := c.ShouldBindJSON(&payload) + if !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge, got %v", err) + } +} + +func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) { + t.Helper() + + body := strings.NewReader("name=abcdef") + req := httptest.NewRequest(http.MethodPost, "/form", body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) + c.SetMaxRequestBodySize(4) + + var payload struct { + Name string `form:"name"` + } + + err := c.ShouldBindForm(&payload) + if !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge, got %v", err) + } +} + +func TestPostFormHonorsMaxRequestBodySize(t *testing.T) { + t.Helper() + + body := strings.NewReader("name=abcdef") + req := httptest.NewRequest(http.MethodPost, "/form", body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) + c.SetMaxRequestBodySize(4) + + if got := c.PostForm("name"); got != "" { + t.Fatalf("expected empty value on over-limit form body, got %q", got) + } + if len(c.Errors) == 0 { + t.Fatal("expected parse error to be recorded") + } + if !errors.Is(c.Errors[0], ErrBodyTooLarge) { + t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0]) + } +} diff --git a/maxreader.go b/maxreader.go index c6201e6..96e54c7 100644 --- a/maxreader.go +++ b/maxreader.go @@ -46,11 +46,29 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { // Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制. func (mbr *maxBytesReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. readSoFar := mbr.read.Load() - // 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误. - if readSoFar >= mbr.n { + if readSoFar > mbr.n { + return 0, ErrBodyTooLarge + } + + // 当已恰好读满限制时, 需要探测底层是否还有额外数据. + // 如果下一次读取立即 EOF, 说明请求体大小恰好等于限制, 属于合法情况. + if readSoFar == mbr.n { + var probe [1]byte + n, err := mbr.r.Read(probe[:]) + if n > 0 { + mbr.read.Add(int64(n)) + return 0, ErrBodyTooLarge + } + if err != nil { + return 0, err + } return 0, ErrBodyTooLarge } From 85cc9b5cf660af787b0d752066890d8dc5723fda Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 31 Mar 2026 18:59:32 +0800 Subject: [PATCH 02/55] fix(form): align PostForm parsing with body limit handling --- context.go | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/context.go b/context.go index 855206e..9c4ba7e 100644 --- a/context.go +++ b/context.go @@ -275,29 +275,29 @@ func (c *Context) PostForm(key string) string { if c.formCache == nil { if c.MaxRequestBodySize > 0 { c.prepareRequestBody() - contentType := c.Request.Header.Get("Content-Type") - mediaType, _, err := mime.ParseMediaType(contentType) - if err != nil { + } + contentType := c.Request.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + + switch mediaType { + case "multipart/form-data": + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { c.AddError(fmt.Errorf("parse form error: %w", err)) c.formCache = make(url.Values) return "" } - - switch mediaType { - case "multipart/form-data": - if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { - c.AddError(fmt.Errorf("parse form error: %w", err)) - c.formCache = make(url.Values) - return "" - } - default: - if err := c.Request.ParseForm(); err != nil { - c.AddError(fmt.Errorf("parse form error: %w", err)) - c.formCache = make(url.Values) - return "" - } + case "application/x-www-form-urlencoded": + if err := c.Request.ParseForm(); err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" } - } else { + default: if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { if !errors.Is(err, http.ErrNotMultipart) { c.AddError(fmt.Errorf("parse form error: %w", err)) From 91c50536c49c0a0eee61ca544cfd615695072e5c Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 31 Mar 2026 23:37:02 +0800 Subject: [PATCH 03/55] fix(maxreader): avoid hangs after reaching body limit --- context_bodylimit_test.go | 62 +++++++++++++++++++++++++++++++++++++++ maxreader.go | 53 +++++++++++++-------------------- 2 files changed, 83 insertions(+), 32 deletions(-) diff --git a/context_bodylimit_test.go b/context_bodylimit_test.go index 546f06e..37f5e46 100644 --- a/context_bodylimit_test.go +++ b/context_bodylimit_test.go @@ -11,6 +11,32 @@ import ( "testing" ) +type zeroNilThenEOFReader struct { + readCalls int +} + +func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) { + r.readCalls++ + if r.readCalls == 1 { + return 0, nil + } + return 0, io.EOF +} + +func (r *zeroNilThenEOFReader) Close() error { + return nil +} + +type zeroNilForeverReader struct{} + +func (r *zeroNilForeverReader) Read(_ []byte) (int, error) { + return 0, nil +} + +func (r *zeroNilForeverReader) Close() error { + return nil +} + func TestFileTextUsesProvidedStatusCode(t *testing.T) { t.Helper() @@ -63,6 +89,42 @@ func TestMaxBytesReaderRejectsOverLimit(t *testing.T) { } } +func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1) + defer reader.Close() + + buf := make([]byte, 1) + n, err := reader.Read(buf) + if n != 0 || err != nil { + t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err) + } + + n, err = reader.Read(buf) + if n != 0 || !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err) + } +} + +func TestMaxBytesReaderRejectsOverLimitWithoutProbeLoop(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(&zeroNilForeverReader{}, 0) + defer reader.Close() + + buf := make([]byte, 1) + n, err := reader.Read(buf) + if n != 0 || err != nil { + t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err) + } + + n, err = reader.Read(buf) + if n != 0 || !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge after repeated zero,nil reads, got n=%d err=%v", n, err) + } +} + func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) { t.Helper() diff --git a/maxreader.go b/maxreader.go index 96e54c7..8191853 100644 --- a/maxreader.go +++ b/maxreader.go @@ -23,6 +23,8 @@ type maxBytesReader struct { n int64 // read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数. read atomic.Int64 + // emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读. + emptyAtLimit atomic.Bool } // NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据, @@ -52,14 +54,11 @@ func (mbr *maxBytesReader) Read(p []byte) (int, error) { // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. readSoFar := mbr.read.Load() - - if readSoFar > mbr.n { + remaining := mbr.n - readSoFar + if remaining < 0 { return 0, ErrBodyTooLarge } - - // 当已恰好读满限制时, 需要探测底层是否还有额外数据. - // 如果下一次读取立即 EOF, 说明请求体大小恰好等于限制, 属于合法情况. - if readSoFar == mbr.n { + if remaining == 0 { var probe [1]byte n, err := mbr.r.Read(probe[:]) if n > 0 { @@ -69,43 +68,33 @@ func (mbr *maxBytesReader) Read(p []byte) (int, error) { if err != nil { return 0, err } - return 0, ErrBodyTooLarge + if mbr.emptyAtLimit.Swap(true) { + return 0, ErrBodyTooLarge + } + return 0, nil } + mbr.emptyAtLimit.Store(false) - // 计算当前还可以读取多少字节. - remaining := mbr.n - readSoFar - - // 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度. - // 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数. - if int64(len(p)) > remaining { - p = p[:remaining] + // 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。 + if int64(len(p))-1 > remaining { + p = p[:remaining+1] } // 从底层 Reader 读取数据. n, err := mbr.r.Read(p) - // 如果实际读取到了数据, 更新原子计数器. - if n > 0 { - readSoFar = mbr.read.Add(int64(n)) - } - - // 如果底层 Read 返回错误 (例如 io.EOF). - if err != nil { - // 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束. - // 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了. + if int64(n) <= remaining { + if n > 0 { + mbr.read.Add(int64(n)) + } return n, err } - // 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误. - // 这是处理"跨越"限制情况的关键. - if readSoFar > mbr.n { - // 返回实际读取的字节数 n, 并附上超限错误. - // 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭. - return n, ErrBodyTooLarge + // 读取结果跨过了限制,只向上层暴露允许的部分。 + if remaining > 0 { + mbr.read.Add(remaining) } - - // 一切正常, 返回读取的字节数和 nil 错误. - return n, nil + return int(remaining), ErrBodyTooLarge } // Close 方法关闭底层的 ReadCloser, 保证资源释放. From e6ff0fa6b9c32cee88e591a81c841fa7a0fedfd7 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 1 Apr 2026 00:03:23 +0800 Subject: [PATCH 04/55] fix(maxreader): treat non-positive limits as unlimited --- context_bodylimit_test.go | 27 +++++++-------------------- maxreader.go | 6 +++--- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/context_bodylimit_test.go b/context_bodylimit_test.go index 37f5e46..1e7696a 100644 --- a/context_bodylimit_test.go +++ b/context_bodylimit_test.go @@ -27,16 +27,6 @@ func (r *zeroNilThenEOFReader) Close() error { return nil } -type zeroNilForeverReader struct{} - -func (r *zeroNilForeverReader) Read(_ []byte) (int, error) { - return 0, nil -} - -func (r *zeroNilForeverReader) Close() error { - return nil -} - func TestFileTextUsesProvidedStatusCode(t *testing.T) { t.Helper() @@ -107,21 +97,18 @@ func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) { } } -func TestMaxBytesReaderRejectsOverLimitWithoutProbeLoop(t *testing.T) { +func TestMaxBytesReaderTreatsZeroLimitAsUnlimited(t *testing.T) { t.Helper() - reader := NewMaxBytesReader(&zeroNilForeverReader{}, 0) + reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abc")), 0) defer reader.Close() - buf := make([]byte, 1) - n, err := reader.Read(buf) - if n != 0 || err != nil { - t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err) + data, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("expected zero limit to leave body unlimited, got %v", err) } - - n, err = reader.Read(buf) - if n != 0 || !errors.Is(err, ErrBodyTooLarge) { - t.Fatalf("expected ErrBodyTooLarge after repeated zero,nil reads, got n=%d err=%v", n, err) + if string(data) != "abc" { + t.Fatalf("unexpected data: %q", string(data)) } } diff --git a/maxreader.go b/maxreader.go index 8191853..4d3fb2c 100644 --- a/maxreader.go +++ b/maxreader.go @@ -31,13 +31,13 @@ type maxBytesReader struct { // 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误. // // 如果 r 为 nil, 会 panic. -// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r. +// 如果 n 小于等于 0, 则读取不受限制, 直接返回原始的 r. func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { if r == nil { panic("NewMaxBytesReader called with a nil reader") } - // 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser. - if n < 0 { + // 如果限制为非正数, 意味着不限制, 直接返回原始的 ReadCloser. + if n <= 0 { return r } return &maxBytesReader{ From ed44c592d314c3212222748eb674211916c47120 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 03:18:49 +0800 Subject: [PATCH 05/55] fix(reverseproxy): align forwarding and tunnel semantics --- docs/reverse-proxy.md | 15 +- ecw.go | 2 +- engine.go | 52 +++-- respw.go | 2 +- reverseproxy.go | 368 +++++++++++++++++++++++++++++++++- reverseproxy_test.go | 451 +++++++++++++++++++++++++++++++++++++++++- 6 files changed, 864 insertions(+), 26 deletions(-) diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 5dfcbd1..1dfd760 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -242,11 +242,20 @@ const ( r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "gateway-1", + ForwardedBy: "_gateway-1", Via: "edge-1", })) ``` +如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。 + +- IPv4:`203.0.113.43` +- IPv6 / 带端口:`[2001:db8::17]:443` +- 匿名标识:`_gateway-1` +- 未知:`unknown` + +像 `gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。 + `Via` 不是“留空即禁用”的开关。当前实现中: - 如果 `Via` 非空,则使用该值追加 `Via` @@ -282,11 +291,13 @@ Touka 会尽量遵循代理链语义: Touka 的反向代理实现支持以下能力: +- `CONNECT` 隧道转发(HTTP/1.x) - `Connection: Upgrade` / `Upgrade` 协议升级转发 - WebSocket 等 101 Switching Protocols 场景 - SSE(Server-Sent Events)立即刷新 - Trailer 透传 - 1xx 响应透传 +- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理 例如,代理 WebSocket 服务: @@ -341,7 +352,7 @@ func main() { r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "gateway-1", + ForwardedBy: "_gateway-1", Via: "gateway-1", FlushInterval: 100 * time.Millisecond, ModifyRequest: func(req *http.Request) { diff --git a/ecw.go b/ecw.go index 754571f..dedbe27 100644 --- a/ecw.go +++ b/ecw.go @@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool { func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { hijacker, ok := ecw.w.(http.Hijacker) if !ok { - return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface") + return nil, nil, http.ErrNotSupported } return hijacker.Hijack() } diff --git a/engine.go b/engine.go index c2eae91..a4350c0 100644 --- a/engine.go +++ b/engine.go @@ -475,21 +475,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { func MethodNotAllowed() HandlerFunc { return func(c *Context) { httpMethod := c.Request.Method - requestPath := c.Request.URL.Path + requestPath := routeLookupPath(c.Request) engine := c.engine // 是否是OPTIONS方式 if httpMethod == http.MethodOptions { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := []string{} - for _, treeIter := range engine.methodTrees { - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - allowedMethods = append(allowedMethods, treeIter.method) - } - } + allowedMethods := engine.allowedMethodsForPath(requestPath) if len(allowedMethods) > 0 { // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) @@ -705,7 +696,7 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // 这是路由查找和执行的核心逻辑 func (engine *Engine) handleRequest(c *Context) { httpMethod := c.Request.Method - requestPath := c.Request.URL.Path + requestPath := routeLookupPath(c.Request) // 查找对应的路由树的根节点 rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 @@ -725,7 +716,7 @@ func (engine *Engine) handleRequest(c *Context) { } // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) - if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 + if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向 if value.tsr && engine.RedirectTrailingSlash { // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ redirectPath := requestPath @@ -782,6 +773,41 @@ func (engine *Engine) handleRequest(c *Context) { //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } +func routeLookupPath(req *http.Request) string { + if req == nil { + return "" + } + + if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") { + return "/" + req.RequestURI + } + if isGeneralOptionsRequest(req) { + return "" + } + if req.URL == nil { + return "" + } + return req.URL.Path +} + +func isGeneralOptionsRequest(req *http.Request) bool { + return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" +} + +func (engine *Engine) allowedMethodsForPath(requestPath string) []string { + allowedMethods := make([]string, 0, len(engine.methodTrees)) + for _, treeIter := range engine.methodTrees { + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) + PutTempSkippedNodes(tempSkippedNodes) + if value.handlers != nil { + allowedMethods = append(allowedMethods, treeIter.method) + } + } + return allowedMethods +} + // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // 它可以用于在长连接 (如 SSE) 中监听关闭信号. func (engine *Engine) Context() context.Context { diff --git a/respw.go b/respw.go index dd94db3..ef5cc3c 100644 --- a/respw.go +++ b/respw.go @@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) { // 尝试从底层 ResponseWriter 获取 Hijacker 接口 hj, ok := rw.ResponseWriter.(http.Hijacker) if !ok { - return nil, nil, errors.New("http.Hijacker interface not supported") + return nil, nil, http.ErrNotSupported } // 调用底层的 Hijack 方法 diff --git a/reverseproxy.go b/reverseproxy.go index 1730b1e..977402b 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -14,6 +14,7 @@ import ( "net" "net/http" "net/http/httptrace" + "net/http/httputil" "net/netip" "net/textproto" "net/url" @@ -217,6 +218,12 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { default: proxy.config.ForwardedHeaders = ForwardedBoth } + proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy) + if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) { + if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil { + proxy.configError = err + } + } return proxy } @@ -234,11 +241,20 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { transport = http.DefaultTransport } + updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) + if err != nil { + p.handleError(c, err) + return + } + if handledLocally { + return + } + ctx, cancel := p.requestContext(c) defer cancel() outreq := c.Request.Clone(ctx) - if c.Request.ContentLength == 0 { + if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { outreq.Body = nil } if outreq.Body != nil { @@ -249,12 +265,35 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { outreq.Header = make(http.Header) } outreq.Close = false - - rewriteReverseProxyURL(outreq, p.target) - if !p.config.PreserveHost { - outreq.Host = "" + var connectWriter *io.PipeWriter + defer func() { + if connectWriter != nil { + _ = connectWriter.Close() + } + }() + if outreq.Method == http.MethodConnect { + pipeReader, pipeWriter := io.Pipe() + outreq.Body = pipeReader + outreq.ContentLength = -1 + defer outreq.Body.Close() + connectWriter = pipeWriter + } + + if outreq.Method == http.MethodConnect { + if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { + p.handleError(c, err) + return + } + } else { + rewriteReverseProxyURL(outreq, p.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } + if updatedMaxForwards != "" { + outreq.Header.Set("Max-Forwards", updatedMaxForwards) } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) reqUpType := reverseProxyUpgradeType(outreq.Header) if reqUpType != "" && !isPrintableASCII(reqUpType) { @@ -318,6 +357,23 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { return } + if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { + removeHopByHopHeaders(res.Header) + res.Header.Del("Content-Length") + res.Header.Del("Transfer-Encoding") + res.ContentLength = -1 + res.TransferEncoding = nil + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return + } + if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil { + p.handleError(c, err) + } + connectWriter = nil + return + } + if res.StatusCode == http.StatusSwitchingProtocols { appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { @@ -353,6 +409,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { defer res.Body.Close() c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err)) p.logf(c, "reverse proxy body copy failed: %v", err) + if reverseProxyShouldPanicOnCopyError(c.Request) { + panic(http.ErrAbortHandler) + } return } res.Body.Close() @@ -378,6 +437,86 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { } } +func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) { + if c == nil || c.Request == nil { + return "", false, nil + } + + switch c.Request.Method { + case http.MethodOptions, http.MethodTrace: + default: + return "", false, nil + } + + rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards")) + if rawValue == "" { + return "", false, nil + } + + value, err := strconv.Atoi(rawValue) + if err != nil || value < 0 { + return "", false, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("invalid Max-Forwards value %q", rawValue), + } + } + if value == 0 { + switch c.Request.Method { + case http.MethodTrace: + return "", true, p.writeLocalTraceResponse(c) + case http.MethodOptions: + p.writeLocalOptionsResponse(c) + return "", true, nil + } + } + + return strconv.Itoa(value - 1), false, nil +} + +func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error { + if c == nil || c.Request == nil { + return nil + } + + traceReq := c.Request.Clone(c.Request.Context()) + traceReq.Body = nil + traceReq.ContentLength = 0 + traceReq.TransferEncoding = nil + traceReq.RequestURI = c.Request.RequestURI + if traceReq.RequestURI == "" && traceReq.URL != nil { + traceReq.RequestURI = traceReq.URL.RequestURI() + } + traceReq.Header = traceReq.Header.Clone() + for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} { + traceReq.Header.Del(key) + } + + dump, err := httputil.DumpRequest(traceReq, false) + if err != nil { + return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err} + } + + c.Writer.Header().Set("Content-Type", "message/http") + c.Writer.WriteHeader(http.StatusOK) + _, err = c.Writer.Write(dump) + return err +} + +func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { + if c == nil { + return + } + + if c.engine != nil { + if c.Request != nil && c.Request.RequestURI != "*" { + if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 { + c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) + } + } + } + c.Writer.WriteHeader(http.StatusOK) +} + func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) { ctx := c.Request.Context() if ctx.Done() != nil { @@ -522,7 +661,11 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques clientConn, brw, err := c.Writer.Hijack() if err != nil { backConn.Close() - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + status := http.StatusBadGateway + if errors.Is(err, http.ErrNotSupported) { + status = http.StatusNotImplemented + } + return &reverseProxyStatusError{status: status, err: err} } defer clientConn.Close() @@ -561,6 +704,80 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { + if backWrite == nil { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"), + } + } + backRead := res.Body + + clientConn, brw, err := c.Writer.Hijack() + if err != nil { + backRead.Close() + _ = backWrite.Close() + status := http.StatusBadGateway + if errors.Is(err, http.ErrNotSupported) { + status = http.StatusNotImplemented + } + return &reverseProxyStatusError{status: status, err: err} + } + + defer clientConn.Close() + defer backRead.Close() + defer backWrite.Close() + + backConnClosed := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + case <-backConnClosed: + } + backRead.Close() + _ = backWrite.Close() + }() + defer close(backConnClosed) + + res.Body = nil + if err := res.Write(brw); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + if err := brw.Flush(); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + go func() { + if _, err := io.Copy(clientConn, backRead); err != nil { + errc <- err + return + } + if cw, ok := clientConn.(interface{ CloseWrite() error }); ok { + errc <- cw.CloseWrite() + return + } + errc <- errReverseProxyCopyDone + }() + go func() { + if _, err := io.Copy(backWrite, clientConn); err != nil { + errc <- err + return + } + errc <- backWrite.Close() + }() + + firstErr := <-errc + if firstErr == nil { + firstErr = <-errc + } + if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) { + return nil + } + return firstErr +} + func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { return -1 @@ -638,6 +855,10 @@ func reverseProxyStatusCode(err error) int { if errors.As(err, &statusErr) && statusErr.status > 0 { return statusErr.status } + var netErr net.Error + if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) { + return http.StatusGatewayTimeout + } return http.StatusBadGateway } @@ -651,6 +872,17 @@ func validateReverseProxyTarget(target *url.URL) error { return nil } +func validateReverseProxyForwardedBy(value string) error { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return nil + } + if !isValidForwardedNodeIdentifier(trimmed) { + return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value) + } + return nil +} + func normalizeReverseProxyTarget(target *url.URL) { switch strings.ToLower(target.Scheme) { case "ws": @@ -732,6 +964,83 @@ func buildForwardedHeaderValue(clientIP, by, host, scheme string) string { return strings.Join(pairs, ";") } +func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { + return policy == ForwardedBoth || policy == ForwardedRFC7239Only +} + +func isValidForwardedNodeIdentifier(value string) bool { + if value == "" { + return false + } + if strings.HasPrefix(value, "[") { + closing := strings.IndexByte(value, ']') + if closing <= 1 { + return false + } + addr, err := netip.ParseAddr(value[1:closing]) + if err != nil || !addr.Is6() { + return false + } + if closing == len(value)-1 { + return true + } + if value[closing+1] != ':' { + return false + } + return isValidForwardedNodePort(value[closing+2:]) + } + + host, port, hasPort := strings.Cut(value, ":") + if hasPort { + switch { + case host == "unknown", isValidForwardedObfuscatedIdentifier(host): + return isValidForwardedNodePort(port) + default: + addr, err := netip.ParseAddr(host) + return err == nil && addr.Is4() && isValidForwardedNodePort(port) + } + } + + if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) { + return true + } + addr, err := netip.ParseAddr(value) + return err == nil && addr.Is4() +} + +func isValidForwardedNodePort(value string) bool { + if value == "" { + return false + } + if isValidForwardedObfuscatedIdentifier(value) { + return true + } + if len(value) > 5 { + return false + } + port, err := strconv.Atoi(value) + return err == nil && port > 0 && port <= 65535 +} + +func isValidForwardedObfuscatedIdentifier(value string) bool { + if len(value) < 2 || value[0] != '_' { + return false + } + for i := 1; i < len(value); i++ { + b := value[i] + if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') { + continue + } + switch b { + case '.', '_', '-': + continue + default: + return false + } + } + return true +} + func formatForwardedFor(clientIP string) string { addr, err := netip.ParseAddr(clientIP) if err != nil { @@ -817,6 +1126,47 @@ func rewriteReverseProxyURL(req *http.Request, target *url.URL) { } } +func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error { + connectTarget, err := reverseProxyConnectTarget(target) + if err != nil { + return &reverseProxyStatusError{status: http.StatusBadRequest, err: err} + } + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = "" + req.URL.RawPath = "" + req.URL.RawQuery = "" + req.URL.Opaque = connectTarget + req.Host = connectTarget + return nil +} + +func reverseProxyConnectTarget(target *url.URL) (string, error) { + if target == nil { + return "", errReverseProxyNilTarget + } + host := target.Hostname() + if host == "" { + return "", errReverseProxyInvalidTarget + } + port := target.Port() + if port == "" { + switch strings.ToLower(target.Scheme) { + case "http": + port = "80" + case "https": + port = "443" + default: + return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme) + } + } + portNum, err := strconv.Atoi(port) + if err != nil || portNum <= 0 || portNum > 65535 { + return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port) + } + return net.JoinHostPort(host, port), nil +} + func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) { if base.RawPath == "" && incoming.RawPath == "" { return reverseProxySingleJoiningSlash(base.Path, incoming.Path), "" @@ -919,6 +1269,10 @@ func cleanReverseProxyQueryParams(rawQuery string) string { return values.Encode() } +func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { + return req != nil && req.Context().Value(http.ServerContextKey) != nil +} + func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { return UnwrapResponseWriter(writer) } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index f82aff9..b7df512 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -2,6 +2,7 @@ package touka import ( "bufio" + "context" "errors" "fmt" "io" @@ -70,7 +71,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ Target: target, ForwardedHeaders: ForwardedBoth, - ForwardedBy: "proxy-node", + ForwardedBy: "_proxy-node", Via: "proxy.test", })) @@ -144,7 +145,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { if !strings.Contains(got.Forwarded, "for=198.51.100.10") { t.Fatalf("forwarded header missing client ip: %q", got.Forwarded) } - if !strings.Contains(got.Forwarded, "by=proxy-node") { + if !strings.Contains(got.Forwarded, "by=_proxy-node") { t.Fatalf("forwarded header missing by token: %q", got.Forwarded) } if !strings.Contains(got.Forwarded, "host=client.example") { @@ -170,6 +171,61 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { } } +func TestReverseProxyRejectsInvalidForwardedBy(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + ForwardedHeaders: ForwardedBoth, + ForwardedBy: "proxy-node", + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusInternalServerError { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyForwardedByTrimsWhitespace(t *testing.T) { + t.Helper() + + forwardedCh := make(chan string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + forwardedCh <- r.Header.Get("Forwarded") + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, backend.URL), + ForwardedHeaders: ForwardedBoth, + ForwardedBy: " _proxy-node ", + })) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + req.RemoteAddr = "198.51.100.10:4567" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case forwarded := <-forwardedCh: + if !strings.Contains(forwarded, "by=_proxy-node") { + t.Fatalf("unexpected Forwarded header: %q", forwarded) + } + if strings.Contains(forwarded, `by=" _proxy-node "`) { + t.Fatalf("forwarded header should not preserve surrounding whitespace: %q", forwarded) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Forwarded header") + } +} + func TestReverseProxyDefaultViaFallback(t *testing.T) { t.Helper() @@ -229,6 +285,23 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) { } } +func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, context.DeadlineExceeded + }), + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusGatewayTimeout { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) { t.Helper() @@ -452,6 +525,362 @@ func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) { } } +func TestReverseProxyUpgradeNeedsHijacker(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("backend response writer does not support hijack") + } + conn, brw, err := hj.Hijack() + if err != nil { + t.Fatalf("backend hijack failed: %v", err) + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + _ = brw.Flush() + })) + defer backend.Close() + + engine := New() + engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/ws", nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotImplemented { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyMaxForwardsTraceHandledLocally(t *testing.T) { + t.Helper() + + called := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) + req.RequestURI = "/trace" + req.Header.Set("Max-Forwards", "0") + req.Header.Set("Authorization", "secret") + req.Header.Set("Cookie", "a=b") + req.Header.Set("Forwarded", "for=192.0.2.1") + + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + resp := rr.Result() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if got := resp.Header.Get("Content-Type"); got != "message/http" { + t.Fatalf("unexpected content type: %q", got) + } + if !strings.Contains(string(body), "TRACE /trace HTTP/1.1") { + t.Fatalf("trace body missing request line: %q", string(body)) + } + if strings.Contains(string(body), "Authorization:") { + t.Fatalf("trace body leaked authorization header: %q", string(body)) + } + if strings.Contains(string(body), "Cookie:") { + t.Fatalf("trace body leaked cookie header: %q", string(body)) + } + if strings.Contains(string(body), "Forwarded:") { + t.Fatalf("trace body leaked forwarded header: %q", string(body)) + } + + select { + case <-called: + t.Fatal("backend should not be called when Max-Forwards is zero") + default: + } +} + +func TestReverseProxyMaxForwardsTraceDecrementsBeforeForwarding(t *testing.T) { + t.Helper() + + maxForwardsCh := make(chan string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + maxForwardsCh <- r.Header.Get("Max-Forwards") + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) + req.Header.Set("Max-Forwards", "2") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case got := <-maxForwardsCh: + if got != "1" { + t.Fatalf("unexpected Max-Forwards header: %q", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Max-Forwards") + } +} + +func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) { + t.Helper() + + called := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", func(c *Context) { c.Status(http.StatusNoContent) }) + engine.OPTIONS("/proxy", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodOptions, "http://client.example/proxy", nil) + req.Header.Set("Max-Forwards", "0") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rr.Code) + } + allow := rr.Header().Get("Allow") + if !strings.Contains(allow, http.MethodGet) || !strings.Contains(allow, http.MethodOptions) { + t.Fatalf("unexpected Allow header: %q", allow) + } + + select { + case <-called: + t.Fatal("backend should not be called when Max-Forwards is zero") + default: + } +} + +func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { + t.Helper() + + engine := New() + engine.OPTIONS("/", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + req := httptest.NewRequest(http.MethodOptions, "http://client.example/", nil) + req.RequestURI = "*" + req.URL.Path = "" + req.URL.RawPath = "" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code) + } +} + +func TestReverseProxyConnectTunnel(t *testing.T) { + t.Helper() + + backendAddr := "" + errCh := make(chan error, 4) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if got, want := r.RequestURI, backendAddr; got != want { + errCh <- fmt.Errorf("unexpected CONNECT target %q, want %q", got, want) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("backend response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\nVia: 1.1 upstream\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("backend read failed: %w", err) + return + } + _, _ = io.WriteString(brw, strings.ToUpper(line)) + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend write failed: %w", err) + return + } + })) + defer backend.Close() + backendAddr = strings.TrimPrefix(backend.URL, "http://") + + engine := New() + engine.Handle(http.MethodConnect, "/:authority", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, backend.URL), + Via: "proxy.test", + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + + _, err = fmt.Fprintf(conn, "CONNECT origin.example:443 HTTP/1.1\r\nHost: origin.example:443\r\n\r\n") + if err != nil { + t.Fatalf("write connect request: %v", err) + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read status line: %v", err) + } + if !strings.Contains(statusLine, "200") { + t.Fatalf("unexpected status line: %q", statusLine) + } + + headers, err := textproto.NewReader(reader).ReadMIMEHeader() + if err != nil { + t.Fatalf("read headers: %v", err) + } + respHeader := http.Header(headers) + if got := respHeader.Get("Content-Length"); got != "" { + t.Fatalf("CONNECT response should not include Content-Length, got %q", got) + } + if got := respHeader.Get("Transfer-Encoding"); got != "" { + t.Fatalf("CONNECT response should not include Transfer-Encoding, got %q", got) + } + if gotVia := respHeader.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.1 upstream" || gotVia[1] != "1.1 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(conn, "ping\n"); err != nil { + t.Fatalf("write tunneled payload: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read tunneled payload: %v", err) + } + if message != "PING\n" { + t.Fatalf("unexpected tunneled payload: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyConnectNeedsHijacker(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("backend response writer does not support hijack") + } + conn, brw, err := hj.Hijack() + if err != nil { + t.Fatalf("backend hijack failed: %v", err) + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\n\r\n") + _ = brw.Flush() + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/tunnel", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodConnect, "http://client.example/tunnel", nil) + req.URL.Path = "/tunnel" + req.RequestURI = "/tunnel" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotImplemented { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + Body: &failingReadCloser{chunks: []string{"ok"}, err: errors.New("boom")}, + ContentLength: -1, + Request: req, + }, nil + }), + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := proxy.Client().Get(proxy.URL + "/proxy") + if err != nil { + t.Fatalf("perform request: %v", err) + } + _, err = io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err == nil { + t.Fatal("expected body read to fail after upstream copy error") + } +} + func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) { t.Helper() @@ -568,3 +997,21 @@ func mustParseURL(t *testing.T, raw string) *url.URL { } return u } + +type failingReadCloser struct { + chunks []string + err error +} + +func (r *failingReadCloser) Read(p []byte) (int, error) { + if len(r.chunks) == 0 { + return 0, r.err + } + n := copy(p, r.chunks[0]) + r.chunks = r.chunks[1:] + return n, nil +} + +func (r *failingReadCloser) Close() error { + return nil +} From 2165cc4114e9c33a7b8176df17b0ffe370abe822 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 03:53:17 +0800 Subject: [PATCH 06/55] feat(http2): support OPTIONS * and extended CONNECT --- docs/reverse-proxy.md | 1 + docs/routing.md | 2 + engine.go | 25 +++++++++ go.mod | 3 +- go.sum | 2 + http2xconnect.go | 53 ++++++++++++++++++ reverseproxy.go | 119 ++++++++++++++++++++++++++++++++++++---- reverseproxy_test.go | 123 +++++++++++++++++++++++++++++++++++++++++- 8 files changed, 316 insertions(+), 12 deletions(-) create mode 100644 http2xconnect.go diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 1dfd760..959d866 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -292,6 +292,7 @@ Touka 会尽量遵循代理链语义: Touka 的反向代理实现支持以下能力: - `CONNECT` 隧道转发(HTTP/1.x) +- HTTP/2 extended `CONNECT` - `Connection: Upgrade` / `Upgrade` 协议升级转发 - WebSocket 等 101 Switching Protocols 场景 - SSE(Server-Sent Events)立即刷新 diff --git a/docs/routing.md b/docs/routing.md index e90308e..223081a 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -22,6 +22,8 @@ r.ANY("/any", handle) r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) ``` +服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。 + ## 路径参数 (Named Parameters) 使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 diff --git a/engine.go b/engine.go index a4350c0..b7cf330 100644 --- a/engine.go +++ b/engine.go @@ -7,6 +7,7 @@ package touka import ( "context" "errors" + "io" "reflect" "runtime" "strings" @@ -344,6 +345,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) { func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { if engine.serverProtocols != nil { srv.Protocols = engine.serverProtocols + if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() { + if err := configureHTTP2ExtendedConnectServer(srv); err != nil { + panic(err) + } + } } } @@ -695,6 +701,11 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // handleRequest 负责根据请求查找路由并执行相应的处理函数链 // 这是路由查找和执行的核心逻辑 func (engine *Engine) handleRequest(c *Context) { + if isGeneralOptionsRequest(c.Request) { + engine.handleGeneralOptions(c) + return + } + httpMethod := c.Request.Method requestPath := routeLookupPath(c.Request) @@ -808,6 +819,20 @@ func (engine *Engine) allowedMethodsForPath(requestPath string) []string { return allowedMethods } +func (engine *Engine) handleGeneralOptions(c *Context) { + if c == nil || c.Request == nil { + return + } + + c.Writer.Header().Set("Content-Length", "0") + if c.Request.ContentLength != 0 { + mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10) + _, _ = io.Copy(io.Discard, mb) + } + c.Writer.WriteHeader(http.StatusOK) + c.Abort() +} + // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // 它可以用于在长连接 (如 SSE) 中监听关闭信号. func (engine *Engine) Context() context.Context { diff --git a/go.mod b/go.mod index 42f4be4..bd0c046 100644 --- a/go.mod +++ b/go.mod @@ -8,9 +8,10 @@ require ( github.com/WJQSERVER/wanf v0.0.8 github.com/fenthope/reco v0.0.5 github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 + golang.org/x/net v0.52.0 ) require ( github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.52.0 // indirect + golang.org/x/text v0.35.0 // indirect ) diff --git a/go.sum b/go.sum index b49879b..6a8d0c6 100644 --- a/go.sum +++ b/go.sum @@ -12,3 +12,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= diff --git a/http2xconnect.go b/http2xconnect.go new file mode 100644 index 0000000..b3b12a0 --- /dev/null +++ b/http2xconnect.go @@ -0,0 +1,53 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2026 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "net/url" + "strings" + "sync" + _ "unsafe" + + "golang.org/x/net/http2" +) + +var enableHTTP2ExtendedConnectOnce sync.Once + +//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol +var xnetDisableHTTP2ExtendedConnectProtocol bool + +func enableHTTP2ExtendedConnectProtocol() { + enableHTTP2ExtendedConnectOnce.Do(func() { + xnetDisableHTTP2ExtendedConnectProtocol = false + }) +} + +func configureHTTP2ExtendedConnectServer(srv *http.Server) error { + if srv == nil { + return nil + } + enableHTTP2ExtendedConnectProtocol() + return http2.ConfigureServer(srv, nil) +} + +func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper { + enableHTTP2ExtendedConnectProtocol() + + transport := &http2.Transport{} + if target == nil || !strings.EqualFold(target.Scheme, "http") { + return transport + } + + transport.AllowHTTP = true + transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) + } + return transport +} diff --git a/reverseproxy.go b/reverseproxy.go index 977402b..e01f4d0 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -67,10 +67,11 @@ var ( ) type reverseProxyHandler struct { - config ReverseProxyConfig - target *url.URL - receivedBy string - configError error + config ReverseProxyConfig + target *url.URL + receivedBy string + configError error + extendedConnectTransport http.RoundTripper } type reverseProxyStatusError struct { @@ -208,6 +209,9 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { target: target, receivedBy: reverseProxyReceivedBy(config.Via), } + if config.Transport == nil { + proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) + } if err := validateReverseProxyTarget(target); err != nil { proxy.configError = err @@ -238,7 +242,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { transport := p.config.Transport if transport == nil { - transport = http.DefaultTransport + if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil { + transport = p.extendedConnectTransport + } else { + transport = http.DefaultTransport + } } updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) @@ -280,9 +288,17 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { } if outreq.Method == http.MethodConnect { - if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { - p.handleError(c, err) - return + if reverseProxyIsExtendedConnectRequest(outreq) { + rewriteReverseProxyURL(outreq, p.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } else { + if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { + p.handleError(c, err) + return + } } } else { rewriteReverseProxyURL(outreq, p.target) @@ -367,7 +383,11 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { if !p.modifyResponse(c, res, outreq) { return } - if err := p.handleConnectResponse(c, outreq, res, connectWriter); err != nil { + handleConnect := p.handleConnectResponse + if reverseProxyIsExtendedConnectRequest(outreq) { + handleConnect = p.handleExtendedConnectResponse + } + if err := handleConnect(c, outreq, res, connectWriter); err != nil { p.handleError(c, err) } connectWriter = nil @@ -778,6 +798,72 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { + if c == nil || c.Request == nil { + res.Body.Close() + if backWrite != nil { + _ = backWrite.Close() + } + return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")} + } + if backWrite == nil { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"), + } + } + + controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + res.Body.Close() + _ = backWrite.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + reverseProxyCopyHeader(c.Writer.Header(), res.Header) + c.Writer.WriteHeader(res.StatusCode) + if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + res.Body.Close() + _ = backWrite.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(backWrite, c.Request.Body) + closeErr := backWrite.Close() + if err != nil && !reverseProxyIsBenignTunnelError(err) { + errc <- err + return + } + errc <- closeErr + }() + go func() { + copyErr := p.copyResponse(c.Writer, res.Body, -1) + closeErr := res.Body.Close() + if copyErr != nil { + errc <- copyErr + return + } + errc <- closeErr + }() + + firstErr := <-errc + _ = c.Request.Body.Close() + _ = backWrite.Close() + _ = res.Body.Close() + secondErr := <-errc + + for _, err := range []error{firstErr, secondErr} { + if reverseProxyIsBenignTunnelError(err) { + continue + } + return err + } + return nil +} + func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { return -1 @@ -968,6 +1054,17 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } +func reverseProxyIsExtendedConnectRequest(req *http.Request) bool { + return reverseProxyExtendedConnectProtocol(req) != "" +} + +func reverseProxyExtendedConnectProtocol(req *http.Request) string { + if req == nil || req.Method != http.MethodConnect || req.Header == nil { + return "" + } + return textproto.TrimString(req.Header.Get(":protocol")) +} + func isValidForwardedNodeIdentifier(value string) bool { if value == "" { return false @@ -1273,6 +1370,10 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { return req != nil && req.Context().Value(http.ServerContextKey) != nil } +func reverseProxyIsBenignTunnelError(err error) bool { + return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) +} + func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { return UnwrapResponseWriter(writer) } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index b7df512..345dd97 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -3,6 +3,7 @@ package touka import ( "bufio" "context" + "crypto/tls" "errors" "fmt" "io" @@ -15,6 +16,8 @@ import ( "strings" "testing" "time" + + "golang.org/x/net/http2" ) func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { @@ -680,7 +683,7 @@ func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) { } } -func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { +func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) { t.Helper() engine := New() @@ -695,9 +698,12 @@ func TestEngineDoesNotTreatOptionsAsteriskAsSlashRoute(t *testing.T) { rr := httptest.NewRecorder() engine.ServeHTTP(rr, req) - if rr.Code != http.StatusNotFound { + if rr.Code != http.StatusOK { t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code) } + if got := rr.Header().Get("Content-Length"); got != "0" { + t.Fatalf("unexpected Content-Length header: %q", got) + } } func TestReverseProxyConnectTunnel(t *testing.T) { @@ -848,6 +854,119 @@ func TestReverseProxyConnectNeedsHijacker(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.ProtoMajor != 2 { + errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get(":protocol"); got != "websocket" { + errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.URL.Path; got != "/ws" { + errCh <- fmt.Errorf("unexpected upstream path: %q", got) + w.WriteHeader(http.StatusBadRequest) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("enable full duplex failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + line, err := bufio.NewReader(r.Body).ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(w, "echo:"+line); err != nil { + errCh <- fmt.Errorf("write tunneled response body failed: %w", err) + return + } + _ = controller.Flush() + })) + upstream.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { + t.Fatalf("configure upstream HTTP/2 server: %v", err) + } + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled response body: %q", message) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { t.Helper() From 59f190ce3a6097e659fcb46ecc630a65a8e8eebb Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 04:09:43 +0800 Subject: [PATCH 07/55] fix(http2): preserve extended CONNECT tunnel shutdown semantics --- reverseproxy.go | 52 ++++++++--- reverseproxy_test.go | 217 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+), 11 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index e01f4d0..bb1784b 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -829,6 +829,19 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } + var closeOnce sync.Once + closeTunnel := func() { + closeOnce.Do(func() { + _ = c.Request.Body.Close() + _ = backWrite.Close() + _ = res.Body.Close() + }) + } + go func() { + <-req.Context().Done() + closeTunnel() + }() + errc := make(chan error, 2) go func() { _, err := io.Copy(backWrite, c.Request.Body) @@ -849,19 +862,24 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt errc <- closeErr }() - firstErr := <-errc - _ = c.Request.Body.Close() - _ = backWrite.Close() - _ = res.Body.Close() - secondErr := <-errc - - for _, err := range []error{firstErr, secondErr} { + var firstErr error + for i := 0; i < 2; i++ { + err := <-errc if reverseProxyIsBenignTunnelError(err) { continue } - return err + if firstErr == nil { + firstErr = err + closeTunnel() + } } - return nil + closeTunnel() + if reverseProxyIsBenignTunnelError(firstErr) { + return nil + } + + return firstErr + } func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { @@ -902,7 +920,7 @@ func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byt var written int64 for { nr, rerr := src.Read(buf) - if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) { + if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) { p.logf(nil, "reverse proxy read error during body copy: %v", rerr) } if nr > 0 { @@ -1371,7 +1389,19 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { } func reverseProxyIsBenignTunnelError(err error) bool { - return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) + return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err) +} + +func reverseProxyIsClosedBodyError(err error) bool { + if err == nil { + return false + } + switch err.Error() { + case "body closed by handler", "http2: response body closed", "response body closed": + return true + default: + return false + } } func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 345dd97..e56aa5e 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -967,6 +967,223 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("enable full duplex failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + reader := bufio.NewReader(r.Body) + line, err := reader.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(w, "ack:"+line); err != nil { + errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err) + return + } + _ = controller.Flush() + + if _, err := io.Copy(io.Discard, reader); err != nil { + errCh <- fmt.Errorf("wait for request half-close failed: %w", err) + return + } + if _, err := io.WriteString(w, "after-close\n"); err != nil { + errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err) + return + } + _ = controller.Flush() + })) + upstream.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { + t.Fatalf("configure upstream HTTP/2 server: %v", err) + } + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + reader := bufio.NewReader(resp.Body) + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read immediate tunneled response: %v", err) + } + if message != "ack:ping\n" { + t.Fatalf("unexpected immediate tunneled response: %q", message) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + + message, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("read post-close tunneled response: %v", err) + } + if message != "after-close\n" { + t.Fatalf("unexpected post-close tunneled response: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("enable full duplex failed: %w", err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + <-r.Context().Done() + })) + upstream.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { + t.Fatalf("configure upstream HTTP/2 server: %v", err) + } + upstream.StartTLS() + defer upstream.Close() + + proxyErrCh := make(chan error, 1) + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + select { + case proxyErrCh <- err: + default: + } + }, + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(ctx, http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + writeErrCh := make(chan error, 1) + go func() { + _, err := io.WriteString(pw, strings.Repeat("x", 1<<20)) + writeErrCh <- err + }() + time.Sleep(50 * time.Millisecond) + + cancel() + _ = pw.CloseWithError(context.Canceled) + select { + case <-writeErrCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for request body writer to unblock") + } + + select { + case err := <-proxyErrCh: + t.Fatalf("proxy error handler should not be called on cancellation, got: %v", err) + case <-time.After(200 * time.Millisecond): + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { t.Helper() From 919236665bfa59a55a3c46e3bfec9a4114edf28f Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:40:56 +0800 Subject: [PATCH 08/55] feat(reverseproxy): add upstream balancing and failover --- docs/reverse-proxy.md | 113 +++++++- reverseproxy.go | 426 +++++++++++++++++++++-------- reverseproxy_lb.go | 352 ++++++++++++++++++++++++ reverseproxy_test.go | 619 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1394 insertions(+), 116 deletions(-) create mode 100644 reverseproxy_lb.go diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 959d866..15ebafd 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -59,7 +59,11 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ ```go type ReverseProxyConfig struct { - Target *url.URL + Target *url.URL + Targets []string + + LoadBalancing ReverseProxyLoadBalancingConfig + PassiveHealth ReverseProxyPassiveHealthConfig Transport http.RoundTripper FlushInterval time.Duration @@ -78,12 +82,115 @@ type ReverseProxyConfig struct { ### `Target` -必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。 +与 `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme` 和 `host`。 ```go target, _ := url.Parse("http://backend:9000") ``` +### `Targets` + +可选。用于配置多个后端目标地址。 + +- `Target` 与 `Targets` 互斥,只能使用其中一种 +- `Targets` 的每一项都必须是完整 URL +- 每个 target 仍然可以自带 base path 和 query + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Targets: []string{ + "http://127.0.0.1:9001/base?from=a", + "http://127.0.0.1:9002/base?from=b", + }, +})) +``` + +这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。 + +### `LoadBalancing` + +用于配置 upstream 选择策略和重试行为。 + +```go +type ReverseProxyLoadBalancingConfig struct { + Policy ReverseProxyLBPolicy + Retries int + TryDuration time.Duration + TryInterval time.Duration +} +``` + +当前内置策略: + +- `touka.LBRandom()` +- `touka.LBRoundRobin()` +- `touka.LBFirst()` +- `touka.LBLeastConn()` +- `touka.LBIPHash()` +- `touka.LBClientIPHash()` +- `touka.LBURIHash()` +- `touka.LBHeader("X-Upstream", fallback)` +- `touka.LBQuery("tenant", fallback)` + +其中: + +- `LBFirst()` 适合主备/故障转移顺序 +- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback +- 如果 `LBHeader` / `LBQuery` 没有显式 fallback,则默认回退到 `LBRandom()` + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Targets: []string{ + "http://127.0.0.1:9001", + "http://127.0.0.1:9002", + }, + LoadBalancing: touka.ReverseProxyLoadBalancingConfig{ + Policy: touka.LBHeader("X-Upstream", touka.LBFirst()), + Retries: 1, + }, +})) +``` + +重试说明: + +- 只对未开始收到上游响应的失败进行重试 +- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试 +- `Retries` 表示额外重试次数 +- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机 +- `TryInterval` 表示两次重试之间的等待间隔 + +### `PassiveHealth` + +用于配置被动健康检查。它不会后台探测 upstream,而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。 + +```go +type ReverseProxyPassiveHealthConfig struct { + FailDuration time.Duration + MaxFails int + UnhealthyStatus []int +} +``` + +- `FailDuration > 0` 时启用被动健康跟踪 +- `MaxFails <= 0` 时默认按 `1` 处理 +- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Targets: []string{ + "http://127.0.0.1:9001", + "http://127.0.0.1:9002", + }, + LoadBalancing: touka.ReverseProxyLoadBalancingConfig{ + Policy: touka.LBFirst(), + }, + PassiveHealth: touka.ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + UnhealthyStatus: []int{http.StatusServiceUnavailable}, + }, +})) +``` + ### `Transport` 可选。用于自定义底层转发所使用的 `http.RoundTripper`。 @@ -150,6 +257,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ 在请求真正发往后端前,对出站请求做最后修改。 +如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。 + 常见用途: - 覆盖 `Host` diff --git a/reverseproxy.go b/reverseproxy.go index bb1784b..186e163 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -23,6 +23,8 @@ import ( "sync" "sync/atomic" "time" + + "golang.org/x/net/http2" ) // ForwardedHeadersPolicy controls how forwarding headers are generated. @@ -44,7 +46,11 @@ type BufferPool interface { // ReverseProxyConfig configures the reverse proxy handler. type ReverseProxyConfig struct { - Target *url.URL + Target *url.URL + Targets []string + + LoadBalancing ReverseProxyLoadBalancingConfig + PassiveHealth ReverseProxyPassiveHealthConfig Transport http.RoundTripper FlushInterval time.Duration @@ -61,17 +67,18 @@ type ReverseProxyConfig struct { } var ( - errReverseProxyNilTarget = errors.New("reverse proxy target is nil") - errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") - errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") + errReverseProxyNilTarget = errors.New("reverse proxy target is nil") + errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") + errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") + errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams") ) type reverseProxyHandler struct { - config ReverseProxyConfig - target *url.URL - receivedBy string - configError error - extendedConnectTransport http.RoundTripper + config ReverseProxyConfig + upstreams []*reverseProxyUpstream + receivedBy string + configError error + roundRobin atomic.Uint64 } type reverseProxyStatusError struct { @@ -199,22 +206,16 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc { } func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { - target := cloneReverseProxyURL(config.Target) - if target != nil { - normalizeReverseProxyTarget(target) - } - proxy := &reverseProxyHandler{ config: config, - target: target, receivedBy: reverseProxyReceivedBy(config.Via), } - if config.Transport == nil { - proxy.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) - } - if err := validateReverseProxyTarget(target); err != nil { + upstreams, err := buildReverseProxyUpstreams(config) + if err != nil { proxy.configError = err + } else { + proxy.upstreams = upstreams } switch config.ForwardedHeaders { @@ -228,6 +229,11 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { proxy.configError = err } } + if proxy.configError == nil { + if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil { + proxy.configError = err + } + } return proxy } @@ -240,15 +246,6 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { return } - transport := p.config.Transport - if transport == nil { - if reverseProxyIsExtendedConnectRequest(c.Request) && p.extendedConnectTransport != nil { - transport = p.extendedConnectTransport - } else { - transport = http.DefaultTransport - } - } - updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) if err != nil { p.handleError(c, err) @@ -260,86 +257,64 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { ctx, cancel := p.requestContext(c) defer cancel() + attempted := make(map[string]struct{}, len(p.upstreams)) + attempts := 0 + started := time.Now() + var lastErr error - outreq := c.Request.Clone(ctx) - if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { - outreq.Body = nil - } - if outreq.Body != nil { - outreq.Body = &noopCloseReader{readCloser: outreq.Body} - defer outreq.Body.Close() - } - if outreq.Header == nil { - outreq.Header = make(http.Header) - } - outreq.Close = false - var connectWriter *io.PipeWriter - defer func() { - if connectWriter != nil { - _ = connectWriter.Close() - } - }() - if outreq.Method == http.MethodConnect { - pipeReader, pipeWriter := io.Pipe() - outreq.Body = pipeReader - outreq.ContentLength = -1 - defer outreq.Body.Close() - connectWriter = pipeWriter - } - - if outreq.Method == http.MethodConnect { - if reverseProxyIsExtendedConnectRequest(outreq) { - rewriteReverseProxyURL(outreq, p.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } else { - if err := rewriteReverseProxyConnectRequest(outreq, p.target); err != nil { - p.handleError(c, err) + for { + upstream, err := p.selectUpstream(c, attempted) + if err != nil { + if lastErr != nil { + p.handleError(c, lastErr) return } + p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err}) + return } - } else { - rewriteReverseProxyURL(outreq, p.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } - if updatedMaxForwards != "" { - outreq.Header.Set("Max-Forwards", updatedMaxForwards) - } - reqUpType := reverseProxyUpgradeType(outreq.Header) - if reqUpType != "" && !isPrintableASCII(reqUpType) { - p.handleError(c, &reverseProxyStatusError{ - status: http.StatusBadRequest, - err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), - }) + attempts++ + upstream.inFlight.Add(1) + served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards) + upstream.inFlight.Add(-1) + + if served { + return + } + if attemptErr != nil { + lastErr = attemptErr + } + if retriable && p.shouldRetryAttempt(c.Request, attempts, started) { + attempted[upstream.key] = struct{}{} + if !p.waitRetryInterval(ctx, started) { + if lastErr != nil { + p.handleError(c, lastErr) + } + return + } + continue + } + if attemptErr != nil { + p.handleError(c, attemptErr) + return + } + if lastErr != nil { + p.handleError(c, lastErr) + return + } + p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams}) return } +} - removeHopByHopHeaders(outreq.Header) - if headerValuesContainToken(c.Request.Header["Te"], "trailers") { - outreq.Header.Set("Te", "trailers") - } - if reqUpType != "" { - outreq.Header.Set("Connection", "Upgrade") - outreq.Header.Set("Upgrade", reqUpType) - } - - p.addForwardingHeaders(c.Request, outreq) - appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) - - if _, ok := outreq.Header["User-Agent"]; !ok { - outreq.Header.Set("User-Agent", "") - } - - if p.config.ModifyRequest != nil { - p.config.ModifyRequest(outreq) +func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) { + outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards) + if err != nil { + return false, err, false } + defer cleanup() + transport := p.transportForUpstream(c.Request, upstream) rawWriter := reverseProxyBaseResponseWriter(c.Writer) var ( roundTripMu sync.Mutex @@ -369,8 +344,13 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { roundTripDone = true roundTripMu.Unlock() if err != nil { - p.handleError(c, err) - return + if reverseProxyShouldCountPassiveFailure(outreq, err) { + upstream.recordFailure(time.Now(), p.config.PassiveHealth) + } + return false, err, true + } + if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) { + upstream.recordFailure(time.Now(), p.config.PassiveHealth) } if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { @@ -381,35 +361,34 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { res.TransferEncoding = nil appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return + return true, nil, false } handleConnect := p.handleConnectResponse if reverseProxyIsExtendedConnectRequest(outreq) { handleConnect = p.handleExtendedConnectResponse } if err := handleConnect(c, outreq, res, connectWriter); err != nil { - p.handleError(c, err) + return false, err, false } - connectWriter = nil - return + return true, nil, false } if res.StatusCode == http.StatusSwitchingProtocols { appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return + return true, nil, false } if err := p.handleUpgradeResponse(c, outreq, res); err != nil { - p.handleError(c, err) + return false, err, false } - return + return true, nil, false } removeHopByHopHeaders(res.Header) appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return + return true, nil, false } reverseProxyCopyHeader(c.Writer.Header(), res.Header) @@ -432,7 +411,7 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { if reverseProxyShouldPanicOnCopyError(c.Request) { panic(http.ErrAbortHandler) } - return + return true, nil, false } res.Body.Close() @@ -440,13 +419,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { c.Writer.Flush() } - // Keep the stdlib-compatible fallback here. - // If the backend only exposes additional trailer keys after the body has been - // fully read, the trailer map can grow and those values must be written using - // the TrailerPrefix form instead of the pre-announced bare header keys. if len(res.Trailer) == announcedTrailers { reverseProxyCopyHeader(c.Writer.Header(), res.Trailer) - return + return true, nil, false } for key, values := range res.Trailer { @@ -455,6 +430,148 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { c.Writer.Header().Add(prefixedKey, value) } } + return true, nil, false +} + +func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { + outreq := c.Request.Clone(ctx) + if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { + outreq.Body = nil + } else if c.Request.GetBody != nil { + body, err := c.Request.GetBody() + if err != nil { + return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err) + } + outreq.Body = body + } else if outreq.Body != nil { + outreq.Body = &noopCloseReader{readCloser: outreq.Body} + } + if outreq.Header == nil { + outreq.Header = make(http.Header) + } + outreq.Close = false + var connectWriter *io.PipeWriter + if outreq.Method == http.MethodConnect { + pipeReader, pipeWriter := io.Pipe() + outreq.Body = pipeReader + outreq.ContentLength = -1 + connectWriter = pipeWriter + } + cleanup := func() { + if outreq.Body != nil { + _ = outreq.Body.Close() + } + if connectWriter != nil { + _ = connectWriter.Close() + } + } + + if outreq.Method == http.MethodConnect { + if reverseProxyIsExtendedConnectRequest(outreq) { + rewriteReverseProxyURL(outreq, upstream.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } else { + if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { + cleanup() + return nil, nil, nil, err + } + } + } else { + rewriteReverseProxyURL(outreq, upstream.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } + if updatedMaxForwards != "" { + outreq.Header.Set("Max-Forwards", updatedMaxForwards) + } + + reqUpType := reverseProxyUpgradeType(outreq.Header) + if reqUpType != "" && !isPrintableASCII(reqUpType) { + cleanup() + return nil, nil, nil, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), + } + } + + removeHopByHopHeaders(outreq.Header) + if headerValuesContainToken(c.Request.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + p.addForwardingHeaders(c.Request, outreq) + appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) + + if _, ok := outreq.Header["User-Agent"]; !ok { + outreq.Header.Set("User-Agent", "") + } + + if p.config.ModifyRequest != nil { + p.config.ModifyRequest(outreq) + } + + return outreq, connectWriter, cleanup, nil +} + +func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper { + if p.config.Transport != nil { + return p.config.Transport + } + if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil { + return upstream.extendedConnectTransport + } + return http.DefaultTransport +} + +func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool { + if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) { + return false + } + lb := p.config.LoadBalancing + if lb.TryDuration > 0 { + return time.Since(started) < lb.TryDuration + } + return attempts <= lb.Retries +} + +func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool { + interval := p.config.LoadBalancing.TryInterval + tryDuration := p.config.LoadBalancing.TryDuration + if tryDuration > 0 && interval == 0 { + interval = 250 * time.Millisecond + } + if tryDuration > 0 { + remaining := tryDuration - time.Since(started) + if remaining <= 0 { + return false + } + if interval <= 0 { + return ctx.Err() == nil + } + if interval > remaining { + return false + } + } + if interval <= 0 { + return ctx.Err() == nil + } + timer := time.NewTimer(interval) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } } func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) { @@ -976,6 +1093,54 @@ func validateReverseProxyTarget(target *url.URL) error { return nil } +func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) { + if config.Target != nil && len(config.Targets) > 0 { + return nil, errors.New("reverse proxy Target and Targets cannot be used together") + } + + targets := make([]*url.URL, 0, max(1, len(config.Targets))) + if config.Target != nil { + target := cloneReverseProxyURL(config.Target) + normalizeReverseProxyTarget(target) + if err := validateReverseProxyTarget(target); err != nil { + return nil, err + } + targets = append(targets, target) + } + for i, rawTarget := range config.Targets { + trimmed := strings.TrimSpace(rawTarget) + if trimmed == "" { + return nil, fmt.Errorf("reverse proxy target at index %d is empty", i) + } + target, err := url.Parse(trimmed) + if err != nil { + return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) + } + normalizeReverseProxyTarget(target) + if err := validateReverseProxyTarget(target); err != nil { + return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) + } + targets = append(targets, target) + } + if len(targets) == 0 { + return nil, errReverseProxyNilTarget + } + + upstreams := make([]*reverseProxyUpstream, 0, len(targets)) + for i, target := range targets { + upstream := &reverseProxyUpstream{ + key: fmt.Sprintf("%d:%s", i, target.String()), + target: target, + index: i, + } + if config.Transport == nil { + upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil +} + func validateReverseProxyForwardedBy(value string) error { trimmed := strings.TrimSpace(value) if trimmed == "" { @@ -1388,6 +1553,35 @@ func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { return req != nil && req.Context().Value(http.ServerContextKey) != nil } +func reverseProxyCanRetryRequest(req *http.Request) bool { + if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) { + return false + } + if req.Body == nil || req.ContentLength == 0 { + return true + } + return req.GetBody != nil +} + +func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool { + if err == nil || reverseProxyIsBenignTunnelError(err) { + return false + } + if req != nil && req.Context().Err() != nil { + return false + } + return !errors.Is(err, context.Canceled) +} + +func reverseProxyMethodIsSafe(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + return true + default: + return false + } +} + func reverseProxyIsBenignTunnelError(err error) bool { return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err) } @@ -1396,6 +1590,10 @@ func reverseProxyIsClosedBodyError(err error) bool { if err == nil { return false } + var streamErr http2.StreamError + if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel { + return true + } switch err.Error() { case "body closed by handler", "http2: response body closed", "response body closed": return true diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go new file mode 100644 index 0000000..9b41af0 --- /dev/null +++ b/reverseproxy_lb.go @@ -0,0 +1,352 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2026 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "fmt" + "math/rand/v2" + "net/http" + "net/textproto" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" +) + +// ReverseProxyLoadBalancingConfig configures upstream selection and retries. +type ReverseProxyLoadBalancingConfig struct { + Policy ReverseProxyLBPolicy + Retries int + TryDuration time.Duration + TryInterval time.Duration +} + +// ReverseProxyPassiveHealthConfig configures inline passive health tracking. +type ReverseProxyPassiveHealthConfig struct { + FailDuration time.Duration + MaxFails int + UnhealthyStatus []int +} + +// ReverseProxyLBPolicy selects an upstream from the configured target pool. +// Use the helper constructors such as LBRandom or LBHeader to build a policy. +type ReverseProxyLBPolicy struct { + kind reverseProxyLBPolicyKind + key string + fallback *ReverseProxyLBPolicy +} + +type reverseProxyLBPolicyKind uint8 + +const ( + reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota + reverseProxyLBPolicyRoundRobin + reverseProxyLBPolicyFirst + reverseProxyLBPolicyLeastConn + reverseProxyLBPolicyIPHash + reverseProxyLBPolicyClientIPHash + reverseProxyLBPolicyURIHash + reverseProxyLBPolicyHeader + reverseProxyLBPolicyQuery +) + +type reverseProxyUpstream struct { + key string + target *url.URL + index int + extendedConnectTransport http.RoundTripper + inFlight atomic.Int64 + + passiveMu sync.Mutex + failures []time.Time +} + +func LBRandom() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom} +} + +func LBRoundRobin() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin} +} + +func LBFirst() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst} +} + +func LBLeastConn() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn} +} + +func LBIPHash() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash} +} + +func LBClientIPHash() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash} +} + +func LBURIHash() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash} +} + +func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy { + policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))} + if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil { + policy.fallback = &fallback + } + return policy +} + +func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy { + policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)} + if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil { + policy.fallback = &fallback + } + return policy +} + +func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error { + switch policy.kind { + case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst, + reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash, + reverseProxyLBPolicyURIHash: + return nil + case reverseProxyLBPolicyHeader: + if policy.key == "" { + return fmt.Errorf("reverse proxy header load-balancing policy requires a header field") + } + case reverseProxyLBPolicyQuery: + if policy.key == "" { + return fmt.Errorf("reverse proxy query load-balancing policy requires a query key") + } + default: + return fmt.Errorf("reverse proxy load-balancing policy is invalid") + } + if policy.fallback != nil { + return validateReverseProxyLBPolicy(*policy.fallback) + } + return nil +} + +func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) { + now := time.Now() + policy := p.config.LoadBalancing.Policy + candidates := p.availableUpstreams(now, excluded) + if len(candidates) == 0 && len(excluded) > 0 { + candidates = p.availableUpstreams(now, nil) + } + if len(candidates) == 0 { + return nil, errReverseProxyNoAvailableUpstreams + } + return p.selectUpstreamWithPolicy(c, candidates, policy), nil +} + +func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream { + candidates := make([]*reverseProxyUpstream, 0, len(p.upstreams)) + for _, upstream := range p.upstreams { + if _, skip := excluded[upstream.key]; skip { + continue + } + if !upstream.healthy(now, p.config.PassiveHealth) { + continue + } + candidates = append(candidates, upstream) + } + return candidates +} + +func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + + switch policy.kind { + case reverseProxyLBPolicyRoundRobin: + return candidates[p.nextRoundRobinIndex(len(candidates))] + case reverseProxyLBPolicyFirst: + return candidates[0] + case reverseProxyLBPolicyLeastConn: + return p.selectLeastConnUpstream(candidates) + case reverseProxyLBPolicyIPHash: + return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr)) + case reverseProxyLBPolicyClientIPHash: + return reverseProxySelectHRW(candidates, c.RequestIP()) + case reverseProxyLBPolicyURIHash: + if c.Request == nil || c.Request.URL == nil { + return reverseProxySelectRandom(candidates) + } + return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI()) + case reverseProxyLBPolicyHeader: + if c.Request != nil && c.Request.Header != nil { + if values, ok := c.Request.Header[policy.key]; ok { + return reverseProxySelectHRW(candidates, strings.Join(values, ",")) + } + } + return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) + case reverseProxyLBPolicyQuery: + if c.Request != nil && c.Request.URL != nil { + if values, ok := c.Request.URL.Query()[policy.key]; ok { + return reverseProxySelectHRW(candidates, strings.Join(values, ",")) + } + } + return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) + case reverseProxyLBPolicyRandom: + fallthrough + default: + return reverseProxySelectRandom(candidates) + } +} + +func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int { + if size <= 1 { + return 0 + } + return int((p.roundRobin.Add(1) - 1) % uint64(size)) +} + +func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + selected := candidates[0] + lowest := selected.inFlight.Load() + ties := []*reverseProxyUpstream{selected} + for _, upstream := range candidates[1:] { + count := upstream.inFlight.Load() + switch { + case count < lowest: + selected = upstream + lowest = count + ties = []*reverseProxyUpstream{upstream} + case count == lowest: + ties = append(ties, upstream) + } + } + if len(ties) == 1 { + return selected + } + return ties[p.nextRoundRobinIndex(len(ties))] +} + +func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + if len(candidates) == 1 { + return candidates[0] + } + return candidates[rand.IntN(len(candidates))] +} + +func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + if key == "" { + return reverseProxySelectRandom(candidates) + } + selected := candidates[0] + bestScore := reverseProxyHRWScore(key, selected.key) + for _, upstream := range candidates[1:] { + score := reverseProxyHRWScore(key, upstream.key) + if score > bestScore { + selected = upstream + bestScore = score + } + } + return selected +} + +func reverseProxyHRWScore(key, upstreamKey string) uint64 { + const ( + offset64 = 14695981039346656037 + prime64 = 1099511628211 + ) + h := uint64(offset64) + for i := 0; i < len(key); i++ { + h ^= uint64(key[i]) + h *= prime64 + } + h ^= 0xff + h *= prime64 + for i := 0; i < len(upstreamKey); i++ { + h ^= uint64(upstreamKey[i]) + h *= prime64 + } + return h +} + +func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy { + if policy.fallback != nil { + return *policy.fallback + } + return LBRandom() +} + +func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool { + maxFails := reverseProxyPassiveMaxFails(config) + if config.FailDuration <= 0 || maxFails <= 0 { + return true + } + + u.passiveMu.Lock() + defer u.passiveMu.Unlock() + u.pruneFailuresLocked(now, config.FailDuration) + return len(u.failures) < maxFails +} + +func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) { + maxFails := reverseProxyPassiveMaxFails(config) + if config.FailDuration <= 0 || maxFails <= 0 { + return + } + + u.passiveMu.Lock() + defer u.passiveMu.Unlock() + u.pruneFailuresLocked(now, config.FailDuration) + u.failures = append(u.failures, now) +} + +func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) { + if len(u.failures) == 0 || window <= 0 { + if window <= 0 { + u.failures = nil + } + return + } + cutoff := now.Add(-window) + keep := 0 + for _, failureAt := range u.failures { + if failureAt.Before(cutoff) { + continue + } + u.failures[keep] = failureAt + keep++ + } + u.failures = u.failures[:keep] +} + +func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int { + if config.FailDuration <= 0 { + return 0 + } + if config.MaxFails <= 0 { + return 1 + } + return config.MaxFails +} + +func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool { + if status <= 0 { + return false + } + for _, unhealthyStatus := range config.UnhealthyStatus { + if status == unhealthyStatus { + return true + } + } + return false +} diff --git a/reverseproxy_test.go b/reverseproxy_test.go index e56aa5e..b68f74e 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -13,7 +13,9 @@ import ( "net/http/httptrace" "net/textproto" "net/url" + "strconv" "strings" + "sync/atomic" "testing" "time" @@ -262,6 +264,507 @@ func TestReverseProxyDefaultViaFallback(t *testing.T) { } } +func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Targets: []string{"http://example.net"}, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusInternalServerError { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) { + t.Helper() + + type snapshot struct { + Path string + RawQuery string + } + + backendOneCh := make(chan snapshot, 1) + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery} + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwoCh := make(chan snapshot, 1) + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery} + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ + Targets: []string{ + backendOne.URL + "/one?from=one", + backendTwo.URL + "/two?from=two", + }, + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()}, + })) + + first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil) + if first.Code != http.StatusOK || first.Body.String() != "one" { + t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String()) + } + second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil) + if second.Code != http.StatusOK || second.Body.String() != "two" { + t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String()) + } + + select { + case got := <-backendOneCh: + if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" { + t.Fatalf("unexpected first upstream request: %#v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first upstream request") + } + + select { + case got := <-backendTwoCh: + if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" { + t.Fatalf("unexpected second upstream request: %#v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for second upstream request") + } +} + +func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) { + t.Helper() + + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBHeader("X-Upstream", LBFirst()), + }, + })) + + fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" { + t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String()) + } + + headers := http.Header{"X-Upstream": {"tenant-a"}} + firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers) + secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers) + if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK { + t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code) + } + if firstSticky.Body.String() != secondSticky.Body.String() { + t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String()) + } +} + +func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) { + t.Helper() + + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBQuery("tenant", LBFirst()), + }, + })) + + fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" { + t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String()) + } + + firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil) + secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil) + if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK { + t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code) + } + if firstSticky.Body.String() != secondSticky.Body.String() { + t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String()) + } +} + +func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) { + t.Helper() + + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"}) + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBClientIPHash(), + }, + })) + + reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + reqOne.RemoteAddr = "10.0.0.1:1234" + reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10") + rrOne := httptest.NewRecorder() + engine.ServeHTTP(rrOne, reqOne) + + reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + reqTwo.RemoteAddr = "10.0.0.2:5678" + reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10") + rrTwo := httptest.NewRecorder() + engine.ServeHTTP(rrTwo, reqTwo) + + if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK { + t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code) + } + if rrOne.Body.String() != rrTwo.Body.String() { + t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String()) + } +} + +func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 1, + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String()) + } +} + +func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, r.Header.Get("X-Attempt")) + })) + defer backend.Close() + + var attempts atomic.Int64 + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 1, + }, + ModifyRequest: func(req *http.Request) { + req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10)) + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rr.Code) + } + if rr.Body.String() != "2" { + t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String()) + } +} + +func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) { + t.Helper() + + backendCalls := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCalls <- struct{}{} + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 1, + }, + })) + + rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil) + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case <-backendCalls: + t.Fatal("unsafe POST request should not be retried to the next upstream") + default: + } +} + +func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) { + t.Helper() + + backendOneStarted := make(chan struct{}, 1) + releaseBackendOne := make(chan struct{}) + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendOneStarted <- struct{}{} + <-releaseBackendOne + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBLeastConn(), + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + client := proxy.Client() + client.Timeout = 5 * time.Second + + firstRespCh := make(chan string, 1) + firstErrCh := make(chan error, 1) + go func() { + resp, err := client.Get(proxy.URL + "/proxy") + if err != nil { + firstErrCh <- err + return + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + firstErrCh <- err + return + } + firstRespCh <- string(body) + }() + + select { + case <-backendOneStarted: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first backend request") + } + + secondResp, err := client.Get(proxy.URL + "/proxy") + if err != nil { + close(releaseBackendOne) + t.Fatalf("second request failed: %v", err) + } + secondBody, err := io.ReadAll(secondResp.Body) + _ = secondResp.Body.Close() + if err != nil { + close(releaseBackendOne) + t.Fatalf("read second response: %v", err) + } + if string(secondBody) != "two" { + close(releaseBackendOne) + t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody)) + } + + close(releaseBackendOne) + select { + case err := <-firstErrCh: + t.Fatalf("first request failed: %v", err) + case body := <-firstRespCh: + if body != "one" { + t.Fatalf("unexpected first response body: %q", body) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first response body") + } +} + +func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) { + t.Helper() + + primaryCalls := make(chan struct{}, 4) + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + primaryCalls <- struct{}{} + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = io.WriteString(w, "primary down") + })) + defer primary.Close() + + secondaryCalls := make(chan struct{}, 4) + secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + secondaryCalls <- struct{}{} + _, _ = io.WriteString(w, "secondary up") + })) + defer secondary.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{primary.URL, secondary.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + }, + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + UnhealthyStatus: []int{http.StatusServiceUnavailable}, + }, + })) + + first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" { + t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String()) + } + second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if second.Code != http.StatusOK || second.Body.String() != "secondary up" { + t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String()) + } + + select { + case <-primaryCalls: + default: + t.Fatal("expected primary to receive the first request") + } + select { + case <-secondaryCalls: + default: + t.Fatal("expected secondary to receive the second request") + } + select { + case <-primaryCalls: + t.Fatal("primary should not receive the second request while unhealthy") + default: + } +} + +func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) { + t.Helper() + + started := make(chan struct{}, 1) + release := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + started <- struct{}{} + <-release + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backend.URL}, + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + client := proxy.Client() + respCh := make(chan error, 1) + go func() { + resp, err := client.Do(req) + if resp != nil { + _ = resp.Body.Close() + } + respCh <- err + }() + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend request") + } + cancel() + close(release) + select { + case <-respCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for canceled request to finish") + } + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String()) + } +} + +func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) { + t.Helper() + + backendCalls := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCalls <- struct{}{} + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 3, + TryDuration: 100 * time.Millisecond, + TryInterval: 250 * time.Millisecond, + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case <-backendCalls: + t.Fatal("retry budget should expire before the next upstream attempt") + default: + } +} + func TestReverseProxyCustomErrorHandler(t *testing.T) { t.Helper() @@ -967,6 +1470,122 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 8) + newBackend := func(name string) *httptest.Server { + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if got := r.Header.Get(":protocol"); got != "websocket" { + errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got) + w.WriteHeader(http.StatusBadRequest) + return + } + + controller := http.NewResponseController(w) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err) + return + } + w.WriteHeader(http.StatusOK) + _ = controller.Flush() + + line, err := bufio.NewReader(r.Body).ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err) + return + } + if _, err := io.WriteString(w, name+":"+line); err != nil { + errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err) + return + } + _ = controller.Flush() + })) + server.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil { + t.Fatalf("configure %s HTTP/2 server: %v", name, err) + } + server.StartTLS() + return server + } + + backendOne := newBackend("one") + defer backendOne.Close() + backendTwo := newBackend("two") + defer backendTwo.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBRoundRobin(), + }, + Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + doRequest := func(payload string) string { + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if _, err := io.WriteString(pw, payload+"\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + return message + } + + if got := doRequest("ping"); got != "one:ping\n" { + t.Fatalf("unexpected first tunneled response: %q", got) + } + if got := doRequest("pong"); got != "two:pong\n" { + t.Fatalf("unexpected second tunneled response: %q", got) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { t.Helper() From a9c1662333c2396ca042cd4f160b1966c396f743 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:19:41 +0800 Subject: [PATCH 09/55] fix(reverseproxy): bridge websocket extended connect upstreams --- docs/reverse-proxy.md | 19 +++ http2xconnect.go | 40 +++-- reverseproxy.go | 194 ++++++++++++++++++++++- reverseproxy_lb.go | 3 + reverseproxy_test.go | 351 ++++++++++++++++++++++++++++++++---------- 5 files changed, 508 insertions(+), 99 deletions(-) diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 15ebafd..7d05290 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -68,6 +68,7 @@ type ReverseProxyConfig struct { Transport http.RoundTripper FlushInterval time.Duration BufferPool BufferPool + AllowH2CUpstream bool ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error @@ -191,6 +192,24 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ })) ``` +### `AllowH2CUpstream` + +允许代理使用未加密 HTTP/2(h2c)与 `http://` upstream 通信。 + +- 默认关闭 +- 这是一个显式配置项 +- 启用后,Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游 +- 这意味着上游本身也必须显式支持 h2c;它不是“先试 h2c,失败再自动回退到 h1”的协商模式 + +```go +r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + AllowH2CUpstream: true, +})) +``` + +对于下游 HTTP/2 extended `CONNECT` websocket 场景,Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade,以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。 + ### `Transport` 可选。用于自定义底层转发所使用的 `http.RoundTripper`。 diff --git a/http2xconnect.go b/http2xconnect.go index b3b12a0..872f5b3 100644 --- a/http2xconnect.go +++ b/http2xconnect.go @@ -5,12 +5,8 @@ package touka import ( - "context" "crypto/tls" - "net" "net/http" - "net/url" - "strings" "sync" _ "unsafe" @@ -36,18 +32,36 @@ func configureHTTP2ExtendedConnectServer(srv *http.Server) error { return http2.ConfigureServer(srv, nil) } -func newHTTP2ExtendedConnectTransport(target *url.URL) http.RoundTripper { +func newHTTP2ExtendedConnectTransport() http.RoundTripper { enableHTTP2ExtendedConnectProtocol() + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetHTTP1(true) + transport.Protocols.SetHTTP2(true) + return transport +} - transport := &http2.Transport{} - if target == nil || !strings.EqualFold(target.Scheme, "http") { - return transport +func newHTTP1BridgeTransport() http.RoundTripper { + return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}}) +} + +func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetHTTP1(true) + transport.TLSClientConfig = tlsConfig + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} } - - transport.AllowHTTP = true - transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { - var dialer net.Dialer - return dialer.DialContext(ctx, network, addr) + if len(transport.TLSClientConfig.NextProtos) == 0 { + transport.TLSClientConfig.NextProtos = []string{"http/1.1"} } return transport } + +func newH2CTransport() http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return transport +} diff --git a/reverseproxy.go b/reverseproxy.go index 186e163..13ffe89 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -5,7 +5,10 @@ package touka import ( + "bufio" "context" + "crypto/rand" + "encoding/base64" "errors" "fmt" "io" @@ -52,9 +55,10 @@ type ReverseProxyConfig struct { LoadBalancing ReverseProxyLoadBalancingConfig PassiveHealth ReverseProxyPassiveHealthConfig - Transport http.RoundTripper - FlushInterval time.Duration - BufferPool BufferPool + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool + AllowH2CUpstream bool ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error @@ -86,6 +90,33 @@ type reverseProxyStatusError struct { err error } +type reverseProxyExtendedConnectBridge struct { + body io.ReadCloser +} + +type reverseProxyH2ReadWriteCloser struct { + io.ReadCloser + ResponseWriter +} + +func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { + n, err := rwc.ResponseWriter.Write(p) + if err != nil { + return 0, err + } + if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + return 0, err + } + return n, nil +} + +func (rwc reverseProxyH2ReadWriteCloser) Close() error { + if rwc.ReadCloser == nil { + return nil + } + return rwc.ReadCloser.Close() +} + func (e *reverseProxyStatusError) Error() string { if e == nil || e.err == nil { return "" @@ -314,7 +345,7 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte } defer cleanup() - transport := p.transportForUpstream(c.Request, upstream) + transport := p.transportForUpstream(outreq, upstream) rawWriter := reverseProxyBaseResponseWriter(c.Writer) var ( roundTripMu sync.Mutex @@ -353,6 +384,20 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte upstream.recordFailure(time.Now(), p.config.PassiveHealth) } + if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil { + if res.StatusCode == http.StatusSwitchingProtocols { + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return true, nil, false + } + if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil { + return false, err, false + } + return true, nil, false + } + return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false + } + if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { removeHopByHopHeaders(res.Header) res.Header.Del("Content-Length") @@ -435,6 +480,10 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { outreq := c.Request.Clone(ctx) + bridgeCtx, bridged := reverseProxyPrepareExtendedConnectBridge(outreq) + if bridged { + outreq = outreq.WithContext(bridgeCtx) + } if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { outreq.Body = nil } else if c.Request.GetBody != nil { @@ -451,7 +500,7 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte } outreq.Close = false var connectWriter *io.PipeWriter - if outreq.Method == http.MethodConnect { + if outreq.Method == http.MethodConnect && !bridged { pipeReader, pipeWriter := io.Pipe() outreq.Body = pipeReader outreq.ContentLength = -1 @@ -467,7 +516,13 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte } if outreq.Method == http.MethodConnect { - if reverseProxyIsExtendedConnectRequest(outreq) { + if bridged { + rewriteReverseProxyURL(outreq, upstream.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } else if reverseProxyIsExtendedConnectRequest(outreq) { rewriteReverseProxyURL(outreq, upstream.target) if !p.config.PreserveHost { outreq.Host = "" @@ -526,6 +581,15 @@ func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream * if p.config.Transport != nil { return p.config.Transport } + if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil { + if upstream.bridgeTransport != nil { + return upstream.bridgeTransport + } + return http.DefaultTransport + } + if upstream.useH2C && upstream.h2cTransport != nil { + return upstream.h2cTransport + } if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil { return upstream.extendedConnectTransport } @@ -915,6 +979,71 @@ func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error { + if c == nil || c.Request == nil { + res.Body.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")} + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("backend returned bridged websocket response without writable body"), + } + } + + controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + responseHeader := c.Writer.Header() + reverseProxyCopyHeader(responseHeader, res.Header) + responseHeader.Del("Upgrade") + responseHeader.Del("Connection") + responseHeader.Del("Sec-WebSocket-Accept") + c.Writer.WriteHeader(http.StatusOK) + if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer} + brw := bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) + + backConnClosed := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + case <-backConnClosed: + } + backConn.Close() + }() + defer close(backConnClosed) + defer conn.Close() + defer backConn.Close() + + if err := brw.Flush(); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + copyer := switchProtocolCopier{user: conn, backend: backConn} + go copyer.copyToBackend(errc) + go copyer.copyFromBackend(errc) + + firstErr := <-errc + if firstErr == nil { + firstErr = <-errc + } + if reverseProxyIsBenignTunnelError(firstErr) { + return nil + } + return firstErr +} + func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { if c == nil || c.Request == nil { res.Body.Close() @@ -1128,13 +1257,23 @@ func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstr upstreams := make([]*reverseProxyUpstream, 0, len(targets)) for i, target := range targets { + useH2C := strings.EqualFold(target.Scheme, "h2c") + if useH2C { + target = cloneReverseProxyURL(target) + target.Scheme = "http" + } upstream := &reverseProxyUpstream{ key: fmt.Sprintf("%d:%s", i, target.String()), target: target, index: i, + useH2C: useH2C || config.AllowH2CUpstream, } if config.Transport == nil { - upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport(target) + upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport() + upstream.bridgeTransport = newHTTP1BridgeTransport() + if upstream.useH2C { + upstream.h2cTransport = newH2CTransport() + } } upstreams = append(upstreams, upstream) } @@ -1237,6 +1376,47 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } +func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool) { + protocol := reverseProxyExtendedConnectProtocol(req) + if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { + if req == nil { + return context.Background(), false + } + return req.Context(), false + } + + bridge := &reverseProxyExtendedConnectBridge{body: req.Body} + ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge) + req.Header.Del(":protocol") + req.Method = http.MethodGet + req.Body = http.NoBody + req.ContentLength = 0 + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Version", "13") + key, err := reverseProxyGenerateWebSocketKey() + if err == nil { + req.Header.Set("Sec-WebSocket-Key", key) + } + return ctx, true +} + +func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge { + if ctx == nil { + return nil + } + bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge) + return bridge +} + +func reverseProxyGenerateWebSocketKey() (string, error) { + key := make([]byte, 16) + if _, err := rand.Read(key); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(key), nil +} + func reverseProxyIsExtendedConnectRequest(req *http.Request) bool { return reverseProxyExtendedConnectProtocol(req) != "" } diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go index 9b41af0..d2d45ab 100644 --- a/reverseproxy_lb.go +++ b/reverseproxy_lb.go @@ -57,7 +57,10 @@ type reverseProxyUpstream struct { key string target *url.URL index int + useH2C bool extendedConnectTransport http.RoundTripper + bridgeTransport http.RoundTripper + h2cTransport http.RoundTripper inFlight atomic.Int64 passiveMu sync.Mutex diff --git a/reverseproxy_test.go b/reverseproxy_test.go index b68f74e..b05f426 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -112,7 +112,8 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { t.Fatalf("unexpected body: %q", string(body)) } if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } if got.Path != "/base/api/ping" { t.Fatalf("unexpected upstream path: %q", got.Path) @@ -765,6 +766,43 @@ func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) { } } +func TestReverseProxyAllowH2CUpstream(t *testing.T) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen h2c upstream: %v", err) + } + server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Upstream-Proto", r.Proto) + _, _ = io.WriteString(w, "ok") + })} + server.Protocols = new(http.Protocols) + server.Protocols.SetUnencryptedHTTP2(true) + errCh := make(chan error, 1) + go func() { + errCh <- server.Serve(listener) + }() + defer func() { + _ = server.Close() + <-errCh + }() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://"+listener.Addr().String()), + AllowH2CUpstream: true, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("unexpected response: code=%d body=%q", rr.Code, rr.Body.String()) + } + if got := rr.Header().Get("X-Upstream-Proto"); got != "HTTP/2.0" { + t.Fatalf("expected h2c upstream proto, got %q", got) + } +} + func TestReverseProxyCustomErrorHandler(t *testing.T) { t.Helper() @@ -1363,19 +1401,29 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { enableHTTP2ExtendedConnectProtocol() errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - if r.ProtoMajor != 2 { - errCh <- fmt.Errorf("unexpected upstream protocol version: %s", r.Proto) + if got := r.Header.Get(":protocol"); got != "" { + errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) w.WriteHeader(http.StatusBadRequest) return } - if got := r.Header.Get(":protocol"); got != "websocket" { - errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { + errCh <- fmt.Errorf("unexpected upstream Connection header: %#v", r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected upstream Upgrade header: %q", r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { + errCh <- errors.New("missing upstream Sec-WebSocket-Key header") w.WriteHeader(http.StatusBadRequest) return } @@ -1385,36 +1433,41 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { return } - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("enable full duplex failed: %w", err) + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() - line, err := bufio.NewReader(r.Body).ReadString('\n') + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') if err != nil { errCh <- fmt.Errorf("read tunneled request body failed: %w", err) return } - if _, err := io.WriteString(w, "echo:"+line); err != nil { + if _, err := io.WriteString(brw, "echo:"+line); err != nil { errCh <- fmt.Errorf("write tunneled response body failed: %w", err) return } - _ = controller.Flush() + _ = brw.Flush() })) - upstream.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { - t.Fatalf("configure upstream HTTP/2 server: %v", err) - } - upstream.StartTLS() defer upstream.Close() engine := New() engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ Target: mustParseURL(t, upstream.URL), - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), Via: "proxy.test", })) @@ -1445,7 +1498,10 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status: %d", resp.StatusCode) } - if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "2.0 proxy.test" { + if got := resp.Header.Get("Upgrade"); got != "" { + t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got) + } + if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { t.Fatalf("unexpected Via response header: %#v", gotVia) } @@ -1470,6 +1526,116 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor != 1 { + errCh <- fmt.Errorf("expected bridged upstream protocol HTTP/1.x, got %s", r.Proto) + w.WriteHeader(http.StatusBadRequest) + return + } + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected websocket bridge headers: Connection=%#v Upgrade=%q", r.Header.Values("Connection"), r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(brw, "echo:"+line); err != nil { + errCh <- fmt.Errorf("write tunneled response body failed: %w", err) + return + } + _ = brw.Flush() + })) + upstream.EnableHTTP2 = true + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) + } + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled response body: %q", message) + } + _ = pw.Close() + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { t.Helper() @@ -1477,42 +1643,62 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { errCh := make(chan error, 8) newBackend := func(name string) *httptest.Server { - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - if got := r.Header.Get(":protocol"); got != "websocket" { + if got := r.Header.Get(":protocol"); got != "" { errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got) w.WriteHeader(http.StatusBadRequest) return } - - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("%s enable full duplex failed: %w", name, err) + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { + errCh <- fmt.Errorf("%s unexpected upstream Connection header: %#v", name, r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("%s unexpected upstream Upgrade header: %q", name, r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { + errCh <- fmt.Errorf("%s missing upstream Sec-WebSocket-Key header", name) + w.WriteHeader(http.StatusBadRequest) return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() - line, err := bufio.NewReader(r.Body).ReadString('\n') + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- fmt.Errorf("%s upstream response writer does not support hijack", name) + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("%s upstream hijack failed: %w", name, err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("%s upstream flush failed: %w", name, err) + return + } + + line, err := brw.ReadString('\n') if err != nil { errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err) return } - if _, err := io.WriteString(w, name+":"+line); err != nil { + if _, err := io.WriteString(brw, name+":"+line); err != nil { errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err) return } - _ = controller.Flush() + _ = brw.Flush() })) - server.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(server.Config); err != nil { - t.Fatalf("configure %s HTTP/2 server: %v", name, err) - } - server.StartTLS() return server } @@ -1527,8 +1713,7 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { LoadBalancing: ReverseProxyLoadBalancingConfig{ Policy: LBRoundRobin(), }, - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Via: "proxy.test", + Via: "proxy.test", })) proxy := httptest.NewUnstartedServer(engine) @@ -1557,7 +1742,8 @@ func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } if _, err := io.WriteString(pw, payload+"\n"); err != nil { t.Fatalf("write tunneled request body: %v", err) @@ -1592,55 +1778,59 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { enableHTTP2ExtendedConnectProtocol() errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("enable full duplex failed: %w", err) + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() - reader := bufio.NewReader(r.Body) + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + reader := bufio.NewReader(brw) line, err := reader.ReadString('\n') if err != nil { errCh <- fmt.Errorf("read tunneled request body failed: %w", err) return } - if _, err := io.WriteString(w, "ack:"+line); err != nil { + if _, err := io.WriteString(brw, "ack:"+line); err != nil { errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err) return } - _ = controller.Flush() + _ = brw.Flush() if _, err := io.Copy(io.Discard, reader); err != nil { errCh <- fmt.Errorf("wait for request half-close failed: %w", err) return } - if _, err := io.WriteString(w, "after-close\n"); err != nil { + if _, err := io.WriteString(brw, "after-close\n"); err != nil { errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err) return } - _ = controller.Flush() + _ = brw.Flush() })) - upstream.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { - t.Fatalf("configure upstream HTTP/2 server: %v", err) - } - upstream.StartTLS() defer upstream.Close() engine := New() engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Via: "proxy.test", + Target: mustParseURL(t, upstream.URL), + Via: "proxy.test", })) proxy := httptest.NewUnstartedServer(engine) @@ -1668,7 +1858,8 @@ func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } reader := bufio.NewReader(resp.Body) @@ -1707,36 +1898,37 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi enableHTTP2ExtendedConnectProtocol() errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } - controller := http.NewResponseController(w) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - errCh <- fmt.Errorf("enable full duplex failed: %w", err) + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") return } - w.WriteHeader(http.StatusOK) - _ = controller.Flush() + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + _ = brw.Flush() <-r.Context().Done() })) - upstream.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(upstream.Config); err != nil { - t.Fatalf("configure upstream HTTP/2 server: %v", err) - } - upstream.StartTLS() defer upstream.Close() proxyErrCh := make(chan error, 1) engine := New() engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Transport: &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Via: "proxy.test", + Target: mustParseURL(t, upstream.URL), + Via: "proxy.test", ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { select { case proxyErrCh <- err: @@ -1772,7 +1964,8 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } writeErrCh := make(chan error, 1) From 50c6a2361405e2ee103a2cc5a301f38c891c2ca6 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:50:27 +0800 Subject: [PATCH 10/55] refactor: simplify reverse proxy bridged connection handling by removing unused bufio --- reverseproxy.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 13ffe89..e674e7e 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -5,7 +5,6 @@ package touka import ( - "bufio" "context" "crypto/rand" "encoding/base64" @@ -1011,7 +1010,6 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r } conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer} - brw := bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) backConnClosed := make(chan struct{}) go func() { @@ -1025,10 +1023,6 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r defer conn.Close() defer backConn.Close() - if err := brw.Flush(); err != nil { - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} - } - errc := make(chan error, 2) copyer := switchProtocolCopier{user: conn, backend: backConn} go copyer.copyToBackend(errc) From 7abedc1acea918edb522d00a5e1c3c7e7d45870d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:33:18 +0800 Subject: [PATCH 11/55] enhance: improve reverse proxy error handling and add tests --- reverseproxy.go | 24 ++++---- reverseproxy_test.go | 130 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 10 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index e674e7e..2d6dfea 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -101,10 +101,10 @@ type reverseProxyH2ReadWriteCloser struct { func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { n, err := rwc.ResponseWriter.Write(p) if err != nil { - return 0, err + return n, err } if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { - return 0, err + return n, err } return n, nil } @@ -479,7 +479,10 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { outreq := c.Request.Clone(ctx) - bridgeCtx, bridged := reverseProxyPrepareExtendedConnectBridge(outreq) + bridgeCtx, bridged, err := reverseProxyPrepareExtendedConnectBridge(outreq) + if err != nil { + return nil, nil, nil, err + } if bridged { outreq = outreq.WithContext(bridgeCtx) } @@ -1370,13 +1373,13 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { return policy == ForwardedBoth || policy == ForwardedRFC7239Only } -func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool) { +func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { protocol := reverseProxyExtendedConnectProtocol(req) if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { if req == nil { - return context.Background(), false + return context.Background(), false, nil } - return req.Context(), false + return req.Context(), false, nil } bridge := &reverseProxyExtendedConnectBridge{body: req.Body} @@ -1389,10 +1392,11 @@ func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Contex req.Header.Set("Connection", "Upgrade") req.Header.Set("Sec-WebSocket-Version", "13") key, err := reverseProxyGenerateWebSocketKey() - if err == nil { - req.Header.Set("Sec-WebSocket-Key", key) + if err != nil { + return nil, false, fmt.Errorf("reverse proxy failed to generate websocket key: %w", err) } - return ctx, true + req.Header.Set("Sec-WebSocket-Key", key) + return ctx, true, nil } func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge { @@ -1405,7 +1409,7 @@ func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseP func reverseProxyGenerateWebSocketKey() (string, error) { key := make([]byte, 16) - if _, err := rand.Read(key); err != nil { + if _, err := io.ReadFull(rand.Reader, key); err != nil { return "", err } return base64.StdEncoding.EncodeToString(key), nil diff --git a/reverseproxy_test.go b/reverseproxy_test.go index b05f426..8a250b2 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -2,7 +2,9 @@ package touka import ( "bufio" + "bytes" "context" + crand "crypto/rand" "crypto/tls" "errors" "fmt" @@ -829,6 +831,67 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) { } } +func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *testing.T) { + t.Helper() + + flushErr := errors.New("flush failed") + writer := &flushErrorResponseWriter{flushErr: flushErr} + conn := reverseProxyH2ReadWriteCloser{ + ReadCloser: io.NopCloser(strings.NewReader("")), + ResponseWriter: writer, + } + + n, err := conn.Write([]byte("ping")) + if n != len("ping") { + t.Fatalf("unexpected bytes written: %d", n) + } + if !errors.Is(err, flushErr) { + t.Fatalf("unexpected write error: %v", err) + } + if got := writer.body.String(); got != "ping" { + t.Fatalf("unexpected buffered body: %q", got) + } +} + +func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *testing.T) { + t.Helper() + + transportCalled := atomic.Bool{} + entropyErr := errors.New("entropy source unavailable") + originalReader := crand.Reader + crand.Reader = errorReader{err: entropyErr} + t.Cleanup(func() { + crand.Reader = originalReader + }) + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + transportCalled.Store(true) + return nil, errors.New("unexpected round trip") + }), + ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) { + w.WriteHeader(reverseProxyStatusCode(err)) + _, _ = io.WriteString(w, err.Error()) + }, + })) + + headers := make(http.Header) + headers.Set(":protocol", "websocket") + rr := PerformRequest(engine, http.MethodConnect, "/ws", nil, headers) + + if transportCalled.Load() { + t.Fatal("transport should not be called when websocket key generation fails") + } + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "reverse proxy failed to generate websocket key") || !strings.Contains(body, entropyErr.Error()) { + t.Fatalf("unexpected error body: %q", body) + } +} + func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { t.Helper() @@ -2137,6 +2200,73 @@ func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) return fn(req) } +type flushErrorResponseWriter struct { + header http.Header + body bytes.Buffer + status int + written bool + flushErr error +} + +func (w *flushErrorResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *flushErrorResponseWriter) WriteHeader(statusCode int) { + if w.written { + return + } + w.status = statusCode + w.written = true +} + +func (w *flushErrorResponseWriter) Write(p []byte) (int, error) { + if !w.written { + w.WriteHeader(http.StatusOK) + } + return w.body.Write(p) +} + +func (w *flushErrorResponseWriter) Flush() {} + +func (w *flushErrorResponseWriter) FlushError() error { + if !w.written { + w.WriteHeader(http.StatusOK) + } + return w.flushErr +} + +func (w *flushErrorResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} + +func (w *flushErrorResponseWriter) Status() int { + return w.status +} + +func (w *flushErrorResponseWriter) Size() int { + return w.body.Len() +} + +func (w *flushErrorResponseWriter) Written() bool { + return w.written +} + +func (w *flushErrorResponseWriter) IsHijacked() bool { + return false +} + +type errorReader struct { + err error +} + +func (r errorReader) Read([]byte) (int, error) { + return 0, r.err +} + func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) From 20dc6e4047cb42ca8b00206bddb6d762a897fce1 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:44:02 +0800 Subject: [PATCH 12/55] refactor: cache ResponseController in H2ReadWriteCloser for better performance --- reverseproxy.go | 9 +++++---- reverseproxy_test.go | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 2d6dfea..afdbd9c 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -96,20 +96,21 @@ type reverseProxyExtendedConnectBridge struct { type reverseProxyH2ReadWriteCloser struct { io.ReadCloser ResponseWriter + controller *http.ResponseController } -func (rwc reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { +func (rwc *reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { n, err := rwc.ResponseWriter.Write(p) if err != nil { return n, err } - if err := http.NewResponseController(reverseProxyBaseResponseWriter(rwc.ResponseWriter)).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + if err := rwc.controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { return n, err } return n, nil } -func (rwc reverseProxyH2ReadWriteCloser) Close() error { +func (rwc *reverseProxyH2ReadWriteCloser) Close() error { if rwc.ReadCloser == nil { return nil } @@ -1012,7 +1013,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } - conn := reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer} + conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} backConnClosed := make(chan struct{}) go func() { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 8a250b2..85a64d4 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -836,9 +836,10 @@ func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *te flushErr := errors.New("flush failed") writer := &flushErrorResponseWriter{flushErr: flushErr} - conn := reverseProxyH2ReadWriteCloser{ + conn := &reverseProxyH2ReadWriteCloser{ ReadCloser: io.NopCloser(strings.NewReader("")), ResponseWriter: writer, + controller: http.NewResponseController(reverseProxyBaseResponseWriter(writer)), } n, err := conn.Write([]byte("ping")) From dcdb1504a32b709122709a37dc39f4fc869e70d8 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:58:34 +0800 Subject: [PATCH 13/55] feat: add robust transport cloning and improve header handling in reverse proxy --- http2xconnect.go | 26 ++++++++++++++++++++++--- reverseproxy.go | 3 +-- reverseproxy_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/http2xconnect.go b/http2xconnect.go index 872f5b3..8521672 100644 --- a/http2xconnect.go +++ b/http2xconnect.go @@ -6,8 +6,10 @@ package touka import ( "crypto/tls" + "net" "net/http" "sync" + "time" _ "unsafe" "golang.org/x/net/http2" @@ -34,7 +36,7 @@ func configureHTTP2ExtendedConnectServer(srv *http.Server) error { func newHTTP2ExtendedConnectTransport() http.RoundTripper { enableHTTP2ExtendedConnectProtocol() - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetHTTP1(true) transport.Protocols.SetHTTP2(true) @@ -46,7 +48,7 @@ func newHTTP1BridgeTransport() http.RoundTripper { } func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper { - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetHTTP1(true) transport.TLSClientConfig = tlsConfig @@ -60,8 +62,26 @@ func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripp } func newH2CTransport() http.RoundTripper { - transport := http.DefaultTransport.(*http.Transport).Clone() + transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetUnencryptedHTTP2(true) return transport } + +func cloneDefaultTransport() *http.Transport { + if transport, ok := http.DefaultTransport.(*http.Transport); ok { + return transport.Clone() + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} diff --git a/reverseproxy.go b/reverseproxy.go index afdbd9c..5d9b1ad 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1004,8 +1004,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r responseHeader := c.Writer.Header() reverseProxyCopyHeader(responseHeader, res.Header) - responseHeader.Del("Upgrade") - responseHeader.Del("Connection") + removeHopByHopHeaders(responseHeader) responseHeader.Del("Sec-WebSocket-Accept") c.Writer.WriteHeader(http.StatusOK) if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 85a64d4..9252e4a 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -893,6 +893,46 @@ func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *te } } +func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing.T) { + t.Helper() + + originalDefaultTransport := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("unexpected round trip") + }) + t.Cleanup(func() { + http.DefaultTransport = originalDefaultTransport + }) + + assertTransport := func(name string, rt http.RoundTripper, check func(*http.Transport)) { + t.Helper() + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("%s returned %T, want *http.Transport", name, rt) + } + check(transport) + } + + assertTransport("newHTTP2ExtendedConnectTransport", newHTTP2ExtendedConnectTransport(), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.HTTP1() || !transport.Protocols.HTTP2() { + t.Fatalf("unexpected protocols for extended connect transport: %#v", transport.Protocols) + } + }) + assertTransport("newHTTP1BridgeTransportWithTLSConfig", newHTTP1BridgeTransportWithTLSConfig(nil), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.HTTP1() || transport.Protocols.HTTP2() || transport.Protocols.UnencryptedHTTP2() { + t.Fatalf("unexpected protocols for bridge transport: %#v", transport.Protocols) + } + if transport.TLSClientConfig == nil || len(transport.TLSClientConfig.NextProtos) != 1 || transport.TLSClientConfig.NextProtos[0] != "http/1.1" { + t.Fatalf("unexpected TLS next protos for bridge transport: %#v", transport.TLSClientConfig) + } + }) + assertTransport("newH2CTransport", newH2CTransport(), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.UnencryptedHTTP2() || transport.Protocols.HTTP1() || transport.Protocols.HTTP2() { + t.Fatalf("unexpected protocols for h2c transport: %#v", transport.Protocols) + } + }) +} + func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { t.Helper() @@ -1509,7 +1549,7 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } defer conn.Close() - _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ignored\r\n\r\n") + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade, X-Hop-Token\r\nX-Hop-Token: hidden\r\nSec-WebSocket-Accept: ignored\r\n\r\n") if err := brw.Flush(); err != nil { errCh <- fmt.Errorf("upstream flush failed: %w", err) return @@ -1565,6 +1605,9 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { if got := resp.Header.Get("Upgrade"); got != "" { t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got) } + if got := resp.Header.Get("X-Hop-Token"); got != "" { + t.Fatalf("bridged extended CONNECT response should not expose hop-by-hop token header, got %q", got) + } if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { t.Fatalf("unexpected Via response header: %#v", gotVia) } From d53693952a8cccd8e6579c7a92f7a63b5a6bb5d8 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Thu, 2 Apr 2026 22:13:50 +0800 Subject: [PATCH 14/55] refactor: improve TLS config handling and add bridge connection tests --- http2xconnect.go | 5 +- reverseproxy.go | 9 ++- reverseproxy_test.go | 146 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 7 deletions(-) diff --git a/http2xconnect.go b/http2xconnect.go index 8521672..c691a77 100644 --- a/http2xconnect.go +++ b/http2xconnect.go @@ -51,9 +51,10 @@ func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripp transport := cloneDefaultTransport() transport.Protocols = new(http.Protocols) transport.Protocols.SetHTTP1(true) - transport.TLSClientConfig = tlsConfig - if transport.TLSClientConfig == nil { + if tlsConfig == nil { transport.TLSClientConfig = &tls.Config{} + } else { + transport.TLSClientConfig = tlsConfig.Clone() } if len(transport.TLSClientConfig.NextProtos) == 0 { transport.TLSClientConfig.NextProtos = []string{"http/1.1"} diff --git a/reverseproxy.go b/reverseproxy.go index 5d9b1ad..148a9b4 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1024,7 +1024,6 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r }() defer close(backConnClosed) defer conn.Close() - defer backConn.Close() errc := make(chan error, 2) copyer := switchProtocolCopier{user: conn, backend: backConn} @@ -1374,11 +1373,11 @@ func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { } func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { + if req == nil { + return context.Background(), false, nil + } protocol := reverseProxyExtendedConnectProtocol(req) - if req == nil || req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { - if req == nil { - return context.Background(), false, nil - } + if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { return req.Context(), false, nil } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 9252e4a..bf7b0bb 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -933,6 +933,29 @@ func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing }) } +func TestNewHTTP1BridgeTransportWithTLSConfigClonesInput(t *testing.T) { + t.Helper() + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + rt := newHTTP1BridgeTransportWithTLSConfig(tlsConfig) + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("unexpected transport type: %T", rt) + } + if transport.TLSClientConfig == nil { + t.Fatal("expected TLS client config") + } + if transport.TLSClientConfig == tlsConfig { + t.Fatal("expected bridge transport to clone TLS config") + } + if len(tlsConfig.NextProtos) != 0 { + t.Fatalf("input TLS config was mutated: %#v", tlsConfig.NextProtos) + } + if got := transport.TLSClientConfig.NextProtos; len(got) != 1 || got[0] != "http/1.1" { + t.Fatalf("unexpected transport NextProtos: %#v", got) + } +} + func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { t.Helper() @@ -1633,6 +1656,98 @@ func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { } } +func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + closeCalls := atomic.Int32{} + transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Method != http.MethodGet { + return nil, fmt.Errorf("unexpected upstream method: %s", req.Method) + } + backend := &countingReadWriteCloser{ + readData: []byte("echo:ping\n"), + closeCalls: &closeCalls, + closeWriteErr: http.ErrNotSupported, + } + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: http.Header{ + "Connection": []string{"Upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-WebSocket-Accept": []string{"ignored"}, + }, + Body: backend, + Request: req, + }, nil + }) + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: transport, + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + clientTransport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer clientTransport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := clientTransport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if _, err := io.WriteString(pw, "ping\n"); err != nil { + _ = resp.Body.Close() + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + _ = resp.Body.Close() + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + _ = resp.Body.Close() + t.Fatalf("unexpected tunneled response body: %q", message) + } + if err := pw.Close(); err != nil { + _ = resp.Body.Close() + t.Fatalf("close tunneled request body: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatalf("close response body: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if closeCalls.Load() > 0 { + break + } + time.Sleep(10 * time.Millisecond) + } + if got := closeCalls.Load(); got != 1 { + t.Fatalf("expected backend connection to close exactly once, got %d", got) + } +} + func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) { t.Helper() @@ -2311,6 +2426,37 @@ func (r errorReader) Read([]byte) (int, error) { return 0, r.err } +type countingReadWriteCloser struct { + readData []byte + writeBuf bytes.Buffer + closeCalls *atomic.Int32 + closeWriteErr error +} + +func (r *countingReadWriteCloser) Read(p []byte) (int, error) { + if len(r.readData) == 0 { + return 0, io.EOF + } + n := copy(p, r.readData) + r.readData = r.readData[n:] + return n, nil +} + +func (r *countingReadWriteCloser) Write(p []byte) (int, error) { + return r.writeBuf.Write(p) +} + +func (r *countingReadWriteCloser) Close() error { + if r.closeCalls != nil { + r.closeCalls.Add(1) + } + return nil +} + +func (r *countingReadWriteCloser) CloseWrite() error { + return r.closeWriteErr +} + func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) From 1a6325d461d6c2594b29c8ba61c1af99a0dd6454 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Fri, 3 Apr 2026 00:29:15 +0800 Subject: [PATCH 15/55] feat: improve reverse proxy tunnel management with sync.Once and better error handling --- reverseproxy.go | 53 ++++++++++++++++++++------------------------ reverseproxy_test.go | 38 +++++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 148a9b4..1b89b2a 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -518,24 +518,10 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte } } - if outreq.Method == http.MethodConnect { - if bridged { - rewriteReverseProxyURL(outreq, upstream.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } else if reverseProxyIsExtendedConnectRequest(outreq) { - rewriteReverseProxyURL(outreq, upstream.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } else { - if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { - cleanup() - return nil, nil, nil, err - } + if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) { + if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { + cleanup() + return nil, nil, nil, err } } else { rewriteReverseProxyURL(outreq, upstream.target) @@ -1014,26 +1000,35 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} - backConnClosed := make(chan struct{}) + var closeOnce sync.Once + closeTunnel := func() { + closeOnce.Do(func() { + _ = conn.Close() + _ = backConn.Close() + }) + } go func() { - select { - case <-req.Context().Done(): - case <-backConnClosed: - } - backConn.Close() + <-req.Context().Done() + closeTunnel() }() - defer close(backConnClosed) - defer conn.Close() errc := make(chan error, 2) copyer := switchProtocolCopier{user: conn, backend: backConn} go copyer.copyToBackend(errc) go copyer.copyFromBackend(errc) - firstErr := <-errc - if firstErr == nil { - firstErr = <-errc + var firstErr error + for i := 0; i < 2; i++ { + err := <-errc + if reverseProxyIsBenignTunnelError(err) { + continue + } + if firstErr == nil { + firstErr = err + closeTunnel() + } } + closeTunnel() if reverseProxyIsBenignTunnelError(firstErr) { return nil } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index bf7b0bb..9cbc734 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -17,6 +17,7 @@ import ( "net/url" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -1662,14 +1663,24 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { enableHTTP2ExtendedConnectProtocol() closeCalls := atomic.Int32{} + backendReadDone := make(chan struct{}, 1) transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.Method != http.MethodGet { return nil, fmt.Errorf("unexpected upstream method: %s", req.Method) } - backend := &countingReadWriteCloser{ - readData: []byte("echo:ping\n"), + var respondOnce sync.Once + var backend *countingReadWriteCloser + backend = &countingReadWriteCloser{ + readDataCh: make(chan []byte, 1), closeCalls: &closeCalls, - closeWriteErr: http.ErrNotSupported, + closeWriteErr: nil, + afterWrite: func() { + respondOnce.Do(func() { + backendReadDone <- struct{}{} + backend.readDataCh <- []byte("echo:ping\n") + close(backend.readDataCh) + }) + }, } return &http.Response{ StatusCode: http.StatusSwitchingProtocols, @@ -1719,6 +1730,12 @@ func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { _ = resp.Body.Close() t.Fatalf("write tunneled request body: %v", err) } + select { + case <-backendReadDone: + case <-time.After(2 * time.Second): + _ = resp.Body.Close() + t.Fatal("backend did not receive tunneled request body") + } message, err := bufio.NewReader(resp.Body).ReadString('\n') if err != nil { _ = resp.Body.Close() @@ -2428,12 +2445,21 @@ func (r errorReader) Read([]byte) (int, error) { type countingReadWriteCloser struct { readData []byte + readDataCh chan []byte writeBuf bytes.Buffer closeCalls *atomic.Int32 closeWriteErr error + afterWrite func() } func (r *countingReadWriteCloser) Read(p []byte) (int, error) { + if len(r.readData) == 0 && r.readDataCh != nil { + data, ok := <-r.readDataCh + if !ok { + return 0, io.EOF + } + r.readData = data + } if len(r.readData) == 0 { return 0, io.EOF } @@ -2443,7 +2469,11 @@ func (r *countingReadWriteCloser) Read(p []byte) (int, error) { } func (r *countingReadWriteCloser) Write(p []byte) (int, error) { - return r.writeBuf.Write(p) + n, err := r.writeBuf.Write(p) + if err == nil && r.afterWrite != nil { + r.afterWrite() + } + return n, err } func (r *countingReadWriteCloser) Close() error { From 70f8cc615946a8819fc59a1dd5944d1807a4ce5c Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 07:19:33 +0800 Subject: [PATCH 16/55] fix: avoid panic in case-insensitive wildcard lookup --- tree.go | 2 +- tree_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/tree.go b/tree.go index 31246a5..f5452f4 100644 --- a/tree.go +++ b/tree.go @@ -852,7 +852,7 @@ walk: // 外部循环用于遍历路由树 return nil // 未找到, 返回 nil } - n = n.children[0] // 移动到通配符子节点(通常是唯一一个) + n = n.children[len(n.children)-1] // 通配符子节点约定始终位于末尾 switch n.nType { case param: // 参数节点 // 查找参数结束位置('/' 或路径末尾) diff --git a/tree_test.go b/tree_test.go index d3ffdfa..7665afd 100644 --- a/tree_test.go +++ b/tree_test.go @@ -901,6 +901,34 @@ func TestTreeInvalidNodeType(t *testing.T) { } } +func TestFindCaseInsensitivePathWithStaticAndParamRoutesDoesNotPanicOnMiss(t *testing.T) { + tree := &node{} + routes := [...]string{ + "/:user/:repo/info/refs", + "/healthz", + "/api/db/data", + "/api/db/sum", + } + + for _, route := range routes { + tree.addRoute(route, fakeHandler(route)) + } + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic while looking up missing path: %v", r) + } + }() + + if out, found := tree.findCaseInsensitivePath("/does-not-exist", true); found || out != nil { + t.Fatalf("expected missing path lookup to return no match, got %q, %t", string(out), found) + } + + if out, found := tree.findCaseInsensitivePath("/does-not-exist", false); found || out != nil { + t.Fatalf("expected missing path lookup without trailing slash fix to return no match, got %q, %t", string(out), found) + } +} + func TestTreeInvalidParamsType(t *testing.T) { tree := &node{} // add a child with wildcard From d12e887858ab32b9fd627febb1586a36afdecabe Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 07:46:06 +0800 Subject: [PATCH 17/55] fix: keep RunShutdown on HTTP path --- serve.go | 32 ++++++++++++---------- serve_test.go | 76 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 14 deletions(-) create mode 100644 serve_test.go diff --git a/serve.go b/serve.go index f3ddc5f..1825b32 100644 --- a/serve.go +++ b/serve.go @@ -46,26 +46,30 @@ func getShutdownTimeout(timeouts []time.Duration) time.Duration { return defaultShutdownTimeout } +// serveServer 根据显式指定的启动模式运行 HTTP 或 HTTPS 服务器. +func serveServer(srv *http.Server, serveTLS bool) error { + if serveTLS { + // 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置, + // ListenAndServeTLS 的前两个参数可以为空字符串 + return srv.ListenAndServeTLS("", "") + } + + return srv.ListenAndServe() +} + // runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server, // 并处理其启动失败的致命错误 // serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS") -func runServer(serverType string, srv *http.Server) { +func runServer(serverType string, srv *http.Server, serveTLS bool) { go func() { - var err error protocol := "http" - if srv.TLSConfig != nil { + if serveTLS { protocol = "https" } log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) - if srv.TLSConfig != nil { - // 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置, - // ListenAndServeTLS 的前两个参数可以为空字符串 - err = srv.ListenAndServeTLS("", "") - } else { - err = srv.ListenAndServe() - } + err := serveServer(srv, serveTLS) // 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed), // 则认为是一个严重错误,并终止程序 @@ -236,7 +240,7 @@ func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error engine.ServerConfigurator(srv) } - runServer("HTTP", srv) + runServer("HTTP", srv, false) return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) } @@ -293,7 +297,7 @@ func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...tim engine.ServerConfigurator(srv) } - runServer("HTTPS", srv) + runServer("HTTPS", srv, true) return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) } @@ -361,8 +365,8 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con } // --- 启动服务器和优雅关闭 --- - runServer("HTTPS", httpsSrv) - runServer("HTTP Redirect", httpSrv) + runServer("HTTPS", httpsSrv, true) + runServer("HTTP Redirect", httpSrv, false) return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco) } diff --git a/serve_test.go b/serve_test.go new file mode 100644 index 0000000..01d639f --- /dev/null +++ b/serve_test.go @@ -0,0 +1,76 @@ +package touka + +import ( + "context" + "crypto/tls" + "errors" + "io" + "net" + "net/http" + "testing" + "time" +) + +func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on ephemeral port: %v", err) + } + addr := listener.Addr().String() + if err := listener.Close(); err != nil { + t.Fatalf("close temporary listener: %v", err) + } + + srv := &http.Server{ + Addr: addr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("ok")) + }), + // RunShutdown uses the HTTP startup path and must not let a shared + // ServerConfigurator accidentally turn it into HTTPS. + TLSConfig: &tls.Config{}, + } + + errCh := make(chan error, 1) + go func() { + errCh <- serveServer(srv, false) + }() + + client := &http.Client{Timeout: 200 * time.Millisecond} + var resp *http.Response + requestURL := "http://" + addr + + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + resp, err = client.Get(requestURL) + if err == nil { + break + } + time.Sleep(20 * time.Millisecond) + } + if err != nil { + t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read response body: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: got %d want %d", resp.StatusCode, http.StatusOK) + } + if string(body) != "ok" { + t.Fatalf("unexpected body: got %q want %q", string(body), "ok") + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + t.Fatalf("shutdown server: %v", err) + } + + if err := <-errCh; !errors.Is(err, http.ErrServerClosed) { + t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err) + } +} From 7db3d32d7b9f807864ff6a4692b13f949db8a316 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 07:51:39 +0800 Subject: [PATCH 18/55] test: improve serve startup failure diagnostics --- serve_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/serve_test.go b/serve_test.go index 01d639f..6092f7b 100644 --- a/serve_test.go +++ b/serve_test.go @@ -49,7 +49,12 @@ func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { time.Sleep(20 * time.Millisecond) } if err != nil { - t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err) + select { + case serveErr := <-errCh: + t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: request error=%v, serve error=%v", err, serveErr) + default: + t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err) + } } defer resp.Body.Close() From 6acac9edce474de3d9abeec76d45a703247d8a2d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 08:27:00 +0800 Subject: [PATCH 19/55] fix: streamline route matcher backtracking Avoid rebuilding skipped-node state during wildcard fallback so the matcher no longer loops on the same static branch and stops allocating on the hot path. Add focused route benchmarks and regression coverage to keep the optimized path stable. --- .gitignore | 3 +- engine.go | 13 ++-- route_match_benchmark_test.go | 130 ++++++++++++++++++++++++++++++++++ tree.go | 69 +++++++++--------- tree_test.go | 66 +++++++++++++++++ 5 files changed, 240 insertions(+), 41 deletions(-) create mode 100644 route_match_benchmark_test.go diff --git a/.gitignore b/.gitignore index 30d74d2..6f301cd 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -test \ No newline at end of file +test +/bench_route_match_baseline.txt diff --git a/engine.go b/engine.go index b7cf330..ece023d 100644 --- a/engine.go +++ b/engine.go @@ -739,12 +739,13 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 return } - // 尝试不区分大小写的查找 - // 直接在 rootNode 上调用 findCaseInsensitivePath 方法 - ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) - if found && engine.RedirectFixedPath { - c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 - return + if engine.RedirectFixedPath { + // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. + ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) + if found { + c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 + return + } } } } diff --git a/route_match_benchmark_test.go b/route_match_benchmark_test.go new file mode 100644 index 0000000..e0dd2aa --- /dev/null +++ b/route_match_benchmark_test.go @@ -0,0 +1,130 @@ +package touka + +import "testing" + +var ( + benchmarkRouteHandlers HandlersChain + benchmarkRouteFullPath string + benchmarkRouteParamsLen int + benchmarkRouteCIPath []byte + benchmarkRouteCIFound bool +) + +func buildRouteMatchBenchmarkTree() *node { + tree := &node{} + routes := []string{ + "/", + "/health", + "/contact", + "/api/v1/users", + "/api/v1/users/:id", + "/api/v1/users/:id/settings", + "/assets/*filepath", + "/abc/b", + "/abc/:p1/cde", + "/abc/:p1/:p2/def/*filepath", + } + + for _, route := range routes { + tree.addRoute(route, fakeHandler(route)) + } + + return tree +} + +func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) { + b.Helper() + + params := make(Params, 0, 4) + skipped := make([]skippedNode, 0, 8) + + value := tree.getValue(path, ¶ms, &skipped, true) + if wantFullPath == "" { + if value.handlers != nil { + b.Fatalf("expected no match for %q, got %q", path, value.fullPath) + } + } else { + if value.handlers == nil { + b.Fatalf("expected match for %q, got nil handlers", path) + } + if value.fullPath != wantFullPath { + b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath) + } + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + params = params[:0] + skipped = skipped[:0] + value = tree.getValue(path, ¶ms, &skipped, true) + } + + benchmarkRouteHandlers = value.handlers + benchmarkRouteFullPath = value.fullPath + if value.params != nil { + benchmarkRouteParamsLen = len(*value.params) + } else { + benchmarkRouteParamsLen = 0 + } +} + +func BenchmarkRouteMatch(b *testing.B) { + tree := buildRouteMatchBenchmarkTree() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users") + }) + + b.Run("ParamHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id") + }) + + b.Run("BacktrackingHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath") + }) + + b.Run("Miss", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/does/not/exist", "") + }) + + b.Run("CaseInsensitiveHit", func(b *testing.B) { + path := "/API/V1/USERS/123/SETTINGS" + out, found := tree.findCaseInsensitivePath(path, true) + if !found { + b.Fatalf("expected fixed-path match for %q", path) + } + if got := string(out); got != "/api/v1/users/123/settings" { + b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) + + b.Run("CaseInsensitiveMiss", func(b *testing.B) { + path := "/DOES/NOT/EXIST" + out, found := tree.findCaseInsensitivePath(path, true) + if found || out != nil { + b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) +} diff --git a/tree.go b/tree.go index f5452f4..e9a10e6 100644 --- a/tree.go +++ b/tree.go @@ -452,12 +452,14 @@ type skippedNode struct { // 建议进行 TSR(尾部斜杠重定向). func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { var globalParamsCount int16 // 全局参数计数 + var backtrackToWildChild bool walk: // 外部循环用于遍历路由树 for { prefix := n.path // 当前节点的路径前缀 if len(path) > len(prefix) { if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 + pathAtNode := path path = path[len(prefix):] // 移除已匹配的前缀 // 在访问 path[0] 之前进行安全检查 @@ -467,30 +469,26 @@ walk: // 外部循环用于遍历路由树 // 优先尝试所有非通配符子节点, 通过匹配索引字符 idxc := path[0] // 剩余路径的第一个字符 - for i, c := range []byte(n.indices) { - if c == idxc { // 如果找到匹配的索引字符 - // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 - if n.wildChild { - index := len(*skippedNodes) - *skippedNodes = (*skippedNodes)[:index+1] - (*skippedNodes)[index] = skippedNode{ - path: prefix + path, // 记录跳过的路径 - node: &node{ // 复制当前节点的状态 - path: n.path, - wildChild: n.wildChild, - nType: n.nType, - priority: n.priority, - children: n.children, - handlers: n.handlers, - fullPath: n.fullPath, - }, - paramsCount: globalParamsCount, // 记录当前参数计数 + if !backtrackToWildChild { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 如果找到匹配的索引字符 + // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 + if n.wildChild { + index := len(*skippedNodes) + *skippedNodes = (*skippedNodes)[:index+1] + (*skippedNodes)[index] = skippedNode{ + path: pathAtNode, // 记录进入当前节点时的剩余路径 + node: n, + paramsCount: globalParamsCount, // 记录当前参数计数 + } } - } - n = n.children[i] // 移动到匹配的子节点 - continue walk // 继续外部循环 + n = n.children[i] // 移动到匹配的子节点 + continue walk // 继续外部循环 + } } + } else { + backtrackToWildChild = false } if !n.wildChild { @@ -507,7 +505,8 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 } globalParamsCount = skippedNode.paramsCount // 恢复参数计数 - continue walk // 继续外部循环 + backtrackToWildChild = true + continue walk // 继续外部循环 } } } @@ -547,7 +546,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path[:end] // 提取参数值 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(val); err == nil { val = v // 解码成功则更新值 } @@ -599,7 +598,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path // 参数值是剩余的整个路径 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(path); err == nil { val = v // 解码成功则更新值 } @@ -634,6 +633,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -658,8 +658,8 @@ walk: // 外部循环用于遍历路由树 } // 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议 - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 @@ -688,6 +688,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -758,8 +759,8 @@ walk: // 外部循环用于遍历路由树 // 未找到处理函数. // 尝试通过添加尾部斜杠来修复路径 if fixTrailingSlash { - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 @@ -781,8 +782,8 @@ walk: // 外部循环用于遍历路由树 if rb[0] != 0 { // 旧 rune 未处理完 idxc := rb[0] - for i, c := range []byte(n.indices) { - if c == idxc { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) @@ -813,9 +814,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 小写匹配 - if c == idxc { + if n.indices[i] == idxc { // 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在 if out := n.children[i].findCaseInsensitivePathRec( path, ciPath, rb, fixTrailingSlash, @@ -832,9 +833,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 大写匹配 - if c == idxc { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) diff --git a/tree_test.go b/tree_test.go index 7665afd..a35a1a8 100644 --- a/tree_test.go +++ b/tree_test.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" "testing" + "time" ) // Used as a workaround since we can't compare functions or their addresses @@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode { return &ps } +func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue { + t.Helper() + + resultCh := make(chan nodeValue, 1) + go func() { + resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape) + }() + + select { + case value := <-resultCh: + return value + case <-time.After(2 * time.Second): + t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path) + return nodeValue{} + } +} + func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { unescape := false if len(unescapes) >= 1 { @@ -1104,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) { t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams) } } + +func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) { + tests := []struct { + name string + routes []string + requestPath string + wantFullPath string + wantParams Params + }{ + { + name: "param route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/details"}, + requestPath: "/foo/bar/details", + wantFullPath: "/foo/:id/details", + wantParams: Params{{Key: "id", Value: "bar"}}, + }, + { + name: "catch-all route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/*rest"}, + requestPath: "/foo/bar/baz.txt", + wantFullPath: "/foo/:id/*rest", + wantParams: Params{ + {Key: "id", Value: "bar"}, + {Key: "rest", Value: "/baz.txt"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree := &node{} + for _, route := range tt.routes { + tree.addRoute(route, fakeHandler(route)) + } + + value := getValueWithTimeout(t, tree, tt.requestPath, false) + if value.handlers == nil { + t.Fatalf("expected handlers for %q", tt.requestPath) + } + if value.fullPath != tt.wantFullPath { + t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath) + } + if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) { + t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params) + } + }) + } +} From 5d979e56707a239e682f0c374cd6cc78d3bb2f3a Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 08:39:10 +0800 Subject: [PATCH 20/55] fix: reduce per-request context and fallback overhead Make Context keys lazy so requests that never call Set stop allocating on reset. Reuse stable 404 and 405 handlers and add focused benchmarks so ServeHTTP miss paths stay measurable. --- context.go | 2 +- context_benchmark_test.go | 78 +++++++++++++++++++++++++++++++++++++++ engine.go | 77 ++++++++++++++++++++------------------ engine_benchmark_test.go | 64 ++++++++++++++++++++++++++++++++ 4 files changed, 185 insertions(+), 36 deletions(-) create mode 100644 context_benchmark_test.go create mode 100644 engine_benchmark_test.go diff --git a/context.go b/context.go index 9c4ba7e..f24ceb0 100644 --- a/context.go +++ b/context.go @@ -97,7 +97,7 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 - c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 + c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map c.Errors = c.Errors[:0] // 清空 Errors 切片 c.queryCache = nil // 清空查询参数缓存 c.formCache = nil // 清空表单数据缓存 diff --git a/context_benchmark_test.go b/context_benchmark_test.go new file mode 100644 index 0000000..2198c59 --- /dev/null +++ b/context_benchmark_test.go @@ -0,0 +1,78 @@ +package touka + +import ( + "net/http" + "testing" +) + +func TestContextResetKeepsKeysNilUntilSet(t *testing.T) { + c, _ := CreateTestContext(nil) + if c.Keys != nil { + t.Fatalf("expected fresh test context Keys to be nil before first Set") + } + + c.Set("answer", 42) + if c.Keys == nil { + t.Fatalf("expected Set to allocate Keys map") + } + if value, exists := c.Get("answer"); !exists || value != 42 { + t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists) + } + + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + c.reset(c.Writer, req) + + if c.Keys != nil { + t.Fatalf("expected reset to clear Keys without allocating a new map") + } + if value, exists := c.Get("answer"); exists || value != nil { + t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists) + } + + ctxValue := c.Value("missing") + if ctxValue != nil { + t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue) + } + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected MustGet to panic for missing key after reset") + } + }() + _ = c.MustGet("answer") +} + +func BenchmarkContextReset(b *testing.B) { + b.Run("NoKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + } + }) + + b.Run("WithKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(c.Writer, req) + c.Set("request-id", i) + } + }) +} diff --git a/engine.go b/engine.go index ece023d..b2cc952 100644 --- a/engine.go +++ b/engine.go @@ -117,6 +117,46 @@ type ErrorHandle struct { type ErrorHandler func(c *Context, code int, err error) +var errMethodNotAllowed = errors.New("method not allowed") +var errNotFound = errors.New("not found") + +var methodNotAllowedHandler HandlerFunc = func(c *Context) { + httpMethod := c.Request.Method + requestPath := routeLookupPath(c.Request) + engine := c.engine + // 是否是OPTIONS方式 + if httpMethod == http.MethodOptions { + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + allowedMethods := engine.allowedMethodsForPath(requestPath) + if len(allowedMethods) > 0 { + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + c.Status(http.StatusOK) + return + } + } + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + for _, treeIter := range engine.methodTrees { + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + continue + } + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 + PutTempSkippedNodes(tempSkippedNodes) + if value.handlers != nil { + // 使用定义的ErrorHandle处理 + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed) + return + } + } +} + +var notFoundHandler HandlerFunc = func(c *Context) { + engine := c.engine + engine.errorHandle.handler(c, http.StatusNotFound, errNotFound) +} + // defaultErrorHandle 默认错误处理 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 select { @@ -479,45 +519,12 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { // 405中间件 func MethodNotAllowed() HandlerFunc { - return func(c *Context) { - httpMethod := c.Request.Method - requestPath := routeLookupPath(c.Request) - engine := c.engine - // 是否是OPTIONS方式 - if httpMethod == http.MethodOptions { - // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := engine.allowedMethodsForPath(requestPath) - if len(allowedMethods) > 0 { - // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) - c.Status(http.StatusOK) - return - } - } - // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 - for _, treeIter := range engine.methodTrees { - if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 - continue - } - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - // 使用定义的ErrorHandle处理 - engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) - return - } - } - } + return methodNotAllowedHandler } // 404最后处理 func NotFound() HandlerFunc { - return func(c *Context) { - engine := c.engine - engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found")) - } + return notFoundHandler } // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go new file mode 100644 index 0000000..5780230 --- /dev/null +++ b/engine_benchmark_test.go @@ -0,0 +1,64 @@ +package touka + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +var benchmarkStatusCode int + +func buildServeHTTPBenchmarkEngine() *Engine { + engine := New() + engine.GET("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.GET("/api/v1/users/:id", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + return engine +} + +func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) { + b.Helper() + + req, err := http.NewRequest(method, path, nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr = httptest.NewRecorder() + engine.ServeHTTP(rr, req) + } + + benchmarkStatusCode = rr.Code +} + +func BenchmarkServeHTTP(b *testing.B) { + engine := buildServeHTTPBenchmarkEngine() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users") + }) + + b.Run("NotFound", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist") + }) + + b.Run("MethodNotAllowed", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users") + }) + + b.Run("OptionsAllow", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") + }) +} From 2d4aefc86e5d0276bb0ad7dab39eefa75b3c68c7 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:06:56 +0800 Subject: [PATCH 21/55] fix: cut redirect and allow-path routing overhead Reuse fixed-path and Allow-header buffers so redirect and OPTIONS handling stop rebuilding temporary data on every request. Cache fallback chains and add regression coverage for redirect, 404, 405, and Allow behavior to keep the faster miss paths stable. --- context.go | 15 +++++ engine.go | 125 ++++++++++++++++++++++++++++----------- engine_benchmark_test.go | 7 +++ engine_test.go | 102 ++++++++++++++++++++++++++++++++ reverseproxy.go | 13 +++- tree.go | 50 ++++++++++++---- 6 files changed, 264 insertions(+), 48 deletions(-) create mode 100644 engine_test.go diff --git a/context.go b/context.go index f24ceb0..f06d21e 100644 --- a/context.go +++ b/context.go @@ -73,6 +73,12 @@ type Context struct { // skippedNodes 用于记录跳过的节点信息,以便回溯 // 通常在处理嵌套路由时使用 SkippedNodes []skippedNode + + // fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲. + fixedPathBuf []byte + + allowedMethodsBuf []string + allowHeaderBuf []byte } // --- Context 相关方法实现 --- @@ -111,6 +117,15 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } else { c.SkippedNodes = make([]skippedNode, 0, 256) } + if cap(c.fixedPathBuf) > 0 { + c.fixedPathBuf = c.fixedPathBuf[:0] + } + if cap(c.allowedMethodsBuf) > 0 { + c.allowedMethodsBuf = c.allowedMethodsBuf[:0] + } + if cap(c.allowHeaderBuf) > 0 { + c.allowHeaderBuf = c.allowHeaderBuf[:0] + } } // Next 在处理链中执行下一个处理函数 diff --git a/engine.go b/engine.go index b2cc952..5214654 100644 --- a/engine.go +++ b/engine.go @@ -11,6 +11,7 @@ import ( "reflect" "runtime" "strings" + "unicode/utf8" "net/http" @@ -82,6 +83,11 @@ type Engine struct { // GlobalMaxRequestBodySize 全局请求体Body大小限制 GlobalMaxRequestBodySize int64 + + notFoundChain HandlersChain + notFoundNoMethodChain HandlersChain + unmatchedFSChain HandlersChain + unmatchedFSNoMethodChain HandlersChain } // HandleFunc 注册一个或多个 HTTP 方法的路由 @@ -127,10 +133,19 @@ var methodNotAllowedHandler HandlerFunc = func(c *Context) { // 是否是OPTIONS方式 if httpMethod == http.MethodOptions { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := engine.allowedMethodsForPath(requestPath) + allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0]) + c.allowedMethodsBuf = allowedMethods[:0] if len(allowedMethods) > 0 { // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allowedMethods { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", BytesToString(allowHeader)) c.Status(http.StatusOK) return } @@ -251,6 +266,7 @@ func New() *Engine { TLSServerConfigurator: nil, GlobalMaxRequestBodySize: -1, } + engine.rebuildFallbackChains() engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() @@ -306,6 +322,7 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) { // 是否开启MethodNotAllowed func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { engine.HandleMethodNotAllowed = enable + engine.rebuildFallbackChains() } // SetLogger传入实例 @@ -346,6 +363,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF engine.unMatchFS.ServeUnmatchedAsFS = false engine.UnMatchFSRoutes = nil } + engine.rebuildFallbackChains() } // 获取默认Protocol配置 @@ -531,12 +549,52 @@ func NotFound() HandlerFunc { func (Engine *Engine) NoRoute(handler HandlerFunc) { Engine.noRoute = handler Engine.noRoutes = nil + Engine.rebuildFallbackChains() } // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { Engine.noRoute = nil Engine.noRoutes = handlerFuncs + Engine.rebuildFallbackChains() +} + +func (engine *Engine) rebuildFallbackChains() { + buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain { + finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound + if includeMethodNotAllowed { + finalSize++ + } + if includeUnmatchedFS { + finalSize += len(engine.UnMatchFSRoutes) + } + if engine.noRoute != nil { + finalSize++ + } else { + finalSize += len(engine.noRoutes) + } + + chain := make(HandlersChain, 0, finalSize) + chain = append(chain, engine.globalHandlers...) + if includeMethodNotAllowed { + chain = append(chain, methodNotAllowedHandler) + } + if includeUnmatchedFS { + chain = append(chain, engine.UnMatchFSRoutes...) + } + if engine.noRoute != nil { + chain = append(chain, engine.noRoute) + } else if len(engine.noRoutes) > 0 { + chain = append(chain, engine.noRoutes...) + } + chain = append(chain, notFoundHandler) + return chain + } + + engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false) + engine.notFoundNoMethodChain = buildChain(false, false) + engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS) + engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS) } // combineHandlers 组合多个处理函数链为一个 @@ -553,6 +611,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // 这些中间件将应用于所有注册的路由 func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { engine.globalHandlers = append(engine.globalHandlers, middleware...) + engine.rebuildFallbackChains() return engine } @@ -746,48 +805,24 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 return } - if engine.RedirectFixedPath { + if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) { // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. - ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) + ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash) if found { + c.fixedPathBuf = ciPath[:0] c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 return } + c.fixedPathBuf = ciPath[:0] } } } - // 构建处理链 - // 组合全局中间件和路由处理函数 - handlers := engine.globalHandlers - - // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 - // 则在全局中间件之后添加 MethodNotAllowed 处理器 - if engine.HandleMethodNotAllowed { - handlers = append(handlers, MethodNotAllowed()) - } - - // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed - // 则在处理链的最后添加 UnMatchFS 处理器 if engine.unMatchFS.ServeUnmatchedAsFS { - /* - var unMatchFSHandle = c.engine.unMatchFileServer - handlers = append(handlers, unMatchFSHandle) - */ - handlers = append(handlers, engine.UnMatchFSRoutes...) + c.handlers = engine.unmatchedFSChain + } else { + c.handlers = engine.notFoundChain } - - // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS - // 则在处理链的最后添加 NoRoute 处理器 - if engine.noRoute != nil { - handlers = append(handlers, engine.noRoute) - } else if len(engine.noRoutes) > 0 { - handlers = append(handlers, engine.noRoutes...) - } - - handlers = append(handlers, NotFound()) - - c.handlers = handlers c.Next() // 执行处理函数链 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } @@ -813,8 +848,28 @@ func isGeneralOptionsRequest(req *http.Request) bool { return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" } -func (engine *Engine) allowedMethodsForPath(requestPath string) []string { - allowedMethods := make([]string, 0, len(engine.methodTrees)) +func shouldTryFixedPathLookup(path string, root *node) bool { + if root != nil && root.hasCaseInsensitivePath { + return true + } + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false +} + +func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string { + if cap(allowedMethods) < len(engine.methodTrees) { + allowedMethods = make([]string, 0, len(engine.methodTrees)) + } else { + allowedMethods = allowedMethods[:0] + } for _, treeIter := range engine.methodTrees { // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 tempSkippedNodes := GetTempSkippedNodes() diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go index 5780230..666e8b2 100644 --- a/engine_benchmark_test.go +++ b/engine_benchmark_test.go @@ -16,6 +16,9 @@ func buildServeHTTPBenchmarkEngine() *Engine { engine.GET("/api/v1/users/:id", func(c *Context) { c.Status(http.StatusNoContent) }) + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) engine.POST("/api/v1/users", func(c *Context) { c.Status(http.StatusNoContent) }) @@ -61,4 +64,8 @@ func BenchmarkServeHTTP(b *testing.B) { b.Run("OptionsAllow", func(b *testing.B) { benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") }) + + b.Run("FixedPathRedirect", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS") + }) } diff --git a/engine_test.go b/engine_test.go new file mode 100644 index 0000000..292d5e2 --- /dev/null +++ b/engine_test.go @@ -0,0 +1,102 @@ +package touka + +import ( + "net/http" + "testing" +) + +func TestHandleRequestRedirectFixedPath(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" { + t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location) + } +} + +func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code) + } +} + +func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) { + engine := New() + engine.GET("/Users/Profile", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code) + } + if location := rr.Header().Get("Location"); location != "/Users/Profile" { + t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location) + } +} + +func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) { + engine := New() + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "hit" { + t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got) + } +} + +func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "" { + t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got) + } +} + +func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil) + if rr.Code != http.StatusOK { + t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code) + } + allow := rr.Header().Get("Allow") + if allow != "GET, POST" && allow != "POST, GET" { + t.Fatalf("expected Allow header to list matching methods, got %q", allow) + } +} diff --git a/reverseproxy.go b/reverseproxy.go index 1b89b2a..ff49aef 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -699,8 +699,17 @@ func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { if c.engine != nil { if c.Request != nil && c.Request.RequestURI != "*" { - if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request)); len(allow) > 0 { - c.Writer.Header().Set("Allow", strings.Join(allow, ", ")) + if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 { + c.allowedMethodsBuf = allow[:0] + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allow { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", BytesToString(allowHeader)) } } } diff --git a/tree.go b/tree.go index e9a10e6..6595655 100644 --- a/tree.go +++ b/tree.go @@ -121,14 +121,28 @@ const ( // node 表示路由树中的一个节点. type node struct { - path string // 当前节点的路径段 - indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 - wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) - nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) - priority uint32 // 节点的优先级, 用于查找时优先匹配 - children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 - handlers HandlersChain // 绑定到此节点的处理函数链 - fullPath string // 完整路径, 用于调试和错误信息 + path string // 当前节点的路径段 + indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 + wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) + hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由 + nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) + priority uint32 // 节点的优先级, 用于查找时优先匹配 + children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 + handlers HandlersChain // 绑定到此节点的处理函数链 + fullPath string // 完整路径, 用于调试和错误信息 +} + +func routeNeedsCaseInsensitiveLookup(path string) bool { + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false } // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. @@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int { func (n *node) addRoute(path string, handlers HandlersChain) { fullPath := path // 记录完整的路径 n.priority++ // 增加当前节点的优先级 + if routeNeedsCaseInsensitiveLookup(path) { + n.hasCaseInsensitivePath = true + } // 如果是空树(根节点) if len(n.path) == 0 && len(n.children) == 0 { @@ -702,13 +719,24 @@ walk: // 外部循环用于遍历路由树 // 它还可以选择修复尾部斜杠. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { + return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash) +} + +func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) { const stackBufSize = 128 // 栈上缓冲区的默认大小 // 在常见情况下使用栈上静态大小的缓冲区. // 如果路径太长, 则在堆上分配缓冲区. - buf := make([]byte, 0, stackBufSize) - if length := len(path) + 1; length > stackBufSize { - buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区 + if buf != nil { + buf = buf[:0] + } + if cap(buf) < len(path)+1 { + var stackBuf [stackBufSize]byte + if len(path)+1 <= stackBufSize { + buf = stackBuf[:0] + } else { + buf = make([]byte, 0, len(path)+1) // 如果路径太长, 则分配更大的缓冲区 + } } ciPath := n.findCaseInsensitivePathRec( From 57847fa44647a1670f49bc22d1889b7e6203e0c8 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:32:14 +0800 Subject: [PATCH 22/55] fix: avoid unsafe header buffer reuse Use safe string copies for pooled header buffers and simplify case-insensitive lookup buffering now that the pseudo stack path was ineffective. This addresses review concerns without changing the routing semantics. --- engine.go | 4 ++-- reverseproxy.go | 2 +- tree.go | 11 +---------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/engine.go b/engine.go index 5214654..698fbd5 100644 --- a/engine.go +++ b/engine.go @@ -145,7 +145,7 @@ var methodNotAllowedHandler HandlerFunc = func(c *Context) { allowHeader = append(allowHeader, method...) } c.allowHeaderBuf = allowHeader[:0] - c.Writer.Header().Set("Allow", BytesToString(allowHeader)) + c.Writer.Header().Set("Allow", string(allowHeader)) c.Status(http.StatusOK) return } @@ -810,7 +810,7 @@ func (engine *Engine) handleRequest(c *Context) { ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash) if found { c.fixedPathBuf = ciPath[:0] - c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 + c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径 return } c.fixedPathBuf = ciPath[:0] diff --git a/reverseproxy.go b/reverseproxy.go index ff49aef..fe66e2b 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -709,7 +709,7 @@ func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { allowHeader = append(allowHeader, method...) } c.allowHeaderBuf = allowHeader[:0] - c.Writer.Header().Set("Allow", BytesToString(allowHeader)) + c.Writer.Header().Set("Allow", string(allowHeader)) } } } diff --git a/tree.go b/tree.go index 6595655..b159c8d 100644 --- a/tree.go +++ b/tree.go @@ -723,20 +723,11 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]by } func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) { - const stackBufSize = 128 // 栈上缓冲区的默认大小 - - // 在常见情况下使用栈上静态大小的缓冲区. - // 如果路径太长, 则在堆上分配缓冲区. if buf != nil { buf = buf[:0] } if cap(buf) < len(path)+1 { - var stackBuf [stackBufSize]byte - if len(path)+1 <= stackBufSize { - buf = stackBuf[:0] - } else { - buf = make([]byte, 0, len(path)+1) // 如果路径太长, 则分配更大的缓冲区 - } + buf = make([]byte, 0, len(path)+1) } ciPath := n.findCaseInsensitivePathRec( From fa027347d32012678df1dd7aafded6d8e1444c1a Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:35:39 +0800 Subject: [PATCH 23/55] fix: reduce default error response overhead Encode the built-in 404 and 405 payload with a fixed struct instead of a map so default error pages allocate less on the hot miss path. Add a regression test to keep the JSON shape stable. --- engine.go | 12 +++++++----- engine_test.go | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/engine.go b/engine.go index 698fbd5..81d3673 100644 --- a/engine.go +++ b/engine.go @@ -126,6 +126,12 @@ type ErrorHandler func(c *Context, code int, err error) var errMethodNotAllowed = errors.New("method not allowed") var errNotFound = errors.New("not found") +type defaultErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` +} + var methodNotAllowedHandler HandlerFunc = func(c *Context) { httpMethod := c.Request.Method requestPath := routeLookupPath(c.Request) @@ -187,11 +193,7 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是 if err != nil { errMsg = err.Error() } - c.JSON(code, H{ - "code": code, - "message": http.StatusText(code), - "error": errMsg, - }) + c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg}) c.Writer.Flush() c.Abort() return diff --git a/engine_test.go b/engine_test.go index 292d5e2..71f9772 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1,6 +1,7 @@ package touka import ( + "encoding/json" "net/http" "testing" ) @@ -100,3 +101,23 @@ func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) { t.Fatalf("expected Allow header to list matching methods, got %q", allow) } } + +func TestDefaultErrorHandleJSONShape(t *testing.T) { + engine := New() + rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code) + } + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) + } + if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" { + t.Fatalf("unexpected error payload: %+v", body) + } +} From 987ea81329e34d43357f200ea58a38226d4b1d3b Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:57:16 +0800 Subject: [PATCH 24/55] fix: avoid fixed-path miss panic and trim 405 fallback work --- engine.go | 14 +++++++++----- engine_test.go | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/engine.go b/engine.go index 81d3673..f9d233a 100644 --- a/engine.go +++ b/engine.go @@ -155,22 +155,25 @@ var methodNotAllowedHandler HandlerFunc = func(c *Context) { c.Status(http.StatusOK) return } + return } // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + tempSkippedNodes := GetTempSkippedNodes() for _, treeIter := range engine.methodTrees { if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 continue } // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() + *tempSkippedNodes = (*tempSkippedNodes)[:0] value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - PutTempSkippedNodes(tempSkippedNodes) if value.handlers != nil { + PutTempSkippedNodes(tempSkippedNodes) // 使用定义的ErrorHandle处理 engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed) return } } + PutTempSkippedNodes(tempSkippedNodes) } var notFoundHandler HandlerFunc = func(c *Context) { @@ -815,7 +818,7 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径 return } - c.fixedPathBuf = ciPath[:0] + c.fixedPathBuf = c.fixedPathBuf[:0] } } } @@ -872,15 +875,16 @@ func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods [ } else { allowedMethods = allowedMethods[:0] } + tempSkippedNodes := GetTempSkippedNodes() for _, treeIter := range engine.methodTrees { // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() + *tempSkippedNodes = (*tempSkippedNodes)[:0] value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) - PutTempSkippedNodes(tempSkippedNodes) if value.handlers != nil { allowedMethods = append(allowedMethods, treeIter.method) } } + PutTempSkippedNodes(tempSkippedNodes) return allowedMethods } diff --git a/engine_test.go b/engine_test.go index 71f9772..571f4b7 100644 --- a/engine_test.go +++ b/engine_test.go @@ -48,6 +48,24 @@ func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) { } } +func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) { + engine := New() + engine.GET("/Users/Profile", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic for fixed-path miss: %v", r) + } + }() + + rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code) + } +} + func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) { engine := New() engine.NoRoute(func(c *Context) { From e4d3eed379cb58c1bbcc915d776a3c9ccda6e796 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 17:44:55 +0800 Subject: [PATCH 25/55] feat: redesign server startup around Run options Replace the old RunShutdown and RunTLS style entry points with a single Run(opts...) API for v1. Add focused startup semantics tests, keep TLS and graceful shutdown independent, ensure sibling servers are cleaned up on startup failure, and update docs to match the new option-based startup model. --- README.md | 4 +- about-touka.md | 22 +- docs/advanced.md | 36 ++- docs/introduction.md | 2 +- docs/quickstart.md | 6 +- docs/reverse-proxy.md | 4 +- docs/routing.md | 2 +- docs/sse.md | 2 +- docs/static-files.md | 2 +- engine.go | 22 +- protocols_test.go | 29 +-- serve.go | 585 ++++++++++++++++++++++-------------------- serve_test.go | 196 ++++++++++++++ 13 files changed, 577 insertions(+), 335 deletions(-) diff --git a/README.md b/README.md index a7b99fd..e2eaec8 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,9 @@ func main() { c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query) }) - // 启动服务器 (支持优雅关闭) + // 启动服务器(通过 WithGracefulShutdown 启用优雅关闭) log.Println("Touka Server starting on :8080...") - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("Touka server failed to start: %v", err) } } diff --git a/about-touka.md b/about-touka.md index 86a056f..b3a16b4 100644 --- a/about-touka.md +++ b/about-touka.md @@ -70,13 +70,13 @@ func main() { r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB // ... 其他配置 - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` #### 1.3. 服务器生命周期管理 -Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。 +Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。 ```go func main() { @@ -90,11 +90,11 @@ func main() { fmt.Println("自定义的 HTTP 服务器配置已应用") }) - // 启动服务器,并支持优雅关闭 - // RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号 - // 第二个参数是优雅关闭的超时时间 + // 启动服务器,并通过 Run 选项启用优雅关闭 + // Run(...) 会阻塞当前 goroutine + // WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒 fmt.Println("服务器启动于 :8080") - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("服务器启动失败: %v", err) } } @@ -187,7 +187,7 @@ func main() { } } - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } func AuthMiddleware() touka.HandlerFunc { @@ -313,7 +313,7 @@ func main() { }) }) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } // templates/index.html @@ -400,7 +400,7 @@ func main() { c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID}) }) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` @@ -483,7 +483,7 @@ func main() { // 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获 r.StaticDir("/files", "./non-existent-dir") - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` @@ -546,7 +546,7 @@ func main() { // 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录 r.StaticFS("/", http.FS(subFS)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/docs/advanced.md b/docs/advanced.md index a7cb9a2..7e6a417 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -44,7 +44,9 @@ r.SetTLSServerConfigurator(func(server *http.Server) { Touka 支持配置 HTTP/1.1、HTTP/2 和 H2C(HTTP/2 Cleartext): ```go -// 使用默认协议配置(仅 HTTP/1.1) +// 使用默认协议配置 +// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集, +// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。 r.SetDefaultProtocols() // 自定义协议配置 @@ -57,33 +59,49 @@ r.SetProtocols(&touka.ProtocolsConfig{ ### 启动方式 -Touka 提供了多种服务器启动方式: +Touka 统一通过 `Run(opts...)` 启动服务器: ```go // 1. 简单启动(无优雅停机) -r.Run(":8080") +r.Run(touka.WithAddr(":8080")) // 2. 带优雅停机的启动 -r.RunShutdown(":8080", 10*time.Second) +r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)) // 3. 带上下文的优雅停机 ctx, cancel := context.WithCancel(context.Background()) -r.RunShutdownWithContext(":8080", ctx, 10*time.Second) +defer cancel() +r.Run( + touka.WithAddr(":8080"), + touka.WithGracefulShutdown(10*time.Second), + touka.WithShutdownContext(ctx), +) // 4. HTTPS 启动 tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, // 其他 TLS 配置... } -r.RunTLS(":443", tlsConfig, 10*time.Second) +// WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。 +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithGracefulShutdownDefault(), +) // 5. HTTPS + HTTP 重定向 -r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second) +// WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。 +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect(":80"), + touka.WithGracefulShutdown(10*time.Second), +) ``` ## 优雅停机 (Graceful Shutdown) -在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。 +在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 ```go r := touka.Default() @@ -91,7 +109,7 @@ r := touka.Default() // 监听 SIGINT 和 SIGTERM 信号 // 如果在 10 秒内未处理完,则强制关闭 -if err := r.RunShutdown(":8080", 10*time.Second); err != nil { +if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatal("服务器退出异常:", err) } ``` diff --git a/docs/introduction.md b/docs/introduction.md index 94a7310..87c3e40 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其 1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。 2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。 -3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理。 +3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求。 Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。 diff --git a/docs/quickstart.md b/docs/quickstart.md index 94f7433..2911732 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -46,7 +46,7 @@ func main() { // 4. 启动服务器并监听 8080 端口 log.Println("Touka server is running on :8080") - if err := r.Run(":8080"); err != nil { + if err := r.Run(touka.WithAddr(":8080")); err != nil { log.Fatalf("Server failed: %v", err) } } @@ -66,11 +66,11 @@ go run main.go ## 优雅停机 -在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成。 +在生产环境中,我们推荐为 `Run` 追加优雅关闭选项。启用后,Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`。 ```go // 等待 10 秒以处理剩余请求 -if err := r.RunShutdown(":8080", 10*time.Second); err != nil { +if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("Server forced to shutdown: %v", err) } ``` diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 7d05290..cb4b2a3 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -28,7 +28,7 @@ func main() { Target: target, })) - _ = r.Run(":8080") + _ = r.Run(touka.WithAddr(":8080")) } ``` @@ -497,7 +497,7 @@ func main() { }, })) - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatal(err) } } diff --git a/docs/routing.md b/docs/routing.md index 223081a..70a24dc 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -142,7 +142,7 @@ func main() { r := touka.Default() fsroot, _ := fs.Sub(content, "dist") r.StaticFS("/", http.FS(fsroot)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/docs/sse.md b/docs/sse.md index 1b44521..a003be9 100644 --- a/docs/sse.md +++ b/docs/sse.md @@ -125,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) { 2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。 3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。 -**注意:** 请务必使用 `RunShutdown`、`RunTLS` 或 `RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号。 +**注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)` 或 `touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文。 diff --git a/docs/static-files.md b/docs/static-files.md index a2138cd..b1f06a8 100644 --- a/docs/static-files.md +++ b/docs/static-files.md @@ -39,7 +39,7 @@ func main() { // 您也可以使用 StaticFS 服务根路径 // r.StaticFS("/", http.FS(fsroot)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/engine.go b/engine.go index f9d233a..2849ffa 100644 --- a/engine.go +++ b/engine.go @@ -404,11 +404,18 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) { }() } -// applyDefaultServerConfig 应用框架的默认配置到 http.Server -func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { - if engine.serverProtocols != nil { - srv.Protocols = engine.serverProtocols - if engine.serverProtocols.HTTP2() || engine.serverProtocols.UnencryptedHTTP2() { +func cloneServerProtocols(protocols *http.Protocols) *http.Protocols { + if protocols == nil { + return nil + } + cloned := *protocols + return &cloned +} + +func applyServerProtocols(srv *http.Server, protocols *http.Protocols) { + if protocols != nil { + srv.Protocols = cloneServerProtocols(protocols) + if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() { if err := configureHTTP2ExtendedConnectServer(srv); err != nil { panic(err) } @@ -416,6 +423,11 @@ func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { } } +// applyDefaultServerConfig 应用框架的默认配置到 http.Server +func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { + applyServerProtocols(srv, engine.serverProtocols) +} + // 配置全局Req Body大小限制 func (engine *Engine) SetGlobalMaxRequestBodySize(size int64) { engine.GlobalMaxRequestBodySize = size diff --git a/protocols_test.go b/protocols_test.go index 73f16e9..0e2bf1f 100644 --- a/protocols_test.go +++ b/protocols_test.go @@ -70,42 +70,25 @@ func TestApplyDefaultServerConfig(t *testing.T) { } } -func TestRunTLSProtocolInheritance(t *testing.T) { +func TestTLSRunDefaultsProtocolInheritance(t *testing.T) { engine := New() - // 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2 - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, - }) - } - - srv := &http.Server{TLSConfig: &tls.Config{}} - engine.applyDefaultServerConfig(srv) + srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) if !srv.Protocols.HTTP2() { - t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config") + t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config") } - // 模拟用户设置了自定义协议后调用 RunTLS + // 模拟用户设置了自定义协议后进入 TLS 运行模式 engine = New() engine.SetProtocols(&ProtocolsConfig{ Http1: true, Http2: false, // 用户明确不想要 HTTP/2 }) - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, - }) - } - - srv2 := &http.Server{TLSConfig: &tls.Config{}} - engine.applyDefaultServerConfig(srv2) + srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) if srv2.Protocols.HTTP2() { - t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously") + t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols") } } diff --git a/serve.go b/serve.go index 1825b32..2c8c73b 100644 --- a/serve.go +++ b/serve.go @@ -21,45 +21,119 @@ import ( "github.com/fenthope/reco" ) -// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间 const defaultShutdownTimeout = 5 * time.Second -// --- 内部辅助函数 --- +type runMode uint8 -// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080" -func resolveAddress(addr []string) string { - switch len(addr) { - case 0: - return ":8080" - case 1: - return addr[0] - default: - panic("too many parameters provided for server address") +const ( + runModeHTTP runMode = iota + runModeHTTPS + runModeHTTPSRedirect +) + +type runConfig struct { + addr string + httpRedirectAddr string + tlsConfig *tls.Config + graceful bool + shutdownTimeout time.Duration + gracefulCtx context.Context + mode runMode + shutdownDefaultSet bool + shutdownTimeoutSet bool +} + +type RunOption interface { + apply(*runConfig) error +} + +type runOptionFunc func(*runConfig) error + +func (f runOptionFunc) apply(cfg *runConfig) error { + return f(cfg) +} + +func defaultRunConfig() runConfig { + return runConfig{ + addr: ":8080", + shutdownTimeout: defaultShutdownTimeout, + mode: runModeHTTP, } } -// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值 -func getShutdownTimeout(timeouts []time.Duration) time.Duration { - if len(timeouts) > 0 && timeouts[0] > 0 { - return timeouts[0] - } - return defaultShutdownTimeout +func WithAddr(addr string) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if addr == "" { + return errors.New("run address must not be empty") + } + cfg.addr = addr + return nil + }) +} + +func WithTLS(tlsConfig *tls.Config) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil") + } + cfg.tlsConfig = tlsConfig + if cfg.mode == runModeHTTP { + cfg.mode = runModeHTTPS + } + return nil + }) +} + +func WithHTTPRedirect(addr string) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if addr == "" { + return errors.New("http redirect address must not be empty") + } + cfg.httpRedirectAddr = addr + cfg.mode = runModeHTTPSRedirect + return nil + }) +} + +func WithGracefulShutdown(timeout time.Duration) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + cfg.graceful = true + cfg.shutdownTimeoutSet = true + if timeout > 0 { + cfg.shutdownTimeout = timeout + } else { + cfg.shutdownTimeout = defaultShutdownTimeout + } + return nil + }) +} + +func WithGracefulShutdownDefault() RunOption { + return runOptionFunc(func(cfg *runConfig) error { + cfg.graceful = true + cfg.shutdownDefaultSet = true + cfg.shutdownTimeout = defaultShutdownTimeout + return nil + }) +} + +func WithShutdownContext(ctx context.Context) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if ctx == nil { + return errors.New("shutdown context must not be nil") + } + cfg.gracefulCtx = ctx + return nil + }) } -// serveServer 根据显式指定的启动模式运行 HTTP 或 HTTPS 服务器. func serveServer(srv *http.Server, serveTLS bool) error { if serveTLS { - // 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置, - // ListenAndServeTLS 的前两个参数可以为空字符串 return srv.ListenAndServeTLS("", "") } - return srv.ListenAndServe() } -// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server, -// 并处理其启动失败的致命错误 -// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS") func runServer(serverType string, srv *http.Server, serveTLS bool) { go func() { protocol := "http" @@ -70,284 +144,90 @@ func runServer(serverType string, srv *http.Server, serveTLS bool) { log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) err := serveServer(srv, serveTLS) - - // 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed), - // 则认为是一个严重错误,并终止程序 if err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Touka %s server failed: %v", serverType, err) } }() } -// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器 -// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿 -func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error { - // 创建一个 channel 来接收操作系统信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 - <-quit // 阻塞,直到接收到上述信号之一 - log.Println("Shutting down Touka server(s)...") - - // 关闭日志记录器 - if logger != nil { - go func() { - log.Println("Closing Touka logger...") - CloseLogger(logger) - }() +func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config { + if tlsConfig == nil { + return nil } - - // 创建一个带超时的上下文,用于 Shutdown - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - var wg sync.WaitGroup - errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel - - // 并发地关闭所有服务器 - for _, srv := range servers { - wg.Add(1) - go func(s *http.Server) { - defer wg.Done() - if err := s.Shutdown(ctx); err != nil { - // 将错误发送到 channel - errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) - } - }(srv) - } - - wg.Wait() // 等待所有服务器的关闭 goroutine 完成 - close(errChan) // 关闭 channel,以便可以安全地遍历它 - - // 收集所有关闭过程中发生的错误 - var shutdownErrors []error - for err := range errChan { - shutdownErrors = append(shutdownErrors, err) - log.Printf("Shutdown error: %v", err) - } - - if len(shutdownErrors) > 0 { - return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 - } - log.Println("Touka server(s) exited gracefully.") - return nil + return tlsConfig.Clone() } -func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error { - // 创建一个 channel 来接收操作系统信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 - - // 启动服务器 - serverStopped := make(chan error, 1) - for _, srv := range servers { - go func(s *http.Server) { - serverStopped <- s.ListenAndServe() - }(srv) +func parseHTTPSPort(addr string) (string, error) { + _, port, err := net.SplitHostPort(addr) + if err != nil { + return "", fmt.Errorf("https address %q must include a port: %w", addr, err) } + return port, nil +} - select { - case <-ctx.Done(): - // Context 被取消 (例如,通过外部取消函数) - log.Println("Context cancelled, shutting down Touka server(s)...") - case err := <-serverStopped: - // 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("Touka HTTP server failed: %w", err) +func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) { + if serveTLS { + if engine.TLSServerConfigurator != nil { + engine.TLSServerConfigurator(srv) + return } - log.Println("Touka HTTP server stopped gracefully.") - return nil // 服务器已自行优雅关闭,无需进一步处理 - case <-quit: - // 接收到操作系统信号 - log.Println("Shutting down Touka server(s) due to OS signal...") } - - // 关闭日志记录器 - if logger != nil { - go func() { - log.Println("Closing Touka logger...") - CloseLogger(logger) - }() - } - - // 创建一个带超时的上下文,用于 Shutdown - shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - var wg sync.WaitGroup - errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel - - // 并发地关闭所有服务器 - for _, srv := range servers { - wg.Add(1) - go func(s *http.Server) { - defer wg.Done() - if err := s.Shutdown(shutdownCtx); err != nil { - // 将错误发送到 channel - errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) - } - }(srv) - } - - wg.Wait() - close(errChan) // 关闭 channel,以便可以安全地遍历它 - - // 收集所有关闭过程中发生的错误 - var shutdownErrors []error - for err := range errChan { - shutdownErrors = append(shutdownErrors, err) - log.Printf("Shutdown error: %v", err) - } - - if len(shutdownErrors) > 0 { - return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 - } - log.Println("Touka server(s) exited gracefully.") - return nil -} - -// --- 公共 Run 方法 --- - -// Run 启动一个不支持优雅关闭的 HTTP 服务器 -// 这是一个阻塞调用,主要用于简单的场景或快速测试 -// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法 -func (engine *Engine) Run(addr ...string) error { - address := resolveAddress(addr) - srv := &http.Server{Addr: address, Handler: engine} - - // 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性 - engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } - log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address) - return srv.ListenAndServe() } -// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 -func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error { - srv := &http.Server{ - Addr: addr, - Handler: engine, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, - } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - engine.applyDefaultServerConfig(srv) +func applyRedirectServerConfig(engine *Engine, srv *http.Server) { + applyServerProtocols(srv, engine.serverProtocols) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } - - runServer("HTTP", srv, false) - return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) } -// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 -func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error { - srv := &http.Server{ - Addr: addr, - Handler: engine, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, +func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols { + if engine == nil { + return nil } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - engine.applyDefaultServerConfig(srv) - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) + if serveTLS && engine.useDefaultProtocols { + protocols := &http.Protocols{} + protocols.SetHTTP1(true) + protocols.SetHTTP2(true) + return protocols } - - return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco) + return cloneServerProtocols(engine.serverProtocols) } -// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器 -func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunTLS") - } - - // 配置 HTTP/2 支持 (如果使用默认配置) - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, // 默认在 TLS 上启用 HTTP/2 - }) - } - - srv := &http.Server{ - Addr: addr, +func buildMainServer(engine *Engine, cfg runConfig) *http.Server { + serveTLS := cfg.mode != runModeHTTP + server := &http.Server{ + Addr: cfg.addr, Handler: engine, - TLSConfig: tlsConfig, - BaseContext: func(l net.Listener) context.Context { + TLSConfig: cloneTLSConfig(cfg.tlsConfig), + } + if cfg.graceful { + server.BaseContext = func(net.Listener) context.Context { return engine.shutdownCtx - }, + } + server.RegisterOnShutdown(engine.shutdownCancel) } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - // 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator - engine.applyDefaultServerConfig(srv) - if engine.TLSServerConfigurator != nil { - engine.TLSServerConfigurator(srv) - } else if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) - } - - runServer("HTTPS", srv, true) - return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) + applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS)) + applyMainServerConfig(engine, server, serveTLS) + return server } -// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名 -func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - return engine.RunTLS(addr, tlsConfig, timeouts...) -} - -// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭 -func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunTLSRedir") +func buildRedirectServer(engine *Engine, httpsAddr, httpAddr string) (*http.Server, error) { + httpsPort, err := parseHTTPSPort(httpsAddr) + if err != nil { + return nil, err } - // --- HTTPS 服务器 --- - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true}) - } - httpsSrv := &http.Server{ - Addr: httpsAddr, - Handler: engine, - TLSConfig: tlsConfig, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, - } - httpsSrv.RegisterOnShutdown(engine.shutdownCancel) - engine.applyDefaultServerConfig(httpsSrv) - if engine.TLSServerConfigurator != nil { - engine.TLSServerConfigurator(httpsSrv) - } else if engine.ServerConfigurator != nil { - engine.ServerConfigurator(httpsSrv) - } - - // --- HTTP 重定向服务器 --- redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { host, _, err := net.SplitHostPort(r.Host) if err != nil { host = r.Host } - _, httpsPort, err := net.SplitHostPort(httpsAddr) - if err != nil { - // 如果 httpsAddr 没有端口,这是一个配置错误 - - log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr) - } - targetURL := "https://" + host - // 只有在非标准 HTTPS 端口 (443) 时才附加端口号 if httpsPort != "443" { targetURL = "https://" + net.JoinHostPort(host, httpsPort) } @@ -355,22 +235,175 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con http.Redirect(w, r, targetURL, http.StatusMovedPermanently) }) - httpSrv := &http.Server{ - Addr: httpAddr, - Handler: redirectHandler, - } - engine.applyDefaultServerConfig(httpSrv) - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(httpSrv) - } - // --- 启动服务器和优雅关闭 --- - runServer("HTTPS", httpsSrv, true) - runServer("HTTP Redirect", httpSrv, false) - return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco) + server := &http.Server{Addr: httpAddr, Handler: redirectHandler} + applyRedirectServerConfig(engine, server) + return server, nil } -// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性 -func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...) +func validateRunConfig(cfg runConfig) error { + if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil { + return errors.New("WithHTTPRedirect requires WithTLS") + } + if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil { + return errors.New("https mode requires WithTLS") + } + if cfg.httpRedirectAddr != "" && cfg.mode != runModeHTTPSRedirect { + cfg.mode = runModeHTTPSRedirect + } + if cfg.gracefulCtx != nil && !cfg.graceful { + return errors.New("WithShutdownContext requires graceful shutdown") + } + return nil +} + +func effectiveShutdownTimeout(cfg runConfig) time.Duration { + if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet { + if cfg.shutdownTimeout > 0 { + return cfg.shutdownTimeout + } + } + return defaultShutdownTimeout +} + +func closeLoggerAsync(logger *reco.Logger) { + if logger == nil { + return + } + go func() { + log.Println("Closing Touka logger...") + CloseLogger(logger) + }() +} + +func shutdownServers(servers []*http.Server, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + var wg sync.WaitGroup + errChan := make(chan error, len(servers)) + for _, srv := range servers { + wg.Add(1) + go func(s *http.Server) { + defer wg.Done() + if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) + } + }(srv) + } + + wg.Wait() + close(errChan) + + var shutdownErrors []error + for err := range errChan { + shutdownErrors = append(shutdownErrors, err) + log.Printf("Shutdown error: %v", err) + } + if len(shutdownErrors) > 0 { + return errors.Join(shutdownErrors...) + } + return nil +} + +func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(quit) + + serverStopped := make(chan error, len(servers)) + for i, srv := range servers { + serveTLSFlag := serveTLS[i] + go func(server *http.Server, useTLS bool) { + serverStopped <- serveServer(server, useTLS) + }(srv, serveTLSFlag) + } + + select { + case err := <-serverStopped: + if err != nil && !errors.Is(err, http.ErrServerClosed) { + if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil { + return errors.Join(err, shutdownErr) + } + return err + } + log.Println("Touka server stopped gracefully.") + return nil + case <-quit: + log.Println("Shutting down Touka server(s) due to OS signal...") + case <-shutdownCtx.Done(): + log.Println("Context cancelled, shutting down Touka server(s)...") + } + + closeLoggerAsync(logger) + if err := shutdownServers(servers, timeout); err != nil { + return err + } + log.Println("Touka server(s) exited gracefully.") + return nil +} + +// Run starts the engine with the provided startup options. +// +// Default behavior with no options: +// - HTTP only +// - listens on :8080 +// - no graceful shutdown orchestration +// +// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable +// signal-aware graceful shutdown and request-context cancellation semantics. +// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown. +func (engine *Engine) Run(opts ...RunOption) error { + cfg := defaultRunConfig() + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.apply(&cfg); err != nil { + return err + } + } + if cfg.httpRedirectAddr != "" { + cfg.mode = runModeHTTPSRedirect + } else if cfg.tlsConfig != nil { + cfg.mode = runModeHTTPS + } + if err := validateRunConfig(cfg); err != nil { + return err + } + + serveTLS := cfg.mode != runModeHTTP + + mainServer := buildMainServer(engine, cfg) + servers := []*http.Server{mainServer} + serveTLSFlags := []bool{serveTLS} + if cfg.mode == runModeHTTPSRedirect { + redirectServer, err := buildRedirectServer(engine, cfg.addr, cfg.httpRedirectAddr) + if err != nil { + return err + } + servers = append(servers, redirectServer) + serveTLSFlags = append(serveTLSFlags, false) + } + + if !cfg.graceful { + if len(servers) > 1 { + runServer("HTTPS", servers[0], true) + log.Printf("Starting Touka HTTP Redirect server on %s", servers[1].Addr) + return serveServer(servers[1], false) + } + + protocolLabel := "HTTP" + if serveTLS { + protocolLabel = "HTTPS" + } + log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr) + return serveServer(mainServer, serveTLS) + } + + shutdownCtx := context.Background() + if cfg.gracefulCtx != nil { + shutdownCtx = cfg.gracefulCtx + } + return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx) } diff --git a/serve_test.go b/serve_test.go index 6092f7b..6ecbeba 100644 --- a/serve_test.go +++ b/serve_test.go @@ -7,6 +7,8 @@ import ( "io" "net" "net/http" + "net/http/httptest" + "strings" "testing" "time" ) @@ -79,3 +81,197 @@ func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err) } } + +func TestRunRejectsRedirectWithoutTLS(t *testing.T) { + engine := New() + err := engine.Run(WithHTTPRedirect(":80")) + if err == nil { + t.Fatal("expected redirect mode without TLS to fail") + } +} + +func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) { + cfg := defaultRunConfig() + if err := WithGracefulShutdownDefault().apply(&cfg); err != nil { + t.Fatalf("apply graceful default option: %v", err) + } + if !cfg.graceful { + t.Fatal("expected graceful shutdown to be enabled") + } + if cfg.shutdownTimeout != defaultShutdownTimeout { + t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout) + } +} + +func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) { + cfg := defaultRunConfig() + tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12} + if err := WithTLS(tlsConfig).apply(&cfg); err != nil { + t.Fatalf("apply TLS option: %v", err) + } + if cfg.mode != runModeHTTPS { + t.Fatalf("expected HTTPS mode, got %v", cfg.mode) + } + if cfg.graceful { + t.Fatal("expected TLS option to remain independent from graceful shutdown") + } + if cfg.tlsConfig != tlsConfig { + t.Fatal("expected TLS config to be preserved in run config") + } +} + +func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { + engine := New() + if _, err := buildRedirectServer(engine, "example.com", ":80"); err == nil { + t.Fatal("expected redirect server builder to reject https address without port") + } +} + +func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { + cfg := defaultRunConfig() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := WithShutdownContext(ctx).apply(&cfg); err != nil { + t.Fatalf("apply shutdown context option: %v", err) + } + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected shutdown context without graceful shutdown to fail validation") + } +} + +func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) { + engine := New() + server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP}) + if server.BaseContext == nil { + t.Fatal("expected graceful main server to set BaseContext") + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for base context check: %v", err) + } + defer listener.Close() + if got := server.BaseContext(listener); got != engine.shutdownCtx { + t.Fatal("expected graceful main server to use engine shutdown context") + } +} + +func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) { + engine := New() + serverConfigured := false + tlsConfigured := false + engine.SetServerConfigurator(func(s *http.Server) { + serverConfigured = true + s.ReadTimeout = time.Second + }) + engine.SetTLSServerConfigurator(func(s *http.Server) { + tlsConfigured = true + s.IdleTimeout = time.Second + }) + + server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) + if !tlsConfigured { + t.Fatal("expected TLS configurator to run for HTTPS main server") + } + if serverConfigured { + t.Fatal("expected generic server configurator to be skipped when TLS configurator is set") + } + if server.IdleTimeout != time.Second { + t.Fatal("expected TLS configurator changes to be applied to HTTPS main server") + } +} + +func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) { + engine := New() + configured := false + engine.SetServerConfigurator(func(s *http.Server) { + configured = true + s.ReadTimeout = time.Second + }) + + server, err := buildRedirectServer(engine, ":443", ":80") + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + if !configured { + t.Fatal("expected redirect server to use generic server configurator") + } + if server.ReadTimeout != time.Second { + t.Fatal("expected redirect server configurator changes to be applied") + } +} + +func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) { + engine := New() + httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) + if !httpsServer.Protocols.HTTP2() { + t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings") + } + + httpServer := buildMainServer(engine, defaultRunConfig()) + if httpServer.Protocols.HTTP2() { + t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled") + } +} + +func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, ":443", ":80") + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on occupied addr: %v", err) + } + occupiedAddr := occupied.Addr().String() + defer occupied.Close() + + redirectListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for redirect addr: %v", err) + } + redirectAddr := redirectListener.Addr().String() + if err := redirectListener.Close(); err != nil { + t.Fatalf("close redirect addr probe: %v", err) + } + + engine := New() + redirectServer, err := buildRedirectServer(engine, ":443", redirectAddr) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + mainServer := &http.Server{Addr: occupiedAddr, Handler: engine} + + err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background()) + if err == nil { + t.Fatal("expected gracefulServe to fail when one server cannot bind") + } + if !strings.Contains(err.Error(), occupiedAddr) { + t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err) + } + + conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond) + if dialErr == nil { + conn.Close() + t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr) + } + if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") { + t.Fatalf("unexpected dial result after shutdown, got %v", dialErr) + } +} From e2cf08d5ddc659e67077ad52953ce50bc8d8694d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 19:49:13 +0800 Subject: [PATCH 26/55] feat: add redirect host selection options Support explicit redirect host source selection for HTTP-to-HTTPS redirects with ordered header lookup, fixed host mode, and strict validation. Document the new redirect option relationships and add focused tests for 426 fallback, conflict checks, and non-graceful startup errors. --- docs/advanced.md | 98 ++++++++++++++++++++++++++ engine.go | 10 +-- serve.go | 178 +++++++++++++++++++++++++++++++++++++++++------ serve_test.go | 163 +++++++++++++++++++++++++++++++++++++++++-- touka.go | 8 +-- 5 files changed, 422 insertions(+), 35 deletions(-) diff --git a/docs/advanced.md b/docs/advanced.md index 7e6a417..eb44c2d 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -97,8 +97,106 @@ r.Run( touka.WithHTTPRedirect(":80"), touka.WithGracefulShutdown(10*time.Second), ) + +// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) + +// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) ``` +### HTTPS Redirect Host 策略 + +`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。 + +可用的 redirect 子选项: + +- `touka.WithUseHeaderHost(true|false)` +- `touka.WithRedirectHostHeaders([]string{...})` +- `touka.WithRedirectHost("example.com")` + +#### 模式一:使用请求输入侧的 host + +当 `WithUseHeaderHost(true)` 时: + +- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host` +- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值 +- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) +``` + +#### 模式二:使用配置的固定 host + +当 `WithUseHeaderHost(false)` 时: + +- 不读取 `Request.Host` +- 不读取 `WithRedirectHostHeaders(...)` +- 必须配置 `WithRedirectHost("example.com")` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) +``` + +#### 严格校验规则 + +以下组合会直接返回配置错误: + +- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)` +- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)` +- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)` +- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)` +- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)` + +#### 优先级关系 + +1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式 +2. `WithUseHeaderHost(...)` 决定 host 来源模式 +3. 当 `WithUseHeaderHost(true)` 时: + - 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询 + - 未配置时使用 `Request.Host` +4. 当 `WithUseHeaderHost(false)` 时: + - 只使用 `WithRedirectHost(...)` + +**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。 + ## 优雅停机 (Graceful Shutdown) 在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 diff --git a/engine.go b/engine.go index 2849ffa..b0723e7 100644 --- a/engine.go +++ b/engine.go @@ -626,7 +626,7 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // Use 将全局中间件添加到 Engine // 这些中间件将应用于所有注册的路由 -func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { +func (engine *Engine) Use(middleware ...HandlerFunc) Router { engine.globalHandlers = append(engine.globalHandlers, middleware...) engine.rebuildFallbackChains() return engine @@ -695,7 +695,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo { // Group 创建一个新的路由组 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 -func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 basePath: resolveRoutePath("/", relativePath), @@ -704,7 +704,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute } // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 -// 它也实现了 IRouter 接口,允许嵌套分组 +// 它也实现了 Router 接口,允许嵌套分组 type RouterGroup struct { Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 basePath string // 组路径前缀 @@ -713,7 +713,7 @@ type RouterGroup struct { // Use 将中间件应用于当前路由组 // 这些中间件将应用于当前组及其子组的所有路由 -func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { +func (group *RouterGroup) Use(middleware ...HandlerFunc) Router { group.Handlers = append(group.Handlers, middleware...) return group } @@ -759,7 +759,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) { } // Group 为当前组创建一个新的子组 -func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: group.engine.combineHandlers(group.Handlers, handlers), basePath: resolveRoutePath(group.basePath, relativePath), diff --git a/serve.go b/serve.go index 2c8c73b..b2ba358 100644 --- a/serve.go +++ b/serve.go @@ -14,6 +14,7 @@ import ( "net/http" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -32,15 +33,19 @@ const ( ) type runConfig struct { - addr string - httpRedirectAddr string - tlsConfig *tls.Config - graceful bool - shutdownTimeout time.Duration - gracefulCtx context.Context - mode runMode - shutdownDefaultSet bool - shutdownTimeoutSet bool + addr string + httpRedirectAddr string + tlsConfig *tls.Config + redirectHost string + redirectHostHeaders []string + useHeaderHost bool + useHeaderHostSet bool + graceful bool + shutdownTimeout time.Duration + gracefulCtx context.Context + mode runMode + shutdownDefaultSet bool + shutdownTimeoutSet bool } type RunOption interface { @@ -58,9 +63,20 @@ func defaultRunConfig() runConfig { addr: ":8080", shutdownTimeout: defaultShutdownTimeout, mode: runModeHTTP, + useHeaderHost: true, } } +type HTTPRedirectOption interface { + applyRedirect(*runConfig) error +} + +type redirectOptionFunc func(*runConfig) error + +func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error { + return f(cfg) +} + func WithAddr(addr string) RunOption { return runOptionFunc(func(cfg *runConfig) error { if addr == "" { @@ -84,13 +100,52 @@ func WithTLS(tlsConfig *tls.Config) RunOption { }) } -func WithHTTPRedirect(addr string) RunOption { +func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption { return runOptionFunc(func(cfg *runConfig) error { if addr == "" { return errors.New("http redirect address must not be empty") } cfg.httpRedirectAddr = addr cfg.mode = runModeHTTPSRedirect + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.applyRedirect(cfg); err != nil { + return err + } + } + return nil + }) +} + +func WithUseHeaderHost(enabled bool) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.useHeaderHost = enabled + cfg.useHeaderHostSet = true + return nil + }) +} + +func WithRedirectHost(host string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + if host == "" { + return errors.New("redirect host must not be empty") + } + cfg.redirectHost = host + return nil + }) +} + +func WithRedirectHostHeaders(headers []string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0] + for _, header := range headers { + trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header)) + if trimmed != "" { + cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed) + } + } return nil }) } @@ -215,16 +270,68 @@ func buildMainServer(engine *Engine, cfg runConfig) *http.Server { return server } -func buildRedirectServer(engine *Engine, httpsAddr, httpAddr string) (*http.Server, error) { +func firstRedirectHeaderHost(r *http.Request, headers []string) string { + if r == nil { + return "" + } + for _, header := range headers { + value := strings.TrimSpace(r.Header.Get(header)) + if value == "" { + continue + } + if comma := strings.IndexByte(value, ','); comma >= 0 { + value = strings.TrimSpace(value[:comma]) + } + if value != "" { + return value + } + } + return "" +} + +func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) { + if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return "", http.StatusInternalServerError, false + } + return cfg.redirectHost, 0, true + } + + if len(cfg.redirectHostHeaders) > 0 { + host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true + } + + if r == nil { + return "", http.StatusUpgradeRequired, false + } + host := strings.TrimSpace(r.Host) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true +} + +func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { + httpsAddr := cfg.addr + httpAddr := cfg.httpRedirectAddr httpsPort, err := parseHTTPSPort(httpsAddr) if err != nil { return nil, err } redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - host = r.Host + host, statusCode, ok := redirectTargetHost(r, cfg) + if !ok { + http.Error(w, http.StatusText(statusCode), statusCode) + return + } + + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + host = parsedHost } targetURL := "https://" + host @@ -248,12 +355,26 @@ func validateRunConfig(cfg runConfig) error { if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil { return errors.New("https mode requires WithTLS") } - if cfg.httpRedirectAddr != "" && cfg.mode != runModeHTTPSRedirect { - cfg.mode = runModeHTTPSRedirect - } if cfg.gracefulCtx != nil && !cfg.graceful { return errors.New("WithShutdownContext requires graceful shutdown") } + if len(cfg.redirectHostHeaders) > 0 { + if !cfg.useHeaderHostSet || !cfg.useHeaderHost { + return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)") + } + } + if cfg.useHeaderHostSet && cfg.useHeaderHost { + if cfg.redirectHost != "" { + return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)") + } + } else if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return errors.New("WithUseHeaderHost(false) requires WithRedirectHost") + } + if len(cfg.redirectHostHeaders) > 0 { + return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)") + } + } return nil } @@ -286,7 +407,7 @@ func shutdownServers(servers []*http.Server, timeout time.Duration) error { wg.Add(1) go func(s *http.Server) { defer wg.Done() - if err := s.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := s.Shutdown(ctx); err != nil { errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) } }(srv) @@ -378,7 +499,7 @@ func (engine *Engine) Run(opts ...RunOption) error { servers := []*http.Server{mainServer} serveTLSFlags := []bool{serveTLS} if cfg.mode == runModeHTTPSRedirect { - redirectServer, err := buildRedirectServer(engine, cfg.addr, cfg.httpRedirectAddr) + redirectServer, err := buildRedirectServer(engine, cfg) if err != nil { return err } @@ -388,9 +509,22 @@ func (engine *Engine) Run(opts ...RunOption) error { if !cfg.graceful { if len(servers) > 1 { - runServer("HTTPS", servers[0], true) - log.Printf("Starting Touka HTTP Redirect server on %s", servers[1].Addr) - return serveServer(servers[1], false) + serverStopped := make(chan error, len(servers)) + for i, srv := range servers { + serveTLSFlag := serveTLSFlags[i] + go func(server *http.Server, useTLS bool) { + serverStopped <- serveServer(server, useTLS) + }(srv, serveTLSFlag) + } + + err := <-serverStopped + if err != nil && !errors.Is(err, http.ErrServerClosed) { + if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { + return errors.Join(err, shutdownErr) + } + return err + } + return err } protocolLabel := "HTTP" diff --git a/serve_test.go b/serve_test.go index 6ecbeba..2bdddc5 100644 --- a/serve_test.go +++ b/serve_test.go @@ -90,6 +90,18 @@ func TestRunRejectsRedirectWithoutTLS(t *testing.T) { } } +func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) { + engine := New() + err := engine.Run( + WithAddr(":443"), + WithTLS(&tls.Config{}), + WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})), + ) + if err == nil { + t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail") + } +} + func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) { cfg := defaultRunConfig() if err := WithGracefulShutdownDefault().apply(&cfg); err != nil { @@ -122,7 +134,7 @@ func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) { func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { engine := New() - if _, err := buildRedirectServer(engine, "example.com", ":80"); err == nil { + if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil { t.Fatal("expected redirect server builder to reject https address without port") } } @@ -139,6 +151,40 @@ func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { } } +func TestValidateRunConfigDoesNotMutateMode(t *testing.T) { + cfg := defaultRunConfig() + cfg.httpRedirectAddr = ":80" + if err := validateRunConfig(cfg); err != nil { + t.Fatalf("validate run config: %v", err) + } + if cfg.mode != runModeHTTP { + t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode) + } +} + +func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = false + cfg.useHeaderHostSet = true + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected configured host mode without redirect host to fail validation") + } +} + +func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = true + cfg.useHeaderHostSet = true + cfg.redirectHost = "configured.example" + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected redirect host to be rejected when header host mode is enabled") + } +} + func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) { engine := New() server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP}) @@ -189,7 +235,7 @@ func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) { s.ReadTimeout = time.Second }) - server, err := buildRedirectServer(engine, ":443", ":80") + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) if err != nil { t.Fatalf("build redirect server: %v", err) } @@ -216,7 +262,7 @@ func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) { func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { engine := New() - server, err := buildRedirectServer(engine, ":443", ":80") + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) if err != nil { t.Fatalf("build redirect server: %v", err) } @@ -234,6 +280,84 @@ func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { } } +func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + req.Header.Set("X-First-Host", "first.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUpgradeRequired { + t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code) + } +} + +func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: false, + useHeaderHostSet: true, + redirectHost: "configured.example", + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { occupied, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -252,7 +376,7 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { } engine := New() - redirectServer, err := buildRedirectServer(engine, ":443", redirectAddr) + redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr}) if err != nil { t.Fatalf("build redirect server: %v", err) } @@ -275,3 +399,34 @@ func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { t.Fatalf("unexpected dial result after shutdown, got %v", dialErr) } } + +func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) { + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on occupied addr: %v", err) + } + occupiedAddr := occupied.Addr().String() + defer occupied.Close() + + redirectListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for redirect addr: %v", err) + } + redirectAddr := redirectListener.Addr().String() + if err := redirectListener.Close(); err != nil { + t.Fatalf("close redirect addr probe: %v", err) + } + + engine := New() + err = engine.Run( + WithAddr(occupiedAddr), + WithTLS(&tls.Config{}), + WithHTTPRedirect(redirectAddr), + ) + if err == nil { + t.Fatal("expected non-graceful TLS redirect startup to return bind error") + } + if !strings.Contains(err.Error(), occupiedAddr) { + t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err) + } +} diff --git a/touka.go b/touka.go index dd529cb..4ad81da 100644 --- a/touka.go +++ b/touka.go @@ -22,10 +22,10 @@ type HandlerFunc func(*Context) // HandlersChain 定义处理函数链(中间件栈)的类型。 type HandlersChain []HandlerFunc -// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 -type IRouter interface { - Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 - Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 +// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 +type Router interface { + Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组 + Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 GET(relativePath string, handlers ...HandlerFunc) From 9e57f5a5f56d5ab1b3bc6c981c948f710c67e2cf Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 20:00:58 +0800 Subject: [PATCH 27/55] fix: stop redirect siblings on shutdown Make the non-graceful HTTPS redirect path shut down all sibling servers after any server returns, so cleanup stays consistent with the graceful path and partial shutdowns do not leave the redirect listener running. --- serve.go | 9 ++++++--- serve_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/serve.go b/serve.go index b2ba358..386eaf5 100644 --- a/serve.go +++ b/serve.go @@ -518,13 +518,16 @@ func (engine *Engine) Run(opts ...RunOption) error { } err := <-serverStopped - if err != nil && !errors.Is(err, http.ErrServerClosed) { - if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { + if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) { return errors.Join(err, shutdownErr) } + return shutdownErr + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { return err } - return err + return nil } protocolLabel := "HTTP" diff --git a/serve_test.go b/serve_test.go index 2bdddc5..8de14c3 100644 --- a/serve_test.go +++ b/serve_test.go @@ -2,9 +2,15 @@ package touka import ( "context" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "io" + "math/big" "net" "net/http" "net/http/httptest" @@ -13,6 +19,41 @@ import ( "time" ) +func generateSelfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate private key: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "127.0.0.1"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("create self-signed cert: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("parse self-signed cert: %v", err) + } + return cert +} + func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { From 121679b44e160aab44e9fa8a98d99002a26c8010 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 7 Apr 2026 20:31:10 +0800 Subject: [PATCH 28/55] fix: preserve IPv6 brackets in redirects Re-wrap bare IPv6 hosts after stripping ports so HTTPS redirect URLs stay valid. Add a regression test covering bracketed IPv6 hosts in redirect responses. --- serve.go | 3 +++ serve_test.go | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/serve.go b/serve.go index 386eaf5..0fc83f9 100644 --- a/serve.go +++ b/serve.go @@ -332,6 +332,9 @@ func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { if parsedHost, _, err := net.SplitHostPort(host); err == nil { host = parsedHost + if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") { + host = "[" + host + "]" + } } targetURL := "https://" + host diff --git a/serve_test.go b/serve_test.go index 8de14c3..c717653 100644 --- a/serve_test.go +++ b/serve_test.go @@ -399,6 +399,26 @@ func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t * } } +func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil) + req.Host = "[::1]:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" { + t.Fatalf("unexpected IPv6 redirect location: %q", location) + } +} + func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { occupied, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { From 71a344a3de591c0fd4973356edd0916d63fbd4a3 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Fri, 10 Apr 2026 06:08:55 +0800 Subject: [PATCH 29/55] perf: reuse reverse proxy copy buffers --- context_benchmark_test.go | 9 +- ecw_benchmark_test.go | 59 ++++++++++ reverseproxy.go | 11 ++ reverseproxy_benchmark_test.go | 206 +++++++++++++++++++++++++++++++++ 4 files changed, 282 insertions(+), 3 deletions(-) create mode 100644 ecw_benchmark_test.go create mode 100644 reverseproxy_benchmark_test.go diff --git a/context_benchmark_test.go b/context_benchmark_test.go index 2198c59..3c464d0 100644 --- a/context_benchmark_test.go +++ b/context_benchmark_test.go @@ -23,7 +23,7 @@ func TestContextResetKeepsKeysNilUntilSet(t *testing.T) { if err != nil { t.Fatalf("failed to build request: %v", err) } - c.reset(c.Writer, req) + c.reset(UnwrapResponseWriter(c.Writer), req) if c.Keys != nil { t.Fatalf("expected reset to clear Keys without allocating a new map") @@ -47,6 +47,7 @@ func TestContextResetKeepsKeysNilUntilSet(t *testing.T) { func BenchmarkContextReset(b *testing.B) { b.Run("NoKeysUse", func(b *testing.B) { c, _ := CreateTestContext(nil) + rawWriter := UnwrapResponseWriter(c.Writer) req, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { b.Fatalf("failed to build request: %v", err) @@ -56,12 +57,13 @@ func BenchmarkContextReset(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - c.reset(c.Writer, req) + c.reset(rawWriter, req) } }) b.Run("WithKeysUse", func(b *testing.B) { c, _ := CreateTestContext(nil) + rawWriter := UnwrapResponseWriter(c.Writer) req, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { b.Fatalf("failed to build request: %v", err) @@ -71,8 +73,9 @@ func BenchmarkContextReset(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - c.reset(c.Writer, req) + c.reset(rawWriter, req) c.Set("request-id", i) } }) + } diff --git a/ecw_benchmark_test.go b/ecw_benchmark_test.go new file mode 100644 index 0000000..d9a427c --- /dev/null +++ b/ecw_benchmark_test.go @@ -0,0 +1,59 @@ +package touka + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestErrorCapturingResponseWriterResetClearsHeaderSnapshot(t *testing.T) { + c, _ := CreateTestContext(nil) + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + + ecw.capturedErrorSignal = true + ecw.Header().Set("Content-Type", "text/plain") + ecw.Header().Add("X-Test", "one") + + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + + ecw.reset(httptest.NewRecorder(), req, c, c.engine.errorHandle.handler) + + if len(ecw.headerSnapshot) != 0 { + t.Fatalf("expected header snapshot to be empty after reset, got %#v", ecw.headerSnapshot) + } +} + +func BenchmarkErrorCapturingResponseWriterReset(b *testing.B) { + c, _ := CreateTestContext(nil) + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + + rawWriter := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + keys := make([]string, 16) + for i := range keys { + keys[i] = http.CanonicalHeaderKey("X-Test-" + string(rune('A'+i))) + } + values := []string{"one", "two", "three"} + for _, key := range keys { + ecw.headerSnapshot[key] = values + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + ecw.reset(rawWriter, req, c, c.engine.errorHandle.handler) + for _, key := range keys { + ecw.headerSnapshot[key] = values + } + } +} diff --git a/reverseproxy.go b/reverseproxy.go index fe66e2b..1d9c07d 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -84,6 +84,13 @@ type reverseProxyHandler struct { roundRobin atomic.Uint64 } +var reverseProxyCopyBufferPool = sync.Pool{ + New: func() any { + buf := make([]byte, 32*1024) + return &buf + }, +} + type reverseProxyStatusError struct { status int err error @@ -1153,6 +1160,10 @@ func (p *reverseProxyHandler) copyResponse(dst ResponseWriter, src io.Reader, fl if p.config.BufferPool != nil { buf = p.config.BufferPool.Get() defer p.config.BufferPool.Put(buf) + } else { + bufp := reverseProxyCopyBufferPool.Get().(*[]byte) + buf = *bufp + defer reverseProxyCopyBufferPool.Put(bufp) } _, err := p.copyBuffer(writer, src, buf) return err diff --git a/reverseproxy_benchmark_test.go b/reverseproxy_benchmark_test.go new file mode 100644 index 0000000..5d82037 --- /dev/null +++ b/reverseproxy_benchmark_test.go @@ -0,0 +1,206 @@ +package touka + +import ( + "bufio" + "bytes" + "errors" + "io" + "net" + "net/http" + "testing" + "time" +) + +type benchmarkReadSeeker struct { + data []byte + off int +} + +func (r *benchmarkReadSeeker) Read(p []byte) (int, error) { + if r.off >= len(r.data) { + return 0, io.EOF + } + n := copy(p, r.data[r.off:]) + r.off += n + return n, nil +} + +func (r *benchmarkReadSeeker) Reset() { + r.off = 0 +} + +type benchmarkResponseWriter struct { + header http.Header + status int + size int +} + +func newBenchmarkResponseWriter() *benchmarkResponseWriter { + return &benchmarkResponseWriter{header: make(http.Header)} +} + +func (w *benchmarkResponseWriter) Header() http.Header { + return w.header +} + +func (w *benchmarkResponseWriter) WriteHeader(statusCode int) { + if w.status == 0 { + w.status = statusCode + } +} + +func (w *benchmarkResponseWriter) Write(p []byte) (int, error) { + if w.status == 0 { + w.status = http.StatusOK + } + w.size += len(p) + return len(p), nil +} + +func (w *benchmarkResponseWriter) Flush() {} + +func (w *benchmarkResponseWriter) Status() int { + return w.status +} + +func (w *benchmarkResponseWriter) Size() int { + return w.size +} + +func (w *benchmarkResponseWriter) Written() bool { + return w.status != 0 +} + +func (w *benchmarkResponseWriter) IsHijacked() bool { + return false +} + +func (w *benchmarkResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} + +func (w *benchmarkResponseWriter) reset() { + clear(w.header) + w.status = 0 + w.size = 0 +} + +var benchmarkReverseProxySink int + +func BenchmarkReverseProxyCopyResponse(b *testing.B) { + body := bytes.Repeat([]byte("0123456789abcdef"), 4096) + proxy := newReverseProxyHandler(ReverseProxyConfig{}) + dst := newBenchmarkResponseWriter() + src := &benchmarkReadSeeker{data: body} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + dst.reset() + src.Reset() + if err := proxy.copyResponse(dst, src, 0); err != nil { + b.Fatalf("copyResponse failed: %v", err) + } + } + + benchmarkReverseProxySink = dst.Size() +} + +func BenchmarkReverseProxyAvailableUpstreams(b *testing.B) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a"}, + {key: "b"}, + {key: "c"}, + {key: "d"}, + }, + config: ReverseProxyConfig{ + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + MaxFails: 3, + }, + }, + } + + now := time.Now() + proxy.upstreams[0].failures = []time.Time{now.Add(-30 * time.Second)} + proxy.upstreams[1].failures = []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkReverseProxySink = len(proxy.availableUpstreams(now, nil)) + } +} + +func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) { + proxy := newReverseProxyHandler(ReverseProxyConfig{}) + dst := newBenchmarkResponseWriter() + src := bytes.NewBufferString("hello, reverse proxy") + + if err := proxy.copyResponse(dst, src, 0); err != nil { + t.Fatalf("copyResponse failed: %v", err) + } + + if got, want := dst.Size(), len("hello, reverse proxy"); got != want { + t.Fatalf("expected %d bytes copied, got %d", want, got) + } +} + +type fixedLenBufferPool struct { + buf []byte +} + +func (p *fixedLenBufferPool) Get() []byte { + return p.buf +} + +func (p *fixedLenBufferPool) Put(buf []byte) { + p.buf = buf +} + +type recordingReader struct { + chunk int + reads []int + left int +} + +func (r *recordingReader) Read(p []byte) (int, error) { + if r.left == 0 { + return 0, io.EOF + } + n := min(r.chunk, len(p), r.left) + if n == 0 { + return 0, errors.New("reader received zero-length buffer") + } + for i := 0; i < n; i++ { + p[i] = 'x' + } + r.left -= n + r.reads = append(r.reads, len(p)) + return n, nil +} + +func TestReverseProxyCopyResponseRespectsCustomBufferLength(t *testing.T) { + pool := &fixedLenBufferPool{buf: make([]byte, 8, 32*1024)} + proxy := newReverseProxyHandler(ReverseProxyConfig{BufferPool: pool}) + dst := newBenchmarkResponseWriter() + src := &recordingReader{chunk: 8, left: 24} + + if err := proxy.copyResponse(dst, src, 0); err != nil { + t.Fatalf("copyResponse failed: %v", err) + } + + if len(src.reads) == 0 { + t.Fatal("expected reader to be used") + } + for _, size := range src.reads { + if size != 8 { + t.Fatalf("expected custom buffer length 8 to be preserved, got read size %d", size) + } + } +} + +var _ io.Writer = (*benchmarkResponseWriter)(nil) From 017bb13295c33af88791a6ae22931a14db2f9e0d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Fri, 10 Apr 2026 06:18:52 +0800 Subject: [PATCH 30/55] perf: reuse reverse proxy candidate slices --- reverseproxy.go | 7 ++++ reverseproxy_benchmark_test.go | 65 +++++++++++++++++++++++++++++++--- reverseproxy_lb.go | 22 +++++++++--- 3 files changed, 86 insertions(+), 8 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 1d9c07d..5b178d5 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -91,6 +91,13 @@ var reverseProxyCopyBufferPool = sync.Pool{ }, } +var reverseProxyCandidatePool = sync.Pool{ + New: func() any { + s := make([]*reverseProxyUpstream, 0, 8) + return &s + }, +} + type reverseProxyStatusError struct { status int err error diff --git a/reverseproxy_benchmark_test.go b/reverseproxy_benchmark_test.go index 5d82037..f55f5f0 100644 --- a/reverseproxy_benchmark_test.go +++ b/reverseproxy_benchmark_test.go @@ -110,10 +110,10 @@ func BenchmarkReverseProxyCopyResponse(b *testing.B) { func BenchmarkReverseProxyAvailableUpstreams(b *testing.B) { proxy := &reverseProxyHandler{ upstreams: []*reverseProxyUpstream{ - {key: "a"}, - {key: "b"}, - {key: "c"}, - {key: "d"}, + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + {key: "d", index: 3}, }, config: ReverseProxyConfig{ PassiveHealth: ReverseProxyPassiveHealthConfig{ @@ -135,6 +135,38 @@ func BenchmarkReverseProxyAvailableUpstreams(b *testing.B) { } } +func BenchmarkReverseProxySelectUpstream(b *testing.B) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + {key: "d", index: 3}, + }, + config: ReverseProxyConfig{ + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()}, + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + MaxFails: 3, + }, + }, + } + proxy.upstreams[0].failures = []time.Time{time.Now().Add(-30 * time.Second)} + + c, _ := CreateTestContext(nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + selected, err := proxy.selectUpstream(c, nil) + if err != nil { + b.Fatalf("selectUpstream failed: %v", err) + } + benchmarkReverseProxySink = selected.index + } +} + func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) { proxy := newReverseProxyHandler(ReverseProxyConfig{}) dst := newBenchmarkResponseWriter() @@ -203,4 +235,29 @@ func TestReverseProxyCopyResponseRespectsCustomBufferLength(t *testing.T) { } } +func TestReverseProxyAvailableUpstreamsFiltersExcludedAndUnhealthy(t *testing.T) { + now := time.Now() + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a"}, + {key: "b", failures: []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}}, + {key: "c"}, + }, + config: ReverseProxyConfig{ + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + MaxFails: 2, + }, + }, + } + + available := proxy.availableUpstreams(now, map[string]struct{}{"c": {}}) + if len(available) != 1 { + t.Fatalf("expected only one available upstream, got %d", len(available)) + } + if available[0].key != "a" { + t.Fatalf("expected upstream 'a', got %q", available[0].key) + } +} + var _ io.Writer = (*benchmarkResponseWriter)(nil) diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go index d2d45ab..02895fb 100644 --- a/reverseproxy_lb.go +++ b/reverseproxy_lb.go @@ -137,18 +137,32 @@ func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error { func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) { now := time.Now() policy := p.config.LoadBalancing.Policy - candidates := p.availableUpstreams(now, excluded) + candidateBuf := reverseProxyCandidatePool.Get().(*[]*reverseProxyUpstream) + candidates := p.availableUpstreamsInto(now, excluded, *candidateBuf) if len(candidates) == 0 && len(excluded) > 0 { - candidates = p.availableUpstreams(now, nil) + candidates = p.availableUpstreamsInto(now, nil, candidates[:0]) } if len(candidates) == 0 { + *candidateBuf = candidates[:0] + reverseProxyCandidatePool.Put(candidateBuf) return nil, errReverseProxyNoAvailableUpstreams } - return p.selectUpstreamWithPolicy(c, candidates, policy), nil + selected := p.selectUpstreamWithPolicy(c, candidates, policy) + *candidateBuf = candidates[:0] + reverseProxyCandidatePool.Put(candidateBuf) + return selected, nil } func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream { - candidates := make([]*reverseProxyUpstream, 0, len(p.upstreams)) + return p.availableUpstreamsInto(now, excluded, nil) +} + +func (p *reverseProxyHandler) availableUpstreamsInto(now time.Time, excluded map[string]struct{}, candidates []*reverseProxyUpstream) []*reverseProxyUpstream { + if cap(candidates) < len(p.upstreams) { + candidates = make([]*reverseProxyUpstream, 0, len(p.upstreams)) + } else { + candidates = candidates[:0] + } for _, upstream := range p.upstreams { if _, skip := excluded[upstream.key]; skip { continue From 7c37d4c38c70a15c1d755f8a43b47ac5ecb18894 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Fri, 10 Apr 2026 21:44:31 +0800 Subject: [PATCH 31/55] perf: fast-path default 404 and 405 responses --- engine.go | 37 +++++++++++++++++++++++++++++++++++++ engine_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/engine.go b/engine.go index b0723e7..536d6e1 100644 --- a/engine.go +++ b/engine.go @@ -19,6 +19,7 @@ import ( "github.com/WJQSERVER-STUDIO/httpc" "github.com/fenthope/reco" + "github.com/go-json-experiment/json" ) // Last 返回链中的最后一个处理函数 @@ -132,6 +133,32 @@ type defaultErrorResponse struct { Error string `json:"error"` } +var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error()) +var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error()) + +func mustMarshalDefaultErrorBody(code int, errMsg string) []byte { + body, err := json.Marshal(defaultErrorResponse{ + Code: code, + Message: http.StatusText(code), + Error: errMsg, + }) + if err != nil { + panic(err) + } + return body +} + +func writeDefaultErrorJSON(c *Context, code int, body []byte) { + if c == nil || c.Writer == nil { + return + } + c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") + c.Writer.WriteHeader(code) + _, _ = c.Writer.Write(body) + c.Writer.Flush() + c.Abort() +} + var methodNotAllowedHandler HandlerFunc = func(c *Context) { httpMethod := c.Request.Method requestPath := routeLookupPath(c.Request) @@ -191,6 +218,16 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是 if c.Writer.Written() { return } + if len(c.Errors) == 0 { + switch { + case code == http.StatusNotFound && errors.Is(err, errNotFound): + writeDefaultErrorJSON(c, code, defaultNotFoundBody) + return + case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed): + writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody) + return + } + } // 输出json 状态码与状态码对应描述 var errMsg string if err != nil { diff --git a/engine_test.go b/engine_test.go index 571f4b7..f6906b3 100644 --- a/engine_test.go +++ b/engine_test.go @@ -139,3 +139,49 @@ func TestDefaultErrorHandleJSONShape(t *testing.T) { t.Fatalf("unexpected error payload: %+v", body) } } + +func TestDefaultMethodNotAllowedJSONShape(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) + } + if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" { + t.Fatalf("unexpected error payload: %+v", body) + } +} + +func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) { + engine := New() + engine.SetErrorHandler(func(c *Context, code int, err error) { + c.Writer.Header().Set("X-Custom-Error", "1") + c.String(code, "custom:%v", err) + }) + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if got := rr.Header().Get("X-Custom-Error"); got != "1" { + t.Fatalf("expected custom error header, got %q", got) + } + if rr.Body.String() != "custom:method not allowed" { + t.Fatalf("expected custom error body, got %q", rr.Body.String()) + } +} From 02861b5537fe237d91e271cff0659835b7b90ceb Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Fri, 10 Apr 2026 21:55:21 +0800 Subject: [PATCH 32/55] perf: avoid header policy join allocations --- reverseproxy_benchmark_test.go | 92 ++++++++++++++++++++++++++++++++++ reverseproxy_lb.go | 46 ++++++++++++++++- 2 files changed, 137 insertions(+), 1 deletion(-) diff --git a/reverseproxy_benchmark_test.go b/reverseproxy_benchmark_test.go index f55f5f0..7a03bd4 100644 --- a/reverseproxy_benchmark_test.go +++ b/reverseproxy_benchmark_test.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/http" + "strings" "testing" "time" ) @@ -167,6 +168,33 @@ func BenchmarkReverseProxySelectUpstream(b *testing.B) { } } +func BenchmarkReverseProxySelectUpstreamHeaderPolicy(b *testing.B) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + {key: "d", index: 3}, + }, + config: ReverseProxyConfig{ + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())}, + }, + } + c, _ := CreateTestContext(nil) + c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b", "tenant-c"} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + selected, err := proxy.selectUpstream(c, nil) + if err != nil { + b.Fatalf("selectUpstream failed: %v", err) + } + benchmarkReverseProxySink = selected.index + } +} + func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) { proxy := newReverseProxyHandler(ReverseProxyConfig{}) dst := newBenchmarkResponseWriter() @@ -260,4 +288,68 @@ func TestReverseProxyAvailableUpstreamsFiltersExcludedAndUnhealthy(t *testing.T) } } +func TestReverseProxyHeaderPolicyUsesAllHeaderValues(t *testing.T) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + }, + config: ReverseProxyConfig{ + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())}, + }, + } + + c, _ := CreateTestContext(nil) + c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b"} + + selectedA, err := proxy.selectUpstream(c, nil) + if err != nil { + t.Fatalf("selectUpstream failed: %v", err) + } + selectedB, err := proxy.selectUpstream(c, nil) + if err != nil { + t.Fatalf("selectUpstream failed: %v", err) + } + if selectedA.key != selectedB.key { + t.Fatalf("expected stable selection for identical multi-value header, got %q and %q", selectedA.key, selectedB.key) + } + + c.Request.Header["X-Tenant"] = []string{"tenant-b", "tenant-a"} + selectedC, err := proxy.selectUpstream(c, nil) + if err != nil { + t.Fatalf("selectUpstream failed: %v", err) + } + if selectedC == nil { + t.Fatal("expected upstream for reordered multi-value header") + } +} + +func TestReverseProxyHeaderPolicyMatchesJoinCompatibility(t *testing.T) { + candidates := []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + } + + testCases := [][]string{ + {"tenant-a"}, + {"tenant-a", "tenant-b"}, + {"", "tenant-b"}, + {"tenant-a", ""}, + {"", ""}, + } + + for _, values := range testCases { + got := reverseProxySelectHRWValues(candidates, values) + want := reverseProxySelectHRW(candidates, strings.Join(values, ",")) + if got == nil || want == nil { + t.Fatalf("expected non-nil upstreams for values %v", values) + } + if got.key != want.key { + t.Fatalf("expected joined compatibility for values %v, got %q want %q", values, got.key, want.key) + } + } +} + var _ io.Writer = (*benchmarkResponseWriter)(nil) diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go index 02895fb..3be7234 100644 --- a/reverseproxy_lb.go +++ b/reverseproxy_lb.go @@ -199,7 +199,7 @@ func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates [] case reverseProxyLBPolicyHeader: if c.Request != nil && c.Request.Header != nil { if values, ok := c.Request.Header[policy.key]; ok { - return reverseProxySelectHRW(candidates, strings.Join(values, ",")) + return reverseProxySelectHRWValues(candidates, values) } } return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) @@ -277,6 +277,25 @@ func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reve return selected } +func reverseProxySelectHRWValues(candidates []*reverseProxyUpstream, values []string) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + if len(values) == 0 { + return reverseProxySelectRandom(candidates) + } + selected := candidates[0] + bestScore := reverseProxyHRWValuesScore(values, selected.key) + for _, upstream := range candidates[1:] { + score := reverseProxyHRWValuesScore(values, upstream.key) + if score > bestScore { + selected = upstream + bestScore = score + } + } + return selected +} + func reverseProxyHRWScore(key, upstreamKey string) uint64 { const ( offset64 = 14695981039346656037 @@ -296,6 +315,31 @@ func reverseProxyHRWScore(key, upstreamKey string) uint64 { return h } +func reverseProxyHRWValuesScore(values []string, upstreamKey string) uint64 { + const ( + offset64 = 14695981039346656037 + prime64 = 1099511628211 + ) + h := uint64(offset64) + for valueIndex, value := range values { + for i := 0; i < len(value); i++ { + h ^= uint64(value[i]) + h *= prime64 + } + if valueIndex+1 < len(values) { + h ^= ',' + h *= prime64 + } + } + h ^= 0xff + h *= prime64 + for i := 0; i < len(upstreamKey); i++ { + h ^= uint64(upstreamKey[i]) + h *= prime64 + } + return h +} + func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy { if policy.fallback != nil { return *policy.fallback From 54f7de0c608dc11022fb80e7664902e0c92af8b3 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Sat, 11 Apr 2026 01:43:34 +0800 Subject: [PATCH 33/55] perf: modernize io paths and reduce proxy allocations --- context.go | 33 +++++--- engine.go | 2 +- engine_test.go | 119 ++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 2 + iox_benchmark_test.go | 150 +++++++++++++++++++++++++++++++++ reverseproxy.go | 12 +-- reverseproxy_benchmark_test.go | 2 +- reverseproxy_lb.go | 8 +- reverseproxy_test.go | 8 +- serve_test.go | 3 +- 11 files changed, 312 insertions(+), 29 deletions(-) create mode 100644 iox_benchmark_test.go diff --git a/context.go b/context.go index f06d21e..e73033d 100644 --- a/context.go +++ b/context.go @@ -128,6 +128,19 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } } +func (c *Context) writeResponseBody(data []byte, contextMsg string) { + if len(data) == 0 { + return + } + if _, err := c.Writer.Write(data); err != nil { + wrapped := fmt.Errorf("%s: %w", contextMsg, err) + c.AddError(wrapped) + if c != nil && c.engine != nil && c.engine.LogReco != nil { + c.engine.LogReco.Errorf("%s: %v", contextMsg, err) + } + } +} + // Next 在处理链中执行下一个处理函数 // 这是中间件模式的核心,允许请求依次经过多个处理函数 func (c *Context) Next() { @@ -344,20 +357,20 @@ func (c *Context) Param(key string) string { func (c *Context) Raw(code int, contentType string, data []byte) { c.Writer.Header().Set("Content-Type", contentType) c.Writer.WriteHeader(code) - c.Writer.Write(data) + c.writeResponseBody(data, "failed to write raw response") } // String 向响应写入格式化的字符串 func (c *Context) String(code int, format string, values ...any) { c.Writer.WriteHeader(code) - c.Writer.Write(fmt.Appendf(nil, format, values...)) + c.writeResponseBody(fmt.Appendf(nil, format, values...), "failed to write string response") } // Text 向响应写入无需格式化的string func (c *Context) Text(code int, text string) { c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write([]byte(text)) + c.writeResponseBody([]byte(text), "failed to write text response") } // FileText @@ -495,7 +508,7 @@ func (c *Context) JSONBuf(code int, obj any) { c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered JSON response") } // GOB 向响应写入GOB数据 @@ -524,7 +537,7 @@ func (c *Context) GOBBuf(code int, obj any) { } c.Writer.Header().Set("Content-Type", "application/octet-stream") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered GOB response") } // WANF向响应写入WANF数据 @@ -553,7 +566,7 @@ func (c *Context) WANFBuf(code int, obj any) { } c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered WANF response") } // HTML 渲染 HTML 模板 @@ -577,7 +590,7 @@ func (c *Context) HTML(code int, name string, obj any) { // 可以扩展支持其他渲染器接口 } // 默认简单输出,用于未配置 HTMLRender 的情况 - c.Writer.Write(fmt.Appendf(nil, "\n
%v
", name, obj)) + c.writeResponseBody(fmt.Appendf(nil, "\n
%v
", name, obj), "failed to write HTML response") } // HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体. @@ -602,7 +615,7 @@ func (c *Context) HTMLBuf(code int, name string, obj any) { // 渲染成功,写入响应 c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered HTML response") return } @@ -938,7 +951,7 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { } }() - data, err := iox.ReadAll(body) + data, err := io.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -959,7 +972,7 @@ func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { } }() - data, err := iox.ReadAll(body) + data, err := io.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) diff --git a/engine.go b/engine.go index 536d6e1..d712064 100644 --- a/engine.go +++ b/engine.go @@ -154,7 +154,7 @@ func writeDefaultErrorJSON(c *Context, code int, body []byte) { } c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.WriteHeader(code) - _, _ = c.Writer.Write(body) + c.writeResponseBody(body, "failed to write default error response") c.Writer.Flush() c.Abort() } diff --git a/engine_test.go b/engine_test.go index f6906b3..4772810 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1,11 +1,66 @@ package touka import ( + "bufio" "encoding/json" + "errors" + "html/template" + "net" "net/http" "testing" ) +type failingResponseWriter struct { + header http.Header + status int + err error +} + +func (w *failingResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *failingResponseWriter) WriteHeader(statusCode int) { + if w.status == 0 { + w.status = statusCode + } +} + +func (w *failingResponseWriter) Write(p []byte) (int, error) { + if w.status == 0 { + w.status = http.StatusOK + } + if w.err != nil { + return 0, w.err + } + return len(p), nil +} + +func (w *failingResponseWriter) Flush() {} + +func (w *failingResponseWriter) Status() int { + return w.status +} + +func (w *failingResponseWriter) Size() int { + return 0 +} + +func (w *failingResponseWriter) Written() bool { + return w.status != 0 +} + +func (w *failingResponseWriter) IsHijacked() bool { + return false +} + +func (w *failingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} + func TestHandleRequestRedirectFixedPath(t *testing.T) { engine := New() engine.GET("/api/v1/users/:id/settings", func(c *Context) { @@ -185,3 +240,67 @@ func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) { t.Fatalf("expected custom error body, got %q", rr.Body.String()) } } + +func TestResponseHelpersCaptureWriteErrors(t *testing.T) { + testCases := []struct { + name string + run func(*Context) + }{ + {name: "Raw", run: func(c *Context) { c.Raw(http.StatusOK, "application/octet-stream", []byte("payload")) }}, + {name: "String", run: func(c *Context) { c.String(http.StatusOK, "value=%d", 1) }}, + {name: "Text", run: func(c *Context) { c.Text(http.StatusOK, "payload") }}, + {name: "JSONBuf", run: func(c *Context) { c.JSONBuf(http.StatusOK, map[string]string{"a": "b"}) }}, + {name: "GOBBuf", run: func(c *Context) { c.GOBBuf(http.StatusOK, struct{ A string }{A: "b"}) }}, + {name: "WANFBuf", run: func(c *Context) { c.WANFBuf(http.StatusOK, map[string]string{"a": "b"}) }}, + {name: "HTMLFallback", run: func(c *Context) { c.HTML(http.StatusOK, "page", map[string]string{"a": "b"}) }}, + {name: "HTMLBuf", run: func(c *Context) { + c.engine.HTMLRender = template.Must(template.New("page").Parse(`{{.a}}`)) + c.HTMLBuf(http.StatusOK, "page", map[string]string{"a": "b"}) + }}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + writerErr := errors.New("write failed") + w := &failingResponseWriter{err: writerErr} + c, _ := CreateTestContext(w) + + tc.run(c) + + if got := len(c.Errors); got != 1 { + t.Fatalf("expected exactly one captured error, got %d", got) + } + if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) { + t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1]) + } + }) + } +} + +func TestDefaultErrorFastPathCapturesWriteErrors(t *testing.T) { + writerErr := errors.New("write failed") + w := &failingResponseWriter{err: writerErr} + engine := New() + c, _ := CreateTestContext(w) + c.engine = engine + req, err := http.NewRequest(http.MethodGet, "/missing", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + c.reset(w, req) + + defaultErrorHandle(c, http.StatusNotFound, errNotFound) + + if len(c.Errors) == 0 { + t.Fatal("expected write error to be captured") + } + if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) { + t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1]) + } + if c.Writer.Status() != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, c.Writer.Status()) + } + if !c.IsAborted() { + t.Fatal("expected fast path to abort context") + } +} diff --git a/go.mod b/go.mod index bd0c046..dee187d 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/infinite-iroha/touka go 1.26 require ( - github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 + github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 github.com/WJQSERVER-STUDIO/httpc v0.9.0 github.com/WJQSERVER/wanf v0.0.8 github.com/fenthope/reco v0.0.5 diff --git a/go.sum b/go.sum index 6a8d0c6..4b9dbd9 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= +github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 h1:Hc1O6D50U3URkdSzfQ/SgeUU750wUBCYhefdvAbE2Ck= +github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3/go.mod h1:nFQzepAwwdj5Hp5U+X19l4FVvsaOSBTW41BzfI/CkMA= github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4= github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg= github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8= diff --git a/iox_benchmark_test.go b/iox_benchmark_test.go new file mode 100644 index 0000000..9b43590 --- /dev/null +++ b/iox_benchmark_test.go @@ -0,0 +1,150 @@ +package touka + +import ( + "bytes" + "io" + "testing" + + "github.com/WJQSERVER-STUDIO/go-utils/iox" +) + +type benchmarkResetReader struct { + data []byte + off int +} + +func (r *benchmarkResetReader) Read(p []byte) (int, error) { + if r.off >= len(r.data) { + return 0, io.EOF + } + n := copy(p, r.data[r.off:]) + r.off += n + return n, nil +} + +func (r *benchmarkResetReader) Reset() { + r.off = 0 +} + +type benchmarkDiscardWriter struct{} + +func (benchmarkDiscardWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +var benchmarkIOXResult int64 +var benchmarkIOXBytes []byte + +func BenchmarkIOXCopyComparison(b *testing.B) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) + + b.Run("io.Copy", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := io.Copy(w, r) + if err != nil { + b.Fatalf("io.Copy failed: %v", err) + } + benchmarkIOXResult = n + } + }) + + b.Run("iox.Copy", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := iox.Copy(w, r) + if err != nil { + b.Fatalf("iox.Copy failed: %v", err) + } + benchmarkIOXResult = n + } + }) +} + +func BenchmarkIOXCopyBufferComparison(b *testing.B) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) + + b.Run("io.CopyBuffer", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + buf := make([]byte, 32*1024) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := io.CopyBuffer(w, r, buf) + if err != nil { + b.Fatalf("io.CopyBuffer failed: %v", err) + } + benchmarkIOXResult = n + } + }) + + b.Run("iox.CopyBuffer", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + buf := make([]byte, 32*1024) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := iox.CopyBuffer(w, r, buf) + if err != nil { + b.Fatalf("iox.CopyBuffer failed: %v", err) + } + benchmarkIOXResult = n + } + }) +} + +func BenchmarkIOXReadAllComparison(b *testing.B) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) + + b.Run("io.ReadAll", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + data, err := io.ReadAll(r) + if err != nil { + b.Fatalf("io.ReadAll failed: %v", err) + } + benchmarkIOXBytes = data + } + }) + + b.Run("iox.ReadAll", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + data, err := io.ReadAll(r) + if err != nil { + b.Fatalf("iox.ReadAll failed: %v", err) + } + benchmarkIOXBytes = data + } + }) +} diff --git a/reverseproxy.go b/reverseproxy.go index 5b178d5..5ec3693 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1041,7 +1041,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r go copyer.copyFromBackend(errc) var firstErr error - for i := 0; i < 2; i++ { + for range 2 { err := <-errc if reverseProxyIsBenignTunnelError(err) { continue @@ -1123,7 +1123,7 @@ func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *htt }() var firstErr error - for i := 0; i < 2; i++ { + for range 2 { err := <-errc if reverseProxyIsBenignTunnelError(err) { continue @@ -1587,8 +1587,8 @@ func reverseProxyViaProtocol(major, minor int, raw string) string { if major > 0 { return strconv.Itoa(major) + "." + strconv.Itoa(minor) } - if strings.HasPrefix(raw, "HTTP/") { - return strings.TrimPrefix(raw, "HTTP/") + if after, ok := strings.CutPrefix(raw, "HTTP/"); ok { + return after } return raw } @@ -1702,7 +1702,7 @@ var reverseProxyHopHeaders = []string{ func removeHopByHopHeaders(header http.Header) { for _, connectionValue := range header["Connection"] { - for _, token := range strings.Split(connectionValue, ",") { + for token := range strings.SplitSeq(connectionValue, ",") { trimmed := textproto.TrimString(token) if trimmed != "" { header.Del(trimmed) @@ -1726,7 +1726,7 @@ func headerValuesContainToken(values []string, token string) bool { return false } for _, value := range values { - for _, part := range strings.Split(value, ",") { + for part := range strings.SplitSeq(value, ",") { if strings.EqualFold(textproto.TrimString(part), token) { return true } diff --git a/reverseproxy_benchmark_test.go b/reverseproxy_benchmark_test.go index 7a03bd4..b496f5c 100644 --- a/reverseproxy_benchmark_test.go +++ b/reverseproxy_benchmark_test.go @@ -235,7 +235,7 @@ func (r *recordingReader) Read(p []byte) (int, error) { if n == 0 { return 0, errors.New("reader received zero-length buffer") } - for i := 0; i < n; i++ { + for i := range n { p[i] = 'x' } r.left -= n diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go index 3be7234..ce5e949 100644 --- a/reverseproxy_lb.go +++ b/reverseproxy_lb.go @@ -10,6 +10,7 @@ import ( "net/http" "net/textproto" "net/url" + "slices" "strings" "sync" "sync/atomic" @@ -404,10 +405,5 @@ func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, statu if status <= 0 { return false } - for _, unhealthyStatus := range config.UnhealthyStatus { - if status == unhealthyStatus { - return true - } - } - return false + return slices.Contains(config.UnhealthyStatus, status) } diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 9cbc734..6863da7 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -1866,7 +1866,9 @@ func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) if message != "echo:ping\n" { t.Fatalf("unexpected tunneled response body: %q", message) } - _ = pw.Close() + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } select { case err := <-errCh: @@ -2215,7 +2217,9 @@ func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testi time.Sleep(50 * time.Millisecond) cancel() - _ = pw.CloseWithError(context.Canceled) + if err := pw.CloseWithError(context.Canceled); err != nil { + t.Fatalf("close request body with cancellation: %v", err) + } select { case <-writeErrCh: case <-time.After(2 * time.Second): diff --git a/serve_test.go b/serve_test.go index c717653..a02f1df 100644 --- a/serve_test.go +++ b/serve_test.go @@ -182,8 +182,7 @@ func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { cfg := defaultRunConfig() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() if err := WithShutdownContext(ctx).apply(&cfg); err != nil { t.Fatalf("apply shutdown context option: %v", err) } From b008fc8e612e5a2635b77ef9971b32f60db54164 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Sun, 19 Apr 2026 07:44:22 +0800 Subject: [PATCH 34/55] fix: only remove Sec-WebSocket-Accept if present in HTTP/2 Extended CONNECT - Check if Sec-WebSocket-Accept header exists before deleting - This prevents unnecessary header manipulation when backend doesn't send it - Maintains compatibility with backends that may or may not include this header --- reverseproxy.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/reverseproxy.go b/reverseproxy.go index 5ec3693..f8335d2 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1014,7 +1014,9 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r responseHeader := c.Writer.Header() reverseProxyCopyHeader(responseHeader, res.Header) removeHopByHopHeaders(responseHeader) - responseHeader.Del("Sec-WebSocket-Accept") + if accept := res.Header.Get("Sec-WebSocket-Accept"); accept != "" { + responseHeader.Del("Sec-WebSocket-Accept") + } c.Writer.WriteHeader(http.StatusOK) if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { backConn.Close() From 3b5f2c81af2cc367f74c7daa62572ee49a6265ce Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Sun, 19 Apr 2026 07:52:00 +0800 Subject: [PATCH 35/55] fix: optimize Sec-WebSocket-Accept header check - Remove unused variable assignment in condition - Direct comparison is more efficient (no extra variable allocation) - Maintains same defensive check behavior --- reverseproxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reverseproxy.go b/reverseproxy.go index f8335d2..4c2b3cd 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -1014,7 +1014,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r responseHeader := c.Writer.Header() reverseProxyCopyHeader(responseHeader, res.Header) removeHopByHopHeaders(responseHeader) - if accept := res.Header.Get("Sec-WebSocket-Accept"); accept != "" { + if res.Header.Get("Sec-WebSocket-Accept") != "" { responseHeader.Del("Sec-WebSocket-Accept") } c.Writer.WriteHeader(http.StatusOK) From 06a6d42de1b6dfa1bb51ba7482463da720f34e7f Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Sun, 19 Apr 2026 09:30:06 +0800 Subject: [PATCH 36/55] feat: add headers operations for reverse proxy - Add HeaderOps struct for Add/Set/Delete header operations - Add RespHeaderOps for response header manipulation with deferred support - Support wildcard patterns for header deletion (prefix-*, *suffix, *substring*) - Apply request headers before forwarding to upstream - Apply response headers before sending to client - Add comprehensive test coverage for header operations Usage example: engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ Target: target, RequestHeaders: &HeaderOps{ Add: map[string][]string{"X-Custom": {"value"}}, Delete: []string{"X-Sensitive-*"}, }, ResponseHeaders: &RespHeaderOps{ HeaderOps: &HeaderOps{ Set: map[string][]string{"X-Frame-Options": {"DENY"}}, }, }, })) --- reverseproxy.go | 193 ++++++++++++++++++++++++++++-- reverseproxy_headers_test.go | 220 +++++++++++++++++++++++++++++++++++ 2 files changed, 401 insertions(+), 12 deletions(-) create mode 100644 reverseproxy_headers_test.go diff --git a/reverseproxy.go b/reverseproxy.go index 4c2b3cd..eb5043c 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -48,34 +48,195 @@ type BufferPool interface { // ReverseProxyConfig configures the reverse proxy handler. type ReverseProxyConfig struct { - Target *url.URL + Target *url.URL Targets []string LoadBalancing ReverseProxyLoadBalancingConfig PassiveHealth ReverseProxyPassiveHealthConfig - Transport http.RoundTripper - FlushInterval time.Duration - BufferPool BufferPool + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool AllowH2CUpstream bool - ModifyRequest func(*http.Request) + ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error - ErrorHandler func(http.ResponseWriter, *http.Request, error) + ErrorHandler func(http.ResponseWriter, *http.Request, error) ForwardedHeaders ForwardedHeadersPolicy - ForwardedBy string - Via string - PreserveHost bool + ForwardedBy string + Via string + PreserveHost bool + + RequestHeaders *HeaderOps + ResponseHeaders *RespHeaderOps } var ( - errReverseProxyNilTarget = errors.New("reverse proxy target is nil") - errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") - errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") + errReverseProxyNilTarget = errors.New("reverse proxy target is nil") + errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") + errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams") ) +type HeaderOps struct { + Add map[string][]string + Set map[string][]string + Delete []string +} + +type RespHeaderOps struct { + *HeaderOps + Deferred bool +} + +func (ops *HeaderOps) applyToRequest(req *http.Request) { + if ops == nil { + return + } + replacer := newReverseProxyReplacer(req) + + for fieldName, vals := range ops.Add { + fieldName = replacer.Replace(fieldName) + for _, v := range vals { + req.Header.Add(fieldName, replacer.Replace(v)) + } + } + + for fieldName, vals := range ops.Set { + fieldName = replacer.Replace(fieldName) + req.Header.Del(fieldName) + for _, v := range vals { + req.Header.Add(fieldName, replacer.Replace(v)) + } + } + + for _, fieldName := range ops.Delete { + fieldName = strings.ToLower(replacer.Replace(fieldName)) + if fieldName == "*" { + for k := range req.Header { + req.Header.Del(k) + } + continue + } + + switch { + case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): + pattern := fieldName[1:len(fieldName)-1] + for k := range req.Header { + if strings.Contains(strings.ToLower(k), pattern) { + req.Header.Del(k) + } + } + case strings.HasPrefix(fieldName, "*"): + suffix := fieldName[1:] + for k := range req.Header { + if strings.HasSuffix(strings.ToLower(k), suffix) { + req.Header.Del(k) + } + } + case strings.HasSuffix(fieldName, "*"): + prefix := fieldName[:len(fieldName)-1] + for k := range req.Header { + if strings.HasPrefix(strings.ToLower(k), prefix) { + req.Header.Del(k) + } + } + default: + req.Header.Del(fieldName) + } + } +} + +func (ops *RespHeaderOps) applyToResponse(hdr http.Header) { + if ops == nil { + return + } + if ops.Deferred { + return + } + ops.applyTo(hdr, newReverseProxyReplacerFromHeader(hdr)) +} + +func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) { + if ops == nil { + return + } + if repl == nil { + repl = &reverseProxyReplacer{} + } + + for fieldName, vals := range ops.Add { + fieldName = repl.Replace(fieldName) + for _, v := range vals { + hdr.Add(fieldName, repl.Replace(v)) + } + } + + for fieldName, vals := range ops.Set { + fieldName = repl.Replace(fieldName) + hdr.Del(fieldName) + for _, v := range vals { + hdr.Add(fieldName, repl.Replace(v)) + } + } + + for _, fieldName := range ops.Delete { + fieldName = strings.ToLower(repl.Replace(fieldName)) + if fieldName == "*" { + for k := range hdr { + hdr.Del(k) + } + continue + } + + switch { + case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): + pattern := fieldName[1:len(fieldName)-1] + for k := range hdr { + if strings.Contains(strings.ToLower(k), pattern) { + hdr.Del(k) + } + } + case strings.HasPrefix(fieldName, "*"): + suffix := fieldName[1:] + for k := range hdr { + if strings.HasSuffix(strings.ToLower(k), suffix) { + hdr.Del(k) + } + } + case strings.HasSuffix(fieldName, "*"): + prefix := fieldName[:len(fieldName)-1] + for k := range hdr { + if strings.HasPrefix(strings.ToLower(k), prefix) { + hdr.Del(k) + } + } + default: + hdr.Del(fieldName) + } + } +} + +type reverseProxyReplacer struct { + req *http.Request +} + +func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer { + return &reverseProxyReplacer{req: req} +} + +func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer { + return &reverseProxyReplacer{} +} + +func (r *reverseProxyReplacer) Replace(s string) string { + if r == nil || s == "" { + return s + } + return s +} + type reverseProxyHandler struct { config ReverseProxyConfig upstreams []*reverseProxyUpstream @@ -573,6 +734,10 @@ func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Conte outreq.Header.Set("User-Agent", "") } + if p.config.RequestHeaders != nil { + p.config.RequestHeaders.applyToRequest(outreq) + } + if p.config.ModifyRequest != nil { p.config.ModifyRequest(outreq) } @@ -808,6 +973,10 @@ func appendXForwardedFor(header http.Header, clientIP string) { } func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool { + if p.config.ResponseHeaders != nil && !p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } + if p.config.ModifyResponse == nil { return true } diff --git a/reverseproxy_headers_test.go b/reverseproxy_headers_test.go new file mode 100644 index 0000000..4a4ae26 --- /dev/null +++ b/reverseproxy_headers_test.go @@ -0,0 +1,220 @@ +package touka + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestReverseProxyHeaderOpsAdd(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Custom-Header"); got != "test-value" { + t.Errorf("expected X-Custom-Header=test-value, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Add: map[string][]string{ + "X-Custom-Header": {"test-value"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsDelete(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Sensitive") != "" { + t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive")) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Delete: []string{"X-Sensitive"}, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Sensitive", "should-be-removed") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsSet(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got := r.Header.Get("X-Replace") + if got != "new-value" { + t.Errorf("expected X-Replace=new-value, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Set: map[string][]string{ + "X-Replace": {"new-value"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Replace", "old-value") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyResponseHeaderOps(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "backend-server") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Set: map[string][]string{ + "X-Custom": {"custom-value"}, + }, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Custom"); got != "custom-value" { + t.Errorf("expected X-Custom=custom-value, got %q", got) + } +} + +func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Powered-By", "Express") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Delete: []string{"X-Powered-By"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Powered-By"); got != "" { + t.Errorf("expected X-Powered-By to be deleted, got %q", got) + } +} From 93f5edc6eb770b07516a8354e522e5c908fcd4de Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Sun, 19 Apr 2026 11:28:08 +0800 Subject: [PATCH 37/55] feat: add Replace support for reverse proxy header ops - Support substring replacement via Search field - Support regex replacement via SearchRegexp field (precompiled at Provision) - Support wildcard field name '*' to apply replacement to all headers - Validate that Search and SearchRegexp are mutually exclusive - Add 5 functional tests and 9 benchmark tests covering all operations Benchmark results (no external allocs in hot paths): Add: 527 ns/op, 448 B/op, 5 allocs/op Set: 891 ns/op, 480 B/op, 7 allocs/op Delete(single): 476 ns/op, 48 B/op, 3 allocs/op Delete(wildcard): 1073 ns/op, 104 B/op, 7 allocs/op Replace(sub): 303 ns/op, 64 B/op, 2 allocs/op Replace(regex): 1503 ns/op, 224 B/op, 6 allocs/op Replace(wild): 731 ns/op, 80 B/op, 4 allocs/op Mixed: 1527 ns/op, 128 B/op, 7 allocs/op --- reverseproxy.go | 95 ++++++- reverseproxy_headers_replace_test.go | 402 +++++++++++++++++++++++++++ 2 files changed, 494 insertions(+), 3 deletions(-) create mode 100644 reverseproxy_headers_replace_test.go diff --git a/reverseproxy.go b/reverseproxy.go index eb5043c..cac2f04 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -20,6 +20,7 @@ import ( "net/netip" "net/textproto" "net/url" + "regexp" "strconv" "strings" "sync" @@ -80,9 +81,17 @@ var ( ) type HeaderOps struct { - Add map[string][]string - Set map[string][]string - Delete []string + Add map[string][]string + Set map[string][]string + Delete []string + Replace map[string][]Replacement +} + +type Replacement struct { + Search string + Replace string + SearchRegexp string + re *regexp.Regexp } type RespHeaderOps struct { @@ -146,6 +155,8 @@ func (ops *HeaderOps) applyToRequest(req *http.Request) { req.Header.Del(fieldName) } } + + ops.applyReplace(req.Header, replacer) } func (ops *RespHeaderOps) applyToResponse(hdr http.Header) { @@ -216,6 +227,71 @@ func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) { hdr.Del(fieldName) } } + + ops.applyReplace(hdr, repl) +} + +func (ops *HeaderOps) applyReplace(hdr http.Header, repl *reverseProxyReplacer) { + if ops == nil || len(ops.Replace) == 0 { + return + } + for fieldName, replacements := range ops.Replace { + fieldName = http.CanonicalHeaderKey(repl.Replace(fieldName)) + if fieldName == "*" { + for fn, vals := range hdr { + for i := range vals { + for _, r := range replacements { + hdr[fn][i] = r.apply(vals[i]) + } + } + } + continue + } + vals, ok := hdr[fieldName] + if !ok { + continue + } + for i := range vals { + for _, r := range replacements { + hdr[fieldName][i] = r.apply(vals[i]) + } + } + } +} + +func (r *Replacement) apply(s string) string { + if r == nil || s == "" { + return s + } + if r.SearchRegexp != "" && r.re != nil { + return r.re.ReplaceAllString(s, r.Replace) + } + if r.Search != "" { + return strings.ReplaceAll(s, r.Search, r.Replace) + } + return s +} + +func (ops *HeaderOps) Provision() error { + if ops == nil { + return nil + } + for fieldName, replacements := range ops.Replace { + for i, r := range replacements { + if r.SearchRegexp == "" { + continue + } + if r.Search != "" { + return fmt.Errorf("replacement %d for header field %q: cannot specify both Search and SearchRegexp", i, fieldName) + } + re, err := regexp.Compile(r.SearchRegexp) + if err != nil { + return fmt.Errorf("replacement %d for header field %q: %v", i, fieldName, err) + } + replacements[i].re = re + } + } + return nil } type reverseProxyReplacer struct { @@ -417,6 +493,19 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { receivedBy: reverseProxyReceivedBy(config.Via), } + if config.RequestHeaders != nil { + if err := config.RequestHeaders.Provision(); err != nil { + proxy.configError = err + return proxy + } + } + if config.ResponseHeaders != nil && config.ResponseHeaders.HeaderOps != nil { + if err := config.ResponseHeaders.HeaderOps.Provision(); err != nil { + proxy.configError = err + return proxy + } + } + upstreams, err := buildReverseProxyUpstreams(config) if err != nil { proxy.configError = err diff --git a/reverseproxy_headers_replace_test.go b/reverseproxy_headers_replace_test.go new file mode 100644 index 0000000..54e7889 --- /dev/null +++ b/reverseproxy_headers_replace_test.go @@ -0,0 +1,402 @@ +package touka + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "testing" +) + +func TestReverseProxyHeaderOpsReplaceSubstring(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Server"); got != "Caddy" { + t.Errorf("expected X-Server=Caddy, got %q", got) + } + if got := r.Header.Get("X-Location"); got != "/api/v2/resource" { + t.Errorf("expected X-Location=/api/v2/resource, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Server": {{Search: "NGINX", Replace: "Caddy"}}, + "X-Location": {{Search: "v1", Replace: "v2"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Server", "NGINX") + req.Header.Set("X-Location", "/api/v1/resource") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsReplaceRegexp(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Route"); got != "/proxy-upstream" { + t.Errorf("expected X-Route=/proxy-upstream, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Route": {{SearchRegexp: `^/([^/]+)/(.+)$`, Replace: "/proxy-$2"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Route", "/original/upstream") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsReplaceWildcard(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Host-A"); got != "new.example.com" { + t.Errorf("expected X-Host-A=new.example.com, got %q", got) + } + if got := r.Header.Get("X-Host-B"); got != "new.example.com" { + t.Errorf("expected X-Host-B=new.example.com, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "*": {{Search: "old.example.com", Replace: "new.example.com"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Host-A", "old.example.com") + req.Header.Set("X-Host-B", "old.example.com") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "backend-internal:8080") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Backend": {{Search: "backend-internal:8080", Replace: "public.example.com"}}, + }, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Backend"); got != "public.example.com" { + t.Errorf("expected X-Backend=public.example.com, got %q", got) + } +} + +func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) { + _ = New() + ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Test": {{SearchRegexp: "[invalid"}}, + }, + }, + }) +} + +func TestReplacementApply(t *testing.T) { + tests := []struct { + name string + r *Replacement + s string + want string + }{ + {name: "nil replacement", r: nil, s: "hello", want: "hello"}, + {name: "empty string", r: &Replacement{Search: "x", Replace: "y"}, s: "", want: ""}, + {name: "substring match", r: &Replacement{Search: "world", Replace: "go"}, s: "hello world", want: "hello go"}, + {name: "substring no match", r: &Replacement{Search: "foo", Replace: "bar"}, s: "hello world", want: "hello world"}, + {name: "substring multiple", r: &Replacement{Search: "a", Replace: "b"}, s: "aaa", want: "bbb"}, + {name: "regexp match", r: &Replacement{SearchRegexp: `\d+`, Replace: "N", re: regexp.MustCompile(`\d+`)}, s: "abc123def", want: "abcNdef"}, + {name: "regexp no match", r: &Replacement{SearchRegexp: `z+`, Replace: "Z", re: regexp.MustCompile(`z+`)}, s: "abc", want: "abc"}, + {name: "empty search and regexp", r: &Replacement{}, s: "unchanged", want: "unchanged"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.apply(tt.s); got != tt.want { + t.Errorf("Replacement.apply() = %q, want %q", got, tt.want) + } + }) + } +} + +func BenchmarkHeaderOpsAdd(b *testing.B) { + ops := &HeaderOps{ + Add: map[string][]string{ + "X-Custom-1": {"value-1"}, + "X-Custom-2": {"value-2"}, + "X-Custom-3": {"value-3"}, + }, + } + hdr := make(http.Header) + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr = make(http.Header) + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsSet(b *testing.B) { + ops := &HeaderOps{ + Set: map[string][]string{ + "X-Frame-Options": {"DENY"}, + "X-Content-Type-Options": {"nosniff"}, + "X-XSS-Protection": {"1; mode=block"}, + }, + } + hdr := make(http.Header) + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr = make(http.Header) + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsDeleteSingle(b *testing.B) { + ops := &HeaderOps{ + Delete: []string{"X-Powered-By"}, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Powered-By", "Express") + hdr.Set("X-Keep", "value") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsDeleteWildcard(b *testing.B) { + ops := &HeaderOps{ + Delete: []string{"X-Debug-*"}, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Debug-1", "v1") + hdr.Set("X-Debug-2", "v2") + hdr.Set("X-Keep", "value") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsReplaceSubstring(b *testing.B) { + ops := &HeaderOps{ + Replace: map[string][]Replacement{ + "Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("Location", "http://internal:8080/api/v1/users") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsReplaceRegexp(b *testing.B) { + re := regexp.MustCompile(`^http://([^/]+)(/.*)$`) + ops := &HeaderOps{ + Replace: map[string][]Replacement{ + "Location": {{SearchRegexp: `^http://([^/]+)(/.*)$`, Replace: "https://public.example.com$2", re: re}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("Location", "http://internal:8080/api/v1/users") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsReplaceWildcard(b *testing.B) { + ops := &HeaderOps{ + Replace: map[string][]Replacement{ + "*": {{Search: "internal.example.com", Replace: "public.example.com"}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Host", "internal.example.com") + hdr.Set("X-Origin", "internal.example.com") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsMixed(b *testing.B) { + ops := &HeaderOps{ + Add: map[string][]string{ + "X-Request-ID": {"req-123"}, + }, + Set: map[string][]string{ + "X-Frame-Options": {"DENY"}, + }, + Delete: []string{"X-Powered-By"}, + Replace: map[string][]Replacement{ + "Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Powered-By", "Express") + hdr.Set("Location", "http://internal:8080/api") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkReplacementApplySubstring(b *testing.B) { + r := &Replacement{Search: "old.example.com", Replace: "new.example.com"} + s := "https://old.example.com/api/v1/resource" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = r.apply(s) + } +} + +func BenchmarkReplacementApplyRegexp(b *testing.B) { + r := &Replacement{SearchRegexp: `^https?://[^/]+`, Replace: "https://new.example.com", re: regexp.MustCompile(`^https?://[^/]+`)} + s := "https://old.example.com/api/v1/resource" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = r.apply(s) + } +} From c0e31c449ed6827f2fe719e03bc1b975138c75b5 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 16:58:14 +0800 Subject: [PATCH 38/55] fix: address PR review comments for header ops - fix Deferred response header logic: apply headers after ModifyResponse callback - refactor applyToRequest to eliminate code duplication via applyTo - remove redundant Sec-WebSocket-Accept condition check --- reverseproxy.go | 70 +++++++------------------------------------------ 1 file changed, 9 insertions(+), 61 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index cac2f04..6ccca43 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -103,69 +103,13 @@ func (ops *HeaderOps) applyToRequest(req *http.Request) { if ops == nil { return } - replacer := newReverseProxyReplacer(req) - - for fieldName, vals := range ops.Add { - fieldName = replacer.Replace(fieldName) - for _, v := range vals { - req.Header.Add(fieldName, replacer.Replace(v)) - } - } - - for fieldName, vals := range ops.Set { - fieldName = replacer.Replace(fieldName) - req.Header.Del(fieldName) - for _, v := range vals { - req.Header.Add(fieldName, replacer.Replace(v)) - } - } - - for _, fieldName := range ops.Delete { - fieldName = strings.ToLower(replacer.Replace(fieldName)) - if fieldName == "*" { - for k := range req.Header { - req.Header.Del(k) - } - continue - } - - switch { - case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): - pattern := fieldName[1:len(fieldName)-1] - for k := range req.Header { - if strings.Contains(strings.ToLower(k), pattern) { - req.Header.Del(k) - } - } - case strings.HasPrefix(fieldName, "*"): - suffix := fieldName[1:] - for k := range req.Header { - if strings.HasSuffix(strings.ToLower(k), suffix) { - req.Header.Del(k) - } - } - case strings.HasSuffix(fieldName, "*"): - prefix := fieldName[:len(fieldName)-1] - for k := range req.Header { - if strings.HasPrefix(strings.ToLower(k), prefix) { - req.Header.Del(k) - } - } - default: - req.Header.Del(fieldName) - } - } - - ops.applyReplace(req.Header, replacer) + ops.applyTo(req.Header, newReverseProxyReplacer(req)) } func (ops *RespHeaderOps) applyToResponse(hdr http.Header) { if ops == nil { return } - if ops.Deferred { - return - } ops.applyTo(hdr, newReverseProxyReplacerFromHeader(hdr)) } @@ -1065,8 +1009,11 @@ func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req if p.config.ResponseHeaders != nil && !p.config.ResponseHeaders.Deferred { p.config.ResponseHeaders.applyToResponse(res.Header) } - + if p.config.ModifyResponse == nil { + if p.config.ResponseHeaders != nil && p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } return true } if err := p.config.ModifyResponse(res); err != nil { @@ -1074,6 +1021,9 @@ func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req p.handleError(c, err) return false } + if p.config.ResponseHeaders != nil && p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } return true } @@ -1272,9 +1222,7 @@ func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, r responseHeader := c.Writer.Header() reverseProxyCopyHeader(responseHeader, res.Header) removeHopByHopHeaders(responseHeader) - if res.Header.Get("Sec-WebSocket-Accept") != "" { - responseHeader.Del("Sec-WebSocket-Accept") - } + responseHeader.Del("Sec-WebSocket-Accept") c.Writer.WriteHeader(http.StatusOK) if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { backConn.Close() From 5d9bb3187d86bcfe150e632c48a28f58d96de30b Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 17:20:30 +0800 Subject: [PATCH 39/55] perf: optimize wildcard header deletion; test: assert invalid regex returns 500 - refactor Delete logic to iterate headers once, reducing ToLower calls from O(patterns * headers) to O(headers) - rewrite invalid regex test to verify runtime 500 response --- reverseproxy.go | 71 +++++++++++++++++++--------- reverseproxy_headers_replace_test.go | 34 +++++++++++-- 2 files changed, 78 insertions(+), 27 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 6ccca43..30fa370 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -136,39 +136,64 @@ func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) { } } + var deleteAll bool + var exactDeletes []string + var suffixPatterns, prefixPatterns, containsPatterns []string + for _, fieldName := range ops.Delete { fieldName = strings.ToLower(repl.Replace(fieldName)) if fieldName == "*" { - for k := range hdr { - hdr.Del(k) - } - continue + deleteAll = true + break } - switch { case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): - pattern := fieldName[1:len(fieldName)-1] - for k := range hdr { - if strings.Contains(strings.ToLower(k), pattern) { - hdr.Del(k) - } - } + containsPatterns = append(containsPatterns, fieldName[1:len(fieldName)-1]) case strings.HasPrefix(fieldName, "*"): - suffix := fieldName[1:] - for k := range hdr { - if strings.HasSuffix(strings.ToLower(k), suffix) { - hdr.Del(k) - } - } + suffixPatterns = append(suffixPatterns, fieldName[1:]) case strings.HasSuffix(fieldName, "*"): - prefix := fieldName[:len(fieldName)-1] - for k := range hdr { - if strings.HasPrefix(strings.ToLower(k), prefix) { - hdr.Del(k) + prefixPatterns = append(prefixPatterns, fieldName[:len(fieldName)-1]) + default: + exactDeletes = append(exactDeletes, fieldName) + } + } + + if deleteAll { + for k := range hdr { + hdr.Del(k) + } + } else if len(exactDeletes) > 0 || len(suffixPatterns) > 0 || len(prefixPatterns) > 0 || len(containsPatterns) > 0 { + toDelete := make([]string, 0, len(exactDeletes)) + for k := range hdr { + kl := strings.ToLower(k) + for _, d := range exactDeletes { + if kl == d { + toDelete = append(toDelete, k) + goto skip } } - default: - hdr.Del(fieldName) + for _, p := range containsPatterns { + if strings.Contains(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range suffixPatterns { + if strings.HasSuffix(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range prefixPatterns { + if strings.HasPrefix(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + skip: + } + for _, k := range toDelete { + hdr.Del(k) } } diff --git a/reverseproxy_headers_replace_test.go b/reverseproxy_headers_replace_test.go index 54e7889..1eb0d04 100644 --- a/reverseproxy_headers_replace_test.go +++ b/reverseproxy_headers_replace_test.go @@ -193,15 +193,41 @@ func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) { } func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) { - _ = New() - ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, "http://example.com"), + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, RequestHeaders: &HeaderOps{ Replace: map[string][]Replacement{ "X-Test": {{SearchRegexp: "[invalid"}}, }, }, - }) + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", resp.StatusCode) + } } func TestReplacementApply(t *testing.T) { From fa925582d7121795671cdcd6e92383d7acf951ee Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 17:36:38 +0800 Subject: [PATCH 40/55] feat: implement dynamic request variable replacement in replacer Replace the no-op reverseProxyReplacer.Replace with strings.NewReplacer supporting {method}, {host}, {path}, {query}, {scheme}, {uri}, {proto} --- reverseproxy.go | 30 +++++++- reverseproxy_headers_replace_test.go | 102 +++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 3 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 30fa370..5bb124f 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -264,11 +264,32 @@ func (ops *HeaderOps) Provision() error { } type reverseProxyReplacer struct { - req *http.Request + req *http.Request + repl *strings.Replacer } func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer { - return &reverseProxyReplacer{req: req} + r := &reverseProxyReplacer{req: req} + if req != nil { + uri := req.RequestURI + if uri == "" { + uri = req.URL.RequestURI() + } + scheme := "http" + if req.TLS != nil { + scheme = "https" + } + r.repl = strings.NewReplacer( + "{method}", req.Method, + "{host}", req.Host, + "{path}", req.URL.Path, + "{query}", req.URL.RawQuery, + "{scheme}", scheme, + "{uri}", uri, + "{proto}", req.Proto, + ) + } + return r } func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer { @@ -279,7 +300,10 @@ func (r *reverseProxyReplacer) Replace(s string) string { if r == nil || s == "" { return s } - return s + if r.repl == nil { + return s + } + return r.repl.Replace(s) } type reverseProxyHandler struct { diff --git a/reverseproxy_headers_replace_test.go b/reverseproxy_headers_replace_test.go index 1eb0d04..0c0d599 100644 --- a/reverseproxy_headers_replace_test.go +++ b/reverseproxy_headers_replace_test.go @@ -426,3 +426,105 @@ func BenchmarkReplacementApplyRegexp(b *testing.B) { _ = r.apply(s) } } + +func TestReverseProxyReplacerDynamicVars(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com/api/v1/users?sort=name&limit=10", nil) + req.Host = "example.com" + repl := newReverseProxyReplacer(req) + + tests := []struct { + name string + input string + want string + }{ + {"method", "{method}", "GET"}, + {"host", "{host}", "example.com"}, + {"path", "{path}", "/api/v1/users"}, + {"query", "{query}", "sort=name&limit=10"}, + {"scheme", "{scheme}", "http"}, + {"proto", "{proto}", "HTTP/1.1"}, + {"combined", "X-{method}-{path}", "X-GET-/api/v1/users"}, + {"no vars", "static-value", "static-value"}, + {"empty", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := repl.Replace(tt.input); got != tt.want { + t.Errorf("Replace(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestReverseProxyReplacerNilRequest(t *testing.T) { + repl := newReverseProxyReplacer(nil) + if got := repl.Replace("{method}"); got != "{method}" { + t.Errorf("expected unchanged string with nil request, got %q", got) + } +} + +func TestReverseProxyReplacerNilReplacer(t *testing.T) { + var repl *reverseProxyReplacer + if got := repl.Replace("{method}"); got != "{method}" { + t.Errorf("expected unchanged string with nil replacer, got %q", got) + } +} + +func TestReverseProxyReplacerFromHeader(t *testing.T) { + hdr := make(http.Header) + repl := newReverseProxyReplacerFromHeader(hdr) + if got := repl.Replace("{method}"); got != "{method}" { + t.Errorf("expected unchanged string from header replacer, got %q", got) + } +} + +func TestReverseProxyHeaderOpsWithDynamicVars(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Forwarded-Path"); got != "/dynamic/path" { + t.Errorf("expected X-Forwarded-Path=/dynamic/path, got %q", got) + } + if got := r.Header.Get("X-Forwarded-Method"); got != "GET" { + t.Errorf("expected X-Forwarded-Method=GET, got %q", got) + } + if got := r.Header.Get("X-Forwarded-Host"); got != "client.example" { + t.Errorf("expected X-Forwarded-Host=client.example, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/dynamic/path", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Add: map[string][]string{ + "X-Forwarded-Path": {"{path}"}, + "X-Forwarded-Method": {"{method}"}, + "X-Forwarded-Host": {"{host}"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/dynamic/path", nil) + req.Host = "client.example" + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} From 1243d2d37ad0fae1c255cbfda0efbe97ef7bce62 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 18:02:57 +0800 Subject: [PATCH 41/55] =?UTF-8?q?fix:=20address=20PR=20review=20for=20repl?= =?UTF-8?q?acer=20=E2=80=94=20nil=20check,=20EscapedPath,=20scheme=20reuse?= =?UTF-8?q?,=20perf?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - add req.URL nil guard - use EscapedPath for {path} to avoid illegal header characters - reuse reverseProxyRequestScheme for {scheme} consistency - replace strings.NewReplacer with struct fields + strings.ReplaceAll --- reverseproxy.go | 62 +++++++++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/reverseproxy.go b/reverseproxy.go index 5bb124f..1cf0078 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -264,32 +264,26 @@ func (ops *HeaderOps) Provision() error { } type reverseProxyReplacer struct { - req *http.Request - repl *strings.Replacer + method, host, path, query, scheme, uri, proto string } func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer { - r := &reverseProxyReplacer{req: req} - if req != nil { - uri := req.RequestURI - if uri == "" { - uri = req.URL.RequestURI() - } - scheme := "http" - if req.TLS != nil { - scheme = "https" - } - r.repl = strings.NewReplacer( - "{method}", req.Method, - "{host}", req.Host, - "{path}", req.URL.Path, - "{query}", req.URL.RawQuery, - "{scheme}", scheme, - "{uri}", uri, - "{proto}", req.Proto, - ) + if req == nil || req.URL == nil { + return &reverseProxyReplacer{} + } + uri := req.RequestURI + if uri == "" { + uri = req.URL.RequestURI() + } + return &reverseProxyReplacer{ + method: req.Method, + host: req.Host, + path: req.URL.EscapedPath(), + query: req.URL.RawQuery, + scheme: reverseProxyRequestScheme(req), + uri: uri, + proto: req.Proto, } - return r } func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer { @@ -300,10 +294,28 @@ func (r *reverseProxyReplacer) Replace(s string) string { if r == nil || s == "" { return s } - if r.repl == nil { - return s + if r.method != "" { + s = strings.ReplaceAll(s, "{method}", r.method) } - return r.repl.Replace(s) + if r.host != "" { + s = strings.ReplaceAll(s, "{host}", r.host) + } + if r.path != "" { + s = strings.ReplaceAll(s, "{path}", r.path) + } + if r.query != "" { + s = strings.ReplaceAll(s, "{query}", r.query) + } + if r.scheme != "" { + s = strings.ReplaceAll(s, "{scheme}", r.scheme) + } + if r.uri != "" { + s = strings.ReplaceAll(s, "{uri}", r.uri) + } + if r.proto != "" { + s = strings.ReplaceAll(s, "{proto}", r.proto) + } + return s } type reverseProxyHandler struct { From fce12ee7e7418ec0d6e56d4bf7e714acf50822f5 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 18:19:44 +0800 Subject: [PATCH 42/55] =?UTF-8?q?docs:=20=E8=A1=A5=E5=85=85=E4=B8=AD?= =?UTF-8?q?=E9=97=B4=E4=BB=B6=E6=96=87=E6=A1=A3=EF=BC=8C=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E8=B7=AF=E7=94=B1=E7=BA=A7=E4=B8=AD=E9=97=B4=E4=BB=B6=E5=92=8C?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=E9=A1=BA=E5=BA=8F=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加路由级中间件使用示例 - 说明在创建组时直接传入中间件的方法 - 添加中间件执行顺序章节,清晰展示全局/组/路由中间件的执行流程 --- docs/middleware.md | 65 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/docs/middleware.md b/docs/middleware.md index a222437..14c75d6 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -26,6 +26,41 @@ api.Use(AuthMiddleware()) } ``` +也可以在创建组时直接传入中间件: + +```go +api := r.Group("/api", AuthMiddleware(), RateLimitMiddleware()) +{ + api.GET("/user", handleUser) + api.POST("/data", handleData) +} +``` + +### 路由级中间件 + +为单个路由注册中间件,仅对该路由生效。 + +```go +// 单个路由中间件 +r.GET("/protected", AuthMiddleware(), func(c *touka.Context) { + c.String(http.StatusOK, "Protected content") +}) + +// 多个路由中间件(按顺序执行) +r.POST("/upload", + RateLimitMiddleware(), + AuthMiddleware(), + PermissionCheckMiddleware(), + func(c *touka.Context) { + // 处理上传 + }, +) + +// 路由组中的单个路由也可以使用路由级中间件 +api := r.Group("/api") +api.GET("/admin", AdminAuthMiddleware(), adminHandler) +``` + ## 编写自定义中间件 中间件的函数签名是 `touka.HandlerFunc`。 @@ -67,6 +102,36 @@ func APIKeyAuth() touka.HandlerFunc { } ``` +## 中间件执行顺序 + +理解中间件的执行顺序对于构建正确的处理流程至关重要。中间件按照以下顺序执行: + +```go +// 全局中间件 +r.Use(GlobalMiddleware1()) +r.Use(GlobalMiddleware2()) + +// 组中间件 +api := r.Group("/api", GroupMiddleware1()) +api.Use(GroupMiddleware2()) + +// 路由级中间件 +api.GET("/users", RouteMiddleware1(), RouteMiddleware2(), userHandler) +``` + +对于 `/api/users` 请求,执行顺序为: +1. `GlobalMiddleware1()` - 全局中间件 +2. `GlobalMiddleware2()` - 全局中间件 +3. `GroupMiddleware1()` - 组中间件 +4. `GroupMiddleware2()` - 组中间件 +5. `RouteMiddleware1()` - 路由级中间件 +6. `RouteMiddleware2()` - 路由级中间件 +7. `userHandler` - 最终处理函数 + +``` +请求进入 → 全局中间件 → 组中间件 → 路由中间件 → 处理函数 → 路由中间件后置逻辑 → 组中间件后置逻辑 → 全局中间件后置逻辑 → 响应 +``` + ## 内置中间件 - **Recovery**: 捕获任何发生的 panic,恢复运行并返回 500 错误。它还负责调用全局错误处理器。 From 58fd877ae269cb7ca3799fc8f5b2c3fa4f5aaec5 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 18:32:10 +0800 Subject: [PATCH 43/55] =?UTF-8?q?docs:=20=E4=BF=AE=E5=A4=8D=E5=AE=A1?= =?UTF-8?q?=E6=9F=A5=E6=84=8F=E8=A7=81=EF=BC=8C=E7=BB=9F=E4=B8=80=E6=9C=AF?= =?UTF-8?q?=E8=AF=AD=E5=B9=B6=E8=A1=A5=E5=85=85=E6=B3=A8=E5=86=8C=E9=A1=BA?= =?UTF-8?q?=E5=BA=8F=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 补充中间件注册顺序说明(必须在路由定义之前) - 统一术语:'组中间件' → '路由组中间件' - 统一流程图术语 --- docs/middleware.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/middleware.md b/docs/middleware.md index 14c75d6..b688fb5 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -104,7 +104,7 @@ func APIKeyAuth() touka.HandlerFunc { ## 中间件执行顺序 -理解中间件的执行顺序对于构建正确的处理流程至关重要。中间件按照以下顺序执行: +理解中间件的执行顺序对于构建正确的处理流程至关重要。**注意:注册顺序决定了执行逻辑**,中间件必须在注册路由之前调用(全局中间件应在创建组或定义路由前注册)。中间件按照以下顺序执行: ```go // 全局中间件 @@ -122,14 +122,14 @@ api.GET("/users", RouteMiddleware1(), RouteMiddleware2(), userHandler) 对于 `/api/users` 请求,执行顺序为: 1. `GlobalMiddleware1()` - 全局中间件 2. `GlobalMiddleware2()` - 全局中间件 -3. `GroupMiddleware1()` - 组中间件 -4. `GroupMiddleware2()` - 组中间件 +3. `GroupMiddleware1()` - 路由组中间件 +4. `GroupMiddleware2()` - 路由组中间件 5. `RouteMiddleware1()` - 路由级中间件 6. `RouteMiddleware2()` - 路由级中间件 7. `userHandler` - 最终处理函数 ``` -请求进入 → 全局中间件 → 组中间件 → 路由中间件 → 处理函数 → 路由中间件后置逻辑 → 组中间件后置逻辑 → 全局中间件后置逻辑 → 响应 +请求进入 → 全局中间件 → 路由组中间件 → 路由级中间件 → 最终处理函数 → 路由级中间件后置逻辑 → 路由组中间件后置逻辑 → 全局中间件后置逻辑 → 响应 ``` ## 内置中间件 From c8b14ef43a1374ced51b1a8c718b2c3f137d1cbd Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 19:43:56 +0800 Subject: [PATCH 44/55] =?UTF-8?q?feat:=20=E5=BC=95=E5=85=A5=20Logger=20?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E6=8A=BD=E8=B1=A1=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89=E6=97=A5=E5=BF=97=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 Logger 接口定义,支持 zap/slog 等自定义实现 - 新增 CloserLogger 接口用于支持关闭操作 - Engine 新增 SetLogger/GetLogger 方法使用接口 - 新增 compat.go 兼容层,保留 reco 兼容方法 - 新增 slog 适配器示例 - 删除 zap 示例 - Context.GetLogger() 返回接口类型 --- compat.go | 37 +++ context.go | 23 +- docs/logger-migration-design.md | 400 ++++++++++++++++++++++++++++++++ engine.go | 29 ++- examples/logger_slog/main.go | 71 ++++++ logger.go | 23 ++ logreco.go | 9 + 7 files changed, 575 insertions(+), 17 deletions(-) create mode 100644 compat.go create mode 100644 docs/logger-migration-design.md create mode 100644 examples/logger_slog/main.go create mode 100644 logger.go diff --git a/compat.go b/compat.go new file mode 100644 index 0000000..6a49c89 --- /dev/null +++ b/compat.go @@ -0,0 +1,37 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2024 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import "github.com/fenthope/reco" + +// GetLogReco 返回底层的 reco.Logger 实例 +// 用于需要访问 reco 特定功能的场景 +// 如果当前 logger 不是 *reco.Logger 类型,返回 nil +// +//go:fix inline +func (engine *Engine) GetLogReco() *reco.Logger { + return engine.LogReco +} + +// SetLogReco 设置 reco.Logger 实例 +// 用于向后兼容,等价于 SetLogger(l) +// +//go:fix inline +func (engine *Engine) SetLogReco(l *reco.Logger) { + engine.LogReco = l + engine.logger = l +} + +// GetLoggerReco 返回底层的 reco.Logger 实例 +// 用于需要访问 reco 特定功能的场景 +// 如果当前 logger 不是 *reco.Logger 类型,返回 nil +// +//go:fix inline +func (c *Context) GetLoggerReco() *reco.Logger { + if rl, ok := c.engine.logger.(*reco.Logger); ok { + return rl + } + return c.engine.LogReco +} diff --git a/context.go b/context.go index e73033d..324386e 100644 --- a/context.go +++ b/context.go @@ -26,7 +26,6 @@ import ( "time" "github.com/WJQSERVER/wanf" - "github.com/fenthope/reco" "github.com/go-json-experiment/json" "github.com/WJQSERVER-STUDIO/go-utils/iox" @@ -135,8 +134,8 @@ func (c *Context) writeResponseBody(data []byte, contextMsg string) { if _, err := c.Writer.Write(data); err != nil { wrapped := fmt.Errorf("%s: %w", contextMsg, err) c.AddError(wrapped) - if c != nil && c.engine != nil && c.engine.LogReco != nil { - c.engine.LogReco.Errorf("%s: %v", contextMsg, err) + if c.engine != nil && c.engine.logger != nil { + c.engine.logger.Errorf("%s: %v", contextMsg, err) } } } @@ -1136,9 +1135,9 @@ func (c *Context) GetHTTPC() *httpc.Client { return c.HTTPClient } -// GetLogger 获取engine的Logger -func (c *Context) GetLogger() *reco.Logger { - return c.engine.LogReco +// GetLogger 获取engine的Logger接口 +func (c *Context) GetLogger() Logger { + return c.engine.logger } // GetReqQueryString @@ -1297,25 +1296,25 @@ func (c *Context) DeleteCookie(name string) { // === 日志记录 === func (c *Context) Debugf(format string, args ...any) { - c.engine.LogReco.Debugf(format, args...) + c.engine.logger.Debugf(format, args...) } func (c *Context) Infof(format string, args ...any) { - c.engine.LogReco.Infof(format, args...) + c.engine.logger.Infof(format, args...) } func (c *Context) Warnf(format string, args ...any) { - c.engine.LogReco.Warnf(format, args...) + c.engine.logger.Warnf(format, args...) } func (c *Context) Errorf(format string, args ...any) { - c.engine.LogReco.Errorf(format, args...) + c.engine.logger.Errorf(format, args...) } func (c *Context) Fatalf(format string, args ...any) { - c.engine.LogReco.Fatalf(format, args...) + c.engine.logger.Fatalf(format, args...) } func (c *Context) Panicf(format string, args ...any) { - c.engine.LogReco.Panicf(format, args...) + c.engine.logger.Panicf(format, args...) } diff --git a/docs/logger-migration-design.md b/docs/logger-migration-design.md new file mode 100644 index 0000000..9684d8e --- /dev/null +++ b/docs/logger-migration-design.md @@ -0,0 +1,400 @@ +# Touka Logger 接口迁移方案 + +## 基于 Go 1.26 `go:fix inline` 的自动化迁移设计 + +--- + +## 一、问题分析 + +当前架构问题: +``` +Engine.LogReco → *reco.Logger (公开字段, 直接访问) +Context.GetLogger() → 返回 *reco.Logger (具体类型) +Context.Debugf/Infof... → 硬编码 c.engine.LogReco.Debugf(...) +``` + +这导致用户无法替换日志实现(如 zap/logrus)。 + +--- + +## 二、目标架构 + +``` +Engine.logger → Logger 接口 (私有) +Engine.logReco → *reco.Logger (私有, 兼容层) +Engine.GetLogger() → 返回 Logger 接口 +Engine.SetLogger(Logger)→ 设置日志实现 +Context.GetLogger() → 返回 Logger 接口 +Context.Debugf/Infof... → 调用 c.engine.logger.Debugf(...) +``` + +--- + +## 三、Logger 接口定义 + +```go +// logger.go +package touka + +// Logger 是日志接口,支持任意日志库实现 +type Logger interface { + Debugf(format string, args ...any) + Infof(format string, args ...any) + Warnf(format string, args ...any) + Errorf(format string, args ...any) + Fatalf(format string, args ...any) + Panicf(format string, args ...any) +} + +// CloserLogger 可选扩展,支持关闭操作 +type CloserLogger interface { + Logger + Close() error +} +``` + +--- + +## 四、Engine 结构变更 + +```go +// engine.go 变更 +type Engine struct { + // ... 其他字段保持不变 + + // logger 是新的日志接口 (私有) + logger Logger + + // logReco 是保留的 reco.Logger 引用 (私有) + // 用于向后兼容,当通过 SetLoggerReco 设置时同步到 logger + logReco *reco.Logger + + // 其他字段... +} +``` + +新增/修改方法: + +```go +// GetLogger 返回日志接口 +func (engine *Engine) GetLogger() Logger { + return engine.logger +} + +// SetLogger 设置任意 Logger 实现 +func (engine *Engine) SetLogger(l Logger) { + engine.logger = l + // 如果是 *reco.Logger 类型,同步更新 logReco + if rl, ok := l.(*reco.Logger); ok { + engine.logReco = rl + } else { + engine.logReco = nil + } +} + +// SetLoggerCfg 使用 reco.Config 配置日志 +func (engine *Engine) SetLoggerCfg(logcfg reco.Config) { + logger := NewLogger(logcfg) + engine.logger = logger + engine.logReco = logger +} +``` + +--- + +## 五、`go:fix inline` 兼容性函数 + +### 5.1 旧 API 包装函数 + +在 `compat.go` 中定义: + +```go +// compat.go +package touka + +import "github.com/fenthope/reco" + +// GetLogReco 返回 reco.Logger,用于向后兼容 +// +//go:fix inline +func (engine *Engine) GetLogReco() *reco.Logger { + return engine.logReco +} + +// SetLogReco 设置 reco.Logger,用于向后兼容 +// +//go:fix inline +func (engine *Engine) SetLogReco(l *reco.Logger) { + engine.logReco = l + engine.logger = l +} +``` + +### 5.2 Context 日志方法的 inline 包装 + +```go +// context_compat.go +package touka + +// Debugf 记录 Debug 级别日志 +// +//go:fix inline +func (c *Context) Debugf(format string, args ...any) { + c.engine.logger.Debugf(format, args...) +} + +// Infof 记录 Info 级别日志 +// +//go:fix inline +func (c *Context) Infof(format string, args ...any) { + c.engine.logger.Infof(format, args...) +} + +// Warnf 记录 Warn 级别日志 +// +//go:fix inline +func (c *Context) Warnf(format string, args ...any) { + c.engine.logger.Warnf(format, args...) +} + +// Errorf 记录 Error 级别日志 +// +//go:fix inline +func (c *Context) Errorf(format string, args ...any) { + c.engine.logger.Errorf(format, args...) +} + +// Fatalf 记录 Fatal 级别日志 +// +//go:fix inline +func (c *Context) Fatalf(format string, args ...any) { + c.engine.logger.Fatalf(format, args...) +} + +// Panicf 记录 Panic 级别日志 +// +//go:fix inline +func (c *Context) Panicf(format string, args ...any) { + c.engine.logger.Panicf(format, args...) +} +``` + +### 5.3 GetLogger 返回类型的兼容处理 + +由于 `GetLogger()` 返回类型从 `*reco.Logger` 变为 `Logger`,需要提供兼容函数: + +```go +// context_compat.go (续) + +// GetLoggerReco 返回 *reco.Logger 类型,用于需要具体类型的场景 +// +//go:fix inline +func (c *Context) GetLoggerReco() *reco.Logger { + if rl, ok := c.engine.logger.(*reco.Logger); ok { + return rl + } + return nil +} +``` + +--- + +## 六、go:fix inline 工作原理 + +### 迁移前用户代码: +```go +func handler(c *touka.Context) { + // 旧 API 调用 + c.Debugf("request: %s", c.Request.URL.Path) + c.engine.LogReco.Infof("server started") +} +``` + +### go fix 执行后(自动替换): +```go +func handler(c *touka.Context) { + // Debugf 被替换为函数体 + c.engine.logger.Debugf("request: %s", c.Request.URL.Path) + + // LogReco 访问无法通过 inline 自动处理,需要手动迁移 + // 或者通过 getter 调用 +} +``` + +### 对于字段访问的处理策略: + +`engine.LogReco` 字段访问无法直接用 `go:fix inline` 处理,采用以下策略: + +1. **保留字段但标记 deprecated**:继续导出 `LogReco` 但文档标记为 deprecated +2. **提供 getter/setter**:通过 `go:fix inline` 提供 `GetLogReco/SetLogReco` +3. **渐进迁移**:用户可以在方便时手动迁移到 `GetLogger()/SetLogger()` + +--- + +## 七、迁移前后对比 + +### 场景 1:基本日志调用 + +**迁移前:** +```go +func myHandler(c *touka.Context) { + c.Debugf("processing request %s", c.Request.URL.Path) + c.Infof("user %s logged in", username) + c.Warnf("slow query: %v", duration) + c.Errorf("db error: %v", err) +} +``` + +**迁移后(自动替换):** +```go +func myHandler(c *touka.Context) { + c.engine.logger.Debugf("processing request %s", c.Request.URL.Path) + c.engine.logger.Infof("user %s logged in", username) + c.engine.logger.Warnf("slow query: %v", duration) + c.engine.logger.Errorf("db error: %v", err) +} +``` + +### 场景 2:Engine 配置日志 + +**迁移前:** +```go +engine := touka.New() +engine.LogReco = myLogger // 直接赋值 +logger := engine.LogReco // 直接读取 +``` + +**迁移后(手动 + 自动混合):** +```go +engine := touka.New() + +// 方式 1:使用新 API(推荐) +engine.SetLogger(myLogger) +logger := engine.GetLogger() + +// 方式 2:通过 go:fix inline 自动替换为 getter +// engine.SetLogReco(myLogger) ← go fix 替换 +// logger := engine.GetLogReco() ← go fix 替换 +``` + +### 场景 3:使用第三方日志库(新功能) + +```go +import "go.uber.org/zap" + +func main() { + zapLogger, _ := zap.NewProduction() + defer zapLogger.Sync() + + engine := touka.New() + // 使用 zap 替代默认的 reco.Logger + engine.SetLogger(&ZapAdapter{logger: zapLogger}) + + engine.GET("/api", func(c *touka.Context) { + c.Infof("api called") // 自动使用 zap 输出 + }) +} + +// ZapAdapter 适配 zap 到 touka.Logger 接口 +type ZapAdapter struct { + logger *zap.Logger +} + +func (z *ZapAdapter) Debugf(format string, args ...any) { + z.logger.Debug(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Infof(format string, args ...any) { + z.logger.Info(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Warnf(format string, args ...any) { + z.logger.Warn(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Errorf(format string, args ...any) { + z.logger.Error(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Fatalf(format string, args ...any) { + z.logger.Fatal(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Panicf(format string, args ...any) { + z.logger.Panic(fmt.Sprintf(format, args...)) +} +``` + +--- + +## 八、内部使用迁移 + +框架内部代码也需要迁移,将直接调用 `engine.LogReco` 改为 `engine.logger`: + +需要修改的文件: +- `context.go`: writeResponseBody 中的 `c.engine.LogReco.Errorf` +- `recovery.go`: 如有使用日志 +- `logreco.go`: CloseLogger 方法 + +```go +// context.go 修改前 +func (c *Context) writeResponseBody(data []byte, contextMsg string) { + if _, err := c.Writer.Write(data); err != nil { + if c.engine.LogReco != nil { + c.engine.LogReco.Errorf("%s: %v", contextMsg, err) + } + } +} + +// context.go 修改后 +func (c *Context) writeResponseBody(data []byte, contextMsg string) { + if _, err := c.Writer.Write(data); err != nil { + if c.engine.logger != nil { + c.engine.logger.Errorf("%s: %v", contextMsg, err) + } + } +} +``` + +--- + +## 九、完整文件结构 + +``` +touka/ +├── logger.go # Logger 接口定义 +├── logreco.go # reco.Logger 相关工具函数 +├── compat.go # go:fix inline 兼容性函数 (Engine) +├── context_compat.go # go:fix inline 兼容性函数 (Context) +├── engine.go # Engine 结构变更 +├── context.go # Context 日志方法变更 +└── ... +``` + +--- + +## 十、版本策略 + +| 版本 | 变更内容 | +|------|---------| +| v1.x | 引入 Logger 接口,LogReco 标记 deprecated | +| v2.x | 移除 LogReco 公开字段,仅通过 getter/setter 访问 | +| v3.x | 移除 go:fix inline 兼容函数 | + +--- + +## 十一、go:fix inline 限制说明 + +1. **字段访问无法自动迁移**:`engine.LogReco` 字段访问需要用户手动修改 +2. **返回类型变更需谨慎**:`GetLogger()` 返回类型变更会导致依赖具体类型的代码失败 +3. **inline 函数有大小限制**:函数体过大会影响内联效果 +4. **跨包迁移**:`go:fix inline` 支持跨包,但用户必须运行 `go fix` + +--- + +## 十二、推荐迁移步骤 + +1. **框架侧**:添加 Logger 接口,添加 go:fix inline 函数 +2. **用户侧**:运行 `go fix ./...` 自动迁移可处理的部分 +3. **用户侧**:手动将 `engine.LogReco` 字段访问改为 `engine.SetLogger()/GetLogger()` +4. **用户侧**:如需使用第三方日志,实现 Logger 接口并通过 SetLogger 设置 diff --git a/engine.go b/engine.go index d712064..15df162 100644 --- a/engine.go +++ b/engine.go @@ -52,8 +52,14 @@ type Engine struct { HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求 + // LogReco 保留的 reco.Logger 字段 + // Deprecated: 使用 SetLogger/GetLogger 替代 LogReco *reco.Logger + // logger 是新的日志接口,支持任意 Logger 实现 + // 优先级: logger > LogReco + logger Logger + HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 routesInfo []RouteInfo // 存储所有注册的路由信息 @@ -367,14 +373,27 @@ func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { engine.rebuildFallbackChains() } -// SetLogger传入实例 -func (engine *Engine) SetLogger(logger *reco.Logger) { - engine.LogReco = logger +// SetLogger 传入 Logger 接口实例 +func (engine *Engine) SetLogger(logger Logger) { + engine.logger = logger + // 同步更新 LogReco 以保持向后兼容 + if rl, ok := logger.(*reco.Logger); ok { + engine.LogReco = rl + } else { + engine.LogReco = nil + } } -// 配置日志LoggerCfg +// GetLogger 返回 Logger 接口实例 +func (engine *Engine) GetLogger() Logger { + return engine.logger +} + +// SetLoggerCfg 使用 reco.Config 配置日志 func (engine *Engine) SetLoggerCfg(logcfg reco.Config) { - engine.LogReco = NewLogger(logcfg) + logger := NewLogger(logcfg) + engine.logger = logger + engine.LogReco = logger } // 设置自定义错误处理 diff --git a/examples/logger_slog/main.go b/examples/logger_slog/main.go new file mode 100644 index 0000000..2263960 --- /dev/null +++ b/examples/logger_slog/main.go @@ -0,0 +1,71 @@ +package main + +import ( + "fmt" + "log/slog" + "net/http" + "os" + + "github.com/infinite-iroha/touka" +) + +// SlogAdapter 将 slog.Logger 适配到 touka.Logger 接口 +type SlogAdapter struct { + logger *slog.Logger +} + +func NewSlogAdapter(handler slog.Handler) *SlogAdapter { + return &SlogAdapter{ + logger: slog.New(handler), + } +} + +func (s *SlogAdapter) Debugf(format string, args ...any) { + s.logger.Debug(fmt.Sprintf(format, args...)) +} + +func (s *SlogAdapter) Infof(format string, args ...any) { + s.logger.Info(fmt.Sprintf(format, args...)) +} + +func (s *SlogAdapter) Warnf(format string, args ...any) { + s.logger.Warn(fmt.Sprintf(format, args...)) +} + +func (s *SlogAdapter) Errorf(format string, args ...any) { + s.logger.Error(fmt.Sprintf(format, args...)) +} + +func (s *SlogAdapter) Fatalf(format string, args ...any) { + s.logger.Error(fmt.Sprintf(format, args...)) + os.Exit(1) +} + +func (s *SlogAdapter) Panicf(format string, args ...any) { + s.logger.Error(fmt.Sprintf(format, args...)) + panic(fmt.Sprintf(format, args...)) +} + +func main() { + engine := touka.New() + + // 使用 slog 替换默认的 reco.Logger + handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + slogAdapter := NewSlogAdapter(handler) + engine.SetLogger(slogAdapter) + + engine.GET("/", func(c *touka.Context) { + c.Infof("request received: %s", c.Request.URL.Path) + c.JSON(http.StatusOK, map[string]string{"message": "hello"}) + }) + + // 也可以获取 Logger 接口 + logger := engine.GetLogger() + logger.Debugf("engine started") + + // 也可以直接使用 slog + slog.Info("Server running", "addr", ":8080") + // engine.Run(":8080") +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..1be0077 --- /dev/null +++ b/logger.go @@ -0,0 +1,23 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2024 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +// Logger 是日志接口,支持多种日志库实现(reco、zap、logrus 等) +// 用户可以通过实现此接口来替换默认的日志实现 +type Logger interface { + Debugf(format string, args ...any) + Infof(format string, args ...any) + Warnf(format string, args ...any) + Errorf(format string, args ...any) + Fatalf(format string, args ...any) + Panicf(format string, args ...any) +} + +// CloserLogger 可选扩展接口,支持关闭操作 +// 如果 Logger 实现了此接口,Engine 在关闭时会调用 Close() +type CloserLogger interface { + Logger + Close() error +} diff --git a/logreco.go b/logreco.go index 4bda8d3..e37dd53 100644 --- a/logreco.go +++ b/logreco.go @@ -39,7 +39,16 @@ func CloseLogger(logger *reco.Logger) { } } +// CloseLogger 关闭 Engine 的日志实现 +// 如果 logger 实现了 CloserLogger 接口,会调用其 Close 方法 func (engine *Engine) CloseLogger() { + if cl, ok := engine.logger.(CloserLogger); ok { + if err := cl.Close(); err != nil { + log.Printf("Close Logger Error: %s", err) + } + return + } + // 兼容旧代码 if engine.LogReco != nil { CloseLogger(engine.LogReco) } From 10033f4a174b1d13273a9964fe8fd49a2d7dc1d2 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 21:49:42 +0800 Subject: [PATCH 45/55] =?UTF-8?q?docs:=20=E4=BF=AE=E5=A4=8D=E5=AE=A1?= =?UTF-8?q?=E6=9F=A5=E6=84=8F=E8=A7=81=EF=BC=8C=E4=BF=AE=E6=AD=A3=E8=AE=BE?= =?UTF-8?q?=E8=AE=A1=E6=96=87=E6=A1=A3=E4=B8=8E=E5=AE=9E=E7=8E=B0=E7=9A=84?= =?UTF-8?q?=E4=B8=8D=E4=B8=80=E8=87=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将设计文档中 logReco 改为 LogReco,与实际实现保持一致 - LogReco 字段保持公开但标记为 Deprecated --- docs/logger-migration-design.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/logger-migration-design.md b/docs/logger-migration-design.md index 9684d8e..7b2e0a6 100644 --- a/docs/logger-migration-design.md +++ b/docs/logger-migration-design.md @@ -21,7 +21,7 @@ Context.Debugf/Infof... → 硬编码 c.engine.LogReco.Debugf(...) ``` Engine.logger → Logger 接口 (私有) -Engine.logReco → *reco.Logger (私有, 兼容层) +Engine.LogReco → *reco.Logger (公开, Deprecated - 保持向后兼容) Engine.GetLogger() → 返回 Logger 接口 Engine.SetLogger(Logger)→ 设置日志实现 Context.GetLogger() → 返回 Logger 接口 From f2295c3084e88048bd59fa76f29016ad17d21ef8 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Tue, 21 Apr 2026 22:55:26 +0800 Subject: [PATCH 46/55] =?UTF-8?q?feat:=20httpc=20=E9=9B=86=E6=88=90?= =?UTF-8?q?=E6=94=B9=E8=BF=9B=EF=BC=8C=E8=87=AA=E5=8A=A8=E5=85=B3=E8=81=94?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=20Context?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 contextHTTPClient 包装器,自动关联请求 Context - 新增 Context.HTTPC() 方法返回 contextHTTPClient - Client() 标记为 Deprecated - 添加 GetHTTPC() go:fix inline 兼容函数 当请求被取消时,出站 HTTP 请求也会自动取消。 --- compat.go | 17 +++++++++++++- context.go | 18 ++++++++++----- context_httpc.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 7 deletions(-) create mode 100644 context_httpc.go diff --git a/compat.go b/compat.go index 6a49c89..4e40687 100644 --- a/compat.go +++ b/compat.go @@ -4,7 +4,12 @@ // All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. package touka -import "github.com/fenthope/reco" +import ( + "github.com/WJQSERVER-STUDIO/httpc" + "github.com/fenthope/reco" +) + +// --- reco 兼容函数 --- // GetLogReco 返回底层的 reco.Logger 实例 // 用于需要访问 reco 特定功能的场景 @@ -35,3 +40,13 @@ func (c *Context) GetLoggerReco() *reco.Logger { } return c.engine.LogReco } + +// --- httpc 兼容函数 --- + +// GetHTTPC 返回底层的 httpc.Client 实例 +// Deprecated: 使用 HTTPClient() 替代,新方法会自动关联请求 Context +// +//go:fix inline +func (c *Context) GetHTTPC() *httpc.Client { + return c.Client() +} diff --git a/context.go b/context.go index 324386e..c720de3 100644 --- a/context.go +++ b/context.go @@ -865,11 +865,22 @@ func (c *Context) GetErrors() []error { } // Client 返回 Engine 提供的 HTTPClient -// 方便在请求处理函数中进行出站 HTTP 请求 +// 方便在请求处理函数中进行出站 HTTP请求 +// +// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context func (c *Context) Client() *httpc.Client { return c.HTTPClient } +// HTTPC 返回自动关联请求 Context 的 HTTP 客户端 +// 当请求被取消时,通过此客户端发起的出站请求也会自动取消 +func (c *Context) HTTPC() *contextHTTPClient { + return &contextHTTPClient{ + client: c.engine.HTTPClient, + ctx: c.ctx, + } +} + // Context() 返回请求的上下文,用于取消操作 // 这是 Go 标准库的 `context.Context`,用于请求的取消和超时管理 func (c *Context) Context() context.Context { @@ -1130,11 +1141,6 @@ func (c *Context) GetProtocol() string { return c.Request.Proto } -// GetHTTPC 获取框架自带传递的httpc -func (c *Context) GetHTTPC() *httpc.Client { - return c.HTTPClient -} - // GetLogger 获取engine的Logger接口 func (c *Context) GetLogger() Logger { return c.engine.logger diff --git a/context_httpc.go b/context_httpc.go new file mode 100644 index 0000000..3256a3b --- /dev/null +++ b/context_httpc.go @@ -0,0 +1,58 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2024 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "context" + + "github.com/WJQSERVER-STUDIO/httpc" +) + +// contextHTTPClient 包装 httpc.Client,自动关联请求的 Context +// 当请求被取消时,出站 HTTP 请求也会自动取消 +type contextHTTPClient struct { + client *httpc.Client + ctx context.Context +} + +// NewRequestBuilder 创建请求构建器,自动关联请求 Context +func (c *contextHTTPClient) NewRequestBuilder(method, urlStr string) *httpc.RequestBuilder { + return c.client.NewRequestBuilder(method, urlStr).WithContext(c.ctx) +} + +// GET 创建 GET 请求构建器 +func (c *contextHTTPClient) GET(urlStr string) *httpc.RequestBuilder { + return c.client.GET(urlStr).WithContext(c.ctx) +} + +// POST 创建 POST 请求构建器 +func (c *contextHTTPClient) POST(urlStr string) *httpc.RequestBuilder { + return c.client.POST(urlStr).WithContext(c.ctx) +} + +// PUT 创建 PUT 请求构建器 +func (c *contextHTTPClient) PUT(urlStr string) *httpc.RequestBuilder { + return c.client.PUT(urlStr).WithContext(c.ctx) +} + +// DELETE 创建 DELETE 请求构建器 +func (c *contextHTTPClient) DELETE(urlStr string) *httpc.RequestBuilder { + return c.client.DELETE(urlStr).WithContext(c.ctx) +} + +// PATCH 创建 PATCH 请求构建器 +func (c *contextHTTPClient) PATCH(urlStr string) *httpc.RequestBuilder { + return c.client.PATCH(urlStr).WithContext(c.ctx) +} + +// HEAD 创建 HEAD 请求构建器 +func (c *contextHTTPClient) HEAD(urlStr string) *httpc.RequestBuilder { + return c.client.HEAD(urlStr).WithContext(c.ctx) +} + +// OPTIONS 创建 OPTIONS 请求构建器 +func (c *contextHTTPClient) OPTIONS(urlStr string) *httpc.RequestBuilder { + return c.client.OPTIONS(urlStr).WithContext(c.ctx) +} From 4f262b2497b7dbb1c266ee55421c9507dfaaca5d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 22 Apr 2026 07:13:55 +0800 Subject: [PATCH 47/55] =?UTF-8?q?docs:=20=E6=B7=BB=E5=8A=A0=20httpc=20?= =?UTF-8?q?=E9=9B=86=E6=88=90=E6=96=87=E6=A1=A3=E5=92=8C=E7=A4=BA=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 examples/httpc 示例代码 - 新增 docs/httpc.md 文档说明 --- docs/httpc.md | 188 +++++++++++++++++++++++++++++++++++++++++ examples/httpc/main.go | 103 ++++++++++++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 docs/httpc.md create mode 100644 examples/httpc/main.go diff --git a/docs/httpc.md b/docs/httpc.md new file mode 100644 index 0000000..8742c18 --- /dev/null +++ b/docs/httpc.md @@ -0,0 +1,188 @@ +# HTTP Client (httpc) + +Touka 内置了 [httpc](https://github.com/WJQSERVER-STUDIO/httpc) HTTP 客户端,方便在请求处理函数中发起出站 HTTP 请求。 + +## 核心特性 + +- **自动 Context 关联**:使用 `HTTPC()` 方法时,出站请求会自动关联当前请求的 Context +- **请求取消传播**:当客户端断开连接时,出站请求会自动取消,避免资源泄漏 +- **链式调用**:保持 httpc 原有的组合式构建器风格 + +## 基本用法 + +### 简单 GET 请求 + +```go +r.GET("/proxy", func(c *touka.Context) { + body, err := c.HTTPC(). + GET("https://api.example.com/data"). + Text() + if err != nil { + c.JSON(500, touka.H{"error": err.Error()}) + return + } + c.String(200, body) +}) +``` + +### POST JSON 请求 + +```go +r.POST("/users", func(c *touka.Context) { + var req struct { + Name string `json:"name"` + Email string `json:"email"` + } + c.ShouldBindJSON(&req) + + var result struct { + ID int `json:"id"` + Name string `json:"name"` + } + + err := c.HTTPC(). + POST("https://api.example.com/users"). + SetHeader("Authorization", "Bearer "+token). + SetJSONBody(req). + DecodeJSON(&result) + if err != nil { + c.JSON(500, touka.H{"error": err.Error()}) + return + } + c.JSON(200, result) +}) +``` + +### 带查询参数 + +```go +r.GET("/search", func(c *touka.Context) { + query := c.Query("q") + + var result SearchResult + err := c.HTTPC(). + GET("https://api.example.com/search"). + SetQueryParam("q", query). + SetQueryParam("limit", "10"). + DecodeJSON(&result) + if err != nil { + c.JSON(500, touka.H{"error": err.Error()}) + return + } + c.JSON(200, result) +}) +``` + +## API 对比 + +### 旧方式(Deprecated) + +```go +// 需要手动 WithContext,容易忘记 +resp, err := c.Client(). + WithContext(c.Context()). + GET(url). + Execute() +``` + +### 新方式(推荐) + +```go +// 自动关联请求 Context +resp, err := c.HTTPC(). + GET(url). + Execute() +``` + +## Context 取消机制 + +使用 `HTTPC()` 时,当客户端断开连接(如关闭浏览器),出站请求会自动取消: + +```go +r.GET("/long-task", func(c *touka.Context) { + // 这个请求会在客户端断开时自动取消 + resp, err := c.HTTPC(). + GET("https://slow-api.example.com/data"). + Execute() + + // 如果客户端已断开,err 会包含 context.Canceled + if errors.Is(err, context.Canceled) { + return // 客户端已断开,无需处理 + } + // ... +}) +``` + +## 完整 API + +### contextHTTPClient 方法 + +| 方法 | 返回类型 | 说明 | +|------|----------|------| +| `NewRequestBuilder(method, url)` | `*httpc.RequestBuilder` | 创建通用请求构建器 | +| `GET(url)` | `*httpc.RequestBuilder` | 创建 GET 请求 | +| `POST(url)` | `*httpc.RequestBuilder` | 创建 POST 请求 | +| `PUT(url)` | `*httpc.RequestBuilder` | 创建 PUT 请求 | +| `DELETE(url)` | `*httpc.RequestBuilder` | 创建 DELETE 请求 | +| `PATCH(url)` | `*httpc.RequestBuilder` | 创建 PATCH 请求 | +| `HEAD(url)` | `*httpc.RequestBuilder` | 创建 HEAD 请求 | +| `OPTIONS(url)` | `*httpc.RequestBuilder` | 创建 OPTIONS 请求 | + +### httpc.RequestBuilder 链式方法 + +返回 `*httpc.RequestBuilder`(用于链式调用): + +| 方法 | 说明 | +|------|------| +| `WithContext(ctx)` | 设置 Context(通常不需要,已自动关联) | +| `NoDefaultHeaders()` | 不添加默认 Header | +| `SetHeader(key, value)` | 设置 Header | +| `AddHeader(key, value)` | 添加 Header(可重复) | +| `SetHeaders(map)` | 批量设置 Headers | +| `SetQueryParam(key, value)` | 设置查询参数 | +| `AddQueryParam(key, value)` | 添加查询参数(可重复) | +| `SetQueryParams(map)` | 批量设置查询参数 | +| `SetBody(io.Reader)` | 设置请求 Body | +| `SetRawBody([]byte)` | 设置字节 Body | + +返回 `(*httpc.RequestBuilder, error)`(可能失败): + +| 方法 | 说明 | +|------|------| +| `SetJSONBody(any)` | 设置 JSON Body | +| `SetXMLBody(any)` | 设置 XML Body | +| `SetGOBBody(any)` | 设置 GOB Body | + +### 终结方法 + +| 方法 | 返回类型 | 说明 | +|------|----------|------| +| `Build()` | `(*http.Request, error)` | 构建请求但不执行 | +| `Execute()` | `(*http.Response, error)` | 执行并返回原始响应 | +| `DecodeJSON(v)` | `error` | 执行并解码 JSON | +| `DecodeXML(v)` | `error` | 执行并解码 XML | +| `DecodeGOB(v)` | `error` | 执行并解码 GOB | +| `Text()` | `(string, error)` | 执行并返回文本 | +| `Bytes()` | `([]byte, error)` | 执行并返回字节 | +| `SSE()` | `(*SSEStream, error)` | 建立 SSE 流连接 | + +## 迁移指南 + +### go:fix inline 兼容 + +旧代码 `c.GetHTTPC()` 可通过 `go fix` 自动迁移到 `c.Client()`: + +```bash +go fix ./... +``` + +### 手动迁移 + +| 旧代码 | 新代码 | +|--------|--------| +| `c.GetHTTPC()` | `c.Client()` 或 `c.HTTPC()` | +| `c.Client().WithContext(ctx).GET(url)` | `c.HTTPC().GET(url)` | + +## 示例 + +完整示例请参考 [examples/httpc](../examples/httpc)。 diff --git a/examples/httpc/main.go b/examples/httpc/main.go new file mode 100644 index 0000000..f50ec9a --- /dev/null +++ b/examples/httpc/main.go @@ -0,0 +1,103 @@ +package main + +import ( + "fmt" + "net/http" + + "github.com/infinite-iroha/touka" +) + +func main() { + r := touka.Default() + + // 示例 1:简单 GET 请求(自动关联请求 Context) + r.GET("/proxy", func(c *touka.Context) { + // 使用 HTTPC() 方法,自动关联请求 Context + // 当客户端断开连接时,出站请求也会自动取消 + body, err := c.HTTPC(). + GET("https://httpbin.org/get"). + Text() + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.String(http.StatusOK, body) + }) + + // 示例 2:带 Header 的 POST 请求 + r.POST("/users", func(c *touka.Context) { + var req struct { + Name string `json:"name"` + Email string `json:"email"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()}) + return + } + + var result struct { + ID int `json:"id"` + Name string `json:"name"` + } + + // 链式调用,保持 httpc 风格 + // 注意:SetJSONBody 返回 (*RequestBuilder, error) + rb, err := c.HTTPC(). + POST("https://httpbin.org/post"). + SetHeader("X-API-Key", "secret"). + SetJSONBody(req) + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + if err := rb.DecodeJSON(&result); err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, result) + }) + + // 示例 3:带查询参数的请求 + r.GET("/search", func(c *touka.Context) { + query := c.DefaultQuery("q", "") + page := c.DefaultQuery("page", "1") + + var result struct { + Items []string `json:"items"` + Total int `json:"total"` + } + + err := c.HTTPC(). + GET("https://httpbin.org/get"). + SetQueryParam("q", query). + SetQueryParam("page", page). + DecodeJSON(&result) + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, result) + }) + + // 示例 4:使用底层 httpc.Client(旧方式,仍可用但不推荐) + r.GET("/legacy", func(c *touka.Context) { + // 旧方式:需要手动 WithContext + body, err := c.Client(). + GET("https://httpbin.org/get"). + WithContext(c.Context()). + Text() + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.String(http.StatusOK, body) + }) + + fmt.Println("Server running on :8080") + fmt.Println("Try:") + fmt.Println(" curl http://localhost:8080/proxy") + fmt.Println(" curl -X POST -d '{\"name\":\"test\",\"email\":\"test@example.com\"}' http://localhost:8080/users") + fmt.Println(" curl 'http://localhost:8080/search?q=golang&page=1'") + + // r.Run(touka.WithAddr(":8080")) +} From e7c7d5e41f5db632c189514a9eefae7df6b1b8f7 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 22 Apr 2026 07:30:40 +0800 Subject: [PATCH 48/55] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20Client()=20?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E8=BF=87=E6=97=B6=20HTTPClient=20=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 Client() 从返回 c.HTTPClient 改为返回 c.engine.HTTPClient - 与 HTTPC() 方法保持一致 --- context.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/context.go b/context.go index c720de3..540e27f 100644 --- a/context.go +++ b/context.go @@ -869,7 +869,7 @@ func (c *Context) GetErrors() []error { // // Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context func (c *Context) Client() *httpc.Client { - return c.HTTPClient + return c.engine.HTTPClient } // HTTPC 返回自动关联请求 Context 的 HTTP 客户端 From 74873691253c7da12e15f9487b3f3faad357594d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 22 Apr 2026 08:43:36 +0800 Subject: [PATCH 49/55] =?UTF-8?q?improve:=20MergeCtx=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=20cause=20=E4=BC=A0=E6=92=AD,=20=E4=BD=BF=E7=94=A8=20WithCance?= =?UTF-8?q?lCause/WithDeadlineCause?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 内部改用 context.WithCancelCause 和 WithDeadlineCause, 父 context 取消原因自动传播 - Value() 先检查嵌入 context 再查 parents, 确保 context.Cause() 正确工作 - Done()/Err() 同时监听 cancelCtx 和 deadlineCtx, 支持 deadline 到期 cause - 新增 Cause() 便捷方法 - 单 parent 短路径改用 WithCancelCause 保留 cause - 新增 mergectx_test.go, 覆盖 cause 传播、deadline、Value 查找等场景 - API 兼容: 返回类型保持 CancelFunc 不变 Alina Agent生成 --- mergectx.go | 106 +++++++++++++++----- mergectx_test.go | 256 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 338 insertions(+), 24 deletions(-) create mode 100644 mergectx_test.go diff --git a/mergectx.go b/mergectx.go index e5d3ec4..2e36c09 100644 --- a/mergectx.go +++ b/mergectx.go @@ -12,17 +12,19 @@ import ( // mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. type mergedContext struct { - // 嵌入一个基础 context, 它持有最早的 deadline 和取消信号. + // 嵌入一个基础 context, 用于 Deadline() 和 Value() 查找. context.Context // 保存了所有的父 context, 用于 Value() 方法的查找. parents []context.Context - // 用于手动取消此 mergedContext 的函数. - cancel context.CancelFunc + // cancelCtx 由 CancelCause 管理, 当 cause 取消时其 Done() 关闭. + cancelCtx context.Context + // deadlineCtx 仅在有 deadline 时非 nil, 用于检测 deadline 到期. + deadlineCtx context.Context } // MergeCtx 创建并返回一个新的 context.Context. // 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时, -// 自动被取消 (逻辑或关系). +// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context. // // 新的 context 会继承: // - Deadline: 所有父 context 中最早的截止时间. @@ -32,7 +34,8 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C return context.WithCancel(context.Background()) } if len(parents) == 1 { - return context.WithCancel(parents[0]) + ctx, cancel := context.WithCancelCause(parents[0]) + return ctx, func() { cancel(nil) } } var earliestDeadline time.Time @@ -44,37 +47,78 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C } } - var baseCtx context.Context - var baseCancel context.CancelFunc + // baseCtx 提供 CancelCauseFunc 以支持 cause 传播. + baseCtx, baseCancel := context.WithCancelCause(context.Background()) + + // deadlineCtx 仅用于监听 deadline 到期信号. + var deadlineCtx context.Context + var deadlineCancel context.CancelFunc if !earliestDeadline.IsZero() { - baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline) - } else { - baseCtx, baseCancel = context.WithCancel(context.Background()) + deadlineCtx, deadlineCancel = context.WithDeadlineCause(context.Background(), earliestDeadline, context.DeadlineExceeded) + } + + // 嵌入的 context: 有 deadline 时用 deadlineCtx, 否则用 baseCtx. + embedCtx := baseCtx + if deadlineCtx != nil { + embedCtx = deadlineCtx } mc := &mergedContext{ - Context: baseCtx, - parents: parents, - cancel: baseCancel, + Context: embedCtx, + parents: parents, + cancelCtx: baseCtx, + deadlineCtx: deadlineCtx, } - // 启动一个监控 goroutine. + // 启动监控 goroutine. go func() { - defer mc.cancel() + var once sync.Once + doCancel := func(cause error) { + once.Do(func() { baseCancel(cause) }) + } + defer doCancel(nil) - // orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭. - // 同时监听 baseCtx.Done() 以便支持手动取消. - select { - case <-orDone(mc.parents...): - case <-mc.Context.Done(): + parentDone := orDone(mc.parents...) + + if deadlineCtx != nil { + defer deadlineCancel() + select { + case <-parentDone: + for _, p := range mc.parents { + if p.Err() != nil { + doCancel(context.Cause(p)) + return + } + } + doCancel(nil) + case <-deadlineCtx.Done(): + doCancel(context.DeadlineExceeded) + case <-baseCtx.Done(): + } + } else { + select { + case <-parentDone: + for _, p := range mc.parents { + if p.Err() != nil { + doCancel(context.Cause(p)) + return + } + } + doCancel(nil) + case <-baseCtx.Done(): + } } }() - return mc, mc.cancel + return mc, func() { baseCancel(nil) } } -// Value 返回当前Ctx Value +// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause), +// 再按传入顺序从 parents 中查找. func (mc *mergedContext) Value(key any) any { + if v := mc.Context.Value(key); v != nil { + return v + } for _, p := range mc.parents { if val := p.Value(key); val != nil { return val @@ -90,12 +134,26 @@ func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) { // Done 实现了 context.Context 的 Done 方法. func (mc *mergedContext) Done() <-chan struct{} { - return mc.Context.Done() + if mc.deadlineCtx != nil { + return orDone(mc.cancelCtx, mc.deadlineCtx) + } + return mc.cancelCtx.Done() } // Err 实现了 context.Context 的 Err 方法. func (mc *mergedContext) Err() error { - return mc.Context.Err() + if mc.cancelCtx.Err() != nil { + return mc.cancelCtx.Err() + } + if mc.deadlineCtx != nil { + return mc.deadlineCtx.Err() + } + return nil +} + +// Cause 返回取消原因, 使 context.Cause() 能正确传播 cause. +func (mc *mergedContext) Cause() error { + return context.Cause(mc.cancelCtx) } // orDone 是一个辅助函数, 返回一个 channel. diff --git a/mergectx_test.go b/mergectx_test.go new file mode 100644 index 0000000..d6d1225 --- /dev/null +++ b/mergectx_test.go @@ -0,0 +1,256 @@ +package touka + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestMergeCtx_NoParents(t *testing.T) { + ctx, cancel := MergeCtx() + defer cancel() + + if ctx.Err() != nil { + t.Fatal("expected no error before cancel") + } + cancel() + if ctx.Err() == nil { + t.Fatal("expected error after cancel") + } +} + +func TestMergeCtx_SingleParent(t *testing.T) { + parent, parentCancel := context.WithCancel(context.Background()) + + ctx, cancel := MergeCtx(parent) + defer cancel() + + if ctx.Err() != nil { + t.Fatal("expected no error before parent cancel") + } + + parentCancel() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after parent cancel") + } +} + +func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel1() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p1 cancel") + } + // p2 should still be fine + if p2.Err() != nil { + t.Fatal("expected p2 to be unaffected") + } +} + +func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel2() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p2 cancel") + } +} + +func TestMergeCtx_ExternalCancel(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + + cancel() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after external cancel") + } +} + +func TestMergeCtx_CausePropagation(t *testing.T) { + testErr := errors.New("test cause") + + p1, cancel1 := context.WithCancelCause(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel1(testErr) + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p1 cancel") + } + + cause := context.Cause(ctx) + if cause != testErr { + t.Fatalf("expected cause %v, got %v", testErr, cause) + } + cancel1(nil) // cleanup (already cancelled, no-op) +} + +func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) { + testErr := errors.New("second parent cause") + + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancelCause(context.Background()) + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel2(testErr) + + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p2 cancel") + } + + cause := context.Cause(ctx) + if cause != testErr { + t.Fatalf("expected cause %v, got %v", testErr, cause) + } + + cancel1() +} + +func TestMergeCtx_Deadline_Earliest(t *testing.T) { + now := time.Now() + early := now.Add(100 * time.Millisecond) + late := now.Add(1 * time.Hour) + + p1, cancel1 := context.WithDeadline(context.Background(), late) + p2, cancel2 := context.WithDeadline(context.Background(), early) + defer cancel1() + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + dl, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline to be set") + } + if !dl.Equal(early) { + t.Fatalf("expected deadline %v, got %v", early, dl) + } +} + +func TestMergeCtx_Deadline_Expires(t *testing.T) { + p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancelP() + + ctx, cancel := MergeCtx(p) + defer cancel() + + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after deadline expires") + } +} + +func TestMergeCtx_ValueLookup(t *testing.T) { + type key struct{} + p1 := context.WithValue(context.Background(), key{}, "from_p1") + p2 := context.WithValue(context.Background(), key{}, "from_p2") + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + val := ctx.Value(key{}) + if val != "from_p1" { + t.Fatalf("expected 'from_p1', got %v", val) + } +} + +func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) { + type key1 struct{} + type key2 struct{} + p1 := context.WithValue(context.Background(), key1{}, "val1") + p2 := context.WithValue(context.Background(), key2{}, "val2") + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + if v := ctx.Value(key1{}); v != "val1" { + t.Fatalf("expected 'val1', got %v", v) + } + if v := ctx.Value(key2{}); v != "val2" { + t.Fatalf("expected 'val2', got %v", v) + } + if v := ctx.Value("missing"); v != nil { + t.Fatalf("expected nil, got %v", v) + } +} + +func TestMergeCtx_ContextInterface(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + defer cancel2() + + var ctx context.Context + ctx, _ = MergeCtx(p1, p2) + + // Verify all Context interface methods work + _ = ctx.Done() + _ = ctx.Err() + _, _ = ctx.Deadline() + _ = ctx.Value("any") +} + +func TestOrDone_SingleContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + done := orDone(ctx) + + cancel() + <-done // should not block +} + +func TestOrDone_MultipleContexts(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + done := orDone(p1, p2) + + cancel1() + <-done // should not block +} + +func TestOrDone_SecondContextCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + + done := orDone(p1, p2) + + cancel2() + <-done // should not block +} From 390190695fe77e2c79dd2b1b4ac15e57e1d87864 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 22 Apr 2026 08:51:42 +0800 Subject: [PATCH 50/55] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20examples/http?= =?UTF-8?q?c=20=E4=B8=AD=20c.String=20=E9=9D=9E=E5=B8=B8=E9=87=8F=20format?= =?UTF-8?q?=20string=20=E7=BC=96=E8=AF=91=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/httpc/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/httpc/main.go b/examples/httpc/main.go index f50ec9a..db2be4f 100644 --- a/examples/httpc/main.go +++ b/examples/httpc/main.go @@ -21,7 +21,7 @@ func main() { c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) return } - c.String(http.StatusOK, body) + c.String(http.StatusOK, "%s", body) }) // 示例 2:带 Header 的 POST 请求 @@ -90,7 +90,7 @@ func main() { c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) return } - c.String(http.StatusOK, body) + c.String(http.StatusOK, "%s", body) }) fmt.Println("Server running on :8080") From 6006267d256254d3bb16cb8079c8b4eba1b88746 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 22 Apr 2026 09:00:01 +0800 Subject: [PATCH 51/55] =?UTF-8?q?fix:=20Done()=20=E4=BD=BF=E7=94=A8=20sync?= =?UTF-8?q?.Once=20=E7=BC=93=E5=AD=98=20channel=EF=BC=8C=E9=81=BF=E5=85=8D?= =?UTF-8?q?=E9=87=8D=E5=A4=8D=E5=88=9B=E5=BB=BA=20orDone=20goroutine?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复 Gemini 审查意见:多次调用 Done() 时不再重复创建 goroutine, 每个 mergedContext 最多产生 2 个 orDone goroutine。 Alina Agent生成 --- mergectx.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mergectx.go b/mergectx.go index 2e36c09..9aab5bb 100644 --- a/mergectx.go +++ b/mergectx.go @@ -20,6 +20,9 @@ type mergedContext struct { cancelCtx context.Context // deadlineCtx 仅在有 deadline 时非 nil, 用于检测 deadline 到期. deadlineCtx context.Context + // done 缓存 Done() 的 channel, 避免重复创建 orDone goroutine. + done <-chan struct{} + doneOnce sync.Once } // MergeCtx 创建并返回一个新的 context.Context. @@ -135,7 +138,10 @@ func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) { // Done 实现了 context.Context 的 Done 方法. func (mc *mergedContext) Done() <-chan struct{} { if mc.deadlineCtx != nil { - return orDone(mc.cancelCtx, mc.deadlineCtx) + mc.doneOnce.Do(func() { + mc.done = orDone(mc.cancelCtx, mc.deadlineCtx) + }) + return mc.done } return mc.cancelCtx.Done() } From d8a5f200c1376035456b128da80db1b115f9012f Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 22 Apr 2026 09:17:02 +0800 Subject: [PATCH 52/55] =?UTF-8?q?fix:=20Client()/HTTPC()=20=E4=BC=98?= =?UTF-8?q?=E5=85=88=E4=BD=BF=E7=94=A8=20per-request=20HTTPClient=20?= =?UTF-8?q?=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复 Gemini 审查意见:中间件设置的自定义 HTTPClient 不再被绕过。 Client() 和 HTTPC() 现在优先使用 Context.HTTPClient, 仅在未设置时回退到 Engine 默认实例。 Alina Agent生成 --- context.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/context.go b/context.go index 540e27f..f21ed48 100644 --- a/context.go +++ b/context.go @@ -864,19 +864,27 @@ func (c *Context) GetErrors() []error { return c.Errors } -// Client 返回 Engine 提供的 HTTPClient -// 方便在请求处理函数中进行出站 HTTP请求 +// Client 返回当前请求的 HTTPClient +// 如果请求处理函数或中间件设置了自定义 HTTPClient,返回该实例; +// 否则返回 Engine 提供的默认实例 // // Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context func (c *Context) Client() *httpc.Client { + if c.HTTPClient != nil { + return c.HTTPClient + } return c.engine.HTTPClient } // HTTPC 返回自动关联请求 Context 的 HTTP 客户端 // 当请求被取消时,通过此客户端发起的出站请求也会自动取消 func (c *Context) HTTPC() *contextHTTPClient { + client := c.HTTPClient + if client == nil { + client = c.engine.HTTPClient + } return &contextHTTPClient{ - client: c.engine.HTTPClient, + client: client, ctx: c.ctx, } } From 2d693e3b13bfc075c32281401d34fb88568dd268 Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 22 Apr 2026 09:27:53 +0800 Subject: [PATCH 53/55] =?UTF-8?q?refactor:=20mergectx=20=E7=AE=80=E5=8C=96?= =?UTF-8?q?=E7=BB=93=E6=9E=84=EF=BC=8C=E4=BF=AE=E5=A4=8D=20Gemini=20?= =?UTF-8?q?=E5=AE=A1=E6=9F=A5=E6=84=8F=E8=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - deadlineCtx 改为 cancelCtx 的子 context,建立父子层级关系 - 嵌入 cancelCtx/context.Context 直接提供 Done()/Err()/Deadline(),移除冗余方法 - orDone 中加入 cancelCtx,防止手动 cancel() 时 goroutine 泄漏 - 移除 cancelCtx/deadlineCtx/done/doneOnce 字段,struct 简化为 Context + parents - 移除冗余 Cause() 方法(context.Cause 用 Value(&cancelCtxKey) 机制) - 移除 Done()/Err() 显式实现,由嵌入 context 自动提供 Alina Agent生成 --- mergectx.go | 122 +++++++++++++++------------------------------------- 1 file changed, 34 insertions(+), 88 deletions(-) diff --git a/mergectx.go b/mergectx.go index 9aab5bb..9c30b92 100644 --- a/mergectx.go +++ b/mergectx.go @@ -6,23 +6,15 @@ package touka import ( "context" - "sync" "time" ) // mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. +// 嵌入 cancelCtx 作为基础 context, 支持 cause 传播. +// deadlineCtx 作为 cancelCtx 的子 context, 确保 deadline 到期时 cancelCtx 也被取消. type mergedContext struct { - // 嵌入一个基础 context, 用于 Deadline() 和 Value() 查找. context.Context - // 保存了所有的父 context, 用于 Value() 方法的查找. parents []context.Context - // cancelCtx 由 CancelCause 管理, 当 cause 取消时其 Done() 关闭. - cancelCtx context.Context - // deadlineCtx 仅在有 deadline 时非 nil, 用于检测 deadline 到期. - deadlineCtx context.Context - // done 缓存 Done() 的 channel, 避免重复创建 orDone goroutine. - done <-chan struct{} - doneOnce sync.Once } // MergeCtx 创建并返回一个新的 context.Context. @@ -50,70 +42,63 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C } } - // baseCtx 提供 CancelCauseFunc 以支持 cause 传播. - baseCtx, baseCancel := context.WithCancelCause(context.Background()) + // cancelCtx 作为基础 context, 提供 CancelCauseFunc 以支持 cause 传播. + cancelCtx, cancelCause := context.WithCancelCause(context.Background()) - // deadlineCtx 仅用于监听 deadline 到期信号. + // deadlineCtx 作为 cancelCtx 的子 context (如果有 deadline). + // 当 cancelCtx 被取消时, deadlineCtx 也会被取消; + // 当 deadline 到期时, deadlineCtx 自行取消, watcher 负责关闭 cancelCtx. var deadlineCtx context.Context var deadlineCancel context.CancelFunc if !earliestDeadline.IsZero() { - deadlineCtx, deadlineCancel = context.WithDeadlineCause(context.Background(), earliestDeadline, context.DeadlineExceeded) + deadlineCtx, deadlineCancel = context.WithDeadlineCause(cancelCtx, earliestDeadline, context.DeadlineExceeded) } - // 嵌入的 context: 有 deadline 时用 deadlineCtx, 否则用 baseCtx. - embedCtx := baseCtx + // 嵌入的 context: 有 deadline 时用 deadlineCtx (以返回正确的 Deadline), + // 否则用 cancelCtx. + embedCtx := cancelCtx if deadlineCtx != nil { embedCtx = deadlineCtx } mc := &mergedContext{ - Context: embedCtx, - parents: parents, - cancelCtx: baseCtx, - deadlineCtx: deadlineCtx, + Context: embedCtx, + parents: parents, } - // 启动监控 goroutine. + // 启动监控 goroutine, 监听 parent 取消或 deadline 到期. go func() { - var once sync.Once - doCancel := func(cause error) { - once.Do(func() { baseCancel(cause) }) - } - defer doCancel(nil) - - parentDone := orDone(mc.parents...) + // 将 cancelCtx 加入 orDone, 确保手动 cancel() 时 orDone goroutine 能退出, 防止泄漏. + parentDone := orDone(append(mc.parents, cancelCtx)...) if deadlineCtx != nil { defer deadlineCancel() select { case <-parentDone: + // parent 取消或手动 cancel() for _, p := range mc.parents { if p.Err() != nil { - doCancel(context.Cause(p)) + cancelCause(context.Cause(p)) return } } - doCancel(nil) + // 手动 cancel(), cause 已由 cancelCause() 设置 case <-deadlineCtx.Done(): - doCancel(context.DeadlineExceeded) - case <-baseCtx.Done(): + // deadline 到期, 需要关闭 cancelCtx 并设置 cause + cancelCause(context.DeadlineExceeded) } } else { - select { - case <-parentDone: - for _, p := range mc.parents { - if p.Err() != nil { - doCancel(context.Cause(p)) - return - } + <-parentDone + for _, p := range mc.parents { + if p.Err() != nil { + cancelCause(context.Cause(p)) + return } - doCancel(nil) - case <-baseCtx.Done(): } } }() - return mc, func() { baseCancel(nil) } + return mc, func() { cancelCause(nil) } } // Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause), @@ -130,62 +115,23 @@ func (mc *mergedContext) Value(key any) any { return nil } -// Deadline 实现了 context.Context 的 Deadline 方法. -func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) { - return mc.Context.Deadline() -} +// Deadline, Done, Err 均由嵌入的 context.Context 提供. -// Done 实现了 context.Context 的 Done 方法. -func (mc *mergedContext) Done() <-chan struct{} { - if mc.deadlineCtx != nil { - mc.doneOnce.Do(func() { - mc.done = orDone(mc.cancelCtx, mc.deadlineCtx) - }) - return mc.done - } - return mc.cancelCtx.Done() -} - -// Err 实现了 context.Context 的 Err 方法. -func (mc *mergedContext) Err() error { - if mc.cancelCtx.Err() != nil { - return mc.cancelCtx.Err() - } - if mc.deadlineCtx != nil { - return mc.deadlineCtx.Err() - } - return nil -} - -// Cause 返回取消原因, 使 context.Cause() 能正确传播 cause. -func (mc *mergedContext) Cause() error { - return context.Cause(mc.cancelCtx) -} - -// orDone 是一个辅助函数, 返回一个 channel. -// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭. -// 这是一个非阻塞的、不会泄漏 goroutine 的实现. +// orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭. func orDone(contexts ...context.Context) <-chan struct{} { done := make(chan struct{}) - - var once sync.Once - closeDone := func() { - once.Do(func() { - close(done) - }) - } - - // 为每个父 context 启动一个 goroutine. for _, ctx := range contexts { go func(c context.Context) { select { case <-c.Done(): - closeDone() + select { + case <-done: + default: + close(done) + } case <-done: - // orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出. } }(ctx) } - return done } From 9dcab4b1ae609a0ca33d8e5c803c34c42fe3a38d Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 22 Apr 2026 09:37:19 +0800 Subject: [PATCH 54/55] =?UTF-8?q?fix:=20orDone=20=E4=BD=BF=E7=94=A8=20sync?= =?UTF-8?q?.Once=20=E4=BF=AE=E5=A4=8D=20close(done)=20=E7=AB=9E=E6=80=81?= =?UTF-8?q?=E6=9D=A1=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复 Gemini 审查意见:多 goroutine 同时 close(done) 可能导致 panic。 恢复 sync.Once 保证 channel 只被关闭一次。 Alina Agent生成 --- mergectx.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mergectx.go b/mergectx.go index 9c30b92..404f7b1 100644 --- a/mergectx.go +++ b/mergectx.go @@ -6,6 +6,7 @@ package touka import ( "context" + "sync" "time" ) @@ -120,15 +121,12 @@ func (mc *mergedContext) Value(key any) any { // orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭. func orDone(contexts ...context.Context) <-chan struct{} { done := make(chan struct{}) + var once sync.Once for _, ctx := range contexts { go func(c context.Context) { select { case <-c.Done(): - select { - case <-done: - default: - close(done) - } + once.Do(func() { close(done) }) case <-done: } }(ctx) From 3c40a3d6b532d02d99a38a5f630550a8b59f65ff Mon Sep 17 00:00:00 2001 From: wjqserver <114663932+WJQSERVER@users.noreply.github.com> Date: Wed, 22 Apr 2026 09:37:45 +0800 Subject: [PATCH 55/55] =?UTF-8?q?fix:=20=E4=BF=AE=E6=AD=A3=20GetHTTPC=20?= =?UTF-8?q?=E6=B3=A8=E9=87=8A=E4=B8=AD=E6=96=B9=E6=B3=95=E5=90=8D=20typo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HTTPClient() → HTTPC() Alina Agent生成 --- compat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compat.go b/compat.go index 4e40687..0be715d 100644 --- a/compat.go +++ b/compat.go @@ -44,7 +44,7 @@ func (c *Context) GetLoggerReco() *reco.Logger { // --- httpc 兼容函数 --- // GetHTTPC 返回底层的 httpc.Client 实例 -// Deprecated: 使用 HTTPClient() 替代,新方法会自动关联请求 Context +// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context // //go:fix inline func (c *Context) GetHTTPC() *httpc.Client {