diff --git a/.gitignore b/.gitignore index 6f301cd..30d74d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1 @@ -test -/bench_route_match_baseline.txt +test \ No newline at end of file diff --git a/README.md b/README.md index e2eaec8..a7b99fd 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,9 @@ func main() { c.String(http.StatusOK, "Hello, %s! You seem %s.", name, query) }) - // 启动服务器(通过 WithGracefulShutdown 启用优雅关闭) + // 启动服务器 (支持优雅关闭) log.Println("Touka Server starting on :8080...") - if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { + if err := r.RunShutdown(":8080", 10*time.Second); err != nil { log.Fatalf("Touka server failed to start: %v", err) } } diff --git a/about-touka.md b/about-touka.md index b3a16b4..86a056f 100644 --- a/about-touka.md +++ b/about-touka.md @@ -70,13 +70,13 @@ func main() { r.SetGlobalMaxRequestBodySize(10 * 1024 * 1024) // 10 MB // ... 其他配置 - r.Run(touka.WithAddr(":8080")) + r.Run(":8080") } ``` #### 1.3. 服务器生命周期管理 -Touka 提供了对底层 `*http.Server` 的完全控制,并可通过 `Run(...)` 的启动选项启用优雅关闭逻辑。 +Touka 提供了对底层 `*http.Server` 的完全控制,并内置了优雅关闭的逻辑。 ```go func main() { @@ -90,11 +90,11 @@ func main() { fmt.Println("自定义的 HTTP 服务器配置已应用") }) - // 启动服务器,并通过 Run 选项启用优雅关闭 - // Run(...) 会阻塞当前 goroutine - // WithGracefulShutdown(10*time.Second) 表示在关闭时最多等待 10 秒 + // 启动服务器,并支持优雅关闭 + // RunShutdown 会阻塞,直到收到 SIGINT 或 SIGTERM 信号 + // 第二个参数是优雅关闭的超时时间 fmt.Println("服务器启动于 :8080") - if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { + if err := r.RunShutdown(":8080", 10*time.Second); err != nil { log.Fatalf("服务器启动失败: %v", err) } } @@ -187,7 +187,7 @@ func main() { } } - r.Run(touka.WithAddr(":8080")) + r.Run(":8080") } func AuthMiddleware() touka.HandlerFunc { @@ -313,7 +313,7 @@ func main() { }) }) - r.Run(touka.WithAddr(":8080")) + r.Run(":8080") } // templates/index.html @@ -400,7 +400,7 @@ func main() { c.JSON(http.StatusOK, touka.H{"status": "ok", "request_id": requestID}) }) - r.Run(touka.WithAddr(":8080")) + r.Run(":8080") } ``` @@ -483,7 +483,7 @@ func main() { // 静态文件服务,如果文件不存在,也会被上面的 ErrorHandler 捕获 r.StaticDir("/files", "./non-existent-dir") - r.Run(touka.WithAddr(":8080")) + r.Run(":8080") } ``` @@ -546,7 +546,7 @@ func main() { // 所有对 / 的访问都会映射到嵌入的 frontend/dist 目录 r.StaticFS("/", http.FS(subFS)) - r.Run(touka.WithAddr(":8080")) + r.Run(":8080") } ``` diff --git a/compat.go b/compat.go deleted file mode 100644 index 0be715d..0000000 --- a/compat.go +++ /dev/null @@ -1,52 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. -// Copyright 2024 WJQSERVER. All rights reserved. -// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. -package touka - -import ( - "github.com/WJQSERVER-STUDIO/httpc" - "github.com/fenthope/reco" -) - -// --- reco 兼容函数 --- - -// GetLogReco 返回底层的 reco.Logger 实例 -// 用于需要访问 reco 特定功能的场景 -// 如果当前 logger 不是 *reco.Logger 类型,返回 nil -// -//go:fix inline -func (engine *Engine) GetLogReco() *reco.Logger { - return engine.LogReco -} - -// SetLogReco 设置 reco.Logger 实例 -// 用于向后兼容,等价于 SetLogger(l) -// -//go:fix inline -func (engine *Engine) SetLogReco(l *reco.Logger) { - engine.LogReco = l - engine.logger = l -} - -// GetLoggerReco 返回底层的 reco.Logger 实例 -// 用于需要访问 reco 特定功能的场景 -// 如果当前 logger 不是 *reco.Logger 类型,返回 nil -// -//go:fix inline -func (c *Context) GetLoggerReco() *reco.Logger { - if rl, ok := c.engine.logger.(*reco.Logger); ok { - return rl - } - return c.engine.LogReco -} - -// --- httpc 兼容函数 --- - -// GetHTTPC 返回底层的 httpc.Client 实例 -// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context -// -//go:fix inline -func (c *Context) GetHTTPC() *httpc.Client { - return c.Client() -} diff --git a/context.go b/context.go index f21ed48..2e4d2bb 100644 --- a/context.go +++ b/context.go @@ -26,6 +26,7 @@ import ( "time" "github.com/WJQSERVER/wanf" + "github.com/fenthope/reco" "github.com/go-json-experiment/json" "github.com/WJQSERVER-STUDIO/go-utils/iox" @@ -43,8 +44,6 @@ type Context struct { handlers HandlersChain // 当前请求的处理函数链 (中间件 + 最终handler) index int8 // 当前执行到处理链的哪个位置 - requestBodyPrepared bool - mu sync.RWMutex Keys map[string]any // 用于在中间件之间传递数据 @@ -72,12 +71,6 @@ type Context struct { // skippedNodes 用于记录跳过的节点信息,以便回溯 // 通常在处理嵌套路由时使用 SkippedNodes []skippedNode - - // fixedPathBuf 用于复用固定路径重定向时的大小写修正结果缓冲. - fixedPathBuf []byte - - allowedMethodsBuf []string - allowHeaderBuf []byte } // --- Context 相关方法实现 --- @@ -102,42 +95,19 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 - c.Keys = nil // 仅在首次 Set 时创建,避免每个请求都分配 map + c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 c.Errors = c.Errors[:0] // 清空 Errors 切片 c.queryCache = nil // 清空查询参数缓存 c.formCache = nil // 清空表单数据缓存 c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize - c.requestBodyPrepared = false if cap(c.SkippedNodes) > 0 { c.SkippedNodes = c.SkippedNodes[:0] } else { c.SkippedNodes = make([]skippedNode, 0, 256) } - if cap(c.fixedPathBuf) > 0 { - c.fixedPathBuf = c.fixedPathBuf[:0] - } - if cap(c.allowedMethodsBuf) > 0 { - c.allowedMethodsBuf = c.allowedMethodsBuf[:0] - } - if cap(c.allowHeaderBuf) > 0 { - c.allowHeaderBuf = c.allowHeaderBuf[:0] - } -} - -func (c *Context) writeResponseBody(data []byte, contextMsg string) { - if len(data) == 0 { - return - } - if _, err := c.Writer.Write(data); err != nil { - wrapped := fmt.Errorf("%s: %w", contextMsg, err) - c.AddError(wrapped) - if c.engine != nil && c.engine.logger != nil { - c.engine.logger.Errorf("%s: %v", contextMsg, err) - } - } } // Next 在处理链中执行下一个处理函数 @@ -267,18 +237,6 @@ func (c *Context) SetMaxRequestBodySize(size int64) { c.MaxRequestBodySize = size } -func (c *Context) prepareRequestBody() io.ReadCloser { - if c.Request == nil || c.Request.Body == nil { - return nil - } - if c.requestBodyPrepared || c.MaxRequestBodySize <= 0 { - return c.Request.Body - } - c.Request.Body = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) - c.requestBodyPrepared = true - return c.Request.Body -} - // Query 从 URL 查询参数中获取值 // 懒加载解析查询参数,并进行缓存 func (c *Context) Query(key string) string { @@ -300,39 +258,7 @@ func (c *Context) DefaultQuery(key, defaultValue string) string { // 懒加载解析表单数据,并进行缓存 func (c *Context) PostForm(key string) string { if c.formCache == nil { - if c.MaxRequestBodySize > 0 { - c.prepareRequestBody() - } - contentType := c.Request.Header.Get("Content-Type") - mediaType, _, err := mime.ParseMediaType(contentType) - if err != nil { - c.AddError(fmt.Errorf("parse form error: %w", err)) - c.formCache = make(url.Values) - return "" - } - - switch mediaType { - case "multipart/form-data": - if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { - c.AddError(fmt.Errorf("parse form error: %w", err)) - c.formCache = make(url.Values) - return "" - } - case "application/x-www-form-urlencoded": - if err := c.Request.ParseForm(); err != nil { - c.AddError(fmt.Errorf("parse form error: %w", err)) - c.formCache = make(url.Values) - return "" - } - default: - if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { - if !errors.Is(err, http.ErrNotMultipart) { - c.AddError(fmt.Errorf("parse form error: %w", err)) - c.formCache = make(url.Values) - return "" - } - } - } + c.Request.ParseMultipartForm(defaultMemory) // 解析 multipart/form-data 或 application/x-www-form-urlencoded c.formCache = c.Request.PostForm } return c.formCache.Get(key) @@ -356,20 +282,20 @@ func (c *Context) Param(key string) string { func (c *Context) Raw(code int, contentType string, data []byte) { c.Writer.Header().Set("Content-Type", contentType) c.Writer.WriteHeader(code) - c.writeResponseBody(data, "failed to write raw response") + c.Writer.Write(data) } // String 向响应写入格式化的字符串 func (c *Context) String(code int, format string, values ...any) { c.Writer.WriteHeader(code) - c.writeResponseBody(fmt.Appendf(nil, format, values...), "failed to write string response") + c.Writer.Write(fmt.Appendf(nil, format, values...)) } // Text 向响应写入无需格式化的string func (c *Context) Text(code int, text string) { c.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8") c.Writer.WriteHeader(code) - c.writeResponseBody([]byte(text), "failed to write text response") + c.Writer.Write([]byte(text)) } // FileText @@ -412,11 +338,8 @@ func (c *Context) FileText(code int, filePath string) { } c.SetHeader("Content-Type", "text/plain; charset=utf-8") - c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", fileInfo.Size())) - c.Writer.WriteHeader(code) - if _, err := iox.Copy(c.Writer, file); err != nil { - c.AddError(fmt.Errorf("failed to write file %s to response: %w", cleanPath, err)) - } + + c.SetBodyStream(file, int(fileInfo.Size())) } /* @@ -507,7 +430,7 @@ func (c *Context) JSONBuf(code int, obj any) { c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.WriteHeader(code) - c.writeResponseBody(buf.Bytes(), "failed to write buffered JSON response") + c.Writer.Write(buf.Bytes()) } // GOB 向响应写入GOB数据 @@ -536,7 +459,7 @@ func (c *Context) GOBBuf(code int, obj any) { } c.Writer.Header().Set("Content-Type", "application/octet-stream") c.Writer.WriteHeader(code) - c.writeResponseBody(buf.Bytes(), "failed to write buffered GOB response") + c.Writer.Write(buf.Bytes()) } // WANF向响应写入WANF数据 @@ -565,7 +488,7 @@ func (c *Context) WANFBuf(code int, obj any) { } c.Writer.Header().Set("Content-Type", "application/vnd.wjqserver.wanf; charset=utf-8") c.Writer.WriteHeader(code) - c.writeResponseBody(buf.Bytes(), "failed to write buffered WANF response") + c.Writer.Write(buf.Bytes()) } // HTML 渲染 HTML 模板 @@ -589,7 +512,7 @@ func (c *Context) HTML(code int, name string, obj any) { // 可以扩展支持其他渲染器接口 } // 默认简单输出,用于未配置 HTMLRender 的情况 - c.writeResponseBody(fmt.Appendf(nil, "\n
%v
", name, obj), "failed to write HTML response") + c.Writer.Write(fmt.Appendf(nil, "\n
%v
", name, obj)) } // HTMLBuf 先将 HTML 模板渲染到 buffer, 成功后再写入状态码和响应体. @@ -614,7 +537,7 @@ func (c *Context) HTMLBuf(code int, name string, obj any) { // 渲染成功,写入响应 c.Writer.Header().Set("Content-Type", "text/html; charset=utf-8") c.Writer.WriteHeader(code) - c.writeResponseBody(buf.Bytes(), "failed to write buffered HTML response") + c.Writer.Write(buf.Bytes()) return } @@ -634,16 +557,10 @@ func (c *Context) Redirect(code int, location string) { // ShouldBindJSON 尝试将请求体绑定到 JSON 对象 func (c *Context) ShouldBindJSON(obj any) error { - var body io.ReadCloser - if c.MaxRequestBodySize > 0 { - body = c.prepareRequestBody() - } else { - body = c.Request.Body - } - if body == nil { + if c.Request.Body == nil { return errors.New("request body is empty") } - err := json.UnmarshalRead(body, obj) + err := json.UnmarshalRead(c.Request.Body, obj) if err != nil { return fmt.Errorf("json binding error: %w", err) } @@ -652,16 +569,10 @@ func (c *Context) ShouldBindJSON(obj any) error { // ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象 func (c *Context) ShouldBindWANF(obj any) error { - var body io.ReadCloser - if c.MaxRequestBodySize > 0 { - body = c.prepareRequestBody() - } else { - body = c.Request.Body - } - if body == nil { + if c.Request.Body == nil { return errors.New("request body is empty") } - decoder, err := wanf.NewStreamDecoder(body) + decoder, err := wanf.NewStreamDecoder(c.Request.Body) if err != nil { return fmt.Errorf("failed to create WANF decoder: %w", err) } @@ -674,16 +585,10 @@ func (c *Context) ShouldBindWANF(obj any) error { // ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象 func (c *Context) ShouldBindGOB(obj any) error { - var body io.ReadCloser - if c.MaxRequestBodySize > 0 { - body = c.prepareRequestBody() - } else { - body = c.Request.Body - } - if body == nil { + if c.Request.Body == nil { return errors.New("request body is empty") } - decoder := gob.NewDecoder(body) + decoder := gob.NewDecoder(c.Request.Body) if err := decoder.Decode(obj); err != nil { return fmt.Errorf("GOB binding error: %w", err) } @@ -800,10 +705,6 @@ func setFieldValue(field reflect.Value, values []string) error { // ShouldBindForm 尝试将表单数据绑定到结构体 // 支持 application/x-www-form-urlencoded 和 multipart/form-data func (c *Context) ShouldBindForm(obj any) error { - if c.MaxRequestBodySize > 0 { - c.prepareRequestBody() - } - contentType := c.Request.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(contentType) if err != nil { @@ -812,7 +713,7 @@ func (c *Context) ShouldBindForm(obj any) error { switch mediaType { case "multipart/form-data": - if err := c.Request.ParseMultipartForm(defaultMemory); err != nil { + if err := c.Request.ParseMultipartForm(32 << 20); err != nil { return fmt.Errorf("parse multipart form error: %w", err) } case "application/x-www-form-urlencoded": @@ -826,7 +727,6 @@ func (c *Context) ShouldBindForm(obj any) error { if err := bindForm(c.Request.Form, obj); err != nil { return fmt.Errorf("form binding error: %w", err) } - c.formCache = c.Request.PostForm return nil } @@ -864,29 +764,10 @@ func (c *Context) GetErrors() []error { return c.Errors } -// Client 返回当前请求的 HTTPClient -// 如果请求处理函数或中间件设置了自定义 HTTPClient,返回该实例; -// 否则返回 Engine 提供的默认实例 -// -// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context +// Client 返回 Engine 提供的 HTTPClient +// 方便在请求处理函数中进行出站 HTTP 请求 func (c *Context) Client() *httpc.Client { - if c.HTTPClient != nil { - return c.HTTPClient - } - return c.engine.HTTPClient -} - -// HTTPC 返回自动关联请求 Context 的 HTTP 客户端 -// 当请求被取消时,通过此客户端发起的出站请求也会自动取消 -func (c *Context) HTTPC() *contextHTTPClient { - client := c.HTTPClient - if client == nil { - client = c.engine.HTTPClient - } - return &contextHTTPClient{ - client: client, - ctx: c.ctx, - } + return c.HTTPClient } // Context() 返回请求的上下文,用于取消操作 @@ -946,30 +827,37 @@ func (c *Context) WriteStream(reader io.Reader) (written int64, err error) { // GetReqBody 以获取一个 io.ReadCloser 接口,用于读取请求体 // 注意:请求体只能读取一次 func (c *Context) GetReqBody() io.ReadCloser { - if c.MaxRequestBodySize > 0 { - return c.prepareRequestBody() - } - if c.Request == nil || c.Request.Body == nil { - return nil - } return c.Request.Body } // GetReqBodyFull 读取并返回请求体的所有内容 // 注意:请求体只能读取一次 func (c *Context) GetReqBodyFull() ([]byte, error) { - body := c.GetReqBody() - if body == nil { + if c.Request.Body == nil { return nil, nil } - defer func() { - err := body.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - data, err := io.ReadAll(body) + var limitBytesReader io.ReadCloser + + if c.MaxRequestBodySize > 0 { + limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) + defer func() { + err := limitBytesReader.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() + } else { + limitBytesReader = c.Request.Body + defer func() { + err := limitBytesReader.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() + } + + data, err := iox.ReadAll(limitBytesReader) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -979,18 +867,31 @@ func (c *Context) GetReqBodyFull() ([]byte, error) { // 类似 GetReqBodyFull, 返回 *bytes.Buffer func (c *Context) GetReqBodyBuffer() (*bytes.Buffer, error) { - body := c.GetReqBody() - if body == nil { + if c.Request.Body == nil { return nil, nil } - defer func() { - err := body.Close() - if err != nil { - c.AddError(fmt.Errorf("failed to close request body: %w", err)) - } - }() - data, err := io.ReadAll(body) + var limitBytesReader io.ReadCloser + + if c.MaxRequestBodySize > 0 { + limitBytesReader = NewMaxBytesReader(c.Request.Body, c.MaxRequestBodySize) + defer func() { + err := limitBytesReader.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() + } else { + limitBytesReader = c.Request.Body + defer func() { + err := limitBytesReader.Close() + if err != nil { + c.AddError(fmt.Errorf("failed to close request body: %w", err)) + } + }() + } + + data, err := iox.ReadAll(limitBytesReader) if err != nil { c.AddError(fmt.Errorf("failed to read request body: %w", err)) return nil, fmt.Errorf("failed to read request body: %w", err) @@ -1149,9 +1050,14 @@ func (c *Context) GetProtocol() string { return c.Request.Proto } -// GetLogger 获取engine的Logger接口 -func (c *Context) GetLogger() Logger { - return c.engine.logger +// GetHTTPC 获取框架自带传递的httpc +func (c *Context) GetHTTPC() *httpc.Client { + return c.HTTPClient +} + +// GetLogger 获取engine的Logger +func (c *Context) GetLogger() *reco.Logger { + return c.engine.LogReco } // GetReqQueryString @@ -1310,25 +1216,25 @@ func (c *Context) DeleteCookie(name string) { // === 日志记录 === func (c *Context) Debugf(format string, args ...any) { - c.engine.logger.Debugf(format, args...) + c.engine.LogReco.Debugf(format, args...) } func (c *Context) Infof(format string, args ...any) { - c.engine.logger.Infof(format, args...) + c.engine.LogReco.Infof(format, args...) } func (c *Context) Warnf(format string, args ...any) { - c.engine.logger.Warnf(format, args...) + c.engine.LogReco.Warnf(format, args...) } func (c *Context) Errorf(format string, args ...any) { - c.engine.logger.Errorf(format, args...) + c.engine.LogReco.Errorf(format, args...) } func (c *Context) Fatalf(format string, args ...any) { - c.engine.logger.Fatalf(format, args...) + c.engine.LogReco.Fatalf(format, args...) } func (c *Context) Panicf(format string, args ...any) { - c.engine.logger.Panicf(format, args...) + c.engine.LogReco.Panicf(format, args...) } diff --git a/context_benchmark_test.go b/context_benchmark_test.go deleted file mode 100644 index 3c464d0..0000000 --- a/context_benchmark_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package touka - -import ( - "net/http" - "testing" -) - -func TestContextResetKeepsKeysNilUntilSet(t *testing.T) { - c, _ := CreateTestContext(nil) - if c.Keys != nil { - t.Fatalf("expected fresh test context Keys to be nil before first Set") - } - - c.Set("answer", 42) - if c.Keys == nil { - t.Fatalf("expected Set to allocate Keys map") - } - if value, exists := c.Get("answer"); !exists || value != 42 { - t.Fatalf("expected stored value to round-trip, got %v, %t", value, exists) - } - - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatalf("failed to build request: %v", err) - } - c.reset(UnwrapResponseWriter(c.Writer), req) - - if c.Keys != nil { - t.Fatalf("expected reset to clear Keys without allocating a new map") - } - if value, exists := c.Get("answer"); exists || value != nil { - t.Fatalf("expected cleared keys after reset, got %v, %t", value, exists) - } - - ctxValue := c.Value("missing") - if ctxValue != nil { - t.Fatalf("expected nil value for missing context key after reset, got %v", ctxValue) - } - defer func() { - if r := recover(); r == nil { - t.Fatalf("expected MustGet to panic for missing key after reset") - } - }() - _ = c.MustGet("answer") -} - -func BenchmarkContextReset(b *testing.B) { - b.Run("NoKeysUse", func(b *testing.B) { - c, _ := CreateTestContext(nil) - rawWriter := UnwrapResponseWriter(c.Writer) - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - b.Fatalf("failed to build request: %v", err) - } - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - c.reset(rawWriter, req) - } - }) - - b.Run("WithKeysUse", func(b *testing.B) { - c, _ := CreateTestContext(nil) - rawWriter := UnwrapResponseWriter(c.Writer) - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - b.Fatalf("failed to build request: %v", err) - } - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - c.reset(rawWriter, req) - c.Set("request-id", i) - } - }) - -} diff --git a/context_bodylimit_test.go b/context_bodylimit_test.go deleted file mode 100644 index 1e7696a..0000000 --- a/context_bodylimit_test.go +++ /dev/null @@ -1,174 +0,0 @@ -package touka - -import ( - "errors" - "io" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "strings" - "testing" -) - -type zeroNilThenEOFReader struct { - readCalls int -} - -func (r *zeroNilThenEOFReader) Read(_ []byte) (int, error) { - r.readCalls++ - if r.readCalls == 1 { - return 0, nil - } - return 0, io.EOF -} - -func (r *zeroNilThenEOFReader) Close() error { - return nil -} - -func TestFileTextUsesProvidedStatusCode(t *testing.T) { - t.Helper() - - dir := t.TempDir() - filePath := filepath.Join(dir, "hello.txt") - if err := os.WriteFile(filePath, []byte("hello touka"), 0o644); err != nil { - t.Fatalf("write temp file: %v", err) - } - - rr := httptest.NewRecorder() - c, _ := CreateTestContext(rr) - - c.FileText(http.StatusCreated, filePath) - - if rr.Code != http.StatusCreated { - t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code) - } - if got := rr.Header().Get("Content-Type"); got != "text/plain; charset=utf-8" { - t.Fatalf("unexpected content type: %q", got) - } - if body := rr.Body.String(); body != "hello touka" { - t.Fatalf("unexpected body: %q", body) - } -} - -func TestMaxBytesReaderAllowsExactLimit(t *testing.T) { - t.Helper() - - reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcd")), 4) - defer reader.Close() - - data, err := io.ReadAll(reader) - if err != nil { - t.Fatalf("expected exact limit read to succeed, got %v", err) - } - if string(data) != "abcd" { - t.Fatalf("unexpected data: %q", string(data)) - } -} - -func TestMaxBytesReaderRejectsOverLimit(t *testing.T) { - t.Helper() - - reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abcde")), 4) - defer reader.Close() - - _, err := io.ReadAll(reader) - if !errors.Is(err, ErrBodyTooLarge) { - t.Fatalf("expected ErrBodyTooLarge, got %v", err) - } -} - -func TestMaxBytesReaderAllowsZeroNilThenEOFAtExactLimit(t *testing.T) { - t.Helper() - - reader := NewMaxBytesReader(&zeroNilThenEOFReader{}, 1) - defer reader.Close() - - buf := make([]byte, 1) - n, err := reader.Read(buf) - if n != 0 || err != nil { - t.Fatalf("expected initial zero,nil read result, got n=%d err=%v", n, err) - } - - n, err = reader.Read(buf) - if n != 0 || !errors.Is(err, io.EOF) { - t.Fatalf("expected EOF after retry, got n=%d err=%v", n, err) - } -} - -func TestMaxBytesReaderTreatsZeroLimitAsUnlimited(t *testing.T) { - t.Helper() - - reader := NewMaxBytesReader(io.NopCloser(strings.NewReader("abc")), 0) - defer reader.Close() - - data, err := io.ReadAll(reader) - if err != nil { - t.Fatalf("expected zero limit to leave body unlimited, got %v", err) - } - if string(data) != "abc" { - t.Fatalf("unexpected data: %q", string(data)) - } -} - -func TestShouldBindJSONHonorsMaxRequestBodySize(t *testing.T) { - t.Helper() - - body := strings.NewReader(`{"name":"abcdef"}`) - req := httptest.NewRequest(http.MethodPost, "/json", body) - req.Header.Set("Content-Type", "application/json") - - c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) - c.SetMaxRequestBodySize(8) - - var payload struct { - Name string `json:"name"` - } - - err := c.ShouldBindJSON(&payload) - if !errors.Is(err, ErrBodyTooLarge) { - t.Fatalf("expected ErrBodyTooLarge, got %v", err) - } -} - -func TestShouldBindFormHonorsMaxRequestBodySize(t *testing.T) { - t.Helper() - - body := strings.NewReader("name=abcdef") - req := httptest.NewRequest(http.MethodPost, "/form", body) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) - c.SetMaxRequestBodySize(4) - - var payload struct { - Name string `form:"name"` - } - - err := c.ShouldBindForm(&payload) - if !errors.Is(err, ErrBodyTooLarge) { - t.Fatalf("expected ErrBodyTooLarge, got %v", err) - } -} - -func TestPostFormHonorsMaxRequestBodySize(t *testing.T) { - t.Helper() - - body := strings.NewReader("name=abcdef") - req := httptest.NewRequest(http.MethodPost, "/form", body) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - c, _ := CreateTestContextWithRequest(httptest.NewRecorder(), req) - c.SetMaxRequestBodySize(4) - - if got := c.PostForm("name"); got != "" { - t.Fatalf("expected empty value on over-limit form body, got %q", got) - } - if len(c.Errors) == 0 { - t.Fatal("expected parse error to be recorded") - } - if !errors.Is(c.Errors[0], ErrBodyTooLarge) { - t.Fatalf("expected recorded error to wrap ErrBodyTooLarge, got %v", c.Errors[0]) - } -} diff --git a/context_httpc.go b/context_httpc.go deleted file mode 100644 index 3256a3b..0000000 --- a/context_httpc.go +++ /dev/null @@ -1,58 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. -// Copyright 2024 WJQSERVER. All rights reserved. -// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. -package touka - -import ( - "context" - - "github.com/WJQSERVER-STUDIO/httpc" -) - -// contextHTTPClient 包装 httpc.Client,自动关联请求的 Context -// 当请求被取消时,出站 HTTP 请求也会自动取消 -type contextHTTPClient struct { - client *httpc.Client - ctx context.Context -} - -// NewRequestBuilder 创建请求构建器,自动关联请求 Context -func (c *contextHTTPClient) NewRequestBuilder(method, urlStr string) *httpc.RequestBuilder { - return c.client.NewRequestBuilder(method, urlStr).WithContext(c.ctx) -} - -// GET 创建 GET 请求构建器 -func (c *contextHTTPClient) GET(urlStr string) *httpc.RequestBuilder { - return c.client.GET(urlStr).WithContext(c.ctx) -} - -// POST 创建 POST 请求构建器 -func (c *contextHTTPClient) POST(urlStr string) *httpc.RequestBuilder { - return c.client.POST(urlStr).WithContext(c.ctx) -} - -// PUT 创建 PUT 请求构建器 -func (c *contextHTTPClient) PUT(urlStr string) *httpc.RequestBuilder { - return c.client.PUT(urlStr).WithContext(c.ctx) -} - -// DELETE 创建 DELETE 请求构建器 -func (c *contextHTTPClient) DELETE(urlStr string) *httpc.RequestBuilder { - return c.client.DELETE(urlStr).WithContext(c.ctx) -} - -// PATCH 创建 PATCH 请求构建器 -func (c *contextHTTPClient) PATCH(urlStr string) *httpc.RequestBuilder { - return c.client.PATCH(urlStr).WithContext(c.ctx) -} - -// HEAD 创建 HEAD 请求构建器 -func (c *contextHTTPClient) HEAD(urlStr string) *httpc.RequestBuilder { - return c.client.HEAD(urlStr).WithContext(c.ctx) -} - -// OPTIONS 创建 OPTIONS 请求构建器 -func (c *contextHTTPClient) OPTIONS(urlStr string) *httpc.RequestBuilder { - return c.client.OPTIONS(urlStr).WithContext(c.ctx) -} diff --git a/docs/advanced.md b/docs/advanced.md index eb44c2d..a7cb9a2 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -44,9 +44,7 @@ r.SetTLSServerConfigurator(func(server *http.Server) { Touka 支持配置 HTTP/1.1、HTTP/2 和 H2C(HTTP/2 Cleartext): ```go -// 使用默认协议配置 -// 普通 HTTP 启动时默认为 HTTP/1.1;若使用 WithTLS(...) 且未手动覆盖协议集, -// HTTPS 服务器会默认启用 HTTP/1.1 与 HTTP/2。 +// 使用默认协议配置(仅 HTTP/1.1) r.SetDefaultProtocols() // 自定义协议配置 @@ -59,147 +57,33 @@ r.SetProtocols(&touka.ProtocolsConfig{ ### 启动方式 -Touka 统一通过 `Run(opts...)` 启动服务器: +Touka 提供了多种服务器启动方式: ```go // 1. 简单启动(无优雅停机) -r.Run(touka.WithAddr(":8080")) +r.Run(":8080") // 2. 带优雅停机的启动 -r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)) +r.RunShutdown(":8080", 10*time.Second) // 3. 带上下文的优雅停机 ctx, cancel := context.WithCancel(context.Background()) -defer cancel() -r.Run( - touka.WithAddr(":8080"), - touka.WithGracefulShutdown(10*time.Second), - touka.WithShutdownContext(ctx), -) +r.RunShutdownWithContext(":8080", ctx, 10*time.Second) // 4. HTTPS 启动 tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, // 其他 TLS 配置... } -// WithTLS(...) 与优雅关闭相互独立;这里演示 HTTPS + 默认优雅关闭超时。 -r.Run( - touka.WithAddr(":443"), - touka.WithTLS(tlsConfig), - touka.WithGracefulShutdownDefault(), -) +r.RunTLS(":443", tlsConfig, 10*time.Second) // 5. HTTPS + HTTP 重定向 -// WithHTTPRedirect(...) 需要与 WithTLS(...) 配合使用。 -r.Run( - touka.WithAddr(":443"), - touka.WithTLS(tlsConfig), - touka.WithHTTPRedirect(":80"), - touka.WithGracefulShutdown(10*time.Second), -) - -// 6. HTTPS + HTTP 重定向(按 header 顺序决定跳转 host) -r.Run( - touka.WithAddr(":443"), - touka.WithTLS(tlsConfig), - touka.WithHTTPRedirect( - ":80", - touka.WithUseHeaderHost(true), - touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), - ), -) - -// 7. HTTPS + HTTP 重定向(固定跳转到配置的 host) -r.Run( - touka.WithAddr(":443"), - touka.WithTLS(tlsConfig), - touka.WithHTTPRedirect( - ":80", - touka.WithUseHeaderHost(false), - touka.WithRedirectHost("example.com"), - ), -) +r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second) ``` -### HTTPS Redirect Host 策略 - -`WithHTTPRedirect(addr, opts...)` 除了开启 HTTP -> HTTPS 重定向外,还支持通过 redirect 子选项控制最终跳转目标的 host。 - -可用的 redirect 子选项: - -- `touka.WithUseHeaderHost(true|false)` -- `touka.WithRedirectHostHeaders([]string{...})` -- `touka.WithRedirectHost("example.com")` - -#### 模式一:使用请求输入侧的 host - -当 `WithUseHeaderHost(true)` 时: - -- 如果没有配置 `WithRedirectHostHeaders(...)`,使用 `Request.Host` -- 如果配置了 `WithRedirectHostHeaders(...)`,按给定顺序读取这些 header,并使用第一个非空值 -- 如果配置了 `WithRedirectHostHeaders(...)` 但所有 header 都为空,返回 `426 Upgrade Required` - -示例: - -```go -r.Run( - touka.WithAddr(":443"), - touka.WithTLS(tlsConfig), - touka.WithHTTPRedirect( - ":80", - touka.WithUseHeaderHost(true), - touka.WithRedirectHostHeaders([]string{"X-Forwarded-Host", "X-Original-Host"}), - ), -) -``` - -#### 模式二:使用配置的固定 host - -当 `WithUseHeaderHost(false)` 时: - -- 不读取 `Request.Host` -- 不读取 `WithRedirectHostHeaders(...)` -- 必须配置 `WithRedirectHost("example.com")` - -示例: - -```go -r.Run( - touka.WithAddr(":443"), - touka.WithTLS(tlsConfig), - touka.WithHTTPRedirect( - ":80", - touka.WithUseHeaderHost(false), - touka.WithRedirectHost("example.com"), - ), -) -``` - -#### 严格校验规则 - -以下组合会直接返回配置错误: - -- `WithHTTPRedirect(...)` 但没有 `WithTLS(...)` -- 配置了 `WithRedirectHostHeaders(...)`,但没有显式传入 `WithUseHeaderHost(true)` -- `WithUseHeaderHost(false)` 但没有配置 `WithRedirectHost(...)` -- `WithUseHeaderHost(false)` 同时配置了 `WithRedirectHostHeaders(...)` -- `WithUseHeaderHost(true)` 同时配置了 `WithRedirectHost(...)` - -#### 优先级关系 - -1. 是否启用 `WithHTTPRedirect(...)` 决定是否进入 HTTPS + redirect 模式 -2. `WithUseHeaderHost(...)` 决定 host 来源模式 -3. 当 `WithUseHeaderHost(true)` 时: - - 配置了 `WithRedirectHostHeaders(...)` 就按 header 顺序查询 - - 未配置时使用 `Request.Host` -4. 当 `WithUseHeaderHost(false)` 时: - - 只使用 `WithRedirectHost(...)` - -**注意:** `WithRedirectHostHeaders(...)` 读取的是普通请求头值。只有在您明确知道请求经过受信任代理并会正确填充这些 header 时,才建议启用它。 - ## 优雅停机 (Graceful Shutdown) -在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。启用优雅关闭后,Touka 会监听 `SIGINT`/`SIGTERM`,并在关闭时取消活动请求的上下文。 +在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。 ```go r := touka.Default() @@ -207,7 +91,7 @@ r := touka.Default() // 监听 SIGINT 和 SIGTERM 信号 // 如果在 10 秒内未处理完,则强制关闭 -if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { +if err := r.RunShutdown(":8080", 10*time.Second); err != nil { log.Fatal("服务器退出异常:", err) } ``` diff --git a/docs/httpc.md b/docs/httpc.md deleted file mode 100644 index 8742c18..0000000 --- a/docs/httpc.md +++ /dev/null @@ -1,188 +0,0 @@ -# HTTP Client (httpc) - -Touka 内置了 [httpc](https://github.com/WJQSERVER-STUDIO/httpc) HTTP 客户端,方便在请求处理函数中发起出站 HTTP 请求。 - -## 核心特性 - -- **自动 Context 关联**:使用 `HTTPC()` 方法时,出站请求会自动关联当前请求的 Context -- **请求取消传播**:当客户端断开连接时,出站请求会自动取消,避免资源泄漏 -- **链式调用**:保持 httpc 原有的组合式构建器风格 - -## 基本用法 - -### 简单 GET 请求 - -```go -r.GET("/proxy", func(c *touka.Context) { - body, err := c.HTTPC(). - GET("https://api.example.com/data"). - Text() - if err != nil { - c.JSON(500, touka.H{"error": err.Error()}) - return - } - c.String(200, body) -}) -``` - -### POST JSON 请求 - -```go -r.POST("/users", func(c *touka.Context) { - var req struct { - Name string `json:"name"` - Email string `json:"email"` - } - c.ShouldBindJSON(&req) - - var result struct { - ID int `json:"id"` - Name string `json:"name"` - } - - err := c.HTTPC(). - POST("https://api.example.com/users"). - SetHeader("Authorization", "Bearer "+token). - SetJSONBody(req). - DecodeJSON(&result) - if err != nil { - c.JSON(500, touka.H{"error": err.Error()}) - return - } - c.JSON(200, result) -}) -``` - -### 带查询参数 - -```go -r.GET("/search", func(c *touka.Context) { - query := c.Query("q") - - var result SearchResult - err := c.HTTPC(). - GET("https://api.example.com/search"). - SetQueryParam("q", query). - SetQueryParam("limit", "10"). - DecodeJSON(&result) - if err != nil { - c.JSON(500, touka.H{"error": err.Error()}) - return - } - c.JSON(200, result) -}) -``` - -## API 对比 - -### 旧方式(Deprecated) - -```go -// 需要手动 WithContext,容易忘记 -resp, err := c.Client(). - WithContext(c.Context()). - GET(url). - Execute() -``` - -### 新方式(推荐) - -```go -// 自动关联请求 Context -resp, err := c.HTTPC(). - GET(url). - Execute() -``` - -## Context 取消机制 - -使用 `HTTPC()` 时,当客户端断开连接(如关闭浏览器),出站请求会自动取消: - -```go -r.GET("/long-task", func(c *touka.Context) { - // 这个请求会在客户端断开时自动取消 - resp, err := c.HTTPC(). - GET("https://slow-api.example.com/data"). - Execute() - - // 如果客户端已断开,err 会包含 context.Canceled - if errors.Is(err, context.Canceled) { - return // 客户端已断开,无需处理 - } - // ... -}) -``` - -## 完整 API - -### contextHTTPClient 方法 - -| 方法 | 返回类型 | 说明 | -|------|----------|------| -| `NewRequestBuilder(method, url)` | `*httpc.RequestBuilder` | 创建通用请求构建器 | -| `GET(url)` | `*httpc.RequestBuilder` | 创建 GET 请求 | -| `POST(url)` | `*httpc.RequestBuilder` | 创建 POST 请求 | -| `PUT(url)` | `*httpc.RequestBuilder` | 创建 PUT 请求 | -| `DELETE(url)` | `*httpc.RequestBuilder` | 创建 DELETE 请求 | -| `PATCH(url)` | `*httpc.RequestBuilder` | 创建 PATCH 请求 | -| `HEAD(url)` | `*httpc.RequestBuilder` | 创建 HEAD 请求 | -| `OPTIONS(url)` | `*httpc.RequestBuilder` | 创建 OPTIONS 请求 | - -### httpc.RequestBuilder 链式方法 - -返回 `*httpc.RequestBuilder`(用于链式调用): - -| 方法 | 说明 | -|------|------| -| `WithContext(ctx)` | 设置 Context(通常不需要,已自动关联) | -| `NoDefaultHeaders()` | 不添加默认 Header | -| `SetHeader(key, value)` | 设置 Header | -| `AddHeader(key, value)` | 添加 Header(可重复) | -| `SetHeaders(map)` | 批量设置 Headers | -| `SetQueryParam(key, value)` | 设置查询参数 | -| `AddQueryParam(key, value)` | 添加查询参数(可重复) | -| `SetQueryParams(map)` | 批量设置查询参数 | -| `SetBody(io.Reader)` | 设置请求 Body | -| `SetRawBody([]byte)` | 设置字节 Body | - -返回 `(*httpc.RequestBuilder, error)`(可能失败): - -| 方法 | 说明 | -|------|------| -| `SetJSONBody(any)` | 设置 JSON Body | -| `SetXMLBody(any)` | 设置 XML Body | -| `SetGOBBody(any)` | 设置 GOB Body | - -### 终结方法 - -| 方法 | 返回类型 | 说明 | -|------|----------|------| -| `Build()` | `(*http.Request, error)` | 构建请求但不执行 | -| `Execute()` | `(*http.Response, error)` | 执行并返回原始响应 | -| `DecodeJSON(v)` | `error` | 执行并解码 JSON | -| `DecodeXML(v)` | `error` | 执行并解码 XML | -| `DecodeGOB(v)` | `error` | 执行并解码 GOB | -| `Text()` | `(string, error)` | 执行并返回文本 | -| `Bytes()` | `([]byte, error)` | 执行并返回字节 | -| `SSE()` | `(*SSEStream, error)` | 建立 SSE 流连接 | - -## 迁移指南 - -### go:fix inline 兼容 - -旧代码 `c.GetHTTPC()` 可通过 `go fix` 自动迁移到 `c.Client()`: - -```bash -go fix ./... -``` - -### 手动迁移 - -| 旧代码 | 新代码 | -|--------|--------| -| `c.GetHTTPC()` | `c.Client()` 或 `c.HTTPC()` | -| `c.Client().WithContext(ctx).GET(url)` | `c.HTTPC().GET(url)` | - -## 示例 - -完整示例请参考 [examples/httpc](../examples/httpc)。 diff --git a/docs/introduction.md b/docs/introduction.md index 87c3e40..94a7310 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -22,6 +22,6 @@ Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其 1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。 2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。 -3. **健壮性**: 通过 `Run(...)` 的启动选项提供优雅停机支持,使服务在更新或关闭时能更稳妥地处理进行中的请求。 +3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理。 Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。 diff --git a/docs/logger-migration-design.md b/docs/logger-migration-design.md deleted file mode 100644 index 7b2e0a6..0000000 --- a/docs/logger-migration-design.md +++ /dev/null @@ -1,400 +0,0 @@ -# Touka Logger 接口迁移方案 - -## 基于 Go 1.26 `go:fix inline` 的自动化迁移设计 - ---- - -## 一、问题分析 - -当前架构问题: -``` -Engine.LogReco → *reco.Logger (公开字段, 直接访问) -Context.GetLogger() → 返回 *reco.Logger (具体类型) -Context.Debugf/Infof... → 硬编码 c.engine.LogReco.Debugf(...) -``` - -这导致用户无法替换日志实现(如 zap/logrus)。 - ---- - -## 二、目标架构 - -``` -Engine.logger → Logger 接口 (私有) -Engine.LogReco → *reco.Logger (公开, Deprecated - 保持向后兼容) -Engine.GetLogger() → 返回 Logger 接口 -Engine.SetLogger(Logger)→ 设置日志实现 -Context.GetLogger() → 返回 Logger 接口 -Context.Debugf/Infof... → 调用 c.engine.logger.Debugf(...) -``` - ---- - -## 三、Logger 接口定义 - -```go -// logger.go -package touka - -// Logger 是日志接口,支持任意日志库实现 -type Logger interface { - Debugf(format string, args ...any) - Infof(format string, args ...any) - Warnf(format string, args ...any) - Errorf(format string, args ...any) - Fatalf(format string, args ...any) - Panicf(format string, args ...any) -} - -// CloserLogger 可选扩展,支持关闭操作 -type CloserLogger interface { - Logger - Close() error -} -``` - ---- - -## 四、Engine 结构变更 - -```go -// engine.go 变更 -type Engine struct { - // ... 其他字段保持不变 - - // logger 是新的日志接口 (私有) - logger Logger - - // logReco 是保留的 reco.Logger 引用 (私有) - // 用于向后兼容,当通过 SetLoggerReco 设置时同步到 logger - logReco *reco.Logger - - // 其他字段... -} -``` - -新增/修改方法: - -```go -// GetLogger 返回日志接口 -func (engine *Engine) GetLogger() Logger { - return engine.logger -} - -// SetLogger 设置任意 Logger 实现 -func (engine *Engine) SetLogger(l Logger) { - engine.logger = l - // 如果是 *reco.Logger 类型,同步更新 logReco - if rl, ok := l.(*reco.Logger); ok { - engine.logReco = rl - } else { - engine.logReco = nil - } -} - -// SetLoggerCfg 使用 reco.Config 配置日志 -func (engine *Engine) SetLoggerCfg(logcfg reco.Config) { - logger := NewLogger(logcfg) - engine.logger = logger - engine.logReco = logger -} -``` - ---- - -## 五、`go:fix inline` 兼容性函数 - -### 5.1 旧 API 包装函数 - -在 `compat.go` 中定义: - -```go -// compat.go -package touka - -import "github.com/fenthope/reco" - -// GetLogReco 返回 reco.Logger,用于向后兼容 -// -//go:fix inline -func (engine *Engine) GetLogReco() *reco.Logger { - return engine.logReco -} - -// SetLogReco 设置 reco.Logger,用于向后兼容 -// -//go:fix inline -func (engine *Engine) SetLogReco(l *reco.Logger) { - engine.logReco = l - engine.logger = l -} -``` - -### 5.2 Context 日志方法的 inline 包装 - -```go -// context_compat.go -package touka - -// Debugf 记录 Debug 级别日志 -// -//go:fix inline -func (c *Context) Debugf(format string, args ...any) { - c.engine.logger.Debugf(format, args...) -} - -// Infof 记录 Info 级别日志 -// -//go:fix inline -func (c *Context) Infof(format string, args ...any) { - c.engine.logger.Infof(format, args...) -} - -// Warnf 记录 Warn 级别日志 -// -//go:fix inline -func (c *Context) Warnf(format string, args ...any) { - c.engine.logger.Warnf(format, args...) -} - -// Errorf 记录 Error 级别日志 -// -//go:fix inline -func (c *Context) Errorf(format string, args ...any) { - c.engine.logger.Errorf(format, args...) -} - -// Fatalf 记录 Fatal 级别日志 -// -//go:fix inline -func (c *Context) Fatalf(format string, args ...any) { - c.engine.logger.Fatalf(format, args...) -} - -// Panicf 记录 Panic 级别日志 -// -//go:fix inline -func (c *Context) Panicf(format string, args ...any) { - c.engine.logger.Panicf(format, args...) -} -``` - -### 5.3 GetLogger 返回类型的兼容处理 - -由于 `GetLogger()` 返回类型从 `*reco.Logger` 变为 `Logger`,需要提供兼容函数: - -```go -// context_compat.go (续) - -// GetLoggerReco 返回 *reco.Logger 类型,用于需要具体类型的场景 -// -//go:fix inline -func (c *Context) GetLoggerReco() *reco.Logger { - if rl, ok := c.engine.logger.(*reco.Logger); ok { - return rl - } - return nil -} -``` - ---- - -## 六、go:fix inline 工作原理 - -### 迁移前用户代码: -```go -func handler(c *touka.Context) { - // 旧 API 调用 - c.Debugf("request: %s", c.Request.URL.Path) - c.engine.LogReco.Infof("server started") -} -``` - -### go fix 执行后(自动替换): -```go -func handler(c *touka.Context) { - // Debugf 被替换为函数体 - c.engine.logger.Debugf("request: %s", c.Request.URL.Path) - - // LogReco 访问无法通过 inline 自动处理,需要手动迁移 - // 或者通过 getter 调用 -} -``` - -### 对于字段访问的处理策略: - -`engine.LogReco` 字段访问无法直接用 `go:fix inline` 处理,采用以下策略: - -1. **保留字段但标记 deprecated**:继续导出 `LogReco` 但文档标记为 deprecated -2. **提供 getter/setter**:通过 `go:fix inline` 提供 `GetLogReco/SetLogReco` -3. **渐进迁移**:用户可以在方便时手动迁移到 `GetLogger()/SetLogger()` - ---- - -## 七、迁移前后对比 - -### 场景 1:基本日志调用 - -**迁移前:** -```go -func myHandler(c *touka.Context) { - c.Debugf("processing request %s", c.Request.URL.Path) - c.Infof("user %s logged in", username) - c.Warnf("slow query: %v", duration) - c.Errorf("db error: %v", err) -} -``` - -**迁移后(自动替换):** -```go -func myHandler(c *touka.Context) { - c.engine.logger.Debugf("processing request %s", c.Request.URL.Path) - c.engine.logger.Infof("user %s logged in", username) - c.engine.logger.Warnf("slow query: %v", duration) - c.engine.logger.Errorf("db error: %v", err) -} -``` - -### 场景 2:Engine 配置日志 - -**迁移前:** -```go -engine := touka.New() -engine.LogReco = myLogger // 直接赋值 -logger := engine.LogReco // 直接读取 -``` - -**迁移后(手动 + 自动混合):** -```go -engine := touka.New() - -// 方式 1:使用新 API(推荐) -engine.SetLogger(myLogger) -logger := engine.GetLogger() - -// 方式 2:通过 go:fix inline 自动替换为 getter -// engine.SetLogReco(myLogger) ← go fix 替换 -// logger := engine.GetLogReco() ← go fix 替换 -``` - -### 场景 3:使用第三方日志库(新功能) - -```go -import "go.uber.org/zap" - -func main() { - zapLogger, _ := zap.NewProduction() - defer zapLogger.Sync() - - engine := touka.New() - // 使用 zap 替代默认的 reco.Logger - engine.SetLogger(&ZapAdapter{logger: zapLogger}) - - engine.GET("/api", func(c *touka.Context) { - c.Infof("api called") // 自动使用 zap 输出 - }) -} - -// ZapAdapter 适配 zap 到 touka.Logger 接口 -type ZapAdapter struct { - logger *zap.Logger -} - -func (z *ZapAdapter) Debugf(format string, args ...any) { - z.logger.Debug(fmt.Sprintf(format, args...)) -} - -func (z *ZapAdapter) Infof(format string, args ...any) { - z.logger.Info(fmt.Sprintf(format, args...)) -} - -func (z *ZapAdapter) Warnf(format string, args ...any) { - z.logger.Warn(fmt.Sprintf(format, args...)) -} - -func (z *ZapAdapter) Errorf(format string, args ...any) { - z.logger.Error(fmt.Sprintf(format, args...)) -} - -func (z *ZapAdapter) Fatalf(format string, args ...any) { - z.logger.Fatal(fmt.Sprintf(format, args...)) -} - -func (z *ZapAdapter) Panicf(format string, args ...any) { - z.logger.Panic(fmt.Sprintf(format, args...)) -} -``` - ---- - -## 八、内部使用迁移 - -框架内部代码也需要迁移,将直接调用 `engine.LogReco` 改为 `engine.logger`: - -需要修改的文件: -- `context.go`: writeResponseBody 中的 `c.engine.LogReco.Errorf` -- `recovery.go`: 如有使用日志 -- `logreco.go`: CloseLogger 方法 - -```go -// context.go 修改前 -func (c *Context) writeResponseBody(data []byte, contextMsg string) { - if _, err := c.Writer.Write(data); err != nil { - if c.engine.LogReco != nil { - c.engine.LogReco.Errorf("%s: %v", contextMsg, err) - } - } -} - -// context.go 修改后 -func (c *Context) writeResponseBody(data []byte, contextMsg string) { - if _, err := c.Writer.Write(data); err != nil { - if c.engine.logger != nil { - c.engine.logger.Errorf("%s: %v", contextMsg, err) - } - } -} -``` - ---- - -## 九、完整文件结构 - -``` -touka/ -├── logger.go # Logger 接口定义 -├── logreco.go # reco.Logger 相关工具函数 -├── compat.go # go:fix inline 兼容性函数 (Engine) -├── context_compat.go # go:fix inline 兼容性函数 (Context) -├── engine.go # Engine 结构变更 -├── context.go # Context 日志方法变更 -└── ... -``` - ---- - -## 十、版本策略 - -| 版本 | 变更内容 | -|------|---------| -| v1.x | 引入 Logger 接口,LogReco 标记 deprecated | -| v2.x | 移除 LogReco 公开字段,仅通过 getter/setter 访问 | -| v3.x | 移除 go:fix inline 兼容函数 | - ---- - -## 十一、go:fix inline 限制说明 - -1. **字段访问无法自动迁移**:`engine.LogReco` 字段访问需要用户手动修改 -2. **返回类型变更需谨慎**:`GetLogger()` 返回类型变更会导致依赖具体类型的代码失败 -3. **inline 函数有大小限制**:函数体过大会影响内联效果 -4. **跨包迁移**:`go:fix inline` 支持跨包,但用户必须运行 `go fix` - ---- - -## 十二、推荐迁移步骤 - -1. **框架侧**:添加 Logger 接口,添加 go:fix inline 函数 -2. **用户侧**:运行 `go fix ./...` 自动迁移可处理的部分 -3. **用户侧**:手动将 `engine.LogReco` 字段访问改为 `engine.SetLogger()/GetLogger()` -4. **用户侧**:如需使用第三方日志,实现 Logger 接口并通过 SetLogger 设置 diff --git a/docs/middleware.md b/docs/middleware.md index b688fb5..a222437 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -26,41 +26,6 @@ api.Use(AuthMiddleware()) } ``` -也可以在创建组时直接传入中间件: - -```go -api := r.Group("/api", AuthMiddleware(), RateLimitMiddleware()) -{ - api.GET("/user", handleUser) - api.POST("/data", handleData) -} -``` - -### 路由级中间件 - -为单个路由注册中间件,仅对该路由生效。 - -```go -// 单个路由中间件 -r.GET("/protected", AuthMiddleware(), func(c *touka.Context) { - c.String(http.StatusOK, "Protected content") -}) - -// 多个路由中间件(按顺序执行) -r.POST("/upload", - RateLimitMiddleware(), - AuthMiddleware(), - PermissionCheckMiddleware(), - func(c *touka.Context) { - // 处理上传 - }, -) - -// 路由组中的单个路由也可以使用路由级中间件 -api := r.Group("/api") -api.GET("/admin", AdminAuthMiddleware(), adminHandler) -``` - ## 编写自定义中间件 中间件的函数签名是 `touka.HandlerFunc`。 @@ -102,36 +67,6 @@ func APIKeyAuth() touka.HandlerFunc { } ``` -## 中间件执行顺序 - -理解中间件的执行顺序对于构建正确的处理流程至关重要。**注意:注册顺序决定了执行逻辑**,中间件必须在注册路由之前调用(全局中间件应在创建组或定义路由前注册)。中间件按照以下顺序执行: - -```go -// 全局中间件 -r.Use(GlobalMiddleware1()) -r.Use(GlobalMiddleware2()) - -// 组中间件 -api := r.Group("/api", GroupMiddleware1()) -api.Use(GroupMiddleware2()) - -// 路由级中间件 -api.GET("/users", RouteMiddleware1(), RouteMiddleware2(), userHandler) -``` - -对于 `/api/users` 请求,执行顺序为: -1. `GlobalMiddleware1()` - 全局中间件 -2. `GlobalMiddleware2()` - 全局中间件 -3. `GroupMiddleware1()` - 路由组中间件 -4. `GroupMiddleware2()` - 路由组中间件 -5. `RouteMiddleware1()` - 路由级中间件 -6. `RouteMiddleware2()` - 路由级中间件 -7. `userHandler` - 最终处理函数 - -``` -请求进入 → 全局中间件 → 路由组中间件 → 路由级中间件 → 最终处理函数 → 路由级中间件后置逻辑 → 路由组中间件后置逻辑 → 全局中间件后置逻辑 → 响应 -``` - ## 内置中间件 - **Recovery**: 捕获任何发生的 panic,恢复运行并返回 500 错误。它还负责调用全局错误处理器。 diff --git a/docs/quickstart.md b/docs/quickstart.md index 2911732..94f7433 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -46,7 +46,7 @@ func main() { // 4. 启动服务器并监听 8080 端口 log.Println("Touka server is running on :8080") - if err := r.Run(touka.WithAddr(":8080")); err != nil { + if err := r.Run(":8080"); err != nil { log.Fatalf("Server failed: %v", err) } } @@ -66,11 +66,11 @@ go run main.go ## 优雅停机 -在生产环境中,我们推荐为 `Run` 追加优雅关闭选项。启用后,Touka 会监听 `SIGINT`/`SIGTERM`,在关闭时取消活动请求的上下文,并在超时前等待正在处理的请求完成。如需由应用内部事件触发关闭,还可以额外配合 `touka.WithShutdownContext(ctx)`。 +在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成。 ```go // 等待 10 秒以处理剩余请求 -if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { +if err := r.RunShutdown(":8080", 10*time.Second); err != nil { log.Fatalf("Server forced to shutdown: %v", err) } ``` diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index cb4b2a3..5dfcbd1 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -28,7 +28,7 @@ func main() { Target: target, })) - _ = r.Run(touka.WithAddr(":8080")) + _ = r.Run(":8080") } ``` @@ -59,16 +59,11 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ ```go type ReverseProxyConfig struct { - Target *url.URL - Targets []string - - LoadBalancing ReverseProxyLoadBalancingConfig - PassiveHealth ReverseProxyPassiveHealthConfig + Target *url.URL Transport http.RoundTripper FlushInterval time.Duration BufferPool BufferPool - AllowH2CUpstream bool ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error @@ -83,133 +78,12 @@ type ReverseProxyConfig struct { ### `Target` -与 `Targets` 二选一。表示单个后端目标地址,至少需要提供 `scheme` 和 `host`。 +必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。 ```go target, _ := url.Parse("http://backend:9000") ``` -### `Targets` - -可选。用于配置多个后端目标地址。 - -- `Target` 与 `Targets` 互斥,只能使用其中一种 -- `Targets` 的每一项都必须是完整 URL -- 每个 target 仍然可以自带 base path 和 query - -```go -r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ - Targets: []string{ - "http://127.0.0.1:9001/base?from=a", - "http://127.0.0.1:9002/base?from=b", - }, -})) -``` - -这意味着不同 upstream 仍然可以保留各自的路径前缀和固定查询参数。 - -### `LoadBalancing` - -用于配置 upstream 选择策略和重试行为。 - -```go -type ReverseProxyLoadBalancingConfig struct { - Policy ReverseProxyLBPolicy - Retries int - TryDuration time.Duration - TryInterval time.Duration -} -``` - -当前内置策略: - -- `touka.LBRandom()` -- `touka.LBRoundRobin()` -- `touka.LBFirst()` -- `touka.LBLeastConn()` -- `touka.LBIPHash()` -- `touka.LBClientIPHash()` -- `touka.LBURIHash()` -- `touka.LBHeader("X-Upstream", fallback)` -- `touka.LBQuery("tenant", fallback)` - -其中: - -- `LBFirst()` 适合主备/故障转移顺序 -- `LBHeader` / `LBQuery` 只有在对应 header/query **缺失**时才会走 fallback -- 如果 `LBHeader` / `LBQuery` 没有显式 fallback,则默认回退到 `LBRandom()` - -```go -r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ - Targets: []string{ - "http://127.0.0.1:9001", - "http://127.0.0.1:9002", - }, - LoadBalancing: touka.ReverseProxyLoadBalancingConfig{ - Policy: touka.LBHeader("X-Upstream", touka.LBFirst()), - Retries: 1, - }, -})) -``` - -重试说明: - -- 只对未开始收到上游响应的失败进行重试 -- 默认仅对 RFC 定义的安全方法(`GET` / `HEAD` / `OPTIONS` / `TRACE`)重试 -- `Retries` 表示额外重试次数 -- `TryDuration` 表示总尝试时间预算;如果配置了它,会优先于重试次数控制停止时机 -- `TryInterval` 表示两次重试之间的等待间隔 - -### `PassiveHealth` - -用于配置被动健康检查。它不会后台探测 upstream,而是根据真实代理请求的失败结果临时把某个 upstream 视为不健康。 - -```go -type ReverseProxyPassiveHealthConfig struct { - FailDuration time.Duration - MaxFails int - UnhealthyStatus []int -} -``` - -- `FailDuration > 0` 时启用被动健康跟踪 -- `MaxFails <= 0` 时默认按 `1` 处理 -- `UnhealthyStatus` 中的状态码会被记为一次失败,但当前请求仍会先收到该响应;后续请求才会绕过这个 upstream - -```go -r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ - Targets: []string{ - "http://127.0.0.1:9001", - "http://127.0.0.1:9002", - }, - LoadBalancing: touka.ReverseProxyLoadBalancingConfig{ - Policy: touka.LBFirst(), - }, - PassiveHealth: touka.ReverseProxyPassiveHealthConfig{ - FailDuration: time.Minute, - UnhealthyStatus: []int{http.StatusServiceUnavailable}, - }, -})) -``` - -### `AllowH2CUpstream` - -允许代理使用未加密 HTTP/2(h2c)与 `http://` upstream 通信。 - -- 默认关闭 -- 这是一个显式配置项 -- 启用后,Touka 会为该 upstream 使用 h2c prior-knowledge 方式连接上游 -- 这意味着上游本身也必须显式支持 h2c;它不是“先试 h2c,失败再自动回退到 h1”的协商模式 - -```go -r.GET("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ - Target: target, - AllowH2CUpstream: true, -})) -``` - -对于下游 HTTP/2 extended `CONNECT` websocket 场景,Touka 会只在该特殊桥接路径上强制与上游使用 HTTP/1.1 websocket upgrade,以匹配 Caddy 风格的桥接语义;普通 HTTP 请求不会因为这个特性而被强制降级为 HTTP/1.1。 - ### `Transport` 可选。用于自定义底层转发所使用的 `http.RoundTripper`。 @@ -276,8 +150,6 @@ r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ 在请求真正发往后端前,对出站请求做最后修改。 -如果启用了多 upstream 重试,`ModifyRequest` 可能会在同一个客户端请求里被调用多次:每一次实际发往 upstream 的尝试都会重新构造一份请求并再次执行它。因此,这个回调最好保持幂等,不要依赖“只会执行一次”的副作用。 - 常见用途: - 覆盖 `Host` @@ -370,20 +242,11 @@ const ( r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "_gateway-1", + ForwardedBy: "gateway-1", Via: "edge-1", })) ``` -如果您配置了 `ForwardedBy`,它必须是一个符合 RFC 7239 的 node identifier。 - -- IPv4:`203.0.113.43` -- IPv6 / 带端口:`[2001:db8::17]:443` -- 匿名标识:`_gateway-1` -- 未知:`unknown` - -像 `gateway-1` 这类普通 token 不再被视为合法的 `by=` 值。 - `Via` 不是“留空即禁用”的开关。当前实现中: - 如果 `Via` 非空,则使用该值追加 `Via` @@ -419,14 +282,11 @@ Touka 会尽量遵循代理链语义: Touka 的反向代理实现支持以下能力: -- `CONNECT` 隧道转发(HTTP/1.x) -- HTTP/2 extended `CONNECT` - `Connection: Upgrade` / `Upgrade` 协议升级转发 - WebSocket 等 101 Switching Protocols 场景 - SSE(Server-Sent Events)立即刷新 - Trailer 透传 - 1xx 响应透传 -- `TRACE` / `OPTIONS` 上的 `Max-Forwards` 递减与本地终止处理 例如,代理 WebSocket 服务: @@ -481,7 +341,7 @@ func main() { r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ Target: target, ForwardedHeaders: touka.ForwardedBoth, - ForwardedBy: "_gateway-1", + ForwardedBy: "gateway-1", Via: "gateway-1", FlushInterval: 100 * time.Millisecond, ModifyRequest: func(req *http.Request) { @@ -497,7 +357,7 @@ func main() { }, })) - if err := r.Run(touka.WithAddr(":8080"), touka.WithGracefulShutdown(10*time.Second)); err != nil { + if err := r.RunShutdown(":8080", 10*time.Second); err != nil { log.Fatal(err) } } diff --git a/docs/routing.md b/docs/routing.md index 70a24dc..e90308e 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -22,8 +22,6 @@ r.ANY("/any", handle) r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) ``` -服务器级 `OPTIONS *` 请求不需要单独注册路由。Touka 会直接返回一个空的 `200 OK` 响应,而不会把它当成 `/` 路由来匹配。 - ## 路径参数 (Named Parameters) 使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 @@ -142,7 +140,7 @@ func main() { r := touka.Default() fsroot, _ := fs.Sub(content, "dist") r.StaticFS("/", http.FS(fsroot)) - r.Run(touka.WithAddr(":8080")) + r.Run(":8080") } ``` diff --git a/docs/sse.md b/docs/sse.md index a003be9..1b44521 100644 --- a/docs/sse.md +++ b/docs/sse.md @@ -125,4 +125,4 @@ r.GET("/events-graceful", func(c *touka.Context) { 2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。 3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。 -**注意:** 请务必通过 `r.Run(...)` 并显式传入优雅关闭选项来启动服务器,例如 `touka.WithGracefulShutdown(...)` 或 `touka.WithGracefulShutdownDefault()`。只有启用了优雅关闭,框架才会在服务退出时取消这些请求上下文。 +**注意:** 请务必使用 `RunShutdown`、`RunTLS` 或 `RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号。 diff --git a/docs/static-files.md b/docs/static-files.md index b1f06a8..a2138cd 100644 --- a/docs/static-files.md +++ b/docs/static-files.md @@ -39,7 +39,7 @@ func main() { // 您也可以使用 StaticFS 服务根路径 // r.StaticFS("/", http.FS(fsroot)) - r.Run(touka.WithAddr(":8080")) + r.Run(":8080") } ``` diff --git a/ecw.go b/ecw.go index dedbe27..754571f 100644 --- a/ecw.go +++ b/ecw.go @@ -197,7 +197,7 @@ func (ecw *errorCapturingResponseWriter) Written() bool { func (ecw *errorCapturingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { hijacker, ok := ecw.w.(http.Hijacker) if !ok { - return nil, nil, http.ErrNotSupported + return nil, nil, errors.New("the underlying ResponseWriter does not support the Hijacker interface") } return hijacker.Hijack() } diff --git a/ecw_benchmark_test.go b/ecw_benchmark_test.go deleted file mode 100644 index d9a427c..0000000 --- a/ecw_benchmark_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package touka - -import ( - "net/http" - "net/http/httptest" - "testing" -) - -func TestErrorCapturingResponseWriterResetClearsHeaderSnapshot(t *testing.T) { - c, _ := CreateTestContext(nil) - ecw := AcquireErrorCapturingResponseWriter(c) - defer ReleaseErrorCapturingResponseWriter(ecw) - - ecw.capturedErrorSignal = true - ecw.Header().Set("Content-Type", "text/plain") - ecw.Header().Add("X-Test", "one") - - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatalf("failed to build request: %v", err) - } - - ecw.reset(httptest.NewRecorder(), req, c, c.engine.errorHandle.handler) - - if len(ecw.headerSnapshot) != 0 { - t.Fatalf("expected header snapshot to be empty after reset, got %#v", ecw.headerSnapshot) - } -} - -func BenchmarkErrorCapturingResponseWriterReset(b *testing.B) { - c, _ := CreateTestContext(nil) - ecw := AcquireErrorCapturingResponseWriter(c) - defer ReleaseErrorCapturingResponseWriter(ecw) - - rawWriter := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - b.Fatalf("failed to build request: %v", err) - } - - keys := make([]string, 16) - for i := range keys { - keys[i] = http.CanonicalHeaderKey("X-Test-" + string(rune('A'+i))) - } - values := []string{"one", "two", "three"} - for _, key := range keys { - ecw.headerSnapshot[key] = values - } - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - ecw.reset(rawWriter, req, c, c.engine.errorHandle.handler) - for _, key := range keys { - ecw.headerSnapshot[key] = values - } - } -} diff --git a/engine.go b/engine.go index 15df162..c2eae91 100644 --- a/engine.go +++ b/engine.go @@ -7,11 +7,9 @@ package touka import ( "context" "errors" - "io" "reflect" "runtime" "strings" - "unicode/utf8" "net/http" @@ -19,7 +17,6 @@ import ( "github.com/WJQSERVER-STUDIO/httpc" "github.com/fenthope/reco" - "github.com/go-json-experiment/json" ) // Last 返回链中的最后一个处理函数 @@ -52,14 +49,8 @@ type Engine struct { HTTPClient *httpc.Client // 用于在此上下文中执行出站 HTTP 请求 - // LogReco 保留的 reco.Logger 字段 - // Deprecated: 使用 SetLogger/GetLogger 替代 LogReco *reco.Logger - // logger 是新的日志接口,支持任意 Logger 实现 - // 优先级: logger > LogReco - logger Logger - HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 routesInfo []RouteInfo // 存储所有注册的路由信息 @@ -90,11 +81,6 @@ type Engine struct { // GlobalMaxRequestBodySize 全局请求体Body大小限制 GlobalMaxRequestBodySize int64 - - notFoundChain HandlersChain - notFoundNoMethodChain HandlersChain - unmatchedFSChain HandlersChain - unmatchedFSNoMethodChain HandlersChain } // HandleFunc 注册一个或多个 HTTP 方法的路由 @@ -130,90 +116,6 @@ type ErrorHandle struct { type ErrorHandler func(c *Context, code int, err error) -var errMethodNotAllowed = errors.New("method not allowed") -var errNotFound = errors.New("not found") - -type defaultErrorResponse struct { - Code int `json:"code"` - Message string `json:"message"` - Error string `json:"error"` -} - -var defaultNotFoundBody = mustMarshalDefaultErrorBody(http.StatusNotFound, errNotFound.Error()) -var defaultMethodNotAllowedBody = mustMarshalDefaultErrorBody(http.StatusMethodNotAllowed, errMethodNotAllowed.Error()) - -func mustMarshalDefaultErrorBody(code int, errMsg string) []byte { - body, err := json.Marshal(defaultErrorResponse{ - Code: code, - Message: http.StatusText(code), - Error: errMsg, - }) - if err != nil { - panic(err) - } - return body -} - -func writeDefaultErrorJSON(c *Context, code int, body []byte) { - if c == nil || c.Writer == nil { - return - } - c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") - c.Writer.WriteHeader(code) - c.writeResponseBody(body, "failed to write default error response") - c.Writer.Flush() - c.Abort() -} - -var methodNotAllowedHandler HandlerFunc = func(c *Context) { - httpMethod := c.Request.Method - requestPath := routeLookupPath(c.Request) - engine := c.engine - // 是否是OPTIONS方式 - if httpMethod == http.MethodOptions { - // 如果是 OPTIONS 请求,尝试查找所有允许的方法 - allowedMethods := engine.allowedMethodsForPath(requestPath, c.allowedMethodsBuf[:0]) - c.allowedMethodsBuf = allowedMethods[:0] - if len(allowedMethods) > 0 { - // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 - allowHeader := c.allowHeaderBuf[:0] - for i, method := range allowedMethods { - if i > 0 { - allowHeader = append(allowHeader, ',', ' ') - } - allowHeader = append(allowHeader, method...) - } - c.allowHeaderBuf = allowHeader[:0] - c.Writer.Header().Set("Allow", string(allowHeader)) - c.Status(http.StatusOK) - return - } - return - } - // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 - tempSkippedNodes := GetTempSkippedNodes() - for _, treeIter := range engine.methodTrees { - if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 - continue - } - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - *tempSkippedNodes = (*tempSkippedNodes)[:0] - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 - if value.handlers != nil { - PutTempSkippedNodes(tempSkippedNodes) - // 使用定义的ErrorHandle处理 - engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errMethodNotAllowed) - return - } - } - PutTempSkippedNodes(tempSkippedNodes) -} - -var notFoundHandler HandlerFunc = func(c *Context) { - engine := c.engine - engine.errorHandle.handler(c, http.StatusNotFound, errNotFound) -} - // defaultErrorHandle 默认错误处理 func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是否已断开连接 select { @@ -224,22 +126,16 @@ func defaultErrorHandle(c *Context, code int, err error) { // 检查客户端是 if c.Writer.Written() { return } - if len(c.Errors) == 0 { - switch { - case code == http.StatusNotFound && errors.Is(err, errNotFound): - writeDefaultErrorJSON(c, code, defaultNotFoundBody) - return - case code == http.StatusMethodNotAllowed && errors.Is(err, errMethodNotAllowed): - writeDefaultErrorJSON(c, code, defaultMethodNotAllowedBody) - return - } - } // 输出json 状态码与状态码对应描述 var errMsg string if err != nil { errMsg = err.Error() } - c.JSON(code, defaultErrorResponse{Code: code, Message: http.StatusText(code), Error: errMsg}) + c.JSON(code, H{ + "code": code, + "message": http.StatusText(code), + "error": errMsg, + }) c.Writer.Flush() c.Abort() return @@ -314,7 +210,6 @@ func New() *Engine { TLSServerConfigurator: nil, GlobalMaxRequestBodySize: -1, } - engine.rebuildFallbackChains() engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() @@ -370,32 +265,18 @@ func (engine *Engine) SetRedirectFixedPath(enable bool) { // 是否开启MethodNotAllowed func (engine *Engine) SetHandleMethodNotAllowed(enable bool) { engine.HandleMethodNotAllowed = enable - engine.rebuildFallbackChains() } -// SetLogger 传入 Logger 接口实例 -func (engine *Engine) SetLogger(logger Logger) { - engine.logger = logger - // 同步更新 LogReco 以保持向后兼容 - if rl, ok := logger.(*reco.Logger); ok { - engine.LogReco = rl - } else { - engine.LogReco = nil - } -} - -// GetLogger 返回 Logger 接口实例 -func (engine *Engine) GetLogger() Logger { - return engine.logger -} - -// SetLoggerCfg 使用 reco.Config 配置日志 -func (engine *Engine) SetLoggerCfg(logcfg reco.Config) { - logger := NewLogger(logcfg) - engine.logger = logger +// SetLogger传入实例 +func (engine *Engine) SetLogger(logger *reco.Logger) { engine.LogReco = logger } +// 配置日志LoggerCfg +func (engine *Engine) SetLoggerCfg(logcfg reco.Config) { + engine.LogReco = NewLogger(logcfg) +} + // 设置自定义错误处理 func (engine *Engine) SetErrorHandler(handler ErrorHandler) { engine.errorHandle.useDefault = false @@ -424,7 +305,6 @@ func (engine *Engine) SetUnMatchFSChain(fs http.FileSystem, handlers ...HandlerF engine.unMatchFS.ServeUnmatchedAsFS = false engine.UnMatchFSRoutes = nil } - engine.rebuildFallbackChains() } // 获取默认Protocol配置 @@ -460,28 +340,11 @@ func (engine *Engine) setProtocols(config *ProtocolsConfig) { }() } -func cloneServerProtocols(protocols *http.Protocols) *http.Protocols { - if protocols == nil { - return nil - } - cloned := *protocols - return &cloned -} - -func applyServerProtocols(srv *http.Server, protocols *http.Protocols) { - if protocols != nil { - srv.Protocols = cloneServerProtocols(protocols) - if srv.Protocols.HTTP2() || srv.Protocols.UnencryptedHTTP2() { - if err := configureHTTP2ExtendedConnectServer(srv); err != nil { - panic(err) - } - } - } -} - // applyDefaultServerConfig 应用框架的默认配置到 http.Server func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { - applyServerProtocols(srv, engine.serverProtocols) + if engine.serverProtocols != nil { + srv.Protocols = engine.serverProtocols + } } // 配置全局Req Body大小限制 @@ -610,64 +473,66 @@ func PutTempSkippedNodes(skippedNodes *[]skippedNode) { // 405中间件 func MethodNotAllowed() HandlerFunc { - return methodNotAllowedHandler + return func(c *Context) { + httpMethod := c.Request.Method + requestPath := c.Request.URL.Path + engine := c.engine + // 是否是OPTIONS方式 + if httpMethod == http.MethodOptions { + // 如果是 OPTIONS 请求,尝试查找所有允许的方法 + allowedMethods := []string{} + for _, treeIter := range engine.methodTrees { + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) + PutTempSkippedNodes(tempSkippedNodes) + if value.handlers != nil { + allowedMethods = append(allowedMethods, treeIter.method) + } + } + if len(allowedMethods) > 0 { + // 如果找到了允许的方法,返回 200 OK 并设置 Allow 头部 + c.Writer.Header().Set("Allow", strings.Join(allowedMethods, ", ")) + c.Status(http.StatusOK) + return + } + } + // 尝试遍历所有方法树,看是否有其他方法可以匹配当前路径 + for _, treeIter := range engine.methodTrees { + if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 + continue + } + // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 + PutTempSkippedNodes(tempSkippedNodes) + if value.handlers != nil { + // 使用定义的ErrorHandle处理 + engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) + return + } + } + } } // 404最后处理 func NotFound() HandlerFunc { - return notFoundHandler + return func(c *Context) { + engine := c.engine + engine.errorHandle.handler(c, http.StatusNotFound, errors.New("not found")) + } } // 传入并设置NoRoute (这不是最后一个处理, 你仍可以next到默认的404处理) func (Engine *Engine) NoRoute(handler HandlerFunc) { Engine.noRoute = handler Engine.noRoutes = nil - Engine.rebuildFallbackChains() } // 传入并设置NoRoutes (这不是最后一个处理, 你仍可以next到默认的404处理) func (Engine *Engine) NoRoutes(handlerFuncs ...HandlerFunc) { Engine.noRoute = nil Engine.noRoutes = handlerFuncs - Engine.rebuildFallbackChains() -} - -func (engine *Engine) rebuildFallbackChains() { - buildChain := func(includeMethodNotAllowed bool, includeUnmatchedFS bool) HandlersChain { - finalSize := len(engine.globalHandlers) + 1 // 最后的 NotFound - if includeMethodNotAllowed { - finalSize++ - } - if includeUnmatchedFS { - finalSize += len(engine.UnMatchFSRoutes) - } - if engine.noRoute != nil { - finalSize++ - } else { - finalSize += len(engine.noRoutes) - } - - chain := make(HandlersChain, 0, finalSize) - chain = append(chain, engine.globalHandlers...) - if includeMethodNotAllowed { - chain = append(chain, methodNotAllowedHandler) - } - if includeUnmatchedFS { - chain = append(chain, engine.UnMatchFSRoutes...) - } - if engine.noRoute != nil { - chain = append(chain, engine.noRoute) - } else if len(engine.noRoutes) > 0 { - chain = append(chain, engine.noRoutes...) - } - chain = append(chain, notFoundHandler) - return chain - } - - engine.notFoundChain = buildChain(engine.HandleMethodNotAllowed, false) - engine.notFoundNoMethodChain = buildChain(false, false) - engine.unmatchedFSChain = buildChain(engine.HandleMethodNotAllowed, engine.unMatchFS.ServeUnmatchedAsFS) - engine.unmatchedFSNoMethodChain = buildChain(false, engine.unMatchFS.ServeUnmatchedAsFS) } // combineHandlers 组合多个处理函数链为一个 @@ -682,9 +547,8 @@ func (engine *Engine) combineHandlers(h1 HandlersChain, h2 HandlersChain) Handle // Use 将全局中间件添加到 Engine // 这些中间件将应用于所有注册的路由 -func (engine *Engine) Use(middleware ...HandlerFunc) Router { +func (engine *Engine) Use(middleware ...HandlerFunc) IRouter { engine.globalHandlers = append(engine.globalHandlers, middleware...) - engine.rebuildFallbackChains() return engine } @@ -751,7 +615,7 @@ func (engine *Engine) GetRouterInfo() []RouteInfo { // Group 创建一个新的路由组 // 路由组允许将具有相同前缀路径和/或共享中间件的路由组织在一起 -func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router { +func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) IRouter { return &RouterGroup{ Handlers: engine.combineHandlers(engine.globalHandlers, handlers), // 继承全局中间件 basePath: resolveRoutePath("/", relativePath), @@ -760,7 +624,7 @@ func (engine *Engine) Group(relativePath string, handlers ...HandlerFunc) Router } // RouterGroup 表示一个路由分组,可以添加组特定的中间件和路由 -// 它也实现了 Router 接口,允许嵌套分组 +// 它也实现了 IRouter 接口,允许嵌套分组 type RouterGroup struct { Handlers HandlersChain // 组中间件,仅应用于当前组及其子组的路由 basePath string // 组路径前缀 @@ -769,7 +633,7 @@ type RouterGroup struct { // Use 将中间件应用于当前路由组 // 这些中间件将应用于当前组及其子组的所有路由 -func (group *RouterGroup) Use(middleware ...HandlerFunc) Router { +func (group *RouterGroup) Use(middleware ...HandlerFunc) IRouter { group.Handlers = append(group.Handlers, middleware...) return group } @@ -815,7 +679,7 @@ func (group *RouterGroup) ANY(relativePath string, handlers ...HandlerFunc) { } // Group 为当前组创建一个新的子组 -func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) Router { +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) IRouter { return &RouterGroup{ Handlers: group.engine.combineHandlers(group.Handlers, handlers), basePath: resolveRoutePath(group.basePath, relativePath), @@ -840,13 +704,8 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // handleRequest 负责根据请求查找路由并执行相应的处理函数链 // 这是路由查找和执行的核心逻辑 func (engine *Engine) handleRequest(c *Context) { - if isGeneralOptionsRequest(c.Request) { - engine.handleGeneralOptions(c) - return - } - httpMethod := c.Request.Method - requestPath := routeLookupPath(c.Request) + requestPath := c.Request.URL.Path // 查找对应的路由树的根节点 rootNode := engine.methodTrees.get(httpMethod) // 这里获取到的 rootNode 已经是 *node 类型 @@ -866,7 +725,7 @@ func (engine *Engine) handleRequest(c *Context) { } // 如果没有找到处理函数,检查是否需要重定向(尾部斜杠或大小写修复) - if httpMethod != http.MethodConnect && requestPath != "/" && !isGeneralOptionsRequest(c.Request) { // CONNECT 方法、服务器级 OPTIONS 和根路径不进行重定向 + if httpMethod != http.MethodConnect && requestPath != "/" { // CONNECT 方法和根路径不进行重定向 if value.tsr && engine.RedirectTrailingSlash { // 尾部斜杠重定向:/foo/ -> /foo 或 /foo -> /foo/ redirectPath := requestPath @@ -878,98 +737,51 @@ func (engine *Engine) handleRequest(c *Context) { c.Redirect(http.StatusMovedPermanently, redirectPath) // 301 永久重定向 return } - if engine.RedirectFixedPath && shouldTryFixedPathLookup(requestPath, rootNode) { - // 仅在启用固定路径重定向时执行大小写修复查找, 避免无意义的二次树遍历. - ciPath, found := rootNode.findCaseInsensitivePathWithBuffer(requestPath, c.fixedPathBuf, engine.RedirectTrailingSlash) - if found { - c.fixedPathBuf = ciPath[:0] - c.Redirect(http.StatusMovedPermanently, string(ciPath)) // 301 永久重定向到修正后的路径 - return - } - c.fixedPathBuf = c.fixedPathBuf[:0] + // 尝试不区分大小写的查找 + // 直接在 rootNode 上调用 findCaseInsensitivePath 方法 + ciPath, found := rootNode.findCaseInsensitivePath(requestPath, engine.RedirectTrailingSlash) + if found && engine.RedirectFixedPath { + c.Redirect(http.StatusMovedPermanently, BytesToString(ciPath)) // 301 永久重定向到修正后的路径 + return } } } - if engine.unMatchFS.ServeUnmatchedAsFS { - c.handlers = engine.unmatchedFSChain - } else { - c.handlers = engine.notFoundChain + // 构建处理链 + // 组合全局中间件和路由处理函数 + handlers := engine.globalHandlers + + // 如果启用了 MethodNotAllowed 处理,并且没有找到精确匹配的路由 + // 则在全局中间件之后添加 MethodNotAllowed 处理器 + if engine.HandleMethodNotAllowed { + handlers = append(handlers, MethodNotAllowed()) } + + // 如果启用了 UnMatchFS 处理,并且没有找到精确匹配的路由和 MethodNotAllowed + // 则在处理链的最后添加 UnMatchFS 处理器 + if engine.unMatchFS.ServeUnmatchedAsFS { + /* + var unMatchFSHandle = c.engine.unMatchFileServer + handlers = append(handlers, unMatchFSHandle) + */ + handlers = append(handlers, engine.UnMatchFSRoutes...) + } + + // 如果用户设置了 NoRoute 处理器,且没有匹配到任何路由、MethodNotAllowed 或 UnMatchFS + // 则在处理链的最后添加 NoRoute 处理器 + if engine.noRoute != nil { + handlers = append(handlers, engine.noRoute) + } else if len(engine.noRoutes) > 0 { + handlers = append(handlers, engine.noRoutes...) + } + + handlers = append(handlers, NotFound()) + + c.handlers = handlers c.Next() // 执行处理函数链 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } -func routeLookupPath(req *http.Request) string { - if req == nil { - return "" - } - - if req.Method == http.MethodConnect && req.RequestURI != "" && req.RequestURI != "*" && !strings.HasPrefix(req.RequestURI, "/") && !strings.Contains(req.RequestURI, "://") { - return "/" + req.RequestURI - } - if isGeneralOptionsRequest(req) { - return "" - } - if req.URL == nil { - return "" - } - return req.URL.Path -} - -func isGeneralOptionsRequest(req *http.Request) bool { - return req != nil && req.Method == http.MethodOptions && req.RequestURI == "*" -} - -func shouldTryFixedPathLookup(path string, root *node) bool { - if root != nil && root.hasCaseInsensitivePath { - return true - } - for i := 0; i < len(path); i++ { - c := path[i] - if c >= utf8.RuneSelf { - return true - } - if c >= 'A' && c <= 'Z' { - return true - } - } - return false -} - -func (engine *Engine) allowedMethodsForPath(requestPath string, allowedMethods []string) []string { - if cap(allowedMethods) < len(engine.methodTrees) { - allowedMethods = make([]string, 0, len(engine.methodTrees)) - } else { - allowedMethods = allowedMethods[:0] - } - tempSkippedNodes := GetTempSkippedNodes() - for _, treeIter := range engine.methodTrees { - // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - *tempSkippedNodes = (*tempSkippedNodes)[:0] - value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) - if value.handlers != nil { - allowedMethods = append(allowedMethods, treeIter.method) - } - } - PutTempSkippedNodes(tempSkippedNodes) - return allowedMethods -} - -func (engine *Engine) handleGeneralOptions(c *Context) { - if c == nil || c.Request == nil { - return - } - - c.Writer.Header().Set("Content-Length", "0") - if c.Request.ContentLength != 0 { - mb := http.MaxBytesReader(c.Writer, c.Request.Body, 4<<10) - _, _ = io.Copy(io.Discard, mb) - } - c.Writer.WriteHeader(http.StatusOK) - c.Abort() -} - // Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. // 它可以用于在长连接 (如 SSE) 中监听关闭信号. func (engine *Engine) Context() context.Context { diff --git a/engine_benchmark_test.go b/engine_benchmark_test.go deleted file mode 100644 index 666e8b2..0000000 --- a/engine_benchmark_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package touka - -import ( - "net/http" - "net/http/httptest" - "testing" -) - -var benchmarkStatusCode int - -func buildServeHTTPBenchmarkEngine() *Engine { - engine := New() - engine.GET("/api/v1/users", func(c *Context) { - c.Status(http.StatusNoContent) - }) - engine.GET("/api/v1/users/:id", func(c *Context) { - c.Status(http.StatusNoContent) - }) - engine.GET("/api/v1/users/:id/settings", func(c *Context) { - c.Status(http.StatusNoContent) - }) - engine.POST("/api/v1/users", func(c *Context) { - c.Status(http.StatusNoContent) - }) - return engine -} - -func benchmarkServeHTTP(b *testing.B, engine *Engine, method, path string) { - b.Helper() - - req, err := http.NewRequest(method, path, nil) - if err != nil { - b.Fatalf("failed to build request: %v", err) - } - rr := httptest.NewRecorder() - engine.ServeHTTP(rr, req) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - rr = httptest.NewRecorder() - engine.ServeHTTP(rr, req) - } - - benchmarkStatusCode = rr.Code -} - -func BenchmarkServeHTTP(b *testing.B) { - engine := buildServeHTTPBenchmarkEngine() - - b.Run("StaticHit", func(b *testing.B) { - benchmarkServeHTTP(b, engine, http.MethodGet, "/api/v1/users") - }) - - b.Run("NotFound", func(b *testing.B) { - benchmarkServeHTTP(b, engine, http.MethodGet, "/does/not/exist") - }) - - b.Run("MethodNotAllowed", func(b *testing.B) { - benchmarkServeHTTP(b, engine, http.MethodDelete, "/api/v1/users") - }) - - b.Run("OptionsAllow", func(b *testing.B) { - benchmarkServeHTTP(b, engine, http.MethodOptions, "/api/v1/users") - }) - - b.Run("FixedPathRedirect", func(b *testing.B) { - benchmarkServeHTTP(b, engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS") - }) -} diff --git a/engine_test.go b/engine_test.go deleted file mode 100644 index 4772810..0000000 --- a/engine_test.go +++ /dev/null @@ -1,306 +0,0 @@ -package touka - -import ( - "bufio" - "encoding/json" - "errors" - "html/template" - "net" - "net/http" - "testing" -) - -type failingResponseWriter struct { - header http.Header - status int - err error -} - -func (w *failingResponseWriter) Header() http.Header { - if w.header == nil { - w.header = make(http.Header) - } - return w.header -} - -func (w *failingResponseWriter) WriteHeader(statusCode int) { - if w.status == 0 { - w.status = statusCode - } -} - -func (w *failingResponseWriter) Write(p []byte) (int, error) { - if w.status == 0 { - w.status = http.StatusOK - } - if w.err != nil { - return 0, w.err - } - return len(p), nil -} - -func (w *failingResponseWriter) Flush() {} - -func (w *failingResponseWriter) Status() int { - return w.status -} - -func (w *failingResponseWriter) Size() int { - return 0 -} - -func (w *failingResponseWriter) Written() bool { - return w.status != 0 -} - -func (w *failingResponseWriter) IsHijacked() bool { - return false -} - -func (w *failingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return nil, nil, http.ErrNotSupported -} - -func TestHandleRequestRedirectFixedPath(t *testing.T) { - engine := New() - engine.GET("/api/v1/users/:id/settings", func(c *Context) { - c.Status(http.StatusNoContent) - }) - - rr := PerformRequest(engine, http.MethodGet, "/API/V1/USERS/123/SETTINGS", nil, nil) - if rr.Code != http.StatusMovedPermanently { - t.Fatalf("expected fixed-path redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) - } - if location := rr.Header().Get("Location"); location != "/api/v1/users/123/settings" { - t.Fatalf("expected fixed-path redirect location %q, got %q", "/api/v1/users/123/settings", location) - } -} - -func TestHandleRequestSkipsFixedPathLookupForLowercaseMiss(t *testing.T) { - engine := New() - engine.GET("/api/v1/users/:id/settings", func(c *Context) { - c.Status(http.StatusNoContent) - }) - - rr := PerformRequest(engine, http.MethodGet, "/does/not/exist", nil, nil) - if rr.Code != http.StatusNotFound { - t.Fatalf("expected lowercase miss to stay as 404, got %d", rr.Code) - } -} - -func TestHandleRequestKeepsFixedPathLookupForUppercaseMiss(t *testing.T) { - engine := New() - engine.GET("/Users/Profile", func(c *Context) { - c.Status(http.StatusNoContent) - }) - - rr := PerformRequest(engine, http.MethodGet, "/users/profile", nil, nil) - if rr.Code != http.StatusMovedPermanently { - t.Fatalf("expected uppercase route miss to trigger fixed-path redirect, got %d", rr.Code) - } - if location := rr.Header().Get("Location"); location != "/Users/Profile" { - t.Fatalf("expected uppercase route redirect location %q, got %q", "/Users/Profile", location) - } -} - -func TestHandleRequestFixedPathLookupMissDoesNotPanic(t *testing.T) { - engine := New() - engine.GET("/Users/Profile", func(c *Context) { - c.Status(http.StatusNoContent) - }) - - defer func() { - if r := recover(); r != nil { - t.Fatalf("unexpected panic for fixed-path miss: %v", r) - } - }() - - rr := PerformRequest(engine, http.MethodGet, "/users/unknown", nil, nil) - if rr.Code != http.StatusNotFound { - t.Fatalf("expected fixed-path miss to stay as 404, got %d", rr.Code) - } -} - -func TestNoRouteCanContinueToDefaultNotFound(t *testing.T) { - engine := New() - engine.NoRoute(func(c *Context) { - c.Writer.Header().Set("X-NoRoute", "hit") - c.Next() - }) - - rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) - if rr.Code != http.StatusNotFound { - t.Fatalf("expected default not found status %d, got %d", http.StatusNotFound, rr.Code) - } - if got := rr.Header().Get("X-NoRoute"); got != "hit" { - t.Fatalf("expected NoRoute middleware header to be preserved, got %q", got) - } -} - -func TestMethodNotAllowedDoesNotContinueToNoRoute(t *testing.T) { - engine := New() - engine.GET("/users", func(c *Context) { - c.Status(http.StatusNoContent) - }) - engine.NoRoute(func(c *Context) { - c.Writer.Header().Set("X-NoRoute", "hit") - c.Next() - }) - - rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) - if rr.Code != http.StatusMethodNotAllowed { - t.Fatalf("expected method not allowed status %d, got %d", http.StatusMethodNotAllowed, rr.Code) - } - if got := rr.Header().Get("X-NoRoute"); got != "" { - t.Fatalf("expected NoRoute chain to be skipped after 405, got header %q", got) - } -} - -func TestOptionsAllowHeaderListsMatchingMethods(t *testing.T) { - engine := New() - engine.GET("/users", func(c *Context) { - c.Status(http.StatusNoContent) - }) - engine.POST("/users", func(c *Context) { - c.Status(http.StatusNoContent) - }) - - rr := PerformRequest(engine, http.MethodOptions, "/users", nil, nil) - if rr.Code != http.StatusOK { - t.Fatalf("expected OPTIONS allow status %d, got %d", http.StatusOK, rr.Code) - } - allow := rr.Header().Get("Allow") - if allow != "GET, POST" && allow != "POST, GET" { - t.Fatalf("expected Allow header to list matching methods, got %q", allow) - } -} - -func TestDefaultErrorHandleJSONShape(t *testing.T) { - engine := New() - rr := PerformRequest(engine, http.MethodGet, "/missing", nil, nil) - if rr.Code != http.StatusNotFound { - t.Fatalf("expected status %d, got %d", http.StatusNotFound, rr.Code) - } - - var body struct { - Code int `json:"code"` - Message string `json:"message"` - Error string `json:"error"` - } - if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { - t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) - } - if body.Code != http.StatusNotFound || body.Message != http.StatusText(http.StatusNotFound) || body.Error != "not found" { - t.Fatalf("unexpected error payload: %+v", body) - } -} - -func TestDefaultMethodNotAllowedJSONShape(t *testing.T) { - engine := New() - engine.GET("/users", func(c *Context) { - c.Status(http.StatusNoContent) - }) - - rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) - if rr.Code != http.StatusMethodNotAllowed { - t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) - } - - var body struct { - Code int `json:"code"` - Message string `json:"message"` - Error string `json:"error"` - } - if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { - t.Fatalf("expected JSON error body, got %q: %v", rr.Body.String(), err) - } - if body.Code != http.StatusMethodNotAllowed || body.Message != http.StatusText(http.StatusMethodNotAllowed) || body.Error != "method not allowed" { - t.Fatalf("unexpected error payload: %+v", body) - } -} - -func TestCustomErrorHandlerStillOverridesDefaultFastPath(t *testing.T) { - engine := New() - engine.SetErrorHandler(func(c *Context, code int, err error) { - c.Writer.Header().Set("X-Custom-Error", "1") - c.String(code, "custom:%v", err) - }) - engine.GET("/users", func(c *Context) { - c.Status(http.StatusNoContent) - }) - - rr := PerformRequest(engine, http.MethodDelete, "/users", nil, nil) - if rr.Code != http.StatusMethodNotAllowed { - t.Fatalf("expected status %d, got %d", http.StatusMethodNotAllowed, rr.Code) - } - if got := rr.Header().Get("X-Custom-Error"); got != "1" { - t.Fatalf("expected custom error header, got %q", got) - } - if rr.Body.String() != "custom:method not allowed" { - t.Fatalf("expected custom error body, got %q", rr.Body.String()) - } -} - -func TestResponseHelpersCaptureWriteErrors(t *testing.T) { - testCases := []struct { - name string - run func(*Context) - }{ - {name: "Raw", run: func(c *Context) { c.Raw(http.StatusOK, "application/octet-stream", []byte("payload")) }}, - {name: "String", run: func(c *Context) { c.String(http.StatusOK, "value=%d", 1) }}, - {name: "Text", run: func(c *Context) { c.Text(http.StatusOK, "payload") }}, - {name: "JSONBuf", run: func(c *Context) { c.JSONBuf(http.StatusOK, map[string]string{"a": "b"}) }}, - {name: "GOBBuf", run: func(c *Context) { c.GOBBuf(http.StatusOK, struct{ A string }{A: "b"}) }}, - {name: "WANFBuf", run: func(c *Context) { c.WANFBuf(http.StatusOK, map[string]string{"a": "b"}) }}, - {name: "HTMLFallback", run: func(c *Context) { c.HTML(http.StatusOK, "page", map[string]string{"a": "b"}) }}, - {name: "HTMLBuf", run: func(c *Context) { - c.engine.HTMLRender = template.Must(template.New("page").Parse(`{{.a}}`)) - c.HTMLBuf(http.StatusOK, "page", map[string]string{"a": "b"}) - }}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - writerErr := errors.New("write failed") - w := &failingResponseWriter{err: writerErr} - c, _ := CreateTestContext(w) - - tc.run(c) - - if got := len(c.Errors); got != 1 { - t.Fatalf("expected exactly one captured error, got %d", got) - } - if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) { - t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1]) - } - }) - } -} - -func TestDefaultErrorFastPathCapturesWriteErrors(t *testing.T) { - writerErr := errors.New("write failed") - w := &failingResponseWriter{err: writerErr} - engine := New() - c, _ := CreateTestContext(w) - c.engine = engine - req, err := http.NewRequest(http.MethodGet, "/missing", nil) - if err != nil { - t.Fatalf("failed to build request: %v", err) - } - c.reset(w, req) - - defaultErrorHandle(c, http.StatusNotFound, errNotFound) - - if len(c.Errors) == 0 { - t.Fatal("expected write error to be captured") - } - if !errors.Is(c.Errors[len(c.Errors)-1], writerErr) { - t.Fatalf("expected captured error to wrap write failure, got %v", c.Errors[len(c.Errors)-1]) - } - if c.Writer.Status() != http.StatusNotFound { - t.Fatalf("expected status %d, got %d", http.StatusNotFound, c.Writer.Status()) - } - if !c.IsAborted() { - t.Fatal("expected fast path to abort context") - } -} diff --git a/examples/httpc/main.go b/examples/httpc/main.go deleted file mode 100644 index db2be4f..0000000 --- a/examples/httpc/main.go +++ /dev/null @@ -1,103 +0,0 @@ -package main - -import ( - "fmt" - "net/http" - - "github.com/infinite-iroha/touka" -) - -func main() { - r := touka.Default() - - // 示例 1:简单 GET 请求(自动关联请求 Context) - r.GET("/proxy", func(c *touka.Context) { - // 使用 HTTPC() 方法,自动关联请求 Context - // 当客户端断开连接时,出站请求也会自动取消 - body, err := c.HTTPC(). - GET("https://httpbin.org/get"). - Text() - if err != nil { - c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) - return - } - c.String(http.StatusOK, "%s", body) - }) - - // 示例 2:带 Header 的 POST 请求 - r.POST("/users", func(c *touka.Context) { - var req struct { - Name string `json:"name"` - Email string `json:"email"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()}) - return - } - - var result struct { - ID int `json:"id"` - Name string `json:"name"` - } - - // 链式调用,保持 httpc 风格 - // 注意:SetJSONBody 返回 (*RequestBuilder, error) - rb, err := c.HTTPC(). - POST("https://httpbin.org/post"). - SetHeader("X-API-Key", "secret"). - SetJSONBody(req) - if err != nil { - c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) - return - } - if err := rb.DecodeJSON(&result); err != nil { - c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, result) - }) - - // 示例 3:带查询参数的请求 - r.GET("/search", func(c *touka.Context) { - query := c.DefaultQuery("q", "") - page := c.DefaultQuery("page", "1") - - var result struct { - Items []string `json:"items"` - Total int `json:"total"` - } - - err := c.HTTPC(). - GET("https://httpbin.org/get"). - SetQueryParam("q", query). - SetQueryParam("page", page). - DecodeJSON(&result) - if err != nil { - c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, result) - }) - - // 示例 4:使用底层 httpc.Client(旧方式,仍可用但不推荐) - r.GET("/legacy", func(c *touka.Context) { - // 旧方式:需要手动 WithContext - body, err := c.Client(). - GET("https://httpbin.org/get"). - WithContext(c.Context()). - Text() - if err != nil { - c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) - return - } - c.String(http.StatusOK, "%s", body) - }) - - fmt.Println("Server running on :8080") - fmt.Println("Try:") - fmt.Println(" curl http://localhost:8080/proxy") - fmt.Println(" curl -X POST -d '{\"name\":\"test\",\"email\":\"test@example.com\"}' http://localhost:8080/users") - fmt.Println(" curl 'http://localhost:8080/search?q=golang&page=1'") - - // r.Run(touka.WithAddr(":8080")) -} diff --git a/examples/logger_slog/main.go b/examples/logger_slog/main.go deleted file mode 100644 index 2263960..0000000 --- a/examples/logger_slog/main.go +++ /dev/null @@ -1,71 +0,0 @@ -package main - -import ( - "fmt" - "log/slog" - "net/http" - "os" - - "github.com/infinite-iroha/touka" -) - -// SlogAdapter 将 slog.Logger 适配到 touka.Logger 接口 -type SlogAdapter struct { - logger *slog.Logger -} - -func NewSlogAdapter(handler slog.Handler) *SlogAdapter { - return &SlogAdapter{ - logger: slog.New(handler), - } -} - -func (s *SlogAdapter) Debugf(format string, args ...any) { - s.logger.Debug(fmt.Sprintf(format, args...)) -} - -func (s *SlogAdapter) Infof(format string, args ...any) { - s.logger.Info(fmt.Sprintf(format, args...)) -} - -func (s *SlogAdapter) Warnf(format string, args ...any) { - s.logger.Warn(fmt.Sprintf(format, args...)) -} - -func (s *SlogAdapter) Errorf(format string, args ...any) { - s.logger.Error(fmt.Sprintf(format, args...)) -} - -func (s *SlogAdapter) Fatalf(format string, args ...any) { - s.logger.Error(fmt.Sprintf(format, args...)) - os.Exit(1) -} - -func (s *SlogAdapter) Panicf(format string, args ...any) { - s.logger.Error(fmt.Sprintf(format, args...)) - panic(fmt.Sprintf(format, args...)) -} - -func main() { - engine := touka.New() - - // 使用 slog 替换默认的 reco.Logger - handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelDebug, - }) - slogAdapter := NewSlogAdapter(handler) - engine.SetLogger(slogAdapter) - - engine.GET("/", func(c *touka.Context) { - c.Infof("request received: %s", c.Request.URL.Path) - c.JSON(http.StatusOK, map[string]string{"message": "hello"}) - }) - - // 也可以获取 Logger 接口 - logger := engine.GetLogger() - logger.Debugf("engine started") - - // 也可以直接使用 slog - slog.Info("Server running", "addr", ":8080") - // engine.Run(":8080") -} diff --git a/go.mod b/go.mod index dee187d..42f4be4 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,14 @@ module github.com/infinite-iroha/touka go 1.26 require ( - github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 + github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 github.com/WJQSERVER-STUDIO/httpc v0.9.0 github.com/WJQSERVER/wanf v0.0.8 github.com/fenthope/reco v0.0.5 github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 - golang.org/x/net v0.52.0 ) require ( github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/text v0.35.0 // indirect + golang.org/x/net v0.52.0 // indirect ) diff --git a/go.sum b/go.sum index 4b9dbd9..b49879b 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= -github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3 h1:Hc1O6D50U3URkdSzfQ/SgeUU750wUBCYhefdvAbE2Ck= -github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.3/go.mod h1:nFQzepAwwdj5Hp5U+X19l4FVvsaOSBTW41BzfI/CkMA= github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4= github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg= github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8= @@ -14,5 +12,3 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= -golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= -golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= diff --git a/http2xconnect.go b/http2xconnect.go deleted file mode 100644 index c691a77..0000000 --- a/http2xconnect.go +++ /dev/null @@ -1,88 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. -// Copyright 2026 WJQSERVER. All rights reserved. -// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. -package touka - -import ( - "crypto/tls" - "net" - "net/http" - "sync" - "time" - _ "unsafe" - - "golang.org/x/net/http2" -) - -var enableHTTP2ExtendedConnectOnce sync.Once - -//go:linkname xnetDisableHTTP2ExtendedConnectProtocol golang.org/x/net/http2.disableExtendedConnectProtocol -var xnetDisableHTTP2ExtendedConnectProtocol bool - -func enableHTTP2ExtendedConnectProtocol() { - enableHTTP2ExtendedConnectOnce.Do(func() { - xnetDisableHTTP2ExtendedConnectProtocol = false - }) -} - -func configureHTTP2ExtendedConnectServer(srv *http.Server) error { - if srv == nil { - return nil - } - enableHTTP2ExtendedConnectProtocol() - return http2.ConfigureServer(srv, nil) -} - -func newHTTP2ExtendedConnectTransport() http.RoundTripper { - enableHTTP2ExtendedConnectProtocol() - transport := cloneDefaultTransport() - transport.Protocols = new(http.Protocols) - transport.Protocols.SetHTTP1(true) - transport.Protocols.SetHTTP2(true) - return transport -} - -func newHTTP1BridgeTransport() http.RoundTripper { - return newHTTP1BridgeTransportWithTLSConfig(&tls.Config{NextProtos: []string{"http/1.1"}}) -} - -func newHTTP1BridgeTransportWithTLSConfig(tlsConfig *tls.Config) http.RoundTripper { - transport := cloneDefaultTransport() - transport.Protocols = new(http.Protocols) - transport.Protocols.SetHTTP1(true) - if tlsConfig == nil { - transport.TLSClientConfig = &tls.Config{} - } else { - transport.TLSClientConfig = tlsConfig.Clone() - } - if len(transport.TLSClientConfig.NextProtos) == 0 { - transport.TLSClientConfig.NextProtos = []string{"http/1.1"} - } - return transport -} - -func newH2CTransport() http.RoundTripper { - transport := cloneDefaultTransport() - transport.Protocols = new(http.Protocols) - transport.Protocols.SetUnencryptedHTTP2(true) - return transport -} - -func cloneDefaultTransport() *http.Transport { - if transport, ok := http.DefaultTransport.(*http.Transport); ok { - return transport.Clone() - } - return &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } -} diff --git a/iox_benchmark_test.go b/iox_benchmark_test.go deleted file mode 100644 index 9b43590..0000000 --- a/iox_benchmark_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package touka - -import ( - "bytes" - "io" - "testing" - - "github.com/WJQSERVER-STUDIO/go-utils/iox" -) - -type benchmarkResetReader struct { - data []byte - off int -} - -func (r *benchmarkResetReader) Read(p []byte) (int, error) { - if r.off >= len(r.data) { - return 0, io.EOF - } - n := copy(p, r.data[r.off:]) - r.off += n - return n, nil -} - -func (r *benchmarkResetReader) Reset() { - r.off = 0 -} - -type benchmarkDiscardWriter struct{} - -func (benchmarkDiscardWriter) Write(p []byte) (int, error) { - return len(p), nil -} - -var benchmarkIOXResult int64 -var benchmarkIOXBytes []byte - -func BenchmarkIOXCopyComparison(b *testing.B) { - payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) - - b.Run("io.Copy", func(b *testing.B) { - r := &benchmarkResetReader{data: payload} - w := benchmarkDiscardWriter{} - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.Reset() - n, err := io.Copy(w, r) - if err != nil { - b.Fatalf("io.Copy failed: %v", err) - } - benchmarkIOXResult = n - } - }) - - b.Run("iox.Copy", func(b *testing.B) { - r := &benchmarkResetReader{data: payload} - w := benchmarkDiscardWriter{} - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.Reset() - n, err := iox.Copy(w, r) - if err != nil { - b.Fatalf("iox.Copy failed: %v", err) - } - benchmarkIOXResult = n - } - }) -} - -func BenchmarkIOXCopyBufferComparison(b *testing.B) { - payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) - - b.Run("io.CopyBuffer", func(b *testing.B) { - r := &benchmarkResetReader{data: payload} - w := benchmarkDiscardWriter{} - buf := make([]byte, 32*1024) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.Reset() - n, err := io.CopyBuffer(w, r, buf) - if err != nil { - b.Fatalf("io.CopyBuffer failed: %v", err) - } - benchmarkIOXResult = n - } - }) - - b.Run("iox.CopyBuffer", func(b *testing.B) { - r := &benchmarkResetReader{data: payload} - w := benchmarkDiscardWriter{} - buf := make([]byte, 32*1024) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.Reset() - n, err := iox.CopyBuffer(w, r, buf) - if err != nil { - b.Fatalf("iox.CopyBuffer failed: %v", err) - } - benchmarkIOXResult = n - } - }) -} - -func BenchmarkIOXReadAllComparison(b *testing.B) { - payload := bytes.Repeat([]byte("0123456789abcdef"), 4096) - - b.Run("io.ReadAll", func(b *testing.B) { - r := &benchmarkResetReader{data: payload} - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.Reset() - data, err := io.ReadAll(r) - if err != nil { - b.Fatalf("io.ReadAll failed: %v", err) - } - benchmarkIOXBytes = data - } - }) - - b.Run("iox.ReadAll", func(b *testing.B) { - r := &benchmarkResetReader{data: payload} - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.Reset() - data, err := io.ReadAll(r) - if err != nil { - b.Fatalf("iox.ReadAll failed: %v", err) - } - benchmarkIOXBytes = data - } - }) -} diff --git a/logger.go b/logger.go deleted file mode 100644 index 1be0077..0000000 --- a/logger.go +++ /dev/null @@ -1,23 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. -// Copyright 2024 WJQSERVER. All rights reserved. -// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. -package touka - -// Logger 是日志接口,支持多种日志库实现(reco、zap、logrus 等) -// 用户可以通过实现此接口来替换默认的日志实现 -type Logger interface { - Debugf(format string, args ...any) - Infof(format string, args ...any) - Warnf(format string, args ...any) - Errorf(format string, args ...any) - Fatalf(format string, args ...any) - Panicf(format string, args ...any) -} - -// CloserLogger 可选扩展接口,支持关闭操作 -// 如果 Logger 实现了此接口,Engine 在关闭时会调用 Close() -type CloserLogger interface { - Logger - Close() error -} diff --git a/logreco.go b/logreco.go index e37dd53..4bda8d3 100644 --- a/logreco.go +++ b/logreco.go @@ -39,16 +39,7 @@ func CloseLogger(logger *reco.Logger) { } } -// CloseLogger 关闭 Engine 的日志实现 -// 如果 logger 实现了 CloserLogger 接口,会调用其 Close 方法 func (engine *Engine) CloseLogger() { - if cl, ok := engine.logger.(CloserLogger); ok { - if err := cl.Close(); err != nil { - log.Printf("Close Logger Error: %s", err) - } - return - } - // 兼容旧代码 if engine.LogReco != nil { CloseLogger(engine.LogReco) } diff --git a/maxreader.go b/maxreader.go index 4d3fb2c..c6201e6 100644 --- a/maxreader.go +++ b/maxreader.go @@ -23,21 +23,19 @@ type maxBytesReader struct { n int64 // read 是一个原子计数器, 用于安全地在多个 goroutine 之间跟踪已读取的字节数. read atomic.Int64 - // emptyAtLimit 记录在达到上限后是否已经遇到过一次 0,nil 读. - emptyAtLimit atomic.Bool } // NewMaxBytesReader 创建并返回一个 io.ReadCloser, 它从 r 读取数据, // 但在读取的字节数超过 n 后会返回 ErrBodyTooLarge 错误. // // 如果 r 为 nil, 会 panic. -// 如果 n 小于等于 0, 则读取不受限制, 直接返回原始的 r. +// 如果 n 小于 0, 则读取不受限制, 直接返回原始的 r. func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { if r == nil { panic("NewMaxBytesReader called with a nil reader") } - // 如果限制为非正数, 意味着不限制, 直接返回原始的 ReadCloser. - if n <= 0 { + // 如果限制为负数, 意味着不限制, 直接返回原始的 ReadCloser. + if n < 0 { return r } return &maxBytesReader{ @@ -48,53 +46,48 @@ func NewMaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser { // Read 方法从底层的 ReadCloser 读取数据, 同时检查是否超过了字节限制. func (mbr *maxBytesReader) Read(p []byte) (int, error) { - if len(p) == 0 { - return 0, nil - } - // 在函数开始时只加载一次原子变量, 减少后续的原子操作开销. readSoFar := mbr.read.Load() - remaining := mbr.n - readSoFar - if remaining < 0 { + + // 快速失败路径: 如果在读取之前就已经达到了限制, 立即返回错误. + if readSoFar >= mbr.n { return 0, ErrBodyTooLarge } - if remaining == 0 { - var probe [1]byte - n, err := mbr.r.Read(probe[:]) - if n > 0 { - mbr.read.Add(int64(n)) - return 0, ErrBodyTooLarge - } - if err != nil { - return 0, err - } - if mbr.emptyAtLimit.Swap(true) { - return 0, ErrBodyTooLarge - } - return 0, nil - } - mbr.emptyAtLimit.Store(false) - // 最多多读一个字节, 以区分“恰好到上限”和“已经超限”。 - if int64(len(p))-1 > remaining { - p = p[:remaining+1] + // 计算当前还可以读取多少字节. + remaining := mbr.n - readSoFar + + // 如果请求读取的长度大于剩余可读长度, 我们需要限制本次读取的长度. + // 这样可以保证即使 p 很大, 我们也只读取到恰好达到 maxBytes 的字节数. + if int64(len(p)) > remaining { + p = p[:remaining] } // 从底层 Reader 读取数据. n, err := mbr.r.Read(p) - if int64(n) <= remaining { - if n > 0 { - mbr.read.Add(int64(n)) - } + // 如果实际读取到了数据, 更新原子计数器. + if n > 0 { + readSoFar = mbr.read.Add(int64(n)) + } + + // 如果底层 Read 返回错误 (例如 io.EOF). + if err != nil { + // 如果是 EOF, 并且我们还没有读满 n 个字节, 这是一个正常的结束. + // 如果已经读满了 n 个字节, 即使是 EOF, 也可以认为成功了. return n, err } - // 读取结果跨过了限制,只向上层暴露允许的部分。 - if remaining > 0 { - mbr.read.Add(remaining) + // 读后检查: 如果这次读取使得总字节数超过了限制, 返回超限错误. + // 这是处理"跨越"限制情况的关键. + if readSoFar > mbr.n { + // 返回实际读取的字节数 n, 并附上超限错误. + // 上层调用者知道已经有 n 字节被读入了缓冲区 p, 但流已因超限而关闭. + return n, ErrBodyTooLarge } - return int(remaining), ErrBodyTooLarge + + // 一切正常, 返回读取的字节数和 nil 错误. + return n, nil } // Close 方法关闭底层的 ReadCloser, 保证资源释放. diff --git a/mergectx.go b/mergectx.go index 404f7b1..e5d3ec4 100644 --- a/mergectx.go +++ b/mergectx.go @@ -11,16 +11,18 @@ import ( ) // mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. -// 嵌入 cancelCtx 作为基础 context, 支持 cause 传播. -// deadlineCtx 作为 cancelCtx 的子 context, 确保 deadline 到期时 cancelCtx 也被取消. type mergedContext struct { + // 嵌入一个基础 context, 它持有最早的 deadline 和取消信号. context.Context + // 保存了所有的父 context, 用于 Value() 方法的查找. parents []context.Context + // 用于手动取消此 mergedContext 的函数. + cancel context.CancelFunc } // MergeCtx 创建并返回一个新的 context.Context. // 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时, -// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context. +// 自动被取消 (逻辑或关系). // // 新的 context 会继承: // - Deadline: 所有父 context 中最早的截止时间. @@ -30,8 +32,7 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C return context.WithCancel(context.Background()) } if len(parents) == 1 { - ctx, cancel := context.WithCancelCause(parents[0]) - return ctx, func() { cancel(nil) } + return context.WithCancel(parents[0]) } var earliestDeadline time.Time @@ -43,71 +44,37 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C } } - // cancelCtx 作为基础 context, 提供 CancelCauseFunc 以支持 cause 传播. - cancelCtx, cancelCause := context.WithCancelCause(context.Background()) - - // deadlineCtx 作为 cancelCtx 的子 context (如果有 deadline). - // 当 cancelCtx 被取消时, deadlineCtx 也会被取消; - // 当 deadline 到期时, deadlineCtx 自行取消, watcher 负责关闭 cancelCtx. - var deadlineCtx context.Context - var deadlineCancel context.CancelFunc + var baseCtx context.Context + var baseCancel context.CancelFunc if !earliestDeadline.IsZero() { - deadlineCtx, deadlineCancel = context.WithDeadlineCause(cancelCtx, earliestDeadline, context.DeadlineExceeded) - } - - // 嵌入的 context: 有 deadline 时用 deadlineCtx (以返回正确的 Deadline), - // 否则用 cancelCtx. - embedCtx := cancelCtx - if deadlineCtx != nil { - embedCtx = deadlineCtx + baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline) + } else { + baseCtx, baseCancel = context.WithCancel(context.Background()) } mc := &mergedContext{ - Context: embedCtx, + Context: baseCtx, parents: parents, + cancel: baseCancel, } - // 启动监控 goroutine, 监听 parent 取消或 deadline 到期. + // 启动一个监控 goroutine. go func() { - // 将 cancelCtx 加入 orDone, 确保手动 cancel() 时 orDone goroutine 能退出, 防止泄漏. - parentDone := orDone(append(mc.parents, cancelCtx)...) + defer mc.cancel() - if deadlineCtx != nil { - defer deadlineCancel() - select { - case <-parentDone: - // parent 取消或手动 cancel() - for _, p := range mc.parents { - if p.Err() != nil { - cancelCause(context.Cause(p)) - return - } - } - // 手动 cancel(), cause 已由 cancelCause() 设置 - case <-deadlineCtx.Done(): - // deadline 到期, 需要关闭 cancelCtx 并设置 cause - cancelCause(context.DeadlineExceeded) - } - } else { - <-parentDone - for _, p := range mc.parents { - if p.Err() != nil { - cancelCause(context.Cause(p)) - return - } - } + // orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭. + // 同时监听 baseCtx.Done() 以便支持手动取消. + select { + case <-orDone(mc.parents...): + case <-mc.Context.Done(): } }() - return mc, func() { cancelCause(nil) } + return mc, mc.cancel } -// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause), -// 再按传入顺序从 parents 中查找. +// Value 返回当前Ctx Value func (mc *mergedContext) Value(key any) any { - if v := mc.Context.Value(key); v != nil { - return v - } for _, p := range mc.parents { if val := p.Value(key); val != nil { return val @@ -116,20 +83,45 @@ func (mc *mergedContext) Value(key any) any { return nil } -// Deadline, Done, Err 均由嵌入的 context.Context 提供. +// Deadline 实现了 context.Context 的 Deadline 方法. +func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) { + return mc.Context.Deadline() +} -// orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭. +// Done 实现了 context.Context 的 Done 方法. +func (mc *mergedContext) Done() <-chan struct{} { + return mc.Context.Done() +} + +// Err 实现了 context.Context 的 Err 方法. +func (mc *mergedContext) Err() error { + return mc.Context.Err() +} + +// orDone 是一个辅助函数, 返回一个 channel. +// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭. +// 这是一个非阻塞的、不会泄漏 goroutine 的实现. func orDone(contexts ...context.Context) <-chan struct{} { done := make(chan struct{}) + var once sync.Once + closeDone := func() { + once.Do(func() { + close(done) + }) + } + + // 为每个父 context 启动一个 goroutine. for _, ctx := range contexts { go func(c context.Context) { select { case <-c.Done(): - once.Do(func() { close(done) }) + closeDone() case <-done: + // orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出. } }(ctx) } + return done } diff --git a/mergectx_test.go b/mergectx_test.go deleted file mode 100644 index d6d1225..0000000 --- a/mergectx_test.go +++ /dev/null @@ -1,256 +0,0 @@ -package touka - -import ( - "context" - "errors" - "testing" - "time" -) - -func TestMergeCtx_NoParents(t *testing.T) { - ctx, cancel := MergeCtx() - defer cancel() - - if ctx.Err() != nil { - t.Fatal("expected no error before cancel") - } - cancel() - if ctx.Err() == nil { - t.Fatal("expected error after cancel") - } -} - -func TestMergeCtx_SingleParent(t *testing.T) { - parent, parentCancel := context.WithCancel(context.Background()) - - ctx, cancel := MergeCtx(parent) - defer cancel() - - if ctx.Err() != nil { - t.Fatal("expected no error before parent cancel") - } - - parentCancel() - <-ctx.Done() - - if ctx.Err() == nil { - t.Fatal("expected error after parent cancel") - } -} - -func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) { - p1, cancel1 := context.WithCancel(context.Background()) - p2, cancel2 := context.WithCancel(context.Background()) - defer cancel2() - - ctx, cancel := MergeCtx(p1, p2) - defer cancel() - - cancel1() - <-ctx.Done() - - if ctx.Err() == nil { - t.Fatal("expected error after p1 cancel") - } - // p2 should still be fine - if p2.Err() != nil { - t.Fatal("expected p2 to be unaffected") - } -} - -func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) { - p1, cancel1 := context.WithCancel(context.Background()) - p2, cancel2 := context.WithCancel(context.Background()) - defer cancel1() - - ctx, cancel := MergeCtx(p1, p2) - defer cancel() - - cancel2() - <-ctx.Done() - - if ctx.Err() == nil { - t.Fatal("expected error after p2 cancel") - } -} - -func TestMergeCtx_ExternalCancel(t *testing.T) { - p1, cancel1 := context.WithCancel(context.Background()) - p2, cancel2 := context.WithCancel(context.Background()) - defer cancel1() - defer cancel2() - - ctx, cancel := MergeCtx(p1, p2) - - cancel() - <-ctx.Done() - - if ctx.Err() == nil { - t.Fatal("expected error after external cancel") - } -} - -func TestMergeCtx_CausePropagation(t *testing.T) { - testErr := errors.New("test cause") - - p1, cancel1 := context.WithCancelCause(context.Background()) - p2, cancel2 := context.WithCancel(context.Background()) - defer cancel2() - - ctx, cancel := MergeCtx(p1, p2) - defer cancel() - - cancel1(testErr) - <-ctx.Done() - - if ctx.Err() == nil { - t.Fatal("expected error after p1 cancel") - } - - cause := context.Cause(ctx) - if cause != testErr { - t.Fatalf("expected cause %v, got %v", testErr, cause) - } - cancel1(nil) // cleanup (already cancelled, no-op) -} - -func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) { - testErr := errors.New("second parent cause") - - p1, cancel1 := context.WithCancel(context.Background()) - p2, cancel2 := context.WithCancelCause(context.Background()) - - ctx, cancel := MergeCtx(p1, p2) - defer cancel() - - cancel2(testErr) - - <-ctx.Done() - - if ctx.Err() == nil { - t.Fatal("expected error after p2 cancel") - } - - cause := context.Cause(ctx) - if cause != testErr { - t.Fatalf("expected cause %v, got %v", testErr, cause) - } - - cancel1() -} - -func TestMergeCtx_Deadline_Earliest(t *testing.T) { - now := time.Now() - early := now.Add(100 * time.Millisecond) - late := now.Add(1 * time.Hour) - - p1, cancel1 := context.WithDeadline(context.Background(), late) - p2, cancel2 := context.WithDeadline(context.Background(), early) - defer cancel1() - defer cancel2() - - ctx, cancel := MergeCtx(p1, p2) - defer cancel() - - dl, ok := ctx.Deadline() - if !ok { - t.Fatal("expected deadline to be set") - } - if !dl.Equal(early) { - t.Fatalf("expected deadline %v, got %v", early, dl) - } -} - -func TestMergeCtx_Deadline_Expires(t *testing.T) { - p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancelP() - - ctx, cancel := MergeCtx(p) - defer cancel() - - <-ctx.Done() - - if ctx.Err() == nil { - t.Fatal("expected error after deadline expires") - } -} - -func TestMergeCtx_ValueLookup(t *testing.T) { - type key struct{} - p1 := context.WithValue(context.Background(), key{}, "from_p1") - p2 := context.WithValue(context.Background(), key{}, "from_p2") - - ctx, cancel := MergeCtx(p1, p2) - defer cancel() - - val := ctx.Value(key{}) - if val != "from_p1" { - t.Fatalf("expected 'from_p1', got %v", val) - } -} - -func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) { - type key1 struct{} - type key2 struct{} - p1 := context.WithValue(context.Background(), key1{}, "val1") - p2 := context.WithValue(context.Background(), key2{}, "val2") - - ctx, cancel := MergeCtx(p1, p2) - defer cancel() - - if v := ctx.Value(key1{}); v != "val1" { - t.Fatalf("expected 'val1', got %v", v) - } - if v := ctx.Value(key2{}); v != "val2" { - t.Fatalf("expected 'val2', got %v", v) - } - if v := ctx.Value("missing"); v != nil { - t.Fatalf("expected nil, got %v", v) - } -} - -func TestMergeCtx_ContextInterface(t *testing.T) { - p1, cancel1 := context.WithCancel(context.Background()) - p2, cancel2 := context.WithCancel(context.Background()) - defer cancel1() - defer cancel2() - - var ctx context.Context - ctx, _ = MergeCtx(p1, p2) - - // Verify all Context interface methods work - _ = ctx.Done() - _ = ctx.Err() - _, _ = ctx.Deadline() - _ = ctx.Value("any") -} - -func TestOrDone_SingleContext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - done := orDone(ctx) - - cancel() - <-done // should not block -} - -func TestOrDone_MultipleContexts(t *testing.T) { - p1, cancel1 := context.WithCancel(context.Background()) - p2, cancel2 := context.WithCancel(context.Background()) - defer cancel2() - - done := orDone(p1, p2) - - cancel1() - <-done // should not block -} - -func TestOrDone_SecondContextCancels(t *testing.T) { - p1, cancel1 := context.WithCancel(context.Background()) - p2, cancel2 := context.WithCancel(context.Background()) - defer cancel1() - - done := orDone(p1, p2) - - cancel2() - <-done // should not block -} diff --git a/protocols_test.go b/protocols_test.go index 0e2bf1f..73f16e9 100644 --- a/protocols_test.go +++ b/protocols_test.go @@ -70,25 +70,42 @@ func TestApplyDefaultServerConfig(t *testing.T) { } } -func TestTLSRunDefaultsProtocolInheritance(t *testing.T) { +func TestRunTLSProtocolInheritance(t *testing.T) { engine := New() - srv := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) - - if !srv.Protocols.HTTP2() { - t.Error("TLS run defaults: expected HTTP/2 to be enabled for default config") + // 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2 + if engine.useDefaultProtocols { + engine.setProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, + }) } - // 模拟用户设置了自定义协议后进入 TLS 运行模式 + srv := &http.Server{TLSConfig: &tls.Config{}} + engine.applyDefaultServerConfig(srv) + + if !srv.Protocols.HTTP2() { + t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config") + } + + // 模拟用户设置了自定义协议后调用 RunTLS engine = New() engine.SetProtocols(&ProtocolsConfig{ Http1: true, Http2: false, // 用户明确不想要 HTTP/2 }) - srv2 := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) + if engine.useDefaultProtocols { + engine.setProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, + }) + } + + srv2 := &http.Server{TLSConfig: &tls.Config{}} + engine.applyDefaultServerConfig(srv2) if srv2.Protocols.HTTP2() { - t.Error("TLS run defaults: expected HTTP/2 to remain disabled when user set custom protocols") + t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously") } } diff --git a/respw.go b/respw.go index ef5cc3c..dd94db3 100644 --- a/respw.go +++ b/respw.go @@ -113,7 +113,7 @@ func (rw *responseWriterImpl) Hijack() (net.Conn, *bufio.ReadWriter, error) { // 尝试从底层 ResponseWriter 获取 Hijacker 接口 hj, ok := rw.ResponseWriter.(http.Hijacker) if !ok { - return nil, nil, http.ErrNotSupported + return nil, nil, errors.New("http.Hijacker interface not supported") } // 调用底层的 Hijack 方法 diff --git a/reverseproxy.go b/reverseproxy.go index 1cf0078..1730b1e 100644 --- a/reverseproxy.go +++ b/reverseproxy.go @@ -6,8 +6,6 @@ package touka import ( "context" - "crypto/rand" - "encoding/base64" "errors" "fmt" "io" @@ -16,18 +14,14 @@ import ( "net" "net/http" "net/http/httptrace" - "net/http/httputil" "net/netip" "net/textproto" "net/url" - "regexp" "strconv" "strings" "sync" "sync/atomic" "time" - - "golang.org/x/net/http2" ) // ForwardedHeadersPolicy controls how forwarding headers are generated. @@ -50,294 +44,32 @@ type BufferPool interface { // ReverseProxyConfig configures the reverse proxy handler. type ReverseProxyConfig struct { Target *url.URL - Targets []string - LoadBalancing ReverseProxyLoadBalancingConfig - PassiveHealth ReverseProxyPassiveHealthConfig - - Transport http.RoundTripper + Transport http.RoundTripper FlushInterval time.Duration - BufferPool BufferPool - AllowH2CUpstream bool + BufferPool BufferPool - ModifyRequest func(*http.Request) + ModifyRequest func(*http.Request) ModifyResponse func(*http.Response) error - ErrorHandler func(http.ResponseWriter, *http.Request, error) + ErrorHandler func(http.ResponseWriter, *http.Request, error) ForwardedHeaders ForwardedHeadersPolicy - ForwardedBy string - Via string - PreserveHost bool - - RequestHeaders *HeaderOps - ResponseHeaders *RespHeaderOps + ForwardedBy string + Via string + PreserveHost bool } var ( - errReverseProxyNilTarget = errors.New("reverse proxy target is nil") + errReverseProxyNilTarget = errors.New("reverse proxy target is nil") errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") - errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") - errReverseProxyNoAvailableUpstreams = errors.New("reverse proxy has no available upstreams") + errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") ) -type HeaderOps struct { - Add map[string][]string - Set map[string][]string - Delete []string - Replace map[string][]Replacement -} - -type Replacement struct { - Search string - Replace string - SearchRegexp string - re *regexp.Regexp -} - -type RespHeaderOps struct { - *HeaderOps - Deferred bool -} - -func (ops *HeaderOps) applyToRequest(req *http.Request) { - if ops == nil { - return - } - ops.applyTo(req.Header, newReverseProxyReplacer(req)) -} - -func (ops *RespHeaderOps) applyToResponse(hdr http.Header) { - if ops == nil { - return - } - ops.applyTo(hdr, newReverseProxyReplacerFromHeader(hdr)) -} - -func (ops *HeaderOps) applyTo(hdr http.Header, repl *reverseProxyReplacer) { - if ops == nil { - return - } - if repl == nil { - repl = &reverseProxyReplacer{} - } - - for fieldName, vals := range ops.Add { - fieldName = repl.Replace(fieldName) - for _, v := range vals { - hdr.Add(fieldName, repl.Replace(v)) - } - } - - for fieldName, vals := range ops.Set { - fieldName = repl.Replace(fieldName) - hdr.Del(fieldName) - for _, v := range vals { - hdr.Add(fieldName, repl.Replace(v)) - } - } - - var deleteAll bool - var exactDeletes []string - var suffixPatterns, prefixPatterns, containsPatterns []string - - for _, fieldName := range ops.Delete { - fieldName = strings.ToLower(repl.Replace(fieldName)) - if fieldName == "*" { - deleteAll = true - break - } - switch { - case strings.HasPrefix(fieldName, "*") && strings.HasSuffix(fieldName, "*"): - containsPatterns = append(containsPatterns, fieldName[1:len(fieldName)-1]) - case strings.HasPrefix(fieldName, "*"): - suffixPatterns = append(suffixPatterns, fieldName[1:]) - case strings.HasSuffix(fieldName, "*"): - prefixPatterns = append(prefixPatterns, fieldName[:len(fieldName)-1]) - default: - exactDeletes = append(exactDeletes, fieldName) - } - } - - if deleteAll { - for k := range hdr { - hdr.Del(k) - } - } else if len(exactDeletes) > 0 || len(suffixPatterns) > 0 || len(prefixPatterns) > 0 || len(containsPatterns) > 0 { - toDelete := make([]string, 0, len(exactDeletes)) - for k := range hdr { - kl := strings.ToLower(k) - for _, d := range exactDeletes { - if kl == d { - toDelete = append(toDelete, k) - goto skip - } - } - for _, p := range containsPatterns { - if strings.Contains(kl, p) { - toDelete = append(toDelete, k) - goto skip - } - } - for _, p := range suffixPatterns { - if strings.HasSuffix(kl, p) { - toDelete = append(toDelete, k) - goto skip - } - } - for _, p := range prefixPatterns { - if strings.HasPrefix(kl, p) { - toDelete = append(toDelete, k) - goto skip - } - } - skip: - } - for _, k := range toDelete { - hdr.Del(k) - } - } - - ops.applyReplace(hdr, repl) -} - -func (ops *HeaderOps) applyReplace(hdr http.Header, repl *reverseProxyReplacer) { - if ops == nil || len(ops.Replace) == 0 { - return - } - for fieldName, replacements := range ops.Replace { - fieldName = http.CanonicalHeaderKey(repl.Replace(fieldName)) - if fieldName == "*" { - for fn, vals := range hdr { - for i := range vals { - for _, r := range replacements { - hdr[fn][i] = r.apply(vals[i]) - } - } - } - continue - } - vals, ok := hdr[fieldName] - if !ok { - continue - } - for i := range vals { - for _, r := range replacements { - hdr[fieldName][i] = r.apply(vals[i]) - } - } - } -} - -func (r *Replacement) apply(s string) string { - if r == nil || s == "" { - return s - } - if r.SearchRegexp != "" && r.re != nil { - return r.re.ReplaceAllString(s, r.Replace) - } - if r.Search != "" { - return strings.ReplaceAll(s, r.Search, r.Replace) - } - return s -} - -func (ops *HeaderOps) Provision() error { - if ops == nil { - return nil - } - for fieldName, replacements := range ops.Replace { - for i, r := range replacements { - if r.SearchRegexp == "" { - continue - } - if r.Search != "" { - return fmt.Errorf("replacement %d for header field %q: cannot specify both Search and SearchRegexp", i, fieldName) - } - re, err := regexp.Compile(r.SearchRegexp) - if err != nil { - return fmt.Errorf("replacement %d for header field %q: %v", i, fieldName, err) - } - replacements[i].re = re - } - } - return nil -} - -type reverseProxyReplacer struct { - method, host, path, query, scheme, uri, proto string -} - -func newReverseProxyReplacer(req *http.Request) *reverseProxyReplacer { - if req == nil || req.URL == nil { - return &reverseProxyReplacer{} - } - uri := req.RequestURI - if uri == "" { - uri = req.URL.RequestURI() - } - return &reverseProxyReplacer{ - method: req.Method, - host: req.Host, - path: req.URL.EscapedPath(), - query: req.URL.RawQuery, - scheme: reverseProxyRequestScheme(req), - uri: uri, - proto: req.Proto, - } -} - -func newReverseProxyReplacerFromHeader(hdr http.Header) *reverseProxyReplacer { - return &reverseProxyReplacer{} -} - -func (r *reverseProxyReplacer) Replace(s string) string { - if r == nil || s == "" { - return s - } - if r.method != "" { - s = strings.ReplaceAll(s, "{method}", r.method) - } - if r.host != "" { - s = strings.ReplaceAll(s, "{host}", r.host) - } - if r.path != "" { - s = strings.ReplaceAll(s, "{path}", r.path) - } - if r.query != "" { - s = strings.ReplaceAll(s, "{query}", r.query) - } - if r.scheme != "" { - s = strings.ReplaceAll(s, "{scheme}", r.scheme) - } - if r.uri != "" { - s = strings.ReplaceAll(s, "{uri}", r.uri) - } - if r.proto != "" { - s = strings.ReplaceAll(s, "{proto}", r.proto) - } - return s -} - type reverseProxyHandler struct { config ReverseProxyConfig - upstreams []*reverseProxyUpstream + target *url.URL receivedBy string configError error - roundRobin atomic.Uint64 -} - -var reverseProxyCopyBufferPool = sync.Pool{ - New: func() any { - buf := make([]byte, 32*1024) - return &buf - }, -} - -var reverseProxyCandidatePool = sync.Pool{ - New: func() any { - s := make([]*reverseProxyUpstream, 0, 8) - return &s - }, } type reverseProxyStatusError struct { @@ -345,34 +77,6 @@ type reverseProxyStatusError struct { err error } -type reverseProxyExtendedConnectBridge struct { - body io.ReadCloser -} - -type reverseProxyH2ReadWriteCloser struct { - io.ReadCloser - ResponseWriter - controller *http.ResponseController -} - -func (rwc *reverseProxyH2ReadWriteCloser) Write(p []byte) (int, error) { - n, err := rwc.ResponseWriter.Write(p) - if err != nil { - return n, err - } - if err := rwc.controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { - return n, err - } - return n, nil -} - -func (rwc *reverseProxyH2ReadWriteCloser) Close() error { - if rwc.ReadCloser == nil { - return nil - } - return rwc.ReadCloser.Close() -} - func (e *reverseProxyStatusError) Error() string { if e == nil || e.err == nil { return "" @@ -493,29 +197,19 @@ func ReverseProxy(config ReverseProxyConfig) HandlerFunc { } func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { + target := cloneReverseProxyURL(config.Target) + if target != nil { + normalizeReverseProxyTarget(target) + } + proxy := &reverseProxyHandler{ config: config, + target: target, receivedBy: reverseProxyReceivedBy(config.Via), } - if config.RequestHeaders != nil { - if err := config.RequestHeaders.Provision(); err != nil { - proxy.configError = err - return proxy - } - } - if config.ResponseHeaders != nil && config.ResponseHeaders.HeaderOps != nil { - if err := config.ResponseHeaders.HeaderOps.Provision(); err != nil { - proxy.configError = err - return proxy - } - } - - upstreams, err := buildReverseProxyUpstreams(config) - if err != nil { + if err := validateReverseProxyTarget(target); err != nil { proxy.configError = err - } else { - proxy.upstreams = upstreams } switch config.ForwardedHeaders { @@ -523,17 +217,6 @@ func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { default: proxy.config.ForwardedHeaders = ForwardedBoth } - proxy.config.ForwardedBy = strings.TrimSpace(proxy.config.ForwardedBy) - if reverseProxyUsesForwardedHeader(proxy.config.ForwardedHeaders) { - if err := validateReverseProxyForwardedBy(proxy.config.ForwardedBy); err != nil { - proxy.configError = err - } - } - if proxy.configError == nil { - if err := validateReverseProxyLBPolicy(proxy.config.LoadBalancing.Policy); err != nil { - proxy.configError = err - } - } return proxy } @@ -546,75 +229,62 @@ func (p *reverseProxyHandler) ServeHTTP(c *Context) { return } - updatedMaxForwards, handledLocally, err := p.handleMaxForwards(c) - if err != nil { - p.handleError(c, err) - return - } - if handledLocally { - return + transport := p.config.Transport + if transport == nil { + transport = http.DefaultTransport } ctx, cancel := p.requestContext(c) defer cancel() - attempted := make(map[string]struct{}, len(p.upstreams)) - attempts := 0 - started := time.Now() - var lastErr error - for { - upstream, err := p.selectUpstream(c, attempted) - if err != nil { - if lastErr != nil { - p.handleError(c, lastErr) - return - } - p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: err}) - return - } + outreq := c.Request.Clone(ctx) + if c.Request.ContentLength == 0 { + outreq.Body = nil + } + if outreq.Body != nil { + outreq.Body = &noopCloseReader{readCloser: outreq.Body} + defer outreq.Body.Close() + } + if outreq.Header == nil { + outreq.Header = make(http.Header) + } + outreq.Close = false - attempts++ - upstream.inFlight.Add(1) - served, attemptErr, retriable := p.serveUpstreamAttempt(c, ctx, upstream, updatedMaxForwards) - upstream.inFlight.Add(-1) + rewriteReverseProxyURL(outreq, p.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - if served { - return - } - if attemptErr != nil { - lastErr = attemptErr - } - if retriable && p.shouldRetryAttempt(c.Request, attempts, started) { - attempted[upstream.key] = struct{}{} - if !p.waitRetryInterval(ctx, started) { - if lastErr != nil { - p.handleError(c, lastErr) - } - return - } - continue - } - if attemptErr != nil { - p.handleError(c, attemptErr) - return - } - if lastErr != nil { - p.handleError(c, lastErr) - return - } - p.handleError(c, &reverseProxyStatusError{status: http.StatusBadGateway, err: errReverseProxyNoAvailableUpstreams}) + reqUpType := reverseProxyUpgradeType(outreq.Header) + if reqUpType != "" && !isPrintableASCII(reqUpType) { + p.handleError(c, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), + }) return } -} -func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (bool, error, bool) { - outreq, connectWriter, cleanup, err := p.buildOutgoingRequest(c, ctx, upstream, updatedMaxForwards) - if err != nil { - return false, err, false + removeHopByHopHeaders(outreq.Header) + if headerValuesContainToken(c.Request.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + p.addForwardingHeaders(c.Request, outreq) + appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) + + if _, ok := outreq.Header["User-Agent"]; !ok { + outreq.Header.Set("User-Agent", "") + } + + if p.config.ModifyRequest != nil { + p.config.ModifyRequest(outreq) } - defer cleanup() - transport := p.transportForUpstream(outreq, upstream) rawWriter := reverseProxyBaseResponseWriter(c.Writer) var ( roundTripMu sync.Mutex @@ -644,65 +314,26 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte roundTripDone = true roundTripMu.Unlock() if err != nil { - if reverseProxyShouldCountPassiveFailure(outreq, err) { - upstream.recordFailure(time.Now(), p.config.PassiveHealth) - } - return false, err, true - } - if reverseProxyStatusIsUnhealthy(p.config.PassiveHealth, res.StatusCode) { - upstream.recordFailure(time.Now(), p.config.PassiveHealth) - } - - if bridge := reverseProxyExtendedConnectBridgeFromContext(outreq.Context()); bridge != nil { - if res.StatusCode == http.StatusSwitchingProtocols { - appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) - if !p.modifyResponse(c, res, outreq) { - return true, nil, false - } - if err := p.handleBridgedExtendedConnectResponse(c, outreq, res, bridge); err != nil { - return false, err, false - } - return true, nil, false - } - return false, &reverseProxyStatusError{status: http.StatusBadGateway, err: fmt.Errorf("extended CONNECT backend returned status %d instead of 101", res.StatusCode)}, false - } - - if outreq.Method == http.MethodConnect && res.StatusCode >= http.StatusOK && res.StatusCode < http.StatusMultipleChoices { - removeHopByHopHeaders(res.Header) - res.Header.Del("Content-Length") - res.Header.Del("Transfer-Encoding") - res.ContentLength = -1 - res.TransferEncoding = nil - appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) - if !p.modifyResponse(c, res, outreq) { - return true, nil, false - } - handleConnect := p.handleConnectResponse - if reverseProxyIsExtendedConnectRequest(outreq) { - handleConnect = p.handleExtendedConnectResponse - } - if err := handleConnect(c, outreq, res, connectWriter); err != nil { - return false, err, false - } - return true, nil, false + p.handleError(c, err) + return } if res.StatusCode == http.StatusSwitchingProtocols { appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return true, nil, false + return } if err := p.handleUpgradeResponse(c, outreq, res); err != nil { - return false, err, false + p.handleError(c, err) } - return true, nil, false + return } removeHopByHopHeaders(res.Header) appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) if !p.modifyResponse(c, res, outreq) { - return true, nil, false + return } reverseProxyCopyHeader(c.Writer.Header(), res.Header) @@ -722,10 +353,7 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte defer res.Body.Close() c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err)) p.logf(c, "reverse proxy body copy failed: %v", err) - if reverseProxyShouldPanicOnCopyError(c.Request) { - panic(http.ErrAbortHandler) - } - return true, nil, false + return } res.Body.Close() @@ -733,9 +361,13 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte c.Writer.Flush() } + // Keep the stdlib-compatible fallback here. + // If the backend only exposes additional trailer keys after the body has been + // fully read, the trailer map can grow and those values must be written using + // the TrailerPrefix form instead of the pre-announced bare header keys. if len(res.Trailer) == announcedTrailers { reverseProxyCopyHeader(c.Writer.Header(), res.Trailer) - return true, nil, false + return } for key, values := range res.Trailer { @@ -744,249 +376,6 @@ func (p *reverseProxyHandler) serveUpstreamAttempt(c *Context, ctx context.Conte c.Writer.Header().Add(prefixedKey, value) } } - return true, nil, false -} - -func (p *reverseProxyHandler) buildOutgoingRequest(c *Context, ctx context.Context, upstream *reverseProxyUpstream, updatedMaxForwards string) (*http.Request, *io.PipeWriter, func(), error) { - outreq := c.Request.Clone(ctx) - bridgeCtx, bridged, err := reverseProxyPrepareExtendedConnectBridge(outreq) - if err != nil { - return nil, nil, nil, err - } - if bridged { - outreq = outreq.WithContext(bridgeCtx) - } - if outreq.Method == http.MethodConnect || c.Request.ContentLength == 0 { - outreq.Body = nil - } else if c.Request.GetBody != nil { - body, err := c.Request.GetBody() - if err != nil { - return nil, nil, nil, fmt.Errorf("reverse proxy failed to replay request body: %w", err) - } - outreq.Body = body - } else if outreq.Body != nil { - outreq.Body = &noopCloseReader{readCloser: outreq.Body} - } - if outreq.Header == nil { - outreq.Header = make(http.Header) - } - outreq.Close = false - var connectWriter *io.PipeWriter - if outreq.Method == http.MethodConnect && !bridged { - pipeReader, pipeWriter := io.Pipe() - outreq.Body = pipeReader - outreq.ContentLength = -1 - connectWriter = pipeWriter - } - cleanup := func() { - if outreq.Body != nil { - _ = outreq.Body.Close() - } - if connectWriter != nil { - _ = connectWriter.Close() - } - } - - if outreq.Method == http.MethodConnect && !reverseProxyIsExtendedConnectRequest(outreq) { - if err := rewriteReverseProxyConnectRequest(outreq, upstream.target); err != nil { - cleanup() - return nil, nil, nil, err - } - } else { - rewriteReverseProxyURL(outreq, upstream.target) - if !p.config.PreserveHost { - outreq.Host = "" - } - outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) - } - if updatedMaxForwards != "" { - outreq.Header.Set("Max-Forwards", updatedMaxForwards) - } - - reqUpType := reverseProxyUpgradeType(outreq.Header) - if reqUpType != "" && !isPrintableASCII(reqUpType) { - cleanup() - return nil, nil, nil, &reverseProxyStatusError{ - status: http.StatusBadRequest, - err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), - } - } - - removeHopByHopHeaders(outreq.Header) - if headerValuesContainToken(c.Request.Header["Te"], "trailers") { - outreq.Header.Set("Te", "trailers") - } - if reqUpType != "" { - outreq.Header.Set("Connection", "Upgrade") - outreq.Header.Set("Upgrade", reqUpType) - } - - p.addForwardingHeaders(c.Request, outreq) - appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) - - if _, ok := outreq.Header["User-Agent"]; !ok { - outreq.Header.Set("User-Agent", "") - } - - if p.config.RequestHeaders != nil { - p.config.RequestHeaders.applyToRequest(outreq) - } - - if p.config.ModifyRequest != nil { - p.config.ModifyRequest(outreq) - } - - return outreq, connectWriter, cleanup, nil -} - -func (p *reverseProxyHandler) transportForUpstream(req *http.Request, upstream *reverseProxyUpstream) http.RoundTripper { - if p.config.Transport != nil { - return p.config.Transport - } - if reverseProxyExtendedConnectBridgeFromContext(req.Context()) != nil { - if upstream.bridgeTransport != nil { - return upstream.bridgeTransport - } - return http.DefaultTransport - } - if upstream.useH2C && upstream.h2cTransport != nil { - return upstream.h2cTransport - } - if reverseProxyIsExtendedConnectRequest(req) && upstream.extendedConnectTransport != nil { - return upstream.extendedConnectTransport - } - return http.DefaultTransport -} - -func (p *reverseProxyHandler) shouldRetryAttempt(req *http.Request, attempts int, started time.Time) bool { - if req == nil || req.Context().Err() != nil || !reverseProxyCanRetryRequest(req) { - return false - } - lb := p.config.LoadBalancing - if lb.TryDuration > 0 { - return time.Since(started) < lb.TryDuration - } - return attempts <= lb.Retries -} - -func (p *reverseProxyHandler) waitRetryInterval(ctx context.Context, started time.Time) bool { - interval := p.config.LoadBalancing.TryInterval - tryDuration := p.config.LoadBalancing.TryDuration - if tryDuration > 0 && interval == 0 { - interval = 250 * time.Millisecond - } - if tryDuration > 0 { - remaining := tryDuration - time.Since(started) - if remaining <= 0 { - return false - } - if interval <= 0 { - return ctx.Err() == nil - } - if interval > remaining { - return false - } - } - if interval <= 0 { - return ctx.Err() == nil - } - timer := time.NewTimer(interval) - defer timer.Stop() - select { - case <-ctx.Done(): - return false - case <-timer.C: - return true - } -} - -func (p *reverseProxyHandler) handleMaxForwards(c *Context) (string, bool, error) { - if c == nil || c.Request == nil { - return "", false, nil - } - - switch c.Request.Method { - case http.MethodOptions, http.MethodTrace: - default: - return "", false, nil - } - - rawValue := textproto.TrimString(c.Request.Header.Get("Max-Forwards")) - if rawValue == "" { - return "", false, nil - } - - value, err := strconv.Atoi(rawValue) - if err != nil || value < 0 { - return "", false, &reverseProxyStatusError{ - status: http.StatusBadRequest, - err: fmt.Errorf("invalid Max-Forwards value %q", rawValue), - } - } - if value == 0 { - switch c.Request.Method { - case http.MethodTrace: - return "", true, p.writeLocalTraceResponse(c) - case http.MethodOptions: - p.writeLocalOptionsResponse(c) - return "", true, nil - } - } - - return strconv.Itoa(value - 1), false, nil -} - -func (p *reverseProxyHandler) writeLocalTraceResponse(c *Context) error { - if c == nil || c.Request == nil { - return nil - } - - traceReq := c.Request.Clone(c.Request.Context()) - traceReq.Body = nil - traceReq.ContentLength = 0 - traceReq.TransferEncoding = nil - traceReq.RequestURI = c.Request.RequestURI - if traceReq.RequestURI == "" && traceReq.URL != nil { - traceReq.RequestURI = traceReq.URL.RequestURI() - } - traceReq.Header = traceReq.Header.Clone() - for _, key := range []string{"Authorization", "Proxy-Authorization", "Cookie", "Forwarded", "X-Forwarded-For", "X-Forwarded-Host", "X-Forwarded-Proto", "Content-Length", "Transfer-Encoding", "Trailer"} { - traceReq.Header.Del(key) - } - - dump, err := httputil.DumpRequest(traceReq, false) - if err != nil { - return &reverseProxyStatusError{status: http.StatusInternalServerError, err: err} - } - - c.Writer.Header().Set("Content-Type", "message/http") - c.Writer.WriteHeader(http.StatusOK) - _, err = c.Writer.Write(dump) - return err -} - -func (p *reverseProxyHandler) writeLocalOptionsResponse(c *Context) { - if c == nil { - return - } - - if c.engine != nil { - if c.Request != nil && c.Request.RequestURI != "*" { - if allow := c.engine.allowedMethodsForPath(routeLookupPath(c.Request), c.allowedMethodsBuf[:0]); len(allow) > 0 { - c.allowedMethodsBuf = allow[:0] - allowHeader := c.allowHeaderBuf[:0] - for i, method := range allow { - if i > 0 { - allowHeader = append(allowHeader, ',', ' ') - } - allowHeader = append(allowHeader, method...) - } - c.allowHeaderBuf = allowHeader[:0] - c.Writer.Header().Set("Allow", string(allowHeader)) - } - } - } - c.Writer.WriteHeader(http.StatusOK) } func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) { @@ -1067,14 +456,7 @@ func appendXForwardedFor(header http.Header, clientIP string) { } func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool { - if p.config.ResponseHeaders != nil && !p.config.ResponseHeaders.Deferred { - p.config.ResponseHeaders.applyToResponse(res.Header) - } - if p.config.ModifyResponse == nil { - if p.config.ResponseHeaders != nil && p.config.ResponseHeaders.Deferred { - p.config.ResponseHeaders.applyToResponse(res.Header) - } return true } if err := p.config.ModifyResponse(res); err != nil { @@ -1082,9 +464,6 @@ func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req p.handleError(c, err) return false } - if p.config.ResponseHeaders != nil && p.config.ResponseHeaders.Deferred { - p.config.ResponseHeaders.applyToResponse(res.Header) - } return true } @@ -1143,11 +522,7 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques clientConn, brw, err := c.Writer.Hijack() if err != nil { backConn.Close() - status := http.StatusBadGateway - if errors.Is(err, http.ErrNotSupported) { - status = http.StatusNotImplemented - } - return &reverseProxyStatusError{status: status, err: err} + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} } defer clientConn.Close() @@ -1186,231 +561,6 @@ func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Reques return firstErr } -func (p *reverseProxyHandler) handleConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { - if backWrite == nil { - res.Body.Close() - return &reverseProxyStatusError{ - status: http.StatusBadGateway, - err: errors.New("reverse proxy CONNECT tunnel is missing backend writer"), - } - } - backRead := res.Body - - clientConn, brw, err := c.Writer.Hijack() - if err != nil { - backRead.Close() - _ = backWrite.Close() - status := http.StatusBadGateway - if errors.Is(err, http.ErrNotSupported) { - status = http.StatusNotImplemented - } - return &reverseProxyStatusError{status: status, err: err} - } - - defer clientConn.Close() - defer backRead.Close() - defer backWrite.Close() - - backConnClosed := make(chan struct{}) - go func() { - select { - case <-req.Context().Done(): - case <-backConnClosed: - } - backRead.Close() - _ = backWrite.Close() - }() - defer close(backConnClosed) - - res.Body = nil - if err := res.Write(brw); err != nil { - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} - } - if err := brw.Flush(); err != nil { - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} - } - - errc := make(chan error, 2) - go func() { - if _, err := io.Copy(clientConn, backRead); err != nil { - errc <- err - return - } - if cw, ok := clientConn.(interface{ CloseWrite() error }); ok { - errc <- cw.CloseWrite() - return - } - errc <- errReverseProxyCopyDone - }() - go func() { - if _, err := io.Copy(backWrite, clientConn); err != nil { - errc <- err - return - } - errc <- backWrite.Close() - }() - - firstErr := <-errc - if firstErr == nil { - firstErr = <-errc - } - if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) { - return nil - } - return firstErr -} - -func (p *reverseProxyHandler) handleBridgedExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, bridge *reverseProxyExtendedConnectBridge) error { - if c == nil || c.Request == nil { - res.Body.Close() - return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT bridge requires a valid request context")} - } - backConn, ok := res.Body.(io.ReadWriteCloser) - if !ok { - res.Body.Close() - return &reverseProxyStatusError{ - status: http.StatusBadGateway, - err: errors.New("backend returned bridged websocket response without writable body"), - } - } - - controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - backConn.Close() - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} - } - - responseHeader := c.Writer.Header() - reverseProxyCopyHeader(responseHeader, res.Header) - removeHopByHopHeaders(responseHeader) - responseHeader.Del("Sec-WebSocket-Accept") - c.Writer.WriteHeader(http.StatusOK) - if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { - backConn.Close() - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} - } - - conn := &reverseProxyH2ReadWriteCloser{ReadCloser: bridge.body, ResponseWriter: c.Writer, controller: controller} - - var closeOnce sync.Once - closeTunnel := func() { - closeOnce.Do(func() { - _ = conn.Close() - _ = backConn.Close() - }) - } - go func() { - <-req.Context().Done() - closeTunnel() - }() - - errc := make(chan error, 2) - copyer := switchProtocolCopier{user: conn, backend: backConn} - go copyer.copyToBackend(errc) - go copyer.copyFromBackend(errc) - - var firstErr error - for range 2 { - err := <-errc - if reverseProxyIsBenignTunnelError(err) { - continue - } - if firstErr == nil { - firstErr = err - closeTunnel() - } - } - closeTunnel() - if reverseProxyIsBenignTunnelError(firstErr) { - return nil - } - return firstErr -} - -func (p *reverseProxyHandler) handleExtendedConnectResponse(c *Context, req *http.Request, res *http.Response, backWrite *io.PipeWriter) error { - if c == nil || c.Request == nil { - res.Body.Close() - if backWrite != nil { - _ = backWrite.Close() - } - return &reverseProxyStatusError{status: http.StatusBadGateway, err: errors.New("extended CONNECT requires a valid request context")} - } - if backWrite == nil { - res.Body.Close() - return &reverseProxyStatusError{ - status: http.StatusBadGateway, - err: errors.New("reverse proxy extended CONNECT tunnel is missing backend writer"), - } - } - - controller := http.NewResponseController(reverseProxyBaseResponseWriter(c.Writer)) - if err := controller.EnableFullDuplex(); err != nil && !errors.Is(err, http.ErrNotSupported) { - res.Body.Close() - _ = backWrite.Close() - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} - } - - reverseProxyCopyHeader(c.Writer.Header(), res.Header) - c.Writer.WriteHeader(res.StatusCode) - if err := controller.Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { - res.Body.Close() - _ = backWrite.Close() - return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} - } - - var closeOnce sync.Once - closeTunnel := func() { - closeOnce.Do(func() { - _ = c.Request.Body.Close() - _ = backWrite.Close() - _ = res.Body.Close() - }) - } - go func() { - <-req.Context().Done() - closeTunnel() - }() - - errc := make(chan error, 2) - go func() { - _, err := io.Copy(backWrite, c.Request.Body) - closeErr := backWrite.Close() - if err != nil && !reverseProxyIsBenignTunnelError(err) { - errc <- err - return - } - errc <- closeErr - }() - go func() { - copyErr := p.copyResponse(c.Writer, res.Body, -1) - closeErr := res.Body.Close() - if copyErr != nil { - errc <- copyErr - return - } - errc <- closeErr - }() - - var firstErr error - for range 2 { - err := <-errc - if reverseProxyIsBenignTunnelError(err) { - continue - } - if firstErr == nil { - firstErr = err - closeTunnel() - } - } - closeTunnel() - if reverseProxyIsBenignTunnelError(firstErr) { - return nil - } - - return firstErr - -} - func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { return -1 @@ -1436,10 +586,6 @@ func (p *reverseProxyHandler) copyResponse(dst ResponseWriter, src io.Reader, fl if p.config.BufferPool != nil { buf = p.config.BufferPool.Get() defer p.config.BufferPool.Put(buf) - } else { - bufp := reverseProxyCopyBufferPool.Get().(*[]byte) - buf = *bufp - defer reverseProxyCopyBufferPool.Put(bufp) } _, err := p.copyBuffer(writer, src, buf) return err @@ -1453,7 +599,7 @@ func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byt var written int64 for { nr, rerr := src.Read(buf) - if rerr != nil && !errors.Is(rerr, io.EOF) && !reverseProxyIsBenignTunnelError(rerr) { + if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) { p.logf(nil, "reverse proxy read error during body copy: %v", rerr) } if nr > 0 { @@ -1492,10 +638,6 @@ func reverseProxyStatusCode(err error) int { if errors.As(err, &statusErr) && statusErr.status > 0 { return statusErr.status } - var netErr net.Error - if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &netErr) && netErr.Timeout()) { - return http.StatusGatewayTimeout - } return http.StatusBadGateway } @@ -1509,75 +651,6 @@ func validateReverseProxyTarget(target *url.URL) error { return nil } -func buildReverseProxyUpstreams(config ReverseProxyConfig) ([]*reverseProxyUpstream, error) { - if config.Target != nil && len(config.Targets) > 0 { - return nil, errors.New("reverse proxy Target and Targets cannot be used together") - } - - targets := make([]*url.URL, 0, max(1, len(config.Targets))) - if config.Target != nil { - target := cloneReverseProxyURL(config.Target) - normalizeReverseProxyTarget(target) - if err := validateReverseProxyTarget(target); err != nil { - return nil, err - } - targets = append(targets, target) - } - for i, rawTarget := range config.Targets { - trimmed := strings.TrimSpace(rawTarget) - if trimmed == "" { - return nil, fmt.Errorf("reverse proxy target at index %d is empty", i) - } - target, err := url.Parse(trimmed) - if err != nil { - return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) - } - normalizeReverseProxyTarget(target) - if err := validateReverseProxyTarget(target); err != nil { - return nil, fmt.Errorf("reverse proxy target at index %d is invalid: %w", i, err) - } - targets = append(targets, target) - } - if len(targets) == 0 { - return nil, errReverseProxyNilTarget - } - - upstreams := make([]*reverseProxyUpstream, 0, len(targets)) - for i, target := range targets { - useH2C := strings.EqualFold(target.Scheme, "h2c") - if useH2C { - target = cloneReverseProxyURL(target) - target.Scheme = "http" - } - upstream := &reverseProxyUpstream{ - key: fmt.Sprintf("%d:%s", i, target.String()), - target: target, - index: i, - useH2C: useH2C || config.AllowH2CUpstream, - } - if config.Transport == nil { - upstream.extendedConnectTransport = newHTTP2ExtendedConnectTransport() - upstream.bridgeTransport = newHTTP1BridgeTransport() - if upstream.useH2C { - upstream.h2cTransport = newH2CTransport() - } - } - upstreams = append(upstreams, upstream) - } - return upstreams, nil -} - -func validateReverseProxyForwardedBy(value string) error { - trimmed := strings.TrimSpace(value) - if trimmed == "" { - return nil - } - if !isValidForwardedNodeIdentifier(trimmed) { - return fmt.Errorf("reverse proxy ForwardedBy must be an RFC 7239 node identifier, got %q", value) - } - return nil -} - func normalizeReverseProxyTarget(target *url.URL) { switch strings.ToLower(target.Scheme) { case "ws": @@ -1659,136 +732,6 @@ func buildForwardedHeaderValue(clientIP, by, host, scheme string) string { return strings.Join(pairs, ";") } -func reverseProxyUsesForwardedHeader(policy ForwardedHeadersPolicy) bool { - return policy == ForwardedBoth || policy == ForwardedRFC7239Only -} - -func reverseProxyPrepareExtendedConnectBridge(req *http.Request) (context.Context, bool, error) { - if req == nil { - return context.Background(), false, nil - } - protocol := reverseProxyExtendedConnectProtocol(req) - if req.Method != http.MethodConnect || protocol == "" || !strings.EqualFold(protocol, "websocket") { - return req.Context(), false, nil - } - - bridge := &reverseProxyExtendedConnectBridge{body: req.Body} - ctx := context.WithValue(req.Context(), reverseProxyExtendedConnectBridge{}, bridge) - req.Header.Del(":protocol") - req.Method = http.MethodGet - req.Body = http.NoBody - req.ContentLength = 0 - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Sec-WebSocket-Version", "13") - key, err := reverseProxyGenerateWebSocketKey() - if err != nil { - return nil, false, fmt.Errorf("reverse proxy failed to generate websocket key: %w", err) - } - req.Header.Set("Sec-WebSocket-Key", key) - return ctx, true, nil -} - -func reverseProxyExtendedConnectBridgeFromContext(ctx context.Context) *reverseProxyExtendedConnectBridge { - if ctx == nil { - return nil - } - bridge, _ := ctx.Value(reverseProxyExtendedConnectBridge{}).(*reverseProxyExtendedConnectBridge) - return bridge -} - -func reverseProxyGenerateWebSocketKey() (string, error) { - key := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(key), nil -} - -func reverseProxyIsExtendedConnectRequest(req *http.Request) bool { - return reverseProxyExtendedConnectProtocol(req) != "" -} - -func reverseProxyExtendedConnectProtocol(req *http.Request) string { - if req == nil || req.Method != http.MethodConnect || req.Header == nil { - return "" - } - return textproto.TrimString(req.Header.Get(":protocol")) -} - -func isValidForwardedNodeIdentifier(value string) bool { - if value == "" { - return false - } - if strings.HasPrefix(value, "[") { - closing := strings.IndexByte(value, ']') - if closing <= 1 { - return false - } - addr, err := netip.ParseAddr(value[1:closing]) - if err != nil || !addr.Is6() { - return false - } - if closing == len(value)-1 { - return true - } - if value[closing+1] != ':' { - return false - } - return isValidForwardedNodePort(value[closing+2:]) - } - - host, port, hasPort := strings.Cut(value, ":") - if hasPort { - switch { - case host == "unknown", isValidForwardedObfuscatedIdentifier(host): - return isValidForwardedNodePort(port) - default: - addr, err := netip.ParseAddr(host) - return err == nil && addr.Is4() && isValidForwardedNodePort(port) - } - } - - if value == "unknown" || isValidForwardedObfuscatedIdentifier(value) { - return true - } - addr, err := netip.ParseAddr(value) - return err == nil && addr.Is4() -} - -func isValidForwardedNodePort(value string) bool { - if value == "" { - return false - } - if isValidForwardedObfuscatedIdentifier(value) { - return true - } - if len(value) > 5 { - return false - } - port, err := strconv.Atoi(value) - return err == nil && port > 0 && port <= 65535 -} - -func isValidForwardedObfuscatedIdentifier(value string) bool { - if len(value) < 2 || value[0] != '_' { - return false - } - for i := 1; i < len(value); i++ { - b := value[i] - if (b >= '0' && b <= '9') || (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') { - continue - } - switch b { - case '.', '_', '-': - continue - default: - return false - } - } - return true -} - func formatForwardedFor(clientIP string) string { addr, err := netip.ParseAddr(clientIP) if err != nil { @@ -1856,8 +799,8 @@ func reverseProxyViaProtocol(major, minor int, raw string) string { if major > 0 { return strconv.Itoa(major) + "." + strconv.Itoa(minor) } - if after, ok := strings.CutPrefix(raw, "HTTP/"); ok { - return after + if strings.HasPrefix(raw, "HTTP/") { + return strings.TrimPrefix(raw, "HTTP/") } return raw } @@ -1874,47 +817,6 @@ func rewriteReverseProxyURL(req *http.Request, target *url.URL) { } } -func rewriteReverseProxyConnectRequest(req *http.Request, target *url.URL) error { - connectTarget, err := reverseProxyConnectTarget(target) - if err != nil { - return &reverseProxyStatusError{status: http.StatusBadRequest, err: err} - } - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = "" - req.URL.RawPath = "" - req.URL.RawQuery = "" - req.URL.Opaque = connectTarget - req.Host = connectTarget - return nil -} - -func reverseProxyConnectTarget(target *url.URL) (string, error) { - if target == nil { - return "", errReverseProxyNilTarget - } - host := target.Hostname() - if host == "" { - return "", errReverseProxyInvalidTarget - } - port := target.Port() - if port == "" { - switch strings.ToLower(target.Scheme) { - case "http": - port = "80" - case "https": - port = "443" - default: - return "", fmt.Errorf("reverse proxy CONNECT target requires a supported scheme, got %q", target.Scheme) - } - } - portNum, err := strconv.Atoi(port) - if err != nil || portNum <= 0 || portNum > 65535 { - return "", fmt.Errorf("reverse proxy CONNECT target has invalid port %q", port) - } - return net.JoinHostPort(host, port), nil -} - func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) { if base.RawPath == "" && incoming.RawPath == "" { return reverseProxySingleJoiningSlash(base.Path, incoming.Path), "" @@ -1971,7 +873,7 @@ var reverseProxyHopHeaders = []string{ func removeHopByHopHeaders(header http.Header) { for _, connectionValue := range header["Connection"] { - for token := range strings.SplitSeq(connectionValue, ",") { + for _, token := range strings.Split(connectionValue, ",") { trimmed := textproto.TrimString(token) if trimmed != "" { header.Del(trimmed) @@ -1995,7 +897,7 @@ func headerValuesContainToken(values []string, token string) bool { return false } for _, value := range values { - for part := range strings.SplitSeq(value, ",") { + for _, part := range strings.Split(value, ",") { if strings.EqualFold(textproto.TrimString(part), token) { return true } @@ -2017,59 +919,6 @@ func cleanReverseProxyQueryParams(rawQuery string) string { return values.Encode() } -func reverseProxyShouldPanicOnCopyError(req *http.Request) bool { - return req != nil && req.Context().Value(http.ServerContextKey) != nil -} - -func reverseProxyCanRetryRequest(req *http.Request) bool { - if req == nil || req.Method == http.MethodConnect || reverseProxyUpgradeType(req.Header) != "" || !reverseProxyMethodIsSafe(req.Method) { - return false - } - if req.Body == nil || req.ContentLength == 0 { - return true - } - return req.GetBody != nil -} - -func reverseProxyShouldCountPassiveFailure(req *http.Request, err error) bool { - if err == nil || reverseProxyIsBenignTunnelError(err) { - return false - } - if req != nil && req.Context().Err() != nil { - return false - } - return !errors.Is(err, context.Canceled) -} - -func reverseProxyMethodIsSafe(method string) bool { - switch method { - case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: - return true - default: - return false - } -} - -func reverseProxyIsBenignTunnelError(err error) bool { - return err == nil || errors.Is(err, errReverseProxyCopyDone) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) || reverseProxyIsClosedBodyError(err) -} - -func reverseProxyIsClosedBodyError(err error) bool { - if err == nil { - return false - } - var streamErr http2.StreamError - if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeCancel { - return true - } - switch err.Error() { - case "body closed by handler", "http2: response body closed", "response body closed": - return true - default: - return false - } -} - func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { return UnwrapResponseWriter(writer) } diff --git a/reverseproxy_benchmark_test.go b/reverseproxy_benchmark_test.go deleted file mode 100644 index b496f5c..0000000 --- a/reverseproxy_benchmark_test.go +++ /dev/null @@ -1,355 +0,0 @@ -package touka - -import ( - "bufio" - "bytes" - "errors" - "io" - "net" - "net/http" - "strings" - "testing" - "time" -) - -type benchmarkReadSeeker struct { - data []byte - off int -} - -func (r *benchmarkReadSeeker) Read(p []byte) (int, error) { - if r.off >= len(r.data) { - return 0, io.EOF - } - n := copy(p, r.data[r.off:]) - r.off += n - return n, nil -} - -func (r *benchmarkReadSeeker) Reset() { - r.off = 0 -} - -type benchmarkResponseWriter struct { - header http.Header - status int - size int -} - -func newBenchmarkResponseWriter() *benchmarkResponseWriter { - return &benchmarkResponseWriter{header: make(http.Header)} -} - -func (w *benchmarkResponseWriter) Header() http.Header { - return w.header -} - -func (w *benchmarkResponseWriter) WriteHeader(statusCode int) { - if w.status == 0 { - w.status = statusCode - } -} - -func (w *benchmarkResponseWriter) Write(p []byte) (int, error) { - if w.status == 0 { - w.status = http.StatusOK - } - w.size += len(p) - return len(p), nil -} - -func (w *benchmarkResponseWriter) Flush() {} - -func (w *benchmarkResponseWriter) Status() int { - return w.status -} - -func (w *benchmarkResponseWriter) Size() int { - return w.size -} - -func (w *benchmarkResponseWriter) Written() bool { - return w.status != 0 -} - -func (w *benchmarkResponseWriter) IsHijacked() bool { - return false -} - -func (w *benchmarkResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return nil, nil, http.ErrNotSupported -} - -func (w *benchmarkResponseWriter) reset() { - clear(w.header) - w.status = 0 - w.size = 0 -} - -var benchmarkReverseProxySink int - -func BenchmarkReverseProxyCopyResponse(b *testing.B) { - body := bytes.Repeat([]byte("0123456789abcdef"), 4096) - proxy := newReverseProxyHandler(ReverseProxyConfig{}) - dst := newBenchmarkResponseWriter() - src := &benchmarkReadSeeker{data: body} - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - dst.reset() - src.Reset() - if err := proxy.copyResponse(dst, src, 0); err != nil { - b.Fatalf("copyResponse failed: %v", err) - } - } - - benchmarkReverseProxySink = dst.Size() -} - -func BenchmarkReverseProxyAvailableUpstreams(b *testing.B) { - proxy := &reverseProxyHandler{ - upstreams: []*reverseProxyUpstream{ - {key: "a", index: 0}, - {key: "b", index: 1}, - {key: "c", index: 2}, - {key: "d", index: 3}, - }, - config: ReverseProxyConfig{ - PassiveHealth: ReverseProxyPassiveHealthConfig{ - FailDuration: time.Minute, - MaxFails: 3, - }, - }, - } - - now := time.Now() - proxy.upstreams[0].failures = []time.Time{now.Add(-30 * time.Second)} - proxy.upstreams[1].failures = []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)} - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - benchmarkReverseProxySink = len(proxy.availableUpstreams(now, nil)) - } -} - -func BenchmarkReverseProxySelectUpstream(b *testing.B) { - proxy := &reverseProxyHandler{ - upstreams: []*reverseProxyUpstream{ - {key: "a", index: 0}, - {key: "b", index: 1}, - {key: "c", index: 2}, - {key: "d", index: 3}, - }, - config: ReverseProxyConfig{ - LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()}, - PassiveHealth: ReverseProxyPassiveHealthConfig{ - FailDuration: time.Minute, - MaxFails: 3, - }, - }, - } - proxy.upstreams[0].failures = []time.Time{time.Now().Add(-30 * time.Second)} - - c, _ := CreateTestContext(nil) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - selected, err := proxy.selectUpstream(c, nil) - if err != nil { - b.Fatalf("selectUpstream failed: %v", err) - } - benchmarkReverseProxySink = selected.index - } -} - -func BenchmarkReverseProxySelectUpstreamHeaderPolicy(b *testing.B) { - proxy := &reverseProxyHandler{ - upstreams: []*reverseProxyUpstream{ - {key: "a", index: 0}, - {key: "b", index: 1}, - {key: "c", index: 2}, - {key: "d", index: 3}, - }, - config: ReverseProxyConfig{ - LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())}, - }, - } - c, _ := CreateTestContext(nil) - c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b", "tenant-c"} - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - selected, err := proxy.selectUpstream(c, nil) - if err != nil { - b.Fatalf("selectUpstream failed: %v", err) - } - benchmarkReverseProxySink = selected.index - } -} - -func TestReverseProxyCopyResponseWithoutBufferPool(t *testing.T) { - proxy := newReverseProxyHandler(ReverseProxyConfig{}) - dst := newBenchmarkResponseWriter() - src := bytes.NewBufferString("hello, reverse proxy") - - if err := proxy.copyResponse(dst, src, 0); err != nil { - t.Fatalf("copyResponse failed: %v", err) - } - - if got, want := dst.Size(), len("hello, reverse proxy"); got != want { - t.Fatalf("expected %d bytes copied, got %d", want, got) - } -} - -type fixedLenBufferPool struct { - buf []byte -} - -func (p *fixedLenBufferPool) Get() []byte { - return p.buf -} - -func (p *fixedLenBufferPool) Put(buf []byte) { - p.buf = buf -} - -type recordingReader struct { - chunk int - reads []int - left int -} - -func (r *recordingReader) Read(p []byte) (int, error) { - if r.left == 0 { - return 0, io.EOF - } - n := min(r.chunk, len(p), r.left) - if n == 0 { - return 0, errors.New("reader received zero-length buffer") - } - for i := range n { - p[i] = 'x' - } - r.left -= n - r.reads = append(r.reads, len(p)) - return n, nil -} - -func TestReverseProxyCopyResponseRespectsCustomBufferLength(t *testing.T) { - pool := &fixedLenBufferPool{buf: make([]byte, 8, 32*1024)} - proxy := newReverseProxyHandler(ReverseProxyConfig{BufferPool: pool}) - dst := newBenchmarkResponseWriter() - src := &recordingReader{chunk: 8, left: 24} - - if err := proxy.copyResponse(dst, src, 0); err != nil { - t.Fatalf("copyResponse failed: %v", err) - } - - if len(src.reads) == 0 { - t.Fatal("expected reader to be used") - } - for _, size := range src.reads { - if size != 8 { - t.Fatalf("expected custom buffer length 8 to be preserved, got read size %d", size) - } - } -} - -func TestReverseProxyAvailableUpstreamsFiltersExcludedAndUnhealthy(t *testing.T) { - now := time.Now() - proxy := &reverseProxyHandler{ - upstreams: []*reverseProxyUpstream{ - {key: "a"}, - {key: "b", failures: []time.Time{now.Add(-20 * time.Second), now.Add(-10 * time.Second)}}, - {key: "c"}, - }, - config: ReverseProxyConfig{ - PassiveHealth: ReverseProxyPassiveHealthConfig{ - FailDuration: time.Minute, - MaxFails: 2, - }, - }, - } - - available := proxy.availableUpstreams(now, map[string]struct{}{"c": {}}) - if len(available) != 1 { - t.Fatalf("expected only one available upstream, got %d", len(available)) - } - if available[0].key != "a" { - t.Fatalf("expected upstream 'a', got %q", available[0].key) - } -} - -func TestReverseProxyHeaderPolicyUsesAllHeaderValues(t *testing.T) { - proxy := &reverseProxyHandler{ - upstreams: []*reverseProxyUpstream{ - {key: "a", index: 0}, - {key: "b", index: 1}, - {key: "c", index: 2}, - }, - config: ReverseProxyConfig{ - LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBHeader("X-Tenant", LBRandom())}, - }, - } - - c, _ := CreateTestContext(nil) - c.Request.Header["X-Tenant"] = []string{"tenant-a", "tenant-b"} - - selectedA, err := proxy.selectUpstream(c, nil) - if err != nil { - t.Fatalf("selectUpstream failed: %v", err) - } - selectedB, err := proxy.selectUpstream(c, nil) - if err != nil { - t.Fatalf("selectUpstream failed: %v", err) - } - if selectedA.key != selectedB.key { - t.Fatalf("expected stable selection for identical multi-value header, got %q and %q", selectedA.key, selectedB.key) - } - - c.Request.Header["X-Tenant"] = []string{"tenant-b", "tenant-a"} - selectedC, err := proxy.selectUpstream(c, nil) - if err != nil { - t.Fatalf("selectUpstream failed: %v", err) - } - if selectedC == nil { - t.Fatal("expected upstream for reordered multi-value header") - } -} - -func TestReverseProxyHeaderPolicyMatchesJoinCompatibility(t *testing.T) { - candidates := []*reverseProxyUpstream{ - {key: "a", index: 0}, - {key: "b", index: 1}, - {key: "c", index: 2}, - } - - testCases := [][]string{ - {"tenant-a"}, - {"tenant-a", "tenant-b"}, - {"", "tenant-b"}, - {"tenant-a", ""}, - {"", ""}, - } - - for _, values := range testCases { - got := reverseProxySelectHRWValues(candidates, values) - want := reverseProxySelectHRW(candidates, strings.Join(values, ",")) - if got == nil || want == nil { - t.Fatalf("expected non-nil upstreams for values %v", values) - } - if got.key != want.key { - t.Fatalf("expected joined compatibility for values %v, got %q want %q", values, got.key, want.key) - } - } -} - -var _ io.Writer = (*benchmarkResponseWriter)(nil) diff --git a/reverseproxy_headers_replace_test.go b/reverseproxy_headers_replace_test.go deleted file mode 100644 index 0c0d599..0000000 --- a/reverseproxy_headers_replace_test.go +++ /dev/null @@ -1,530 +0,0 @@ -package touka - -import ( - "io" - "net/http" - "net/http/httptest" - "net/url" - "regexp" - "testing" -) - -func TestReverseProxyHeaderOpsReplaceSubstring(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("X-Server"); got != "Caddy" { - t.Errorf("expected X-Server=Caddy, got %q", got) - } - if got := r.Header.Get("X-Location"); got != "/api/v2/resource" { - t.Errorf("expected X-Location=/api/v2/resource, got %q", got) - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - RequestHeaders: &HeaderOps{ - Replace: map[string][]Replacement{ - "X-Server": {{Search: "NGINX", Replace: "Caddy"}}, - "X-Location": {{Search: "v1", Replace: "v2"}}, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) - req.Header.Set("X-Server", "NGINX") - req.Header.Set("X-Location", "/api/v1/resource") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status 200, got %d", resp.StatusCode) - } -} - -func TestReverseProxyHeaderOpsReplaceRegexp(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("X-Route"); got != "/proxy-upstream" { - t.Errorf("expected X-Route=/proxy-upstream, got %q", got) - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - RequestHeaders: &HeaderOps{ - Replace: map[string][]Replacement{ - "X-Route": {{SearchRegexp: `^/([^/]+)/(.+)$`, Replace: "/proxy-$2"}}, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) - req.Header.Set("X-Route", "/original/upstream") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status 200, got %d", resp.StatusCode) - } -} - -func TestReverseProxyHeaderOpsReplaceWildcard(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("X-Host-A"); got != "new.example.com" { - t.Errorf("expected X-Host-A=new.example.com, got %q", got) - } - if got := r.Header.Get("X-Host-B"); got != "new.example.com" { - t.Errorf("expected X-Host-B=new.example.com, got %q", got) - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - RequestHeaders: &HeaderOps{ - Replace: map[string][]Replacement{ - "*": {{Search: "old.example.com", Replace: "new.example.com"}}, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) - req.Header.Set("X-Host-A", "old.example.com") - req.Header.Set("X-Host-B", "old.example.com") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status 200, got %d", resp.StatusCode) - } -} - -func TestReverseProxyHeaderOpsReplaceResponse(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Backend", "backend-internal:8080") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - ResponseHeaders: &RespHeaderOps{ - HeaderOps: &HeaderOps{ - Replace: map[string][]Replacement{ - "X-Backend": {{Search: "backend-internal:8080", Replace: "public.example.com"}}, - }, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - resp, err := http.Get(proxy.URL + "/test") - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if got := resp.Header.Get("X-Backend"); got != "public.example.com" { - t.Errorf("expected X-Backend=public.example.com, got %q", got) - } -} - -func TestReverseProxyHeaderOpsProvisionInvalidRegexp(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - RequestHeaders: &HeaderOps{ - Replace: map[string][]Replacement{ - "X-Test": {{SearchRegexp: "[invalid"}}, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("expected status 500, got %d", resp.StatusCode) - } -} - -func TestReplacementApply(t *testing.T) { - tests := []struct { - name string - r *Replacement - s string - want string - }{ - {name: "nil replacement", r: nil, s: "hello", want: "hello"}, - {name: "empty string", r: &Replacement{Search: "x", Replace: "y"}, s: "", want: ""}, - {name: "substring match", r: &Replacement{Search: "world", Replace: "go"}, s: "hello world", want: "hello go"}, - {name: "substring no match", r: &Replacement{Search: "foo", Replace: "bar"}, s: "hello world", want: "hello world"}, - {name: "substring multiple", r: &Replacement{Search: "a", Replace: "b"}, s: "aaa", want: "bbb"}, - {name: "regexp match", r: &Replacement{SearchRegexp: `\d+`, Replace: "N", re: regexp.MustCompile(`\d+`)}, s: "abc123def", want: "abcNdef"}, - {name: "regexp no match", r: &Replacement{SearchRegexp: `z+`, Replace: "Z", re: regexp.MustCompile(`z+`)}, s: "abc", want: "abc"}, - {name: "empty search and regexp", r: &Replacement{}, s: "unchanged", want: "unchanged"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.r.apply(tt.s); got != tt.want { - t.Errorf("Replacement.apply() = %q, want %q", got, tt.want) - } - }) - } -} - -func BenchmarkHeaderOpsAdd(b *testing.B) { - ops := &HeaderOps{ - Add: map[string][]string{ - "X-Custom-1": {"value-1"}, - "X-Custom-2": {"value-2"}, - "X-Custom-3": {"value-3"}, - }, - } - hdr := make(http.Header) - repl := &reverseProxyReplacer{} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - hdr = make(http.Header) - ops.applyTo(hdr, repl) - } -} - -func BenchmarkHeaderOpsSet(b *testing.B) { - ops := &HeaderOps{ - Set: map[string][]string{ - "X-Frame-Options": {"DENY"}, - "X-Content-Type-Options": {"nosniff"}, - "X-XSS-Protection": {"1; mode=block"}, - }, - } - hdr := make(http.Header) - repl := &reverseProxyReplacer{} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - hdr = make(http.Header) - ops.applyTo(hdr, repl) - } -} - -func BenchmarkHeaderOpsDeleteSingle(b *testing.B) { - ops := &HeaderOps{ - Delete: []string{"X-Powered-By"}, - } - repl := &reverseProxyReplacer{} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - hdr := make(http.Header) - hdr.Set("X-Powered-By", "Express") - hdr.Set("X-Keep", "value") - ops.applyTo(hdr, repl) - } -} - -func BenchmarkHeaderOpsDeleteWildcard(b *testing.B) { - ops := &HeaderOps{ - Delete: []string{"X-Debug-*"}, - } - repl := &reverseProxyReplacer{} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - hdr := make(http.Header) - hdr.Set("X-Debug-1", "v1") - hdr.Set("X-Debug-2", "v2") - hdr.Set("X-Keep", "value") - ops.applyTo(hdr, repl) - } -} - -func BenchmarkHeaderOpsReplaceSubstring(b *testing.B) { - ops := &HeaderOps{ - Replace: map[string][]Replacement{ - "Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}}, - }, - } - repl := &reverseProxyReplacer{} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - hdr := make(http.Header) - hdr.Set("Location", "http://internal:8080/api/v1/users") - ops.applyTo(hdr, repl) - } -} - -func BenchmarkHeaderOpsReplaceRegexp(b *testing.B) { - re := regexp.MustCompile(`^http://([^/]+)(/.*)$`) - ops := &HeaderOps{ - Replace: map[string][]Replacement{ - "Location": {{SearchRegexp: `^http://([^/]+)(/.*)$`, Replace: "https://public.example.com$2", re: re}}, - }, - } - repl := &reverseProxyReplacer{} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - hdr := make(http.Header) - hdr.Set("Location", "http://internal:8080/api/v1/users") - ops.applyTo(hdr, repl) - } -} - -func BenchmarkHeaderOpsReplaceWildcard(b *testing.B) { - ops := &HeaderOps{ - Replace: map[string][]Replacement{ - "*": {{Search: "internal.example.com", Replace: "public.example.com"}}, - }, - } - repl := &reverseProxyReplacer{} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - hdr := make(http.Header) - hdr.Set("X-Host", "internal.example.com") - hdr.Set("X-Origin", "internal.example.com") - ops.applyTo(hdr, repl) - } -} - -func BenchmarkHeaderOpsMixed(b *testing.B) { - ops := &HeaderOps{ - Add: map[string][]string{ - "X-Request-ID": {"req-123"}, - }, - Set: map[string][]string{ - "X-Frame-Options": {"DENY"}, - }, - Delete: []string{"X-Powered-By"}, - Replace: map[string][]Replacement{ - "Location": {{Search: "http://internal:8080", Replace: "https://public.example.com"}}, - }, - } - repl := &reverseProxyReplacer{} - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - hdr := make(http.Header) - hdr.Set("X-Powered-By", "Express") - hdr.Set("Location", "http://internal:8080/api") - ops.applyTo(hdr, repl) - } -} - -func BenchmarkReplacementApplySubstring(b *testing.B) { - r := &Replacement{Search: "old.example.com", Replace: "new.example.com"} - s := "https://old.example.com/api/v1/resource" - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = r.apply(s) - } -} - -func BenchmarkReplacementApplyRegexp(b *testing.B) { - r := &Replacement{SearchRegexp: `^https?://[^/]+`, Replace: "https://new.example.com", re: regexp.MustCompile(`^https?://[^/]+`)} - s := "https://old.example.com/api/v1/resource" - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = r.apply(s) - } -} - -func TestReverseProxyReplacerDynamicVars(t *testing.T) { - req, _ := http.NewRequest(http.MethodGet, "http://example.com/api/v1/users?sort=name&limit=10", nil) - req.Host = "example.com" - repl := newReverseProxyReplacer(req) - - tests := []struct { - name string - input string - want string - }{ - {"method", "{method}", "GET"}, - {"host", "{host}", "example.com"}, - {"path", "{path}", "/api/v1/users"}, - {"query", "{query}", "sort=name&limit=10"}, - {"scheme", "{scheme}", "http"}, - {"proto", "{proto}", "HTTP/1.1"}, - {"combined", "X-{method}-{path}", "X-GET-/api/v1/users"}, - {"no vars", "static-value", "static-value"}, - {"empty", "", ""}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := repl.Replace(tt.input); got != tt.want { - t.Errorf("Replace(%q) = %q, want %q", tt.input, got, tt.want) - } - }) - } -} - -func TestReverseProxyReplacerNilRequest(t *testing.T) { - repl := newReverseProxyReplacer(nil) - if got := repl.Replace("{method}"); got != "{method}" { - t.Errorf("expected unchanged string with nil request, got %q", got) - } -} - -func TestReverseProxyReplacerNilReplacer(t *testing.T) { - var repl *reverseProxyReplacer - if got := repl.Replace("{method}"); got != "{method}" { - t.Errorf("expected unchanged string with nil replacer, got %q", got) - } -} - -func TestReverseProxyReplacerFromHeader(t *testing.T) { - hdr := make(http.Header) - repl := newReverseProxyReplacerFromHeader(hdr) - if got := repl.Replace("{method}"); got != "{method}" { - t.Errorf("expected unchanged string from header replacer, got %q", got) - } -} - -func TestReverseProxyHeaderOpsWithDynamicVars(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("X-Forwarded-Path"); got != "/dynamic/path" { - t.Errorf("expected X-Forwarded-Path=/dynamic/path, got %q", got) - } - if got := r.Header.Get("X-Forwarded-Method"); got != "GET" { - t.Errorf("expected X-Forwarded-Method=GET, got %q", got) - } - if got := r.Header.Get("X-Forwarded-Host"); got != "client.example" { - t.Errorf("expected X-Forwarded-Host=client.example, got %q", got) - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/dynamic/path", ReverseProxy(ReverseProxyConfig{ - Target: target, - RequestHeaders: &HeaderOps{ - Add: map[string][]string{ - "X-Forwarded-Path": {"{path}"}, - "X-Forwarded-Method": {"{method}"}, - "X-Forwarded-Host": {"{host}"}, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/dynamic/path", nil) - req.Host = "client.example" - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status 200, got %d", resp.StatusCode) - } -} diff --git a/reverseproxy_headers_test.go b/reverseproxy_headers_test.go deleted file mode 100644 index 4a4ae26..0000000 --- a/reverseproxy_headers_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package touka - -import ( - "io" - "net/http" - "net/http/httptest" - "net/url" - "testing" -) - -func TestReverseProxyHeaderOpsAdd(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got := r.Header.Get("X-Custom-Header"); got != "test-value" { - t.Errorf("expected X-Custom-Header=test-value, got %q", got) - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - RequestHeaders: &HeaderOps{ - Add: map[string][]string{ - "X-Custom-Header": {"test-value"}, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - resp, err := http.Get(proxy.URL + "/test") - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status 200, got %d", resp.StatusCode) - } -} - -func TestReverseProxyHeaderOpsDelete(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("X-Sensitive") != "" { - t.Errorf("expected X-Sensitive header to be deleted, got %q", r.Header.Get("X-Sensitive")) - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - RequestHeaders: &HeaderOps{ - Delete: []string{"X-Sensitive"}, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) - req.Header.Set("X-Sensitive", "should-be-removed") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status 200, got %d", resp.StatusCode) - } -} - -func TestReverseProxyHeaderOpsSet(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - got := r.Header.Get("X-Replace") - if got != "new-value" { - t.Errorf("expected X-Replace=new-value, got %q", got) - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - RequestHeaders: &HeaderOps{ - Set: map[string][]string{ - "X-Replace": {"new-value"}, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - req, _ := http.NewRequest(http.MethodGet, proxy.URL+"/test", nil) - req.Header.Set("X-Replace", "old-value") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status 200, got %d", resp.StatusCode) - } -} - -func TestReverseProxyResponseHeaderOps(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Backend", "backend-server") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - ResponseHeaders: &RespHeaderOps{ - HeaderOps: &HeaderOps{ - Set: map[string][]string{ - "X-Custom": {"custom-value"}, - }, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - resp, err := http.Get(proxy.URL + "/test") - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if got := resp.Header.Get("X-Custom"); got != "custom-value" { - t.Errorf("expected X-Custom=custom-value, got %q", got) - } -} - -func TestReverseProxyResponseHeaderOpsDelete(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Powered-By", "Express") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer backend.Close() - - target, err := url.Parse(backend.URL) - if err != nil { - t.Fatalf("parse target: %v", err) - } - - engine := New() - engine.GET("/test", ReverseProxy(ReverseProxyConfig{ - Target: target, - ResponseHeaders: &RespHeaderOps{ - HeaderOps: &HeaderOps{ - Delete: []string{"X-Powered-By"}, - }, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - resp, err := http.Get(proxy.URL + "/test") - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, _ = io.ReadAll(resp.Body) - - if got := resp.Header.Get("X-Powered-By"); got != "" { - t.Errorf("expected X-Powered-By to be deleted, got %q", got) - } -} diff --git a/reverseproxy_lb.go b/reverseproxy_lb.go deleted file mode 100644 index ce5e949..0000000 --- a/reverseproxy_lb.go +++ /dev/null @@ -1,409 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. -// Copyright 2026 WJQSERVER. All rights reserved. -// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. -package touka - -import ( - "fmt" - "math/rand/v2" - "net/http" - "net/textproto" - "net/url" - "slices" - "strings" - "sync" - "sync/atomic" - "time" -) - -// ReverseProxyLoadBalancingConfig configures upstream selection and retries. -type ReverseProxyLoadBalancingConfig struct { - Policy ReverseProxyLBPolicy - Retries int - TryDuration time.Duration - TryInterval time.Duration -} - -// ReverseProxyPassiveHealthConfig configures inline passive health tracking. -type ReverseProxyPassiveHealthConfig struct { - FailDuration time.Duration - MaxFails int - UnhealthyStatus []int -} - -// ReverseProxyLBPolicy selects an upstream from the configured target pool. -// Use the helper constructors such as LBRandom or LBHeader to build a policy. -type ReverseProxyLBPolicy struct { - kind reverseProxyLBPolicyKind - key string - fallback *ReverseProxyLBPolicy -} - -type reverseProxyLBPolicyKind uint8 - -const ( - reverseProxyLBPolicyRandom reverseProxyLBPolicyKind = iota - reverseProxyLBPolicyRoundRobin - reverseProxyLBPolicyFirst - reverseProxyLBPolicyLeastConn - reverseProxyLBPolicyIPHash - reverseProxyLBPolicyClientIPHash - reverseProxyLBPolicyURIHash - reverseProxyLBPolicyHeader - reverseProxyLBPolicyQuery -) - -type reverseProxyUpstream struct { - key string - target *url.URL - index int - useH2C bool - extendedConnectTransport http.RoundTripper - bridgeTransport http.RoundTripper - h2cTransport http.RoundTripper - inFlight atomic.Int64 - - passiveMu sync.Mutex - failures []time.Time -} - -func LBRandom() ReverseProxyLBPolicy { - return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRandom} -} - -func LBRoundRobin() ReverseProxyLBPolicy { - return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyRoundRobin} -} - -func LBFirst() ReverseProxyLBPolicy { - return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyFirst} -} - -func LBLeastConn() ReverseProxyLBPolicy { - return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyLeastConn} -} - -func LBIPHash() ReverseProxyLBPolicy { - return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyIPHash} -} - -func LBClientIPHash() ReverseProxyLBPolicy { - return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyClientIPHash} -} - -func LBURIHash() ReverseProxyLBPolicy { - return ReverseProxyLBPolicy{kind: reverseProxyLBPolicyURIHash} -} - -func LBHeader(field string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy { - policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyHeader, key: textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(field))} - if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil { - policy.fallback = &fallback - } - return policy -} - -func LBQuery(key string, fallback ReverseProxyLBPolicy) ReverseProxyLBPolicy { - policy := ReverseProxyLBPolicy{kind: reverseProxyLBPolicyQuery, key: strings.TrimSpace(key)} - if fallback.kind != reverseProxyLBPolicyRandom || fallback.key != "" || fallback.fallback != nil { - policy.fallback = &fallback - } - return policy -} - -func validateReverseProxyLBPolicy(policy ReverseProxyLBPolicy) error { - switch policy.kind { - case reverseProxyLBPolicyRandom, reverseProxyLBPolicyRoundRobin, reverseProxyLBPolicyFirst, - reverseProxyLBPolicyLeastConn, reverseProxyLBPolicyIPHash, reverseProxyLBPolicyClientIPHash, - reverseProxyLBPolicyURIHash: - return nil - case reverseProxyLBPolicyHeader: - if policy.key == "" { - return fmt.Errorf("reverse proxy header load-balancing policy requires a header field") - } - case reverseProxyLBPolicyQuery: - if policy.key == "" { - return fmt.Errorf("reverse proxy query load-balancing policy requires a query key") - } - default: - return fmt.Errorf("reverse proxy load-balancing policy is invalid") - } - if policy.fallback != nil { - return validateReverseProxyLBPolicy(*policy.fallback) - } - return nil -} - -func (p *reverseProxyHandler) selectUpstream(c *Context, excluded map[string]struct{}) (*reverseProxyUpstream, error) { - now := time.Now() - policy := p.config.LoadBalancing.Policy - candidateBuf := reverseProxyCandidatePool.Get().(*[]*reverseProxyUpstream) - candidates := p.availableUpstreamsInto(now, excluded, *candidateBuf) - if len(candidates) == 0 && len(excluded) > 0 { - candidates = p.availableUpstreamsInto(now, nil, candidates[:0]) - } - if len(candidates) == 0 { - *candidateBuf = candidates[:0] - reverseProxyCandidatePool.Put(candidateBuf) - return nil, errReverseProxyNoAvailableUpstreams - } - selected := p.selectUpstreamWithPolicy(c, candidates, policy) - *candidateBuf = candidates[:0] - reverseProxyCandidatePool.Put(candidateBuf) - return selected, nil -} - -func (p *reverseProxyHandler) availableUpstreams(now time.Time, excluded map[string]struct{}) []*reverseProxyUpstream { - return p.availableUpstreamsInto(now, excluded, nil) -} - -func (p *reverseProxyHandler) availableUpstreamsInto(now time.Time, excluded map[string]struct{}, candidates []*reverseProxyUpstream) []*reverseProxyUpstream { - if cap(candidates) < len(p.upstreams) { - candidates = make([]*reverseProxyUpstream, 0, len(p.upstreams)) - } else { - candidates = candidates[:0] - } - for _, upstream := range p.upstreams { - if _, skip := excluded[upstream.key]; skip { - continue - } - if !upstream.healthy(now, p.config.PassiveHealth) { - continue - } - candidates = append(candidates, upstream) - } - return candidates -} - -func (p *reverseProxyHandler) selectUpstreamWithPolicy(c *Context, candidates []*reverseProxyUpstream, policy ReverseProxyLBPolicy) *reverseProxyUpstream { - if len(candidates) == 0 { - return nil - } - - switch policy.kind { - case reverseProxyLBPolicyRoundRobin: - return candidates[p.nextRoundRobinIndex(len(candidates))] - case reverseProxyLBPolicyFirst: - return candidates[0] - case reverseProxyLBPolicyLeastConn: - return p.selectLeastConnUpstream(candidates) - case reverseProxyLBPolicyIPHash: - return reverseProxySelectHRW(candidates, reverseProxyClientIP(c.Request.RemoteAddr)) - case reverseProxyLBPolicyClientIPHash: - return reverseProxySelectHRW(candidates, c.RequestIP()) - case reverseProxyLBPolicyURIHash: - if c.Request == nil || c.Request.URL == nil { - return reverseProxySelectRandom(candidates) - } - return reverseProxySelectHRW(candidates, c.Request.URL.RequestURI()) - case reverseProxyLBPolicyHeader: - if c.Request != nil && c.Request.Header != nil { - if values, ok := c.Request.Header[policy.key]; ok { - return reverseProxySelectHRWValues(candidates, values) - } - } - return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) - case reverseProxyLBPolicyQuery: - if c.Request != nil && c.Request.URL != nil { - if values, ok := c.Request.URL.Query()[policy.key]; ok { - return reverseProxySelectHRW(candidates, strings.Join(values, ",")) - } - } - return p.selectUpstreamWithPolicy(c, candidates, reverseProxyFallbackPolicy(policy)) - case reverseProxyLBPolicyRandom: - fallthrough - default: - return reverseProxySelectRandom(candidates) - } -} - -func (p *reverseProxyHandler) nextRoundRobinIndex(size int) int { - if size <= 1 { - return 0 - } - return int((p.roundRobin.Add(1) - 1) % uint64(size)) -} - -func (p *reverseProxyHandler) selectLeastConnUpstream(candidates []*reverseProxyUpstream) *reverseProxyUpstream { - if len(candidates) == 0 { - return nil - } - selected := candidates[0] - lowest := selected.inFlight.Load() - ties := []*reverseProxyUpstream{selected} - for _, upstream := range candidates[1:] { - count := upstream.inFlight.Load() - switch { - case count < lowest: - selected = upstream - lowest = count - ties = []*reverseProxyUpstream{upstream} - case count == lowest: - ties = append(ties, upstream) - } - } - if len(ties) == 1 { - return selected - } - return ties[p.nextRoundRobinIndex(len(ties))] -} - -func reverseProxySelectRandom(candidates []*reverseProxyUpstream) *reverseProxyUpstream { - if len(candidates) == 0 { - return nil - } - if len(candidates) == 1 { - return candidates[0] - } - return candidates[rand.IntN(len(candidates))] -} - -func reverseProxySelectHRW(candidates []*reverseProxyUpstream, key string) *reverseProxyUpstream { - if len(candidates) == 0 { - return nil - } - if key == "" { - return reverseProxySelectRandom(candidates) - } - selected := candidates[0] - bestScore := reverseProxyHRWScore(key, selected.key) - for _, upstream := range candidates[1:] { - score := reverseProxyHRWScore(key, upstream.key) - if score > bestScore { - selected = upstream - bestScore = score - } - } - return selected -} - -func reverseProxySelectHRWValues(candidates []*reverseProxyUpstream, values []string) *reverseProxyUpstream { - if len(candidates) == 0 { - return nil - } - if len(values) == 0 { - return reverseProxySelectRandom(candidates) - } - selected := candidates[0] - bestScore := reverseProxyHRWValuesScore(values, selected.key) - for _, upstream := range candidates[1:] { - score := reverseProxyHRWValuesScore(values, upstream.key) - if score > bestScore { - selected = upstream - bestScore = score - } - } - return selected -} - -func reverseProxyHRWScore(key, upstreamKey string) uint64 { - const ( - offset64 = 14695981039346656037 - prime64 = 1099511628211 - ) - h := uint64(offset64) - for i := 0; i < len(key); i++ { - h ^= uint64(key[i]) - h *= prime64 - } - h ^= 0xff - h *= prime64 - for i := 0; i < len(upstreamKey); i++ { - h ^= uint64(upstreamKey[i]) - h *= prime64 - } - return h -} - -func reverseProxyHRWValuesScore(values []string, upstreamKey string) uint64 { - const ( - offset64 = 14695981039346656037 - prime64 = 1099511628211 - ) - h := uint64(offset64) - for valueIndex, value := range values { - for i := 0; i < len(value); i++ { - h ^= uint64(value[i]) - h *= prime64 - } - if valueIndex+1 < len(values) { - h ^= ',' - h *= prime64 - } - } - h ^= 0xff - h *= prime64 - for i := 0; i < len(upstreamKey); i++ { - h ^= uint64(upstreamKey[i]) - h *= prime64 - } - return h -} - -func reverseProxyFallbackPolicy(policy ReverseProxyLBPolicy) ReverseProxyLBPolicy { - if policy.fallback != nil { - return *policy.fallback - } - return LBRandom() -} - -func (u *reverseProxyUpstream) healthy(now time.Time, config ReverseProxyPassiveHealthConfig) bool { - maxFails := reverseProxyPassiveMaxFails(config) - if config.FailDuration <= 0 || maxFails <= 0 { - return true - } - - u.passiveMu.Lock() - defer u.passiveMu.Unlock() - u.pruneFailuresLocked(now, config.FailDuration) - return len(u.failures) < maxFails -} - -func (u *reverseProxyUpstream) recordFailure(now time.Time, config ReverseProxyPassiveHealthConfig) { - maxFails := reverseProxyPassiveMaxFails(config) - if config.FailDuration <= 0 || maxFails <= 0 { - return - } - - u.passiveMu.Lock() - defer u.passiveMu.Unlock() - u.pruneFailuresLocked(now, config.FailDuration) - u.failures = append(u.failures, now) -} - -func (u *reverseProxyUpstream) pruneFailuresLocked(now time.Time, window time.Duration) { - if len(u.failures) == 0 || window <= 0 { - if window <= 0 { - u.failures = nil - } - return - } - cutoff := now.Add(-window) - keep := 0 - for _, failureAt := range u.failures { - if failureAt.Before(cutoff) { - continue - } - u.failures[keep] = failureAt - keep++ - } - u.failures = u.failures[:keep] -} - -func reverseProxyPassiveMaxFails(config ReverseProxyPassiveHealthConfig) int { - if config.FailDuration <= 0 { - return 0 - } - if config.MaxFails <= 0 { - return 1 - } - return config.MaxFails -} - -func reverseProxyStatusIsUnhealthy(config ReverseProxyPassiveHealthConfig, status int) bool { - if status <= 0 { - return false - } - return slices.Contains(config.UnhealthyStatus, status) -} diff --git a/reverseproxy_test.go b/reverseproxy_test.go index 6863da7..f82aff9 100644 --- a/reverseproxy_test.go +++ b/reverseproxy_test.go @@ -2,10 +2,6 @@ package touka import ( "bufio" - "bytes" - "context" - crand "crypto/rand" - "crypto/tls" "errors" "fmt" "io" @@ -15,14 +11,9 @@ import ( "net/http/httptrace" "net/textproto" "net/url" - "strconv" "strings" - "sync" - "sync/atomic" "testing" "time" - - "golang.org/x/net/http2" ) func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { @@ -79,7 +70,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ Target: target, ForwardedHeaders: ForwardedBoth, - ForwardedBy: "_proxy-node", + ForwardedBy: "proxy-node", Via: "proxy.test", })) @@ -115,8 +106,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { t.Fatalf("unexpected body: %q", string(body)) } if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) + t.Fatalf("unexpected status: %d", resp.StatusCode) } if got.Path != "/base/api/ping" { t.Fatalf("unexpected upstream path: %q", got.Path) @@ -154,7 +144,7 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { if !strings.Contains(got.Forwarded, "for=198.51.100.10") { t.Fatalf("forwarded header missing client ip: %q", got.Forwarded) } - if !strings.Contains(got.Forwarded, "by=_proxy-node") { + if !strings.Contains(got.Forwarded, "by=proxy-node") { t.Fatalf("forwarded header missing by token: %q", got.Forwarded) } if !strings.Contains(got.Forwarded, "host=client.example") { @@ -180,61 +170,6 @@ func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { } } -func TestReverseProxyRejectsInvalidForwardedBy(t *testing.T) { - t.Helper() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, "http://example.com"), - ForwardedHeaders: ForwardedBoth, - ForwardedBy: "proxy-node", - })) - - rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if rr.Code != http.StatusInternalServerError { - t.Fatalf("unexpected status: %d", rr.Code) - } -} - -func TestReverseProxyForwardedByTrimsWhitespace(t *testing.T) { - t.Helper() - - forwardedCh := make(chan string, 1) - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - forwardedCh <- r.Header.Get("Forwarded") - w.WriteHeader(http.StatusNoContent) - })) - defer backend.Close() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, backend.URL), - ForwardedHeaders: ForwardedBoth, - ForwardedBy: " _proxy-node ", - })) - - req := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) - req.RemoteAddr = "198.51.100.10:4567" - rr := httptest.NewRecorder() - engine.ServeHTTP(rr, req) - - if rr.Code != http.StatusNoContent { - t.Fatalf("unexpected status: %d", rr.Code) - } - - select { - case forwarded := <-forwardedCh: - if !strings.Contains(forwarded, "by=_proxy-node") { - t.Fatalf("unexpected Forwarded header: %q", forwarded) - } - if strings.Contains(forwarded, `by=" _proxy-node "`) { - t.Fatalf("forwarded header should not preserve surrounding whitespace: %q", forwarded) - } - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for backend Forwarded header") - } -} - func TestReverseProxyDefaultViaFallback(t *testing.T) { t.Helper() @@ -268,544 +203,6 @@ func TestReverseProxyDefaultViaFallback(t *testing.T) { } } -func TestReverseProxyRejectsConflictingTargetConfig(t *testing.T) { - t.Helper() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, "http://example.com"), - Targets: []string{"http://example.net"}, - })) - - rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if rr.Code != http.StatusInternalServerError { - t.Fatalf("unexpected status: %d", rr.Code) - } -} - -func TestReverseProxyTargetsRoundRobinPreservesFullURLTargets(t *testing.T) { - t.Helper() - - type snapshot struct { - Path string - RawQuery string - } - - backendOneCh := make(chan snapshot, 1) - backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - backendOneCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery} - _, _ = io.WriteString(w, "one") - })) - defer backendOne.Close() - - backendTwoCh := make(chan snapshot, 1) - backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - backendTwoCh <- snapshot{Path: r.URL.Path, RawQuery: r.URL.RawQuery} - _, _ = io.WriteString(w, "two") - })) - defer backendTwo.Close() - - engine := New() - engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ - Targets: []string{ - backendOne.URL + "/one?from=one", - backendTwo.URL + "/two?from=two", - }, - LoadBalancing: ReverseProxyLoadBalancingConfig{Policy: LBRoundRobin()}, - })) - - first := PerformRequest(engine, http.MethodGet, "/api/ping?q=1", nil, nil) - if first.Code != http.StatusOK || first.Body.String() != "one" { - t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String()) - } - second := PerformRequest(engine, http.MethodGet, "/api/pong?q=2", nil, nil) - if second.Code != http.StatusOK || second.Body.String() != "two" { - t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String()) - } - - select { - case got := <-backendOneCh: - if got.Path != "/one/api/ping" || got.RawQuery != "from=one&q=1" { - t.Fatalf("unexpected first upstream request: %#v", got) - } - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for first upstream request") - } - - select { - case got := <-backendTwoCh: - if got.Path != "/two/api/pong" || got.RawQuery != "from=two&q=2" { - t.Fatalf("unexpected second upstream request: %#v", got) - } - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for second upstream request") - } -} - -func TestReverseProxyHeaderPolicyFallbackAndStickiness(t *testing.T) { - t.Helper() - - backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, "one") - })) - defer backendOne.Close() - - backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, "two") - })) - defer backendTwo.Close() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{backendOne.URL, backendTwo.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBHeader("X-Upstream", LBFirst()), - }, - })) - - fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" { - t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String()) - } - - headers := http.Header{"X-Upstream": {"tenant-a"}} - firstSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers) - secondSticky := PerformRequest(engine, http.MethodGet, "/proxy", nil, headers) - if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK { - t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code) - } - if firstSticky.Body.String() != secondSticky.Body.String() { - t.Fatalf("header policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String()) - } -} - -func TestReverseProxyQueryPolicyFallbackAndStickiness(t *testing.T) { - t.Helper() - - backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, "one") - })) - defer backendOne.Close() - - backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, "two") - })) - defer backendTwo.Close() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{backendOne.URL, backendTwo.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBQuery("tenant", LBFirst()), - }, - })) - - fallbackResp := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if fallbackResp.Code != http.StatusOK || fallbackResp.Body.String() != "one" { - t.Fatalf("unexpected fallback response: code=%d body=%q", fallbackResp.Code, fallbackResp.Body.String()) - } - - firstSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil) - secondSticky := PerformRequest(engine, http.MethodGet, "/proxy?tenant=a", nil, nil) - if firstSticky.Code != http.StatusOK || secondSticky.Code != http.StatusOK { - t.Fatalf("unexpected sticky statuses: %d %d", firstSticky.Code, secondSticky.Code) - } - if firstSticky.Body.String() != secondSticky.Body.String() { - t.Fatalf("query policy should be sticky, got %q and %q", firstSticky.Body.String(), secondSticky.Body.String()) - } -} - -func TestReverseProxyClientIPHashUsesParsedClientIP(t *testing.T) { - t.Helper() - - backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, "one") - })) - defer backendOne.Close() - - backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, "two") - })) - defer backendTwo.Close() - - engine := New() - engine.SetRemoteIPHeaders([]string{"CF-Connecting-IP"}) - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{backendOne.URL, backendTwo.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBClientIPHash(), - }, - })) - - reqOne := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) - reqOne.RemoteAddr = "10.0.0.1:1234" - reqOne.Header.Set("CF-Connecting-IP", "203.0.113.10") - rrOne := httptest.NewRecorder() - engine.ServeHTTP(rrOne, reqOne) - - reqTwo := httptest.NewRequest(http.MethodGet, "http://client.example/proxy", nil) - reqTwo.RemoteAddr = "10.0.0.2:5678" - reqTwo.Header.Set("CF-Connecting-IP", "203.0.113.10") - rrTwo := httptest.NewRecorder() - engine.ServeHTTP(rrTwo, reqTwo) - - if rrOne.Code != http.StatusOK || rrTwo.Code != http.StatusOK { - t.Fatalf("unexpected statuses: %d %d", rrOne.Code, rrTwo.Code) - } - if rrOne.Body.String() != rrTwo.Body.String() { - t.Fatalf("client IP hash should use parsed client IP, got %q and %q", rrOne.Body.String(), rrTwo.Body.String()) - } -} - -func TestReverseProxyRetriesSafeRequestsAcrossTargets(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, "ok") - })) - defer backend.Close() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{"http://127.0.0.1:1", backend.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBFirst(), - Retries: 1, - }, - })) - - rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if rr.Code != http.StatusOK || rr.Body.String() != "ok" { - t.Fatalf("unexpected retry response: code=%d body=%q", rr.Code, rr.Body.String()) - } -} - -func TestReverseProxyModifyRequestRunsPerRetryAttempt(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, r.Header.Get("X-Attempt")) - })) - defer backend.Close() - - var attempts atomic.Int64 - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{"http://127.0.0.1:1", backend.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBFirst(), - Retries: 1, - }, - ModifyRequest: func(req *http.Request) { - req.Header.Set("X-Attempt", strconv.FormatInt(attempts.Add(1), 10)) - }, - })) - - rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if rr.Code != http.StatusOK { - t.Fatalf("unexpected status: %d", rr.Code) - } - if rr.Body.String() != "2" { - t.Fatalf("ModifyRequest should run again for the retry attempt, got %q", rr.Body.String()) - } -} - -func TestReverseProxyDoesNotRetryUnsafeRequestsAcrossTargets(t *testing.T) { - t.Helper() - - backendCalls := make(chan struct{}, 1) - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - backendCalls <- struct{}{} - _, _ = io.WriteString(w, "ok") - })) - defer backend.Close() - - engine := New() - engine.POST("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{"http://127.0.0.1:1", backend.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBFirst(), - Retries: 1, - }, - })) - - rr := PerformRequest(engine, http.MethodPost, "/proxy", strings.NewReader("payload"), nil) - if rr.Code != http.StatusBadGateway { - t.Fatalf("unexpected status: %d", rr.Code) - } - - select { - case <-backendCalls: - t.Fatal("unsafe POST request should not be retried to the next upstream") - default: - } -} - -func TestReverseProxyLeastConnPrefersLessBusyUpstream(t *testing.T) { - t.Helper() - - backendOneStarted := make(chan struct{}, 1) - releaseBackendOne := make(chan struct{}) - backendOne := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - backendOneStarted <- struct{}{} - <-releaseBackendOne - _, _ = io.WriteString(w, "one") - })) - defer backendOne.Close() - - backendTwo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, "two") - })) - defer backendTwo.Close() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{backendOne.URL, backendTwo.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBLeastConn(), - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - client := proxy.Client() - client.Timeout = 5 * time.Second - - firstRespCh := make(chan string, 1) - firstErrCh := make(chan error, 1) - go func() { - resp, err := client.Get(proxy.URL + "/proxy") - if err != nil { - firstErrCh <- err - return - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - firstErrCh <- err - return - } - firstRespCh <- string(body) - }() - - select { - case <-backendOneStarted: - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for first backend request") - } - - secondResp, err := client.Get(proxy.URL + "/proxy") - if err != nil { - close(releaseBackendOne) - t.Fatalf("second request failed: %v", err) - } - secondBody, err := io.ReadAll(secondResp.Body) - _ = secondResp.Body.Close() - if err != nil { - close(releaseBackendOne) - t.Fatalf("read second response: %v", err) - } - if string(secondBody) != "two" { - close(releaseBackendOne) - t.Fatalf("least_conn should pick the less busy upstream, got %q", string(secondBody)) - } - - close(releaseBackendOne) - select { - case err := <-firstErrCh: - t.Fatalf("first request failed: %v", err) - case body := <-firstRespCh: - if body != "one" { - t.Fatalf("unexpected first response body: %q", body) - } - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for first response body") - } -} - -func TestReverseProxyPassiveHealthSkipsUnhealthyTargetsOnLaterRequests(t *testing.T) { - t.Helper() - - primaryCalls := make(chan struct{}, 4) - primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - primaryCalls <- struct{}{} - w.WriteHeader(http.StatusServiceUnavailable) - _, _ = io.WriteString(w, "primary down") - })) - defer primary.Close() - - secondaryCalls := make(chan struct{}, 4) - secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - secondaryCalls <- struct{}{} - _, _ = io.WriteString(w, "secondary up") - })) - defer secondary.Close() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{primary.URL, secondary.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBFirst(), - }, - PassiveHealth: ReverseProxyPassiveHealthConfig{ - FailDuration: time.Minute, - UnhealthyStatus: []int{http.StatusServiceUnavailable}, - }, - })) - - first := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if first.Code != http.StatusServiceUnavailable || first.Body.String() != "primary down" { - t.Fatalf("unexpected first response: code=%d body=%q", first.Code, first.Body.String()) - } - second := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if second.Code != http.StatusOK || second.Body.String() != "secondary up" { - t.Fatalf("unexpected second response: code=%d body=%q", second.Code, second.Body.String()) - } - - select { - case <-primaryCalls: - default: - t.Fatal("expected primary to receive the first request") - } - select { - case <-secondaryCalls: - default: - t.Fatal("expected secondary to receive the second request") - } - select { - case <-primaryCalls: - t.Fatal("primary should not receive the second request while unhealthy") - default: - } -} - -func TestReverseProxyPassiveHealthIgnoresClientCancellation(t *testing.T) { - t.Helper() - - started := make(chan struct{}, 1) - release := make(chan struct{}) - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - started <- struct{}{} - <-release - _, _ = io.WriteString(w, "ok") - })) - defer backend.Close() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{backend.URL}, - PassiveHealth: ReverseProxyPassiveHealthConfig{ - FailDuration: time.Minute, - }, - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - ctx, cancel := context.WithCancel(context.Background()) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxy.URL+"/proxy", nil) - if err != nil { - t.Fatalf("new request: %v", err) - } - client := proxy.Client() - respCh := make(chan error, 1) - go func() { - resp, err := client.Do(req) - if resp != nil { - _ = resp.Body.Close() - } - respCh <- err - }() - - select { - case <-started: - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for backend request") - } - cancel() - close(release) - select { - case <-respCh: - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for canceled request to finish") - } - - rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if rr.Code != http.StatusOK || rr.Body.String() != "ok" { - t.Fatalf("healthy backend should remain selectable after client cancellation, got code=%d body=%q", rr.Code, rr.Body.String()) - } -} - -func TestReverseProxyTryDurationPreventsLateRetry(t *testing.T) { - t.Helper() - - backendCalls := make(chan struct{}, 1) - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - backendCalls <- struct{}{} - _, _ = io.WriteString(w, "ok") - })) - defer backend.Close() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Targets: []string{"http://127.0.0.1:1", backend.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBFirst(), - Retries: 3, - TryDuration: 100 * time.Millisecond, - TryInterval: 250 * time.Millisecond, - }, - })) - - rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if rr.Code != http.StatusBadGateway { - t.Fatalf("unexpected status: %d", rr.Code) - } - - select { - case <-backendCalls: - t.Fatal("retry budget should expire before the next upstream attempt") - default: - } -} - -func TestReverseProxyAllowH2CUpstream(t *testing.T) { - t.Helper() - - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen h2c upstream: %v", err) - } - server := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Upstream-Proto", r.Proto) - _, _ = io.WriteString(w, "ok") - })} - server.Protocols = new(http.Protocols) - server.Protocols.SetUnencryptedHTTP2(true) - errCh := make(chan error, 1) - go func() { - errCh <- server.Serve(listener) - }() - defer func() { - _ = server.Close() - <-errCh - }() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, "http://"+listener.Addr().String()), - AllowH2CUpstream: true, - })) - - rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if rr.Code != http.StatusOK || rr.Body.String() != "ok" { - t.Fatalf("unexpected response: code=%d body=%q", rr.Code, rr.Body.String()) - } - if got := rr.Header().Get("X-Upstream-Proto"); got != "HTTP/2.0" { - t.Fatalf("expected h2c upstream proto, got %q", got) - } -} - func TestReverseProxyCustomErrorHandler(t *testing.T) { t.Helper() @@ -832,148 +229,6 @@ func TestReverseProxyCustomErrorHandler(t *testing.T) { } } -func TestReverseProxyH2ReadWriteCloserWriteReturnsWrittenCountOnFlushError(t *testing.T) { - t.Helper() - - flushErr := errors.New("flush failed") - writer := &flushErrorResponseWriter{flushErr: flushErr} - conn := &reverseProxyH2ReadWriteCloser{ - ReadCloser: io.NopCloser(strings.NewReader("")), - ResponseWriter: writer, - controller: http.NewResponseController(reverseProxyBaseResponseWriter(writer)), - } - - n, err := conn.Write([]byte("ping")) - if n != len("ping") { - t.Fatalf("unexpected bytes written: %d", n) - } - if !errors.Is(err, flushErr) { - t.Fatalf("unexpected write error: %v", err) - } - if got := writer.body.String(); got != "ping" { - t.Fatalf("unexpected buffered body: %q", got) - } -} - -func TestReverseProxyExtendedConnectBridgeKeyGenerationFailureReturnsError(t *testing.T) { - t.Helper() - - transportCalled := atomic.Bool{} - entropyErr := errors.New("entropy source unavailable") - originalReader := crand.Reader - crand.Reader = errorReader{err: entropyErr} - t.Cleanup(func() { - crand.Reader = originalReader - }) - - engine := New() - engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, "http://example.com"), - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - transportCalled.Store(true) - return nil, errors.New("unexpected round trip") - }), - ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) { - w.WriteHeader(reverseProxyStatusCode(err)) - _, _ = io.WriteString(w, err.Error()) - }, - })) - - headers := make(http.Header) - headers.Set(":protocol", "websocket") - rr := PerformRequest(engine, http.MethodConnect, "/ws", nil, headers) - - if transportCalled.Load() { - t.Fatal("transport should not be called when websocket key generation fails") - } - if rr.Code != http.StatusBadGateway { - t.Fatalf("unexpected status: %d", rr.Code) - } - if body := rr.Body.String(); !strings.Contains(body, "reverse proxy failed to generate websocket key") || !strings.Contains(body, entropyErr.Error()) { - t.Fatalf("unexpected error body: %q", body) - } -} - -func TestHTTP2TransportBuildersDoNotPanicWhenDefaultTransportIsCustom(t *testing.T) { - t.Helper() - - originalDefaultTransport := http.DefaultTransport - http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { - return nil, errors.New("unexpected round trip") - }) - t.Cleanup(func() { - http.DefaultTransport = originalDefaultTransport - }) - - assertTransport := func(name string, rt http.RoundTripper, check func(*http.Transport)) { - t.Helper() - transport, ok := rt.(*http.Transport) - if !ok { - t.Fatalf("%s returned %T, want *http.Transport", name, rt) - } - check(transport) - } - - assertTransport("newHTTP2ExtendedConnectTransport", newHTTP2ExtendedConnectTransport(), func(transport *http.Transport) { - if transport.Protocols == nil || !transport.Protocols.HTTP1() || !transport.Protocols.HTTP2() { - t.Fatalf("unexpected protocols for extended connect transport: %#v", transport.Protocols) - } - }) - assertTransport("newHTTP1BridgeTransportWithTLSConfig", newHTTP1BridgeTransportWithTLSConfig(nil), func(transport *http.Transport) { - if transport.Protocols == nil || !transport.Protocols.HTTP1() || transport.Protocols.HTTP2() || transport.Protocols.UnencryptedHTTP2() { - t.Fatalf("unexpected protocols for bridge transport: %#v", transport.Protocols) - } - if transport.TLSClientConfig == nil || len(transport.TLSClientConfig.NextProtos) != 1 || transport.TLSClientConfig.NextProtos[0] != "http/1.1" { - t.Fatalf("unexpected TLS next protos for bridge transport: %#v", transport.TLSClientConfig) - } - }) - assertTransport("newH2CTransport", newH2CTransport(), func(transport *http.Transport) { - if transport.Protocols == nil || !transport.Protocols.UnencryptedHTTP2() || transport.Protocols.HTTP1() || transport.Protocols.HTTP2() { - t.Fatalf("unexpected protocols for h2c transport: %#v", transport.Protocols) - } - }) -} - -func TestNewHTTP1BridgeTransportWithTLSConfigClonesInput(t *testing.T) { - t.Helper() - - tlsConfig := &tls.Config{InsecureSkipVerify: true} - rt := newHTTP1BridgeTransportWithTLSConfig(tlsConfig) - transport, ok := rt.(*http.Transport) - if !ok { - t.Fatalf("unexpected transport type: %T", rt) - } - if transport.TLSClientConfig == nil { - t.Fatal("expected TLS client config") - } - if transport.TLSClientConfig == tlsConfig { - t.Fatal("expected bridge transport to clone TLS config") - } - if len(tlsConfig.NextProtos) != 0 { - t.Fatalf("input TLS config was mutated: %#v", tlsConfig.NextProtos) - } - if got := transport.TLSClientConfig.NextProtos; len(got) != 1 || got[0] != "http/1.1" { - t.Fatalf("unexpected transport NextProtos: %#v", got) - } -} - -func TestReverseProxyTimeoutReturnsGatewayTimeout(t *testing.T) { - t.Helper() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, "http://example.com"), - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - return nil, context.DeadlineExceeded - }), - })) - - rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) - if rr.Code != http.StatusGatewayTimeout { - t.Fatalf("unexpected status: %d", rr.Code) - } -} - func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) { t.Helper() @@ -1197,1081 +452,6 @@ func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) { } } -func TestReverseProxyUpgradeNeedsHijacker(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - hj, ok := w.(http.Hijacker) - if !ok { - t.Fatal("backend response writer does not support hijack") - } - conn, brw, err := hj.Hijack() - if err != nil { - t.Fatalf("backend hijack failed: %v", err) - } - defer conn.Close() - - _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") - _ = brw.Flush() - })) - defer backend.Close() - - engine := New() - engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) - - req := httptest.NewRequest(http.MethodGet, "http://client.example/ws", nil) - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - rr := httptest.NewRecorder() - engine.ServeHTTP(rr, req) - - if rr.Code != http.StatusNotImplemented { - t.Fatalf("unexpected status: %d", rr.Code) - } -} - -func TestReverseProxyMaxForwardsTraceHandledLocally(t *testing.T) { - t.Helper() - - called := make(chan struct{}, 1) - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called <- struct{}{} - w.WriteHeader(http.StatusNoContent) - })) - defer backend.Close() - - engine := New() - engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) - - req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) - req.RequestURI = "/trace" - req.Header.Set("Max-Forwards", "0") - req.Header.Set("Authorization", "secret") - req.Header.Set("Cookie", "a=b") - req.Header.Set("Forwarded", "for=192.0.2.1") - - rr := httptest.NewRecorder() - engine.ServeHTTP(rr, req) - - resp := rr.Result() - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read body: %v", err) - } - _ = resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) - } - if got := resp.Header.Get("Content-Type"); got != "message/http" { - t.Fatalf("unexpected content type: %q", got) - } - if !strings.Contains(string(body), "TRACE /trace HTTP/1.1") { - t.Fatalf("trace body missing request line: %q", string(body)) - } - if strings.Contains(string(body), "Authorization:") { - t.Fatalf("trace body leaked authorization header: %q", string(body)) - } - if strings.Contains(string(body), "Cookie:") { - t.Fatalf("trace body leaked cookie header: %q", string(body)) - } - if strings.Contains(string(body), "Forwarded:") { - t.Fatalf("trace body leaked forwarded header: %q", string(body)) - } - - select { - case <-called: - t.Fatal("backend should not be called when Max-Forwards is zero") - default: - } -} - -func TestReverseProxyMaxForwardsTraceDecrementsBeforeForwarding(t *testing.T) { - t.Helper() - - maxForwardsCh := make(chan string, 1) - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - maxForwardsCh <- r.Header.Get("Max-Forwards") - w.WriteHeader(http.StatusNoContent) - })) - defer backend.Close() - - engine := New() - engine.Handle(http.MethodTrace, "/trace", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) - - req := httptest.NewRequest(http.MethodTrace, "http://client.example/trace", nil) - req.Header.Set("Max-Forwards", "2") - rr := httptest.NewRecorder() - engine.ServeHTTP(rr, req) - - if rr.Code != http.StatusNoContent { - t.Fatalf("unexpected status: %d", rr.Code) - } - - select { - case got := <-maxForwardsCh: - if got != "1" { - t.Fatalf("unexpected Max-Forwards header: %q", got) - } - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for backend Max-Forwards") - } -} - -func TestReverseProxyMaxForwardsOptionsHandledLocally(t *testing.T) { - t.Helper() - - called := make(chan struct{}, 1) - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called <- struct{}{} - w.WriteHeader(http.StatusNoContent) - })) - defer backend.Close() - - engine := New() - engine.GET("/proxy", func(c *Context) { c.Status(http.StatusNoContent) }) - engine.OPTIONS("/proxy", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) - - req := httptest.NewRequest(http.MethodOptions, "http://client.example/proxy", nil) - req.Header.Set("Max-Forwards", "0") - rr := httptest.NewRecorder() - engine.ServeHTTP(rr, req) - - if rr.Code != http.StatusOK { - t.Fatalf("unexpected status: %d", rr.Code) - } - allow := rr.Header().Get("Allow") - if !strings.Contains(allow, http.MethodGet) || !strings.Contains(allow, http.MethodOptions) { - t.Fatalf("unexpected Allow header: %q", allow) - } - - select { - case <-called: - t.Fatal("backend should not be called when Max-Forwards is zero") - default: - } -} - -func TestEngineHandlesOptionsAsteriskLocally(t *testing.T) { - t.Helper() - - engine := New() - engine.OPTIONS("/", func(c *Context) { - c.Status(http.StatusNoContent) - }) - - req := httptest.NewRequest(http.MethodOptions, "http://client.example/", nil) - req.RequestURI = "*" - req.URL.Path = "" - req.URL.RawPath = "" - rr := httptest.NewRecorder() - engine.ServeHTTP(rr, req) - - if rr.Code != http.StatusOK { - t.Fatalf("unexpected status for OPTIONS *: %d", rr.Code) - } - if got := rr.Header().Get("Content-Length"); got != "0" { - t.Fatalf("unexpected Content-Length header: %q", got) - } -} - -func TestReverseProxyConnectTunnel(t *testing.T) { - t.Helper() - - backendAddr := "" - errCh := make(chan error, 4) - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { - errCh <- fmt.Errorf("unexpected method: %s", r.Method) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - if got, want := r.RequestURI, backendAddr; got != want { - errCh <- fmt.Errorf("unexpected CONNECT target %q, want %q", got, want) - w.WriteHeader(http.StatusBadRequest) - return - } - - hj, ok := w.(http.Hijacker) - if !ok { - errCh <- errors.New("backend response writer does not support hijack") - return - } - conn, brw, err := hj.Hijack() - if err != nil { - errCh <- fmt.Errorf("backend hijack failed: %w", err) - return - } - defer conn.Close() - - _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\nVia: 1.1 upstream\r\n\r\n") - if err := brw.Flush(); err != nil { - errCh <- fmt.Errorf("backend flush failed: %w", err) - return - } - - line, err := brw.ReadString('\n') - if err != nil { - errCh <- fmt.Errorf("backend read failed: %w", err) - return - } - _, _ = io.WriteString(brw, strings.ToUpper(line)) - if err := brw.Flush(); err != nil { - errCh <- fmt.Errorf("backend write failed: %w", err) - return - } - })) - defer backend.Close() - backendAddr = strings.TrimPrefix(backend.URL, "http://") - - engine := New() - engine.Handle(http.MethodConnect, "/:authority", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, backend.URL), - Via: "proxy.test", - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) - if err != nil { - t.Fatalf("dial proxy: %v", err) - } - defer conn.Close() - - if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { - t.Fatalf("set deadline: %v", err) - } - - _, err = fmt.Fprintf(conn, "CONNECT origin.example:443 HTTP/1.1\r\nHost: origin.example:443\r\n\r\n") - if err != nil { - t.Fatalf("write connect request: %v", err) - } - - reader := bufio.NewReader(conn) - statusLine, err := reader.ReadString('\n') - if err != nil { - t.Fatalf("read status line: %v", err) - } - if !strings.Contains(statusLine, "200") { - t.Fatalf("unexpected status line: %q", statusLine) - } - - headers, err := textproto.NewReader(reader).ReadMIMEHeader() - if err != nil { - t.Fatalf("read headers: %v", err) - } - respHeader := http.Header(headers) - if got := respHeader.Get("Content-Length"); got != "" { - t.Fatalf("CONNECT response should not include Content-Length, got %q", got) - } - if got := respHeader.Get("Transfer-Encoding"); got != "" { - t.Fatalf("CONNECT response should not include Transfer-Encoding, got %q", got) - } - if gotVia := respHeader.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.1 upstream" || gotVia[1] != "1.1 proxy.test" { - t.Fatalf("unexpected Via response header: %#v", gotVia) - } - - if _, err := io.WriteString(conn, "ping\n"); err != nil { - t.Fatalf("write tunneled payload: %v", err) - } - message, err := reader.ReadString('\n') - if err != nil { - t.Fatalf("read tunneled payload: %v", err) - } - if message != "PING\n" { - t.Fatalf("unexpected tunneled payload: %q", message) - } - - select { - case err := <-errCh: - t.Fatal(err) - default: - } -} - -func TestReverseProxyConnectNeedsHijacker(t *testing.T) { - t.Helper() - - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - hj, ok := w.(http.Hijacker) - if !ok { - t.Fatal("backend response writer does not support hijack") - } - conn, brw, err := hj.Hijack() - if err != nil { - t.Fatalf("backend hijack failed: %v", err) - } - defer conn.Close() - - _, _ = io.WriteString(brw, "HTTP/1.1 200 Connection Established\r\n\r\n") - _ = brw.Flush() - })) - defer backend.Close() - - engine := New() - engine.Handle(http.MethodConnect, "/tunnel", ReverseProxy(ReverseProxyConfig{Target: mustParseURL(t, backend.URL)})) - - req := httptest.NewRequest(http.MethodConnect, "http://client.example/tunnel", nil) - req.URL.Path = "/tunnel" - req.RequestURI = "/tunnel" - rr := httptest.NewRecorder() - engine.ServeHTTP(rr, req) - - if rr.Code != http.StatusNotImplemented { - t.Fatalf("unexpected status: %d", rr.Code) - } -} - -func TestReverseProxyHTTP2ExtendedConnect(t *testing.T) { - t.Helper() - - enableHTTP2ExtendedConnectProtocol() - - errCh := make(chan error, 4) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - if got := r.Header.Get(":protocol"); got != "" { - errCh <- fmt.Errorf("unexpected upstream :protocol header: %q", got) - w.WriteHeader(http.StatusBadRequest) - return - } - if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { - errCh <- fmt.Errorf("unexpected upstream Connection header: %#v", r.Header.Values("Connection")) - w.WriteHeader(http.StatusBadRequest) - return - } - if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { - errCh <- fmt.Errorf("unexpected upstream Upgrade header: %q", r.Header.Get("Upgrade")) - w.WriteHeader(http.StatusBadRequest) - return - } - if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { - errCh <- errors.New("missing upstream Sec-WebSocket-Key header") - w.WriteHeader(http.StatusBadRequest) - return - } - if got := r.URL.Path; got != "/ws" { - errCh <- fmt.Errorf("unexpected upstream path: %q", got) - w.WriteHeader(http.StatusBadRequest) - return - } - - hj, ok := w.(http.Hijacker) - if !ok { - errCh <- errors.New("upstream response writer does not support hijack") - return - } - conn, brw, err := hj.Hijack() - if err != nil { - errCh <- fmt.Errorf("upstream hijack failed: %w", err) - return - } - defer conn.Close() - - _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade, X-Hop-Token\r\nX-Hop-Token: hidden\r\nSec-WebSocket-Accept: ignored\r\n\r\n") - if err := brw.Flush(); err != nil { - errCh <- fmt.Errorf("upstream flush failed: %w", err) - return - } - - line, err := brw.ReadString('\n') - if err != nil { - errCh <- fmt.Errorf("read tunneled request body failed: %w", err) - return - } - if _, err := io.WriteString(brw, "echo:"+line); err != nil { - errCh <- fmt.Errorf("write tunneled response body failed: %w", err) - return - } - _ = brw.Flush() - })) - defer upstream.Close() - - engine := New() - engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), - Via: "proxy.test", - })) - - proxy := httptest.NewUnstartedServer(engine) - proxy.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { - t.Fatalf("configure proxy HTTP/2 server: %v", err) - } - proxy.StartTLS() - defer proxy.Close() - - transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} - defer transport.CloseIdleConnections() - - pr, pw := io.Pipe() - req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) - if err != nil { - t.Fatalf("new CONNECT request: %v", err) - } - req.Header.Set(":protocol", "websocket") - - resp, err := transport.RoundTrip(req) - if err != nil { - t.Fatalf("round trip extended CONNECT: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %d", resp.StatusCode) - } - if got := resp.Header.Get("Upgrade"); got != "" { - t.Fatalf("bridged extended CONNECT response should not expose Upgrade header, got %q", got) - } - if got := resp.Header.Get("X-Hop-Token"); got != "" { - t.Fatalf("bridged extended CONNECT response should not expose hop-by-hop token header, got %q", got) - } - if gotVia := resp.Header.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { - t.Fatalf("unexpected Via response header: %#v", gotVia) - } - - if _, err := io.WriteString(pw, "ping\n"); err != nil { - t.Fatalf("write tunneled request body: %v", err) - } - message, err := bufio.NewReader(resp.Body).ReadString('\n') - if err != nil { - t.Fatalf("read tunneled response body: %v", err) - } - if message != "echo:ping\n" { - t.Fatalf("unexpected tunneled response body: %q", message) - } - if err := pw.Close(); err != nil { - t.Fatalf("close tunneled request body: %v", err) - } - - select { - case err := <-errCh: - t.Fatal(err) - default: - } -} - -func TestReverseProxyHTTP2ExtendedConnectBridgeClosesBackendOnce(t *testing.T) { - t.Helper() - - enableHTTP2ExtendedConnectProtocol() - - closeCalls := atomic.Int32{} - backendReadDone := make(chan struct{}, 1) - transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { - if req.Method != http.MethodGet { - return nil, fmt.Errorf("unexpected upstream method: %s", req.Method) - } - var respondOnce sync.Once - var backend *countingReadWriteCloser - backend = &countingReadWriteCloser{ - readDataCh: make(chan []byte, 1), - closeCalls: &closeCalls, - closeWriteErr: nil, - afterWrite: func() { - respondOnce.Do(func() { - backendReadDone <- struct{}{} - backend.readDataCh <- []byte("echo:ping\n") - close(backend.readDataCh) - }) - }, - } - return &http.Response{ - StatusCode: http.StatusSwitchingProtocols, - Header: http.Header{ - "Connection": []string{"Upgrade"}, - "Upgrade": []string{"websocket"}, - "Sec-WebSocket-Accept": []string{"ignored"}, - }, - Body: backend, - Request: req, - }, nil - }) - - engine := New() - engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, "http://example.com"), - Transport: transport, - })) - - proxy := httptest.NewUnstartedServer(engine) - proxy.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { - t.Fatalf("configure proxy HTTP/2 server: %v", err) - } - proxy.StartTLS() - defer proxy.Close() - - clientTransport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} - defer clientTransport.CloseIdleConnections() - - pr, pw := io.Pipe() - req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) - if err != nil { - t.Fatalf("new CONNECT request: %v", err) - } - req.Header.Set(":protocol", "websocket") - - resp, err := clientTransport.RoundTrip(req) - if err != nil { - t.Fatalf("round trip extended CONNECT: %v", err) - } - if resp.StatusCode != http.StatusOK { - _ = resp.Body.Close() - t.Fatalf("unexpected status: %d", resp.StatusCode) - } - if _, err := io.WriteString(pw, "ping\n"); err != nil { - _ = resp.Body.Close() - t.Fatalf("write tunneled request body: %v", err) - } - select { - case <-backendReadDone: - case <-time.After(2 * time.Second): - _ = resp.Body.Close() - t.Fatal("backend did not receive tunneled request body") - } - message, err := bufio.NewReader(resp.Body).ReadString('\n') - if err != nil { - _ = resp.Body.Close() - t.Fatalf("read tunneled response body: %v", err) - } - if message != "echo:ping\n" { - _ = resp.Body.Close() - t.Fatalf("unexpected tunneled response body: %q", message) - } - if err := pw.Close(); err != nil { - _ = resp.Body.Close() - t.Fatalf("close tunneled request body: %v", err) - } - if err := resp.Body.Close(); err != nil { - t.Fatalf("close response body: %v", err) - } - - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - if closeCalls.Load() > 0 { - break - } - time.Sleep(10 * time.Millisecond) - } - if got := closeCalls.Load(); got != 1 { - t.Fatalf("expected backend connection to close exactly once, got %d", got) - } -} - -func TestReverseProxyHTTP2ExtendedConnectForcesHTTP1ToTLSUpstream(t *testing.T) { - t.Helper() - - enableHTTP2ExtendedConnectProtocol() - - errCh := make(chan error, 4) - upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.ProtoMajor != 1 { - errCh <- fmt.Errorf("expected bridged upstream protocol HTTP/1.x, got %s", r.Proto) - w.WriteHeader(http.StatusBadRequest) - return - } - if r.Method != http.MethodGet { - errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { - errCh <- fmt.Errorf("unexpected websocket bridge headers: Connection=%#v Upgrade=%q", r.Header.Values("Connection"), r.Header.Get("Upgrade")) - w.WriteHeader(http.StatusBadRequest) - return - } - - hj, ok := w.(http.Hijacker) - if !ok { - errCh <- errors.New("upstream response writer does not support hijack") - return - } - conn, brw, err := hj.Hijack() - if err != nil { - errCh <- fmt.Errorf("upstream hijack failed: %w", err) - return - } - defer conn.Close() - - _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") - if err := brw.Flush(); err != nil { - errCh <- fmt.Errorf("upstream flush failed: %w", err) - return - } - - line, err := brw.ReadString('\n') - if err != nil { - errCh <- fmt.Errorf("read tunneled request body failed: %w", err) - return - } - if _, err := io.WriteString(brw, "echo:"+line); err != nil { - errCh <- fmt.Errorf("write tunneled response body failed: %w", err) - return - } - _ = brw.Flush() - })) - upstream.EnableHTTP2 = true - upstream.StartTLS() - defer upstream.Close() - - engine := New() - engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Transport: newHTTP1BridgeTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true}), - Via: "proxy.test", - })) - - proxy := httptest.NewUnstartedServer(engine) - proxy.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { - t.Fatalf("configure proxy HTTP/2 server: %v", err) - } - proxy.StartTLS() - defer proxy.Close() - - transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} - defer transport.CloseIdleConnections() - - pr, pw := io.Pipe() - req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) - if err != nil { - t.Fatalf("new CONNECT request: %v", err) - } - req.Header.Set(":protocol", "websocket") - - resp, err := transport.RoundTrip(req) - if err != nil { - t.Fatalf("round trip extended CONNECT: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) - } - if _, err := io.WriteString(pw, "ping\n"); err != nil { - t.Fatalf("write tunneled request body: %v", err) - } - message, err := bufio.NewReader(resp.Body).ReadString('\n') - if err != nil { - t.Fatalf("read tunneled response body: %v", err) - } - if message != "echo:ping\n" { - t.Fatalf("unexpected tunneled response body: %q", message) - } - if err := pw.Close(); err != nil { - t.Fatalf("close tunneled request body: %v", err) - } - - select { - case err := <-errCh: - t.Fatal(err) - default: - } -} - -func TestReverseProxyHTTP2ExtendedConnectTargetsRoundRobin(t *testing.T) { - t.Helper() - - enableHTTP2ExtendedConnectProtocol() - - errCh := make(chan error, 8) - newBackend := func(name string) *httptest.Server { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - errCh <- fmt.Errorf("%s unexpected upstream method: %s", name, r.Method) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - if got := r.Header.Get(":protocol"); got != "" { - errCh <- fmt.Errorf("%s unexpected upstream :protocol header: %q", name, got) - w.WriteHeader(http.StatusBadRequest) - return - } - if !headerValuesContainToken(r.Header.Values("Connection"), "Upgrade") { - errCh <- fmt.Errorf("%s unexpected upstream Connection header: %#v", name, r.Header.Values("Connection")) - w.WriteHeader(http.StatusBadRequest) - return - } - if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { - errCh <- fmt.Errorf("%s unexpected upstream Upgrade header: %q", name, r.Header.Get("Upgrade")) - w.WriteHeader(http.StatusBadRequest) - return - } - if got := r.Header.Get("Sec-WebSocket-Key"); got == "" { - errCh <- fmt.Errorf("%s missing upstream Sec-WebSocket-Key header", name) - w.WriteHeader(http.StatusBadRequest) - return - } - - hj, ok := w.(http.Hijacker) - if !ok { - errCh <- fmt.Errorf("%s upstream response writer does not support hijack", name) - return - } - conn, brw, err := hj.Hijack() - if err != nil { - errCh <- fmt.Errorf("%s upstream hijack failed: %w", name, err) - return - } - defer conn.Close() - - _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") - if err := brw.Flush(); err != nil { - errCh <- fmt.Errorf("%s upstream flush failed: %w", name, err) - return - } - - line, err := brw.ReadString('\n') - if err != nil { - errCh <- fmt.Errorf("%s read tunneled request body failed: %w", name, err) - return - } - if _, err := io.WriteString(brw, name+":"+line); err != nil { - errCh <- fmt.Errorf("%s write tunneled response body failed: %w", name, err) - return - } - _ = brw.Flush() - })) - return server - } - - backendOne := newBackend("one") - defer backendOne.Close() - backendTwo := newBackend("two") - defer backendTwo.Close() - - engine := New() - engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Targets: []string{backendOne.URL, backendTwo.URL}, - LoadBalancing: ReverseProxyLoadBalancingConfig{ - Policy: LBRoundRobin(), - }, - Via: "proxy.test", - })) - - proxy := httptest.NewUnstartedServer(engine) - proxy.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { - t.Fatalf("configure proxy HTTP/2 server: %v", err) - } - proxy.StartTLS() - defer proxy.Close() - - transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} - defer transport.CloseIdleConnections() - - doRequest := func(payload string) string { - pr, pw := io.Pipe() - req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) - if err != nil { - t.Fatalf("new CONNECT request: %v", err) - } - req.Header.Set(":protocol", "websocket") - - resp, err := transport.RoundTrip(req) - if err != nil { - t.Fatalf("round trip extended CONNECT: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) - } - if _, err := io.WriteString(pw, payload+"\n"); err != nil { - t.Fatalf("write tunneled request body: %v", err) - } - if err := pw.Close(); err != nil { - t.Fatalf("close tunneled request body: %v", err) - } - message, err := bufio.NewReader(resp.Body).ReadString('\n') - if err != nil { - t.Fatalf("read tunneled response body: %v", err) - } - return message - } - - if got := doRequest("ping"); got != "one:ping\n" { - t.Fatalf("unexpected first tunneled response: %q", got) - } - if got := doRequest("pong"); got != "two:pong\n" { - t.Fatalf("unexpected second tunneled response: %q", got) - } - - select { - case err := <-errCh: - t.Fatal(err) - default: - } -} - -func TestReverseProxyHTTP2ExtendedConnectAllowsHalfClose(t *testing.T) { - t.Helper() - - enableHTTP2ExtendedConnectProtocol() - - errCh := make(chan error, 4) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - - hj, ok := w.(http.Hijacker) - if !ok { - errCh <- errors.New("upstream response writer does not support hijack") - return - } - conn, brw, err := hj.Hijack() - if err != nil { - errCh <- fmt.Errorf("upstream hijack failed: %w", err) - return - } - defer conn.Close() - - _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") - if err := brw.Flush(); err != nil { - errCh <- fmt.Errorf("upstream flush failed: %w", err) - return - } - - reader := bufio.NewReader(brw) - line, err := reader.ReadString('\n') - if err != nil { - errCh <- fmt.Errorf("read tunneled request body failed: %w", err) - return - } - if _, err := io.WriteString(brw, "ack:"+line); err != nil { - errCh <- fmt.Errorf("write immediate tunneled response failed: %w", err) - return - } - _ = brw.Flush() - - if _, err := io.Copy(io.Discard, reader); err != nil { - errCh <- fmt.Errorf("wait for request half-close failed: %w", err) - return - } - if _, err := io.WriteString(brw, "after-close\n"); err != nil { - errCh <- fmt.Errorf("write post-close tunneled response failed: %w", err) - return - } - _ = brw.Flush() - })) - defer upstream.Close() - - engine := New() - engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Via: "proxy.test", - })) - - proxy := httptest.NewUnstartedServer(engine) - proxy.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { - t.Fatalf("configure proxy HTTP/2 server: %v", err) - } - proxy.StartTLS() - defer proxy.Close() - - transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} - defer transport.CloseIdleConnections() - - pr, pw := io.Pipe() - req, err := http.NewRequest(http.MethodConnect, proxy.URL+"/ws", pr) - if err != nil { - t.Fatalf("new CONNECT request: %v", err) - } - req.Header.Set(":protocol", "websocket") - - resp, err := transport.RoundTrip(req) - if err != nil { - t.Fatalf("round trip extended CONNECT: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) - } - - reader := bufio.NewReader(resp.Body) - if _, err := io.WriteString(pw, "ping\n"); err != nil { - t.Fatalf("write tunneled request body: %v", err) - } - message, err := reader.ReadString('\n') - if err != nil { - t.Fatalf("read immediate tunneled response: %v", err) - } - if message != "ack:ping\n" { - t.Fatalf("unexpected immediate tunneled response: %q", message) - } - if err := pw.Close(); err != nil { - t.Fatalf("close tunneled request body: %v", err) - } - - message, err = reader.ReadString('\n') - if err != nil { - t.Fatalf("read post-close tunneled response: %v", err) - } - if message != "after-close\n" { - t.Fatalf("unexpected post-close tunneled response: %q", message) - } - - select { - case err := <-errCh: - t.Fatal(err) - default: - } -} - -func TestReverseProxyHTTP2ExtendedConnectCancelDoesNotTriggerProxyError(t *testing.T) { - t.Helper() - - enableHTTP2ExtendedConnectProtocol() - - errCh := make(chan error, 4) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - errCh <- fmt.Errorf("unexpected upstream method: %s", r.Method) - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - - hj, ok := w.(http.Hijacker) - if !ok { - errCh <- errors.New("upstream response writer does not support hijack") - return - } - conn, brw, err := hj.Hijack() - if err != nil { - errCh <- fmt.Errorf("upstream hijack failed: %w", err) - return - } - defer conn.Close() - - _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") - _ = brw.Flush() - - <-r.Context().Done() - })) - defer upstream.Close() - - proxyErrCh := make(chan error, 1) - engine := New() - engine.Handle(http.MethodConnect, "/ws", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, upstream.URL), - Via: "proxy.test", - ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - select { - case proxyErrCh <- err: - default: - } - }, - })) - - proxy := httptest.NewUnstartedServer(engine) - proxy.EnableHTTP2 = true - if err := configureHTTP2ExtendedConnectServer(proxy.Config); err != nil { - t.Fatalf("configure proxy HTTP/2 server: %v", err) - } - proxy.StartTLS() - defer proxy.Close() - - transport := &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} - defer transport.CloseIdleConnections() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - pr, pw := io.Pipe() - req, err := http.NewRequestWithContext(ctx, http.MethodConnect, proxy.URL+"/ws", pr) - if err != nil { - t.Fatalf("new CONNECT request: %v", err) - } - req.Header.Set(":protocol", "websocket") - - resp, err := transport.RoundTrip(req) - if err != nil { - t.Fatalf("round trip extended CONNECT: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("unexpected status: %d body=%q", resp.StatusCode, string(body)) - } - - writeErrCh := make(chan error, 1) - go func() { - _, err := io.WriteString(pw, strings.Repeat("x", 1<<20)) - writeErrCh <- err - }() - time.Sleep(50 * time.Millisecond) - - cancel() - if err := pw.CloseWithError(context.Canceled); err != nil { - t.Fatalf("close request body with cancellation: %v", err) - } - select { - case <-writeErrCh: - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for request body writer to unblock") - } - - select { - case err := <-proxyErrCh: - t.Fatalf("proxy error handler should not be called on cancellation, got: %v", err) - case <-time.After(200 * time.Millisecond): - } - - select { - case err := <-errCh: - t.Fatal(err) - default: - } -} - -func TestReverseProxyAbortsStreamingCopyFailure(t *testing.T) { - t.Helper() - - engine := New() - engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ - Target: mustParseURL(t, "http://example.com"), - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{ - "Content-Type": []string{"text/plain"}, - }, - Body: &failingReadCloser{chunks: []string{"ok"}, err: errors.New("boom")}, - ContentLength: -1, - Request: req, - }, nil - }), - })) - - proxy := httptest.NewServer(engine) - defer proxy.Close() - - resp, err := proxy.Client().Get(proxy.URL + "/proxy") - if err != nil { - t.Fatalf("perform request: %v", err) - } - _, err = io.ReadAll(resp.Body) - _ = resp.Body.Close() - if err == nil { - t.Fatal("expected body read to fail after upstream copy error") - } -} - func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) { t.Helper() @@ -2380,117 +560,6 @@ func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) return fn(req) } -type flushErrorResponseWriter struct { - header http.Header - body bytes.Buffer - status int - written bool - flushErr error -} - -func (w *flushErrorResponseWriter) Header() http.Header { - if w.header == nil { - w.header = make(http.Header) - } - return w.header -} - -func (w *flushErrorResponseWriter) WriteHeader(statusCode int) { - if w.written { - return - } - w.status = statusCode - w.written = true -} - -func (w *flushErrorResponseWriter) Write(p []byte) (int, error) { - if !w.written { - w.WriteHeader(http.StatusOK) - } - return w.body.Write(p) -} - -func (w *flushErrorResponseWriter) Flush() {} - -func (w *flushErrorResponseWriter) FlushError() error { - if !w.written { - w.WriteHeader(http.StatusOK) - } - return w.flushErr -} - -func (w *flushErrorResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return nil, nil, http.ErrNotSupported -} - -func (w *flushErrorResponseWriter) Status() int { - return w.status -} - -func (w *flushErrorResponseWriter) Size() int { - return w.body.Len() -} - -func (w *flushErrorResponseWriter) Written() bool { - return w.written -} - -func (w *flushErrorResponseWriter) IsHijacked() bool { - return false -} - -type errorReader struct { - err error -} - -func (r errorReader) Read([]byte) (int, error) { - return 0, r.err -} - -type countingReadWriteCloser struct { - readData []byte - readDataCh chan []byte - writeBuf bytes.Buffer - closeCalls *atomic.Int32 - closeWriteErr error - afterWrite func() -} - -func (r *countingReadWriteCloser) Read(p []byte) (int, error) { - if len(r.readData) == 0 && r.readDataCh != nil { - data, ok := <-r.readDataCh - if !ok { - return 0, io.EOF - } - r.readData = data - } - if len(r.readData) == 0 { - return 0, io.EOF - } - n := copy(p, r.readData) - r.readData = r.readData[n:] - return n, nil -} - -func (r *countingReadWriteCloser) Write(p []byte) (int, error) { - n, err := r.writeBuf.Write(p) - if err == nil && r.afterWrite != nil { - r.afterWrite() - } - return n, err -} - -func (r *countingReadWriteCloser) Close() error { - if r.closeCalls != nil { - r.closeCalls.Add(1) - } - return nil -} - -func (r *countingReadWriteCloser) CloseWrite() error { - return r.closeWriteErr -} - func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) @@ -2499,21 +568,3 @@ func mustParseURL(t *testing.T, raw string) *url.URL { } return u } - -type failingReadCloser struct { - chunks []string - err error -} - -func (r *failingReadCloser) Read(p []byte) (int, error) { - if len(r.chunks) == 0 { - return 0, r.err - } - n := copy(p, r.chunks[0]) - r.chunks = r.chunks[1:] - return n, nil -} - -func (r *failingReadCloser) Close() error { - return nil -} diff --git a/route_match_benchmark_test.go b/route_match_benchmark_test.go deleted file mode 100644 index e0dd2aa..0000000 --- a/route_match_benchmark_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package touka - -import "testing" - -var ( - benchmarkRouteHandlers HandlersChain - benchmarkRouteFullPath string - benchmarkRouteParamsLen int - benchmarkRouteCIPath []byte - benchmarkRouteCIFound bool -) - -func buildRouteMatchBenchmarkTree() *node { - tree := &node{} - routes := []string{ - "/", - "/health", - "/contact", - "/api/v1/users", - "/api/v1/users/:id", - "/api/v1/users/:id/settings", - "/assets/*filepath", - "/abc/b", - "/abc/:p1/cde", - "/abc/:p1/:p2/def/*filepath", - } - - for _, route := range routes { - tree.addRoute(route, fakeHandler(route)) - } - - return tree -} - -func benchmarkRouteLookup(b *testing.B, tree *node, path string, wantFullPath string) { - b.Helper() - - params := make(Params, 0, 4) - skipped := make([]skippedNode, 0, 8) - - value := tree.getValue(path, ¶ms, &skipped, true) - if wantFullPath == "" { - if value.handlers != nil { - b.Fatalf("expected no match for %q, got %q", path, value.fullPath) - } - } else { - if value.handlers == nil { - b.Fatalf("expected match for %q, got nil handlers", path) - } - if value.fullPath != wantFullPath { - b.Fatalf("expected full path %q for %q, got %q", wantFullPath, path, value.fullPath) - } - } - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - params = params[:0] - skipped = skipped[:0] - value = tree.getValue(path, ¶ms, &skipped, true) - } - - benchmarkRouteHandlers = value.handlers - benchmarkRouteFullPath = value.fullPath - if value.params != nil { - benchmarkRouteParamsLen = len(*value.params) - } else { - benchmarkRouteParamsLen = 0 - } -} - -func BenchmarkRouteMatch(b *testing.B) { - tree := buildRouteMatchBenchmarkTree() - - b.Run("StaticHit", func(b *testing.B) { - benchmarkRouteLookup(b, tree, "/api/v1/users", "/api/v1/users") - }) - - b.Run("ParamHit", func(b *testing.B) { - benchmarkRouteLookup(b, tree, "/api/v1/users/123", "/api/v1/users/:id") - }) - - b.Run("BacktrackingHit", func(b *testing.B) { - benchmarkRouteLookup(b, tree, "/abc/b/d/def/some/file.txt", "/abc/:p1/:p2/def/*filepath") - }) - - b.Run("Miss", func(b *testing.B) { - benchmarkRouteLookup(b, tree, "/does/not/exist", "") - }) - - b.Run("CaseInsensitiveHit", func(b *testing.B) { - path := "/API/V1/USERS/123/SETTINGS" - out, found := tree.findCaseInsensitivePath(path, true) - if !found { - b.Fatalf("expected fixed-path match for %q", path) - } - if got := string(out); got != "/api/v1/users/123/settings" { - b.Fatalf("expected fixed-path result %q, got %q", "/api/v1/users/123/settings", got) - } - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - out, found = tree.findCaseInsensitivePath(path, true) - } - - benchmarkRouteCIPath = out - benchmarkRouteCIFound = found - }) - - b.Run("CaseInsensitiveMiss", func(b *testing.B) { - path := "/DOES/NOT/EXIST" - out, found := tree.findCaseInsensitivePath(path, true) - if found || out != nil { - b.Fatalf("expected no fixed-path match for %q, got %q, %t", path, string(out), found) - } - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - out, found = tree.findCaseInsensitivePath(path, true) - } - - benchmarkRouteCIPath = out - benchmarkRouteCIFound = found - }) -} diff --git a/serve.go b/serve.go index 0fc83f9..f3ddc5f 100644 --- a/serve.go +++ b/serve.go @@ -14,7 +14,6 @@ import ( "net/http" "os" "os/signal" - "strings" "sync" "syscall" "time" @@ -22,322 +21,329 @@ import ( "github.com/fenthope/reco" ) +// defaultShutdownTimeout 定义了在强制关闭前等待优雅关闭的最长时间 const defaultShutdownTimeout = 5 * time.Second -type runMode uint8 +// --- 内部辅助函数 --- -const ( - runModeHTTP runMode = iota - runModeHTTPS - runModeHTTPSRedirect -) - -type runConfig struct { - addr string - httpRedirectAddr string - tlsConfig *tls.Config - redirectHost string - redirectHostHeaders []string - useHeaderHost bool - useHeaderHostSet bool - graceful bool - shutdownTimeout time.Duration - gracefulCtx context.Context - mode runMode - shutdownDefaultSet bool - shutdownTimeoutSet bool -} - -type RunOption interface { - apply(*runConfig) error -} - -type runOptionFunc func(*runConfig) error - -func (f runOptionFunc) apply(cfg *runConfig) error { - return f(cfg) -} - -func defaultRunConfig() runConfig { - return runConfig{ - addr: ":8080", - shutdownTimeout: defaultShutdownTimeout, - mode: runModeHTTP, - useHeaderHost: true, +// resolveAddress 解析传入的地址参数,如果没有则返回默认的 ":8080" +func resolveAddress(addr []string) string { + switch len(addr) { + case 0: + return ":8080" + case 1: + return addr[0] + default: + panic("too many parameters provided for server address") } } -type HTTPRedirectOption interface { - applyRedirect(*runConfig) error -} - -type redirectOptionFunc func(*runConfig) error - -func (f redirectOptionFunc) applyRedirect(cfg *runConfig) error { - return f(cfg) -} - -func WithAddr(addr string) RunOption { - return runOptionFunc(func(cfg *runConfig) error { - if addr == "" { - return errors.New("run address must not be empty") - } - cfg.addr = addr - return nil - }) -} - -func WithTLS(tlsConfig *tls.Config) RunOption { - return runOptionFunc(func(cfg *runConfig) error { - if tlsConfig == nil { - return errors.New("tls.Config must not be nil") - } - cfg.tlsConfig = tlsConfig - if cfg.mode == runModeHTTP { - cfg.mode = runModeHTTPS - } - return nil - }) -} - -func WithHTTPRedirect(addr string, opts ...HTTPRedirectOption) RunOption { - return runOptionFunc(func(cfg *runConfig) error { - if addr == "" { - return errors.New("http redirect address must not be empty") - } - cfg.httpRedirectAddr = addr - cfg.mode = runModeHTTPSRedirect - for _, opt := range opts { - if opt == nil { - continue - } - if err := opt.applyRedirect(cfg); err != nil { - return err - } - } - return nil - }) -} - -func WithUseHeaderHost(enabled bool) HTTPRedirectOption { - return redirectOptionFunc(func(cfg *runConfig) error { - cfg.useHeaderHost = enabled - cfg.useHeaderHostSet = true - return nil - }) -} - -func WithRedirectHost(host string) HTTPRedirectOption { - return redirectOptionFunc(func(cfg *runConfig) error { - if host == "" { - return errors.New("redirect host must not be empty") - } - cfg.redirectHost = host - return nil - }) -} - -func WithRedirectHostHeaders(headers []string) HTTPRedirectOption { - return redirectOptionFunc(func(cfg *runConfig) error { - cfg.redirectHostHeaders = cfg.redirectHostHeaders[:0] - for _, header := range headers { - trimmed := http.CanonicalHeaderKey(strings.TrimSpace(header)) - if trimmed != "" { - cfg.redirectHostHeaders = append(cfg.redirectHostHeaders, trimmed) - } - } - return nil - }) -} - -func WithGracefulShutdown(timeout time.Duration) RunOption { - return runOptionFunc(func(cfg *runConfig) error { - cfg.graceful = true - cfg.shutdownTimeoutSet = true - if timeout > 0 { - cfg.shutdownTimeout = timeout - } else { - cfg.shutdownTimeout = defaultShutdownTimeout - } - return nil - }) -} - -func WithGracefulShutdownDefault() RunOption { - return runOptionFunc(func(cfg *runConfig) error { - cfg.graceful = true - cfg.shutdownDefaultSet = true - cfg.shutdownTimeout = defaultShutdownTimeout - return nil - }) -} - -func WithShutdownContext(ctx context.Context) RunOption { - return runOptionFunc(func(cfg *runConfig) error { - if ctx == nil { - return errors.New("shutdown context must not be nil") - } - cfg.gracefulCtx = ctx - return nil - }) -} - -func serveServer(srv *http.Server, serveTLS bool) error { - if serveTLS { - return srv.ListenAndServeTLS("", "") +// getShutdownTimeout 解析可选的超时参数,如果无效或未提供则返回默认值 +func getShutdownTimeout(timeouts []time.Duration) time.Duration { + if len(timeouts) > 0 && timeouts[0] > 0 { + return timeouts[0] } - return srv.ListenAndServe() + return defaultShutdownTimeout } -func runServer(serverType string, srv *http.Server, serveTLS bool) { +// runServer 是一个内部辅助函数,负责在一个新的 goroutine 中启动一个 http.Server, +// 并处理其启动失败的致命错误 +// serverType 用于在日志中标识服务器类型 (例如 "HTTP", "HTTPS") +func runServer(serverType string, srv *http.Server) { go func() { + var err error protocol := "http" - if serveTLS { + if srv.TLSConfig != nil { protocol = "https" } log.Printf("Touka %s server listening on %s://%s", serverType, protocol, srv.Addr) - err := serveServer(srv, serveTLS) + if srv.TLSConfig != nil { + // 对于 HTTPS 服务器,如果 srv.TLSConfig.Certificates 已配置, + // ListenAndServeTLS 的前两个参数可以为空字符串 + err = srv.ListenAndServeTLS("", "") + } else { + err = srv.ListenAndServe() + } + + // 如果服务器停止不是因为被优雅关闭 (http.ErrServerClosed), + // 则认为是一个严重错误,并终止程序 if err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Touka %s server failed: %v", serverType, err) } }() } -func cloneTLSConfig(tlsConfig *tls.Config) *tls.Config { - if tlsConfig == nil { - return nil - } - return tlsConfig.Clone() -} +// handleGracefulShutdown 监听系统信号 (SIGINT, SIGTERM) 并优雅地关闭所有提供的服务器 +// 这是所有支持优雅关闭的 RunXXX 方法的最终归宿 +func handleGracefulShutdown(servers []*http.Server, timeout time.Duration, logger *reco.Logger) error { + // 创建一个 channel 来接收操作系统信号 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 + <-quit // 阻塞,直到接收到上述信号之一 + log.Println("Shutting down Touka server(s)...") -func parseHTTPSPort(addr string) (string, error) { - _, port, err := net.SplitHostPort(addr) - if err != nil { - return "", fmt.Errorf("https address %q must include a port: %w", addr, err) - } - return port, nil -} - -func applyMainServerConfig(engine *Engine, srv *http.Server, serveTLS bool) { - if serveTLS { - if engine.TLSServerConfigurator != nil { - engine.TLSServerConfigurator(srv) - return - } - } - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) - } -} - -func applyRedirectServerConfig(engine *Engine, srv *http.Server) { - applyServerProtocols(srv, engine.serverProtocols) - if engine.ServerConfigurator != nil { - engine.ServerConfigurator(srv) - } -} - -func effectiveServerProtocols(engine *Engine, serveTLS bool) *http.Protocols { - if engine == nil { - return nil - } - if serveTLS && engine.useDefaultProtocols { - protocols := &http.Protocols{} - protocols.SetHTTP1(true) - protocols.SetHTTP2(true) - return protocols - } - return cloneServerProtocols(engine.serverProtocols) -} - -func buildMainServer(engine *Engine, cfg runConfig) *http.Server { - serveTLS := cfg.mode != runModeHTTP - server := &http.Server{ - Addr: cfg.addr, - Handler: engine, - TLSConfig: cloneTLSConfig(cfg.tlsConfig), - } - if cfg.graceful { - server.BaseContext = func(net.Listener) context.Context { - return engine.shutdownCtx - } - server.RegisterOnShutdown(engine.shutdownCancel) - } - applyServerProtocols(server, effectiveServerProtocols(engine, serveTLS)) - applyMainServerConfig(engine, server, serveTLS) - return server -} - -func firstRedirectHeaderHost(r *http.Request, headers []string) string { - if r == nil { - return "" - } - for _, header := range headers { - value := strings.TrimSpace(r.Header.Get(header)) - if value == "" { - continue - } - if comma := strings.IndexByte(value, ','); comma >= 0 { - value = strings.TrimSpace(value[:comma]) - } - if value != "" { - return value - } - } - return "" -} - -func redirectTargetHost(r *http.Request, cfg runConfig) (string, int, bool) { - if cfg.useHeaderHostSet && !cfg.useHeaderHost { - if cfg.redirectHost == "" { - return "", http.StatusInternalServerError, false - } - return cfg.redirectHost, 0, true + // 关闭日志记录器 + if logger != nil { + go func() { + log.Println("Closing Touka logger...") + CloseLogger(logger) + }() } - if len(cfg.redirectHostHeaders) > 0 { - host := firstRedirectHeaderHost(r, cfg.redirectHostHeaders) - if host == "" { - return "", http.StatusUpgradeRequired, false - } - return host, 0, true - } + // 创建一个带超时的上下文,用于 Shutdown + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() - if r == nil { - return "", http.StatusUpgradeRequired, false - } - host := strings.TrimSpace(r.Host) - if host == "" { - return "", http.StatusUpgradeRequired, false - } - return host, 0, true -} + var wg sync.WaitGroup + errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel -func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { - httpsAddr := cfg.addr - httpAddr := cfg.httpRedirectAddr - httpsPort, err := parseHTTPSPort(httpsAddr) - if err != nil { - return nil, err - } - - redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host, statusCode, ok := redirectTargetHost(r, cfg) - if !ok { - http.Error(w, http.StatusText(statusCode), statusCode) - return - } - - if parsedHost, _, err := net.SplitHostPort(host); err == nil { - host = parsedHost - if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") { - host = "[" + host + "]" + // 并发地关闭所有服务器 + for _, srv := range servers { + wg.Add(1) + go func(s *http.Server) { + defer wg.Done() + if err := s.Shutdown(ctx); err != nil { + // 将错误发送到 channel + errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) } + }(srv) + } + + wg.Wait() // 等待所有服务器的关闭 goroutine 完成 + close(errChan) // 关闭 channel,以便可以安全地遍历它 + + // 收集所有关闭过程中发生的错误 + var shutdownErrors []error + for err := range errChan { + shutdownErrors = append(shutdownErrors, err) + log.Printf("Shutdown error: %v", err) + } + + if len(shutdownErrors) > 0 { + return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 + } + log.Println("Touka server(s) exited gracefully.") + return nil +} + +func handleGracefulShutdownWithContext(servers []*http.Server, ctx context.Context, timeout time.Duration, logger *reco.Logger) error { + // 创建一个 channel 来接收操作系统信号 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 监听中断和终止信号 + + // 启动服务器 + serverStopped := make(chan error, 1) + for _, srv := range servers { + go func(s *http.Server) { + serverStopped <- s.ListenAndServe() + }(srv) + } + + select { + case <-ctx.Done(): + // Context 被取消 (例如,通过外部取消函数) + log.Println("Context cancelled, shutting down Touka server(s)...") + case err := <-serverStopped: + // 服务器自身停止 (例如,端口被占用,或 ListenAndServe 返回错误) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("Touka HTTP server failed: %w", err) + } + log.Println("Touka HTTP server stopped gracefully.") + return nil // 服务器已自行优雅关闭,无需进一步处理 + case <-quit: + // 接收到操作系统信号 + log.Println("Shutting down Touka server(s) due to OS signal...") + } + + // 关闭日志记录器 + if logger != nil { + go func() { + log.Println("Closing Touka logger...") + CloseLogger(logger) + }() + } + + // 创建一个带超时的上下文,用于 Shutdown + shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + var wg sync.WaitGroup + errChan := make(chan error, len(servers)) // 用于收集关闭错误的 channel + + // 并发地关闭所有服务器 + for _, srv := range servers { + wg.Add(1) + go func(s *http.Server) { + defer wg.Done() + if err := s.Shutdown(shutdownCtx); err != nil { + // 将错误发送到 channel + errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) + } + }(srv) + } + + wg.Wait() + close(errChan) // 关闭 channel,以便可以安全地遍历它 + + // 收集所有关闭过程中发生的错误 + var shutdownErrors []error + for err := range errChan { + shutdownErrors = append(shutdownErrors, err) + log.Printf("Shutdown error: %v", err) + } + + if len(shutdownErrors) > 0 { + return errors.Join(shutdownErrors...) // Go 1.20+ 的 errors.Join,用于合并多个错误 + } + log.Println("Touka server(s) exited gracefully.") + return nil +} + +// --- 公共 Run 方法 --- + +// Run 启动一个不支持优雅关闭的 HTTP 服务器 +// 这是一个阻塞调用,主要用于简单的场景或快速测试 +// 建议在生产环境中使用 RunShutdown 或其他支持优雅关闭的方法 +func (engine *Engine) Run(addr ...string) error { + address := resolveAddress(addr) + srv := &http.Server{Addr: address, Handler: engine} + + // 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性 + engine.applyDefaultServerConfig(srv) + if engine.ServerConfigurator != nil { + engine.ServerConfigurator(srv) + } + log.Printf("Starting Touka HTTP server on %s (no graceful shutdown)", address) + return srv.ListenAndServe() +} + +// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 +func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error { + srv := &http.Server{ + Addr: addr, + Handler: engine, + BaseContext: func(l net.Listener) context.Context { + return engine.shutdownCtx + }, + } + srv.RegisterOnShutdown(engine.shutdownCancel) + + // 应用框架的默认配置和用户提供的自定义配置 + engine.applyDefaultServerConfig(srv) + if engine.ServerConfigurator != nil { + engine.ServerConfigurator(srv) + } + + runServer("HTTP", srv) + return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) +} + +// RunShutdown 启动一个支持优雅关闭的 HTTP 服务器 +func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, timeouts ...time.Duration) error { + srv := &http.Server{ + Addr: addr, + Handler: engine, + BaseContext: func(l net.Listener) context.Context { + return engine.shutdownCtx + }, + } + srv.RegisterOnShutdown(engine.shutdownCancel) + + // 应用框架的默认配置和用户提供的自定义配置 + engine.applyDefaultServerConfig(srv) + if engine.ServerConfigurator != nil { + engine.ServerConfigurator(srv) + } + + return handleGracefulShutdownWithContext([]*http.Server{srv}, ctx, getShutdownTimeout(timeouts), engine.LogReco) +} + +// RunTLS 启动一个支持优雅关闭的 HTTPS 服务器 +func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil for RunTLS") + } + + // 配置 HTTP/2 支持 (如果使用默认配置) + if engine.useDefaultProtocols { + engine.setProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, // 默认在 TLS 上启用 HTTP/2 + }) + } + + srv := &http.Server{ + Addr: addr, + Handler: engine, + TLSConfig: tlsConfig, + BaseContext: func(l net.Listener) context.Context { + return engine.shutdownCtx + }, + } + srv.RegisterOnShutdown(engine.shutdownCancel) + + // 应用框架的默认配置和用户提供的自定义配置 + // 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator + engine.applyDefaultServerConfig(srv) + if engine.TLSServerConfigurator != nil { + engine.TLSServerConfigurator(srv) + } else if engine.ServerConfigurator != nil { + engine.ServerConfigurator(srv) + } + + runServer("HTTPS", srv) + return handleGracefulShutdown([]*http.Server{srv}, getShutdownTimeout(timeouts), engine.LogReco) +} + +// RunWithTLS 是 RunTLS 的别名,为了保持向后兼容性或更直观的命名 +func (engine *Engine) RunWithTLS(addr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + return engine.RunTLS(addr, tlsConfig, timeouts...) +} + +// RunTLSRedir 启动 HTTP 重定向服务器和 HTTPS 应用服务器,两者都支持优雅关闭 +func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + if tlsConfig == nil { + return errors.New("tls.Config must not be nil for RunTLSRedir") + } + + // --- HTTPS 服务器 --- + if engine.useDefaultProtocols { + engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true}) + } + httpsSrv := &http.Server{ + Addr: httpsAddr, + Handler: engine, + TLSConfig: tlsConfig, + BaseContext: func(l net.Listener) context.Context { + return engine.shutdownCtx + }, + } + httpsSrv.RegisterOnShutdown(engine.shutdownCancel) + engine.applyDefaultServerConfig(httpsSrv) + if engine.TLSServerConfigurator != nil { + engine.TLSServerConfigurator(httpsSrv) + } else if engine.ServerConfigurator != nil { + engine.ServerConfigurator(httpsSrv) + } + + // --- HTTP 重定向服务器 --- + redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } + + _, httpsPort, err := net.SplitHostPort(httpsAddr) + if err != nil { + // 如果 httpsAddr 没有端口,这是一个配置错误 + + log.Fatalf("Invalid HTTPS address for redirection '%s': must include a port.", httpsAddr) } targetURL := "https://" + host + // 只有在非标准 HTTPS 端口 (443) 时才附加端口号 if httpsPort != "443" { targetURL = "https://" + net.JoinHostPort(host, httpsPort) } @@ -345,205 +351,22 @@ func buildRedirectServer(engine *Engine, cfg runConfig) (*http.Server, error) { http.Redirect(w, r, targetURL, http.StatusMovedPermanently) }) + httpSrv := &http.Server{ + Addr: httpAddr, + Handler: redirectHandler, + } + engine.applyDefaultServerConfig(httpSrv) + if engine.ServerConfigurator != nil { + engine.ServerConfigurator(httpSrv) + } - server := &http.Server{Addr: httpAddr, Handler: redirectHandler} - applyRedirectServerConfig(engine, server) - return server, nil + // --- 启动服务器和优雅关闭 --- + runServer("HTTPS", httpsSrv) + runServer("HTTP Redirect", httpSrv) + return handleGracefulShutdown([]*http.Server{httpsSrv, httpSrv}, getShutdownTimeout(timeouts), engine.LogReco) } -func validateRunConfig(cfg runConfig) error { - if cfg.mode == runModeHTTPSRedirect && cfg.tlsConfig == nil { - return errors.New("WithHTTPRedirect requires WithTLS") - } - if cfg.mode == runModeHTTPS && cfg.tlsConfig == nil { - return errors.New("https mode requires WithTLS") - } - if cfg.gracefulCtx != nil && !cfg.graceful { - return errors.New("WithShutdownContext requires graceful shutdown") - } - if len(cfg.redirectHostHeaders) > 0 { - if !cfg.useHeaderHostSet || !cfg.useHeaderHost { - return errors.New("WithRedirectHostHeaders requires WithUseHeaderHost(true)") - } - } - if cfg.useHeaderHostSet && cfg.useHeaderHost { - if cfg.redirectHost != "" { - return errors.New("WithRedirectHost cannot be used when WithUseHeaderHost(true)") - } - } else if cfg.useHeaderHostSet && !cfg.useHeaderHost { - if cfg.redirectHost == "" { - return errors.New("WithUseHeaderHost(false) requires WithRedirectHost") - } - if len(cfg.redirectHostHeaders) > 0 { - return errors.New("WithRedirectHostHeaders cannot be used when WithUseHeaderHost(false)") - } - } - return nil -} - -func effectiveShutdownTimeout(cfg runConfig) time.Duration { - if cfg.shutdownTimeoutSet || cfg.shutdownDefaultSet { - if cfg.shutdownTimeout > 0 { - return cfg.shutdownTimeout - } - } - return defaultShutdownTimeout -} - -func closeLoggerAsync(logger *reco.Logger) { - if logger == nil { - return - } - go func() { - log.Println("Closing Touka logger...") - CloseLogger(logger) - }() -} - -func shutdownServers(servers []*http.Server, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - var wg sync.WaitGroup - errChan := make(chan error, len(servers)) - for _, srv := range servers { - wg.Add(1) - go func(s *http.Server) { - defer wg.Done() - if err := s.Shutdown(ctx); err != nil { - errChan <- fmt.Errorf("server on %s shutdown failed: %w", s.Addr, err) - } - }(srv) - } - - wg.Wait() - close(errChan) - - var shutdownErrors []error - for err := range errChan { - shutdownErrors = append(shutdownErrors, err) - log.Printf("Shutdown error: %v", err) - } - if len(shutdownErrors) > 0 { - return errors.Join(shutdownErrors...) - } - return nil -} - -func gracefulServe(servers []*http.Server, serveTLS []bool, timeout time.Duration, logger *reco.Logger, shutdownCtx context.Context) error { - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - defer signal.Stop(quit) - - serverStopped := make(chan error, len(servers)) - for i, srv := range servers { - serveTLSFlag := serveTLS[i] - go func(server *http.Server, useTLS bool) { - serverStopped <- serveServer(server, useTLS) - }(srv, serveTLSFlag) - } - - select { - case err := <-serverStopped: - if err != nil && !errors.Is(err, http.ErrServerClosed) { - if shutdownErr := shutdownServers(servers, timeout); shutdownErr != nil { - return errors.Join(err, shutdownErr) - } - return err - } - log.Println("Touka server stopped gracefully.") - return nil - case <-quit: - log.Println("Shutting down Touka server(s) due to OS signal...") - case <-shutdownCtx.Done(): - log.Println("Context cancelled, shutting down Touka server(s)...") - } - - closeLoggerAsync(logger) - if err := shutdownServers(servers, timeout); err != nil { - return err - } - log.Println("Touka server(s) exited gracefully.") - return nil -} - -// Run starts the engine with the provided startup options. -// -// Default behavior with no options: -// - HTTP only -// - listens on :8080 -// - no graceful shutdown orchestration -// -// Add WithGracefulShutdown(...) or WithGracefulShutdownDefault() to enable -// signal-aware graceful shutdown and request-context cancellation semantics. -// Add WithTLS(...) to run HTTPS; this is independent from graceful shutdown. -func (engine *Engine) Run(opts ...RunOption) error { - cfg := defaultRunConfig() - for _, opt := range opts { - if opt == nil { - continue - } - if err := opt.apply(&cfg); err != nil { - return err - } - } - if cfg.httpRedirectAddr != "" { - cfg.mode = runModeHTTPSRedirect - } else if cfg.tlsConfig != nil { - cfg.mode = runModeHTTPS - } - if err := validateRunConfig(cfg); err != nil { - return err - } - - serveTLS := cfg.mode != runModeHTTP - - mainServer := buildMainServer(engine, cfg) - servers := []*http.Server{mainServer} - serveTLSFlags := []bool{serveTLS} - if cfg.mode == runModeHTTPSRedirect { - redirectServer, err := buildRedirectServer(engine, cfg) - if err != nil { - return err - } - servers = append(servers, redirectServer) - serveTLSFlags = append(serveTLSFlags, false) - } - - if !cfg.graceful { - if len(servers) > 1 { - serverStopped := make(chan error, len(servers)) - for i, srv := range servers { - serveTLSFlag := serveTLSFlags[i] - go func(server *http.Server, useTLS bool) { - serverStopped <- serveServer(server, useTLS) - }(srv, serveTLSFlag) - } - - err := <-serverStopped - if shutdownErr := shutdownServers(servers, defaultShutdownTimeout); shutdownErr != nil { - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return errors.Join(err, shutdownErr) - } - return shutdownErr - } - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return err - } - return nil - } - - protocolLabel := "HTTP" - if serveTLS { - protocolLabel = "HTTPS" - } - log.Printf("Starting Touka %s server on %s", protocolLabel, cfg.addr) - return serveServer(mainServer, serveTLS) - } - - shutdownCtx := context.Background() - if cfg.gracefulCtx != nil { - shutdownCtx = cfg.gracefulCtx - } - return gracefulServe(servers, serveTLSFlags, effectiveShutdownTimeout(cfg), engine.LogReco, shutdownCtx) +// RunWithTLSRedir 是 RunTLSRedir 的别名,为了保持向后兼容性 +func (engine *Engine) RunWithTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Config, timeouts ...time.Duration) error { + return engine.RunTLSRedir(httpAddr, httpsAddr, tlsConfig, timeouts...) } diff --git a/serve_test.go b/serve_test.go deleted file mode 100644 index a02f1df..0000000 --- a/serve_test.go +++ /dev/null @@ -1,492 +0,0 @@ -package touka - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "errors" - "io" - "math/big" - "net" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" -) - -func generateSelfSignedCert(t *testing.T) tls.Certificate { - t.Helper() - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("generate private key: %v", err) - } - - tmpl := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "127.0.0.1"}, - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - }, - IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, - } - - der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey) - if err != nil { - t.Fatalf("create self-signed cert: %v", err) - } - - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) - - cert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - t.Fatalf("parse self-signed cert: %v", err) - } - return cert -} - -func TestServeServerHTTPModeIgnoresTLSConfig(t *testing.T) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen on ephemeral port: %v", err) - } - addr := listener.Addr().String() - if err := listener.Close(); err != nil { - t.Fatalf("close temporary listener: %v", err) - } - - srv := &http.Server{ - Addr: addr, - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = w.Write([]byte("ok")) - }), - // RunShutdown uses the HTTP startup path and must not let a shared - // ServerConfigurator accidentally turn it into HTTPS. - TLSConfig: &tls.Config{}, - } - - errCh := make(chan error, 1) - go func() { - errCh <- serveServer(srv, false) - }() - - client := &http.Client{Timeout: 200 * time.Millisecond} - var resp *http.Response - requestURL := "http://" + addr - - deadline := time.Now().Add(3 * time.Second) - for time.Now().Before(deadline) { - resp, err = client.Get(requestURL) - if err == nil { - break - } - time.Sleep(20 * time.Millisecond) - } - if err != nil { - select { - case serveErr := <-errCh: - t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: request error=%v, serve error=%v", err, serveErr) - default: - t.Fatalf("expected HTTP server to accept plain HTTP with TLSConfig set: %v", err) - } - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read response body: %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status code: got %d want %d", resp.StatusCode, http.StatusOK) - } - if string(body) != "ok" { - t.Fatalf("unexpected body: got %q want %q", string(body), "ok") - } - - shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if err := srv.Shutdown(shutdownCtx); err != nil { - t.Fatalf("shutdown server: %v", err) - } - - if err := <-errCh; !errors.Is(err, http.ErrServerClosed) { - t.Fatalf("serveServer should stop with ErrServerClosed after shutdown, got %v", err) - } -} - -func TestRunRejectsRedirectWithoutTLS(t *testing.T) { - engine := New() - err := engine.Run(WithHTTPRedirect(":80")) - if err == nil { - t.Fatal("expected redirect mode without TLS to fail") - } -} - -func TestRunRejectsRedirectHostHeadersWithoutExplicitUseHeaderHostTrue(t *testing.T) { - engine := New() - err := engine.Run( - WithAddr(":443"), - WithTLS(&tls.Config{}), - WithHTTPRedirect(":80", WithRedirectHostHeaders([]string{"X-Forwarded-Host"})), - ) - if err == nil { - t.Fatal("expected redirect host headers without explicit WithUseHeaderHost(true) to fail") - } -} - -func TestWithGracefulShutdownDefaultUsesDefaultTimeout(t *testing.T) { - cfg := defaultRunConfig() - if err := WithGracefulShutdownDefault().apply(&cfg); err != nil { - t.Fatalf("apply graceful default option: %v", err) - } - if !cfg.graceful { - t.Fatal("expected graceful shutdown to be enabled") - } - if cfg.shutdownTimeout != defaultShutdownTimeout { - t.Fatalf("expected default shutdown timeout %v, got %v", defaultShutdownTimeout, cfg.shutdownTimeout) - } -} - -func TestWithTLSDoesNotRequireGracefulShutdown(t *testing.T) { - cfg := defaultRunConfig() - tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12} - if err := WithTLS(tlsConfig).apply(&cfg); err != nil { - t.Fatalf("apply TLS option: %v", err) - } - if cfg.mode != runModeHTTPS { - t.Fatalf("expected HTTPS mode, got %v", cfg.mode) - } - if cfg.graceful { - t.Fatal("expected TLS option to remain independent from graceful shutdown") - } - if cfg.tlsConfig != tlsConfig { - t.Fatal("expected TLS config to be preserved in run config") - } -} - -func TestBuildRedirectServerRejectsHTTPSAddrWithoutPort(t *testing.T) { - engine := New() - if _, err := buildRedirectServer(engine, runConfig{addr: "example.com", httpRedirectAddr: ":80"}); err == nil { - t.Fatal("expected redirect server builder to reject https address without port") - } -} - -func TestValidateRunConfigRejectsShutdownContextWithoutGraceful(t *testing.T) { - cfg := defaultRunConfig() - ctx := t.Context() - if err := WithShutdownContext(ctx).apply(&cfg); err != nil { - t.Fatalf("apply shutdown context option: %v", err) - } - if err := validateRunConfig(cfg); err == nil { - t.Fatal("expected shutdown context without graceful shutdown to fail validation") - } -} - -func TestValidateRunConfigDoesNotMutateMode(t *testing.T) { - cfg := defaultRunConfig() - cfg.httpRedirectAddr = ":80" - if err := validateRunConfig(cfg); err != nil { - t.Fatalf("validate run config: %v", err) - } - if cfg.mode != runModeHTTP { - t.Fatalf("expected validateRunConfig to leave mode unchanged, got %v", cfg.mode) - } -} - -func TestValidateRunConfigRejectsConfiguredHostModeWithoutRedirectHost(t *testing.T) { - cfg := defaultRunConfig() - cfg.mode = runModeHTTPSRedirect - cfg.tlsConfig = &tls.Config{} - cfg.useHeaderHost = false - cfg.useHeaderHostSet = true - if err := validateRunConfig(cfg); err == nil { - t.Fatal("expected configured host mode without redirect host to fail validation") - } -} - -func TestValidateRunConfigRejectsRedirectHostWhenHeaderModeEnabled(t *testing.T) { - cfg := defaultRunConfig() - cfg.mode = runModeHTTPSRedirect - cfg.tlsConfig = &tls.Config{} - cfg.useHeaderHost = true - cfg.useHeaderHostSet = true - cfg.redirectHost = "configured.example" - if err := validateRunConfig(cfg); err == nil { - t.Fatal("expected redirect host to be rejected when header host mode is enabled") - } -} - -func TestBuildMainServerGracefulSetsBaseContextAndShutdownHook(t *testing.T) { - engine := New() - server := buildMainServer(engine, runConfig{addr: ":8080", graceful: true, mode: runModeHTTP}) - if server.BaseContext == nil { - t.Fatal("expected graceful main server to set BaseContext") - } - - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen for base context check: %v", err) - } - defer listener.Close() - if got := server.BaseContext(listener); got != engine.shutdownCtx { - t.Fatal("expected graceful main server to use engine shutdown context") - } -} - -func TestBuildMainServerTLSConfiguratorPrecedence(t *testing.T) { - engine := New() - serverConfigured := false - tlsConfigured := false - engine.SetServerConfigurator(func(s *http.Server) { - serverConfigured = true - s.ReadTimeout = time.Second - }) - engine.SetTLSServerConfigurator(func(s *http.Server) { - tlsConfigured = true - s.IdleTimeout = time.Second - }) - - server := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) - if !tlsConfigured { - t.Fatal("expected TLS configurator to run for HTTPS main server") - } - if serverConfigured { - t.Fatal("expected generic server configurator to be skipped when TLS configurator is set") - } - if server.IdleTimeout != time.Second { - t.Fatal("expected TLS configurator changes to be applied to HTTPS main server") - } -} - -func TestBuildRedirectServerUsesGenericConfigurator(t *testing.T) { - engine := New() - configured := false - engine.SetServerConfigurator(func(s *http.Server) { - configured = true - s.ReadTimeout = time.Second - }) - - server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) - if err != nil { - t.Fatalf("build redirect server: %v", err) - } - if !configured { - t.Fatal("expected redirect server to use generic server configurator") - } - if server.ReadTimeout != time.Second { - t.Fatal("expected redirect server configurator changes to be applied") - } -} - -func TestTLSRunDoesNotMutateDefaultHTTPProtocols(t *testing.T) { - engine := New() - httpsServer := buildMainServer(engine, runConfig{addr: ":443", mode: runModeHTTPS, tlsConfig: &tls.Config{}}) - if !httpsServer.Protocols.HTTP2() { - t.Fatal("expected HTTPS server to enable HTTP/2 under default protocol settings") - } - - httpServer := buildMainServer(engine, defaultRunConfig()) - if httpServer.Protocols.HTTP2() { - t.Fatal("expected later plain HTTP server to keep default HTTP/2 disabled") - } -} - -func TestBuildRedirectServerRedirectsWithoutGracefulMode(t *testing.T) { - engine := New() - server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) - if err != nil { - t.Fatalf("build redirect server: %v", err) - } - - req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) - req.Host = "example.com:80" - rr := httptest.NewRecorder() - server.Handler.ServeHTTP(rr, req) - - if rr.Code != http.StatusMovedPermanently { - t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) - } - if location := rr.Header().Get("Location"); location != "https://example.com/plain/path?q=1" { - t.Fatalf("unexpected redirect location: %q", location) - } -} - -func TestBuildRedirectServerUsesConfiguredHeadersInOrder(t *testing.T) { - engine := New() - server, err := buildRedirectServer(engine, runConfig{ - addr: ":443", - httpRedirectAddr: ":80", - useHeaderHost: true, - useHeaderHostSet: true, - redirectHostHeaders: []string{"X-First-Host", "X-Forwarded-Host"}, - }) - if err != nil { - t.Fatalf("build redirect server: %v", err) - } - - req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) - req.Host = "example.com:80" - req.Header.Set("X-Forwarded-Host", "forwarded.example") - req.Header.Set("X-First-Host", "first.example") - rr := httptest.NewRecorder() - server.Handler.ServeHTTP(rr, req) - - if rr.Code != http.StatusMovedPermanently { - t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) - } - if location := rr.Header().Get("Location"); location != "https://first.example/plain/path?q=1" { - t.Fatalf("unexpected redirect location: %q", location) - } -} - -func TestBuildRedirectServerReturns426WhenConfiguredHeadersMiss(t *testing.T) { - engine := New() - server, err := buildRedirectServer(engine, runConfig{ - addr: ":443", - httpRedirectAddr: ":80", - useHeaderHost: true, - useHeaderHostSet: true, - redirectHostHeaders: []string{"X-Forwarded-Host"}, - }) - if err != nil { - t.Fatalf("build redirect server: %v", err) - } - - req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) - req.Host = "example.com:80" - rr := httptest.NewRecorder() - server.Handler.ServeHTTP(rr, req) - - if rr.Code != http.StatusUpgradeRequired { - t.Fatalf("expected status %d when configured redirect headers miss, got %d", http.StatusUpgradeRequired, rr.Code) - } -} - -func TestBuildRedirectServerUsesConfiguredRedirectHostWhenHeaderModeDisabled(t *testing.T) { - engine := New() - server, err := buildRedirectServer(engine, runConfig{ - addr: ":443", - httpRedirectAddr: ":80", - useHeaderHost: false, - useHeaderHostSet: true, - redirectHost: "configured.example", - }) - if err != nil { - t.Fatalf("build redirect server: %v", err) - } - - req := httptest.NewRequest(http.MethodGet, "http://example.com/plain/path?q=1", nil) - req.Host = "example.com:80" - req.Header.Set("X-Forwarded-Host", "forwarded.example") - rr := httptest.NewRecorder() - server.Handler.ServeHTTP(rr, req) - - if rr.Code != http.StatusMovedPermanently { - t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) - } - if location := rr.Header().Get("Location"); location != "https://configured.example/plain/path?q=1" { - t.Fatalf("unexpected redirect location: %q", location) - } -} - -func TestBuildRedirectServerPreservesIPv6BracketsInRedirectURL(t *testing.T) { - engine := New() - server, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: ":80"}) - if err != nil { - t.Fatalf("build redirect server: %v", err) - } - - req := httptest.NewRequest(http.MethodGet, "http://[::1]/plain/path?q=1", nil) - req.Host = "[::1]:80" - rr := httptest.NewRecorder() - server.Handler.ServeHTTP(rr, req) - - if rr.Code != http.StatusMovedPermanently { - t.Fatalf("expected redirect status %d, got %d", http.StatusMovedPermanently, rr.Code) - } - if location := rr.Header().Get("Location"); location != "https://[::1]/plain/path?q=1" { - t.Fatalf("unexpected IPv6 redirect location: %q", location) - } -} - -func TestGracefulServeShutsDownSiblingServersOnStartupFailure(t *testing.T) { - occupied, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen on occupied addr: %v", err) - } - occupiedAddr := occupied.Addr().String() - defer occupied.Close() - - redirectListener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen for redirect addr: %v", err) - } - redirectAddr := redirectListener.Addr().String() - if err := redirectListener.Close(); err != nil { - t.Fatalf("close redirect addr probe: %v", err) - } - - engine := New() - redirectServer, err := buildRedirectServer(engine, runConfig{addr: ":443", httpRedirectAddr: redirectAddr}) - if err != nil { - t.Fatalf("build redirect server: %v", err) - } - mainServer := &http.Server{Addr: occupiedAddr, Handler: engine} - - err = gracefulServe([]*http.Server{mainServer, redirectServer}, []bool{false, false}, 200*time.Millisecond, nil, context.Background()) - if err == nil { - t.Fatal("expected gracefulServe to fail when one server cannot bind") - } - if !strings.Contains(err.Error(), occupiedAddr) { - t.Fatalf("expected startup failure to mention occupied address %q, got %v", occupiedAddr, err) - } - - conn, dialErr := net.DialTimeout("tcp", redirectAddr, 200*time.Millisecond) - if dialErr == nil { - conn.Close() - t.Fatalf("expected sibling redirect server to be shut down after startup failure, but %s is still accepting connections", redirectAddr) - } - if !strings.Contains(dialErr.Error(), "refused") && !strings.Contains(dialErr.Error(), "reset") { - t.Fatalf("unexpected dial result after shutdown, got %v", dialErr) - } -} - -func TestRunNonGracefulRedirectReturnsStartupError(t *testing.T) { - occupied, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen on occupied addr: %v", err) - } - occupiedAddr := occupied.Addr().String() - defer occupied.Close() - - redirectListener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen for redirect addr: %v", err) - } - redirectAddr := redirectListener.Addr().String() - if err := redirectListener.Close(); err != nil { - t.Fatalf("close redirect addr probe: %v", err) - } - - engine := New() - err = engine.Run( - WithAddr(occupiedAddr), - WithTLS(&tls.Config{}), - WithHTTPRedirect(redirectAddr), - ) - if err == nil { - t.Fatal("expected non-graceful TLS redirect startup to return bind error") - } - if !strings.Contains(err.Error(), occupiedAddr) { - t.Fatalf("expected startup error to mention occupied address %q, got %v", occupiedAddr, err) - } -} diff --git a/touka.go b/touka.go index 4ad81da..dd529cb 100644 --- a/touka.go +++ b/touka.go @@ -22,10 +22,10 @@ type HandlerFunc func(*Context) // HandlersChain 定义处理函数链(中间件栈)的类型。 type HandlersChain []HandlerFunc -// Router 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 -type Router interface { - Group(relativePath string, handlers ...HandlerFunc) Router // 创建路由分组 - Use(middleware ...HandlerFunc) Router // 应用中间件到当前组或子组 +// IRouter 定义了路由注册的接口,提供路由分组和HTTP方法注册的能力。 +type IRouter interface { + Group(relativePath string, handlers ...HandlerFunc) IRouter // 创建路由分组 + Use(middleware ...HandlerFunc) IRouter // 应用中间件到当前组或子组 Handle(httpMethod, relativePath string, handlers ...HandlerFunc) // 注册通用HTTP方法 GET(relativePath string, handlers ...HandlerFunc) diff --git a/tree.go b/tree.go index b159c8d..31246a5 100644 --- a/tree.go +++ b/tree.go @@ -121,28 +121,14 @@ const ( // node 表示路由树中的一个节点. type node struct { - path string // 当前节点的路径段 - indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 - wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) - hasCaseInsensitivePath bool // 根节点是否包含需要 fixed-path 大小写修正的路由 - nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) - priority uint32 // 节点的优先级, 用于查找时优先匹配 - children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 - handlers HandlersChain // 绑定到此节点的处理函数链 - fullPath string // 完整路径, 用于调试和错误信息 -} - -func routeNeedsCaseInsensitiveLookup(path string) bool { - for i := 0; i < len(path); i++ { - c := path[i] - if c >= utf8.RuneSelf { - return true - } - if c >= 'A' && c <= 'Z' { - return true - } - } - return false + path string // 当前节点的路径段 + indices string // 子节点第一个字符的索引字符串, 用于快速查找子节点 + wildChild bool // 是否包含通配符子节点(:param 或 *catchAll) + nType nodeType // 节点的类型(静态, 根, 参数, 捕获所有) + priority uint32 // 节点的优先级, 用于查找时优先匹配 + children []*node // 子节点切片, 最多有一个 :param 风格的节点位于数组末尾 + handlers HandlersChain // 绑定到此节点的处理函数链 + fullPath string // 完整路径, 用于调试和错误信息 } // incrementChildPrio 增加给定子节点的优先级并在必要时重新排序. @@ -176,9 +162,6 @@ func (n *node) incrementChildPrio(pos int) int { func (n *node) addRoute(path string, handlers HandlersChain) { fullPath := path // 记录完整的路径 n.priority++ // 增加当前节点的优先级 - if routeNeedsCaseInsensitiveLookup(path) { - n.hasCaseInsensitivePath = true - } // 如果是空树(根节点) if len(n.path) == 0 && len(n.children) == 0 { @@ -469,14 +452,12 @@ type skippedNode struct { // 建议进行 TSR(尾部斜杠重定向). func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { var globalParamsCount int16 // 全局参数计数 - var backtrackToWildChild bool walk: // 外部循环用于遍历路由树 for { prefix := n.path // 当前节点的路径前缀 if len(path) > len(prefix) { if path[:len(prefix)] == prefix { // 如果路径以当前节点的前缀开头 - pathAtNode := path path = path[len(prefix):] // 移除已匹配的前缀 // 在访问 path[0] 之前进行安全检查 @@ -486,26 +467,30 @@ walk: // 外部循环用于遍历路由树 // 优先尝试所有非通配符子节点, 通过匹配索引字符 idxc := path[0] // 剩余路径的第一个字符 - if !backtrackToWildChild { - for i := 0; i < len(n.indices); i++ { - if n.indices[i] == idxc { // 如果找到匹配的索引字符 - // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 - if n.wildChild { - index := len(*skippedNodes) - *skippedNodes = (*skippedNodes)[:index+1] - (*skippedNodes)[index] = skippedNode{ - path: pathAtNode, // 记录进入当前节点时的剩余路径 - node: n, - paramsCount: globalParamsCount, // 记录当前参数计数 - } + for i, c := range []byte(n.indices) { + if c == idxc { // 如果找到匹配的索引字符 + // 如果当前节点有通配符子节点, 则将当前节点添加到 skippedNodes, 以便回溯 + if n.wildChild { + index := len(*skippedNodes) + *skippedNodes = (*skippedNodes)[:index+1] + (*skippedNodes)[index] = skippedNode{ + path: prefix + path, // 记录跳过的路径 + node: &node{ // 复制当前节点的状态 + path: n.path, + wildChild: n.wildChild, + nType: n.nType, + priority: n.priority, + children: n.children, + handlers: n.handlers, + fullPath: n.fullPath, + }, + paramsCount: globalParamsCount, // 记录当前参数计数 } - - n = n.children[i] // 移动到匹配的子节点 - continue walk // 继续外部循环 } + + n = n.children[i] // 移动到匹配的子节点 + continue walk // 继续外部循环 } - } else { - backtrackToWildChild = false } if !n.wildChild { @@ -522,8 +507,7 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] // 恢复参数切片 } globalParamsCount = skippedNode.paramsCount // 恢复参数计数 - backtrackToWildChild = true - continue walk // 继续外部循环 + continue walk // 继续外部循环 } } } @@ -563,7 +547,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path[:end] // 提取参数值 - if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { + if unescape { // 如果需要进行 URL 解码 if v, err := url.QueryUnescape(val); err == nil { val = v // 解码成功则更新值 } @@ -615,7 +599,7 @@ walk: // 外部循环用于遍历路由树 i := len(*value.params) *value.params = (*value.params)[:i+1] // 扩展切片 val := path // 参数值是剩余的整个路径 - if unescape && (strings.IndexByte(val, '%') >= 0 || strings.IndexByte(val, '+') >= 0) { + if unescape { // 如果需要进行 URL 解码 if v, err := url.QueryUnescape(path); err == nil { val = v // 解码成功则更新值 } @@ -650,7 +634,6 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount - backtrackToWildChild = true continue walk } } @@ -675,8 +658,8 @@ walk: // 外部循环用于遍历路由树 } // 未找到处理函数. 检查此路径加尾部斜杠是否存在处理函数, 以进行尾部斜杠重定向建议 - for i := 0; i < len(n.indices); i++ { - if n.indices[i] == '/' { // 如果索引中包含 '/' + for i, c := range []byte(n.indices) { + if c == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 value.tsr = (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) // 或者子节点是 catchAll 且其子节点有处理函数 @@ -705,7 +688,6 @@ walk: // 外部循环用于遍历路由树 *value.params = (*value.params)[:skippedNode.paramsCount] } globalParamsCount = skippedNode.paramsCount - backtrackToWildChild = true continue walk } } @@ -719,15 +701,13 @@ walk: // 外部循环用于遍历路由树 // 它还可以选择修复尾部斜杠. // 它返回大小写校正后的路径和一个布尔值, 指示查找是否成功. func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { - return n.findCaseInsensitivePathWithBuffer(path, nil, fixTrailingSlash) -} + const stackBufSize = 128 // 栈上缓冲区的默认大小 -func (n *node) findCaseInsensitivePathWithBuffer(path string, buf []byte, fixTrailingSlash bool) ([]byte, bool) { - if buf != nil { - buf = buf[:0] - } - if cap(buf) < len(path)+1 { - buf = make([]byte, 0, len(path)+1) + // 在常见情况下使用栈上静态大小的缓冲区. + // 如果路径太长, 则在堆上分配缓冲区. + buf := make([]byte, 0, stackBufSize) + if length := len(path) + 1; length > stackBufSize { + buf = make([]byte, 0, length) // 如果路径太长, 则分配更大的缓冲区 } ciPath := n.findCaseInsensitivePathRec( @@ -778,8 +758,8 @@ walk: // 外部循环用于遍历路由树 // 未找到处理函数. // 尝试通过添加尾部斜杠来修复路径 if fixTrailingSlash { - for i := 0; i < len(n.indices); i++ { - if n.indices[i] == '/' { // 如果索引中包含 '/' + for i, c := range []byte(n.indices) { + if c == '/' { // 如果索引中包含 '/' n = n.children[i] // 移动到对应的子节点 if (len(n.path) == 1 && n.handlers != nil) || // 如果子节点路径是 '/' 且有处理函数 (n.nType == catchAll && n.children[0].handlers != nil) { // 或者子节点是 catchAll 且其子节点有处理函数 @@ -801,8 +781,8 @@ walk: // 外部循环用于遍历路由树 if rb[0] != 0 { // 旧 rune 未处理完 idxc := rb[0] - for i := 0; i < len(n.indices); i++ { - if n.indices[i] == idxc { + for i, c := range []byte(n.indices) { + if c == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) @@ -833,9 +813,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i := 0; i < len(n.indices); i++ { + for i, c := range []byte(n.indices) { // 小写匹配 - if n.indices[i] == idxc { + if c == idxc { // 必须使用递归方法, 因为大写字节和小写字节都可能作为索引存在 if out := n.children[i].findCaseInsensitivePathRec( path, ciPath, rb, fixTrailingSlash, @@ -852,9 +832,9 @@ walk: // 外部循环用于遍历路由树 rb = shiftNRuneBytes(rb, off) idxc := rb[0] - for i := 0; i < len(n.indices); i++ { + for i, c := range []byte(n.indices) { // 大写匹配 - if n.indices[i] == idxc { + if c == idxc { // 继续处理子节点 n = n.children[i] npLen = len(n.path) @@ -872,7 +852,7 @@ walk: // 外部循环用于遍历路由树 return nil // 未找到, 返回 nil } - n = n.children[len(n.children)-1] // 通配符子节点约定始终位于末尾 + n = n.children[0] // 移动到通配符子节点(通常是唯一一个) switch n.nType { case param: // 参数节点 // 查找参数结束位置('/' 或路径末尾) diff --git a/tree_test.go b/tree_test.go index a35a1a8..d3ffdfa 100644 --- a/tree_test.go +++ b/tree_test.go @@ -11,7 +11,6 @@ import ( "regexp" "strings" "testing" - "time" ) // Used as a workaround since we can't compare functions or their addresses @@ -40,23 +39,6 @@ func getSkippedNodes() *[]skippedNode { return &ps } -func getValueWithTimeout(t *testing.T, tree *node, path string, unescape bool) nodeValue { - t.Helper() - - resultCh := make(chan nodeValue, 1) - go func() { - resultCh <- tree.getValue(path, getParams(), getSkippedNodes(), unescape) - }() - - select { - case value := <-resultCh: - return value - case <-time.After(2 * time.Second): - t.Fatalf("lookup for path %q timed out, likely stuck in backtracking", path) - return nodeValue{} - } -} - func checkRequests(t *testing.T, tree *node, requests testRequests, unescapes ...bool) { unescape := false if len(unescapes) >= 1 { @@ -919,34 +901,6 @@ func TestTreeInvalidNodeType(t *testing.T) { } } -func TestFindCaseInsensitivePathWithStaticAndParamRoutesDoesNotPanicOnMiss(t *testing.T) { - tree := &node{} - routes := [...]string{ - "/:user/:repo/info/refs", - "/healthz", - "/api/db/data", - "/api/db/sum", - } - - for _, route := range routes { - tree.addRoute(route, fakeHandler(route)) - } - - defer func() { - if r := recover(); r != nil { - t.Fatalf("unexpected panic while looking up missing path: %v", r) - } - }() - - if out, found := tree.findCaseInsensitivePath("/does-not-exist", true); found || out != nil { - t.Fatalf("expected missing path lookup to return no match, got %q, %t", string(out), found) - } - - if out, found := tree.findCaseInsensitivePath("/does-not-exist", false); found || out != nil { - t.Fatalf("expected missing path lookup without trailing slash fix to return no match, got %q, %t", string(out), found) - } -} - func TestTreeInvalidParamsType(t *testing.T) { tree := &node{} // add a child with wildcard @@ -1122,51 +1076,3 @@ func TestComplexBacktrackingWithCatchAll(t *testing.T) { t.Errorf("处理路径 '%s' 时参数不匹配: \n 得到: %v\n 想要: %v", reqPath, *value.params, wantParams) } } - -func TestBacktrackingFallsThroughToWildcardBranch(t *testing.T) { - tests := []struct { - name string - routes []string - requestPath string - wantFullPath string - wantParams Params - }{ - { - name: "param route after static dead end", - routes: []string{"/foo/bar", "/foo/:id/details"}, - requestPath: "/foo/bar/details", - wantFullPath: "/foo/:id/details", - wantParams: Params{{Key: "id", Value: "bar"}}, - }, - { - name: "catch-all route after static dead end", - routes: []string{"/foo/bar", "/foo/:id/*rest"}, - requestPath: "/foo/bar/baz.txt", - wantFullPath: "/foo/:id/*rest", - wantParams: Params{ - {Key: "id", Value: "bar"}, - {Key: "rest", Value: "/baz.txt"}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tree := &node{} - for _, route := range tt.routes { - tree.addRoute(route, fakeHandler(route)) - } - - value := getValueWithTimeout(t, tree, tt.requestPath, false) - if value.handlers == nil { - t.Fatalf("expected handlers for %q", tt.requestPath) - } - if value.fullPath != tt.wantFullPath { - t.Fatalf("expected full path %q for %q, got %q", tt.wantFullPath, tt.requestPath, value.fullPath) - } - if value.params == nil || !reflect.DeepEqual(*value.params, tt.wantParams) { - t.Fatalf("expected params %v for %q, got %v", tt.wantParams, tt.requestPath, value.params) - } - }) - } -}