diff --git a/.gitignore b/.gitignore index 30d74d2..6f301cd 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -test \ No newline at end of file +test +/bench_route_match_baseline.txt diff --git a/README.md b/README.md index a7b99fd..e2eaec8 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,9 @@ func main() { c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query) }) - // 启动服务器 (支持优雅关闭) + // 启动服务器(通过 WithGracefulShutdown 启用优雅关闭) log.Println("Touka Server starting on :8080...") - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("Touka server failed to start: %v", err) } } diff --git a/about-touka.md b/about-touka.md index 86a056f..b3a16b4 100644 --- a/about-touka.md +++ b/about-touka.md @@ -70,13 +70,13 @@ func main() { r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB // ... 其他配置 - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` #### 1.3. 服务器生命周期管理 -Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。 +Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。 ```go func main() { @@ -90,11 +90,11 @@ func main() { fmt.Println("自定义的 HTTP 服务器配置已应用") }) - // 启动服务器,并支持优雅关闭 - // RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号 - // 第二个参数是优雅关闭的超时时间 + // 启动服务器,并通过 Run 选项启用优雅关闭 + // Run(...) 会阻塞当前 goroutine + // WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒 fmt.Println("服务器启动于 :8080") - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("服务器启动失败: %v", err) } } @@ -187,7 +187,7 @@ func main() { } } - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } func AuthMiddleware() touka.HandlerFunc { @@ -313,7 +313,7 @@ func main() { }) }) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } // templates/index.html @@ -400,7 +400,7 @@ func main() { c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID}) }) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` @@ -483,7 +483,7 @@ func main() { // 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获 r.StaticDir("/files", "./non-existent-dir") - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` @@ -546,7 +546,7 @@ func main() { // 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录 r.StaticFS("/", http.FS(subFS)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/compat.go b/compat.go new file mode 100644 index 0000000..0be715d --- /dev/null +++ b/compat.go @@ -0,0 +1,52 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2024 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "github.com/WJQSERVER-STUDIO/httpc" + "github.com/fenthope/reco" +) + +// --- reco 兼容函数 --- + +// GetLogReco 返回底层的 reco.Logger 实例 +// 用于需要访问 reco 特定功能的场景 +// 如果当前 logger 不是 *reco.Logger 类型,返回 nil +// +//go:fix inline +func (engine *Engine) GetLogReco() *reco.Logger { + return engine.LogReco +} + +// SetLogReco 设置 reco.Logger 实例 +// 用于向后兼容,等价于 SetLogger(l) +// +//go:fix inline +func (engine *Engine) SetLogReco(l *reco.Logger) { + engine.LogReco = l + engine.logger = l +} + +// GetLoggerReco 返回底层的 reco.Logger 实例 +// 用于需要访问 reco 特定功能的场景 +// 如果当前 logger 不是 *reco.Logger 类型,返回 nil +// +//go:fix inline +func (c *Context) GetLoggerReco() *reco.Logger { + if rl, ok := c.engine.logger.(*reco.Logger); ok { + return rl + } + return c.engine.LogReco +} + +// --- httpc 兼容函数 --- + +// GetHTTPC 返回底层的 httpc.Client 实例 +// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context +// +//go:fix inline +func (c *Context) GetHTTPC() *httpc.Client { + return c.Client() +} diff --git a/context.go b/context.go index 2e4d2bb..f21ed48 100644 --- a/context.go +++ b/context.go @@ -26,7 +26,6 @@ import ( "time" "github.com/WJQSERVER/wanf" - "github.com/fenthope/reco" "github.com/go-json-experiment/json" "github.com/WJQSERVER-STUDIO/go-utils/iox" @@ -44,6 +43,8 @@ type Context struct { handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler) index int8 // 当前执行到处理链的哪个位置 + requestBodyPrepared bool + mu sync.RWMutex Keys map[string]any // 用于在中间件之间传递数据 @@ -71,6 +72,12 @@ type Context struct { // skippedNodes 用于记录跳过的节点信息,以便回溯 // 通常在处理嵌套路由时使用 SkippedNodes []skippedNode + + // fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲. + fixedPathBuf []byte + + allowedMethodsBuf []string + allowHeaderBuf []byte } // --- Context 相关方法实现 --- @@ -95,19 +102,42 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 - c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 + c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map c.Errors = c.Errors[:0] // 清空 Errors 切片 c.queryCache = nil // 清空查询参数缓存 c.formCache = nil // 清空表单数据缓存 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize + c.requestBodyPrepared = false if cap(c.SkippedNodes) > 0 { c.SkippedNodes = c.SkippedNodes[:0] } else { c.SkippedNodes = make([]skippedNode, 0, 256) } + if cap(c.fixedPathBuf) > 0 { + c.fixedPathBuf = c.fixedPathBuf[:0] + } + if cap(c.allowedMethodsBuf) > 0 { + c.allowedMethodsBuf = c.allowedMethodsBuf[:0] + } + if cap(c.allowHeaderBuf) > 0 { + c.allowHeaderBuf = c.allowHeaderBuf[:0] + } +} + +func (c *Context) writeResponseBody(data []byte, contextMsg string) { + if len(data) == 0 { + return + } + if _, err := c.Writer.Write(data); err != nil { + wrapped := fmt.Errorf("%s: %w", contextMsg, err) + c.AddError(wrapped) + if c.engine != nil && c.engine.logger != nil { + c.engine.logger.Errorf("%s: %v", contextMsg, err) + } + } } // Next 在处理链中执行下一个处理函数 @@ -237,6 +267,18 @@ func (c *Context) SetMaxRequestBodySize(size int64) { c.MaxRequestBodySize = size } +func (c *Context) prepareRequestBody() io.ReadCloser { + if c.Request == nil || c.Request.Body == nil { + return nil + } + if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 { + return c.Request.Body + } + c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) + c.requestBodyPrepared = true + return c.Request.Body +} + // Query 从 URL 查询参数中获取值 // 懒加载解析查询参数,并进行缓存 func (c *Context) Query(key string) string { @@ -258,7 +300,39 @@ func (c *Context) DefaultQuery(key, defaultValue string) string { // 懒加载解析表单数据,并进行缓存 func (c *Context) PostForm(key string) string { if c.formCache == nil { - c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded + if c.MaxRequestBodySize > 0 { + c.prepareRequestBody() + } + contentType := c.Request.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + + switch mediaType { + case "multipart/form-data": + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + case "application/x-www-form-urlencoded": + if err := c.Request.ParseForm(); err != nil { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + default: + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { + if !errors.Is(err, http.ErrNotMultipart) { + c.AddError(fmt.Errorf("parse form error: %w", err)) + c.formCache = make(url.Values) + return "" + } + } + } c.formCache = c.Request.PostForm } return c.formCache.Get(key) @@ -282,20 +356,20 @@ func (c *Context) Param(key string) string { func (c *Context) Raw(code int, contentType string, data []byte) { c.Writer.Header().Set("Content-Type", contentType) c.Writer.WriteHeader(code) - c.Writer.Write(data) + c.writeResponseBody(data, "failed to write raw response") } // String 向响应写入格式化的字符串 func (c *Context) String(code int, format string, values ...any) { c.Writer.WriteHeader(code) - c.Writer.Write(fmt.Appendf(nil, format, values...)) + c.writeResponseBody(fmt.Appendf(nil, format, values...), "failed to write string response") } // Text 向响应写入无需格式化的string func (c *Context) Text(code int, text string) { c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write([]byte(text)) + c.writeResponseBody([]byte(text), "failed to write text response") } // FileText @@ -338,8 +412,11 @@ func (c *Context) FileText(code int, filePath string) { } c.SetHeader("Content-Type", "text/plain; charset=utf-8") - - c.SetBodyStream(file, int(fileInfo.Size())) + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) + c.Writer.WriteHeader(code) + if _, err := iox.Copy(c.Writer, file); err != nil { + c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err)) + } } /* @@ -430,7 +507,7 @@ func (c *Context) JSONBuf(code int, obj any) { c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered JSON response") } // GOB 向响应写入GOB数据 @@ -459,7 +536,7 @@ func (c *Context) GOBBuf(code int, obj any) { } c.Writer.Header().Set("Content-Type", "application/octet-stream") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered GOB response") } // WANF向响应写入WANF数据 @@ -488,7 +565,7 @@ func (c *Context) WANFBuf(code int, obj any) { } c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered WANF response") } // HTML 渲染 HTML 模板 @@ -512,7 +589,7 @@ func (c *Context) HTML(code int, name string, obj any) { // 可以扩展支持其他渲染器接口 } // 默认简单输出,用于未配置 HTMLRender 的情况 - c.Writer.Write(fmt.Appendf(nil, "\n
%v", name, obj)) + c.writeResponseBody(fmt.Appendf(nil, "\n
%v", name, obj), "failed to write HTML response") } // HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体. @@ -537,7 +614,7 @@ func (c *Context) HTMLBuf(code int, name string, obj any) { // 渲染成功,写入响应 c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") c.Writer.WriteHeader(code) - c.Writer.Write(buf.Bytes()) + c.writeResponseBody(buf.Bytes(), "failed to write buffered HTML response") return } @@ -557,10 +634,16 @@ func (c *Context) Redirect(code int, location string) { // ShouldBindJSON 尝试将请求体绑定到 JSON 对象 func (c *Context) ShouldBindJSON(obj any) error { - if c.Request.Body == nil { + var body io.ReadCloser + if c.MaxRequestBodySize > 0 { + body = c.prepareRequestBody() + } else { + body = c.Request.Body + } + if body == nil { return errors.New("request body is empty") } - err := json.UnmarshalRead(c.Request.Body, obj) + err := json.UnmarshalRead(body, obj) if err != nil { return fmt.Errorf("json binding error: %w", err) } @@ -569,10 +652,16 @@ func (c *Context) ShouldBindJSON(obj any) error { // ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象 func (c *Context) ShouldBindWANF(obj any) error { - if c.Request.Body == nil { + var body io.ReadCloser + if c.MaxRequestBodySize > 0 { + body = c.prepareRequestBody() + } else { + body = c.Request.Body + } + if body == nil { return errors.New("request body is empty") } - decoder, err := wanf.NewStreamDecoder(c.Request.Body) + decoder, err := wanf.NewStreamDecoder(body) if err != nil { return fmt.Errorf("failed to create WANF decoder: %w", err) } @@ -585,10 +674,16 @@ func (c *Context) ShouldBindWANF(obj any) error { // ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象 func (c *Context) ShouldBindGOB(obj any) error { - if c.Request.Body == nil { + var body io.ReadCloser + if c.MaxRequestBodySize > 0 { + body = c.prepareRequestBody() + } else { + body = c.Request.Body + } + if body == nil { return errors.New("request body is empty") } - decoder := gob.NewDecoder(c.Request.Body) + decoder := gob.NewDecoder(body) if err := decoder.Decode(obj); err != nil { return fmt.Errorf("GOB binding error: %w", err) } @@ -705,6 +800,10 @@ func setFieldValue(field reflect.Value, values []string) error { // ShouldBindForm 尝试将表单数据绑定到结构体 // 支持 application/x-www-form-urlencoded 和 multipart/form-data func (c *Context) ShouldBindForm(obj any) error { + if c.MaxRequestBodySize > 0 { + c.prepareRequestBody() + } + contentType := c.Request.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(contentType) if err != nil { @@ -713,7 +812,7 @@ func (c *Context) ShouldBindForm(obj any) error { switch mediaType { case "multipart/form-data": - if err := c.Request.ParseMultipartForm(32 << 20); err != nil { + if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { return fmt.Errorf("parse multipart form error: %w", err) } case "application/x-www-form-urlencoded": @@ -727,6 +826,7 @@ func (c *Context) ShouldBindForm(obj any) error { if err := bindForm(c.Request.Form, obj); err != nil { return fmt.Errorf("form binding error: %w", err) } + c.formCache = c.Request.PostForm return nil } @@ -764,10 +864,29 @@ func (c *Context) GetErrors() []error { return c.Errors } -// Client 返回 Engine 提供的 HTTPClient -// 方便在请求处理函数中进行出站 HTTP 请求 +// Client 返回当前请求的 HTTPClient +// 如果请求处理函数或中间件设置了自定义 HTTPClient,返回该实例; +// 否则返回 Engine 提供的默认实例 +// +// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context func (c *Context) Client() *httpc.Client { - return c.HTTPClient + if c.HTTPClient != nil { + return c.HTTPClient + } + return c.engine.HTTPClient +} + +// HTTPC 返回自动关联请求 Context 的 HTTP 客户端 +// 当请求被取消时,通过此客户端发起的出站请求也会自动取消 +func (c *Context) HTTPC() *contextHTTPClient { + client := c.HTTPClient + if client == nil { + client = c.engine.HTTPClient + } + return &contextHTTPClient{ + client: client, + ctx: c.ctx, + } } // Context() 返回请求的上下文,用于取消操作 @@ -827,37 +946,30 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) { // GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体 // 注意:请求体只能读取一次 func (c *Context) GetReqBody() io.ReadCloser { + if c.MaxRequestBodySize > 0 { + return c.prepareRequestBody() + } + if c.Request == nil || c.Request.Body == nil { + return nil + } return c.Request.Body } // GetReqBodyFull 读取并返回请求体的所有内容 // 注意:请求体只能读取一次 func (c *Context) GetReqBodyFull() ([]byte, error) { - if c.Request.Body == nil { + body := c.GetReqBody() + if body == nil { return nil, nil } + defer func() { + err := body.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() - var limitBytesReader io.ReadCloser - - if c.MaxRequestBodySize > 0 { - limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } else { - limitBytesReader = c.Request.Body - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } - - data, err := iox.ReadAll(limitBytesReader) + data, err := io.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -867,31 +979,18 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { // 类似 GetReqBodyFull, 返回 *bytes.Buffer func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { - if c.Request.Body == nil { + body := c.GetReqBody() + if body == nil { return nil, nil } + defer func() { + err := body.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() - var limitBytesReader io.ReadCloser - - if c.MaxRequestBodySize > 0 { - limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } else { - limitBytesReader = c.Request.Body - defer func() { - err := limitBytesReader.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - } - - data, err := iox.ReadAll(limitBytesReader) + data, err := io.ReadAll(body) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -1050,14 +1149,9 @@ func (c *Context) GetProtocol() string { return c.Request.Proto } -// GetHTTPC 获取框架自带传递的httpc -func (c *Context) GetHTTPC() *httpc.Client { - return c.HTTPClient -} - -// GetLogger 获取engine的Logger -func (c *Context) GetLogger() *reco.Logger { - return c.engine.LogReco +// GetLogger 获取engine的Logger接口 +func (c *Context) GetLogger() Logger { + return c.engine.logger } // GetReqQueryString @@ -1216,25 +1310,25 @@ func (c *Context) DeleteCookie(name string) { // === 日志记录 === func (c *Context) Debugf(format string, args ...any) { - c.engine.LogReco.Debugf(format, args...) + c.engine.logger.Debugf(format, args...) } func (c *Context) Infof(format string, args ...any) { - c.engine.LogReco.Infof(format, args...) + c.engine.logger.Infof(format, args...) } func (c *Context) Warnf(format string, args ...any) { - c.engine.LogReco.Warnf(format, args...) + c.engine.logger.Warnf(format, args...) } func (c *Context) Errorf(format string, args ...any) { - c.engine.LogReco.Errorf(format, args...) + c.engine.logger.Errorf(format, args...) } func (c *Context) Fatalf(format string, args ...any) { - c.engine.LogReco.Fatalf(format, args...) + c.engine.logger.Fatalf(format, args...) } func (c *Context) Panicf(format string, args ...any) { - c.engine.LogReco.Panicf(format, args...) + c.engine.logger.Panicf(format, args...) } diff --git a/context_benchmark_test.go b/context_benchmark_test.go new file mode 100644 index 0000000..3c464d0 --- /dev/null +++ b/context_benchmark_test.go @@ -0,0 +1,81 @@ +package touka + +import ( + "net/http" + "testing" +) + +func TestContextResetKeepsKeysNilUntilSet(t *testing.T) { + c, _ := CreateTestContext(nil) + if c.Keys != nil { + t.Fatalf("expected fresh test context Keys to be nil before first Set") + } + + c.Set("answer", 42) + if c.Keys == nil { + t.Fatalf("expected Set to allocate Keys map") + } + if value, exists := c.Get("answer"); !exists || value != 42 { + t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists) + } + + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + c.reset(UnwrapResponseWriter(c.Writer), req) + + if c.Keys != nil { + t.Fatalf("expected reset to clear Keys without allocating a new map") + } + if value, exists := c.Get("answer"); exists || value != nil { + t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists) + } + + ctxValue := c.Value("missing") + if ctxValue != nil { + t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue) + } + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected MustGet to panic for missing key after reset") + } + }() + _ = c.MustGet("answer") +} + +func BenchmarkContextReset(b *testing.B) { + b.Run("NoKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + rawWriter := UnwrapResponseWriter(c.Writer) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(rawWriter, req) + } + }) + + b.Run("WithKeysUse", func(b *testing.B) { + c, _ := CreateTestContext(nil) + rawWriter := UnwrapResponseWriter(c.Writer) + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c.reset(rawWriter, req) + c.Set("request-id", i) + } + }) + +} diff --git a/context_bodylimit_test.go b/context_bodylimit_test.go new file mode 100644 index 0000000..1e7696a --- /dev/null +++ b/context_bodylimit_test.go @@ -0,0 +1,174 @@ +package touka + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +type zeroNilThenEOFReader struct { + readCalls int +} + +func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) { + r.readCalls++ + if r.readCalls == 1 { + return 0, nil + } + return 0, io.EOF +} + +func (r *zeroNilThenEOFReader) Close() error { + return nil +} + +func TestFileTextUsesProvidedStatusCode(t *testing.T) { + t.Helper() + + dir := t.TempDir() + filePath := filepath.Join(dir, "hello.txt") + if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil { + t.Fatalf("write temp file: %v", err) + } + + rr := httptest.NewRecorder() + c, _ := CreateTestContext(rr) + + c.FileText(http.StatusCreated, filePath) + + if rr.Code != http.StatusCreated { + t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code) + } + if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" { + t.Fatalf("unexpected content type: %q", got) + } + if body := rr.Body.String(); body != "hello touka" { + t.Fatalf("unexpected body: %q", body) + } +} + +func TestMaxBytesReaderAllowsExactLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4) + defer reader.Close() + + data, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("expected exact limit read to succeed, got %v", err) + } + if string(data) != "abcd" { + t.Fatalf("unexpected data: %q", string(data)) + } +} + +func TestMaxBytesReaderRejectsOverLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4) + defer reader.Close() + + _, err := io.ReadAll(reader) + if !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge, got %v", err) + } +} + +func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1) + defer reader.Close() + + buf := make([]byte, 1) + n, err := reader.Read(buf) + if n != 0 || err != nil { + t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err) + } + + n, err = reader.Read(buf) + if n != 0 || !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err) + } +} + +func TestMaxBytesReaderTreatsZeroLimitAsUnlimited(t *testing.T) { + t.Helper() + + reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abc")), 0) + defer reader.Close() + + data, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("expected zero limit to leave body unlimited, got %v", err) + } + if string(data) != "abc" { + t.Fatalf("unexpected data: %q", string(data)) + } +} + +func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) { + t.Helper() + + body := strings.NewReader(`{"name":"abcdef"}`) + req := httptest.NewRequest(http.MethodPost, "/json", body) + req.Header.Set("Content-Type", "application/json") + + c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) + c.SetMaxRequestBodySize(8) + + var payload struct { + Name string `json:"name"` + } + + err := c.ShouldBindJSON(&payload) + if !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge, got %v", err) + } +} + +func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) { + t.Helper() + + body := strings.NewReader("name=abcdef") + req := httptest.NewRequest(http.MethodPost, "/form", body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) + c.SetMaxRequestBodySize(4) + + var payload struct { + Name string `form:"name"` + } + + err := c.ShouldBindForm(&payload) + if !errors.Is(err, ErrBodyTooLarge) { + t.Fatalf("expected ErrBodyTooLarge, got %v", err) + } +} + +func TestPostFormHonorsMaxRequestBodySize(t *testing.T) { + t.Helper() + + body := strings.NewReader("name=abcdef") + req := httptest.NewRequest(http.MethodPost, "/form", body) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) + c.SetMaxRequestBodySize(4) + + if got := c.PostForm("name"); got != "" { + t.Fatalf("expected empty value on over-limit form body, got %q", got) + } + if len(c.Errors) == 0 { + t.Fatal("expected parse error to be recorded") + } + if !errors.Is(c.Errors[0], ErrBodyTooLarge) { + t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0]) + } +} diff --git a/context_httpc.go b/context_httpc.go new file mode 100644 index 0000000..3256a3b --- /dev/null +++ b/context_httpc.go @@ -0,0 +1,58 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2024 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "context" + + "github.com/WJQSERVER-STUDIO/httpc" +) + +// contextHTTPClient 包装 httpc.Client,自动关联请求的 Context +// 当请求被取消时,出站 HTTP 请求也会自动取消 +type contextHTTPClient struct { + client *httpc.Client + ctx context.Context +} + +// NewRequestBuilder 创建请求构建器,自动关联请求 Context +func (c *contextHTTPClient) NewRequestBuilder(method, urlStr string) *httpc.RequestBuilder { + return c.client.NewRequestBuilder(method, urlStr).WithContext(c.ctx) +} + +// GET 创建 GET 请求构建器 +func (c *contextHTTPClient) GET(urlStr string) *httpc.RequestBuilder { + return c.client.GET(urlStr).WithContext(c.ctx) +} + +// POST 创建 POST 请求构建器 +func (c *contextHTTPClient) POST(urlStr string) *httpc.RequestBuilder { + return c.client.POST(urlStr).WithContext(c.ctx) +} + +// PUT 创建 PUT 请求构建器 +func (c *contextHTTPClient) PUT(urlStr string) *httpc.RequestBuilder { + return c.client.PUT(urlStr).WithContext(c.ctx) +} + +// DELETE 创建 DELETE 请求构建器 +func (c *contextHTTPClient) DELETE(urlStr string) *httpc.RequestBuilder { + return c.client.DELETE(urlStr).WithContext(c.ctx) +} + +// PATCH 创建 PATCH 请求构建器 +func (c *contextHTTPClient) PATCH(urlStr string) *httpc.RequestBuilder { + return c.client.PATCH(urlStr).WithContext(c.ctx) +} + +// HEAD 创建 HEAD 请求构建器 +func (c *contextHTTPClient) HEAD(urlStr string) *httpc.RequestBuilder { + return c.client.HEAD(urlStr).WithContext(c.ctx) +} + +// OPTIONS 创建 OPTIONS 请求构建器 +func (c *contextHTTPClient) OPTIONS(urlStr string) *httpc.RequestBuilder { + return c.client.OPTIONS(urlStr).WithContext(c.ctx) +} diff --git a/docs/advanced.md b/docs/advanced.md index a7cb9a2..eb44c2d 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -44,7 +44,9 @@ r.SetTLSServerConfigurator(func(server *http.Server) { Touka 支持配置 HTTP/1.1、HTTP/2 和 H2C(HTTP/2 Cleartext): ```go -// 使用默认协议配置(仅 HTTP/1.1) +// 使用默认协议配置 +// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集, +// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。 r.SetDefaultProtocols() // 自定义协议配置 @@ -57,33 +59,147 @@ r.SetProtocols(&touka.ProtocolsConfig{ ### 启动方式 -Touka 提供了多种服务器启动方式: +Touka 统一通过 `Run(opts...)` 启动服务器: ```go // 1. 简单启动(无优雅停机) -r.Run(":8080") +r.Run(touka.WithAddr(":8080")) // 2. 带优雅停机的启动 -r.RunShutdown(":8080", 10*time.Second) +r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)) // 3. 带上下文的优雅停机 ctx, cancel := context.WithCancel(context.Background()) -r.RunShutdownWithContext(":8080", ctx, 10*time.Second) +defer cancel() +r.Run( + touka.WithAddr(":8080"), + touka.WithGracefulShutdown(10*time.Second), + touka.WithShutdownContext(ctx), +) // 4. HTTPS 启动 tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, // 其他 TLS 配置... } -r.RunTLS(":443", tlsConfig, 10*time.Second) +// WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。 +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithGracefulShutdownDefault(), +) // 5. HTTPS + HTTP 重定向 -r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second) +// WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。 +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect(":80"), + touka.WithGracefulShutdown(10*time.Second), +) + +// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) + +// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host) +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) ``` +### HTTPS Redirect Host 策略 + +`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。 + +可用的 redirect 子选项: + +- `touka.WithUseHeaderHost(true|false)` +- `touka.WithRedirectHostHeaders([]string{...})` +- `touka.WithRedirectHost("example.com")` + +#### 模式一:使用请求输入侧的 host + +当 `WithUseHeaderHost(true)` 时: + +- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host` +- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值 +- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(true), + touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), + ), +) +``` + +#### 模式二:使用配置的固定 host + +当 `WithUseHeaderHost(false)` 时: + +- 不读取 `Request.Host` +- 不读取 `WithRedirectHostHeaders(...)` +- 必须配置 `WithRedirectHost("example.com")` + +示例: + +```go +r.Run( + touka.WithAddr(":443"), + touka.WithTLS(tlsConfig), + touka.WithHTTPRedirect( + ":80", + touka.WithUseHeaderHost(false), + touka.WithRedirectHost("example.com"), + ), +) +``` + +#### 严格校验规则 + +以下组合会直接返回配置错误: + +- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)` +- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)` +- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)` +- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)` +- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)` + +#### 优先级关系 + +1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式 +2. `WithUseHeaderHost(...)` 决定 host 来源模式 +3. 当 `WithUseHeaderHost(true)` 时: + - 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询 + - 未配置时使用 `Request.Host` +4. 当 `WithUseHeaderHost(false)` 时: + - 只使用 `WithRedirectHost(...)` + +**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。 + ## 优雅停机 (Graceful Shutdown) -在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。 +在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 ```go r := touka.Default() @@ -91,7 +207,7 @@ r := touka.Default() // 监听 SIGINT 和 SIGTERM 信号 // 如果在 10 秒内未处理完,则强制关闭 -if err := r.RunShutdown(":8080", 10*time.Second); err != nil { +if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatal("服务器退出异常:", err) } ``` diff --git a/docs/httpc.md b/docs/httpc.md new file mode 100644 index 0000000..8742c18 --- /dev/null +++ b/docs/httpc.md @@ -0,0 +1,188 @@ +# HTTP Client (httpc) + +Touka 内置了 [httpc](https://github.com/WJQSERVER-STUDIO/httpc) HTTP 客户端,方便在请求处理函数中发起出站 HTTP 请求。 + +## 核心特性 + +- **自动 Context 关联**:使用 `HTTPC()` 方法时,出站请求会自动关联当前请求的 Context +- **请求取消传播**:当客户端断开连接时,出站请求会自动取消,避免资源泄漏 +- **链式调用**:保持 httpc 原有的组合式构建器风格 + +## 基本用法 + +### 简单 GET 请求 + +```go +r.GET("/proxy", func(c *touka.Context) { + body, err := c.HTTPC(). + GET("https://api.example.com/data"). + Text() + if err != nil { + c.JSON(500, touka.H{"error": err.Error()}) + return + } + c.String(200, body) +}) +``` + +### POST JSON 请求 + +```go +r.POST("/users", func(c *touka.Context) { + var req struct { + Name string `json:"name"` + Email string `json:"email"` + } + c.ShouldBindJSON(&req) + + var result struct { + ID int `json:"id"` + Name string `json:"name"` + } + + err := c.HTTPC(). + POST("https://api.example.com/users"). + SetHeader("Authorization", "Bearer "+token). + SetJSONBody(req). + DecodeJSON(&result) + if err != nil { + c.JSON(500, touka.H{"error": err.Error()}) + return + } + c.JSON(200, result) +}) +``` + +### 带查询参数 + +```go +r.GET("/search", func(c *touka.Context) { + query := c.Query("q") + + var result SearchResult + err := c.HTTPC(). + GET("https://api.example.com/search"). + SetQueryParam("q", query). + SetQueryParam("limit", "10"). + DecodeJSON(&result) + if err != nil { + c.JSON(500, touka.H{"error": err.Error()}) + return + } + c.JSON(200, result) +}) +``` + +## API 对比 + +### 旧方式(Deprecated) + +```go +// 需要手动 WithContext,容易忘记 +resp, err := c.Client(). + WithContext(c.Context()). + GET(url). + Execute() +``` + +### 新方式(推荐) + +```go +// 自动关联请求 Context +resp, err := c.HTTPC(). + GET(url). + Execute() +``` + +## Context 取消机制 + +使用 `HTTPC()` 时,当客户端断开连接(如关闭浏览器),出站请求会自动取消: + +```go +r.GET("/long-task", func(c *touka.Context) { + // 这个请求会在客户端断开时自动取消 + resp, err := c.HTTPC(). + GET("https://slow-api.example.com/data"). + Execute() + + // 如果客户端已断开,err 会包含 context.Canceled + if errors.Is(err, context.Canceled) { + return // 客户端已断开,无需处理 + } + // ... +}) +``` + +## 完整 API + +### contextHTTPClient 方法 + +| 方法 | 返回类型 | 说明 | +|------|----------|------| +| `NewRequestBuilder(method, url)` | `*httpc.RequestBuilder` | 创建通用请求构建器 | +| `GET(url)` | `*httpc.RequestBuilder` | 创建 GET 请求 | +| `POST(url)` | `*httpc.RequestBuilder` | 创建 POST 请求 | +| `PUT(url)` | `*httpc.RequestBuilder` | 创建 PUT 请求 | +| `DELETE(url)` | `*httpc.RequestBuilder` | 创建 DELETE 请求 | +| `PATCH(url)` | `*httpc.RequestBuilder` | 创建 PATCH 请求 | +| `HEAD(url)` | `*httpc.RequestBuilder` | 创建 HEAD 请求 | +| `OPTIONS(url)` | `*httpc.RequestBuilder` | 创建 OPTIONS 请求 | + +### httpc.RequestBuilder 链式方法 + +返回 `*httpc.RequestBuilder`(用于链式调用): + +| 方法 | 说明 | +|------|------| +| `WithContext(ctx)` | 设置 Context(通常不需要,已自动关联) | +| `NoDefaultHeaders()` | 不添加默认 Header | +| `SetHeader(key, value)` | 设置 Header | +| `AddHeader(key, value)` | 添加 Header(可重复) | +| `SetHeaders(map)` | 批量设置 Headers | +| `SetQueryParam(key, value)` | 设置查询参数 | +| `AddQueryParam(key, value)` | 添加查询参数(可重复) | +| `SetQueryParams(map)` | 批量设置查询参数 | +| `SetBody(io.Reader)` | 设置请求 Body | +| `SetRawBody([]byte)` | 设置字节 Body | + +返回 `(*httpc.RequestBuilder, error)`(可能失败): + +| 方法 | 说明 | +|------|------| +| `SetJSONBody(any)` | 设置 JSON Body | +| `SetXMLBody(any)` | 设置 XML Body | +| `SetGOBBody(any)` | 设置 GOB Body | + +### 终结方法 + +| 方法 | 返回类型 | 说明 | +|------|----------|------| +| `Build()` | `(*http.Request, error)` | 构建请求但不执行 | +| `Execute()` | `(*http.Response, error)` | 执行并返回原始响应 | +| `DecodeJSON(v)` | `error` | 执行并解码 JSON | +| `DecodeXML(v)` | `error` | 执行并解码 XML | +| `DecodeGOB(v)` | `error` | 执行并解码 GOB | +| `Text()` | `(string, error)` | 执行并返回文本 | +| `Bytes()` | `([]byte, error)` | 执行并返回字节 | +| `SSE()` | `(*SSEStream, error)` | 建立 SSE 流连接 | + +## 迁移指南 + +### go:fix inline 兼容 + +旧代码 `c.GetHTTPC()` 可通过 `go fix` 自动迁移到 `c.Client()`: + +```bash +go fix ./... +``` + +### 手动迁移 + +| 旧代码 | 新代码 | +|--------|--------| +| `c.GetHTTPC()` | `c.Client()` 或 `c.HTTPC()` | +| `c.Client().WithContext(ctx).GET(url)` | `c.HTTPC().GET(url)` | + +## 示例 + +完整示例请参考 [examples/httpc](../examples/httpc)。 diff --git a/docs/introduction.md b/docs/introduction.md index 94a7310..87c3e40 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其 1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。 2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。 -3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理。 +3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求。 Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。 diff --git a/docs/logger-migration-design.md b/docs/logger-migration-design.md new file mode 100644 index 0000000..7b2e0a6 --- /dev/null +++ b/docs/logger-migration-design.md @@ -0,0 +1,400 @@ +# Touka Logger 接口迁移方案 + +## 基于 Go 1.26 `go:fix inline` 的自动化迁移设计 + +--- + +## 一、问题分析 + +当前架构问题: +``` +Engine.LogReco → *reco.Logger (公开字段, 直接访问) +Context.GetLogger() → 返回 *reco.Logger (具体类型) +Context.Debugf/Infof... → 硬编码 c.engine.LogReco.Debugf(...) +``` + +这导致用户无法替换日志实现(如 zap/logrus)。 + +--- + +## 二、目标架构 + +``` +Engine.logger → Logger 接口 (私有) +Engine.LogReco → *reco.Logger (公开, Deprecated - 保持向后兼容) +Engine.GetLogger() → 返回 Logger 接口 +Engine.SetLogger(Logger)→ 设置日志实现 +Context.GetLogger() → 返回 Logger 接口 +Context.Debugf/Infof... → 调用 c.engine.logger.Debugf(...) +``` + +--- + +## 三、Logger 接口定义 + +```go +// logger.go +package touka + +// Logger 是日志接口,支持任意日志库实现 +type Logger interface { + Debugf(format string, args ...any) + Infof(format string, args ...any) + Warnf(format string, args ...any) + Errorf(format string, args ...any) + Fatalf(format string, args ...any) + Panicf(format string, args ...any) +} + +// CloserLogger 可选扩展,支持关闭操作 +type CloserLogger interface { + Logger + Close() error +} +``` + +--- + +## 四、Engine 结构变更 + +```go +// engine.go 变更 +type Engine struct { + // ... 其他字段保持不变 + + // logger 是新的日志接口 (私有) + logger Logger + + // logReco 是保留的 reco.Logger 引用 (私有) + // 用于向后兼容,当通过 SetLoggerReco 设置时同步到 logger + logReco *reco.Logger + + // 其他字段... +} +``` + +新增/修改方法: + +```go +// GetLogger 返回日志接口 +func (engine *Engine) GetLogger() Logger { + return engine.logger +} + +// SetLogger 设置任意 Logger 实现 +func (engine *Engine) SetLogger(l Logger) { + engine.logger = l + // 如果是 *reco.Logger 类型,同步更新 logReco + if rl, ok := l.(*reco.Logger); ok { + engine.logReco = rl + } else { + engine.logReco = nil + } +} + +// SetLoggerCfg 使用 reco.Config 配置日志 +func (engine *Engine) SetLoggerCfg(logcfg reco.Config) { + logger := NewLogger(logcfg) + engine.logger = logger + engine.logReco = logger +} +``` + +--- + +## 五、`go:fix inline` 兼容性函数 + +### 5.1 旧 API 包装函数 + +在 `compat.go` 中定义: + +```go +// compat.go +package touka + +import "github.com/fenthope/reco" + +// GetLogReco 返回 reco.Logger,用于向后兼容 +// +//go:fix inline +func (engine *Engine) GetLogReco() *reco.Logger { + return engine.logReco +} + +// SetLogReco 设置 reco.Logger,用于向后兼容 +// +//go:fix inline +func (engine *Engine) SetLogReco(l *reco.Logger) { + engine.logReco = l + engine.logger = l +} +``` + +### 5.2 Context 日志方法的 inline 包装 + +```go +// context_compat.go +package touka + +// Debugf 记录 Debug 级别日志 +// +//go:fix inline +func (c *Context) Debugf(format string, args ...any) { + c.engine.logger.Debugf(format, args...) +} + +// Infof 记录 Info 级别日志 +// +//go:fix inline +func (c *Context) Infof(format string, args ...any) { + c.engine.logger.Infof(format, args...) +} + +// Warnf 记录 Warn 级别日志 +// +//go:fix inline +func (c *Context) Warnf(format string, args ...any) { + c.engine.logger.Warnf(format, args...) +} + +// Errorf 记录 Error 级别日志 +// +//go:fix inline +func (c *Context) Errorf(format string, args ...any) { + c.engine.logger.Errorf(format, args...) +} + +// Fatalf 记录 Fatal 级别日志 +// +//go:fix inline +func (c *Context) Fatalf(format string, args ...any) { + c.engine.logger.Fatalf(format, args...) +} + +// Panicf 记录 Panic 级别日志 +// +//go:fix inline +func (c *Context) Panicf(format string, args ...any) { + c.engine.logger.Panicf(format, args...) +} +``` + +### 5.3 GetLogger 返回类型的兼容处理 + +由于 `GetLogger()` 返回类型从 `*reco.Logger` 变为 `Logger`,需要提供兼容函数: + +```go +// context_compat.go (续) + +// GetLoggerReco 返回 *reco.Logger 类型,用于需要具体类型的场景 +// +//go:fix inline +func (c *Context) GetLoggerReco() *reco.Logger { + if rl, ok := c.engine.logger.(*reco.Logger); ok { + return rl + } + return nil +} +``` + +--- + +## 六、go:fix inline 工作原理 + +### 迁移前用户代码: +```go +func handler(c *touka.Context) { + // 旧 API 调用 + c.Debugf("request: %s", c.Request.URL.Path) + c.engine.LogReco.Infof("server started") +} +``` + +### go fix 执行后(自动替换): +```go +func handler(c *touka.Context) { + // Debugf 被替换为函数体 + c.engine.logger.Debugf("request: %s", c.Request.URL.Path) + + // LogReco 访问无法通过 inline 自动处理,需要手动迁移 + // 或者通过 getter 调用 +} +``` + +### 对于字段访问的处理策略: + +`engine.LogReco` 字段访问无法直接用 `go:fix inline` 处理,采用以下策略: + +1. **保留字段但标记 deprecated**:继续导出 `LogReco` 但文档标记为 deprecated +2. **提供 getter/setter**:通过 `go:fix inline` 提供 `GetLogReco/SetLogReco` +3. **渐进迁移**:用户可以在方便时手动迁移到 `GetLogger()/SetLogger()` + +--- + +## 七、迁移前后对比 + +### 场景 1:基本日志调用 + +**迁移前:** +```go +func myHandler(c *touka.Context) { + c.Debugf("processing request %s", c.Request.URL.Path) + c.Infof("user %s logged in", username) + c.Warnf("slow query: %v", duration) + c.Errorf("db error: %v", err) +} +``` + +**迁移后(自动替换):** +```go +func myHandler(c *touka.Context) { + c.engine.logger.Debugf("processing request %s", c.Request.URL.Path) + c.engine.logger.Infof("user %s logged in", username) + c.engine.logger.Warnf("slow query: %v", duration) + c.engine.logger.Errorf("db error: %v", err) +} +``` + +### 场景 2:Engine 配置日志 + +**迁移前:** +```go +engine := touka.New() +engine.LogReco = myLogger // 直接赋值 +logger := engine.LogReco // 直接读取 +``` + +**迁移后(手动 + 自动混合):** +```go +engine := touka.New() + +// 方式 1:使用新 API(推荐) +engine.SetLogger(myLogger) +logger := engine.GetLogger() + +// 方式 2:通过 go:fix inline 自动替换为 getter +// engine.SetLogReco(myLogger) ← go fix 替换 +// logger := engine.GetLogReco() ← go fix 替换 +``` + +### 场景 3:使用第三方日志库(新功能) + +```go +import "go.uber.org/zap" + +func main() { + zapLogger, _ := zap.NewProduction() + defer zapLogger.Sync() + + engine := touka.New() + // 使用 zap 替代默认的 reco.Logger + engine.SetLogger(&ZapAdapter{logger: zapLogger}) + + engine.GET("/api", func(c *touka.Context) { + c.Infof("api called") // 自动使用 zap 输出 + }) +} + +// ZapAdapter 适配 zap 到 touka.Logger 接口 +type ZapAdapter struct { + logger *zap.Logger +} + +func (z *ZapAdapter) Debugf(format string, args ...any) { + z.logger.Debug(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Infof(format string, args ...any) { + z.logger.Info(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Warnf(format string, args ...any) { + z.logger.Warn(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Errorf(format string, args ...any) { + z.logger.Error(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Fatalf(format string, args ...any) { + z.logger.Fatal(fmt.Sprintf(format, args...)) +} + +func (z *ZapAdapter) Panicf(format string, args ...any) { + z.logger.Panic(fmt.Sprintf(format, args...)) +} +``` + +--- + +## 八、内部使用迁移 + +框架内部代码也需要迁移,将直接调用 `engine.LogReco` 改为 `engine.logger`: + +需要修改的文件: +- `context.go`: writeResponseBody 中的 `c.engine.LogReco.Errorf` +- `recovery.go`: 如有使用日志 +- `logreco.go`: CloseLogger 方法 + +```go +// context.go 修改前 +func (c *Context) writeResponseBody(data []byte, contextMsg string) { + if _, err := c.Writer.Write(data); err != nil { + if c.engine.LogReco != nil { + c.engine.LogReco.Errorf("%s: %v", contextMsg, err) + } + } +} + +// context.go 修改后 +func (c *Context) writeResponseBody(data []byte, contextMsg string) { + if _, err := c.Writer.Write(data); err != nil { + if c.engine.logger != nil { + c.engine.logger.Errorf("%s: %v", contextMsg, err) + } + } +} +``` + +--- + +## 九、完整文件结构 + +``` +touka/ +├── logger.go # Logger 接口定义 +├── logreco.go # reco.Logger 相关工具函数 +├── compat.go # go:fix inline 兼容性函数 (Engine) +├── context_compat.go # go:fix inline 兼容性函数 (Context) +├── engine.go # Engine 结构变更 +├── context.go # Context 日志方法变更 +└── ... +``` + +--- + +## 十、版本策略 + +| 版本 | 变更内容 | +|------|---------| +| v1.x | 引入 Logger 接口,LogReco 标记 deprecated | +| v2.x | 移除 LogReco 公开字段,仅通过 getter/setter 访问 | +| v3.x | 移除 go:fix inline 兼容函数 | + +--- + +## 十一、go:fix inline 限制说明 + +1. **字段访问无法自动迁移**:`engine.LogReco` 字段访问需要用户手动修改 +2. **返回类型变更需谨慎**:`GetLogger()` 返回类型变更会导致依赖具体类型的代码失败 +3. **inline 函数有大小限制**:函数体过大会影响内联效果 +4. **跨包迁移**:`go:fix inline` 支持跨包,但用户必须运行 `go fix` + +--- + +## 十二、推荐迁移步骤 + +1. **框架侧**:添加 Logger 接口,添加 go:fix inline 函数 +2. **用户侧**:运行 `go fix ./...` 自动迁移可处理的部分 +3. **用户侧**:手动将 `engine.LogReco` 字段访问改为 `engine.SetLogger()/GetLogger()` +4. **用户侧**:如需使用第三方日志,实现 Logger 接口并通过 SetLogger 设置 diff --git a/docs/middleware.md b/docs/middleware.md index a222437..b688fb5 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -26,6 +26,41 @@ api.Use(AuthMiddleware()) } ``` +也可以在创建组时直接传入中间件: + +```go +api := r.Group("/api", AuthMiddleware(), RateLimitMiddleware()) +{ + api.GET("/user", handleUser) + api.POST("/data", handleData) +} +``` + +### 路由级中间件 + +为单个路由注册中间件,仅对该路由生效。 + +```go +// 单个路由中间件 +r.GET("/protected", AuthMiddleware(), func(c *touka.Context) { + c.String(http.StatusOK, "Protected content") +}) + +// 多个路由中间件(按顺序执行) +r.POST("/upload", + RateLimitMiddleware(), + AuthMiddleware(), + PermissionCheckMiddleware(), + func(c *touka.Context) { + // 处理上传 + }, +) + +// 路由组中的单个路由也可以使用路由级中间件 +api := r.Group("/api") +api.GET("/admin", AdminAuthMiddleware(), adminHandler) +``` + ## 编写自定义中间件 中间件的函数签名是 `touka.HandlerFunc`。 @@ -67,6 +102,36 @@ func APIKeyAuth() touka.HandlerFunc { } ``` +## 中间件执行顺序 + +理解中间件的执行顺序对于构建正确的处理流程至关重要。**注意:注册顺序决定了执行逻辑**,中间件必须在注册路由之前调用(全局中间件应在创建组或定义路由前注册)。中间件按照以下顺序执行: + +```go +// 全局中间件 +r.Use(GlobalMiddleware1()) +r.Use(GlobalMiddleware2()) + +// 组中间件 +api := r.Group("/api", GroupMiddleware1()) +api.Use(GroupMiddleware2()) + +// 路由级中间件 +api.GET("/users", RouteMiddleware1(), RouteMiddleware2(), userHandler) +``` + +对于 `/api/users` 请求,执行顺序为: +1. `GlobalMiddleware1()` - 全局中间件 +2. `GlobalMiddleware2()` - 全局中间件 +3. `GroupMiddleware1()` - 路由组中间件 +4. `GroupMiddleware2()` - 路由组中间件 +5. `RouteMiddleware1()` - 路由级中间件 +6. `RouteMiddleware2()` - 路由级中间件 +7. `userHandler` - 最终处理函数 + +``` +请求进入 → 全局中间件 → 路由组中间件 → 路由级中间件 → 最终处理函数 → 路由级中间件后置逻辑 → 路由组中间件后置逻辑 → 全局中间件后置逻辑 → 响应 +``` + ## 内置中间件 - **Recovery**: 捕获任何发生的 panic,恢复运行并返回 500 错误。它还负责调用全局错误处理器。 diff --git a/docs/quickstart.md b/docs/quickstart.md index 94f7433..2911732 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -46,7 +46,7 @@ func main() { // 4. 启动服务器并监听 8080 端口 log.Println("Touka server is running on :8080") - if err := r.Run(":8080"); err != nil { + if err := r.Run(touka.WithAddr(":8080")); err != nil { log.Fatalf("Server failed: %v", err) } } @@ -66,11 +66,11 @@ go run main.go ## 优雅停机 -在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成。 +在生产环境中,我们推荐为 `Run` 追加优雅关闭选项。启用后,Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`。 ```go // 等待 10 秒以处理剩余请求 -if err := r.RunShutdown(":8080", 10*time.Second); err != nil { +if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatalf("Server forced to shutdown: %v", err) } ``` diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 5dfcbd1..cb4b2a3 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -28,7 +28,7 @@ func main() { Target: target, })) - _ = r.Run(":8080") + _ = r.Run(touka.WithAddr(":8080")) } ``` @@ -59,11 +59,16 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ ```go type ReverseProxyConfig struct { - Target *url.URL + Target *url.URL + Targets []string + + LoadBalancing ReverseProxyLoadBalancingConfig + PassiveHealth ReverseProxyPassiveHealthConfig Transport http.RoundTripper FlushInterval time.Duration BufferPool BufferPool + AllowH2CUpstream bool ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error @@ -78,12 +83,133 @@ type ReverseProxyConfig struct { ### `Target` -必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。 +与 `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme` 和 `host`。 ```go target, _ := url.Parse("http://backend:9000") ``` +### `Targets` + +可选。用于配置多个后端目标地址。 + +- `Target` 与 `Targets` 互斥,只能使用其中一种 +- `Targets` 的每一项都必须是完整 URL +- 每个 target 仍然可以自带 base path 和 query + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Targets: []string{ + "http://127.0.0.1:9001/base?from=a", + "http://127.0.0.1:9002/base?from=b", + }, +})) +``` + +这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。 + +### `LoadBalancing` + +用于配置 upstream 选择策略和重试行为。 + +```go +type ReverseProxyLoadBalancingConfig struct { + Policy ReverseProxyLBPolicy + Retries int + TryDuration time.Duration + TryInterval time.Duration +} +``` + +当前内置策略: + +- `touka.LBRandom()` +- `touka.LBRoundRobin()` +- `touka.LBFirst()` +- `touka.LBLeastConn()` +- `touka.LBIPHash()` +- `touka.LBClientIPHash()` +- `touka.LBURIHash()` +- `touka.LBHeader("X-Upstream", fallback)` +- `touka.LBQuery("tenant", fallback)` + +其中: + +- `LBFirst()` 适合主备/故障转移顺序 +- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback +- 如果 `LBHeader` / `LBQuery` 没有显式 fallback,则默认回退到 `LBRandom()` + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Targets: []string{ + "http://127.0.0.1:9001", + "http://127.0.0.1:9002", + }, + LoadBalancing: touka.ReverseProxyLoadBalancingConfig{ + Policy: touka.LBHeader("X-Upstream", touka.LBFirst()), + Retries: 1, + }, +})) +``` + +重试说明: + +- 只对未开始收到上游响应的失败进行重试 +- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试 +- `Retries` 表示额外重试次数 +- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机 +- `TryInterval` 表示两次重试之间的等待间隔 + +### `PassiveHealth` + +用于配置被动健康检查。它不会后台探测 upstream,而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。 + +```go +type ReverseProxyPassiveHealthConfig struct { + FailDuration time.Duration + MaxFails int + UnhealthyStatus []int +} +``` + +- `FailDuration > 0` 时启用被动健康跟踪 +- `MaxFails <= 0` 时默认按 `1` 处理 +- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Targets: []string{ + "http://127.0.0.1:9001", + "http://127.0.0.1:9002", + }, + LoadBalancing: touka.ReverseProxyLoadBalancingConfig{ + Policy: touka.LBFirst(), + }, + PassiveHealth: touka.ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + UnhealthyStatus: []int{http.StatusServiceUnavailable}, + }, +})) +``` + +### `AllowH2CUpstream` + +允许代理使用未加密 HTTP/2(h2c)与 `http://` upstream 通信。 + +- 默认关闭 +- 这是一个显式配置项 +- 启用后,Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游 +- 这意味着上游本身也必须显式支持 h2c;它不是“先试 h2c,失败再自动回退到 h1”的协商模式 + +```go +r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + AllowH2CUpstream: true, +})) +``` + +对于下游 HTTP/2 extended `CONNECT` websocket 场景,Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade,以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。 + ### `Transport` 可选。用于自定义底层转发所使用的 `http.RoundTripper`。 @@ -150,6 +276,8 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ 在请求真正发往后端前,对出站请求做最后修改。 +如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。 + 常见用途: - 覆盖 `Host` @@ -242,11 +370,20 @@ const ( r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "gateway-1", + ForwardedBy: "_gateway-1", Via: "edge-1", })) ``` +如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。 + +- IPv4:`203.0.113.43` +- IPv6 / 带端口:`[2001:db8::17]:443` +- 匿名标识:`_gateway-1` +- 未知:`unknown` + +像 `gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。 + `Via` 不是“留空即禁用”的开关。当前实现中: - 如果 `Via` 非空,则使用该值追加 `Via` @@ -282,11 +419,14 @@ Touka 会尽量遵循代理链语义: Touka 的反向代理实现支持以下能力: +- `CONNECT` 隧道转发(HTTP/1.x) +- HTTP/2 extended `CONNECT` - `Connection: Upgrade` / `Upgrade` 协议升级转发 - WebSocket 等 101 Switching Protocols 场景 - SSE(Server-Sent Events)立即刷新 - Trailer 透传 - 1xx 响应透传 +- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理 例如,代理 WebSocket 服务: @@ -341,7 +481,7 @@ func main() { r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "gateway-1", + ForwardedBy: "_gateway-1", Via: "gateway-1", FlushInterval: 100 * time.Millisecond, ModifyRequest: func(req *http.Request) { @@ -357,7 +497,7 @@ func main() { }, })) - if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { log.Fatal(err) } } diff --git a/docs/routing.md b/docs/routing.md index e90308e..70a24dc 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -22,6 +22,8 @@ r.ANY("/any", handle) r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) ``` +服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。 + ## 路径参数 (Named Parameters) 使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 @@ -140,7 +142,7 @@ func main() { r := touka.Default() fsroot, _ := fs.Sub(content, "dist") r.StaticFS("/", http.FS(fsroot)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/docs/sse.md b/docs/sse.md index 1b44521..a003be9 100644 --- a/docs/sse.md +++ b/docs/sse.md @@ -125,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) { 2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。 3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。 -**注意:** 请务必使用 `RunShutdown`、`RunTLS` 或 `RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号。 +**注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)` 或 `touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文。 diff --git a/docs/static-files.md b/docs/static-files.md index a2138cd..b1f06a8 100644 --- a/docs/static-files.md +++ b/docs/static-files.md @@ -39,7 +39,7 @@ func main() { // 您也可以使用 StaticFS 服务根路径 // r.StaticFS("/", http.FS(fsroot)) - r.Run(":8080") + r.Run(touka.WithAddr(":8080")) } ``` diff --git a/ecw.go b/ecw.go index 754571f..dedbe27 100644 --- a/ecw.go +++ b/ecw.go @@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool { func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { hijacker, ok := ecw.w.(http.Hijacker) if !ok { - return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface") + return nil, nil, http.ErrNotSupported } return hijacker.Hijack() } diff --git a/ecw_benchmark_test.go b/ecw_benchmark_test.go new file mode 100644 index 0000000..d9a427c --- /dev/null +++ b/ecw_benchmark_test.go @@ -0,0 +1,59 @@ +package touka + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestErrorCapturingResponseWriterResetClearsHeaderSnapshot(t *testing.T) { + c, _ := CreateTestContext(nil) + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + + ecw.capturedErrorSignal = true + ecw.Header().Set("Content-Type", "text/plain") + ecw.Header().Add("X-Test", "one") + + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + + ecw.reset(httptest.NewRecorder(), req, c, c.engine.errorHandle.handler) + + if len(ecw.headerSnapshot) != 0 { + t.Fatalf("expected header snapshot to be empty after reset, got %#v", ecw.headerSnapshot) + } +} + +func BenchmarkErrorCapturingResponseWriterReset(b *testing.B) { + c, _ := CreateTestContext(nil) + ecw := AcquireErrorCapturingResponseWriter(c) + defer ReleaseErrorCapturingResponseWriter(ecw) + + rawWriter := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + + keys := make([]string, 16) + for i := range keys { + keys[i] = http.CanonicalHeaderKey("X-Test-" + string(rune('A'+i))) + } + values := []string{"one", "two", "three"} + for _, key := range keys { + ecw.headerSnapshot[key] = values + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + ecw.reset(rawWriter, req, c, c.engine.errorHandle.handler) + for _, key := range keys { + ecw.headerSnapshot[key] = values + } + } +} diff --git a/engine.go b/engine.go index c2eae91..15df162 100644 --- a/engine.go +++ b/engine.go @@ -7,9 +7,11 @@ package touka import ( "context" "errors" + "io" "reflect" "runtime" "strings" + "unicode/utf8" "net/http" @@ -17,6 +19,7 @@ import ( "github.com/WJQSERVER-STUDIO/httpc" "github.com/fenthope/reco" + "github.com/go-json-experiment/json" ) // Last 返回链中的最后一个处理函数 @@ -49,8 +52,14 @@ type Engine struct { HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求 + // LogReco 保留的 reco.Logger 字段 + // Deprecated: 使用 SetLogger/GetLogger 替代 LogReco *reco.Logger + // logger 是新的日志接口,支持任意 Logger 实现 + // 优先级: logger > LogReco + logger Logger + HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 routesInfo []RouteInfo // 存储所有注册的路由信息 @@ -81,6 +90,11 @@ type Engine struct { // GlobalMaxRequestBodySize 全局请求体Body大小限制 GlobalMaxRequestBodySize int64 + + notFoundChain HandlersChain + notFoundNoMethodChain HandlersChain + unmatchedFSChain HandlersChain + unmatchedFSNoMethodChain HandlersChain } // HandleFunc 注册一个或多个 HTTP 方法的路由 @@ -116,6 +130,90 @@ type ErrorHandle struct { type ErrorHandler func(c *Context, code int, err error) +var errMethodNotAllowed = errors.New("method not allowed") +var errNotFound = errors.New("not found") + +type defaultErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` +} + +var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error()) +var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error()) + +func mustMarshalDefaultErrorBody(code int, errMsg string) []byte { + body, err := json.Marshal(defaultErrorResponse{ + Code: code, + Message: http.StatusText(code), + Error: errMsg, + }) + if err != nil { + panic(err) + } + return body +} + +func writeDefaultErrorJSON(c *Context, code int, body []byte) { + if c == nil || c.Writer == nil { + return + } + c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") + c.Writer.WriteHeader(code) + c.writeResponseBody(body, "failed to write default error response") + c.Writer.Flush() + c.Abort() +} + +var methodNotAllowedHandler HandlerFunc = func(c *Context) { + httpMethod := c.Request.Method + requestPath := routeLookupPath(c.Request) + engine := c.engine + // 是否是OPTIONS方式 + if httpMethod == http.MethodOptions { + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0]) + c.allowedMethodsBuf = allowedMethods[:0] + if len(allowedMethods) > 0 { + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allowedMethods { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", string(allowHeader)) + c.Status(http.StatusOK) + return + } + return + } + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + tempSkippedNodes := GetTempSkippedNodes() + for _, treeIter := range engine.methodTrees { + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + continue + } + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + *tempSkippedNodes = (*tempSkippedNodes)[:0] + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 + if value.handlers != nil { + PutTempSkippedNodes(tempSkippedNodes) + // 使用定义的ErrorHandle处理 + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed) + return + } + } + PutTempSkippedNodes(tempSkippedNodes) +} + +var notFoundHandler HandlerFunc = func(c *Context) { + engine := c.engine + engine.errorHandle.handler(c, http.StatusNotFound, errNotFound) +} + // defaultErrorHandle 默认错误处理 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 select { @@ -126,16 +224,22 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是 if c.Writer.Written() { return } + if len(c.Errors) == 0 { + switch { + case code == http.StatusNotFound && errors.Is(err, errNotFound): + writeDefaultErrorJSON(c, code, defaultNotFoundBody) + return + case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed): + writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody) + return + } + } // 输出json 状态码与状态码对应描述 var errMsg string if err != nil { errMsg = err.Error() } - c.JSON(code, H{ - "code": code, - "message": http.StatusText(code), - "error": errMsg, - }) + c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg}) c.Writer.Flush() c.Abort() return @@ -210,6 +314,7 @@ func New() *Engine { TLSServerConfigurator: nil, GlobalMaxRequestBodySize: -1, } + engine.rebuildFallbackChains() engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() @@ -265,16 +370,30 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) { // 是否开启MethodNotAllowed func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { engine.HandleMethodNotAllowed = enable + engine.rebuildFallbackChains() } -// SetLogger传入实例 -func (engine *Engine) SetLogger(logger *reco.Logger) { - engine.LogReco = logger +// SetLogger 传入 Logger 接口实例 +func (engine *Engine) SetLogger(logger Logger) { + engine.logger = logger + // 同步更新 LogReco 以保持向后兼容 + if rl, ok := logger.(*reco.Logger); ok { + engine.LogReco = rl + } else { + engine.LogReco = nil + } } -// 配置日志LoggerCfg +// GetLogger 返回 Logger 接口实例 +func (engine *Engine) GetLogger() Logger { + return engine.logger +} + +// SetLoggerCfg 使用 reco.Config 配置日志 func (engine *Engine) SetLoggerCfg(logcfg reco.Config) { - engine.LogReco = NewLogger(logcfg) + logger := NewLogger(logcfg) + engine.logger = logger + engine.LogReco = logger } // 设置自定义错误处理 @@ -305,6 +424,7 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF engine.unMatchFS.ServeUnmatchedAsFS = false engine.UnMatchFSRoutes = nil } + engine.rebuildFallbackChains() } // 获取默认Protocol配置 @@ -340,11 +460,28 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) { }() } +func cloneServerProtocols(protocols *http.Protocols) *http.Protocols { + if protocols == nil { + return nil + } + cloned := *protocols + return &cloned +} + +func applyServerProtocols(srv *http.Server, protocols *http.Protocols) { + if protocols != nil { + srv.Protocols = cloneServerProtocols(protocols) + if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() { + if err := configureHTTP2ExtendedConnectServer(srv); err != nil { + panic(err) + } + } + } +} + // applyDefaultServerConfig 应用框架的默认配置到 http.Server func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { - if engine.serverProtocols != nil { - srv.Protocols = engine.serverProtocols - } + applyServerProtocols(srv, engine.serverProtocols) } // 配置全局Req Body大小限制 @@ -473,66 +610,64 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { // 405中间件 func MethodNotAllowed() HandlerFunc { - return func(c *Context) { - httpMethod := c.Request.Method - requestPath := c.Request.URL.Path - engine := c.engine - // 是否是OPTIONS方式 - if httpMethod == http.MethodOptions { - // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := []string{} - for _, treeIter := range engine.methodTrees { - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - allowedMethods = append(allowedMethods, treeIter.method) - } - } - if len(allowedMethods) > 0 { - // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) - c.Status(http.StatusOK) - return - } - } - // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 - for _, treeIter := range engine.methodTrees { - if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 - continue - } - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - tempSkippedNodes := GetTempSkippedNodes() - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - PutTempSkippedNodes(tempSkippedNodes) - if value.handlers != nil { - // 使用定义的ErrorHandle处理 - engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) - return - } - } - } + return methodNotAllowedHandler } // 404最后处理 func NotFound() HandlerFunc { - return func(c *Context) { - engine := c.engine - engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found")) - } + return notFoundHandler } // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) func (Engine *Engine) NoRoute(handler HandlerFunc) { Engine.noRoute = handler Engine.noRoutes = nil + Engine.rebuildFallbackChains() } // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { Engine.noRoute = nil Engine.noRoutes = handlerFuncs + Engine.rebuildFallbackChains() +} + +func (engine *Engine) rebuildFallbackChains() { + buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain { + finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound + if includeMethodNotAllowed { + finalSize++ + } + if includeUnmatchedFS { + finalSize += len(engine.UnMatchFSRoutes) + } + if engine.noRoute != nil { + finalSize++ + } else { + finalSize += len(engine.noRoutes) + } + + chain := make(HandlersChain, 0, finalSize) + chain = append(chain, engine.globalHandlers...) + if includeMethodNotAllowed { + chain = append(chain, methodNotAllowedHandler) + } + if includeUnmatchedFS { + chain = append(chain, engine.UnMatchFSRoutes...) + } + if engine.noRoute != nil { + chain = append(chain, engine.noRoute) + } else if len(engine.noRoutes) > 0 { + chain = append(chain, engine.noRoutes...) + } + chain = append(chain, notFoundHandler) + return chain + } + + engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false) + engine.notFoundNoMethodChain = buildChain(false, false) + engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS) + engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS) } // combineHandlers 组合多个处理函数链为一个 @@ -547,8 +682,9 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // Use 将全局中间件添加到 Engine // 这些中间件将应用于所有注册的路由 -func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { +func (engine *Engine) Use(middleware ...HandlerFunc) Router { engine.globalHandlers = append(engine.globalHandlers, middleware...) + engine.rebuildFallbackChains() return engine } @@ -615,7 +751,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo { // Group 创建一个新的路由组 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 -func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 basePath: resolveRoutePath("/", relativePath), @@ -624,7 +760,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRoute } // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 -// 它也实现了 IRouter 接口,允许嵌套分组 +// 它也实现了 Router 接口,允许嵌套分组 type RouterGroup struct { Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 basePath string // 组路径前缀 @@ -633,7 +769,7 @@ type RouterGroup struct { // Use 将中间件应用于当前路由组 // 这些中间件将应用于当前组及其子组的所有路由 -func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { +func (group *RouterGroup) Use(middleware ...HandlerFunc) Router { group.Handlers = append(group.Handlers, middleware...) return group } @@ -679,7 +815,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) { } // Group 为当前组创建一个新的子组 -func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router { return &RouterGroup{ Handlers: group.engine.combineHandlers(group.Handlers, handlers), basePath: resolveRoutePath(group.basePath, relativePath), @@ -704,8 +840,13 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // handleRequest 负责根据请求查找路由并执行相应的处理函数链 // 这是路由查找和执行的核心逻辑 func (engine *Engine) handleRequest(c *Context) { + if isGeneralOptionsRequest(c.Request) { + engine.handleGeneralOptions(c) + return + } + httpMethod := c.Request.Method - requestPath := c.Request.URL.Path + requestPath := routeLookupPath(c.Request) // 查找对应的路由树的根节点 rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 @@ -725,7 +866,7 @@ func (engine *Engine) handleRequest(c *Context) { } // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) - if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 + if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向 if value.tsr && engine.RedirectTrailingSlash { // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ redirectPath := requestPath @@ -737,51 +878,98 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 return } - // 尝试不区分大小写的查找 - // 直接在 rootNode 上调用 findCaseInsensitivePath 方法 - ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) - if found && engine.RedirectFixedPath { - c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 - return + if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) { + // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. + ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash) + if found { + c.fixedPathBuf = ciPath[:0] + c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径 + return + } + c.fixedPathBuf = c.fixedPathBuf[:0] } } } - // 构建处理链 - // 组合全局中间件和路由处理函数 - handlers := engine.globalHandlers - - // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 - // 则在全局中间件之后添加 MethodNotAllowed 处理器 - if engine.HandleMethodNotAllowed { - handlers = append(handlers, MethodNotAllowed()) - } - - // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed - // 则在处理链的最后添加 UnMatchFS 处理器 if engine.unMatchFS.ServeUnmatchedAsFS { - /* - var unMatchFSHandle = c.engine.unMatchFileServer - handlers = append(handlers, unMatchFSHandle) - */ - handlers = append(handlers, engine.UnMatchFSRoutes...) + c.handlers = engine.unmatchedFSChain + } else { + c.handlers = engine.notFoundChain } - - // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS - // 则在处理链的最后添加 NoRoute 处理器 - if engine.noRoute != nil { - handlers = append(handlers, engine.noRoute) - } else if len(engine.noRoutes) > 0 { - handlers = append(handlers, engine.noRoutes...) - } - - handlers = append(handlers, NotFound()) - - c.handlers = handlers c.Next() // 执行处理函数链 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } +func routeLookupPath(req *http.Request) string { + if req == nil { + return "" + } + + if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") { + return "/" + req.RequestURI + } + if isGeneralOptionsRequest(req) { + return "" + } + if req.URL == nil { + return "" + } + return req.URL.Path +} + +func isGeneralOptionsRequest(req *http.Request) bool { + return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" +} + +func shouldTryFixedPathLookup(path string, root *node) bool { + if root != nil && root.hasCaseInsensitivePath { + return true + } + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false +} + +func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string { + if cap(allowedMethods) < len(engine.methodTrees) { + allowedMethods = make([]string, 0, len(engine.methodTrees)) + } else { + allowedMethods = allowedMethods[:0] + } + tempSkippedNodes := GetTempSkippedNodes() + for _, treeIter := range engine.methodTrees { + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + *tempSkippedNodes = (*tempSkippedNodes)[:0] + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) + if value.handlers != nil { + allowedMethods = append(allowedMethods, treeIter.method) + } + } + PutTempSkippedNodes(tempSkippedNodes) + return allowedMethods +} + +func (engine *Engine) handleGeneralOptions(c *Context) { + if c == nil || c.Request == nil { + return + } + + c.Writer.Header().Set("Content-Length", "0") + if c.Request.ContentLength != 0 { + mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10) + _, _ = io.Copy(io.Discard, mb) + } + c.Writer.WriteHeader(http.StatusOK) + c.Abort() +} + // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // 它可以用于在长连接 (如 SSE) 中监听关闭信号. func (engine *Engine) Context() context.Context { diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go new file mode 100644 index 0000000..666e8b2 --- /dev/null +++ b/engine_benchmark_test.go @@ -0,0 +1,71 @@ +package touka + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +var benchmarkStatusCode int + +func buildServeHTTPBenchmarkEngine() *Engine { + engine := New() + engine.GET("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.GET("/api/v1/users/:id", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/api/v1/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + return engine +} + +func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) { + b.Helper() + + req, err := http.NewRequest(method, path, nil) + if err != nil { + b.Fatalf("failed to build request: %v", err) + } + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr = httptest.NewRecorder() + engine.ServeHTTP(rr, req) + } + + benchmarkStatusCode = rr.Code +} + +func BenchmarkServeHTTP(b *testing.B) { + engine := buildServeHTTPBenchmarkEngine() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users") + }) + + b.Run("NotFound", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist") + }) + + b.Run("MethodNotAllowed", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users") + }) + + b.Run("OptionsAllow", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") + }) + + b.Run("FixedPathRedirect", func(b *testing.B) { + benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS") + }) +} diff --git a/engine_test.go b/engine_test.go new file mode 100644 index 0000000..4772810 --- /dev/null +++ b/engine_test.go @@ -0,0 +1,306 @@ +package touka + +import ( + "bufio" + "encoding/json" + "errors" + "html/template" + "net" + "net/http" + "testing" +) + +type failingResponseWriter struct { + header http.Header + status int + err error +} + +func (w *failingResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *failingResponseWriter) WriteHeader(statusCode int) { + if w.status == 0 { + w.status = statusCode + } +} + +func (w *failingResponseWriter) Write(p []byte) (int, error) { + if w.status == 0 { + w.status = http.StatusOK + } + if w.err != nil { + return 0, w.err + } + return len(p), nil +} + +func (w *failingResponseWriter) Flush() {} + +func (w *failingResponseWriter) Status() int { + return w.status +} + +func (w *failingResponseWriter) Size() int { + return 0 +} + +func (w *failingResponseWriter) Written() bool { + return w.status != 0 +} + +func (w *failingResponseWriter) IsHijacked() bool { + return false +} + +func (w *failingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} + +func TestHandleRequestRedirectFixedPath(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" { + t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location) + } +} + +func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) { + engine := New() + engine.GET("/api/v1/users/:id/settings", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code) + } +} + +func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) { + engine := New() + engine.GET("/Users/Profile", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil) + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code) + } + if location := rr.Header().Get("Location"); location != "/Users/Profile" { + t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location) + } +} + +func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) { + engine := New() + engine.GET("/Users/Profile", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic for fixed-path miss: %v", r) + } + }() + + rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code) + } +} + +func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) { + engine := New() + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "hit" { + t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got) + } +} + +func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.NoRoute(func(c *Context) { + c.Writer.Header().Set("X-NoRoute", "hit") + c.Next() + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if got := rr.Header().Get("X-NoRoute"); got != "" { + t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got) + } +} + +func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + engine.POST("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil) + if rr.Code != http.StatusOK { + t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code) + } + allow := rr.Header().Get("Allow") + if allow != "GET, POST" && allow != "POST, GET" { + t.Fatalf("expected Allow header to list matching methods, got %q", allow) + } +} + +func TestDefaultErrorHandleJSONShape(t *testing.T) { + engine := New() + rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) + if rr.Code != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code) + } + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) + } + if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" { + t.Fatalf("unexpected error payload: %+v", body) + } +} + +func TestDefaultMethodNotAllowedJSONShape(t *testing.T) { + engine := New() + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + + var body struct { + Code int `json:"code"` + Message string `json:"message"` + Error string `json:"error"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) + } + if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" { + t.Fatalf("unexpected error payload: %+v", body) + } +} + +func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) { + engine := New() + engine.SetErrorHandler(func(c *Context, code int, err error) { + c.Writer.Header().Set("X-Custom-Error", "1") + c.String(code, "custom:%v", err) + }) + engine.GET("/users", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) + if rr.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) + } + if got := rr.Header().Get("X-Custom-Error"); got != "1" { + t.Fatalf("expected custom error header, got %q", got) + } + if rr.Body.String() != "custom:method not allowed" { + t.Fatalf("expected custom error body, got %q", rr.Body.String()) + } +} + +func TestResponseHelpersCaptureWriteErrors(t *testing.T) { + testCases := []struct { + name string + run func(*Context) + }{ + {name: "Raw", run: func(c *Context) { c.Raw(http.StatusOK, "application/octet-stream", []byte("payload")) }}, + {name: "String", run: func(c *Context) { c.String(http.StatusOK, "value=%d", 1) }}, + {name: "Text", run: func(c *Context) { c.Text(http.StatusOK, "payload") }}, + {name: "JSONBuf", run: func(c *Context) { c.JSONBuf(http.StatusOK, map[string]string{"a": "b"}) }}, + {name: "GOBBuf", run: func(c *Context) { c.GOBBuf(http.StatusOK, struct{ A string }{A: "b"}) }}, + {name: "WANFBuf", run: func(c *Context) { c.WANFBuf(http.StatusOK, map[string]string{"a": "b"}) }}, + {name: "HTMLFallback", run: func(c *Context) { c.HTML(http.StatusOK, "page", map[string]string{"a": "b"}) }}, + {name: "HTMLBuf", run: func(c *Context) { + c.engine.HTMLRender = template.Must(template.New("page").Parse(`{{.a}}`)) + c.HTMLBuf(http.StatusOK, "page", map[string]string{"a": "b"}) + }}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + writerErr := errors.New("write failed") + w := &failingResponseWriter{err: writerErr} + c, _ := CreateTestContext(w) + + tc.run(c) + + if got := len(c.Errors); got != 1 { + t.Fatalf("expected exactly one captured error, got %d", got) + } + if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) { + t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1]) + } + }) + } +} + +func TestDefaultErrorFastPathCapturesWriteErrors(t *testing.T) { + writerErr := errors.New("write failed") + w := &failingResponseWriter{err: writerErr} + engine := New() + c, _ := CreateTestContext(w) + c.engine = engine + req, err := http.NewRequest(http.MethodGet, "/missing", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + c.reset(w, req) + + defaultErrorHandle(c, http.StatusNotFound, errNotFound) + + if len(c.Errors) == 0 { + t.Fatal("expected write error to be captured") + } + if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) { + t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1]) + } + if c.Writer.Status() != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, c.Writer.Status()) + } + if !c.IsAborted() { + t.Fatal("expected fast path to abort context") + } +} diff --git a/examples/httpc/main.go b/examples/httpc/main.go new file mode 100644 index 0000000..db2be4f --- /dev/null +++ b/examples/httpc/main.go @@ -0,0 +1,103 @@ +package main + +import ( + "fmt" + "net/http" + + "github.com/infinite-iroha/touka" +) + +func main() { + r := touka.Default() + + // 示例 1:简单 GET 请求(自动关联请求 Context) + r.GET("/proxy", func(c *touka.Context) { + // 使用 HTTPC() 方法,自动关联请求 Context + // 当客户端断开连接时,出站请求也会自动取消 + body, err := c.HTTPC(). + GET("https://httpbin.org/get"). + Text() + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.String(http.StatusOK, "%s", body) + }) + + // 示例 2:带 Header 的 POST 请求 + r.POST("/users", func(c *touka.Context) { + var req struct { + Name string `json:"name"` + Email string `json:"email"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()}) + return + } + + var result struct { + ID int `json:"id"` + Name string `json:"name"` + } + + // 链式调用,保持 httpc 风格 + // 注意:SetJSONBody 返回 (*RequestBuilder, error) + rb, err := c.HTTPC(). + POST("https://httpbin.org/post"). + SetHeader("X-API-Key", "secret"). + SetJSONBody(req) + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + if err := rb.DecodeJSON(&result); err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, result) + }) + + // 示例 3:带查询参数的请求 + r.GET("/search", func(c *touka.Context) { + query := c.DefaultQuery("q", "") + page := c.DefaultQuery("page", "1") + + var result struct { + Items []string `json:"items"` + Total int `json:"total"` + } + + err := c.HTTPC(). + GET("https://httpbin.org/get"). + SetQueryParam("q", query). + SetQueryParam("page", page). + DecodeJSON(&result) + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, result) + }) + + // 示例 4:使用底层 httpc.Client(旧方式,仍可用但不推荐) + r.GET("/legacy", func(c *touka.Context) { + // 旧方式:需要手动 WithContext + body, err := c.Client(). + GET("https://httpbin.org/get"). + WithContext(c.Context()). + Text() + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.String(http.StatusOK, "%s", body) + }) + + fmt.Println("Server running on :8080") + fmt.Println("Try:") + fmt.Println(" curl http://localhost:8080/proxy") + fmt.Println(" curl -X POST -d '{\"name\":\"test\",\"email\":\"test@example.com\"}' http://localhost:8080/users") + fmt.Println(" curl 'http://localhost:8080/search?q=golang&page=1'") + + // r.Run(touka.WithAddr(":8080")) +} diff --git a/examples/logger_slog/main.go b/examples/logger_slog/main.go new file mode 100644 index 0000000..2263960 --- /dev/null +++ b/examples/logger_slog/main.go @@ -0,0 +1,71 @@ +package main + +import ( + "fmt" + "log/slog" + "net/http" + "os" + + "github.com/infinite-iroha/touka" +) + +// SlogAdapter 将 slog.Logger 适配到 touka.Logger 接口 +type SlogAdapter struct { + logger *slog.Logger +} + +func NewSlogAdapter(handler slog.Handler) *SlogAdapter { + return &SlogAdapter{ + logger: slog.New(handler), + } +} + +func (s *SlogAdapter) Debugf(format string, args ...any) { + s.logger.Debug(fmt.Sprintf(format, args...)) +} + +func (s *SlogAdapter) Infof(format string, args ...any) { + s.logger.Info(fmt.Sprintf(format, args...)) +} + +func (s *SlogAdapter) Warnf(format string, args ...any) { + s.logger.Warn(fmt.Sprintf(format, args...)) +} + +func (s *SlogAdapter) Errorf(format string, args ...any) { + s.logger.Error(fmt.Sprintf(format, args...)) +} + +func (s *SlogAdapter) Fatalf(format string, args ...any) { + s.logger.Error(fmt.Sprintf(format, args...)) + os.Exit(1) +} + +func (s *SlogAdapter) Panicf(format string, args ...any) { + s.logger.Error(fmt.Sprintf(format, args...)) + panic(fmt.Sprintf(format, args...)) +} + +func main() { + engine := touka.New() + + // 使用 slog 替换默认的 reco.Logger + handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + slogAdapter := NewSlogAdapter(handler) + engine.SetLogger(slogAdapter) + + engine.GET("/", func(c *touka.Context) { + c.Infof("request received: %s", c.Request.URL.Path) + c.JSON(http.StatusOK, map[string]string{"message": "hello"}) + }) + + // 也可以获取 Logger 接口 + logger := engine.GetLogger() + logger.Debugf("engine started") + + // 也可以直接使用 slog + slog.Info("Server running", "addr", ":8080") + // engine.Run(":8080") +} diff --git a/go.mod b/go.mod index 42f4be4..dee187d 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,15 @@ module github.com/infinite-iroha/touka go 1.26 require ( - github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 + github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 github.com/WJQSERVER-STUDIO/httpc v0.9.0 github.com/WJQSERVER/wanf v0.0.8 github.com/fenthope/reco v0.0.5 github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 + golang.org/x/net v0.52.0 ) require ( github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.52.0 // indirect + golang.org/x/text v0.35.0 // indirect ) diff --git a/go.sum b/go.sum index b49879b..4b9dbd9 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= +github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 h1:Hc1O6D50U3URkdSzfQ/SgeUU750wUBCYhefdvAbE2Ck= +github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3/go.mod h1:nFQzepAwwdj5Hp5U+X19l4FVvsaOSBTW41BzfI/CkMA= github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4= github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg= github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8= @@ -12,3 +14,5 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= diff --git a/http2xconnect.go b/http2xconnect.go new file mode 100644 index 0000000..c691a77 --- /dev/null +++ b/http2xconnect.go @@ -0,0 +1,88 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2026 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "crypto/tls" + "net" + "net/http" + "sync" + "time" + _ "unsafe" + + "golang.org/x/net/http2" +) + +var enableHTTP2ExtendedConnectOnce sync.Once + +//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol +var xnetDisableHTTP2ExtendedConnectProtocol bool + +func enableHTTP2ExtendedConnectProtocol() { + enableHTTP2ExtendedConnectOnce.Do(func() { + xnetDisableHTTP2ExtendedConnectProtocol = false + }) +} + +func configureHTTP2ExtendedConnectServer(srv *http.Server) error { + if srv == nil { + return nil + } + enableHTTP2ExtendedConnectProtocol() + return http2.ConfigureServer(srv, nil) +} + +func newHTTP2ExtendedConnectTransport() http.RoundTripper { + enableHTTP2ExtendedConnectProtocol() + transport := cloneDefaultTransport() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetHTTP1(true) + transport.Protocols.SetHTTP2(true) + return transport +} + +func newHTTP1BridgeTransport() http.RoundTripper { + return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}}) +} + +func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper { + transport := cloneDefaultTransport() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetHTTP1(true) + if tlsConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } else { + transport.TLSClientConfig = tlsConfig.Clone() + } + if len(transport.TLSClientConfig.NextProtos) == 0 { + transport.TLSClientConfig.NextProtos = []string{"http/1.1"} + } + return transport +} + +func newH2CTransport() http.RoundTripper { + transport := cloneDefaultTransport() + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return transport +} + +func cloneDefaultTransport() *http.Transport { + if transport, ok := http.DefaultTransport.(*http.Transport); ok { + return transport.Clone() + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} diff --git a/iox_benchmark_test.go b/iox_benchmark_test.go new file mode 100644 index 0000000..9b43590 --- /dev/null +++ b/iox_benchmark_test.go @@ -0,0 +1,150 @@ +package touka + +import ( + "bytes" + "io" + "testing" + + "github.com/WJQSERVER-STUDIO/go-utils/iox" +) + +type benchmarkResetReader struct { + data []byte + off int +} + +func (r *benchmarkResetReader) Read(p []byte) (int, error) { + if r.off >= len(r.data) { + return 0, io.EOF + } + n := copy(p, r.data[r.off:]) + r.off += n + return n, nil +} + +func (r *benchmarkResetReader) Reset() { + r.off = 0 +} + +type benchmarkDiscardWriter struct{} + +func (benchmarkDiscardWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +var benchmarkIOXResult int64 +var benchmarkIOXBytes []byte + +func BenchmarkIOXCopyComparison(b *testing.B) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) + + b.Run("io.Copy", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := io.Copy(w, r) + if err != nil { + b.Fatalf("io.Copy failed: %v", err) + } + benchmarkIOXResult = n + } + }) + + b.Run("iox.Copy", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := iox.Copy(w, r) + if err != nil { + b.Fatalf("iox.Copy failed: %v", err) + } + benchmarkIOXResult = n + } + }) +} + +func BenchmarkIOXCopyBufferComparison(b *testing.B) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) + + b.Run("io.CopyBuffer", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + buf := make([]byte, 32*1024) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := io.CopyBuffer(w, r, buf) + if err != nil { + b.Fatalf("io.CopyBuffer failed: %v", err) + } + benchmarkIOXResult = n + } + }) + + b.Run("iox.CopyBuffer", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + w := benchmarkDiscardWriter{} + buf := make([]byte, 32*1024) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + n, err := iox.CopyBuffer(w, r, buf) + if err != nil { + b.Fatalf("iox.CopyBuffer failed: %v", err) + } + benchmarkIOXResult = n + } + }) +} + +func BenchmarkIOXReadAllComparison(b *testing.B) { + payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) + + b.Run("io.ReadAll", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + data, err := io.ReadAll(r) + if err != nil { + b.Fatalf("io.ReadAll failed: %v", err) + } + benchmarkIOXBytes = data + } + }) + + b.Run("iox.ReadAll", func(b *testing.B) { + r := &benchmarkResetReader{data: payload} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Reset() + data, err := io.ReadAll(r) + if err != nil { + b.Fatalf("iox.ReadAll failed: %v", err) + } + benchmarkIOXBytes = data + } + }) +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..1be0077 --- /dev/null +++ b/logger.go @@ -0,0 +1,23 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2024 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +// Logger 是日志接口,支持多种日志库实现(reco、zap、logrus 等) +// 用户可以通过实现此接口来替换默认的日志实现 +type Logger interface { + Debugf(format string, args ...any) + Infof(format string, args ...any) + Warnf(format string, args ...any) + Errorf(format string, args ...any) + Fatalf(format string, args ...any) + Panicf(format string, args ...any) +} + +// CloserLogger 可选扩展接口,支持关闭操作 +// 如果 Logger 实现了此接口,Engine 在关闭时会调用 Close() +type CloserLogger interface { + Logger + Close() error +} diff --git a/logreco.go b/logreco.go index 4bda8d3..e37dd53 100644 --- a/logreco.go +++ b/logreco.go @@ -39,7 +39,16 @@ func CloseLogger(logger *reco.Logger) { } } +// CloseLogger 关闭 Engine 的日志实现 +// 如果 logger 实现了 CloserLogger 接口,会调用其 Close 方法 func (engine *Engine) CloseLogger() { + if cl, ok := engine.logger.(CloserLogger); ok { + if err := cl.Close(); err != nil { + log.Printf("Close Logger Error: %s", err) + } + return + } + // 兼容旧代码 if engine.LogReco != nil { CloseLogger(engine.LogReco) } diff --git a/maxreader.go b/maxreader.go index c6201e6..4d3fb2c 100644 --- a/maxreader.go +++ b/maxreader.go @@ -23,19 +23,21 @@ type maxBytesReader struct { n int64 // read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数. read atomic.Int64 + // emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读. + emptyAtLimit atomic.Bool } // NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据, // 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误. // // 如果 r 为 nil, 会 panic. -// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r. +// 如果 n 小于等于 0, 则读取不受限制, 直接返回原始的 r. func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { if r == nil { panic("NewMaxBytesReader called with a nil reader") } - // 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser. - if n < 0 { + // 如果限制为非正数, 意味着不限制, 直接返回原始的 ReadCloser. + if n <= 0 { return r } return &maxBytesReader{ @@ -46,48 +48,53 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { // Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制. func (mbr *maxBytesReader) Read(p []byte) (int, error) { - // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. - readSoFar := mbr.read.Load() - - // 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误. - if readSoFar >= mbr.n { - return 0, ErrBodyTooLarge + if len(p) == 0 { + return 0, nil } - // 计算当前还可以读取多少字节. + // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. + readSoFar := mbr.read.Load() remaining := mbr.n - readSoFar + if remaining < 0 { + return 0, ErrBodyTooLarge + } + if remaining == 0 { + var probe [1]byte + n, err := mbr.r.Read(probe[:]) + if n > 0 { + mbr.read.Add(int64(n)) + return 0, ErrBodyTooLarge + } + if err != nil { + return 0, err + } + if mbr.emptyAtLimit.Swap(true) { + return 0, ErrBodyTooLarge + } + return 0, nil + } + mbr.emptyAtLimit.Store(false) - // 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度. - // 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数. - if int64(len(p)) > remaining { - p = p[:remaining] + // 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。 + if int64(len(p))-1 > remaining { + p = p[:remaining+1] } // 从底层 Reader 读取数据. n, err := mbr.r.Read(p) - // 如果实际读取到了数据, 更新原子计数器. - if n > 0 { - readSoFar = mbr.read.Add(int64(n)) - } - - // 如果底层 Read 返回错误 (例如 io.EOF). - if err != nil { - // 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束. - // 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了. + if int64(n) <= remaining { + if n > 0 { + mbr.read.Add(int64(n)) + } return n, err } - // 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误. - // 这是处理"跨越"限制情况的关键. - if readSoFar > mbr.n { - // 返回实际读取的字节数 n, 并附上超限错误. - // 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭. - return n, ErrBodyTooLarge + // 读取结果跨过了限制,只向上层暴露允许的部分。 + if remaining > 0 { + mbr.read.Add(remaining) } - - // 一切正常, 返回读取的字节数和 nil 错误. - return n, nil + return int(remaining), ErrBodyTooLarge } // Close 方法关闭底层的 ReadCloser, 保证资源释放. diff --git a/mergectx.go b/mergectx.go index e5d3ec4..404f7b1 100644 --- a/mergectx.go +++ b/mergectx.go @@ -11,18 +11,16 @@ import ( ) // mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. +// 嵌入 cancelCtx 作为基础 context, 支持 cause 传播. +// deadlineCtx 作为 cancelCtx 的子 context, 确保 deadline 到期时 cancelCtx 也被取消. type mergedContext struct { - // 嵌入一个基础 context, 它持有最早的 deadline 和取消信号. context.Context - // 保存了所有的父 context, 用于 Value() 方法的查找. parents []context.Context - // 用于手动取消此 mergedContext 的函数. - cancel context.CancelFunc } // MergeCtx 创建并返回一个新的 context.Context. // 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时, -// 自动被取消 (逻辑或关系). +// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context. // // 新的 context 会继承: // - Deadline: 所有父 context 中最早的截止时间. @@ -32,7 +30,8 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C return context.WithCancel(context.Background()) } if len(parents) == 1 { - return context.WithCancel(parents[0]) + ctx, cancel := context.WithCancelCause(parents[0]) + return ctx, func() { cancel(nil) } } var earliestDeadline time.Time @@ -44,37 +43,71 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C } } - var baseCtx context.Context - var baseCancel context.CancelFunc + // cancelCtx 作为基础 context, 提供 CancelCauseFunc 以支持 cause 传播. + cancelCtx, cancelCause := context.WithCancelCause(context.Background()) + + // deadlineCtx 作为 cancelCtx 的子 context (如果有 deadline). + // 当 cancelCtx 被取消时, deadlineCtx 也会被取消; + // 当 deadline 到期时, deadlineCtx 自行取消, watcher 负责关闭 cancelCtx. + var deadlineCtx context.Context + var deadlineCancel context.CancelFunc if !earliestDeadline.IsZero() { - baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline) - } else { - baseCtx, baseCancel = context.WithCancel(context.Background()) + deadlineCtx, deadlineCancel = context.WithDeadlineCause(cancelCtx, earliestDeadline, context.DeadlineExceeded) + } + + // 嵌入的 context: 有 deadline 时用 deadlineCtx (以返回正确的 Deadline), + // 否则用 cancelCtx. + embedCtx := cancelCtx + if deadlineCtx != nil { + embedCtx = deadlineCtx } mc := &mergedContext{ - Context: baseCtx, + Context: embedCtx, parents: parents, - cancel: baseCancel, } - // 启动一个监控 goroutine. + // 启动监控 goroutine, 监听 parent 取消或 deadline 到期. go func() { - defer mc.cancel() + // 将 cancelCtx 加入 orDone, 确保手动 cancel() 时 orDone goroutine 能退出, 防止泄漏. + parentDone := orDone(append(mc.parents, cancelCtx)...) - // orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭. - // 同时监听 baseCtx.Done() 以便支持手动取消. - select { - case <-orDone(mc.parents...): - case <-mc.Context.Done(): + if deadlineCtx != nil { + defer deadlineCancel() + select { + case <-parentDone: + // parent 取消或手动 cancel() + for _, p := range mc.parents { + if p.Err() != nil { + cancelCause(context.Cause(p)) + return + } + } + // 手动 cancel(), cause 已由 cancelCause() 设置 + case <-deadlineCtx.Done(): + // deadline 到期, 需要关闭 cancelCtx 并设置 cause + cancelCause(context.DeadlineExceeded) + } + } else { + <-parentDone + for _, p := range mc.parents { + if p.Err() != nil { + cancelCause(context.Cause(p)) + return + } + } } }() - return mc, mc.cancel + return mc, func() { cancelCause(nil) } } -// Value 返回当前Ctx Value +// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause), +// 再按传入顺序从 parents 中查找. func (mc *mergedContext) Value(key any) any { + if v := mc.Context.Value(key); v != nil { + return v + } for _, p := range mc.parents { if val := p.Value(key); val != nil { return val @@ -83,45 +116,20 @@ func (mc *mergedContext) Value(key any) any { return nil } -// Deadline 实现了 context.Context 的 Deadline 方法. -func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) { - return mc.Context.Deadline() -} +// Deadline, Done, Err 均由嵌入的 context.Context 提供. -// Done 实现了 context.Context 的 Done 方法. -func (mc *mergedContext) Done() <-chan struct{} { - return mc.Context.Done() -} - -// Err 实现了 context.Context 的 Err 方法. -func (mc *mergedContext) Err() error { - return mc.Context.Err() -} - -// orDone 是一个辅助函数, 返回一个 channel. -// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭. -// 这是一个非阻塞的、不会泄漏 goroutine 的实现. +// orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭. func orDone(contexts ...context.Context) <-chan struct{} { done := make(chan struct{}) - var once sync.Once - closeDone := func() { - once.Do(func() { - close(done) - }) - } - - // 为每个父 context 启动一个 goroutine. for _, ctx := range contexts { go func(c context.Context) { select { case <-c.Done(): - closeDone() + once.Do(func() { close(done) }) case <-done: - // orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出. } }(ctx) } - return done } diff --git a/mergectx_test.go b/mergectx_test.go new file mode 100644 index 0000000..d6d1225 --- /dev/null +++ b/mergectx_test.go @@ -0,0 +1,256 @@ +package touka + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestMergeCtx_NoParents(t *testing.T) { + ctx, cancel := MergeCtx() + defer cancel() + + if ctx.Err() != nil { + t.Fatal("expected no error before cancel") + } + cancel() + if ctx.Err() == nil { + t.Fatal("expected error after cancel") + } +} + +func TestMergeCtx_SingleParent(t *testing.T) { + parent, parentCancel := context.WithCancel(context.Background()) + + ctx, cancel := MergeCtx(parent) + defer cancel() + + if ctx.Err() != nil { + t.Fatal("expected no error before parent cancel") + } + + parentCancel() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after parent cancel") + } +} + +func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel1() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p1 cancel") + } + // p2 should still be fine + if p2.Err() != nil { + t.Fatal("expected p2 to be unaffected") + } +} + +func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel2() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p2 cancel") + } +} + +func TestMergeCtx_ExternalCancel(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + + cancel() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after external cancel") + } +} + +func TestMergeCtx_CausePropagation(t *testing.T) { + testErr := errors.New("test cause") + + p1, cancel1 := context.WithCancelCause(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel1(testErr) + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p1 cancel") + } + + cause := context.Cause(ctx) + if cause != testErr { + t.Fatalf("expected cause %v, got %v", testErr, cause) + } + cancel1(nil) // cleanup (already cancelled, no-op) +} + +func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) { + testErr := errors.New("second parent cause") + + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancelCause(context.Background()) + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel2(testErr) + + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p2 cancel") + } + + cause := context.Cause(ctx) + if cause != testErr { + t.Fatalf("expected cause %v, got %v", testErr, cause) + } + + cancel1() +} + +func TestMergeCtx_Deadline_Earliest(t *testing.T) { + now := time.Now() + early := now.Add(100 * time.Millisecond) + late := now.Add(1 * time.Hour) + + p1, cancel1 := context.WithDeadline(context.Background(), late) + p2, cancel2 := context.WithDeadline(context.Background(), early) + defer cancel1() + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + dl, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline to be set") + } + if !dl.Equal(early) { + t.Fatalf("expected deadline %v, got %v", early, dl) + } +} + +func TestMergeCtx_Deadline_Expires(t *testing.T) { + p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancelP() + + ctx, cancel := MergeCtx(p) + defer cancel() + + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after deadline expires") + } +} + +func TestMergeCtx_ValueLookup(t *testing.T) { + type key struct{} + p1 := context.WithValue(context.Background(), key{}, "from_p1") + p2 := context.WithValue(context.Background(), key{}, "from_p2") + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + val := ctx.Value(key{}) + if val != "from_p1" { + t.Fatalf("expected 'from_p1', got %v", val) + } +} + +func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) { + type key1 struct{} + type key2 struct{} + p1 := context.WithValue(context.Background(), key1{}, "val1") + p2 := context.WithValue(context.Background(), key2{}, "val2") + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + if v := ctx.Value(key1{}); v != "val1" { + t.Fatalf("expected 'val1', got %v", v) + } + if v := ctx.Value(key2{}); v != "val2" { + t.Fatalf("expected 'val2', got %v", v) + } + if v := ctx.Value("missing"); v != nil { + t.Fatalf("expected nil, got %v", v) + } +} + +func TestMergeCtx_ContextInterface(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + defer cancel2() + + var ctx context.Context + ctx, _ = MergeCtx(p1, p2) + + // Verify all Context interface methods work + _ = ctx.Done() + _ = ctx.Err() + _, _ = ctx.Deadline() + _ = ctx.Value("any") +} + +func TestOrDone_SingleContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + done := orDone(ctx) + + cancel() + <-done // should not block +} + +func TestOrDone_MultipleContexts(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + done := orDone(p1, p2) + + cancel1() + <-done // should not block +} + +func TestOrDone_SecondContextCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + + done := orDone(p1, p2) + + cancel2() + <-done // should not block +} diff --git a/protocols_test.go b/protocols_test.go index 73f16e9..0e2bf1f 100644 --- a/protocols_test.go +++ b/protocols_test.go @@ -70,42 +70,25 @@ func TestApplyDefaultServerConfig(t *testing.T) { } } -func TestRunTLSProtocolInheritance(t *testing.T) { +func TestTLSRunDefaultsProtocolInheritance(t *testing.T) { engine := New() - // 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2 - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, - }) - } - - srv := &http.Server{TLSConfig: &tls.Config{}} - engine.applyDefaultServerConfig(srv) + srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) if !srv.Protocols.HTTP2() { - t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config") + t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config") } - // 模拟用户设置了自定义协议后调用 RunTLS + // 模拟用户设置了自定义协议后进入 TLS 运行模式 engine = New() engine.SetProtocols(&ProtocolsConfig{ Http1: true, Http2: false, // 用户明确不想要 HTTP/2 }) - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, - }) - } - - srv2 := &http.Server{TLSConfig: &tls.Config{}} - engine.applyDefaultServerConfig(srv2) + srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) if srv2.Protocols.HTTP2() { - t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously") + t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols") } } diff --git a/respw.go b/respw.go index dd94db3..ef5cc3c 100644 --- a/respw.go +++ b/respw.go @@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) { // 尝试从底层 ResponseWriter 获取 Hijacker 接口 hj, ok := rw.ResponseWriter.(http.Hijacker) if !ok { - return nil, nil, errors.New("http.Hijacker interface not supported") + return nil, nil, http.ErrNotSupported } // 调用底层的 Hijack 方法 diff --git a/reverseproxy.go b/reverseproxy.go index 1730b1e..1cf0078 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -6,6 +6,8 @@ package touka import ( "context" + "crypto/rand" + "encoding/base64" "errors" "fmt" "io" @@ -14,14 +16,18 @@ import ( "net" "net/http" "net/http/httptrace" + "net/http/httputil" "net/netip" "net/textproto" "net/url" + "regexp" "strconv" "strings" "sync" "sync/atomic" "time" + + "golang.org/x/net/http2" ) // ForwardedHeadersPolicy controls how forwarding headers are generated. @@ -44,32 +50,294 @@ type BufferPool interface { // ReverseProxyConfig configures the reverse proxy handler. type ReverseProxyConfig struct { Target *url.URL + Targets []string - Transport http.RoundTripper + LoadBalancing ReverseProxyLoadBalancingConfig + PassiveHealth ReverseProxyPassiveHealthConfig + + Transport http.RoundTripper FlushInterval time.Duration - BufferPool BufferPool + BufferPool BufferPool + AllowH2CUpstream bool - ModifyRequest func(*http.Request) + ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error - ErrorHandler func(http.ResponseWriter, *http.Request, error) + ErrorHandler func(http.ResponseWriter, *http.Request, error) ForwardedHeaders ForwardedHeadersPolicy - ForwardedBy string - Via string - PreserveHost bool + ForwardedBy string + Via string + PreserveHost bool + + RequestHeaders *HeaderOps + ResponseHeaders *RespHeaderOps } var ( - errReverseProxyNilTarget = errors.New("reverse proxy target is nil") + errReverseProxyNilTarget = errors.New("reverse proxy target is nil") errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") - errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") + errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") + errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams") ) +type HeaderOps struct { + Add map[string][]string + Set map[string][]string + Delete []string + Replace map[string][]Replacement +} + +type Replacement struct { + Search string + Replace string + SearchRegexp string + re *regexp.Regexp +} + +type RespHeaderOps struct { + *HeaderOps + Deferred bool +} + +func (ops *HeaderOps) applyToRequest(req *http.Request) { + if ops == nil { + return + } + ops.applyTo(req.Header, newReverseProxyReplacer(req)) +} + +func (ops *RespHeaderOps) applyToResponse(hdr http.Header) { + if ops == nil { + return + } + ops.applyTo(hdr, newReverseProxyReplacerFromHeader(hdr)) +} + +func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) { + if ops == nil { + return + } + if repl == nil { + repl = &reverseProxyReplacer{} + } + + for fieldName, vals := range ops.Add { + fieldName = repl.Replace(fieldName) + for _, v := range vals { + hdr.Add(fieldName, repl.Replace(v)) + } + } + + for fieldName, vals := range ops.Set { + fieldName = repl.Replace(fieldName) + hdr.Del(fieldName) + for _, v := range vals { + hdr.Add(fieldName, repl.Replace(v)) + } + } + + var deleteAll bool + var exactDeletes []string + var suffixPatterns, prefixPatterns, containsPatterns []string + + for _, fieldName := range ops.Delete { + fieldName = strings.ToLower(repl.Replace(fieldName)) + if fieldName == "*" { + deleteAll = true + break + } + switch { + case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): + containsPatterns = append(containsPatterns, fieldName[1:len(fieldName)-1]) + case strings.HasPrefix(fieldName, "*"): + suffixPatterns = append(suffixPatterns, fieldName[1:]) + case strings.HasSuffix(fieldName, "*"): + prefixPatterns = append(prefixPatterns, fieldName[:len(fieldName)-1]) + default: + exactDeletes = append(exactDeletes, fieldName) + } + } + + if deleteAll { + for k := range hdr { + hdr.Del(k) + } + } else if len(exactDeletes) > 0 || len(suffixPatterns) > 0 || len(prefixPatterns) > 0 || len(containsPatterns) > 0 { + toDelete := make([]string, 0, len(exactDeletes)) + for k := range hdr { + kl := strings.ToLower(k) + for _, d := range exactDeletes { + if kl == d { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range containsPatterns { + if strings.Contains(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range suffixPatterns { + if strings.HasSuffix(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + for _, p := range prefixPatterns { + if strings.HasPrefix(kl, p) { + toDelete = append(toDelete, k) + goto skip + } + } + skip: + } + for _, k := range toDelete { + hdr.Del(k) + } + } + + ops.applyReplace(hdr, repl) +} + +func (ops *HeaderOps) applyReplace(hdr http.Header, repl *reverseProxyReplacer) { + if ops == nil || len(ops.Replace) == 0 { + return + } + for fieldName, replacements := range ops.Replace { + fieldName = http.CanonicalHeaderKey(repl.Replace(fieldName)) + if fieldName == "*" { + for fn, vals := range hdr { + for i := range vals { + for _, r := range replacements { + hdr[fn][i] = r.apply(vals[i]) + } + } + } + continue + } + vals, ok := hdr[fieldName] + if !ok { + continue + } + for i := range vals { + for _, r := range replacements { + hdr[fieldName][i] = r.apply(vals[i]) + } + } + } +} + +func (r *Replacement) apply(s string) string { + if r == nil || s == "" { + return s + } + if r.SearchRegexp != "" && r.re != nil { + return r.re.ReplaceAllString(s, r.Replace) + } + if r.Search != "" { + return strings.ReplaceAll(s, r.Search, r.Replace) + } + return s +} + +func (ops *HeaderOps) Provision() error { + if ops == nil { + return nil + } + for fieldName, replacements := range ops.Replace { + for i, r := range replacements { + if r.SearchRegexp == "" { + continue + } + if r.Search != "" { + return fmt.Errorf("replacement %d for header field %q: cannot specify both Search and SearchRegexp", i, fieldName) + } + re, err := regexp.Compile(r.SearchRegexp) + if err != nil { + return fmt.Errorf("replacement %d for header field %q: %v", i, fieldName, err) + } + replacements[i].re = re + } + } + return nil +} + +type reverseProxyReplacer struct { + method, host, path, query, scheme, uri, proto string +} + +func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer { + if req == nil || req.URL == nil { + return &reverseProxyReplacer{} + } + uri := req.RequestURI + if uri == "" { + uri = req.URL.RequestURI() + } + return &reverseProxyReplacer{ + method: req.Method, + host: req.Host, + path: req.URL.EscapedPath(), + query: req.URL.RawQuery, + scheme: reverseProxyRequestScheme(req), + uri: uri, + proto: req.Proto, + } +} + +func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer { + return &reverseProxyReplacer{} +} + +func (r *reverseProxyReplacer) Replace(s string) string { + if r == nil || s == "" { + return s + } + if r.method != "" { + s = strings.ReplaceAll(s, "{method}", r.method) + } + if r.host != "" { + s = strings.ReplaceAll(s, "{host}", r.host) + } + if r.path != "" { + s = strings.ReplaceAll(s, "{path}", r.path) + } + if r.query != "" { + s = strings.ReplaceAll(s, "{query}", r.query) + } + if r.scheme != "" { + s = strings.ReplaceAll(s, "{scheme}", r.scheme) + } + if r.uri != "" { + s = strings.ReplaceAll(s, "{uri}", r.uri) + } + if r.proto != "" { + s = strings.ReplaceAll(s, "{proto}", r.proto) + } + return s +} + type reverseProxyHandler struct { config ReverseProxyConfig - target *url.URL + upstreams []*reverseProxyUpstream receivedBy string configError error + roundRobin atomic.Uint64 +} + +var reverseProxyCopyBufferPool = sync.Pool{ + New: func() any { + buf := make([]byte, 32*1024) + return &buf + }, +} + +var reverseProxyCandidatePool = sync.Pool{ + New: func() any { + s := make([]*reverseProxyUpstream, 0, 8) + return &s + }, } type reverseProxyStatusError struct { @@ -77,6 +345,34 @@ type reverseProxyStatusError struct { err error } +type reverseProxyExtendedConnectBridge struct { + body io.ReadCloser +} + +type reverseProxyH2ReadWriteCloser struct { + io.ReadCloser + ResponseWriter + controller *http.ResponseController +} + +func (rwc *reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { + n, err := rwc.ResponseWriter.Write(p) + if err != nil { + return n, err + } + if err := rwc.controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + return n, err + } + return n, nil +} + +func (rwc *reverseProxyH2ReadWriteCloser) Close() error { + if rwc.ReadCloser == nil { + return nil + } + return rwc.ReadCloser.Close() +} + func (e *reverseProxyStatusError) Error() string { if e == nil || e.err == nil { return "" @@ -197,19 +493,29 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc { } func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { - target := cloneReverseProxyURL(config.Target) - if target != nil { - normalizeReverseProxyTarget(target) - } - proxy := &reverseProxyHandler{ config: config, - target: target, receivedBy: reverseProxyReceivedBy(config.Via), } - if err := validateReverseProxyTarget(target); err != nil { + if config.RequestHeaders != nil { + if err := config.RequestHeaders.Provision(); err != nil { + proxy.configError = err + return proxy + } + } + if config.ResponseHeaders != nil && config.ResponseHeaders.HeaderOps != nil { + if err := config.ResponseHeaders.HeaderOps.Provision(); err != nil { + proxy.configError = err + return proxy + } + } + + upstreams, err := buildReverseProxyUpstreams(config) + if err != nil { proxy.configError = err + } else { + proxy.upstreams = upstreams } switch config.ForwardedHeaders { @@ -217,6 +523,17 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { default: proxy.config.ForwardedHeaders = ForwardedBoth } + proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy) + if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) { + if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil { + proxy.configError = err + } + } + if proxy.configError == nil { + if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil { + proxy.configError = err + } + } return proxy } @@ -229,62 +546,75 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { return } - transport := p.config.Transport - if transport == nil { - transport = http.DefaultTransport + updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) + if err != nil { + p.handleError(c, err) + return + } + if handledLocally { + return } ctx, cancel := p.requestContext(c) defer cancel() + attempted := make(map[string]struct{}, len(p.upstreams)) + attempts := 0 + started := time.Now() + var lastErr error - outreq := c.Request.Clone(ctx) - if c.Request.ContentLength == 0 { - outreq.Body = nil - } - if outreq.Body != nil { - outreq.Body = &noopCloseReader{readCloser: outreq.Body} - defer outreq.Body.Close() - } - if outreq.Header == nil { - outreq.Header = make(http.Header) - } - outreq.Close = false + for { + upstream, err := p.selectUpstream(c, attempted) + if err != nil { + if lastErr != nil { + p.handleError(c, lastErr) + return + } + p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err}) + return + } - rewriteReverseProxyURL(outreq, p.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + attempts++ + upstream.inFlight.Add(1) + served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards) + upstream.inFlight.Add(-1) - reqUpType := reverseProxyUpgradeType(outreq.Header) - if reqUpType != "" && !isPrintableASCII(reqUpType) { - p.handleError(c, &reverseProxyStatusError{ - status: http.StatusBadRequest, - err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), - }) + if served { + return + } + if attemptErr != nil { + lastErr = attemptErr + } + if retriable && p.shouldRetryAttempt(c.Request, attempts, started) { + attempted[upstream.key] = struct{}{} + if !p.waitRetryInterval(ctx, started) { + if lastErr != nil { + p.handleError(c, lastErr) + } + return + } + continue + } + if attemptErr != nil { + p.handleError(c, attemptErr) + return + } + if lastErr != nil { + p.handleError(c, lastErr) + return + } + p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams}) return } +} - removeHopByHopHeaders(outreq.Header) - if headerValuesContainToken(c.Request.Header["Te"], "trailers") { - outreq.Header.Set("Te", "trailers") - } - if reqUpType != "" { - outreq.Header.Set("Connection", "Upgrade") - outreq.Header.Set("Upgrade", reqUpType) - } - - p.addForwardingHeaders(c.Request, outreq) - appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) - - if _, ok := outreq.Header["User-Agent"]; !ok { - outreq.Header.Set("User-Agent", "") - } - - if p.config.ModifyRequest != nil { - p.config.ModifyRequest(outreq) +func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) { + outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards) + if err != nil { + return false, err, false } + defer cleanup() + transport := p.transportForUpstream(outreq, upstream) rawWriter := reverseProxyBaseResponseWriter(c.Writer) var ( roundTripMu sync.Mutex @@ -314,26 +644,65 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { roundTripDone = true roundTripMu.Unlock() if err != nil { - p.handleError(c, err) - return + if reverseProxyShouldCountPassiveFailure(outreq, err) { + upstream.recordFailure(time.Now(), p.config.PassiveHealth) + } + return false, err, true + } + if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) { + upstream.recordFailure(time.Now(), p.config.PassiveHealth) + } + + if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil { + if res.StatusCode == http.StatusSwitchingProtocols { + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return true, nil, false + } + if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil { + return false, err, false + } + return true, nil, false + } + return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false + } + + if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { + removeHopByHopHeaders(res.Header) + res.Header.Del("Content-Length") + res.Header.Del("Transfer-Encoding") + res.ContentLength = -1 + res.TransferEncoding = nil + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return true, nil, false + } + handleConnect := p.handleConnectResponse + if reverseProxyIsExtendedConnectRequest(outreq) { + handleConnect = p.handleExtendedConnectResponse + } + if err := handleConnect(c, outreq, res, connectWriter); err != nil { + return false, err, false + } + return true, nil, false } if res.StatusCode == http.StatusSwitchingProtocols { appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return + return true, nil, false } if err := p.handleUpgradeResponse(c, outreq, res); err != nil { - p.handleError(c, err) + return false, err, false } - return + return true, nil, false } removeHopByHopHeaders(res.Header) appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return + return true, nil, false } reverseProxyCopyHeader(c.Writer.Header(), res.Header) @@ -353,7 +722,10 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { defer res.Body.Close() c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err)) p.logf(c, "reverse proxy body copy failed: %v", err) - return + if reverseProxyShouldPanicOnCopyError(c.Request) { + panic(http.ErrAbortHandler) + } + return true, nil, false } res.Body.Close() @@ -361,13 +733,9 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { c.Writer.Flush() } - // Keep the stdlib-compatible fallback here. - // If the backend only exposes additional trailer keys after the body has been - // fully read, the trailer map can grow and those values must be written using - // the TrailerPrefix form instead of the pre-announced bare header keys. if len(res.Trailer) == announcedTrailers { reverseProxyCopyHeader(c.Writer.Header(), res.Trailer) - return + return true, nil, false } for key, values := range res.Trailer { @@ -376,6 +744,249 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { c.Writer.Header().Add(prefixedKey, value) } } + return true, nil, false +} + +func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { + outreq := c.Request.Clone(ctx) + bridgeCtx, bridged, err := reverseProxyPrepareExtendedConnectBridge(outreq) + if err != nil { + return nil, nil, nil, err + } + if bridged { + outreq = outreq.WithContext(bridgeCtx) + } + if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { + outreq.Body = nil + } else if c.Request.GetBody != nil { + body, err := c.Request.GetBody() + if err != nil { + return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err) + } + outreq.Body = body + } else if outreq.Body != nil { + outreq.Body = &noopCloseReader{readCloser: outreq.Body} + } + if outreq.Header == nil { + outreq.Header = make(http.Header) + } + outreq.Close = false + var connectWriter *io.PipeWriter + if outreq.Method == http.MethodConnect && !bridged { + pipeReader, pipeWriter := io.Pipe() + outreq.Body = pipeReader + outreq.ContentLength = -1 + connectWriter = pipeWriter + } + cleanup := func() { + if outreq.Body != nil { + _ = outreq.Body.Close() + } + if connectWriter != nil { + _ = connectWriter.Close() + } + } + + if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) { + if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { + cleanup() + return nil, nil, nil, err + } + } else { + rewriteReverseProxyURL(outreq, upstream.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + } + if updatedMaxForwards != "" { + outreq.Header.Set("Max-Forwards", updatedMaxForwards) + } + + reqUpType := reverseProxyUpgradeType(outreq.Header) + if reqUpType != "" && !isPrintableASCII(reqUpType) { + cleanup() + return nil, nil, nil, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), + } + } + + removeHopByHopHeaders(outreq.Header) + if headerValuesContainToken(c.Request.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + p.addForwardingHeaders(c.Request, outreq) + appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) + + if _, ok := outreq.Header["User-Agent"]; !ok { + outreq.Header.Set("User-Agent", "") + } + + if p.config.RequestHeaders != nil { + p.config.RequestHeaders.applyToRequest(outreq) + } + + if p.config.ModifyRequest != nil { + p.config.ModifyRequest(outreq) + } + + return outreq, connectWriter, cleanup, nil +} + +func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper { + if p.config.Transport != nil { + return p.config.Transport + } + if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil { + if upstream.bridgeTransport != nil { + return upstream.bridgeTransport + } + return http.DefaultTransport + } + if upstream.useH2C && upstream.h2cTransport != nil { + return upstream.h2cTransport + } + if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil { + return upstream.extendedConnectTransport + } + return http.DefaultTransport +} + +func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool { + if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) { + return false + } + lb := p.config.LoadBalancing + if lb.TryDuration > 0 { + return time.Since(started) < lb.TryDuration + } + return attempts <= lb.Retries +} + +func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool { + interval := p.config.LoadBalancing.TryInterval + tryDuration := p.config.LoadBalancing.TryDuration + if tryDuration > 0 && interval == 0 { + interval = 250 * time.Millisecond + } + if tryDuration > 0 { + remaining := tryDuration - time.Since(started) + if remaining <= 0 { + return false + } + if interval <= 0 { + return ctx.Err() == nil + } + if interval > remaining { + return false + } + } + if interval <= 0 { + return ctx.Err() == nil + } + timer := time.NewTimer(interval) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } +} + +func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) { + if c == nil || c.Request == nil { + return "", false, nil + } + + switch c.Request.Method { + case http.MethodOptions, http.MethodTrace: + default: + return "", false, nil + } + + rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards")) + if rawValue == "" { + return "", false, nil + } + + value, err := strconv.Atoi(rawValue) + if err != nil || value < 0 { + return "", false, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("invalid Max-Forwards value %q", rawValue), + } + } + if value == 0 { + switch c.Request.Method { + case http.MethodTrace: + return "", true, p.writeLocalTraceResponse(c) + case http.MethodOptions: + p.writeLocalOptionsResponse(c) + return "", true, nil + } + } + + return strconv.Itoa(value - 1), false, nil +} + +func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error { + if c == nil || c.Request == nil { + return nil + } + + traceReq := c.Request.Clone(c.Request.Context()) + traceReq.Body = nil + traceReq.ContentLength = 0 + traceReq.TransferEncoding = nil + traceReq.RequestURI = c.Request.RequestURI + if traceReq.RequestURI == "" && traceReq.URL != nil { + traceReq.RequestURI = traceReq.URL.RequestURI() + } + traceReq.Header = traceReq.Header.Clone() + for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} { + traceReq.Header.Del(key) + } + + dump, err := httputil.DumpRequest(traceReq, false) + if err != nil { + return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err} + } + + c.Writer.Header().Set("Content-Type", "message/http") + c.Writer.WriteHeader(http.StatusOK) + _, err = c.Writer.Write(dump) + return err +} + +func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { + if c == nil { + return + } + + if c.engine != nil { + if c.Request != nil && c.Request.RequestURI != "*" { + if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 { + c.allowedMethodsBuf = allow[:0] + allowHeader := c.allowHeaderBuf[:0] + for i, method := range allow { + if i > 0 { + allowHeader = append(allowHeader, ',', ' ') + } + allowHeader = append(allowHeader, method...) + } + c.allowHeaderBuf = allowHeader[:0] + c.Writer.Header().Set("Allow", string(allowHeader)) + } + } + } + c.Writer.WriteHeader(http.StatusOK) } func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) { @@ -456,7 +1067,14 @@ func appendXForwardedFor(header http.Header, clientIP string) { } func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool { + if p.config.ResponseHeaders != nil && !p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } + if p.config.ModifyResponse == nil { + if p.config.ResponseHeaders != nil && p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } return true } if err := p.config.ModifyResponse(res); err != nil { @@ -464,6 +1082,9 @@ func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req p.handleError(c, err) return false } + if p.config.ResponseHeaders != nil && p.config.ResponseHeaders.Deferred { + p.config.ResponseHeaders.applyToResponse(res.Header) + } return true } @@ -522,7 +1143,11 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques clientConn, brw, err := c.Writer.Hijack() if err != nil { backConn.Close() - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + status := http.StatusBadGateway + if errors.Is(err, http.ErrNotSupported) { + status = http.StatusNotImplemented + } + return &reverseProxyStatusError{status: status, err: err} } defer clientConn.Close() @@ -561,6 +1186,231 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques return firstErr } +func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { + if backWrite == nil { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"), + } + } + backRead := res.Body + + clientConn, brw, err := c.Writer.Hijack() + if err != nil { + backRead.Close() + _ = backWrite.Close() + status := http.StatusBadGateway + if errors.Is(err, http.ErrNotSupported) { + status = http.StatusNotImplemented + } + return &reverseProxyStatusError{status: status, err: err} + } + + defer clientConn.Close() + defer backRead.Close() + defer backWrite.Close() + + backConnClosed := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + case <-backConnClosed: + } + backRead.Close() + _ = backWrite.Close() + }() + defer close(backConnClosed) + + res.Body = nil + if err := res.Write(brw); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + if err := brw.Flush(); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + go func() { + if _, err := io.Copy(clientConn, backRead); err != nil { + errc <- err + return + } + if cw, ok := clientConn.(interface{ CloseWrite() error }); ok { + errc <- cw.CloseWrite() + return + } + errc <- errReverseProxyCopyDone + }() + go func() { + if _, err := io.Copy(backWrite, clientConn); err != nil { + errc <- err + return + } + errc <- backWrite.Close() + }() + + firstErr := <-errc + if firstErr == nil { + firstErr = <-errc + } + if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) { + return nil + } + return firstErr +} + +func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error { + if c == nil || c.Request == nil { + res.Body.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")} + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("backend returned bridged websocket response without writable body"), + } + } + + controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + responseHeader := c.Writer.Header() + reverseProxyCopyHeader(responseHeader, res.Header) + removeHopByHopHeaders(responseHeader) + responseHeader.Del("Sec-WebSocket-Accept") + c.Writer.WriteHeader(http.StatusOK) + if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} + + var closeOnce sync.Once + closeTunnel := func() { + closeOnce.Do(func() { + _ = conn.Close() + _ = backConn.Close() + }) + } + go func() { + <-req.Context().Done() + closeTunnel() + }() + + errc := make(chan error, 2) + copyer := switchProtocolCopier{user: conn, backend: backConn} + go copyer.copyToBackend(errc) + go copyer.copyFromBackend(errc) + + var firstErr error + for range 2 { + err := <-errc + if reverseProxyIsBenignTunnelError(err) { + continue + } + if firstErr == nil { + firstErr = err + closeTunnel() + } + } + closeTunnel() + if reverseProxyIsBenignTunnelError(firstErr) { + return nil + } + return firstErr +} + +func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { + if c == nil || c.Request == nil { + res.Body.Close() + if backWrite != nil { + _ = backWrite.Close() + } + return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")} + } + if backWrite == nil { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"), + } + } + + controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) + if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { + res.Body.Close() + _ = backWrite.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + reverseProxyCopyHeader(c.Writer.Header(), res.Header) + c.Writer.WriteHeader(res.StatusCode) + if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + res.Body.Close() + _ = backWrite.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + var closeOnce sync.Once + closeTunnel := func() { + closeOnce.Do(func() { + _ = c.Request.Body.Close() + _ = backWrite.Close() + _ = res.Body.Close() + }) + } + go func() { + <-req.Context().Done() + closeTunnel() + }() + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(backWrite, c.Request.Body) + closeErr := backWrite.Close() + if err != nil && !reverseProxyIsBenignTunnelError(err) { + errc <- err + return + } + errc <- closeErr + }() + go func() { + copyErr := p.copyResponse(c.Writer, res.Body, -1) + closeErr := res.Body.Close() + if copyErr != nil { + errc <- copyErr + return + } + errc <- closeErr + }() + + var firstErr error + for range 2 { + err := <-errc + if reverseProxyIsBenignTunnelError(err) { + continue + } + if firstErr == nil { + firstErr = err + closeTunnel() + } + } + closeTunnel() + if reverseProxyIsBenignTunnelError(firstErr) { + return nil + } + + return firstErr + +} + func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { return -1 @@ -586,6 +1436,10 @@ func (p *reverseProxyHandler) copyResponse(dst ResponseWriter, src io.Reader, fl if p.config.BufferPool != nil { buf = p.config.BufferPool.Get() defer p.config.BufferPool.Put(buf) + } else { + bufp := reverseProxyCopyBufferPool.Get().(*[]byte) + buf = *bufp + defer reverseProxyCopyBufferPool.Put(bufp) } _, err := p.copyBuffer(writer, src, buf) return err @@ -599,7 +1453,7 @@ func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byt var written int64 for { nr, rerr := src.Read(buf) - if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) { + if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) { p.logf(nil, "reverse proxy read error during body copy: %v", rerr) } if nr > 0 { @@ -638,6 +1492,10 @@ func reverseProxyStatusCode(err error) int { if errors.As(err, &statusErr) && statusErr.status > 0 { return statusErr.status } + var netErr net.Error + if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) { + return http.StatusGatewayTimeout + } return http.StatusBadGateway } @@ -651,6 +1509,75 @@ func validateReverseProxyTarget(target *url.URL) error { return nil } +func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) { + if config.Target != nil && len(config.Targets) > 0 { + return nil, errors.New("reverse proxy Target and Targets cannot be used together") + } + + targets := make([]*url.URL, 0, max(1, len(config.Targets))) + if config.Target != nil { + target := cloneReverseProxyURL(config.Target) + normalizeReverseProxyTarget(target) + if err := validateReverseProxyTarget(target); err != nil { + return nil, err + } + targets = append(targets, target) + } + for i, rawTarget := range config.Targets { + trimmed := strings.TrimSpace(rawTarget) + if trimmed == "" { + return nil, fmt.Errorf("reverse proxy target at index %d is empty", i) + } + target, err := url.Parse(trimmed) + if err != nil { + return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) + } + normalizeReverseProxyTarget(target) + if err := validateReverseProxyTarget(target); err != nil { + return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) + } + targets = append(targets, target) + } + if len(targets) == 0 { + return nil, errReverseProxyNilTarget + } + + upstreams := make([]*reverseProxyUpstream, 0, len(targets)) + for i, target := range targets { + useH2C := strings.EqualFold(target.Scheme, "h2c") + if useH2C { + target = cloneReverseProxyURL(target) + target.Scheme = "http" + } + upstream := &reverseProxyUpstream{ + key: fmt.Sprintf("%d:%s", i, target.String()), + target: target, + index: i, + useH2C: useH2C || config.AllowH2CUpstream, + } + if config.Transport == nil { + upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport() + upstream.bridgeTransport = newHTTP1BridgeTransport() + if upstream.useH2C { + upstream.h2cTransport = newH2CTransport() + } + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil +} + +func validateReverseProxyForwardedBy(value string) error { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return nil + } + if !isValidForwardedNodeIdentifier(trimmed) { + return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value) + } + return nil +} + func normalizeReverseProxyTarget(target *url.URL) { switch strings.ToLower(target.Scheme) { case "ws": @@ -732,6 +1659,136 @@ func buildForwardedHeaderValue(clientIP, by, host, scheme string) string { return strings.Join(pairs, ";") } +func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { + return policy == ForwardedBoth || policy == ForwardedRFC7239Only +} + +func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { + if req == nil { + return context.Background(), false, nil + } + protocol := reverseProxyExtendedConnectProtocol(req) + if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { + return req.Context(), false, nil + } + + bridge := &reverseProxyExtendedConnectBridge{body: req.Body} + ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge) + req.Header.Del(":protocol") + req.Method = http.MethodGet + req.Body = http.NoBody + req.ContentLength = 0 + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Version", "13") + key, err := reverseProxyGenerateWebSocketKey() + if err != nil { + return nil, false, fmt.Errorf("reverse proxy failed to generate websocket key: %w", err) + } + req.Header.Set("Sec-WebSocket-Key", key) + return ctx, true, nil +} + +func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge { + if ctx == nil { + return nil + } + bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge) + return bridge +} + +func reverseProxyGenerateWebSocketKey() (string, error) { + key := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(key), nil +} + +func reverseProxyIsExtendedConnectRequest(req *http.Request) bool { + return reverseProxyExtendedConnectProtocol(req) != "" +} + +func reverseProxyExtendedConnectProtocol(req *http.Request) string { + if req == nil || req.Method != http.MethodConnect || req.Header == nil { + return "" + } + return textproto.TrimString(req.Header.Get(":protocol")) +} + +func isValidForwardedNodeIdentifier(value string) bool { + if value == "" { + return false + } + if strings.HasPrefix(value, "[") { + closing := strings.IndexByte(value, ']') + if closing <= 1 { + return false + } + addr, err := netip.ParseAddr(value[1:closing]) + if err != nil || !addr.Is6() { + return false + } + if closing == len(value)-1 { + return true + } + if value[closing+1] != ':' { + return false + } + return isValidForwardedNodePort(value[closing+2:]) + } + + host, port, hasPort := strings.Cut(value, ":") + if hasPort { + switch { + case host == "unknown", isValidForwardedObfuscatedIdentifier(host): + return isValidForwardedNodePort(port) + default: + addr, err := netip.ParseAddr(host) + return err == nil && addr.Is4() && isValidForwardedNodePort(port) + } + } + + if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) { + return true + } + addr, err := netip.ParseAddr(value) + return err == nil && addr.Is4() +} + +func isValidForwardedNodePort(value string) bool { + if value == "" { + return false + } + if isValidForwardedObfuscatedIdentifier(value) { + return true + } + if len(value) > 5 { + return false + } + port, err := strconv.Atoi(value) + return err == nil && port > 0 && port <= 65535 +} + +func isValidForwardedObfuscatedIdentifier(value string) bool { + if len(value) < 2 || value[0] != '_' { + return false + } + for i := 1; i < len(value); i++ { + b := value[i] + if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') { + continue + } + switch b { + case '.', '_', '-': + continue + default: + return false + } + } + return true +} + func formatForwardedFor(clientIP string) string { addr, err := netip.ParseAddr(clientIP) if err != nil { @@ -799,8 +1856,8 @@ func reverseProxyViaProtocol(major, minor int, raw string) string { if major > 0 { return strconv.Itoa(major) + "." + strconv.Itoa(minor) } - if strings.HasPrefix(raw, "HTTP/") { - return strings.TrimPrefix(raw, "HTTP/") + if after, ok := strings.CutPrefix(raw, "HTTP/"); ok { + return after } return raw } @@ -817,6 +1874,47 @@ func rewriteReverseProxyURL(req *http.Request, target *url.URL) { } } +func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error { + connectTarget, err := reverseProxyConnectTarget(target) + if err != nil { + return &reverseProxyStatusError{status: http.StatusBadRequest, err: err} + } + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = "" + req.URL.RawPath = "" + req.URL.RawQuery = "" + req.URL.Opaque = connectTarget + req.Host = connectTarget + return nil +} + +func reverseProxyConnectTarget(target *url.URL) (string, error) { + if target == nil { + return "", errReverseProxyNilTarget + } + host := target.Hostname() + if host == "" { + return "", errReverseProxyInvalidTarget + } + port := target.Port() + if port == "" { + switch strings.ToLower(target.Scheme) { + case "http": + port = "80" + case "https": + port = "443" + default: + return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme) + } + } + portNum, err := strconv.Atoi(port) + if err != nil || portNum <= 0 || portNum > 65535 { + return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port) + } + return net.JoinHostPort(host, port), nil +} + func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) { if base.RawPath == "" && incoming.RawPath == "" { return reverseProxySingleJoiningSlash(base.Path, incoming.Path), "" @@ -873,7 +1971,7 @@ var reverseProxyHopHeaders = []string{ func removeHopByHopHeaders(header http.Header) { for _, connectionValue := range header["Connection"] { - for _, token := range strings.Split(connectionValue, ",") { + for token := range strings.SplitSeq(connectionValue, ",") { trimmed := textproto.TrimString(token) if trimmed != "" { header.Del(trimmed) @@ -897,7 +1995,7 @@ func headerValuesContainToken(values []string, token string) bool { return false } for _, value := range values { - for _, part := range strings.Split(value, ",") { + for part := range strings.SplitSeq(value, ",") { if strings.EqualFold(textproto.TrimString(part), token) { return true } @@ -919,6 +2017,59 @@ func cleanReverseProxyQueryParams(rawQuery string) string { return values.Encode() } +func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { + return req != nil && req.Context().Value(http.ServerContextKey) != nil +} + +func reverseProxyCanRetryRequest(req *http.Request) bool { + if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) { + return false + } + if req.Body == nil || req.ContentLength == 0 { + return true + } + return req.GetBody != nil +} + +func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool { + if err == nil || reverseProxyIsBenignTunnelError(err) { + return false + } + if req != nil && req.Context().Err() != nil { + return false + } + return !errors.Is(err, context.Canceled) +} + +func reverseProxyMethodIsSafe(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + return true + default: + return false + } +} + +func reverseProxyIsBenignTunnelError(err error) bool { + return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err) +} + +func reverseProxyIsClosedBodyError(err error) bool { + if err == nil { + return false + } + var streamErr http2.StreamError + if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel { + return true + } + switch err.Error() { + case "body closed by handler", "http2: response body closed", "response body closed": + return true + default: + return false + } +} + func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { return UnwrapResponseWriter(writer) } diff --git a/reverseproxy_benchmark_test.go b/reverseproxy_benchmark_test.go new file mode 100644 index 0000000..b496f5c --- /dev/null +++ b/reverseproxy_benchmark_test.go @@ -0,0 +1,355 @@ +package touka + +import ( + "bufio" + "bytes" + "errors" + "io" + "net" + "net/http" + "strings" + "testing" + "time" +) + +type benchmarkReadSeeker struct { + data []byte + off int +} + +func (r *benchmarkReadSeeker) Read(p []byte) (int, error) { + if r.off >= len(r.data) { + return 0, io.EOF + } + n := copy(p, r.data[r.off:]) + r.off += n + return n, nil +} + +func (r *benchmarkReadSeeker) Reset() { + r.off = 0 +} + +type benchmarkResponseWriter struct { + header http.Header + status int + size int +} + +func newBenchmarkResponseWriter() *benchmarkResponseWriter { + return &benchmarkResponseWriter{header: make(http.Header)} +} + +func (w *benchmarkResponseWriter) Header() http.Header { + return w.header +} + +func (w *benchmarkResponseWriter) WriteHeader(statusCode int) { + if w.status == 0 { + w.status = statusCode + } +} + +func (w *benchmarkResponseWriter) Write(p []byte) (int, error) { + if w.status == 0 { + w.status = http.StatusOK + } + w.size += len(p) + return len(p), nil +} + +func (w *benchmarkResponseWriter) Flush() {} + +func (w *benchmarkResponseWriter) Status() int { + return w.status +} + +func (w *benchmarkResponseWriter) Size() int { + return w.size +} + +func (w *benchmarkResponseWriter) Written() bool { + return w.status != 0 +} + +func (w *benchmarkResponseWriter) IsHijacked() bool { + return false +} + +func (w *benchmarkResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} + +func (w *benchmarkResponseWriter) reset() { + clear(w.header) + w.status = 0 + w.size = 0 +} + +var benchmarkReverseProxySink int + +func BenchmarkReverseProxyCopyResponse(b *testing.B) { + body := bytes.Repeat([]byte("0123456789abcdef"), 4096) + proxy := newReverseProxyHandler(ReverseProxyConfig{}) + dst := newBenchmarkResponseWriter() + src := &benchmarkReadSeeker{data: body} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + dst.reset() + src.Reset() + if err := proxy.copyResponse(dst, src, 0); err != nil { + b.Fatalf("copyResponse failed: %v", err) + } + } + + benchmarkReverseProxySink = dst.Size() +} + +func BenchmarkReverseProxyAvailableUpstreams(b *testing.B) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + {key: "d", index: 3}, + }, + config: ReverseProxyConfig{ + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + MaxFails: 3, + }, + }, + } + + now := time.Now() + proxy.upstreams[0].failures = []time.Time{now.Add(-30 * time.Second)} + proxy.upstreams[1].failures = []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkReverseProxySink = len(proxy.availableUpstreams(now, nil)) + } +} + +func BenchmarkReverseProxySelectUpstream(b *testing.B) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + {key: "d", index: 3}, + }, + config: ReverseProxyConfig{ + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()}, + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + MaxFails: 3, + }, + }, + } + proxy.upstreams[0].failures = []time.Time{time.Now().Add(-30 * time.Second)} + + c, _ := CreateTestContext(nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + selected, err := proxy.selectUpstream(c, nil) + if err != nil { + b.Fatalf("selectUpstream failed: %v", err) + } + benchmarkReverseProxySink = selected.index + } +} + +func BenchmarkReverseProxySelectUpstreamHeaderPolicy(b *testing.B) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + {key: "d", index: 3}, + }, + config: ReverseProxyConfig{ + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())}, + }, + } + c, _ := CreateTestContext(nil) + c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b", "tenant-c"} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + selected, err := proxy.selectUpstream(c, nil) + if err != nil { + b.Fatalf("selectUpstream failed: %v", err) + } + benchmarkReverseProxySink = selected.index + } +} + +func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) { + proxy := newReverseProxyHandler(ReverseProxyConfig{}) + dst := newBenchmarkResponseWriter() + src := bytes.NewBufferString("hello, reverse proxy") + + if err := proxy.copyResponse(dst, src, 0); err != nil { + t.Fatalf("copyResponse failed: %v", err) + } + + if got, want := dst.Size(), len("hello, reverse proxy"); got != want { + t.Fatalf("expected %d bytes copied, got %d", want, got) + } +} + +type fixedLenBufferPool struct { + buf []byte +} + +func (p *fixedLenBufferPool) Get() []byte { + return p.buf +} + +func (p *fixedLenBufferPool) Put(buf []byte) { + p.buf = buf +} + +type recordingReader struct { + chunk int + reads []int + left int +} + +func (r *recordingReader) Read(p []byte) (int, error) { + if r.left == 0 { + return 0, io.EOF + } + n := min(r.chunk, len(p), r.left) + if n == 0 { + return 0, errors.New("reader received zero-length buffer") + } + for i := range n { + p[i] = 'x' + } + r.left -= n + r.reads = append(r.reads, len(p)) + return n, nil +} + +func TestReverseProxyCopyResponseRespectsCustomBufferLength(t *testing.T) { + pool := &fixedLenBufferPool{buf: make([]byte, 8, 32*1024)} + proxy := newReverseProxyHandler(ReverseProxyConfig{BufferPool: pool}) + dst := newBenchmarkResponseWriter() + src := &recordingReader{chunk: 8, left: 24} + + if err := proxy.copyResponse(dst, src, 0); err != nil { + t.Fatalf("copyResponse failed: %v", err) + } + + if len(src.reads) == 0 { + t.Fatal("expected reader to be used") + } + for _, size := range src.reads { + if size != 8 { + t.Fatalf("expected custom buffer length 8 to be preserved, got read size %d", size) + } + } +} + +func TestReverseProxyAvailableUpstreamsFiltersExcludedAndUnhealthy(t *testing.T) { + now := time.Now() + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a"}, + {key: "b", failures: []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}}, + {key: "c"}, + }, + config: ReverseProxyConfig{ + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + MaxFails: 2, + }, + }, + } + + available := proxy.availableUpstreams(now, map[string]struct{}{"c": {}}) + if len(available) != 1 { + t.Fatalf("expected only one available upstream, got %d", len(available)) + } + if available[0].key != "a" { + t.Fatalf("expected upstream 'a', got %q", available[0].key) + } +} + +func TestReverseProxyHeaderPolicyUsesAllHeaderValues(t *testing.T) { + proxy := &reverseProxyHandler{ + upstreams: []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + }, + config: ReverseProxyConfig{ + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())}, + }, + } + + c, _ := CreateTestContext(nil) + c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b"} + + selectedA, err := proxy.selectUpstream(c, nil) + if err != nil { + t.Fatalf("selectUpstream failed: %v", err) + } + selectedB, err := proxy.selectUpstream(c, nil) + if err != nil { + t.Fatalf("selectUpstream failed: %v", err) + } + if selectedA.key != selectedB.key { + t.Fatalf("expected stable selection for identical multi-value header, got %q and %q", selectedA.key, selectedB.key) + } + + c.Request.Header["X-Tenant"] = []string{"tenant-b", "tenant-a"} + selectedC, err := proxy.selectUpstream(c, nil) + if err != nil { + t.Fatalf("selectUpstream failed: %v", err) + } + if selectedC == nil { + t.Fatal("expected upstream for reordered multi-value header") + } +} + +func TestReverseProxyHeaderPolicyMatchesJoinCompatibility(t *testing.T) { + candidates := []*reverseProxyUpstream{ + {key: "a", index: 0}, + {key: "b", index: 1}, + {key: "c", index: 2}, + } + + testCases := [][]string{ + {"tenant-a"}, + {"tenant-a", "tenant-b"}, + {"", "tenant-b"}, + {"tenant-a", ""}, + {"", ""}, + } + + for _, values := range testCases { + got := reverseProxySelectHRWValues(candidates, values) + want := reverseProxySelectHRW(candidates, strings.Join(values, ",")) + if got == nil || want == nil { + t.Fatalf("expected non-nil upstreams for values %v", values) + } + if got.key != want.key { + t.Fatalf("expected joined compatibility for values %v, got %q want %q", values, got.key, want.key) + } + } +} + +var _ io.Writer = (*benchmarkResponseWriter)(nil) diff --git a/reverseproxy_headers_replace_test.go b/reverseproxy_headers_replace_test.go new file mode 100644 index 0000000..0c0d599 --- /dev/null +++ b/reverseproxy_headers_replace_test.go @@ -0,0 +1,530 @@ +package touka + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "testing" +) + +func TestReverseProxyHeaderOpsReplaceSubstring(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Server"); got != "Caddy" { + t.Errorf("expected X-Server=Caddy, got %q", got) + } + if got := r.Header.Get("X-Location"); got != "/api/v2/resource" { + t.Errorf("expected X-Location=/api/v2/resource, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Server": {{Search: "NGINX", Replace: "Caddy"}}, + "X-Location": {{Search: "v1", Replace: "v2"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Server", "NGINX") + req.Header.Set("X-Location", "/api/v1/resource") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsReplaceRegexp(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Route"); got != "/proxy-upstream" { + t.Errorf("expected X-Route=/proxy-upstream, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Route": {{SearchRegexp: `^/([^/]+)/(.+)$`, Replace: "/proxy-$2"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Route", "/original/upstream") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsReplaceWildcard(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Host-A"); got != "new.example.com" { + t.Errorf("expected X-Host-A=new.example.com, got %q", got) + } + if got := r.Header.Get("X-Host-B"); got != "new.example.com" { + t.Errorf("expected X-Host-B=new.example.com, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "*": {{Search: "old.example.com", Replace: "new.example.com"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Host-A", "old.example.com") + req.Header.Set("X-Host-B", "old.example.com") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "backend-internal:8080") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Backend": {{Search: "backend-internal:8080", Replace: "public.example.com"}}, + }, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Backend"); got != "public.example.com" { + t.Errorf("expected X-Backend=public.example.com, got %q", got) + } +} + +func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Replace: map[string][]Replacement{ + "X-Test": {{SearchRegexp: "[invalid"}}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", resp.StatusCode) + } +} + +func TestReplacementApply(t *testing.T) { + tests := []struct { + name string + r *Replacement + s string + want string + }{ + {name: "nil replacement", r: nil, s: "hello", want: "hello"}, + {name: "empty string", r: &Replacement{Search: "x", Replace: "y"}, s: "", want: ""}, + {name: "substring match", r: &Replacement{Search: "world", Replace: "go"}, s: "hello world", want: "hello go"}, + {name: "substring no match", r: &Replacement{Search: "foo", Replace: "bar"}, s: "hello world", want: "hello world"}, + {name: "substring multiple", r: &Replacement{Search: "a", Replace: "b"}, s: "aaa", want: "bbb"}, + {name: "regexp match", r: &Replacement{SearchRegexp: `\d+`, Replace: "N", re: regexp.MustCompile(`\d+`)}, s: "abc123def", want: "abcNdef"}, + {name: "regexp no match", r: &Replacement{SearchRegexp: `z+`, Replace: "Z", re: regexp.MustCompile(`z+`)}, s: "abc", want: "abc"}, + {name: "empty search and regexp", r: &Replacement{}, s: "unchanged", want: "unchanged"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.apply(tt.s); got != tt.want { + t.Errorf("Replacement.apply() = %q, want %q", got, tt.want) + } + }) + } +} + +func BenchmarkHeaderOpsAdd(b *testing.B) { + ops := &HeaderOps{ + Add: map[string][]string{ + "X-Custom-1": {"value-1"}, + "X-Custom-2": {"value-2"}, + "X-Custom-3": {"value-3"}, + }, + } + hdr := make(http.Header) + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr = make(http.Header) + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsSet(b *testing.B) { + ops := &HeaderOps{ + Set: map[string][]string{ + "X-Frame-Options": {"DENY"}, + "X-Content-Type-Options": {"nosniff"}, + "X-XSS-Protection": {"1; mode=block"}, + }, + } + hdr := make(http.Header) + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr = make(http.Header) + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsDeleteSingle(b *testing.B) { + ops := &HeaderOps{ + Delete: []string{"X-Powered-By"}, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Powered-By", "Express") + hdr.Set("X-Keep", "value") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsDeleteWildcard(b *testing.B) { + ops := &HeaderOps{ + Delete: []string{"X-Debug-*"}, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Debug-1", "v1") + hdr.Set("X-Debug-2", "v2") + hdr.Set("X-Keep", "value") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsReplaceSubstring(b *testing.B) { + ops := &HeaderOps{ + Replace: map[string][]Replacement{ + "Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("Location", "http://internal:8080/api/v1/users") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsReplaceRegexp(b *testing.B) { + re := regexp.MustCompile(`^http://([^/]+)(/.*)$`) + ops := &HeaderOps{ + Replace: map[string][]Replacement{ + "Location": {{SearchRegexp: `^http://([^/]+)(/.*)$`, Replace: "https://public.example.com$2", re: re}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("Location", "http://internal:8080/api/v1/users") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsReplaceWildcard(b *testing.B) { + ops := &HeaderOps{ + Replace: map[string][]Replacement{ + "*": {{Search: "internal.example.com", Replace: "public.example.com"}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Host", "internal.example.com") + hdr.Set("X-Origin", "internal.example.com") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkHeaderOpsMixed(b *testing.B) { + ops := &HeaderOps{ + Add: map[string][]string{ + "X-Request-ID": {"req-123"}, + }, + Set: map[string][]string{ + "X-Frame-Options": {"DENY"}, + }, + Delete: []string{"X-Powered-By"}, + Replace: map[string][]Replacement{ + "Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}}, + }, + } + repl := &reverseProxyReplacer{} + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hdr := make(http.Header) + hdr.Set("X-Powered-By", "Express") + hdr.Set("Location", "http://internal:8080/api") + ops.applyTo(hdr, repl) + } +} + +func BenchmarkReplacementApplySubstring(b *testing.B) { + r := &Replacement{Search: "old.example.com", Replace: "new.example.com"} + s := "https://old.example.com/api/v1/resource" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = r.apply(s) + } +} + +func BenchmarkReplacementApplyRegexp(b *testing.B) { + r := &Replacement{SearchRegexp: `^https?://[^/]+`, Replace: "https://new.example.com", re: regexp.MustCompile(`^https?://[^/]+`)} + s := "https://old.example.com/api/v1/resource" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = r.apply(s) + } +} + +func TestReverseProxyReplacerDynamicVars(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com/api/v1/users?sort=name&limit=10", nil) + req.Host = "example.com" + repl := newReverseProxyReplacer(req) + + tests := []struct { + name string + input string + want string + }{ + {"method", "{method}", "GET"}, + {"host", "{host}", "example.com"}, + {"path", "{path}", "/api/v1/users"}, + {"query", "{query}", "sort=name&limit=10"}, + {"scheme", "{scheme}", "http"}, + {"proto", "{proto}", "HTTP/1.1"}, + {"combined", "X-{method}-{path}", "X-GET-/api/v1/users"}, + {"no vars", "static-value", "static-value"}, + {"empty", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := repl.Replace(tt.input); got != tt.want { + t.Errorf("Replace(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestReverseProxyReplacerNilRequest(t *testing.T) { + repl := newReverseProxyReplacer(nil) + if got := repl.Replace("{method}"); got != "{method}" { + t.Errorf("expected unchanged string with nil request, got %q", got) + } +} + +func TestReverseProxyReplacerNilReplacer(t *testing.T) { + var repl *reverseProxyReplacer + if got := repl.Replace("{method}"); got != "{method}" { + t.Errorf("expected unchanged string with nil replacer, got %q", got) + } +} + +func TestReverseProxyReplacerFromHeader(t *testing.T) { + hdr := make(http.Header) + repl := newReverseProxyReplacerFromHeader(hdr) + if got := repl.Replace("{method}"); got != "{method}" { + t.Errorf("expected unchanged string from header replacer, got %q", got) + } +} + +func TestReverseProxyHeaderOpsWithDynamicVars(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Forwarded-Path"); got != "/dynamic/path" { + t.Errorf("expected X-Forwarded-Path=/dynamic/path, got %q", got) + } + if got := r.Header.Get("X-Forwarded-Method"); got != "GET" { + t.Errorf("expected X-Forwarded-Method=GET, got %q", got) + } + if got := r.Header.Get("X-Forwarded-Host"); got != "client.example" { + t.Errorf("expected X-Forwarded-Host=client.example, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/dynamic/path", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Add: map[string][]string{ + "X-Forwarded-Path": {"{path}"}, + "X-Forwarded-Method": {"{method}"}, + "X-Forwarded-Host": {"{host}"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/dynamic/path", nil) + req.Host = "client.example" + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} diff --git a/reverseproxy_headers_test.go b/reverseproxy_headers_test.go new file mode 100644 index 0000000..4a4ae26 --- /dev/null +++ b/reverseproxy_headers_test.go @@ -0,0 +1,220 @@ +package touka + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestReverseProxyHeaderOpsAdd(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Custom-Header"); got != "test-value" { + t.Errorf("expected X-Custom-Header=test-value, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Add: map[string][]string{ + "X-Custom-Header": {"test-value"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsDelete(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Sensitive") != "" { + t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive")) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Delete: []string{"X-Sensitive"}, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Sensitive", "should-be-removed") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyHeaderOpsSet(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got := r.Header.Get("X-Replace") + if got != "new-value" { + t.Errorf("expected X-Replace=new-value, got %q", got) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + RequestHeaders: &HeaderOps{ + Set: map[string][]string{ + "X-Replace": {"new-value"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) + req.Header.Set("X-Replace", "old-value") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestReverseProxyResponseHeaderOps(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "backend-server") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Set: map[string][]string{ + "X-Custom": {"custom-value"}, + }, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Custom"); got != "custom-value" { + t.Errorf("expected X-Custom=custom-value, got %q", got) + } +} + +func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Powered-By", "Express") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/test", ReverseProxy(ReverseProxyConfig{ + Target: target, + ResponseHeaders: &RespHeaderOps{ + HeaderOps: &HeaderOps{ + Delete: []string{"X-Powered-By"}, + }, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := http.Get(proxy.URL + "/test") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + if got := resp.Header.Get("X-Powered-By"); got != "" { + t.Errorf("expected X-Powered-By to be deleted, got %q", got) + } +} diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go new file mode 100644 index 0000000..ce5e949 --- /dev/null +++ b/reverseproxy_lb.go @@ -0,0 +1,409 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2026 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "fmt" + "math/rand/v2" + "net/http" + "net/textproto" + "net/url" + "slices" + "strings" + "sync" + "sync/atomic" + "time" +) + +// ReverseProxyLoadBalancingConfig configures upstream selection and retries. +type ReverseProxyLoadBalancingConfig struct { + Policy ReverseProxyLBPolicy + Retries int + TryDuration time.Duration + TryInterval time.Duration +} + +// ReverseProxyPassiveHealthConfig configures inline passive health tracking. +type ReverseProxyPassiveHealthConfig struct { + FailDuration time.Duration + MaxFails int + UnhealthyStatus []int +} + +// ReverseProxyLBPolicy selects an upstream from the configured target pool. +// Use the helper constructors such as LBRandom or LBHeader to build a policy. +type ReverseProxyLBPolicy struct { + kind reverseProxyLBPolicyKind + key string + fallback *ReverseProxyLBPolicy +} + +type reverseProxyLBPolicyKind uint8 + +const ( + reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota + reverseProxyLBPolicyRoundRobin + reverseProxyLBPolicyFirst + reverseProxyLBPolicyLeastConn + reverseProxyLBPolicyIPHash + reverseProxyLBPolicyClientIPHash + reverseProxyLBPolicyURIHash + reverseProxyLBPolicyHeader + reverseProxyLBPolicyQuery +) + +type reverseProxyUpstream struct { + key string + target *url.URL + index int + useH2C bool + extendedConnectTransport http.RoundTripper + bridgeTransport http.RoundTripper + h2cTransport http.RoundTripper + inFlight atomic.Int64 + + passiveMu sync.Mutex + failures []time.Time +} + +func LBRandom() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom} +} + +func LBRoundRobin() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin} +} + +func LBFirst() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst} +} + +func LBLeastConn() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn} +} + +func LBIPHash() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash} +} + +func LBClientIPHash() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash} +} + +func LBURIHash() ReverseProxyLBPolicy { + return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash} +} + +func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy { + policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))} + if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil { + policy.fallback = &fallback + } + return policy +} + +func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy { + policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)} + if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil { + policy.fallback = &fallback + } + return policy +} + +func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error { + switch policy.kind { + case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst, + reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash, + reverseProxyLBPolicyURIHash: + return nil + case reverseProxyLBPolicyHeader: + if policy.key == "" { + return fmt.Errorf("reverse proxy header load-balancing policy requires a header field") + } + case reverseProxyLBPolicyQuery: + if policy.key == "" { + return fmt.Errorf("reverse proxy query load-balancing policy requires a query key") + } + default: + return fmt.Errorf("reverse proxy load-balancing policy is invalid") + } + if policy.fallback != nil { + return validateReverseProxyLBPolicy(*policy.fallback) + } + return nil +} + +func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) { + now := time.Now() + policy := p.config.LoadBalancing.Policy + candidateBuf := reverseProxyCandidatePool.Get().(*[]*reverseProxyUpstream) + candidates := p.availableUpstreamsInto(now, excluded, *candidateBuf) + if len(candidates) == 0 && len(excluded) > 0 { + candidates = p.availableUpstreamsInto(now, nil, candidates[:0]) + } + if len(candidates) == 0 { + *candidateBuf = candidates[:0] + reverseProxyCandidatePool.Put(candidateBuf) + return nil, errReverseProxyNoAvailableUpstreams + } + selected := p.selectUpstreamWithPolicy(c, candidates, policy) + *candidateBuf = candidates[:0] + reverseProxyCandidatePool.Put(candidateBuf) + return selected, nil +} + +func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream { + return p.availableUpstreamsInto(now, excluded, nil) +} + +func (p *reverseProxyHandler) availableUpstreamsInto(now time.Time, excluded map[string]struct{}, candidates []*reverseProxyUpstream) []*reverseProxyUpstream { + if cap(candidates) < len(p.upstreams) { + candidates = make([]*reverseProxyUpstream, 0, len(p.upstreams)) + } else { + candidates = candidates[:0] + } + for _, upstream := range p.upstreams { + if _, skip := excluded[upstream.key]; skip { + continue + } + if !upstream.healthy(now, p.config.PassiveHealth) { + continue + } + candidates = append(candidates, upstream) + } + return candidates +} + +func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + + switch policy.kind { + case reverseProxyLBPolicyRoundRobin: + return candidates[p.nextRoundRobinIndex(len(candidates))] + case reverseProxyLBPolicyFirst: + return candidates[0] + case reverseProxyLBPolicyLeastConn: + return p.selectLeastConnUpstream(candidates) + case reverseProxyLBPolicyIPHash: + return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr)) + case reverseProxyLBPolicyClientIPHash: + return reverseProxySelectHRW(candidates, c.RequestIP()) + case reverseProxyLBPolicyURIHash: + if c.Request == nil || c.Request.URL == nil { + return reverseProxySelectRandom(candidates) + } + return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI()) + case reverseProxyLBPolicyHeader: + if c.Request != nil && c.Request.Header != nil { + if values, ok := c.Request.Header[policy.key]; ok { + return reverseProxySelectHRWValues(candidates, values) + } + } + return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) + case reverseProxyLBPolicyQuery: + if c.Request != nil && c.Request.URL != nil { + if values, ok := c.Request.URL.Query()[policy.key]; ok { + return reverseProxySelectHRW(candidates, strings.Join(values, ",")) + } + } + return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) + case reverseProxyLBPolicyRandom: + fallthrough + default: + return reverseProxySelectRandom(candidates) + } +} + +func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int { + if size <= 1 { + return 0 + } + return int((p.roundRobin.Add(1) - 1) % uint64(size)) +} + +func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + selected := candidates[0] + lowest := selected.inFlight.Load() + ties := []*reverseProxyUpstream{selected} + for _, upstream := range candidates[1:] { + count := upstream.inFlight.Load() + switch { + case count < lowest: + selected = upstream + lowest = count + ties = []*reverseProxyUpstream{upstream} + case count == lowest: + ties = append(ties, upstream) + } + } + if len(ties) == 1 { + return selected + } + return ties[p.nextRoundRobinIndex(len(ties))] +} + +func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + if len(candidates) == 1 { + return candidates[0] + } + return candidates[rand.IntN(len(candidates))] +} + +func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + if key == "" { + return reverseProxySelectRandom(candidates) + } + selected := candidates[0] + bestScore := reverseProxyHRWScore(key, selected.key) + for _, upstream := range candidates[1:] { + score := reverseProxyHRWScore(key, upstream.key) + if score > bestScore { + selected = upstream + bestScore = score + } + } + return selected +} + +func reverseProxySelectHRWValues(candidates []*reverseProxyUpstream, values []string) *reverseProxyUpstream { + if len(candidates) == 0 { + return nil + } + if len(values) == 0 { + return reverseProxySelectRandom(candidates) + } + selected := candidates[0] + bestScore := reverseProxyHRWValuesScore(values, selected.key) + for _, upstream := range candidates[1:] { + score := reverseProxyHRWValuesScore(values, upstream.key) + if score > bestScore { + selected = upstream + bestScore = score + } + } + return selected +} + +func reverseProxyHRWScore(key, upstreamKey string) uint64 { + const ( + offset64 = 14695981039346656037 + prime64 = 1099511628211 + ) + h := uint64(offset64) + for i := 0; i < len(key); i++ { + h ^= uint64(key[i]) + h *= prime64 + } + h ^= 0xff + h *= prime64 + for i := 0; i < len(upstreamKey); i++ { + h ^= uint64(upstreamKey[i]) + h *= prime64 + } + return h +} + +func reverseProxyHRWValuesScore(values []string, upstreamKey string) uint64 { + const ( + offset64 = 14695981039346656037 + prime64 = 1099511628211 + ) + h := uint64(offset64) + for valueIndex, value := range values { + for i := 0; i < len(value); i++ { + h ^= uint64(value[i]) + h *= prime64 + } + if valueIndex+1 < len(values) { + h ^= ',' + h *= prime64 + } + } + h ^= 0xff + h *= prime64 + for i := 0; i < len(upstreamKey); i++ { + h ^= uint64(upstreamKey[i]) + h *= prime64 + } + return h +} + +func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy { + if policy.fallback != nil { + return *policy.fallback + } + return LBRandom() +} + +func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool { + maxFails := reverseProxyPassiveMaxFails(config) + if config.FailDuration <= 0 || maxFails <= 0 { + return true + } + + u.passiveMu.Lock() + defer u.passiveMu.Unlock() + u.pruneFailuresLocked(now, config.FailDuration) + return len(u.failures) < maxFails +} + +func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) { + maxFails := reverseProxyPassiveMaxFails(config) + if config.FailDuration <= 0 || maxFails <= 0 { + return + } + + u.passiveMu.Lock() + defer u.passiveMu.Unlock() + u.pruneFailuresLocked(now, config.FailDuration) + u.failures = append(u.failures, now) +} + +func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) { + if len(u.failures) == 0 || window <= 0 { + if window <= 0 { + u.failures = nil + } + return + } + cutoff := now.Add(-window) + keep := 0 + for _, failureAt := range u.failures { + if failureAt.Before(cutoff) { + continue + } + u.failures[keep] = failureAt + keep++ + } + u.failures = u.failures[:keep] +} + +func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int { + if config.FailDuration <= 0 { + return 0 + } + if config.MaxFails <= 0 { + return 1 + } + return config.MaxFails +} + +func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool { + if status <= 0 { + return false + } + return slices.Contains(config.UnhealthyStatus, status) +} diff --git a/reverseproxy_test.go b/reverseproxy_test.go index f82aff9..6863da7 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -2,6 +2,10 @@ package touka import ( "bufio" + "bytes" + "context" + crand "crypto/rand" + "crypto/tls" "errors" "fmt" "io" @@ -11,9 +15,14 @@ import ( "net/http/httptrace" "net/textproto" "net/url" + "strconv" "strings" + "sync" + "sync/atomic" "testing" "time" + + "golang.org/x/net/http2" ) func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { @@ -70,7 +79,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ Target: target, ForwardedHeaders: ForwardedBoth, - ForwardedBy: "proxy-node", + ForwardedBy: "_proxy-node", Via: "proxy.test", })) @@ -106,7 +115,8 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { t.Fatalf("unexpected body: %q", string(body)) } if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) } if got.Path != "/base/api/ping" { t.Fatalf("unexpected upstream path: %q", got.Path) @@ -144,7 +154,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { if !strings.Contains(got.Forwarded, "for=198.51.100.10") { t.Fatalf("forwarded header missing client ip: %q", got.Forwarded) } - if !strings.Contains(got.Forwarded, "by=proxy-node") { + if !strings.Contains(got.Forwarded, "by=_proxy-node") { t.Fatalf("forwarded header missing by token: %q", got.Forwarded) } if !strings.Contains(got.Forwarded, "host=client.example") { @@ -170,6 +180,61 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { } } +func TestReverseProxyRejectsInvalidForwardedBy(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + ForwardedHeaders: ForwardedBoth, + ForwardedBy: "proxy-node", + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusInternalServerError { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyForwardedByTrimsWhitespace(t *testing.T) { + t.Helper() + + forwardedCh := make(chan string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + forwardedCh <- r.Header.Get("Forwarded") + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, backend.URL), + ForwardedHeaders: ForwardedBoth, + ForwardedBy: " _proxy-node ", + })) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + req.RemoteAddr = "198.51.100.10:4567" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case forwarded := <-forwardedCh: + if !strings.Contains(forwarded, "by=_proxy-node") { + t.Fatalf("unexpected Forwarded header: %q", forwarded) + } + if strings.Contains(forwarded, `by=" _proxy-node "`) { + t.Fatalf("forwarded header should not preserve surrounding whitespace: %q", forwarded) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Forwarded header") + } +} + func TestReverseProxyDefaultViaFallback(t *testing.T) { t.Helper() @@ -203,6 +268,544 @@ func TestReverseProxyDefaultViaFallback(t *testing.T) { } } +func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Targets: []string{"http://example.net"}, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusInternalServerError { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) { + t.Helper() + + type snapshot struct { + Path string + RawQuery string + } + + backendOneCh := make(chan snapshot, 1) + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery} + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwoCh := make(chan snapshot, 1) + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery} + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ + Targets: []string{ + backendOne.URL + "/one?from=one", + backendTwo.URL + "/two?from=two", + }, + LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()}, + })) + + first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil) + if first.Code != http.StatusOK || first.Body.String() != "one" { + t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String()) + } + second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil) + if second.Code != http.StatusOK || second.Body.String() != "two" { + t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String()) + } + + select { + case got := <-backendOneCh: + if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" { + t.Fatalf("unexpected first upstream request: %#v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first upstream request") + } + + select { + case got := <-backendTwoCh: + if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" { + t.Fatalf("unexpected second upstream request: %#v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for second upstream request") + } +} + +func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) { + t.Helper() + + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBHeader("X-Upstream", LBFirst()), + }, + })) + + fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" { + t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String()) + } + + headers := http.Header{"X-Upstream": {"tenant-a"}} + firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers) + secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers) + if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK { + t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code) + } + if firstSticky.Body.String() != secondSticky.Body.String() { + t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String()) + } +} + +func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) { + t.Helper() + + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBQuery("tenant", LBFirst()), + }, + })) + + fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" { + t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String()) + } + + firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil) + secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil) + if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK { + t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code) + } + if firstSticky.Body.String() != secondSticky.Body.String() { + t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String()) + } +} + +func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) { + t.Helper() + + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"}) + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBClientIPHash(), + }, + })) + + reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + reqOne.RemoteAddr = "10.0.0.1:1234" + reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10") + rrOne := httptest.NewRecorder() + engine.ServeHTTP(rrOne, reqOne) + + reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) + reqTwo.RemoteAddr = "10.0.0.2:5678" + reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10") + rrTwo := httptest.NewRecorder() + engine.ServeHTTP(rrTwo, reqTwo) + + if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK { + t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code) + } + if rrOne.Body.String() != rrTwo.Body.String() { + t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String()) + } +} + +func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 1, + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String()) + } +} + +func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, r.Header.Get("X-Attempt")) + })) + defer backend.Close() + + var attempts atomic.Int64 + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 1, + }, + ModifyRequest: func(req *http.Request) { + req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10)) + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rr.Code) + } + if rr.Body.String() != "2" { + t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String()) + } +} + +func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) { + t.Helper() + + backendCalls := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCalls <- struct{}{} + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 1, + }, + })) + + rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil) + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case <-backendCalls: + t.Fatal("unsafe POST request should not be retried to the next upstream") + default: + } +} + +func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) { + t.Helper() + + backendOneStarted := make(chan struct{}, 1) + releaseBackendOne := make(chan struct{}) + backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendOneStarted <- struct{}{} + <-releaseBackendOne + _, _ = io.WriteString(w, "one") + })) + defer backendOne.Close() + + backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "two") + })) + defer backendTwo.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBLeastConn(), + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + client := proxy.Client() + client.Timeout = 5 * time.Second + + firstRespCh := make(chan string, 1) + firstErrCh := make(chan error, 1) + go func() { + resp, err := client.Get(proxy.URL + "/proxy") + if err != nil { + firstErrCh <- err + return + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + firstErrCh <- err + return + } + firstRespCh <- string(body) + }() + + select { + case <-backendOneStarted: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first backend request") + } + + secondResp, err := client.Get(proxy.URL + "/proxy") + if err != nil { + close(releaseBackendOne) + t.Fatalf("second request failed: %v", err) + } + secondBody, err := io.ReadAll(secondResp.Body) + _ = secondResp.Body.Close() + if err != nil { + close(releaseBackendOne) + t.Fatalf("read second response: %v", err) + } + if string(secondBody) != "two" { + close(releaseBackendOne) + t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody)) + } + + close(releaseBackendOne) + select { + case err := <-firstErrCh: + t.Fatalf("first request failed: %v", err) + case body := <-firstRespCh: + if body != "one" { + t.Fatalf("unexpected first response body: %q", body) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first response body") + } +} + +func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) { + t.Helper() + + primaryCalls := make(chan struct{}, 4) + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + primaryCalls <- struct{}{} + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = io.WriteString(w, "primary down") + })) + defer primary.Close() + + secondaryCalls := make(chan struct{}, 4) + secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + secondaryCalls <- struct{}{} + _, _ = io.WriteString(w, "secondary up") + })) + defer secondary.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{primary.URL, secondary.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + }, + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + UnhealthyStatus: []int{http.StatusServiceUnavailable}, + }, + })) + + first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" { + t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String()) + } + second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if second.Code != http.StatusOK || second.Body.String() != "secondary up" { + t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String()) + } + + select { + case <-primaryCalls: + default: + t.Fatal("expected primary to receive the first request") + } + select { + case <-secondaryCalls: + default: + t.Fatal("expected secondary to receive the second request") + } + select { + case <-primaryCalls: + t.Fatal("primary should not receive the second request while unhealthy") + default: + } +} + +func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) { + t.Helper() + + started := make(chan struct{}, 1) + release := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + started <- struct{}{} + <-release + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backend.URL}, + PassiveHealth: ReverseProxyPassiveHealthConfig{ + FailDuration: time.Minute, + }, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + client := proxy.Client() + respCh := make(chan error, 1) + go func() { + resp, err := client.Do(req) + if resp != nil { + _ = resp.Body.Close() + } + respCh <- err + }() + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend request") + } + cancel() + close(release) + select { + case <-respCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for canceled request to finish") + } + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String()) + } +} + +func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) { + t.Helper() + + backendCalls := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCalls <- struct{}{} + _, _ = io.WriteString(w, "ok") + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Targets: []string{"http://127.0.0.1:1", backend.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBFirst(), + Retries: 3, + TryDuration: 100 * time.Millisecond, + TryInterval: 250 * time.Millisecond, + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case <-backendCalls: + t.Fatal("retry budget should expire before the next upstream attempt") + default: + } +} + +func TestReverseProxyAllowH2CUpstream(t *testing.T) { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen h2c upstream: %v", err) + } + server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Upstream-Proto", r.Proto) + _, _ = io.WriteString(w, "ok") + })} + server.Protocols = new(http.Protocols) + server.Protocols.SetUnencryptedHTTP2(true) + errCh := make(chan error, 1) + go func() { + errCh <- server.Serve(listener) + }() + defer func() { + _ = server.Close() + <-errCh + }() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://"+listener.Addr().String()), + AllowH2CUpstream: true, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusOK || rr.Body.String() != "ok" { + t.Fatalf("unexpected response: code=%d body=%q", rr.Code, rr.Body.String()) + } + if got := rr.Header().Get("X-Upstream-Proto"); got != "HTTP/2.0" { + t.Fatalf("expected h2c upstream proto, got %q", got) + } +} + func TestReverseProxyCustomErrorHandler(t *testing.T) { t.Helper() @@ -229,6 +832,148 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) { } } +func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *testing.T) { + t.Helper() + + flushErr := errors.New("flush failed") + writer := &flushErrorResponseWriter{flushErr: flushErr} + conn := &reverseProxyH2ReadWriteCloser{ + ReadCloser: io.NopCloser(strings.NewReader("")), + ResponseWriter: writer, + controller: http.NewResponseController(reverseProxyBaseResponseWriter(writer)), + } + + n, err := conn.Write([]byte("ping")) + if n != len("ping") { + t.Fatalf("unexpected bytes written: %d", n) + } + if !errors.Is(err, flushErr) { + t.Fatalf("unexpected write error: %v", err) + } + if got := writer.body.String(); got != "ping" { + t.Fatalf("unexpected buffered body: %q", got) + } +} + +func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *testing.T) { + t.Helper() + + transportCalled := atomic.Bool{} + entropyErr := errors.New("entropy source unavailable") + originalReader := crand.Reader + crand.Reader = errorReader{err: entropyErr} + t.Cleanup(func() { + crand.Reader = originalReader + }) + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + transportCalled.Store(true) + return nil, errors.New("unexpected round trip") + }), + ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) { + w.WriteHeader(reverseProxyStatusCode(err)) + _, _ = io.WriteString(w, err.Error()) + }, + })) + + headers := make(http.Header) + headers.Set(":protocol", "websocket") + rr := PerformRequest(engine, http.MethodConnect, "/ws", nil, headers) + + if transportCalled.Load() { + t.Fatal("transport should not be called when websocket key generation fails") + } + if rr.Code != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", rr.Code) + } + if body := rr.Body.String(); !strings.Contains(body, "reverse proxy failed to generate websocket key") || !strings.Contains(body, entropyErr.Error()) { + t.Fatalf("unexpected error body: %q", body) + } +} + +func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing.T) { + t.Helper() + + originalDefaultTransport := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("unexpected round trip") + }) + t.Cleanup(func() { + http.DefaultTransport = originalDefaultTransport + }) + + assertTransport := func(name string, rt http.RoundTripper, check func(*http.Transport)) { + t.Helper() + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("%s returned %T, want *http.Transport", name, rt) + } + check(transport) + } + + assertTransport("newHTTP2ExtendedConnectTransport", newHTTP2ExtendedConnectTransport(), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.HTTP1() || !transport.Protocols.HTTP2() { + t.Fatalf("unexpected protocols for extended connect transport: %#v", transport.Protocols) + } + }) + assertTransport("newHTTP1BridgeTransportWithTLSConfig", newHTTP1BridgeTransportWithTLSConfig(nil), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.HTTP1() || transport.Protocols.HTTP2() || transport.Protocols.UnencryptedHTTP2() { + t.Fatalf("unexpected protocols for bridge transport: %#v", transport.Protocols) + } + if transport.TLSClientConfig == nil || len(transport.TLSClientConfig.NextProtos) != 1 || transport.TLSClientConfig.NextProtos[0] != "http/1.1" { + t.Fatalf("unexpected TLS next protos for bridge transport: %#v", transport.TLSClientConfig) + } + }) + assertTransport("newH2CTransport", newH2CTransport(), func(transport *http.Transport) { + if transport.Protocols == nil || !transport.Protocols.UnencryptedHTTP2() || transport.Protocols.HTTP1() || transport.Protocols.HTTP2() { + t.Fatalf("unexpected protocols for h2c transport: %#v", transport.Protocols) + } + }) +} + +func TestNewHTTP1BridgeTransportWithTLSConfigClonesInput(t *testing.T) { + t.Helper() + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + rt := newHTTP1BridgeTransportWithTLSConfig(tlsConfig) + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("unexpected transport type: %T", rt) + } + if transport.TLSClientConfig == nil { + t.Fatal("expected TLS client config") + } + if transport.TLSClientConfig == tlsConfig { + t.Fatal("expected bridge transport to clone TLS config") + } + if len(tlsConfig.NextProtos) != 0 { + t.Fatalf("input TLS config was mutated: %#v", tlsConfig.NextProtos) + } + if got := transport.TLSClientConfig.NextProtos; len(got) != 1 || got[0] != "http/1.1" { + t.Fatalf("unexpected transport NextProtos: %#v", got) + } +} + +func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, context.DeadlineExceeded + }), + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusGatewayTimeout { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) { t.Helper() @@ -452,6 +1197,1081 @@ func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) { } } +func TestReverseProxyUpgradeNeedsHijacker(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("backend response writer does not support hijack") + } + conn, brw, err := hj.Hijack() + if err != nil { + t.Fatalf("backend hijack failed: %v", err) + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + _ = brw.Flush() + })) + defer backend.Close() + + engine := New() + engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/ws", nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotImplemented { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyMaxForwardsTraceHandledLocally(t *testing.T) { + t.Helper() + + called := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) + req.RequestURI = "/trace" + req.Header.Set("Max-Forwards", "0") + req.Header.Set("Authorization", "secret") + req.Header.Set("Cookie", "a=b") + req.Header.Set("Forwarded", "for=192.0.2.1") + + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + resp := rr.Result() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if got := resp.Header.Get("Content-Type"); got != "message/http" { + t.Fatalf("unexpected content type: %q", got) + } + if !strings.Contains(string(body), "TRACE /trace HTTP/1.1") { + t.Fatalf("trace body missing request line: %q", string(body)) + } + if strings.Contains(string(body), "Authorization:") { + t.Fatalf("trace body leaked authorization header: %q", string(body)) + } + if strings.Contains(string(body), "Cookie:") { + t.Fatalf("trace body leaked cookie header: %q", string(body)) + } + if strings.Contains(string(body), "Forwarded:") { + t.Fatalf("trace body leaked forwarded header: %q", string(body)) + } + + select { + case <-called: + t.Fatal("backend should not be called when Max-Forwards is zero") + default: + } +} + +func TestReverseProxyMaxForwardsTraceDecrementsBeforeForwarding(t *testing.T) { + t.Helper() + + maxForwardsCh := make(chan string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + maxForwardsCh <- r.Header.Get("Max-Forwards") + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) + req.Header.Set("Max-Forwards", "2") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case got := <-maxForwardsCh: + if got != "1" { + t.Fatalf("unexpected Max-Forwards header: %q", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Max-Forwards") + } +} + +func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) { + t.Helper() + + called := make(chan struct{}, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + engine := New() + engine.GET("/proxy", func(c *Context) { c.Status(http.StatusNoContent) }) + engine.OPTIONS("/proxy", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodOptions, "http://client.example/proxy", nil) + req.Header.Set("Max-Forwards", "0") + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rr.Code) + } + allow := rr.Header().Get("Allow") + if !strings.Contains(allow, http.MethodGet) || !strings.Contains(allow, http.MethodOptions) { + t.Fatalf("unexpected Allow header: %q", allow) + } + + select { + case <-called: + t.Fatal("backend should not be called when Max-Forwards is zero") + default: + } +} + +func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) { + t.Helper() + + engine := New() + engine.OPTIONS("/", func(c *Context) { + c.Status(http.StatusNoContent) + }) + + req := httptest.NewRequest(http.MethodOptions, "http://client.example/", nil) + req.RequestURI = "*" + req.URL.Path = "" + req.URL.RawPath = "" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code) + } + if got := rr.Header().Get("Content-Length"); got != "0" { + t.Fatalf("unexpected Content-Length header: %q", got) + } +} + +func TestReverseProxyConnectTunnel(t *testing.T) { + t.Helper() + + backendAddr := "" + errCh := make(chan error, 4) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + errCh <- fmt.Errorf("unexpected method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if got, want := r.RequestURI, backendAddr; got != want { + errCh <- fmt.Errorf("unexpected CONNECT target %q, want %q", got, want) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("backend response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\nVia: 1.1 upstream\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("backend read failed: %w", err) + return + } + _, _ = io.WriteString(brw, strings.ToUpper(line)) + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend write failed: %w", err) + return + } + })) + defer backend.Close() + backendAddr = strings.TrimPrefix(backend.URL, "http://") + + engine := New() + engine.Handle(http.MethodConnect, "/:authority", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, backend.URL), + Via: "proxy.test", + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + + _, err = fmt.Fprintf(conn, "CONNECT origin.example:443 HTTP/1.1\r\nHost: origin.example:443\r\n\r\n") + if err != nil { + t.Fatalf("write connect request: %v", err) + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read status line: %v", err) + } + if !strings.Contains(statusLine, "200") { + t.Fatalf("unexpected status line: %q", statusLine) + } + + headers, err := textproto.NewReader(reader).ReadMIMEHeader() + if err != nil { + t.Fatalf("read headers: %v", err) + } + respHeader := http.Header(headers) + if got := respHeader.Get("Content-Length"); got != "" { + t.Fatalf("CONNECT response should not include Content-Length, got %q", got) + } + if got := respHeader.Get("Transfer-Encoding"); got != "" { + t.Fatalf("CONNECT response should not include Transfer-Encoding, got %q", got) + } + if gotVia := respHeader.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.1 upstream" || gotVia[1] != "1.1 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(conn, "ping\n"); err != nil { + t.Fatalf("write tunneled payload: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read tunneled payload: %v", err) + } + if message != "PING\n" { + t.Fatalf("unexpected tunneled payload: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyConnectNeedsHijacker(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("backend response writer does not support hijack") + } + conn, brw, err := hj.Hijack() + if err != nil { + t.Fatalf("backend hijack failed: %v", err) + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\n\r\n") + _ = brw.Flush() + })) + defer backend.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/tunnel", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) + + req := httptest.NewRequest(http.MethodConnect, "http://client.example/tunnel", nil) + req.URL.Path = "/tunnel" + req.RequestURI = "/tunnel" + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotImplemented { + t.Fatalf("unexpected status: %d", rr.Code) + } +} + +func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if got := r.Header.Get(":protocol"); got != "" { + errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) + w.WriteHeader(http.StatusBadRequest) + return + } + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { + errCh <- fmt.Errorf("unexpected upstream Connection header: %#v", r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected upstream Upgrade header: %q", r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { + errCh <- errors.New("missing upstream Sec-WebSocket-Key header") + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.URL.Path; got != "/ws" { + errCh <- fmt.Errorf("unexpected upstream path: %q", got) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade, X-Hop-Token\r\nX-Hop-Token: hidden\r\nSec-WebSocket-Accept: ignored\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(brw, "echo:"+line); err != nil { + errCh <- fmt.Errorf("write tunneled response body failed: %w", err) + return + } + _ = brw.Flush() + })) + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if got := resp.Header.Get("Upgrade"); got != "" { + t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got) + } + if got := resp.Header.Get("X-Hop-Token"); got != "" { + t.Fatalf("bridged extended CONNECT response should not expose hop-by-hop token header, got %q", got) + } + if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled response body: %q", message) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + closeCalls := atomic.Int32{} + backendReadDone := make(chan struct{}, 1) + transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Method != http.MethodGet { + return nil, fmt.Errorf("unexpected upstream method: %s", req.Method) + } + var respondOnce sync.Once + var backend *countingReadWriteCloser + backend = &countingReadWriteCloser{ + readDataCh: make(chan []byte, 1), + closeCalls: &closeCalls, + closeWriteErr: nil, + afterWrite: func() { + respondOnce.Do(func() { + backendReadDone <- struct{}{} + backend.readDataCh <- []byte("echo:ping\n") + close(backend.readDataCh) + }) + }, + } + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: http.Header{ + "Connection": []string{"Upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-WebSocket-Accept": []string{"ignored"}, + }, + Body: backend, + Request: req, + }, nil + }) + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: transport, + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + clientTransport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer clientTransport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := clientTransport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if _, err := io.WriteString(pw, "ping\n"); err != nil { + _ = resp.Body.Close() + t.Fatalf("write tunneled request body: %v", err) + } + select { + case <-backendReadDone: + case <-time.After(2 * time.Second): + _ = resp.Body.Close() + t.Fatal("backend did not receive tunneled request body") + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + _ = resp.Body.Close() + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + _ = resp.Body.Close() + t.Fatalf("unexpected tunneled response body: %q", message) + } + if err := pw.Close(); err != nil { + _ = resp.Body.Close() + t.Fatalf("close tunneled request body: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatalf("close response body: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if closeCalls.Load() > 0 { + break + } + time.Sleep(10 * time.Millisecond) + } + if got := closeCalls.Load(); got != 1 { + t.Fatalf("expected backend connection to close exactly once, got %d", got) + } +} + +func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor != 1 { + errCh <- fmt.Errorf("expected bridged upstream protocol HTTP/1.x, got %s", r.Proto) + w.WriteHeader(http.StatusBadRequest) + return + } + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected websocket bridge headers: Connection=%#v Upgrade=%q", r.Header.Values("Connection"), r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(brw, "echo:"+line); err != nil { + errCh <- fmt.Errorf("write tunneled response body failed: %w", err) + return + } + _ = brw.Flush() + })) + upstream.EnableHTTP2 = true + upstream.StartTLS() + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) + } + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled response body: %q", message) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 8) + newBackend := func(name string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if got := r.Header.Get(":protocol"); got != "" { + errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got) + w.WriteHeader(http.StatusBadRequest) + return + } + if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { + errCh <- fmt.Errorf("%s unexpected upstream Connection header: %#v", name, r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("%s unexpected upstream Upgrade header: %q", name, r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { + errCh <- fmt.Errorf("%s missing upstream Sec-WebSocket-Key header", name) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- fmt.Errorf("%s upstream response writer does not support hijack", name) + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("%s upstream hijack failed: %w", name, err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("%s upstream flush failed: %w", name, err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err) + return + } + if _, err := io.WriteString(brw, name+":"+line); err != nil { + errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err) + return + } + _ = brw.Flush() + })) + return server + } + + backendOne := newBackend("one") + defer backendOne.Close() + backendTwo := newBackend("two") + defer backendTwo.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Targets: []string{backendOne.URL, backendTwo.URL}, + LoadBalancing: ReverseProxyLoadBalancingConfig{ + Policy: LBRoundRobin(), + }, + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + doRequest := func(payload string) string { + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) + } + if _, err := io.WriteString(pw, payload+"\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + message, err := bufio.NewReader(resp.Body).ReadString('\n') + if err != nil { + t.Fatalf("read tunneled response body: %v", err) + } + return message + } + + if got := doRequest("ping"); got != "one:ping\n" { + t.Fatalf("unexpected first tunneled response: %q", got) + } + if got := doRequest("pong"); got != "two:pong\n" { + t.Fatalf("unexpected second tunneled response: %q", got) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("upstream flush failed: %w", err) + return + } + + reader := bufio.NewReader(brw) + line, err := reader.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("read tunneled request body failed: %w", err) + return + } + if _, err := io.WriteString(brw, "ack:"+line); err != nil { + errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err) + return + } + _ = brw.Flush() + + if _, err := io.Copy(io.Discard, reader); err != nil { + errCh <- fmt.Errorf("wait for request half-close failed: %w", err) + return + } + if _, err := io.WriteString(brw, "after-close\n"); err != nil { + errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err) + return + } + _ = brw.Flush() + })) + defer upstream.Close() + + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Via: "proxy.test", + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + pr, pw := io.Pipe() + req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) + } + + reader := bufio.NewReader(resp.Body) + if _, err := io.WriteString(pw, "ping\n"); err != nil { + t.Fatalf("write tunneled request body: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read immediate tunneled response: %v", err) + } + if message != "ack:ping\n" { + t.Fatalf("unexpected immediate tunneled response: %q", message) + } + if err := pw.Close(); err != nil { + t.Fatalf("close tunneled request body: %v", err) + } + + message, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("read post-close tunneled response: %v", err) + } + if message != "after-close\n" { + t.Fatalf("unexpected post-close tunneled response: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testing.T) { + t.Helper() + + enableHTTP2ExtendedConnectProtocol() + + errCh := make(chan error, 4) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("upstream response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("upstream hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + _ = brw.Flush() + + <-r.Context().Done() + })) + defer upstream.Close() + + proxyErrCh := make(chan error, 1) + engine := New() + engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, upstream.URL), + Via: "proxy.test", + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + select { + case proxyErrCh <- err: + default: + } + }, + })) + + proxy := httptest.NewUnstartedServer(engine) + proxy.EnableHTTP2 = true + if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { + t.Fatalf("configure proxy HTTP/2 server: %v", err) + } + proxy.StartTLS() + defer proxy.Close() + + transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer transport.CloseIdleConnections() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(ctx, http.MethodConnect, proxy.URL+"/ws", pr) + if err != nil { + t.Fatalf("new CONNECT request: %v", err) + } + req.Header.Set(":protocol", "websocket") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip extended CONNECT: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) + } + + writeErrCh := make(chan error, 1) + go func() { + _, err := io.WriteString(pw, strings.Repeat("x", 1<<20)) + writeErrCh <- err + }() + time.Sleep(50 * time.Millisecond) + + cancel() + if err := pw.CloseWithError(context.Canceled); err != nil { + t.Fatalf("close request body with cancellation: %v", err) + } + select { + case <-writeErrCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for request body writer to unblock") + } + + select { + case err := <-proxyErrCh: + t.Fatalf("proxy error handler should not be called on cancellation, got: %v", err) + case <-time.After(200 * time.Millisecond): + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { + t.Helper() + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + Body: &failingReadCloser{chunks: []string{"ok"}, err: errors.New("boom")}, + ContentLength: -1, + Request: req, + }, nil + }), + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + resp, err := proxy.Client().Get(proxy.URL + "/proxy") + if err != nil { + t.Fatalf("perform request: %v", err) + } + _, err = io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err == nil { + t.Fatal("expected body read to fail after upstream copy error") + } +} + func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) { t.Helper() @@ -560,6 +2380,117 @@ func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) return fn(req) } +type flushErrorResponseWriter struct { + header http.Header + body bytes.Buffer + status int + written bool + flushErr error +} + +func (w *flushErrorResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *flushErrorResponseWriter) WriteHeader(statusCode int) { + if w.written { + return + } + w.status = statusCode + w.written = true +} + +func (w *flushErrorResponseWriter) Write(p []byte) (int, error) { + if !w.written { + w.WriteHeader(http.StatusOK) + } + return w.body.Write(p) +} + +func (w *flushErrorResponseWriter) Flush() {} + +func (w *flushErrorResponseWriter) FlushError() error { + if !w.written { + w.WriteHeader(http.StatusOK) + } + return w.flushErr +} + +func (w *flushErrorResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} + +func (w *flushErrorResponseWriter) Status() int { + return w.status +} + +func (w *flushErrorResponseWriter) Size() int { + return w.body.Len() +} + +func (w *flushErrorResponseWriter) Written() bool { + return w.written +} + +func (w *flushErrorResponseWriter) IsHijacked() bool { + return false +} + +type errorReader struct { + err error +} + +func (r errorReader) Read([]byte) (int, error) { + return 0, r.err +} + +type countingReadWriteCloser struct { + readData []byte + readDataCh chan []byte + writeBuf bytes.Buffer + closeCalls *atomic.Int32 + closeWriteErr error + afterWrite func() +} + +func (r *countingReadWriteCloser) Read(p []byte) (int, error) { + if len(r.readData) == 0 && r.readDataCh != nil { + data, ok := <-r.readDataCh + if !ok { + return 0, io.EOF + } + r.readData = data + } + if len(r.readData) == 0 { + return 0, io.EOF + } + n := copy(p, r.readData) + r.readData = r.readData[n:] + return n, nil +} + +func (r *countingReadWriteCloser) Write(p []byte) (int, error) { + n, err := r.writeBuf.Write(p) + if err == nil && r.afterWrite != nil { + r.afterWrite() + } + return n, err +} + +func (r *countingReadWriteCloser) Close() error { + if r.closeCalls != nil { + r.closeCalls.Add(1) + } + return nil +} + +func (r *countingReadWriteCloser) CloseWrite() error { + return r.closeWriteErr +} + func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) @@ -568,3 +2499,21 @@ func mustParseURL(t *testing.T, raw string) *url.URL { } return u } + +type failingReadCloser struct { + chunks []string + err error +} + +func (r *failingReadCloser) Read(p []byte) (int, error) { + if len(r.chunks) == 0 { + return 0, r.err + } + n := copy(p, r.chunks[0]) + r.chunks = r.chunks[1:] + return n, nil +} + +func (r *failingReadCloser) Close() error { + return nil +} diff --git a/route_match_benchmark_test.go b/route_match_benchmark_test.go new file mode 100644 index 0000000..e0dd2aa --- /dev/null +++ b/route_match_benchmark_test.go @@ -0,0 +1,130 @@ +package touka + +import "testing" + +var ( + benchmarkRouteHandlers HandlersChain + benchmarkRouteFullPath string + benchmarkRouteParamsLen int + benchmarkRouteCIPath []byte + benchmarkRouteCIFound bool +) + +func buildRouteMatchBenchmarkTree() *node { + tree := &node{} + routes := []string{ + "/", + "/health", + "/contact", + "/api/v1/users", + "/api/v1/users/:id", + "/api/v1/users/:id/settings", + "/assets/*filepath", + "/abc/b", + "/abc/:p1/cde", + "/abc/:p1/:p2/def/*filepath", + } + + for _, route := range routes { + tree.addRoute(route, fakeHandler(route)) + } + + return tree +} + +func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) { + b.Helper() + + params := make(Params, 0, 4) + skipped := make([]skippedNode, 0, 8) + + value := tree.getValue(path, ¶ms, &skipped, true) + if wantFullPath == "" { + if value.handlers != nil { + b.Fatalf("expected no match for %q, got %q", path, value.fullPath) + } + } else { + if value.handlers == nil { + b.Fatalf("expected match for %q, got nil handlers", path) + } + if value.fullPath != wantFullPath { + b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath) + } + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + params = params[:0] + skipped = skipped[:0] + value = tree.getValue(path, ¶ms, &skipped, true) + } + + benchmarkRouteHandlers = value.handlers + benchmarkRouteFullPath = value.fullPath + if value.params != nil { + benchmarkRouteParamsLen = len(*value.params) + } else { + benchmarkRouteParamsLen = 0 + } +} + +func BenchmarkRouteMatch(b *testing.B) { + tree := buildRouteMatchBenchmarkTree() + + b.Run("StaticHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users") + }) + + b.Run("ParamHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id") + }) + + b.Run("BacktrackingHit", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath") + }) + + b.Run("Miss", func(b *testing.B) { + benchmarkRouteLookup(b, tree, "/does/not/exist", "") + }) + + b.Run("CaseInsensitiveHit", func(b *testing.B) { + path := "/API/V1/USERS/123/SETTINGS" + out, found := tree.findCaseInsensitivePath(path, true) + if !found { + b.Fatalf("expected fixed-path match for %q", path) + } + if got := string(out); got != "/api/v1/users/123/settings" { + b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) + + b.Run("CaseInsensitiveMiss", func(b *testing.B) { + path := "/DOES/NOT/EXIST" + out, found := tree.findCaseInsensitivePath(path, true) + if found || out != nil { + b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + out, found = tree.findCaseInsensitivePath(path, true) + } + + benchmarkRouteCIPath = out + benchmarkRouteCIFound = found + }) +} diff --git a/serve.go b/serve.go index f3ddc5f..0fc83f9 100644 --- a/serve.go +++ b/serve.go @@ -14,6 +14,7 @@ import ( "net/http" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -21,329 +22,322 @@ import ( "github.com/fenthope/reco" ) -// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间 const defaultShutdownTimeout = 5 * time.Second -// --- 内部辅助函数 --- +type runMode uint8 -// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080" -func resolveAddress(addr []string) string { - switch len(addr) { - case 0: - return ":8080" - case 1: - return addr[0] - default: - panic("too many parameters provided for server address") +const ( + runModeHTTP runMode = iota + runModeHTTPS + runModeHTTPSRedirect +) + +type runConfig struct { + addr string + httpRedirectAddr string + tlsConfig *tls.Config + redirectHost string + redirectHostHeaders []string + useHeaderHost bool + useHeaderHostSet bool + graceful bool + shutdownTimeout time.Duration + gracefulCtx context.Context + mode runMode + shutdownDefaultSet bool + shutdownTimeoutSet bool +} + +type RunOption interface { + apply(*runConfig) error +} + +type runOptionFunc func(*runConfig) error + +func (f runOptionFunc) apply(cfg *runConfig) error { + return f(cfg) +} + +func defaultRunConfig() runConfig { + return runConfig{ + addr: ":8080", + shutdownTimeout: defaultShutdownTimeout, + mode: runModeHTTP, + useHeaderHost: true, } } -// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值 -func getShutdownTimeout(timeouts []time.Duration) time.Duration { - if len(timeouts) > 0 && timeouts[0] > 0 { - return timeouts[0] - } - return defaultShutdownTimeout +type HTTPRedirectOption interface { + applyRedirect(*runConfig) error } -// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server, -// 并处理其启动失败的致命错误 -// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS") -func runServer(serverType string, srv *http.Server) { +type redirectOptionFunc func(*runConfig) error + +func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error { + return f(cfg) +} + +func WithAddr(addr string) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if addr == "" { + return errors.New("run address must not be empty") + } + cfg.addr = addr + return nil + }) +} + +func WithTLS(tlsConfig *tls.Config) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil") + } + cfg.tlsConfig = tlsConfig + if cfg.mode == runModeHTTP { + cfg.mode = runModeHTTPS + } + return nil + }) +} + +func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if addr == "" { + return errors.New("http redirect address must not be empty") + } + cfg.httpRedirectAddr = addr + cfg.mode = runModeHTTPSRedirect + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.applyRedirect(cfg); err != nil { + return err + } + } + return nil + }) +} + +func WithUseHeaderHost(enabled bool) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.useHeaderHost = enabled + cfg.useHeaderHostSet = true + return nil + }) +} + +func WithRedirectHost(host string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + if host == "" { + return errors.New("redirect host must not be empty") + } + cfg.redirectHost = host + return nil + }) +} + +func WithRedirectHostHeaders(headers []string) HTTPRedirectOption { + return redirectOptionFunc(func(cfg *runConfig) error { + cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0] + for _, header := range headers { + trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header)) + if trimmed != "" { + cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed) + } + } + return nil + }) +} + +func WithGracefulShutdown(timeout time.Duration) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + cfg.graceful = true + cfg.shutdownTimeoutSet = true + if timeout > 0 { + cfg.shutdownTimeout = timeout + } else { + cfg.shutdownTimeout = defaultShutdownTimeout + } + return nil + }) +} + +func WithGracefulShutdownDefault() RunOption { + return runOptionFunc(func(cfg *runConfig) error { + cfg.graceful = true + cfg.shutdownDefaultSet = true + cfg.shutdownTimeout = defaultShutdownTimeout + return nil + }) +} + +func WithShutdownContext(ctx context.Context) RunOption { + return runOptionFunc(func(cfg *runConfig) error { + if ctx == nil { + return errors.New("shutdown context must not be nil") + } + cfg.gracefulCtx = ctx + return nil + }) +} + +func serveServer(srv *http.Server, serveTLS bool) error { + if serveTLS { + return srv.ListenAndServeTLS("", "") + } + return srv.ListenAndServe() +} + +func runServer(serverType string, srv *http.Server, serveTLS bool) { go func() { - var err error protocol := "http" - if srv.TLSConfig != nil { + if serveTLS { protocol = "https" } log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) - if srv.TLSConfig != nil { - // 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置, - // ListenAndServeTLS 的前两个参数可以为空字符串 - err = srv.ListenAndServeTLS("", "") - } else { - err = srv.ListenAndServe() - } - - // 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed), - // 则认为是一个严重错误,并终止程序 + err := serveServer(srv, serveTLS) if err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Touka %s server failed: %v", serverType, err) } }() } -// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器 -// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿 -func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error { - // 创建一个 channel 来接收操作系统信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 - <-quit // 阻塞,直到接收到上述信号之一 - log.Println("Shutting down Touka server(s)...") - - // 关闭日志记录器 - if logger != nil { - go func() { - log.Println("Closing Touka logger...") - CloseLogger(logger) - }() +func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config { + if tlsConfig == nil { + return nil } - - // 创建一个带超时的上下文,用于 Shutdown - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - var wg sync.WaitGroup - errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel - - // 并发地关闭所有服务器 - for _, srv := range servers { - wg.Add(1) - go func(s *http.Server) { - defer wg.Done() - if err := s.Shutdown(ctx); err != nil { - // 将错误发送到 channel - errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) - } - }(srv) - } - - wg.Wait() // 等待所有服务器的关闭 goroutine 完成 - close(errChan) // 关闭 channel,以便可以安全地遍历它 - - // 收集所有关闭过程中发生的错误 - var shutdownErrors []error - for err := range errChan { - shutdownErrors = append(shutdownErrors, err) - log.Printf("Shutdown error: %v", err) - } - - if len(shutdownErrors) > 0 { - return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 - } - log.Println("Touka server(s) exited gracefully.") - return nil + return tlsConfig.Clone() } -func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error { - // 创建一个 channel 来接收操作系统信号 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 - - // 启动服务器 - serverStopped := make(chan error, 1) - for _, srv := range servers { - go func(s *http.Server) { - serverStopped <- s.ListenAndServe() - }(srv) +func parseHTTPSPort(addr string) (string, error) { + _, port, err := net.SplitHostPort(addr) + if err != nil { + return "", fmt.Errorf("https address %q must include a port: %w", addr, err) } + return port, nil +} - select { - case <-ctx.Done(): - // Context 被取消 (例如,通过外部取消函数) - log.Println("Context cancelled, shutting down Touka server(s)...") - case err := <-serverStopped: - // 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("Touka HTTP server failed: %w", err) +func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) { + if serveTLS { + if engine.TLSServerConfigurator != nil { + engine.TLSServerConfigurator(srv) + return } - log.Println("Touka HTTP server stopped gracefully.") - return nil // 服务器已自行优雅关闭,无需进一步处理 - case <-quit: - // 接收到操作系统信号 - log.Println("Shutting down Touka server(s) due to OS signal...") } - - // 关闭日志记录器 - if logger != nil { - go func() { - log.Println("Closing Touka logger...") - CloseLogger(logger) - }() - } - - // 创建一个带超时的上下文,用于 Shutdown - shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - var wg sync.WaitGroup - errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel - - // 并发地关闭所有服务器 - for _, srv := range servers { - wg.Add(1) - go func(s *http.Server) { - defer wg.Done() - if err := s.Shutdown(shutdownCtx); err != nil { - // 将错误发送到 channel - errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) - } - }(srv) - } - - wg.Wait() - close(errChan) // 关闭 channel,以便可以安全地遍历它 - - // 收集所有关闭过程中发生的错误 - var shutdownErrors []error - for err := range errChan { - shutdownErrors = append(shutdownErrors, err) - log.Printf("Shutdown error: %v", err) - } - - if len(shutdownErrors) > 0 { - return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 - } - log.Println("Touka server(s) exited gracefully.") - return nil -} - -// --- 公共 Run 方法 --- - -// Run 启动一个不支持优雅关闭的 HTTP 服务器 -// 这是一个阻塞调用,主要用于简单的场景或快速测试 -// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法 -func (engine *Engine) Run(addr ...string) error { - address := resolveAddress(addr) - srv := &http.Server{Addr: address, Handler: engine} - - // 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性 - engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } - log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address) - return srv.ListenAndServe() } -// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 -func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error { - srv := &http.Server{ - Addr: addr, - Handler: engine, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, - } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - engine.applyDefaultServerConfig(srv) +func applyRedirectServerConfig(engine *Engine, srv *http.Server) { + applyServerProtocols(srv, engine.serverProtocols) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } - - runServer("HTTP", srv) - return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) } -// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 -func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error { - srv := &http.Server{ - Addr: addr, - Handler: engine, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, +func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols { + if engine == nil { + return nil } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - engine.applyDefaultServerConfig(srv) - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) + if serveTLS && engine.useDefaultProtocols { + protocols := &http.Protocols{} + protocols.SetHTTP1(true) + protocols.SetHTTP2(true) + return protocols } - - return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco) + return cloneServerProtocols(engine.serverProtocols) } -// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器 -func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunTLS") - } - - // 配置 HTTP/2 支持 (如果使用默认配置) - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{ - Http1: true, - Http2: true, // 默认在 TLS 上启用 HTTP/2 - }) - } - - srv := &http.Server{ - Addr: addr, +func buildMainServer(engine *Engine, cfg runConfig) *http.Server { + serveTLS := cfg.mode != runModeHTTP + server := &http.Server{ + Addr: cfg.addr, Handler: engine, - TLSConfig: tlsConfig, - BaseContext: func(l net.Listener) context.Context { + TLSConfig: cloneTLSConfig(cfg.tlsConfig), + } + if cfg.graceful { + server.BaseContext = func(net.Listener) context.Context { return engine.shutdownCtx - }, + } + server.RegisterOnShutdown(engine.shutdownCancel) } - srv.RegisterOnShutdown(engine.shutdownCancel) - - // 应用框架的默认配置和用户提供的自定义配置 - // 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator - engine.applyDefaultServerConfig(srv) - if engine.TLSServerConfigurator != nil { - engine.TLSServerConfigurator(srv) - } else if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) - } - - runServer("HTTPS", srv) - return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) + applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS)) + applyMainServerConfig(engine, server, serveTLS) + return server } -// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名 -func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - return engine.RunTLS(addr, tlsConfig, timeouts...) +func firstRedirectHeaderHost(r *http.Request, headers []string) string { + if r == nil { + return "" + } + for _, header := range headers { + value := strings.TrimSpace(r.Header.Get(header)) + if value == "" { + continue + } + if comma := strings.IndexByte(value, ','); comma >= 0 { + value = strings.TrimSpace(value[:comma]) + } + if value != "" { + return value + } + } + return "" } -// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭 -func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil for RunTLSRedir") +func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) { + if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return "", http.StatusInternalServerError, false + } + return cfg.redirectHost, 0, true } - // --- HTTPS 服务器 --- - if engine.useDefaultProtocols { - engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true}) - } - httpsSrv := &http.Server{ - Addr: httpsAddr, - Handler: engine, - TLSConfig: tlsConfig, - BaseContext: func(l net.Listener) context.Context { - return engine.shutdownCtx - }, - } - httpsSrv.RegisterOnShutdown(engine.shutdownCancel) - engine.applyDefaultServerConfig(httpsSrv) - if engine.TLSServerConfigurator != nil { - engine.TLSServerConfigurator(httpsSrv) - } else if engine.ServerConfigurator != nil { - engine.ServerConfigurator(httpsSrv) + if len(cfg.redirectHostHeaders) > 0 { + host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true + } + + if r == nil { + return "", http.StatusUpgradeRequired, false + } + host := strings.TrimSpace(r.Host) + if host == "" { + return "", http.StatusUpgradeRequired, false + } + return host, 0, true +} + +func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { + httpsAddr := cfg.addr + httpAddr := cfg.httpRedirectAddr + httpsPort, err := parseHTTPSPort(httpsAddr) + if err != nil { + return nil, err } - // --- HTTP 重定向服务器 --- redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - host = r.Host + host, statusCode, ok := redirectTargetHost(r, cfg) + if !ok { + http.Error(w, http.StatusText(statusCode), statusCode) + return } - _, httpsPort, err := net.SplitHostPort(httpsAddr) - if err != nil { - // 如果 httpsAddr 没有端口,这是一个配置错误 - - log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr) + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + host = parsedHost + if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") { + host = "[" + host + "]" + } } targetURL := "https://" + host - // 只有在非标准 HTTPS 端口 (443) 时才附加端口号 if httpsPort != "443" { targetURL = "https://" + net.JoinHostPort(host, httpsPort) } @@ -351,22 +345,205 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con http.Redirect(w, r, targetURL, http.StatusMovedPermanently) }) - httpSrv := &http.Server{ - Addr: httpAddr, - Handler: redirectHandler, - } - engine.applyDefaultServerConfig(httpSrv) - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(httpSrv) - } - // --- 启动服务器和优雅关闭 --- - runServer("HTTPS", httpsSrv) - runServer("HTTP Redirect", httpSrv) - return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco) + server := &http.Server{Addr: httpAddr, Handler: redirectHandler} + applyRedirectServerConfig(engine, server) + return server, nil } -// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性 -func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { - return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...) +func validateRunConfig(cfg runConfig) error { + if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil { + return errors.New("WithHTTPRedirect requires WithTLS") + } + if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil { + return errors.New("https mode requires WithTLS") + } + if cfg.gracefulCtx != nil && !cfg.graceful { + return errors.New("WithShutdownContext requires graceful shutdown") + } + if len(cfg.redirectHostHeaders) > 0 { + if !cfg.useHeaderHostSet || !cfg.useHeaderHost { + return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)") + } + } + if cfg.useHeaderHostSet && cfg.useHeaderHost { + if cfg.redirectHost != "" { + return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)") + } + } else if cfg.useHeaderHostSet && !cfg.useHeaderHost { + if cfg.redirectHost == "" { + return errors.New("WithUseHeaderHost(false) requires WithRedirectHost") + } + if len(cfg.redirectHostHeaders) > 0 { + return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)") + } + } + return nil +} + +func effectiveShutdownTimeout(cfg runConfig) time.Duration { + if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet { + if cfg.shutdownTimeout > 0 { + return cfg.shutdownTimeout + } + } + return defaultShutdownTimeout +} + +func closeLoggerAsync(logger *reco.Logger) { + if logger == nil { + return + } + go func() { + log.Println("Closing Touka logger...") + CloseLogger(logger) + }() +} + +func shutdownServers(servers []*http.Server, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + var wg sync.WaitGroup + errChan := make(chan error, len(servers)) + for _, srv := range servers { + wg.Add(1) + go func(s *http.Server) { + defer wg.Done() + if err := s.Shutdown(ctx); err != nil { + errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) + } + }(srv) + } + + wg.Wait() + close(errChan) + + var shutdownErrors []error + for err := range errChan { + shutdownErrors = append(shutdownErrors, err) + log.Printf("Shutdown error: %v", err) + } + if len(shutdownErrors) > 0 { + return errors.Join(shutdownErrors...) + } + return nil +} + +func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(quit) + + serverStopped := make(chan error, len(servers)) + for i, srv := range servers { + serveTLSFlag := serveTLS[i] + go func(server *http.Server, useTLS bool) { + serverStopped <- serveServer(server, useTLS) + }(srv, serveTLSFlag) + } + + select { + case err := <-serverStopped: + if err != nil && !errors.Is(err, http.ErrServerClosed) { + if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil { + return errors.Join(err, shutdownErr) + } + return err + } + log.Println("Touka server stopped gracefully.") + return nil + case <-quit: + log.Println("Shutting down Touka server(s) due to OS signal...") + case <-shutdownCtx.Done(): + log.Println("Context cancelled, shutting down Touka server(s)...") + } + + closeLoggerAsync(logger) + if err := shutdownServers(servers, timeout); err != nil { + return err + } + log.Println("Touka server(s) exited gracefully.") + return nil +} + +// Run starts the engine with the provided startup options. +// +// Default behavior with no options: +// - HTTP only +// - listens on :8080 +// - no graceful shutdown orchestration +// +// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable +// signal-aware graceful shutdown and request-context cancellation semantics. +// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown. +func (engine *Engine) Run(opts ...RunOption) error { + cfg := defaultRunConfig() + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.apply(&cfg); err != nil { + return err + } + } + if cfg.httpRedirectAddr != "" { + cfg.mode = runModeHTTPSRedirect + } else if cfg.tlsConfig != nil { + cfg.mode = runModeHTTPS + } + if err := validateRunConfig(cfg); err != nil { + return err + } + + serveTLS := cfg.mode != runModeHTTP + + mainServer := buildMainServer(engine, cfg) + servers := []*http.Server{mainServer} + serveTLSFlags := []bool{serveTLS} + if cfg.mode == runModeHTTPSRedirect { + redirectServer, err := buildRedirectServer(engine, cfg) + if err != nil { + return err + } + servers = append(servers, redirectServer) + serveTLSFlags = append(serveTLSFlags, false) + } + + if !cfg.graceful { + if len(servers) > 1 { + serverStopped := make(chan error, len(servers)) + for i, srv := range servers { + serveTLSFlag := serveTLSFlags[i] + go func(server *http.Server, useTLS bool) { + serverStopped <- serveServer(server, useTLS) + }(srv, serveTLSFlag) + } + + err := <-serverStopped + if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return errors.Join(err, shutdownErr) + } + return shutdownErr + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil + } + + protocolLabel := "HTTP" + if serveTLS { + protocolLabel = "HTTPS" + } + log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr) + return serveServer(mainServer, serveTLS) + } + + shutdownCtx := context.Background() + if cfg.gracefulCtx != nil { + shutdownCtx = cfg.gracefulCtx + } + return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx) } diff --git a/serve_test.go b/serve_test.go new file mode 100644 index 0000000..a02f1df --- /dev/null +++ b/serve_test.go @@ -0,0 +1,492 @@ +package touka + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func generateSelfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate private key: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "127.0.0.1"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("create self-signed cert: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("parse self-signed cert: %v", err) + } + return cert +} + +func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on ephemeral port: %v", err) + } + addr := listener.Addr().String() + if err := listener.Close(); err != nil { + t.Fatalf("close temporary listener: %v", err) + } + + srv := &http.Server{ + Addr: addr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("ok")) + }), + // RunShutdown uses the HTTP startup path and must not let a shared + // ServerConfigurator accidentally turn it into HTTPS. + TLSConfig: &tls.Config{}, + } + + errCh := make(chan error, 1) + go func() { + errCh <- serveServer(srv, false) + }() + + client := &http.Client{Timeout: 200 * time.Millisecond} + var resp *http.Response + requestURL := "http://" + addr + + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + resp, err = client.Get(requestURL) + if err == nil { + break + } + time.Sleep(20 * time.Millisecond) + } + if err != nil { + select { + case serveErr := <-errCh: + t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: request error=%v, serve error=%v", err, serveErr) + default: + t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err) + } + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read response body: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: got %d want %d", resp.StatusCode, http.StatusOK) + } + if string(body) != "ok" { + t.Fatalf("unexpected body: got %q want %q", string(body), "ok") + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + t.Fatalf("shutdown server: %v", err) + } + + if err := <-errCh; !errors.Is(err, http.ErrServerClosed) { + t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err) + } +} + +func TestRunRejectsRedirectWithoutTLS(t *testing.T) { + engine := New() + err := engine.Run(WithHTTPRedirect(":80")) + if err == nil { + t.Fatal("expected redirect mode without TLS to fail") + } +} + +func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) { + engine := New() + err := engine.Run( + WithAddr(":443"), + WithTLS(&tls.Config{}), + WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})), + ) + if err == nil { + t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail") + } +} + +func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) { + cfg := defaultRunConfig() + if err := WithGracefulShutdownDefault().apply(&cfg); err != nil { + t.Fatalf("apply graceful default option: %v", err) + } + if !cfg.graceful { + t.Fatal("expected graceful shutdown to be enabled") + } + if cfg.shutdownTimeout != defaultShutdownTimeout { + t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout) + } +} + +func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) { + cfg := defaultRunConfig() + tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12} + if err := WithTLS(tlsConfig).apply(&cfg); err != nil { + t.Fatalf("apply TLS option: %v", err) + } + if cfg.mode != runModeHTTPS { + t.Fatalf("expected HTTPS mode, got %v", cfg.mode) + } + if cfg.graceful { + t.Fatal("expected TLS option to remain independent from graceful shutdown") + } + if cfg.tlsConfig != tlsConfig { + t.Fatal("expected TLS config to be preserved in run config") + } +} + +func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { + engine := New() + if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil { + t.Fatal("expected redirect server builder to reject https address without port") + } +} + +func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { + cfg := defaultRunConfig() + ctx := t.Context() + if err := WithShutdownContext(ctx).apply(&cfg); err != nil { + t.Fatalf("apply shutdown context option: %v", err) + } + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected shutdown context without graceful shutdown to fail validation") + } +} + +func TestValidateRunConfigDoesNotMutateMode(t *testing.T) { + cfg := defaultRunConfig() + cfg.httpRedirectAddr = ":80" + if err := validateRunConfig(cfg); err != nil { + t.Fatalf("validate run config: %v", err) + } + if cfg.mode != runModeHTTP { + t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode) + } +} + +func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = false + cfg.useHeaderHostSet = true + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected configured host mode without redirect host to fail validation") + } +} + +func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) { + cfg := defaultRunConfig() + cfg.mode = runModeHTTPSRedirect + cfg.tlsConfig = &tls.Config{} + cfg.useHeaderHost = true + cfg.useHeaderHostSet = true + cfg.redirectHost = "configured.example" + if err := validateRunConfig(cfg); err == nil { + t.Fatal("expected redirect host to be rejected when header host mode is enabled") + } +} + +func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) { + engine := New() + server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP}) + if server.BaseContext == nil { + t.Fatal("expected graceful main server to set BaseContext") + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for base context check: %v", err) + } + defer listener.Close() + if got := server.BaseContext(listener); got != engine.shutdownCtx { + t.Fatal("expected graceful main server to use engine shutdown context") + } +} + +func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) { + engine := New() + serverConfigured := false + tlsConfigured := false + engine.SetServerConfigurator(func(s *http.Server) { + serverConfigured = true + s.ReadTimeout = time.Second + }) + engine.SetTLSServerConfigurator(func(s *http.Server) { + tlsConfigured = true + s.IdleTimeout = time.Second + }) + + server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) + if !tlsConfigured { + t.Fatal("expected TLS configurator to run for HTTPS main server") + } + if serverConfigured { + t.Fatal("expected generic server configurator to be skipped when TLS configurator is set") + } + if server.IdleTimeout != time.Second { + t.Fatal("expected TLS configurator changes to be applied to HTTPS main server") + } +} + +func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) { + engine := New() + configured := false + engine.SetServerConfigurator(func(s *http.Server) { + configured = true + s.ReadTimeout = time.Second + }) + + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + if !configured { + t.Fatal("expected redirect server to use generic server configurator") + } + if server.ReadTimeout != time.Second { + t.Fatal("expected redirect server configurator changes to be applied") + } +} + +func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) { + engine := New() + httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) + if !httpsServer.Protocols.HTTP2() { + t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings") + } + + httpServer := buildMainServer(engine, defaultRunConfig()) + if httpServer.Protocols.HTTP2() { + t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled") + } +} + +func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + req.Header.Set("X-First-Host", "first.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: true, + useHeaderHostSet: true, + redirectHostHeaders: []string{"X-Forwarded-Host"}, + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUpgradeRequired { + t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code) + } +} + +func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{ + addr: ":443", + httpRedirectAddr: ":80", + useHeaderHost: false, + useHeaderHostSet: true, + redirectHost: "configured.example", + }) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) + req.Host = "example.com:80" + req.Header.Set("X-Forwarded-Host", "forwarded.example") + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" { + t.Fatalf("unexpected redirect location: %q", location) + } +} + +func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) { + engine := New() + server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil) + req.Host = "[::1]:80" + rr := httptest.NewRecorder() + server.Handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusMovedPermanently { + t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) + } + if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" { + t.Fatalf("unexpected IPv6 redirect location: %q", location) + } +} + +func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on occupied addr: %v", err) + } + occupiedAddr := occupied.Addr().String() + defer occupied.Close() + + redirectListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for redirect addr: %v", err) + } + redirectAddr := redirectListener.Addr().String() + if err := redirectListener.Close(); err != nil { + t.Fatalf("close redirect addr probe: %v", err) + } + + engine := New() + redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr}) + if err != nil { + t.Fatalf("build redirect server: %v", err) + } + mainServer := &http.Server{Addr: occupiedAddr, Handler: engine} + + err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background()) + if err == nil { + t.Fatal("expected gracefulServe to fail when one server cannot bind") + } + if !strings.Contains(err.Error(), occupiedAddr) { + t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err) + } + + conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond) + if dialErr == nil { + conn.Close() + t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr) + } + if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") { + t.Fatalf("unexpected dial result after shutdown, got %v", dialErr) + } +} + +func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) { + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on occupied addr: %v", err) + } + occupiedAddr := occupied.Addr().String() + defer occupied.Close() + + redirectListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen for redirect addr: %v", err) + } + redirectAddr := redirectListener.Addr().String() + if err := redirectListener.Close(); err != nil { + t.Fatalf("close redirect addr probe: %v", err) + } + + engine := New() + err = engine.Run( + WithAddr(occupiedAddr), + WithTLS(&tls.Config{}), + WithHTTPRedirect(redirectAddr), + ) + if err == nil { + t.Fatal("expected non-graceful TLS redirect startup to return bind error") + } + if !strings.Contains(err.Error(), occupiedAddr) { + t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err) + } +} diff --git a/touka.go b/touka.go index dd529cb..4ad81da 100644 --- a/touka.go +++ b/touka.go @@ -22,10 +22,10 @@ type HandlerFunc func(*Context) // HandlersChain 定义处理函数链(中间件栈)的类型。 type HandlersChain []HandlerFunc -// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 -type IRouter interface { - Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 - Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 +// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 +type Router interface { + Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组 + Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 GET(relativePath string, handlers ...HandlerFunc) diff --git a/tree.go b/tree.go index 31246a5..b159c8d 100644 --- a/tree.go +++ b/tree.go @@ -121,14 +121,28 @@ const ( // node 表示路由树中的一个节点. type node struct { - path string // 当前节点的路径段 - indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 - wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) - nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) - priority uint32 // 节点的优先级, 用于查找时优先匹配 - children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 - handlers HandlersChain // 绑定到此节点的处理函数链 - fullPath string // 完整路径, 用于调试和错误信息 + path string // 当前节点的路径段 + indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 + wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) + hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由 + nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) + priority uint32 // 节点的优先级, 用于查找时优先匹配 + children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 + handlers HandlersChain // 绑定到此节点的处理函数链 + fullPath string // 完整路径, 用于调试和错误信息 +} + +func routeNeedsCaseInsensitiveLookup(path string) bool { + for i := 0; i < len(path); i++ { + c := path[i] + if c >= utf8.RuneSelf { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + } + return false } // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. @@ -162,6 +176,9 @@ func (n *node) incrementChildPrio(pos int) int { func (n *node) addRoute(path string, handlers HandlersChain) { fullPath := path // 记录完整的路径 n.priority++ // 增加当前节点的优先级 + if routeNeedsCaseInsensitiveLookup(path) { + n.hasCaseInsensitivePath = true + } // 如果是空树(根节点) if len(n.path) == 0 && len(n.children) == 0 { @@ -452,12 +469,14 @@ type skippedNode struct { // 建议进行 TSR(尾部斜杠重定向). func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { var globalParamsCount int16 // 全局参数计数 + var backtrackToWildChild bool walk: // 外部循环用于遍历路由树 for { prefix := n.path // 当前节点的路径前缀 if len(path) > len(prefix) { if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 + pathAtNode := path path = path[len(prefix):] // 移除已匹配的前缀 // 在访问 path[0] 之前进行安全检查 @@ -467,30 +486,26 @@ walk: // 外部循环用于遍历路由树 // 优先尝试所有非通配符子节点, 通过匹配索引字符 idxc := path[0] // 剩余路径的第一个字符 - for i, c := range []byte(n.indices) { - if c == idxc { // 如果找到匹配的索引字符 - // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 - if n.wildChild { - index := len(*skippedNodes) - *skippedNodes = (*skippedNodes)[:index+1] - (*skippedNodes)[index] = skippedNode{ - path: prefix + path, // 记录跳过的路径 - node: &node{ // 复制当前节点的状态 - path: n.path, - wildChild: n.wildChild, - nType: n.nType, - priority: n.priority, - children: n.children, - handlers: n.handlers, - fullPath: n.fullPath, - }, - paramsCount: globalParamsCount, // 记录当前参数计数 + if !backtrackToWildChild { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 如果找到匹配的索引字符 + // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 + if n.wildChild { + index := len(*skippedNodes) + *skippedNodes = (*skippedNodes)[:index+1] + (*skippedNodes)[index] = skippedNode{ + path: pathAtNode, // 记录进入当前节点时的剩余路径 + node: n, + paramsCount: globalParamsCount, // 记录当前参数计数 + } } - } - n = n.children[i] // 移动到匹配的子节点 - continue walk // 继续外部循环 + n = n.children[i] // 移动到匹配的子节点 + continue walk // 继续外部循环 + } } + } else { + backtrackToWildChild = false } if !n.wildChild { @@ -507,7 +522,8 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 } globalParamsCount = skippedNode.paramsCount // 恢复参数计数 - continue walk // 继续外部循环 + backtrackToWildChild = true + continue walk // 继续外部循环 } } } @@ -547,7 +563,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path[:end] // 提取参数值 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(val); err == nil { val = v // 解码成功则更新值 } @@ -599,7 +615,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path // 参数值是剩余的整个路径 - if unescape { // 如果需要进行 URL 解码 + if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { if v, err := url.QueryUnescape(path); err == nil { val = v // 解码成功则更新值 } @@ -634,6 +650,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -658,8 +675,8 @@ walk: // 外部循环用于遍历路由树 } // 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议 - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 @@ -688,6 +705,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount + backtrackToWildChild = true continue walk } } @@ -701,13 +719,15 @@ walk: // 外部循环用于遍历路由树 // 它还可以选择修复尾部斜杠. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { - const stackBufSize = 128 // 栈上缓冲区的默认大小 + return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash) +} - // 在常见情况下使用栈上静态大小的缓冲区. - // 如果路径太长, 则在堆上分配缓冲区. - buf := make([]byte, 0, stackBufSize) - if length := len(path) + 1; length > stackBufSize { - buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区 +func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) { + if buf != nil { + buf = buf[:0] + } + if cap(buf) < len(path)+1 { + buf = make([]byte, 0, len(path)+1) } ciPath := n.findCaseInsensitivePathRec( @@ -758,8 +778,8 @@ walk: // 外部循环用于遍历路由树 // 未找到处理函数. // 尝试通过添加尾部斜杠来修复路径 if fixTrailingSlash { - for i, c := range []byte(n.indices) { - if c == '/' { // 如果索引中包含 '/' + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 @@ -781,8 +801,8 @@ walk: // 外部循环用于遍历路由树 if rb[0] != 0 { // 旧 rune 未处理完 idxc := rb[0] - for i, c := range []byte(n.indices) { - if c == idxc { + for i := 0; i < len(n.indices); i++ { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) @@ -813,9 +833,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 小写匹配 - if c == idxc { + if n.indices[i] == idxc { // 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在 if out := n.children[i].findCaseInsensitivePathRec( path, ciPath, rb, fixTrailingSlash, @@ -832,9 +852,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i, c := range []byte(n.indices) { + for i := 0; i < len(n.indices); i++ { // 大写匹配 - if c == idxc { + if n.indices[i] == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) @@ -852,7 +872,7 @@ walk: // 外部循环用于遍历路由树 return nil // 未找到, 返回 nil } - n = n.children[0] // 移动到通配符子节点(通常是唯一一个) + n = n.children[len(n.children)-1] // 通配符子节点约定始终位于末尾 switch n.nType { case param: // 参数节点 // 查找参数结束位置('/' 或路径末尾) diff --git a/tree_test.go b/tree_test.go index d3ffdfa..a35a1a8 100644 --- a/tree_test.go +++ b/tree_test.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" "testing" + "time" ) // Used as a workaround since we can't compare functions or their addresses @@ -39,6 +40,23 @@ func getSkippedNodes() *[]skippedNode { return &ps } +func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue { + t.Helper() + + resultCh := make(chan nodeValue, 1) + go func() { + resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape) + }() + + select { + case value := <-resultCh: + return value + case <-time.After(2 * time.Second): + t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path) + return nodeValue{} + } +} + func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { unescape := false if len(unescapes) >= 1 { @@ -901,6 +919,34 @@ func TestTreeInvalidNodeType(t *testing.T) { } } +func TestFindCaseInsensitivePathWithStaticAndParamRoutesDoesNotPanicOnMiss(t *testing.T) { + tree := &node{} + routes := [...]string{ + "/:user/:repo/info/refs", + "/healthz", + "/api/db/data", + "/api/db/sum", + } + + for _, route := range routes { + tree.addRoute(route, fakeHandler(route)) + } + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic while looking up missing path: %v", r) + } + }() + + if out, found := tree.findCaseInsensitivePath("/does-not-exist", true); found || out != nil { + t.Fatalf("expected missing path lookup to return no match, got %q, %t", string(out), found) + } + + if out, found := tree.findCaseInsensitivePath("/does-not-exist", false); found || out != nil { + t.Fatalf("expected missing path lookup without trailing slash fix to return no match, got %q, %t", string(out), found) + } +} + func TestTreeInvalidParamsType(t *testing.T) { tree := &node{} // add a child with wildcard @@ -1076,3 +1122,51 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) { t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams) } } + +func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) { + tests := []struct { + name string + routes []string + requestPath string + wantFullPath string + wantParams Params + }{ + { + name: "param route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/details"}, + requestPath: "/foo/bar/details", + wantFullPath: "/foo/:id/details", + wantParams: Params{{Key: "id", Value: "bar"}}, + }, + { + name: "catch-all route after static dead end", + routes: []string{"/foo/bar", "/foo/:id/*rest"}, + requestPath: "/foo/bar/baz.txt", + wantFullPath: "/foo/:id/*rest", + wantParams: Params{ + {Key: "id", Value: "bar"}, + {Key: "rest", Value: "/baz.txt"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tree := &node{} + for _, route := range tt.routes { + tree.addRoute(route, fakeHandler(route)) + } + + value := getValueWithTimeout(t, tree, tt.requestPath, false) + if value.handlers == nil { + t.Fatalf("expected handlers for %q", tt.requestPath) + } + if value.fullPath != tt.wantFullPath { + t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath) + } + if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) { + t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params) + } + }) + } +}