diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d3e55a2..f7754d4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,8 +2,6 @@ name: Go Test on: push: - tags: - - '*' jobs: test: @@ -13,9 +11,9 @@ jobs: uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: - go-version: '1.24' + go-version-file: 'go.mod' - name: Run tests run: go test -v ./... diff --git a/README.md b/README.md index 3ab971f..a7b99fd 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,20 @@ Touka(灯花) 是一个基于 Go 语言构建的多层次、高性能 Web 框架。其设计目标是为开发者提供**更直接的控制、有效的扩展能力,以及针对特定场景的行为优化**。 -**想深入了解 Touka 吗?请阅读我们的 -> [深度指南 (about-touka.md)](about-touka.md)** +## 文档 -这份深度指南包含了对框架设计哲学、核心功能(路由、上下文、中间件、错误处理等)的全面剖析,并提供了大量可直接使用的代码示例,帮助您快速上手并精通 Touka。 +我们提供了详尽的文档来帮助您快速上手并深入了解 Touka: + +- **[灯花框架简介 (introduction.md)](docs/introduction.md)** +- **[快速开始 (quickstart.md)](docs/quickstart.md)** +- **[路由系统 (routing.md)](docs/routing.md)** +- **[上下文 Context (context.md)](docs/context.md)** +- **[中间件 (middleware.md)](docs/middleware.md)** +- **[统一错误处理 (error-handling.md)](docs/error-handling.md)** +- **[静态文件与资源 (static-files.md)](docs/static-files.md)** +- **[反向代理 (reverse-proxy.md)](docs/reverse-proxy.md)** +- **[Server-Sent Events (sse.md)](docs/sse.md)** +- **[高级特性与优化 (advanced.md)](docs/advanced.md)** ### 快速上手 @@ -72,11 +83,9 @@ func main() { - [jwt](https://github.com/fenthope/jwt) - [带宽限制](https://github.com/fenthope/toukautil/blob/main/bandwithlimiter.go) -## 文档与贡献 +## 贡献 -* **深度指南:** **[about-touka.md](about-touka.md)** -* **API 文档:** 访问 [pkg.go.dev/github.com/infinite-iroha/touka](https://pkg.go.dev/github.com/infinite-iroha/touka) 查看完整的 API 参考。 -* **贡献:** 我们欢迎任何形式的贡献,无论是错误报告、功能建议还是代码提交。请遵循项目的贡献指南。 +我们欢迎任何形式的贡献,无论是错误报告、功能建议还是代码提交。请遵循项目的贡献指南。 ## 相关项目 diff --git a/context.go b/context.go index c79e4cc..c37371f 100644 --- a/context.go +++ b/context.go @@ -19,6 +19,8 @@ import ( "net/url" "os" "path/filepath" + "reflect" + "strconv" "strings" "sync" "time" @@ -65,6 +67,10 @@ type Context struct { // 请求体Body大小限制 MaxRequestBodySize int64 + + // skippedNodes 用于记录跳过的节点信息,以便回溯 + // 通常在处理嵌套路由时使用 + SkippedNodes []skippedNode } // --- Context 相关方法实现 --- @@ -80,7 +86,13 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { } c.Request = req - c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 + //c.Params = c.Params[:0] // 清空 Params 切片,而不是重新分配,以复用底层数组 + //避免params长度为0 + if cap(c.Params) > 0 { + c.Params = c.Params[:0] + } else { + c.Params = make(Params, 0, c.engine.maxParams) + } c.handlers = nil c.index = -1 // 初始为 -1,`Next()` 将其设置为 0 c.Keys = make(map[string]any) // 每次请求重新创建 map,避免数据污染 @@ -90,6 +102,12 @@ func (c *Context) reset(w http.ResponseWriter, req *http.Request) { c.ctx = req.Context() // 使用请求的上下文,继承其取消信号和值 c.sameSite = http.SameSiteDefaultMode // 默认 SameSite 模式 c.MaxRequestBodySize = c.engine.GlobalMaxRequestBodySize + + if cap(c.SkippedNodes) > 0 { + c.SkippedNodes = c.SkippedNodes[:0] + } else { + c.SkippedNodes = make([]skippedNode, 0, 256) + } } // Next 在处理链中执行下一个处理函数 @@ -270,7 +288,7 @@ func (c *Context) Raw(code int, contentType string, data []byte) { // String 向响应写入格式化的字符串 func (c *Context) String(code int, format string, values ...any) { c.Writer.WriteHeader(code) - c.Writer.Write([]byte(fmt.Sprintf(format, values...))) + c.Writer.Write(fmt.Appendf(nil, format, values...)) } // Text 向响应写入无需格式化的string @@ -325,7 +343,6 @@ func (c *Context) FileText(code int, filePath string) { } /* -// not fot work // FileTextSafeDir func (c *Context) FileTextSafeDir(code int, filePath string, safeDir string) { @@ -394,6 +411,7 @@ func (c *Context) JSON(code int, obj any) { c.Writer.WriteHeader(code) if err := json.MarshalWrite(c.Writer, obj); err != nil { c.AddError(fmt.Errorf("failed to marshal JSON: %w", err)) + c.Errorf("failed to marshal JSON: %s", err) c.ErrorUseHandle(http.StatusInternalServerError, fmt.Errorf("failed to marshal JSON: %w", err)) return } @@ -448,7 +466,7 @@ func (c *Context) HTML(code int, name string, obj any) { // 可以扩展支持其他渲染器接口 } // 默认简单输出,用于未配置 HTMLRender 的情况 - c.Writer.Write([]byte(fmt.Sprintf("\n
%v
", name, obj))) + c.Writer.Write(fmt.Appendf(nil, "\n
%v
", name, obj)) } // Redirect 执行 HTTP 重定向 @@ -473,7 +491,7 @@ func (c *Context) ShouldBindJSON(obj any) error { return nil } -// ShouldBindWANF +// ShouldBindWANF 尝试将 WANF 格式的请求体绑定到对象 func (c *Context) ShouldBindWANF(obj any) error { if c.Request.Body == nil { return errors.New("request body is empty") @@ -489,23 +507,174 @@ func (c *Context) ShouldBindWANF(obj any) error { return nil } -// Deprecated: This function is a reserved placeholder for future API extensions -// and is not yet implemented. It will either be properly defined or removed in v2.0.0. Do not use. -// ShouldBind 尝试将请求体绑定到各种类型(JSON, Form, XML 等) -// 这是一个复杂的通用绑定接口,通常根据 Content-Type 或其他头部来判断绑定方式 -// 预留接口,可根据项目需求进行扩展 +// ShouldBindGOB 尝试将 GOB 格式的请求体绑定到对象 +func (c *Context) ShouldBindGOB(obj any) error { + if c.Request.Body == nil { + return errors.New("request body is empty") + } + decoder := gob.NewDecoder(c.Request.Body) + if err := decoder.Decode(obj); err != nil { + return fmt.Errorf("GOB binding error: %w", err) + } + return nil +} + +// bindForm 将 url.Values 绑定到结构体 +// 支持 form tag 标签,如 `form:"field_name"` +func bindForm(values url.Values, obj any) error { + val := reflect.ValueOf(obj) + if val.Kind() != reflect.Pointer || val.Elem().Kind() != reflect.Struct { + return errors.New("obj must be a pointer to struct") + } + + val = val.Elem() + typ := val.Type() + + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + fieldType := typ.Field(i) + + if !field.CanSet() { + continue + } + + tag := fieldType.Tag.Get("form") + if tag == "" { + tag = fieldType.Name + } + if tag == "-" { + continue + } + + formValues := values[tag] + if len(formValues) == 0 { + continue + } + + if err := setFieldValue(field, formValues); err != nil { + return fmt.Errorf("field %s: %w", fieldType.Name, err) + } + } + return nil +} + +// setFieldValue 将字符串值设置到反射值 +func setFieldValue(field reflect.Value, values []string) error { + if !field.CanSet() { + return nil + } + + value := values[0] + + switch field.Kind() { + case reflect.String: + field.SetString(value) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if value == "" { + return nil + } + v, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + field.SetInt(v) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if value == "" { + return nil + } + v, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + field.SetUint(v) + case reflect.Float32, reflect.Float64: + if value == "" { + return nil + } + v, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + field.SetFloat(v) + case reflect.Bool: + if value == "" { + return nil + } + v, err := strconv.ParseBool(value) + if err != nil { + return err + } + field.SetBool(v) + case reflect.Pointer: + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + return setFieldValue(field.Elem(), values) + case reflect.Slice: + slice := reflect.MakeSlice(field.Type(), len(values), len(values)) + elemType := field.Type().Elem() + for i, v := range values { + if err := setFieldValue(slice.Index(i), []string{v}); err != nil { + return err + } + _ = elemType + } + field.Set(slice) + default: + return fmt.Errorf("unsupported type: %s", field.Kind()) + } + return nil +} + +// ShouldBindForm 尝试将表单数据绑定到结构体 +// 支持 application/x-www-form-urlencoded 和 multipart/form-data +func (c *Context) ShouldBindForm(obj any) error { + contentType := c.Request.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return fmt.Errorf("invalid content type: %w", err) + } + + switch mediaType { + case "multipart/form-data": + if err := c.Request.ParseMultipartForm(32 << 20); err != nil { + return fmt.Errorf("parse multipart form error: %w", err) + } + case "application/x-www-form-urlencoded": + if err := c.Request.ParseForm(); err != nil { + return fmt.Errorf("parse form error: %w", err) + } + default: + return fmt.Errorf("unsupported form content type: %s", mediaType) + } + + if err := bindForm(c.Request.Form, obj); err != nil { + return fmt.Errorf("form binding error: %w", err) + } + return nil +} + +// ShouldBind 尝试根据 Content-Type 将请求体绑定到结构体 +// 支持的类型:application/json, application/x-www-form-urlencoded, multipart/form-data, application/wanf, application/vnd.wjqserver.wanf, application/gob func (c *Context) ShouldBind(obj any) error { - // TODO: 完整的通用绑定逻辑 - // 可以根据 c.Request.Header.Get("Content-Type") 来判断是 JSON, Form, XML 等 - // 例如: - // contentType := c.Request.Header.Get("Content-Type") - // if strings.HasPrefix(contentType, "application/json") { - // return c.ShouldBindJSON(obj) - // } - // if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") || strings.HasPrefix(contentType, "multipart/form-data") { - // return c.ShouldBindForm(obj) // 需要实现 ShouldBindForm - // } - return errors.New("generic binding not fully implemented yet, implement based on Content-Type") + contentType := c.Request.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return fmt.Errorf("invalid content type: %w", err) + } + + switch mediaType { + case "application/json": + return c.ShouldBindJSON(obj) + case "application/x-www-form-urlencoded", "multipart/form-data": + return c.ShouldBindForm(obj) + case "application/wanf", "application/vnd.wjqserver.wanf": + return c.ShouldBindWANF(obj) + case "application/gob": + return c.ShouldBindGOB(obj) + default: + return fmt.Errorf("unsupported content type: %s", mediaType) + } } // AddError 添加一个错误到 Context diff --git a/docs/advanced.md b/docs/advanced.md new file mode 100644 index 0000000..a7cb9a2 --- /dev/null +++ b/docs/advanced.md @@ -0,0 +1,317 @@ +# 高级特性与优化 + +本章节涵盖了 Touka 的一些深层特性以及在生产环境中的最佳实践。 + +## 性能优化 + +### 1. Context 池化 + +Touka 使用 `sync.Pool` 来重用 `touka.Context` 对象。这极大减少了每个请求产生的内存分配和 GC 压力。 +- **代价**: 您必须在处理器返回后立即停止对该 `Context` 指针的任何引用。 +- **解决方案**: 如果需要在后台 Goroutine 中使用请求数据,请预先提取所需数据(如 `c.Query` 的值),或者深拷贝该对象(不推荐)。 + +### 2. 预分配参数切片 + +在路由匹配过程中,Touka 会预分配路径参数切片,并根据路由深度进行缓存,从而在路由查找时实现几乎零分配。 + +## 服务器配置 + +### 服务器配置器 (ServerConfigurator) + +Touka 允许您在服务器启动前对底层 `*http.Server` 进行自定义配置: + +```go +r := touka.New() + +// 配置 HTTP 服务器 +r.SetServerConfigurator(func(server *http.Server) { + server.ReadTimeout = 30 * time.Second + server.WriteTimeout = 30 * time.Second + server.IdleTimeout = 120 * time.Second + server.MaxHeaderBytes = 1 << 20 // 1MB +}) + +// 专门配置 HTTPS 服务器(优先级高于 ServerConfigurator) +r.SetTLSServerConfigurator(func(server *http.Server) { + server.ReadTimeout = 30 * time.Second + server.WriteTimeout = 30 * time.Second + // HTTPS 特定配置... +}) +``` + +### 协议配置 + +Touka 支持配置 HTTP/1.1、HTTP/2 和 H2C(HTTP/2 Cleartext): + +```go +// 使用默认协议配置(仅 HTTP/1.1) +r.SetDefaultProtocols() + +// 自定义协议配置 +r.SetProtocols(&touka.ProtocolsConfig{ + Http1: true, // 启用 HTTP/1.1 + Http2: true, // 启用 HTTP/2(需要 TLS) + Http2_Cleartext: true, // 启用 H2C(无需 TLS 的 HTTP/2) +}) +``` + +### 启动方式 + +Touka 提供了多种服务器启动方式: + +```go +// 1. 简单启动(无优雅停机) +r.Run(":8080") + +// 2. 带优雅停机的启动 +r.RunShutdown(":8080", 10*time.Second) + +// 3. 带上下文的优雅停机 +ctx, cancel := context.WithCancel(context.Background()) +r.RunShutdownWithContext(":8080", ctx, 10*time.Second) + +// 4. HTTPS 启动 +tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + // 其他 TLS 配置... +} +r.RunTLS(":443", tlsConfig, 10*time.Second) + +// 5. HTTPS + HTTP 重定向 +r.RunTLSRedir(":80", ":443", tlsConfig, 10*time.Second) +``` + +## 优雅停机 (Graceful Shutdown) + +在部署新版本时,我们希望服务器停止接收新请求,但能处理完当前正在进行的请求。 + +```go +r := touka.Default() +// ... 注册路由 ... + +// 监听 SIGINT 和 SIGTERM 信号 +// 如果在 10 秒内未处理完,则强制关闭 +if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + log.Fatal("服务器退出异常:", err) +} +``` + +### SSE 长连接的优雅关闭 + +对于 SSE 等长连接场景,Touka 会自动将引擎的关闭信号注入到请求的 Context 中: + +```go +r.GET("/events", func(c *touka.Context) { + c.EventStream(func(w io.Writer) bool { + select { + case <-c.Request.Context().Done(): + // 收到关闭信号,优雅退出 + return false + case <-time.After(1 * time.Second): + // 发送数据 + event := touka.Event{Data: "tick"} + event.Render(w) + return true + } + }) +}) +``` + +## 路由行为配置 + +```go +r := touka.New() + +// 是否自动重定向尾部斜杠(默认 true) +// /foo/ -> /foo 或 /foo -> /foo/ +r.SetRedirectTrailingSlash(true) + +// 是否自动修复路径大小写(默认 true) +// /FOO -> /foo +r.SetRedirectFixedPath(true) + +// 是否处理 405 Method Not Allowed(默认 true) +// 当路径匹配但方法不匹配时返回 405 而非 404 +r.SetHandleMethodNotAllowed(true) +``` + +### 自定义 404 处理 + +```go +// 单个处理器 +r.NoRoute(func(c *touka.Context) { + c.JSON(http.StatusNotFound, touka.H{ + "error": "Page not found", + "path": c.Request.URL.Path, + }) +}) + +// 处理器链(可以在 404 前执行额外中间件) +r.NoRoutes( + LogNotFoundMiddleware(), + func(c *touka.Context) { + c.JSON(http.StatusNotFound, touka.H{"error": "Not found"}) + }, +) +``` + +### 未匹配路径作为静态文件服务 + +```go +// 当没有路由匹配时,尝试从文件系统中查找文件 +// 非常适合单页应用(SPA)部署 +r.SetUnMatchFS(http.Dir("./frontend/dist")) + +// 也可以添加额外的中间件 +r.SetUnMatchFS(http.Dir("./frontend/dist"), AuthMiddleware()) +``` + +## IP 地址解析配置 + +在反向代理环境中,正确配置 IP 解析非常重要: + +```go +r := touka.New() + +// 是否信任代理头部获取客户端 IP(默认 true) +r.SetForwardByClientIP(true) + +// 设置用于获取客户端 IP 的头部列表(按优先级排序) +r.SetRemoteIPHeaders([]string{ + "X-Forwarded-For", + "X-Real-IP", + "CF-Connecting-IP", // Cloudflare +}) +``` + +如果您同时使用 Touka 的 `ReverseProxy` 把请求继续转发给其他后端,请再参考 `docs/reverse-proxy.md` 中关于 `Forwarded`、`X-Forwarded-*` 与 `Via` 的说明。前者解决“当前请求的客户端 IP 如何被 Touka 正确解析”,后者解决“代理后的请求如何把链路信息继续传给下一跳”。 + +## 请求体大小限制 + +为了防止恶意的大数据包攻击(如慢速 HTTP 攻击或内存溢出),Touka 内置了请求体大小限制机制。 + +### 全局限制 + +```go +// 设置全局最大请求体大小(例如 10MB) +r.SetGlobalMaxRequestBodySize(10 << 20) +``` + +### 单个请求限制 + +```go +r.POST("/upload", func(c *touka.Context) { + // 为特定请求设置限制(覆盖全局设置) + c.SetMaxRequestBodySize(100 << 20) // 100MB + + body, err := c.GetReqBodyFull() + if err != nil { + // 如果超过限制,会返回 ErrBodyTooLarge + if errors.Is(err, touka.ErrBodyTooLarge) { + c.ErrorUseHandle(http.StatusRequestEntityTooLarge, err) + return + } + c.ErrorUseHandle(http.StatusBadRequest, err) + return + } + // 处理 body... +}) +``` + +## 与标准库集成 + +Touka 遵循 `net/http` 哲学。您可以方便地使用现有的标准库组件。 + +### 适配 `http.HandlerFunc` + +```go +r.GET("/pprof/*any", touka.AdapterStdFunc(pprof.Index)) +``` + +### 适配 `http.Handler` + +```go +// 适配 http.FileServer +fileServer := http.FileServer(http.Dir("./static")) +r.GET("/static/*filepath", touka.AdapterStdHandle(http.StripPrefix("/static", fileServer))) +``` + +### 手动注入 + +由于 `Engine` 实现了 `http.Handler` 接口,您可以将其挂载到任何地方。 + +```go +s := &http.Server{ + Addr: ":8080", + Handler: r, // Engine 实例 + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, +} +s.ListenAndServe() +``` + +## 自定义日志集成 + +Touka 默认集成了 `reco` 日志库。您可以自定义其输出行为。 + +```go +logConfig := reco.Config{ + Level: reco.LevelInfo, + Mode: reco.ModeText, // 或 reco.ModeJSON + Output: os.Stdout, + Async: true, // 异步写入提高性能 + TimeFormat: time.RFC3339, +} +r.SetLoggerCfg(logConfig) + +// 或直接传入日志实例 +logger, _ := reco.New(logConfig) +r.SetLogger(logger) + +// 关闭日志(在服务器关闭时) +defer r.CloseLogger() +``` + +## HTTP 客户端配置 + +Touka 内置了 `httpc` HTTP 客户端,可以在请求处理中方便地发起出站请求: + +```go +// 创建自定义 HTTP 客户端 +customClient := httpc.New() +r.SetHTTPClient(customClient) + +// 在处理器中使用 +r.GET("/proxy", func(c *touka.Context) { + resp, err := c.GetHTTPC().Get("https://api.example.com/data") + // ... +}) +``` + +## 条件中间件 + +Touka 支持根据条件动态启用或禁用中间件: + +```go +// 单个条件中间件 +r.Use(r.UseIf(config.EnableLogging, AccessLoggerMiddleware())) + +// 条件中间件链 +r.Use(r.UseChainIf(config.EnableMetrics, + MetricsMiddleware, + PrometheusMiddleware, + MonitoringMiddleware, +)) +``` + +## 获取路由信息 + +```go +// 获取所有已注册的路由信息 +routes := r.GetRouterInfo() +for _, route := range routes { + fmt.Printf("Method: %s, Path: %s, Handler: %s, Group: %s\n", + route.Method, route.Path, route.Handler, route.Group) +} +``` diff --git a/docs/context.md b/docs/context.md new file mode 100644 index 0000000..c13c158 --- /dev/null +++ b/docs/context.md @@ -0,0 +1,469 @@ +# 上下文 (Context) + +`touka.Context` 是 Touka 框架中最重要的结构。它携带了关于当前 HTTP 请求的所有必要信息,并提供了一系列方法来解析请求和构建响应。 + +## 请求数据解析 + +### 路径参数 (Path Parameters) + +```go +// 路由: /users/:id +r.GET("/users/:id", func(c *touka.Context) { + id := c.Param("id") + c.String(http.StatusOK, "User ID: %s", id) +}) +``` + +### 查询参数 (Query Parameters) + +```go +// /welcome?firstname=Jane&lastname=Doe +r.GET("/welcome", func(c *touka.Context) { + firstname := c.DefaultQuery("firstname", "Guest") + lastname := c.Query("lastname") // 快捷方式,不存在则为空 + + c.String(http.StatusOK, "Hello %s %s", firstname, lastname) +}) +``` + +### 表单数据 (Form Data) + +```go +r.POST("/form_post", func(c *touka.Context) { + message := c.PostForm("message") + nick := c.DefaultPostForm("nick", "anonymous") + + c.JSON(http.StatusOK, touka.H{ + "status": "posted", + "message": message, + "nick": nick, + }) +}) +``` + +### 请求体读取 + +```go +// 读取完整请求体 +r.POST("/raw", func(c *touka.Context) { + body, err := c.GetReqBodyFull() + if err != nil { + c.ErrorUseHandle(http.StatusBadRequest, err) + return + } + c.Raw(http.StatusOK, "application/octet-stream", body) +}) + +// 获取 io.ReadCloser(只能读取一次) +r.POST("/stream", func(c *touka.Context) { + reader := c.GetReqBody() + defer reader.Close() + // 处理 reader... +}) +``` + +### 客户端信息 + +```go +r.GET("/client-info", func(c *touka.Context) { + // 获取客户端 IP(支持代理转发) + ip := c.RequestIP() + // 或使用别名 + ip = c.ClientIP() + + // 获取 User-Agent + ua := c.UserAgent() + + // 获取 Content-Type + ct := c.ContentType() + + // 获取请求协议 + proto := c.GetProtocol() + + c.JSON(http.StatusOK, touka.H{ + "ip": ip, + "userAgent": ua, + "protocol": proto, + }) +}) +``` + +### 请求头 + +```go +r.GET("/headers", func(c *touka.Context) { + // 获取单个请求头 + auth := c.GetReqHeader("Authorization") + + // 获取所有请求头 + allHeaders := c.GetAllReqHeader() + + c.JSON(http.StatusOK, touka.H{ + "authorization": auth, + "allHeaders": allHeaders, + }) +}) +``` + +## 数据绑定 + +### JSON 绑定 + +Touka 提供了非常便捷的 JSON 绑定功能,它会自动解析请求体并填充到结构体中。 + +```go +type LoginRequest struct { + User string `json:"user"` + Password string `json:"password"` +} + +r.POST("/login", func(c *touka.Context) { + var json LoginRequest + if err := c.ShouldBindJSON(&json); err != nil { + c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()}) + return + } + + if json.User != "admin" || json.Password != "123" { + c.JSON(http.StatusUnauthorized, touka.H{"status": "unauthorized"}) + return + } + + c.JSON(http.StatusOK, touka.H{"status": "you are logged in"}) +}) +``` + +### 表单绑定 + +```go +type UserForm struct { + Name string `form:"name"` + Email string `form:"email"` + Age int `form:"age"` +} + +r.POST("/user", func(c *touka.Context) { + var form UserForm + if err := c.ShouldBindForm(&form); err != nil { + c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, form) +}) +``` + +### 通用绑定 + +`ShouldBind` 方法会根据请求的 `Content-Type` 自动选择绑定方式: + +```go +r.POST("/data", func(c *touka.Context) { + var data MyData + // 自动根据 Content-Type 绑定(支持 JSON、Form、WANF、GOB) + if err := c.ShouldBind(&data); err != nil { + c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, data) +}) +``` + +### WANF 绑定 + +```go +r.POST("/wanf", func(c *touka.Context) { + var data MyData + if err := c.ShouldBindWANF(&data); err != nil { + c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, data) +}) +``` + +### GOB 绑定 + +```go +r.POST("/gob", func(c *touka.Context) { + var data MyData + if err := c.ShouldBindGOB(&data); err != nil { + c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, data) +}) +``` + +## 响应构建 + +### 基础格式 + +Touka 支持多种响应格式: + +```go +// JSON +c.JSON(http.StatusOK, touka.H{"message": "hey"}) + +// 字符串 (支持格式化) +c.String(http.StatusOK, "welcome %s", name) + +// 纯文本 +c.Text(http.StatusOK, "just text") + +// 原始数据 +c.Raw(http.StatusOK, "application/octet-stream", []byte("raw bytes")) + +// HTML 模板 +c.HTML(http.StatusOK, "index.tmpl", touka.H{"title": "Main website"}) +``` + +### WANF 响应 + +```go +// WANF 格式响应 +c.WANF(http.StatusOK, touka.H{"message": "wanf format"}) +``` + +### GOB 响应 + +```go +// GOB 格式响应 +c.GOB(http.StatusOK, myData) +``` + +### 文件与流 + +```go +// 服务本地文件(触发浏览器下载) +c.File("/local/file.go") + +// 将文件内容作为响应体(不触发下载) +c.SetRespBodyFile(http.StatusOK, "config.json") + +// 以文本形式发送文件 +c.FileText(http.StatusOK, "/path/to/file.txt") + +// 写入数据流 +c.WriteStream(reader) + +// 设置响应体为流 +c.SetBodyStream(reader, contentSize) // contentSize 为 -1 表示未知大小 +``` + +### 响应头操作 + +```go +// 设置响应头 +c.SetHeader("X-Custom-Header", "value") + +// 添加响应头(不覆盖已有值) +c.AddHeader("X-Custom-Header", "another-value") + +// 删除响应头 +c.DelHeader("X-Custom-Header") + +// 批量设置响应头 +c.SetHeaders(map[string][]string{ + "X-Header-1": {"value1"}, + "X-Header-2": {"value2a", "value2b"}, +}) + +// 获取所有响应头 +headers := c.GetAllRespHeader() +``` + +### 状态码 + +```go +// 设置状态码(不写入响应体) +c.Status(http.StatusNoContent) +``` + +### 重定向 + +```go +c.Redirect(http.StatusMovedPermanently, "http://google.com/") +``` + +## Cookie 操作 + +```go +// 设置 Cookie +c.SetCookie("session_id", "abc123", 3600, "/", "example.com", true, true) + +// 设置 SameSite 属性 +c.SetSameSite(http.SameSiteStrictMode) + +// 使用完整 Cookie 对象 +cookie := &http.Cookie{ + Name: "token", + Value: "xyz", + Path: "/", +} +c.SetCookieData(cookie) + +// 获取 Cookie +value, err := c.GetCookie("session_id") +if err != nil { + c.String(http.StatusUnauthorized, "Cookie not found") + return +} + +// 删除 Cookie +c.DeleteCookie("session_id") +``` + +## 数据传递 (Keys/Values) + +您可以在中间件和处理器之间共享数据。 + +```go +// 在中间件中设置 +c.Set("user_id", 12345) + +// 在处理器中获取 +id, exists := c.Get("user_id") +val := c.MustGet("user_id").(int) + +// 类型安全的获取方法 +str, exists := c.GetString("key") +i, exists := c.GetInt("key") +b, exists := c.GetBool("key") +f, exists := c.GetFloat64("key") +t, exists := c.GetTime("key") +d, exists := c.GetDuration("key") +``` + +## 错误处理 + +```go +r.GET("/error", func(c *touka.Context) { + // 添加错误到上下文(可以添加多个) + c.AddError(errors.New("error 1")) + c.AddError(errors.New("error 2")) + + // 获取所有错误 + errs := c.GetErrors() + + // 使用全局错误处理器 + c.ErrorUseHandle(http.StatusInternalServerError, errors.New("something went wrong")) +}) +``` + +## 日志记录 + +Touka 集成了 `reco` 日志库,可以直接在 Context 中使用: + +```go +r.GET("/log", func(c *touka.Context) { + c.Debugf("Debug message: %s", "details") + c.Infof("User accessed /log") + c.Warnf("Warning: %v", someWarning) + c.Errorf("Error occurred: %v", someError) + + // 获取底层日志器 + logger := c.GetLogger() + logger.CustomLog("level", "message") + + c.String(http.StatusOK, "Logged") +}) +``` + +## HTTP 客户端 + +Touka 集成了 `httpc` HTTP 客户端,方便发起出站请求: + +```go +r.GET("/proxy", func(c *touka.Context) { + // 获取 HTTP 客户端 + client := c.GetHTTPC() + // 或 + client = c.Client() + + // 发起请求 + resp, err := client.Get("https://api.example.com/data") + if err != nil { + c.ErrorUseHandle(http.StatusBadGateway, err) + return + } + defer resp.Body.Close() + + // 将响应流式传输给客户端 + c.SetHeader("Content-Type", resp.Header.Get("Content-Type")) + c.WriteStream(resp.Body) +}) +``` + +## 状态管理 + +- `c.Abort()`: 停止执行后续的处理器/中间件。 +- `c.AbortWithStatus(code)`: 中止并设置状态码。 +- `c.IsAborted()`: 检查是否已中止。 +- `c.Next()`: 执行后续的处理链。这常用于中间件中,在执行完某些前置逻辑后,显式调用 `Next`,并在其返回后执行后置逻辑。 + +## 请求上下文 (Go Context) + +Touka Context 实现了 Go 标准库的 `context.Context` 接口: + +```go +r.GET("/long-task", func(c *touka.Context) { + // 获取 Go context + ctx := c.Context() + + // 监听取消信号 + select { + case <-ctx.Done(): + // 客户端断开连接或超时 + return + case result := <-doLongTask(ctx): + c.JSON(http.StatusOK, result) + } +}) + +// 其他 context 方法 +done := c.Done() // 获取 Done channel +err := c.Err() // 获取错误 +val := c.Value("key") // 获取值(同时查找 Keys 和 Go context) +``` + +## 其他方法 + +```go +// 获取原始请求 URI +uri := c.GetRequestURI() + +// 获取请求路径 +path := c.GetRequestURIPath() + +// 获取查询字符串 +query := c.GetReqQueryString() + +// 获取请求协议版本 +proto := c.GetProtocol() // 例如 "HTTP/1.1" +``` + +## 对象池化 + +为了提高性能,Touka 的 Context 对象是复用的。 + +**重要提示:不要在 Goroutine 中持久化持有 `touka.Context` 指针。如果您需要在 Goroutine 中使用请求数据,请务必在派生 Goroutine 前提取所需的值。** + +```go +// 错误示例 ❌ +r.GET("/bad", func(c *touka.Context) { + go func() { + time.Sleep(5 * time.Second) + // 此时 c 可能已被复用,数据不安全 + log.Println(c.Query("name")) + }() +}) + +// 正确示例 ✓ +r.GET("/good", func(c *touka.Context) { + name := c.Query("name") // 提前提取值 + go func() { + time.Sleep(5 * time.Second) + log.Println(name) // 使用提取的值,安全 + }() +}) +``` diff --git a/docs/error-handling.md b/docs/error-handling.md new file mode 100644 index 0000000..5b53579 --- /dev/null +++ b/docs/error-handling.md @@ -0,0 +1,66 @@ +# 统一错误处理 + +Touka 的核心优势之一是其**高度统一且自动化**的错误处理机制。 + +## 全局错误处理器 + +您可以为整个引擎设置一个统一的错误处理器。无论错误来自您的业务代码,还是来自框架内部(如 404/405),甚至是来自标准库的 `http.FileServer`,最终都会流向这个处理器。 + +```go +r.SetErrorHandler(func(c *touka.Context, code int, err error) { + // 您可以在这里定义统一的错误响应格式 + c.JSON(code, touka.H{ + "code": code, + "message": http.StatusText(code), + "detail": err.Error(), + }) + + // 也可以记录日志 + c.Errorf("HTTP Error %d: %v", code, err) +}) +``` + +## `errorCapturingResponseWriter` (ecw) 的工作原理 + +很多时候,我们希望拦截标准库组件(如 `http.FileServer`)产生的错误,以便能够应用我们自定义的 404 页面或 JSON 响应。 + +Touka 通过包装标准的 `http.ResponseWriter` 实现了这一点: + +1. **拦截写入**: 当 `http.FileServer` 等组件尝试调用 `WriteHeader(statusCode)` 且 `statusCode >= 400` 时,Touka 的包装器会捕获这个状态码。 +2. **阻止输出**: 它会阻止组件继续向响应体写入默认的错误消息(如 `404 page not found`)。 +3. **回调处理**: 包装器随后会调用全局配置的 `ErrorHandler`。 + +这意味着您可以像这样轻松地为静态文件服务设置自定义错误处理: + +```go +r := touka.New() + +// 设置全局错误处理 +r.SetErrorHandler(func(c *touka.Context, code int, err error) { + if code == http.StatusNotFound { + c.String(http.StatusNotFound, "找不到此资源") + return + } + c.String(code, "发生错误: %v", err) +}) + +// 服务静态目录 +r.StaticDir("/static", "./public") +// 如果用户访问 /static/missing-file.jpg,他将看到 "找不到此资源" +``` + +## 手动触发错误处理 + +您也可以在处理器中通过 `c.ErrorUseHandle` 手动触发此流程: + +```go +r.GET("/item/:id", func(c *touka.Context) { + item, err := db.GetItem(c.Param("id")) + if err != nil { + // 调用全局错误处理器 + c.ErrorUseHandle(http.StatusInternalServerError, err) + return + } + c.JSON(http.StatusOK, item) +}) +``` diff --git a/docs/introduction.md b/docs/introduction.md new file mode 100644 index 0000000..94a7310 --- /dev/null +++ b/docs/introduction.md @@ -0,0 +1,27 @@ +# Touka (灯花) 简介 + +Touka 是一个基于 Go 语言构建的高性能、多层次 Web 框架。其设计目标是为开发者提供**更直接的控制、有效的扩展能力,以及针对特定场景的行为优化**。 + +## 为什么选择 Touka? + +在众多的 Go Web 框架中,Touka 致力于在保持类似 Gin 的易用性的同时,提供更深度的底层控制和更强大的统一错误处理机制。 + +### 核心特性 + +- **高性能路由**: 基于基数树(Radix Tree)实现的路由系统,支持高效的路径匹配、参数捕获和通配符路由。 +- **极致性能优化**: + - **Context 复用**: 通过对象池(sync.Pool)管理 `touka.Context`,显著减少 GC 压力。 + - **最小化内存分配**: 在热点路径上尽可能减少临时对象的产生。 +- **统一错误处理**: 独创的 `errorCapturingResponseWriter` 机制,能够捕获包括标准库 `http.FileServer` 在内的所有组件产生的错误状态码,并交由全局处理器统一处理。 +- **无缝集成 SSE**: 内置对 Server-Sent Events 的支持,提供简单易用的回调式 API 和高度灵活的通道式 API。 +- **内置反向代理**: 支持请求转发、协议升级、转发头维护、Trailer 与流式响应透传。 +- **静态资源增强**: 针对本地文件、目录以及 Go 嵌入式文件系统(embed.FS)提供了开箱即用的支持。 +- **标准库兼容**: 提供了适配器,可以轻松将现有的 `http.Handler` 或 `http.HandlerFunc` 集成到 Touka 中。 + +## 设计哲学 + +1. **直接性**: 框架 API 设计直观,尽可能减少开发者需要记忆的概念。 +2. **可扩展性**: 每一个核心组件(如日志、错误处理器、渲染器)都是可插拔或可定制的。 +3. **健壮性**: 内置优雅停机支持,确保在服务器更新或关闭时请求能得到正确处理。 + +Touka 不仅仅是一个处理 HTTP 请求的工具,它还是构建现代化、可维护、高可用 Web 应用的坚实基础。 diff --git a/docs/middleware.md b/docs/middleware.md new file mode 100644 index 0000000..a222437 --- /dev/null +++ b/docs/middleware.md @@ -0,0 +1,99 @@ +# 中间件 (Middleware) + +中间件是运行在 HTTP 请求处理链中的函数。它们可以拦截请求、修改数据、控制流向(通过 `c.Next()` 或 `c.Abort()`),并执行通用的前置/后置逻辑。 + +## 如何使用中间件 + +### 全局中间件 + +全局中间件应用于所有注册的路由。 + +```go +r := touka.New() +r.Use(touka.Recovery()) // 崩溃恢复 +r.Use(MyCustomLogger()) // 自定义日志 +``` + +### 路由组中间件 + +仅应用于特定组下的路由。 + +```go +api := r.Group("/api") +api.Use(AuthMiddleware()) +{ + api.GET("/user", handleUser) +} +``` + +## 编写自定义中间件 + +中间件的函数签名是 `touka.HandlerFunc`。 + +### 示例:请求计时器 + +```go +func TimerMiddleware() touka.HandlerFunc { + return func(c *touka.Context) { + // --- 前置逻辑 --- + start := time.Now() + + // 执行处理链中的下一个函数 + c.Next() + + // --- 后置逻辑 --- + duration := time.Since(start) + log.Printf("Request %s %s took %v", c.Request.Method, c.Request.URL.Path, duration) + } +} +``` + +### 示例:简单的 API 密钥验证 + +```go +func APIKeyAuth() touka.HandlerFunc { + return func(c *touka.Context) { + apiKey := c.GetReqHeader("X-API-KEY") + if apiKey != "secret-token" { + // 验证失败,返回错误并中止后续逻辑 + c.JSON(http.StatusUnauthorized, touka.H{"error": "Invalid API Key"}) + c.Abort() + return + } + + // 验证通过,继续执行 + c.Next() + } +} +``` + +## 内置中间件 + +- **Recovery**: 捕获任何发生的 panic,恢复运行并返回 500 错误。它还负责调用全局错误处理器。 + +Touka 的设计非常精简,许多扩展功能(如 Gzip, JWT, Sessions)由外部或第三方库提供,您可以轻松通过 `r.Use()` 集成它们。 + +## 条件中间件 (Conditional Middleware) + +Touka 支持根据布尔条件动态启用或禁用中间件。这在根据环境配置启用插件时非常有用。 + +### `UseIf` + +```go +// 仅在 Debug 模式下启用日志 +r.Use(r.UseIf(config.IsDebug, MyDebugLogger)) +``` + +### `UseChainIf` (条件中间件链) + +如果您有一组相关的中间件需要根据同一条件启用,可以使用 `UseChainIf`。 + +```go +r.Use(r.UseChainIf(config.EnableMetrics, + MetricsMiddleware, + PrometheusMiddleware, + MonitoringMiddleware, +)) +``` + +这些方法利用了 `MiddlewareXFunc`(即返回 `HandlerFunc` 的工厂函数),确保中间件实例按需创建或高效复用。 diff --git a/docs/quickstart.md b/docs/quickstart.md new file mode 100644 index 0000000..94f7433 --- /dev/null +++ b/docs/quickstart.md @@ -0,0 +1,76 @@ +# 快速开始 + +本指南将帮助您在几分钟内启动并运行一个 Touka 应用。 + +## 安装 + +确保您的环境中已经安装了 Go 1.26 或更高版本。 + +在您的项目目录中运行: + +```bash +go get github.com/infinite-iroha/touka +``` + +## 基础示例 + +创建一个 `main.go` 文件,并粘贴以下代码: + +```go +package main + +import ( + "net/http" + "time" + "log" + "github.com/infinite-iroha/touka" +) + +func main() { + // 1. 创建默认引擎(包含 Recovery 中间件) + r := touka.Default() + + // 2. 注册一个简单的 GET 路由 + r.GET("/ping", func(c *touka.Context) { + c.JSON(http.StatusOK, touka.H{ + "message": "pong", + "time": time.Now().Unix(), + }) + }) + + // 3. 注册带参数的路由 + r.GET("/hello/:name", func(c *touka.Context) { + name := c.Param("name") + c.String(http.StatusOK, "Hello, %s!", name) + }) + + // 4. 启动服务器并监听 8080 端口 + log.Println("Touka server is running on :8080") + if err := r.Run(":8080"); err != nil { + log.Fatalf("Server failed: %v", err) + } +} +``` + +## 运行应用 + +执行以下命令启动服务器: + +```bash +go run main.go +``` + +现在,您可以访问: +- `http://localhost:8080/ping` +- `http://localhost:8080/hello/World` + +## 优雅停机 + +在生产环境中,我们推荐使用 `RunShutdown` 方法来启动服务器,它会监听系统信号并在关闭前等待正在处理的请求完成。 + +```go +// 等待 10 秒以处理剩余请求 +if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + log.Fatalf("Server forced to shutdown: %v", err) +} +``` diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md new file mode 100644 index 0000000..5dfcbd1 --- /dev/null +++ b/docs/reverse-proxy.md @@ -0,0 +1,377 @@ +# 反向代理 + +Touka 内置了反向代理能力,可以直接把某一组请求转发到后端服务,同时保留 Touka 的路由、中间件与统一错误处理风格。 + +`touka.ReverseProxy` 返回一个 `HandlerFunc`,因此它可以像普通路由处理器一样直接挂到 `GET`、`ANY`、路由组等位置。 + +## 最简单的用法 + +```go +package main + +import ( + "log" + "net/url" + + "github.com/infinite-iroha/touka" +) + +func main() { + r := touka.Default() + + target, err := url.Parse("http://127.0.0.1:9000") + if err != nil { + log.Fatal(err) + } + + r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + })) + + _ = r.Run(":8080") +} +``` + +当客户端访问 `http://127.0.0.1:8080/api/users` 时,请求会被转发到 `http://127.0.0.1:9000/api/users`。 + +## 带基础路径的代理 + +如果目标服务部署在一个子路径下,可以直接把目标地址写成带路径的 URL: + +```go +target, _ := url.Parse("http://127.0.0.1:9000/backend") + +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, +})) +``` + +此时: + +- `/api/users` 会转发到 `/backend/api/users` +- `/api/orders?id=10` 会转发到 `/backend/api/orders?id=10` + +目标 URL 自身携带的查询参数也会被保留并与原请求查询参数合并。 +合并后的出站查询串会再经过一次规范化处理,因此某些非标准分隔符(例如 `;`)或非法参数片段可能被重编码、折叠或直接丢弃。 +这是为了尽量让代理链各跳对查询参数的解析结果保持一致,并减少参数走私这类解析歧义风险。 + +## 配置项说明 + +```go +type ReverseProxyConfig struct { + Target *url.URL + + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool + + ModifyRequest func(*http.Request) + ModifyResponse func(*http.Response) error + ErrorHandler func(http.ResponseWriter, *http.Request, error) + + ForwardedHeaders ForwardedHeadersPolicy + ForwardedBy string + Via string + PreserveHost bool +} +``` + +### `Target` + +必填。表示后端目标地址,至少需要提供 `scheme` 和 `host`。 + +```go +target, _ := url.Parse("http://backend:9000") +``` + +### `Transport` + +可选。用于自定义底层转发所使用的 `http.RoundTripper`。 + +如果留空,则默认使用 `http.DefaultTransport`。 + +```go +proxyTransport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 20, +} + +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + Transport: proxyTransport, +})) +``` + +### `FlushInterval` + +控制代理在复制响应体时的主动刷新间隔: + +- `0`:不额外定时刷新 +- `> 0`:按指定间隔刷新 +- `< 0`:每次写入后立即刷新 + +对于 SSE 和无 `Content-Length` 的流式响应,Touka 会自动立即刷新,不依赖该配置。 + +### `BufferPool` + +可选。用于为响应体复制过程提供可复用的字节缓冲区,以减少大响应或高并发代理场景下的临时内存分配。 + +如果留空,Touka 会在复制响应体时按需分配默认缓冲区。 + +```go +type bytePool struct { + pool sync.Pool +} + +func (p *bytePool) Get() []byte { + if buf, ok := p.pool.Get().([]byte); ok { + return buf + } + return make([]byte, 32*1024) +} + +func (p *bytePool) Put(buf []byte) { + if cap(buf) >= 32*1024 { + p.pool.Put(buf[:32*1024]) + } +} + +proxyPool := &bytePool{} + +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + BufferPool: proxyPool, +})) +``` + +通常只有在您已经观察到明显的分配压力,或代理的响应体较大、吞吐较高时,才需要专门配置它。 + +### `ModifyRequest` + +在请求真正发往后端前,对出站请求做最后修改。 + +常见用途: + +- 覆盖 `Host` +- 增加鉴权头 +- 重写路径 +- 注入内部追踪头 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ModifyRequest: func(req *http.Request) { + req.Header.Set("X-Internal-Token", "gateway-token") + }, +})) +``` + +### `ModifyResponse` + +在后端返回响应后、写回客户端前,对响应做额外处理。 + +注意:`ModifyResponse` 也会作用于 `101 Switching Protocols` 响应。 +如果该代理路由需要转发 WebSocket 或其他 Upgrade 流量,请不要在这里消费、完全缓冲,或替换 `resp.Body` 为只读对象;后续升级流程仍然要求它保留 `io.ReadWriteCloser` 能力。 +更稳妥的做法是对 `101` 响应直接跳过这类处理。 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ModifyResponse: func(resp *http.Response) error { + if resp.StatusCode == http.StatusSwitchingProtocols { + return nil + } + resp.Header.Set("X-Proxy", "touka") + return nil + }, +})) +``` + +如果该函数返回错误,会转入 `ErrorHandler` 或默认的 `502 Bad Gateway` 处理流程。 + +### `ErrorHandler` + +用于处理连接后端失败、协议升级失败、`ModifyResponse` 返回错误等情况。 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte("upstream unavailable")) + }, +})) +``` + +### `PreserveHost` + +默认情况下,代理请求的 `Host` 会跟随后端目标地址。 + +如果设置为 `true`,则会保留客户端原始 `Host`。 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + PreserveHost: true, +})) +``` + +这在某些依赖原始域名进行路由或租户识别的后端服务中会比较有用。 + +## 转发头策略 + +Touka 支持两类常见的代理转发头: + +- 兼容性更好的 `X-Forwarded-*` +- 标准化的 `Forwarded`(RFC 7239) + +可选值: + +```go +const ( + ForwardedBoth ForwardedHeadersPolicy = iota + ForwardedNone + ForwardedXForwardedOnly + ForwardedRFC7239Only +) +``` + +推荐默认使用 `ForwardedBoth`。 + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ForwardedHeaders: touka.ForwardedBoth, + ForwardedBy: "gateway-1", + Via: "edge-1", +})) +``` + +`Via` 不是“留空即禁用”的开关。当前实现中: + +- 如果 `Via` 非空,则使用该值追加 `Via` +- 如果 `Via` 为空,则会回退到固定值 `touka-engine` + +因此,把 `Via` 留空时,发送出去的请求仍会包含 `Via` 头,只是使用默认标识 `touka-engine`。 + +如果您希望上游清楚区分不同入口、环境或网关实例,仍然建议显式设置一个稳定且可公开暴露的代理标识,例如: + +```go +r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + Via: "edge-gateway", +})) +``` + +当前版本没有提供“完全禁用追加 Via”的单独配置项,因此不要把空字符串当作关闭手段。 + +### Touka 会如何处理这些头? + +Touka 会尽量遵循代理链语义: + +- 已有的 `X-Forwarded-For` 会保留,并在末尾追加当前 hop 的客户端 IP +- 已有的 `Forwarded` 会保留,并在末尾追加当前 hop 的条目 +- 已有的 `X-Forwarded-Host` 与 `X-Forwarded-Proto` 会优先保留;如果缺失,则由当前请求补齐 +- `Via` 会追加当前代理标识 + +这意味着在 Touka 前面还有一层可信代理(如 Nginx、Traefik、Cloudflare、网关)时,上游服务仍然可以看到完整的代理链。 + +如果您**不信任**客户端传入的这些头,请在进入 `ReverseProxy` 之前自行清理,或在 `ModifyRequest` 中显式重写。 + +## 协议升级与流式响应 + +Touka 的反向代理实现支持以下能力: + +- `Connection: Upgrade` / `Upgrade` 协议升级转发 +- WebSocket 等 101 Switching Protocols 场景 +- SSE(Server-Sent Events)立即刷新 +- Trailer 透传 +- 1xx 响应透传 + +例如,代理 WebSocket 服务: + +```go +target, _ := url.Parse("http://127.0.0.1:9001") + +r.ANY("/ws/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, +})) +``` + +## Hop-by-hop 头处理 + +根据 HTTP 代理语义,Touka 在转发时会移除连接级别的 hop-by-hop 头,避免把只应作用于单跳连接的头继续传给下游。 + +典型包括: + +- `Connection` +- `Proxy-Connection` +- `Keep-Alive` +- `Proxy-Authenticate` +- `Proxy-Authorization` +- `TE` +- `Trailer` +- `Transfer-Encoding` +- `Upgrade` + +同时,若请求本身是合法的协议升级请求,Touka 会在剥离后重新补回必要的 `Connection: Upgrade` 与 `Upgrade` 头。 + +## 一个更完整的例子 + +```go +package main + +import ( + "log" + "net/http" + "net/url" + "time" + + "github.com/infinite-iroha/touka" +) + +func main() { + r := touka.Default() + + target, err := url.Parse("http://127.0.0.1:9000") + if err != nil { + log.Fatal(err) + } + + r.ANY("/api/*path", touka.ReverseProxy(touka.ReverseProxyConfig{ + Target: target, + ForwardedHeaders: touka.ForwardedBoth, + ForwardedBy: "gateway-1", + Via: "gateway-1", + FlushInterval: 100 * time.Millisecond, + ModifyRequest: func(req *http.Request) { + req.Header.Set("X-Gateway", "touka") + }, + ModifyResponse: func(resp *http.Response) error { + resp.Header.Set("X-Proxy", "touka") + return nil + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte("bad gateway")) + }, + })) + + if err := r.RunShutdown(":8080", 10*time.Second); err != nil { + log.Fatal(err) + } +} +``` + +## 与 `SetForwardByClientIP` 的关系 + +`ReverseProxy` 负责把请求转发给后端,并维护代理链头。 + +而 `SetForwardByClientIP` / `SetRemoteIPHeaders` 是 Touka 在**接收请求**时,用于解析当前请求客户端 IP 的逻辑。 + +两者通常会一起出现,但解决的是两个不同方向的问题: + +- `ReverseProxy`:出站转发 +- `SetForwardByClientIP`:入站解析 + +如果您的 Touka 本身就部署在其他代理之后,建议同时正确配置这两部分。 diff --git a/docs/routing.md b/docs/routing.md new file mode 100644 index 0000000..e90308e --- /dev/null +++ b/docs/routing.md @@ -0,0 +1,153 @@ +# 路由系统 + +Touka 拥有一个强大且灵活的路由系统,底层基于高性能的基数树(Radix Tree)实现。 + +## 基础路由 + +您可以为所有标准的 HTTP 方法注册处理器: + +```go +r.GET("/someGet", handle) +r.POST("/somePost", handle) +r.PUT("/somePut", handle) +r.DELETE("/someDelete", handle) +r.PATCH("/somePatch", handle) +r.HEAD("/someHead", handle) +r.OPTIONS("/someOptions", handle) + +// 注册所有上述方法的路由 +r.ANY("/any", handle) + +// 同时注册多个方法 +r.HandleFunc([]string{"GET", "POST"}, "/multi", handle) +``` + +## 路径参数 (Named Parameters) + +使用冒号 `:` 定义路径参数。参数值可以通过 `c.Param(key)` 获取。 + +```go +// 匹配 /user/john, 不匹配 /user/ 或 /user/john/send +r.GET("/user/:name", func(c *touka.Context) { + name := c.Param("name") + c.String(http.StatusOK, "Hello %s", name) +}) + +// 匹配 /user/john/send +r.GET("/user/:name/:action", func(c *touka.Context) { + name := c.Param("name") + action := c.Param("action") + c.String(http.StatusOK, "%s is doing %s", name, action) +}) +``` + +## 通配符路由 (Catch-all Parameters) + +使用星号 `*` 定义通配符路由,它会捕获路径中该位置之后的所有内容。 + +```go +// 匹配 /src/main.go, /src/scripts/app.js 等 +r.GET("/src/*filepath", func(c *touka.Context) { + path := c.Param("filepath") + c.String(http.StatusOK, "Viewing file: %s", path) +}) +``` + +## 路由组 (RouterGroup) + +路由组允许您共享公共路径前缀或中间件,使代码结构更清晰。 + +```go +v1 := r.Group("/api/v1") +{ + v1.GET("/login", loginEndpoint) + v1.GET("/submit", submitEndpoint) +} + +v2 := r.Group("/api/v2") +v2.Use(AuthMiddleware()) // 仅应用于 v2 组 +{ + v2.POST("/data", dataEndpoint) +} +``` + +## 路由行为配置 + +Touka 允许您自定义路由匹配的行为: + +- **RedirectTrailingSlash**: 如果启用(默认),请求 `/foo/` 会被重定向到 `/foo`(如果只有后者注册了),反之亦然。 +- **RedirectFixedPath**: 如果启用(默认),引擎会尝试修复路径大小写或移除多余的斜杠并重定向。 +- **HandleMethodNotAllowed**: 如果启用,当请求路径匹配但方法不匹配时,返回 405 而非 404。 + +```go +r := touka.New() +r.SetRedirectTrailingSlash(true) +r.SetHandleMethodNotAllowed(true) +``` + +## 获取已注册路由信息 + +您可以使用 `GetRouterInfo` 获取当前引擎中所有已注册路由的列表。 + +```go +routes := r.GetRouterInfo() +for _, route := range routes { + fmt.Printf("Method: %s, Path: %s\n", route.Method, route.Path) +} +``` + +## 自定义 404 处理 + +当请求没有匹配到任何路由时,Touka 会返回 404。您可以自定义 404 的处理逻辑: + +```go +// 使用单个处理器 +r.NoRoute(func(c *touka.Context) { + c.JSON(http.StatusNotFound, touka.H{ + "error": "资源未找到", + "path": c.Request.URL.Path, + }) +}) + +// 使用处理器链 +r.NoRoutes( + LogNotFoundMiddleware(), + func(c *touka.Context) { + c.JSON(http.StatusNotFound, touka.H{"error": "Not found"}) + }, +) +``` + +**注意**:`NoRoute` 和 `NoRoutes` 不是处理链的终点,您仍然可以在其中调用 `c.Next()` 来继续执行默认的 404 处理。 + +## 静态文件路由 + +Touka 提供了便捷的方法来注册静态文件路由: + +```go +// 服务整个目录 +r.StaticDir("/assets", "./static") +// 访问 /assets/js/main.js 将返回 ./static/js/main.js + +// 服务单个文件 +r.StaticFile("/favicon.ico", "./resources/favicon.ico") + +// 服务嵌入式文件系统 +//go:embed dist/* +var content embed.FS + +func main() { + r := touka.Default() + fsroot, _ := fs.Sub(content, "dist") + r.StaticFS("/", http.FS(fsroot)) + r.Run(":8080") +} +``` + +这些方法同样可以在路由组中使用: + +```go +api := r.Group("/api") +api.StaticDir("/files", "./uploads") +api.StaticFile("/logo", "./assets/logo.png") +``` diff --git a/docs/sse.md b/docs/sse.md new file mode 100644 index 0000000..bafc553 --- /dev/null +++ b/docs/sse.md @@ -0,0 +1,131 @@ +# Server-Sent Events (SSE) + +Server-Sent Events 允许服务器向客户端实时推送数据。Touka 对此提供了原生且易用的支持。 + +## 核心结构:`Event` + +`Event` 结构体代表一个 SSE 消息: + +```go +type Event struct { + Event string // 事件名称 + Data string // 数据内容 (支持多行) + Id string // 事件 ID + Retry string // 重连时间 (毫秒) +} +``` + +## 模式一:回调模式 (EventStream) + +这是最推荐的使用方式,它更简单且能自动管理连接生命周期。 + +```go +r.GET("/events", func(c *touka.Context) { + c.EventStream(func(w io.Writer) bool { + // 构建事件 + event := touka.Event{ + Data: "现在的时间是: " + time.Now().Format(time.RFC3339), + } + + // 渲染并写入 + if err := event.Render(w); err != nil { + return false // 发生写入错误(如客户端断开),返回 false 停止流 + } + + time.Sleep(2 * time.Second) + return true // 返回 true 继续下一次循环 + }) +}) +``` + +## 模式二:通道模式 (EventStreamChan) + +如果您需要更高级的并发控制(例如从多个异步源接收数据),可以使用通道模式。 + +```go +r.GET("/events-chan", func(c *touka.Context) { + eventChan, errChan := c.EventStreamChan() + + // 监听错误/断开连接 + go func() { + if err := <-errChan; err != nil { + log.Printf("SSE 错误: %v", err) + } + }() + + // 发送数据 + go func() { + defer close(eventChan) // 务必在结束时关闭 + + for i := 0; i < 10; i++ { + select { + case <-c.Request.Context().Done(): + return + default: + eventChan <- touka.Event{ + Data: fmt.Sprintf("消息 #%d", i), + } + time.Sleep(1 * time.Second) + } + } + }() +}) +``` + +## 最佳实践 + +1. **资源回收**: 确保在 `EventStreamChan` 模式下正确监听 `c.Request.Context().Done()` 以避免 Goroutine 泄漏。 +2. **数据格式**: SSE 协议要求数据为 UTF-8。Touka 的 `Render` 方法会自动处理多行数据并加上必要的 `data:` 前缀。 +3. **超时管理**: SSE 连接通常是长连接,请确保您的反向代理(如 Nginx)配置了足够大的写超时时间。 + +## 优雅关闭与资源清理 + +在长连接场景下,正确处理客户端断开或服务器关闭信号至关重要,以防止资源泄漏。 + +### 示例:监听 Context 取消信号 + +```go +r.GET("/events-graceful", func(c *touka.Context) { + // 设置响应头(如果手动处理,EventStream 会自动设置) + + // 使用 Context 的 Done 通道来感知连接关闭 + ctx := c.Request.Context() + + // 启动一个用于模拟数据生成的循环 + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + c.EventStream(func(w io.Writer) bool { + select { + case <-ctx.Done(): + // 收到优雅关闭信号(客户端离开或服务器正在关闭) + fmt.Println("SSE 连接正在关闭,开始清理资源...") + return false // 返回 false 告知框架停止流 + + case t := <-ticker.C: + event := touka.Event{ + Data: "Tick at " + t.Format(time.RFC3339), + } + if err := event.Render(w); err != nil { + return false + } + return true + } + }) + + fmt.Println("SSE 连接已彻底释放") +}) +``` + +在该示例中,我们显式地在回调函数中使用 `select` 监听 `ctx.Done()`。虽然 Touka 的 `EventStream` 内部也会检查此信号,但在回调内部自行处理可以执行更复杂的清理逻辑(如关闭数据库连接、停止特定的 Goroutine 等)。 + +### 为什么会出现 "context deadline exceeded"? + +如果您在优雅停机时遇到 `context deadline exceeded` 错误,通常是因为 SSE 连接仍然活跃,而 `http.Server.Shutdown` 正在等待它们结束。 + +在 Touka 的新版本中,我们通过 `BaseContext` 将 `Engine` 的关闭信号注入到了每个请求的 `Context` 中。这意味着: +1. 当服务器收到关闭信号时,`engine.shutdownCtx` 会被取消。 +2. 随后,所有活跃请求的 `c.Request.Context()` 也会收到取消信号。 +3. 您的 SSE 处理器中的 `case <-c.Request.Context().Done():` 会立即触发,从而优雅地结束连接。 + +**注意:** 请务必使用 `RunShutdown`、`RunTLS` 或 `RunTLSRedir` 来启动服务器,以便框架能自动管理这些信号。 diff --git a/docs/static-files.md b/docs/static-files.md new file mode 100644 index 0000000..a2138cd --- /dev/null +++ b/docs/static-files.md @@ -0,0 +1,63 @@ +# 静态文件与资源 + +Touka 提供了多种方式来服务静态文件,这些方法都集成了 Touka 的统一错误处理机制。 + +## 服务本地目录 + +`StaticDir` 方法将 URL 路径映射到本地文件系统目录。 + +```go +// 访问 /assets/js/main.js 将读取 ./static/js/main.js +r.StaticDir("/assets", "./static") +``` + +## 服务单个文件 + +`StaticFile` 用于将特定的 URL 映射到单个本地文件。 + +```go +r.StaticFile("/favicon.ico", "./resources/favicon.ico") +``` + +## 集成 Go 嵌入式资源 (embed.FS) + +使用 Go 1.16+ 的 `embed` 特性,您可以将整个静态前端项目编译进二进制文件中。 + +```go +//go:embed dist/* +var content embed.FS + +func main() { + r := touka.Default() + + // 剥离 "dist" 前缀并包装为 http.FS + fsroot, _ := fs.Sub(content, "dist") + + // 使用 StaticFS 提供服务 + r.StaticFS("/static", http.FS(fsroot)) + + // 您也可以使用 StaticFS 服务根路径 + // r.StaticFS("/", http.FS(fsroot)) + + r.Run(":8080") +} +``` + +## 未匹配路径作为文件服务 (UnMatchFS) + +这是一个独特的功能:当没有任何 API 路由匹配时,尝试从指定的文件系统中查找并返回文件。这非常适合用于单页应用(SPA)的部署。 + +```go +r := touka.New() +r.SetUnMatchFS(http.Dir("./frontend/dist")) + +// API 路由 +r.GET("/api/status", handleStatus) + +// 如果请求 /index.html 且没有 /index.html 的路由, +// 则会从 ./frontend/dist/index.html 读取。 +``` + +## 性能提示 + +对于高负载的静态资源分发,虽然 Touka 表现出色,但我们仍建议在生产环境中使用 Nginx 或 CDN 站在 Touka 前面来处理静态文件,让 Touka 专注于处理动态逻辑。 diff --git a/ecw.go b/ecw.go index c87be28..754571f 100644 --- a/ecw.go +++ b/ecw.go @@ -7,6 +7,7 @@ package touka import ( "bufio" "errors" + "maps" "net" "net/http" "sync" @@ -27,7 +28,7 @@ type errorCapturingResponseWriter struct { // errorResponseWriterPool 是用于复用 errorCapturingResponseWriter 实例的对象池 var errorResponseWriterPool = sync.Pool{ - New: func() interface{} { + New: func() any { return &errorCapturingResponseWriter{ headerSnapshot: make(http.Header), // 预先初始化 map, 减少 reset 时的分配 } @@ -91,9 +92,8 @@ func (ecw *errorCapturingResponseWriter) WriteHeader(statusCode int) { // 是成功状态码 // 将 ecw.headerSnapshot 中(由 FileServer 在此之前通过 ecw.Header() 设置的) // 任何头部直接复制到原始的 w.Header(), 确保多值头部正确传递 - for k, v := range ecw.headerSnapshot { - ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值 - } + // 直接赋值 []string, 保留所有值 + maps.Copy(ecw.w.Header(), ecw.headerSnapshot) ecw.w.WriteHeader(statusCode) // 实际写入状态码到原始 ResponseWriter ecw.responseStarted = true // 标记成功响应已开始 } @@ -112,9 +112,8 @@ func (ecw *errorCapturingResponseWriter) Write(data []byte) (int, error) { ecw.statusCode = http.StatusOK // 隐式 200 OK } // 将 headerSnapshot 中的头部复制到原始 ResponseWriter 的 Header - for k, v := range ecw.headerSnapshot { - ecw.w.Header()[k] = v // 直接赋值 []string, 保留所有值 - } + // 直接赋值 []string, 保留所有值 + maps.Copy(ecw.w.Header(), ecw.headerSnapshot) ecw.w.WriteHeader(ecw.Status()) // 发送实际的状态码 (可能是 200 或之前设置的 2xx) ecw.responseStarted = true } diff --git a/engine.go b/engine.go index 581258c..c2eae91 100644 --- a/engine.go +++ b/engine.go @@ -51,7 +51,7 @@ type Engine struct { LogReco *reco.Logger - HTMLRender interface{} // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 + HTMLRender any // 用于 HTML 模板渲染,可以设置为 *template.Template 或自定义渲染器接口 routesInfo []RouteInfo // 存储所有注册的路由信息 @@ -67,6 +67,9 @@ type Engine struct { Protocols ProtocolsConfig //协议版本配置 useDefaultProtocols bool //是否使用默认协议 + shutdownCtx context.Context + shutdownCancel context.CancelFunc + // ServerConfigurator 允许在服务器启动前对其进行自定义配置 // 例如,设置 ReadTimeout, WriteTimeout 等 ServerConfigurator func(*http.Server) @@ -207,15 +210,16 @@ func New() *Engine { TLSServerConfigurator: nil, GlobalMaxRequestBodySize: -1, } + engine.shutdownCtx, engine.shutdownCancel = context.WithCancel(context.Background()) //engine.SetProtocols(GetDefaultProtocolsConfig()) engine.SetDefaultProtocols() engine.SetLoggerCfg(defaultLogRecoConfig) // 初始化 Context Pool,为每个新 Context 实例提供一个构造函数 - engine.pool.New = func() interface{} { + engine.pool.New = func() any { return &Context{ Writer: newResponseWriter(nil), // 初始时可以传入nil,在ServeHTTP中会重新设置实际的 http.ResponseWriter Params: make(Params, 0, engine.maxParams), // 预分配 Params 切片以减少内存分配 - Keys: make(map[string]interface{}), + Keys: make(map[string]any), Errors: make([]error, 0), ctx: context.Background(), // 初始上下文,后续会被请求的 Context 覆盖 HTTPClient: engine.HTTPClient, @@ -315,11 +319,16 @@ func GetDefaultProtocolsConfig() *ProtocolsConfig { // 设置默认Protocols func (engine *Engine) SetDefaultProtocols() { engine.useDefaultProtocols = true - engine.SetProtocols(GetDefaultProtocolsConfig()) + engine.setProtocols(GetDefaultProtocolsConfig()) } // 设置Protocol func (engine *Engine) SetProtocols(config *ProtocolsConfig) { + engine.setProtocols(config) + engine.useDefaultProtocols = false +} + +func (engine *Engine) setProtocols(config *ProtocolsConfig) { engine.Protocols = *config engine.serverProtocols = &http.Protocols{} // 初始化指针 func() { @@ -329,7 +338,13 @@ func (engine *Engine) SetProtocols(config *ProtocolsConfig) { p.SetUnencryptedHTTP2(config.Http2_Cleartext) *engine.serverProtocols = p // 将值赋给指针指向的结构体 }() - engine.useDefaultProtocols = false +} + +// applyDefaultServerConfig 应用框架的默认配置到 http.Server +func (engine *Engine) applyDefaultServerConfig(srv *http.Server) { + if engine.serverProtocols != nil { + srv.Protocols = engine.serverProtocols + } } // 配置全局Req Body大小限制 @@ -421,6 +436,41 @@ func getHandlerName(h HandlerFunc) string { } +const MaxSkippedNodesCap = 256 + +// TempSkippedNodesPool 存储 *[]skippedNode 以复用内存 +var TempSkippedNodesPool = sync.Pool{ + New: func() any { + // 返回一个指向容量为 256 的新切片的指针 + s := make([]skippedNode, 0, MaxSkippedNodesCap) + return &s + }, +} + +// GetTempSkippedNodes 从 Pool 中获取一个 *[]skippedNode 指针 +func GetTempSkippedNodes() *[]skippedNode { + // 直接返回 Pool 中存储的指针 + return TempSkippedNodesPool.Get().(*[]skippedNode) +} + +// PutTempSkippedNodes 将用完的 *[]skippedNode 指针放回 Pool +func PutTempSkippedNodes(skippedNodes *[]skippedNode) { + if skippedNodes == nil || *skippedNodes == nil { + return + } + + // 检查容量是否符合预期。如果容量不足,则丢弃,不放回 Pool。 + if cap(*skippedNodes) < MaxSkippedNodesCap { + return // 丢弃该对象,让 Pool 在下次 Get 时通过 New 重新分配 + } + + // 长度重置为 0,保留容量,实现复用 + *skippedNodes = (*skippedNodes)[:0] + + // 将指针存回 Pool + TempSkippedNodesPool.Put(skippedNodes) +} + // 405中间件 func MethodNotAllowed() HandlerFunc { return func(c *Context) { @@ -432,9 +482,10 @@ func MethodNotAllowed() HandlerFunc { // 如果是 OPTIONS 请求,尝试查找所有允许的方法 allowedMethods := []string{} for _, treeIter := range engine.methodTrees { - var tempSkippedNodes []skippedNode // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) + PutTempSkippedNodes(tempSkippedNodes) if value.handlers != nil { allowedMethods = append(allowedMethods, treeIter.method) } @@ -451,9 +502,10 @@ func MethodNotAllowed() HandlerFunc { if treeIter.method == httpMethod { // 已经处理过当前方法,跳过 continue } - var tempSkippedNodes []skippedNode // 用于临时查找,不影响主 Context // 注意这里 treeIter.root 才是正确的,因为 treeIter 是 methodTree 类型 - value := treeIter.root.getValue(requestPath, nil, &tempSkippedNodes, false) // 只查找是否存在,不需要参数 + tempSkippedNodes := GetTempSkippedNodes() + value := treeIter.root.getValue(requestPath, nil, tempSkippedNodes, false) // 只查找是否存在,不需要参数 + PutTempSkippedNodes(tempSkippedNodes) if value.handlers != nil { // 使用定义的ErrorHandle处理 engine.errorHandle.handler(c, http.StatusMethodNotAllowed, errors.New("method not allowed")) @@ -661,9 +713,8 @@ func (engine *Engine) handleRequest(c *Context) { // 查找匹配的节点和处理函数 // 这里传递 &c.Params 而不是重新创建,以利用 Context 中预分配的容量 // skippedNodes 内部使用,因此无需从外部传入已分配的 slice - var skippedNodes []skippedNode // 用于回溯的跳过节点 // 直接在 rootNode 上调用 getValue 方法 - value := rootNode.getValue(requestPath, &c.Params, &skippedNodes, true) // unescape=true 对路径参数进行 URL 解码 + value := rootNode.getValue(requestPath, &c.Params, &c.SkippedNodes, true) // unescape=true 对路径参数进行 URL 解码 if value.handlers != nil { //c.handlers = engine.combineHandlers(engine.globalHandlers, value.handlers) // 组合全局中间件和路由处理函数 @@ -730,3 +781,9 @@ func (engine *Engine) handleRequest(c *Context) { c.Next() // 执行处理函数链 //c.Writer.Flush() // 确保所有缓冲的响应数据被发送 } + +// Context 返回 Engine 的根上下文, 该上下文在服务器优雅关闭时会被取消. +// 它可以用于在长连接 (如 SSE) 中监听关闭信号. +func (engine *Engine) Context() context.Context { + return engine.shutdownCtx +} diff --git a/go.mod b/go.mod index d8b0dfc..42f4be4 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,16 @@ module github.com/infinite-iroha/touka -go 1.25.1 +go 1.26 require ( github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 - github.com/WJQSERVER-STUDIO/httpc v0.8.2 - github.com/WJQSERVER/wanf v0.0.0-20250810023226-e51d9d0737ee - github.com/fenthope/reco v0.0.4 - github.com/go-json-experiment/json v0.0.0-20250910080747-cc2cfa0554c3 + github.com/WJQSERVER-STUDIO/httpc v0.9.0 + github.com/WJQSERVER/wanf v0.0.8 + github.com/fenthope/reco v0.0.5 + github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 ) require ( github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.46.0 // indirect + golang.org/x/net v0.52.0 // indirect ) diff --git a/go.sum b/go.sum index ca56e55..b49879b 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,14 @@ github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2 h1:AiIHXP21LpK7pFfqUlUstgQEWzjbekZgxOuvVwiMfyM= github.com/WJQSERVER-STUDIO/go-utils/iox v0.0.2/go.mod h1:mCLqYU32bTmEE6dpj37MKKiZgz70Jh/xyK9vVbq6pok= -github.com/WJQSERVER-STUDIO/httpc v0.8.2 h1:PFPLodV0QAfGEP6915J57vIqoKu9cGuuiXG/7C9TNUk= -github.com/WJQSERVER-STUDIO/httpc v0.8.2/go.mod h1:8WhHVRO+olDFBSvL5PC/bdMkb6U3vRdPJ4p4pnguV5Y= -github.com/WJQSERVER/wanf v0.0.0-20250810023226-e51d9d0737ee h1:tJ31DNBn6UhWkk8fiikAQWqULODM+yBcGAEar1tzdZc= -github.com/WJQSERVER/wanf v0.0.0-20250810023226-e51d9d0737ee/go.mod h1:q2Pyg+G+s1acMWxrbI4CwS/Yk76/BzLREEdZ8iFwUNE= -github.com/fenthope/reco v0.0.4 h1:yo2g3aWwdoMpaZWZX4SdZOW7mCK82RQIU/YI8ZUQThM= -github.com/fenthope/reco v0.0.4/go.mod h1:eMyS8HpdMVdJ/2WJt6Cvt8P1EH9Igzj5lSJrgc+0jeg= -github.com/go-json-experiment/json v0.0.0-20250813233538-9b1f9ea2e11b h1:6Q4zRHXS/YLOl9Ng1b1OOOBWMidAQZR3Gel0UKPC/KU= -github.com/go-json-experiment/json v0.0.0-20250813233538-9b1f9ea2e11b/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= -github.com/go-json-experiment/json v0.0.0-20250910080747-cc2cfa0554c3 h1:02WINGfSX5w0Mn+F28UyRoSt9uvMhKguwWMlOAh6U/0= -github.com/go-json-experiment/json v0.0.0-20250910080747-cc2cfa0554c3/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok= +github.com/WJQSERVER-STUDIO/httpc v0.9.0 h1:MpXcQQqukrSLHH/2tTfnXrhqD6nEDHB/gbzehXaS8o4= +github.com/WJQSERVER-STUDIO/httpc v0.9.0/go.mod h1:filzryrl4eAtFVyl4oVHcJqx1SpNFbrCn+ddQPLlCSg= +github.com/WJQSERVER/wanf v0.0.8 h1:1Ri9d7nKhu22hGxP8O9B9rXnYym6DYGKgi6WRVx3VF8= +github.com/WJQSERVER/wanf v0.0.8/go.mod h1:R0Zw/1skEMVlQ9m5atbkmanlW+9h2bkdq7+wbPY+F/8= +github.com/fenthope/reco v0.0.5 h1:Z/bOunFf4LSgYP/IxG9fe2pTrIq7bPsDflflbNR5Agw= +github.com/fenthope/reco v0.0.5/go.mod h1:nd5gMkuJHN2+2Iwwt3xy+HSqRaROauIjHNkmQWRsHyM= +github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVwYFP/L05h5TKQxvkXoKxNvTpjxYKdF1Nlwuao= +github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= diff --git a/protocols_test.go b/protocols_test.go new file mode 100644 index 0000000..73f16e9 --- /dev/null +++ b/protocols_test.go @@ -0,0 +1,111 @@ +package touka + +import ( + "crypto/tls" + "net/http" + "testing" +) + +func TestApplyDefaultServerConfig(t *testing.T) { + engine := New() + + // 1. 测试默认协议 + srv1 := &http.Server{} + engine.applyDefaultServerConfig(srv1) + + if srv1.Protocols == nil { + t.Fatal("srv1.Protocols should not be nil after applyDefaultServerConfig") + } + + // 默认配置是 Http1: true, Http2: false, Http2_Cleartext: false + if !srv1.Protocols.HTTP1() { + t.Error("Expected HTTP/1 to be enabled by default") + } + if srv1.Protocols.HTTP2() { + t.Error("Expected HTTP/2 to be disabled by default") + } + + // 2. 测试自定义协议 + engine.SetProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, + Http2_Cleartext: true, + }) + + srv2 := &http.Server{} + engine.applyDefaultServerConfig(srv2) + + if srv2.Protocols == nil { + t.Fatal("srv2.Protocols should not be nil after applyDefaultServerConfig") + } + + if !srv2.Protocols.HTTP1() { + t.Error("Expected HTTP/1 to be enabled after SetProtocols") + } + if !srv2.Protocols.HTTP2() { + t.Error("Expected HTTP/2 to be enabled after SetProtocols") + } + if !srv2.Protocols.UnencryptedHTTP2() { + t.Error("Expected Unencrypted HTTP/2 to be enabled after SetProtocols") + } + + // 3. 再次更改协议并验证 + engine.SetProtocols(&ProtocolsConfig{ + Http1: false, + Http2: true, + Http2_Cleartext: false, + }) + + srv3 := &http.Server{} + engine.applyDefaultServerConfig(srv3) + + if srv3.Protocols == nil { + t.Fatal("srv3.Protocols should not be nil") + } + if srv3.Protocols.HTTP1() { + t.Error("Expected HTTP/1 to be disabled") + } + if !srv3.Protocols.HTTP2() { + t.Error("Expected HTTP/2 to be enabled") + } +} + +func TestRunTLSProtocolInheritance(t *testing.T) { + engine := New() + + // 模拟 RunTLS 中的逻辑: 如果使用默认协议, 则启用 HTTP/2 + if engine.useDefaultProtocols { + engine.setProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, + }) + } + + srv := &http.Server{TLSConfig: &tls.Config{}} + engine.applyDefaultServerConfig(srv) + + if !srv.Protocols.HTTP2() { + t.Error("RunTLS simulation: Expected HTTP/2 to be enabled for default config") + } + + // 模拟用户设置了自定义协议后调用 RunTLS + engine = New() + engine.SetProtocols(&ProtocolsConfig{ + Http1: true, + Http2: false, // 用户明确不想要 HTTP/2 + }) + + if engine.useDefaultProtocols { + engine.setProtocols(&ProtocolsConfig{ + Http1: true, + Http2: true, + }) + } + + srv2 := &http.Server{TLSConfig: &tls.Config{}} + engine.applyDefaultServerConfig(srv2) + + if srv2.Protocols.HTTP2() { + t.Error("RunTLS simulation: Expected HTTP/2 to be DISABLED if user set custom protocols previously") + } +} diff --git a/recovery.go b/recovery.go index 5dfb837..dc4d892 100644 --- a/recovery.go +++ b/recovery.go @@ -18,7 +18,7 @@ import ( // PanicHandlerFunc 定义了用户自定义的 panic 处理函数类型 // 它接收当前的 Context 和 panic 的值 -type PanicHandlerFunc func(c *Context, panicInfo interface{}) +type PanicHandlerFunc func(c *Context, panicInfo any) // RecoveryWithOptions 返回一个可配置的 panic 恢复中间件 // @@ -50,7 +50,7 @@ func Recovery() HandlerFunc { } // defaultPanicHandler 是默认的 panic 处理逻辑 -func defaultPanicHandler(c *Context, r interface{}) { +func defaultPanicHandler(c *Context, r any) { // 检查连接是否已由客户端关闭 // 常见的错误类型包括 net.OpError (其内部错误可能是 os.SyscallError), // 以及在 HTTP/2 中可能出现的特定 stream 错误 @@ -107,7 +107,7 @@ func defaultPanicHandler(c *Context, r interface{}) { // isBrokenPipeError 检查 recover() 捕获的值是否表示一个由客户端断开连接引起的网络错误 // 这对于防止在已关闭的连接上写入响应至关重要 -func isBrokenPipeError(r interface{}) bool { +func isBrokenPipeError(r any) bool { // 将 recover() 的结果转换为 error 类型 err, ok := r.(error) if !ok { diff --git a/respw.go b/respw.go index 2cf6700..dd94db3 100644 --- a/respw.go +++ b/respw.go @@ -45,6 +45,15 @@ func newResponseWriter(w http.ResponseWriter) ResponseWriter { } } +// UnwrapResponseWriter returns the underlying stdlib response writer when the +// provided writer is Touka's internal wrapper. +func UnwrapResponseWriter(w ResponseWriter) http.ResponseWriter { + if wrapped, ok := w.(*responseWriterImpl); ok && wrapped.ResponseWriter != nil { + return wrapped.ResponseWriter + } + return w +} + func (rw *responseWriterImpl) reset(w http.ResponseWriter) { rw.ResponseWriter = w rw.status = 0 @@ -56,6 +65,10 @@ func (rw *responseWriterImpl) WriteHeader(statusCode int) { if rw.hijacked { return } + if statusCode >= 100 && statusCode < 200 && statusCode != http.StatusSwitchingProtocols { + rw.ResponseWriter.WriteHeader(statusCode) + return + } if rw.status == 0 { // 确保只设置一次 rw.status = statusCode rw.ResponseWriter.WriteHeader(statusCode) diff --git a/reverseproxy.go b/reverseproxy.go new file mode 100644 index 0000000..1730b1e --- /dev/null +++ b/reverseproxy.go @@ -0,0 +1,933 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2026 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "mime" + "net" + "net/http" + "net/http/httptrace" + "net/netip" + "net/textproto" + "net/url" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// ForwardedHeadersPolicy controls how forwarding headers are generated. +// The zero value uses both X-Forwarded-* and RFC 7239 Forwarded headers. +type ForwardedHeadersPolicy int + +const ( + ForwardedBoth ForwardedHeadersPolicy = iota + ForwardedNone + ForwardedXForwardedOnly + ForwardedRFC7239Only +) + +// BufferPool provides temporary buffers for response body copying. +type BufferPool interface { + Get() []byte + Put([]byte) +} + +// ReverseProxyConfig configures the reverse proxy handler. +type ReverseProxyConfig struct { + Target *url.URL + + Transport http.RoundTripper + FlushInterval time.Duration + BufferPool BufferPool + + ModifyRequest func(*http.Request) + ModifyResponse func(*http.Response) error + ErrorHandler func(http.ResponseWriter, *http.Request, error) + + ForwardedHeaders ForwardedHeadersPolicy + ForwardedBy string + Via string + PreserveHost bool +} + +var ( + errReverseProxyNilTarget = errors.New("reverse proxy target is nil") + errReverseProxyInvalidTarget = errors.New("reverse proxy target must include scheme and host") + errReverseProxyCopyDone = errors.New("reverse proxy switch protocol copy complete") +) + +type reverseProxyHandler struct { + config ReverseProxyConfig + target *url.URL + receivedBy string + configError error +} + +type reverseProxyStatusError struct { + status int + err error +} + +func (e *reverseProxyStatusError) Error() string { + if e == nil || e.err == nil { + return "" + } + return e.err.Error() +} + +func (e *reverseProxyStatusError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +type noopCloseReader struct { + readCloser io.ReadCloser + closed atomic.Bool +} + +func (n *noopCloseReader) Read(p []byte) (int, error) { + if n.closed.Load() { + return 0, errors.New("reverse proxy read on closed body") + } + return n.readCloser.Read(p) +} + +func (n *noopCloseReader) Close() error { + n.closed.Store(true) + return nil +} + +type maxLatencyWriter struct { + dst ResponseWriter + latency time.Duration + + mu sync.Mutex + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + n, err := m.dst.Write(p) + if m.latency < 0 { + m.dst.Flush() + return n, err + } + if m.flushPending { + return n, err + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return n, err +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.flushPending { + return + } + m.dst.Flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +type switchProtocolCopier struct { + user io.ReadWriter + backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + if _, err := io.Copy(c.user, c.backend); err != nil { + errc <- err + return + } + if cw, ok := c.user.(interface{ CloseWrite() error }); ok { + errc <- cw.CloseWrite() + return + } + errc <- errReverseProxyCopyDone +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + if _, err := io.Copy(c.backend, c.user); err != nil { + errc <- err + return + } + if cw, ok := c.backend.(interface{ CloseWrite() error }); ok { + errc <- cw.CloseWrite() + return + } + errc <- errReverseProxyCopyDone +} + +// ReverseProxy returns a handler that proxies requests to the configured backend. +func ReverseProxy(config ReverseProxyConfig) HandlerFunc { + proxy := newReverseProxyHandler(config) + return func(c *Context) { + proxy.ServeHTTP(c) + } +} + +func newReverseProxyHandler(config ReverseProxyConfig) *reverseProxyHandler { + target := cloneReverseProxyURL(config.Target) + if target != nil { + normalizeReverseProxyTarget(target) + } + + proxy := &reverseProxyHandler{ + config: config, + target: target, + receivedBy: reverseProxyReceivedBy(config.Via), + } + + if err := validateReverseProxyTarget(target); err != nil { + proxy.configError = err + } + + switch config.ForwardedHeaders { + case ForwardedBoth, ForwardedNone, ForwardedXForwardedOnly, ForwardedRFC7239Only: + default: + proxy.config.ForwardedHeaders = ForwardedBoth + } + + return proxy +} + +func (p *reverseProxyHandler) ServeHTTP(c *Context) { + defer c.Abort() + + if p.configError != nil { + p.handleError(c, &reverseProxyStatusError{status: http.StatusInternalServerError, err: p.configError}) + return + } + + transport := p.config.Transport + if transport == nil { + transport = http.DefaultTransport + } + + ctx, cancel := p.requestContext(c) + defer cancel() + + outreq := c.Request.Clone(ctx) + if c.Request.ContentLength == 0 { + outreq.Body = nil + } + if outreq.Body != nil { + outreq.Body = &noopCloseReader{readCloser: outreq.Body} + defer outreq.Body.Close() + } + if outreq.Header == nil { + outreq.Header = make(http.Header) + } + outreq.Close = false + + rewriteReverseProxyURL(outreq, p.target) + if !p.config.PreserveHost { + outreq.Host = "" + } + outreq.URL.RawQuery = cleanReverseProxyQueryParams(outreq.URL.RawQuery) + + reqUpType := reverseProxyUpgradeType(outreq.Header) + if reqUpType != "" && !isPrintableASCII(reqUpType) { + p.handleError(c, &reverseProxyStatusError{ + status: http.StatusBadRequest, + err: fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), + }) + return + } + + removeHopByHopHeaders(outreq.Header) + if headerValuesContainToken(c.Request.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + p.addForwardingHeaders(c.Request, outreq) + appendViaHeader(outreq.Header, reverseProxyViaProtocol(c.Request.ProtoMajor, c.Request.ProtoMinor, c.Request.Proto), p.receivedBy) + + if _, ok := outreq.Header["User-Agent"]; !ok { + outreq.Header.Set("User-Agent", "") + } + + if p.config.ModifyRequest != nil { + p.config.ModifyRequest(outreq) + } + + rawWriter := reverseProxyBaseResponseWriter(c.Writer) + var ( + roundTripMu sync.Mutex + roundTripDone bool + ) + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + roundTripMu.Lock() + defer roundTripMu.Unlock() + if roundTripDone { + return nil + } + h := c.Writer.Header() + saved := h.Clone() + clear(h) + reverseProxyCopyHeader(h, http.Header(header)) + rawWriter.WriteHeader(code) + clear(h) + reverseProxyCopyHeader(h, saved) + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + + res, err := transport.RoundTrip(outreq) + roundTripMu.Lock() + roundTripDone = true + roundTripMu.Unlock() + if err != nil { + p.handleError(c, err) + return + } + + if res.StatusCode == http.StatusSwitchingProtocols { + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + if !p.modifyResponse(c, res, outreq) { + return + } + if err := p.handleUpgradeResponse(c, outreq, res); err != nil { + p.handleError(c, err) + } + return + } + + removeHopByHopHeaders(res.Header) + appendViaHeader(res.Header, reverseProxyViaProtocol(res.ProtoMajor, res.ProtoMinor, res.Proto), p.receivedBy) + + if !p.modifyResponse(c, res, outreq) { + return + } + + reverseProxyCopyHeader(c.Writer.Header(), res.Header) + + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for key := range res.Trailer { + trailerKeys = append(trailerKeys, key) + } + c.Writer.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + c.Writer.WriteHeader(res.StatusCode) + + if err := p.copyResponse(c.Writer, res.Body, p.flushInterval(res)); err != nil { + defer res.Body.Close() + c.AddError(fmt.Errorf("reverse proxy body copy failed: %w", err)) + p.logf(c, "reverse proxy body copy failed: %v", err) + return + } + res.Body.Close() + + if len(res.Trailer) > 0 { + c.Writer.Flush() + } + + // Keep the stdlib-compatible fallback here. + // If the backend only exposes additional trailer keys after the body has been + // fully read, the trailer map can grow and those values must be written using + // the TrailerPrefix form instead of the pre-announced bare header keys. + if len(res.Trailer) == announcedTrailers { + reverseProxyCopyHeader(c.Writer.Header(), res.Trailer) + return + } + + for key, values := range res.Trailer { + prefixedKey := http.TrailerPrefix + key + for _, value := range values { + c.Writer.Header().Add(prefixedKey, value) + } + } +} + +func (p *reverseProxyHandler) requestContext(c *Context) (context.Context, context.CancelFunc) { + ctx := c.Request.Context() + if ctx.Done() != nil { + return ctx, func() {} + } + + // Follow the same compatibility path as net/http/httputil.ReverseProxy: + // request contexts are normally cancelable, but middleware can still replace + // c.Request with one backed by context.Background/TODO or another context with + // a nil Done channel. In that case CloseNotifier still provides disconnect + // propagation for the upstream round trip. + rawWriter := reverseProxyBaseResponseWriter(c.Writer) + cn, ok := rawWriter.(http.CloseNotifier) + if !ok { + return ctx, func() {} + } + + ctx, cancel := context.WithCancel(ctx) + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + return ctx, cancel +} + +func (p *reverseProxyHandler) addForwardingHeaders(in *http.Request, out *http.Request) { + if p.config.ForwardedHeaders == ForwardedNone { + return + } + + clientIP := reverseProxyClientIP(in.RemoteAddr) + scheme := reverseProxyRequestScheme(in) + host := in.Host + + if p.config.ForwardedHeaders == ForwardedBoth || p.config.ForwardedHeaders == ForwardedXForwardedOnly { + if clientIP != "" { + appendXForwardedFor(out.Header, clientIP) + } + if host != "" { + if len(out.Header.Values("X-Forwarded-Host")) == 0 { + out.Header.Set("X-Forwarded-Host", host) + } + } + if scheme != "" { + if len(out.Header.Values("X-Forwarded-Proto")) == 0 { + out.Header.Set("X-Forwarded-Proto", scheme) + } + } + } + + if p.config.ForwardedHeaders == ForwardedBoth || p.config.ForwardedHeaders == ForwardedRFC7239Only { + if forwardedValue := buildForwardedHeaderValue(clientIP, p.config.ForwardedBy, host, scheme); forwardedValue != "" { + if prior := out.Header.Values("Forwarded"); len(prior) > 0 { + forwardedValue = strings.Join(prior, ", ") + ", " + forwardedValue + out.Header.Del("Forwarded") + } + out.Header.Add("Forwarded", forwardedValue) + } + } +} + +func appendXForwardedFor(header http.Header, clientIP string) { + if clientIP == "" { + return + } + prior := header.Values("X-Forwarded-For") + if len(prior) == 0 { + header.Set("X-Forwarded-For", clientIP) + return + } + header.Set("X-Forwarded-For", strings.Join(prior, ", ")+", "+clientIP) +} + +func (p *reverseProxyHandler) modifyResponse(c *Context, res *http.Response, req *http.Request) bool { + if p.config.ModifyResponse == nil { + return true + } + if err := p.config.ModifyResponse(res); err != nil { + res.Body.Close() + p.handleError(c, err) + return false + } + return true +} + +func (p *reverseProxyHandler) handleError(c *Context, err error) { + if err == nil { + return + } + c.AddError(err) + if c.Writer.IsHijacked() { + p.logf(c, "reverse proxy error after hijack: %v", err) + return + } + if p.config.ErrorHandler != nil { + p.config.ErrorHandler(c.Writer, c.Request, err) + if c.Writer.Written() || c.Writer.IsHijacked() { + return + } + } + c.ErrorUseHandle(reverseProxyStatusCode(err), err) +} + +func (p *reverseProxyHandler) handleUpgradeResponse(c *Context, req *http.Request, res *http.Response) error { + reqUpType := reverseProxyUpgradeType(req.Header) + resUpType := reverseProxyUpgradeType(res.Header) + if reqUpType == "" || resUpType == "" { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: fmt.Errorf("invalid upgrade negotiation: request protocol=%q, response protocol=%q", reqUpType, resUpType), + } + } + if !isPrintableASCII(resUpType) { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), + } + } + if !strings.EqualFold(reqUpType, resUpType) { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), + } + } + + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + res.Body.Close() + return &reverseProxyStatusError{ + status: http.StatusBadGateway, + err: errors.New("backend returned 101 response without writable body"), + } + } + + clientConn, brw, err := c.Writer.Hijack() + if err != nil { + backConn.Close() + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + defer clientConn.Close() + defer backConn.Close() + + backConnClosed := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + case <-backConnClosed: + } + backConn.Close() + }() + defer close(backConnClosed) + + res.Body = nil + if err := res.Write(brw); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + if err := brw.Flush(); err != nil { + return &reverseProxyStatusError{status: http.StatusBadGateway, err: err} + } + + errc := make(chan error, 2) + copyer := switchProtocolCopier{user: clientConn, backend: backConn} + go copyer.copyToBackend(errc) + go copyer.copyFromBackend(errc) + + firstErr := <-errc + if firstErr == nil { + firstErr = <-errc + } + if errors.Is(firstErr, errReverseProxyCopyDone) || errors.Is(firstErr, net.ErrClosed) || errors.Is(firstErr, io.EOF) || errors.Is(firstErr, context.Canceled) { + return nil + } + return firstErr +} + +func (p *reverseProxyHandler) flushInterval(res *http.Response) time.Duration { + if baseType, _, _ := mime.ParseMediaType(res.Header.Get("Content-Type")); baseType == "text/event-stream" { + return -1 + } + if res.ContentLength == -1 { + return -1 + } + return p.config.FlushInterval +} + +func (p *reverseProxyHandler) copyResponse(dst ResponseWriter, src io.Reader, flushInterval time.Duration) error { + var writer io.Writer = dst + + if flushInterval != 0 { + mlw := &maxLatencyWriter{dst: dst, latency: flushInterval} + defer mlw.stop() + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + writer = mlw + } + + var buf []byte + if p.config.BufferPool != nil { + buf = p.config.BufferPool.Get() + defer p.config.BufferPool.Put(buf) + } + _, err := p.copyBuffer(writer, src, buf) + return err +} + +func (p *reverseProxyHandler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && !errors.Is(rerr, io.EOF) && !errors.Is(rerr, context.Canceled) { + p.logf(nil, "reverse proxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if errors.Is(rerr, io.EOF) { + return written, nil + } + return written, rerr + } + } +} + +func (p *reverseProxyHandler) logf(c *Context, format string, args ...any) { + if c != nil { + if logger := c.GetLogger(); logger != nil { + logger.Errorf(format, args...) + return + } + } + log.Printf(format, args...) +} + +func reverseProxyStatusCode(err error) int { + var statusErr *reverseProxyStatusError + if errors.As(err, &statusErr) && statusErr.status > 0 { + return statusErr.status + } + return http.StatusBadGateway +} + +func validateReverseProxyTarget(target *url.URL) error { + if target == nil { + return errReverseProxyNilTarget + } + if target.Scheme == "" || target.Host == "" { + return errReverseProxyInvalidTarget + } + return nil +} + +func normalizeReverseProxyTarget(target *url.URL) { + switch strings.ToLower(target.Scheme) { + case "ws": + target.Scheme = "http" + case "wss": + target.Scheme = "https" + } +} + +func cloneReverseProxyURL(target *url.URL) *url.URL { + if target == nil { + return nil + } + clone := *target + return &clone +} + +func reverseProxyReceivedBy(configValue string) string { + trimmed := strings.TrimSpace(configValue) + if trimmed != "" { + return trimmed + } + return "touka-engine" +} + +func reverseProxyClientIP(remoteAddr string) string { + if remoteAddr == "" { + return "" + } + if addrPort, err := netip.ParseAddrPort(remoteAddr); err == nil { + return addrPort.Addr().String() + } + host, _, err := net.SplitHostPort(remoteAddr) + if err == nil { + if addr, parseErr := netip.ParseAddr(host); parseErr == nil { + return addr.String() + } + return host + } + if addr, err := netip.ParseAddr(remoteAddr); err == nil { + return addr.String() + } + return remoteAddr +} + +func reverseProxyRequestScheme(req *http.Request) string { + if req == nil { + return "" + } + if req.TLS != nil { + return "https" + } + if req.URL != nil { + scheme := strings.ToLower(req.URL.Scheme) + if scheme != "" { + return scheme + } + } + return "http" +} + +func buildForwardedHeaderValue(clientIP, by, host, scheme string) string { + pairs := make([]string, 0, 4) + if by != "" { + pairs = append(pairs, "by="+formatForwardedParameterValue(by)) + } + if clientIP != "" { + pairs = append(pairs, "for="+formatForwardedFor(clientIP)) + } + if host != "" { + pairs = append(pairs, "host="+formatForwardedParameterValue(host)) + } + if scheme != "" { + pairs = append(pairs, "proto="+formatForwardedParameterValue(strings.ToLower(scheme))) + } + if len(pairs) == 0 { + return "" + } + return strings.Join(pairs, ";") +} + +func formatForwardedFor(clientIP string) string { + addr, err := netip.ParseAddr(clientIP) + if err != nil { + return formatForwardedParameterValue(clientIP) + } + if addr.Is6() { + return quoteForwardedString("[" + addr.String() + "]") + } + return addr.String() +} + +func formatForwardedParameterValue(value string) string { + if isToken(value) { + return value + } + return quoteForwardedString(value) +} + +func quoteForwardedString(value string) string { + replacer := strings.NewReplacer(`\`, `\\`, `"`, `\"`) + return `"` + replacer.Replace(value) + `"` +} + +func isToken(value string) bool { + if value == "" { + return false + } + for i := 0; i < len(value); i++ { + if !isTokenChar(value[i]) { + return false + } + } + return true +} + +func isTokenChar(b byte) bool { + if b >= '0' && b <= '9' { + return true + } + if b >= 'A' && b <= 'Z' { + return true + } + if b >= 'a' && b <= 'z' { + return true + } + switch b { + case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~': + return true + default: + return false + } +} + +func appendViaHeader(header http.Header, protocol, receivedBy string) { + if header == nil || receivedBy == "" { + return + } + if protocol == "" { + protocol = "1.1" + } + header.Add("Via", protocol+" "+receivedBy) +} + +func reverseProxyViaProtocol(major, minor int, raw string) string { + if major > 0 { + return strconv.Itoa(major) + "." + strconv.Itoa(minor) + } + if strings.HasPrefix(raw, "HTTP/") { + return strings.TrimPrefix(raw, "HTTP/") + } + return raw +} + +func rewriteReverseProxyURL(req *http.Request, target *url.URL) { + targetQuery := target.RawQuery + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path, req.URL.RawPath = joinReverseProxyURLPath(target, req.URL) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } +} + +func joinReverseProxyURLPath(base, incoming *url.URL) (string, string) { + if base.RawPath == "" && incoming.RawPath == "" { + return reverseProxySingleJoiningSlash(base.Path, incoming.Path), "" + } + + baseEscaped := base.EscapedPath() + incomingEscaped := incoming.EscapedPath() + + baseSlash := strings.HasSuffix(baseEscaped, "/") + incomingSlash := strings.HasPrefix(incomingEscaped, "/") + + switch { + case baseSlash && incomingSlash: + return base.Path + incoming.Path[1:], baseEscaped + incomingEscaped[1:] + case !baseSlash && !incomingSlash: + return base.Path + "/" + incoming.Path, baseEscaped + "/" + incomingEscaped + default: + return base.Path + incoming.Path, baseEscaped + incomingEscaped + } +} + +func reverseProxySingleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + default: + return a + b + } +} + +func reverseProxyCopyHeader(dst, src http.Header) { + for key, values := range src { + for _, value := range values { + dst.Add(key, value) + } + } +} + +var reverseProxyHopHeaders = []string{ + "Connection", + "Proxy-Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +func removeHopByHopHeaders(header http.Header) { + for _, connectionValue := range header["Connection"] { + for _, token := range strings.Split(connectionValue, ",") { + trimmed := textproto.TrimString(token) + if trimmed != "" { + header.Del(trimmed) + } + } + } + for _, hopHeader := range reverseProxyHopHeaders { + header.Del(hopHeader) + } +} + +func reverseProxyUpgradeType(header http.Header) string { + if !headerValuesContainToken(header["Connection"], "Upgrade") { + return "" + } + return header.Get("Upgrade") +} + +func headerValuesContainToken(values []string, token string) bool { + if token == "" { + return false + } + for _, value := range values { + for _, part := range strings.Split(value, ",") { + if strings.EqualFold(textproto.TrimString(part), token) { + return true + } + } + } + return false +} + +func cleanReverseProxyQueryParams(rawQuery string) string { + if rawQuery == "" { + return "" + } + // Normalize the outgoing query string so the proxy and upstream do not see + // different semantics for non-standard separators or malformed pairs. + // This can change the exact textual form of the original query and may drop + // parts that net/url rejects, but it keeps proxy-chain parsing behavior more + // consistent and reduces parameter-smuggling ambiguity. + values, _ := url.ParseQuery(rawQuery) + return values.Encode() +} + +func reverseProxyBaseResponseWriter(writer ResponseWriter) http.ResponseWriter { + return UnwrapResponseWriter(writer) +} + +func isPrintableASCII(value string) bool { + for i := 0; i < len(value); i++ { + if value[i] < 0x20 || value[i] > 0x7e { + return false + } + } + return true +} diff --git a/reverseproxy_test.go b/reverseproxy_test.go new file mode 100644 index 0000000..f82aff9 --- /dev/null +++ b/reverseproxy_test.go @@ -0,0 +1,570 @@ +package touka + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/http/httptrace" + "net/textproto" + "net/url" + "strings" + "testing" + "time" +) + +func TestReverseProxyForwardingAndHopHeaders(t *testing.T) { + t.Helper() + + type backendRequestSnapshot struct { + Path string + RawQuery string + Host string + Connection string + RemovedHeader string + Forwarded string + XForwardedFor string + XForwardedHost string + XForwardedProto string + Via []string + TE string + UserAgent string + } + + gotCh := make(chan backendRequestSnapshot, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotCh <- backendRequestSnapshot{ + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + Host: r.Host, + Connection: r.Header.Get("Connection"), + RemovedHeader: r.Header.Get("X-Remove-Me"), + Forwarded: r.Header.Get("Forwarded"), + XForwardedFor: r.Header.Get("X-Forwarded-For"), + XForwardedHost: r.Header.Get("X-Forwarded-Host"), + XForwardedProto: r.Header.Get("X-Forwarded-Proto"), + Via: append([]string(nil), r.Header.Values("Via")...), + TE: r.Header.Get("Te"), + UserAgent: r.Header.Get("User-Agent"), + } + + w.Header().Set("Connection", "X-Backend-Secret") + w.Header().Set("X-Backend-Secret", "remove-me") + w.Header().Add("Via", "1.0 upstream") + w.Header().Add("Trailer", "X-Upstream-Trailer") + w.Header().Set("Content-Type", "text/plain") + _, _ = io.WriteString(w, "proxied") + w.Header().Set("X-Upstream-Trailer", "done") + })) + defer backend.Close() + + target, err := url.Parse(backend.URL + "/base?from=target") + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/api/*path", ReverseProxy(ReverseProxyConfig{ + Target: target, + ForwardedHeaders: ForwardedBoth, + ForwardedBy: "proxy-node", + Via: "proxy.test", + })) + + req := httptest.NewRequest(http.MethodGet, "http://client.example/api/ping?bad=1;smuggle=2&q=2", nil) + req.Host = "client.example" + req.RemoteAddr = "198.51.100.10:4567" + req.Header.Set("Connection", "X-Remove-Me") + req.Header.Set("X-Remove-Me", "client-secret") + req.Header.Set("X-Forwarded-For", "203.0.113.9") + req.Header.Set("X-Forwarded-Host", "edge.example") + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("Forwarded", "for=203.0.113.9") + req.Header.Set("Te", "trailers") + + rr := httptest.NewRecorder() + engine.ServeHTTP(rr, req) + + resp := rr.Result() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + var got backendRequestSnapshot + select { + case got = <-gotCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend snapshot") + } + + if string(body) != "proxied" { + t.Fatalf("unexpected body: %q", string(body)) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if got.Path != "/base/api/ping" { + t.Fatalf("unexpected upstream path: %q", got.Path) + } + if got.RawQuery != "from=target&q=2" { + t.Fatalf("unexpected upstream raw query: %q", got.RawQuery) + } + if got.Host != strings.TrimPrefix(backend.URL, "http://") { + t.Fatalf("unexpected upstream host: %q", got.Host) + } + if got.Connection != "" { + t.Fatalf("connection header should be stripped, got %q", got.Connection) + } + if got.RemovedHeader != "" { + t.Fatalf("connection-token header should be stripped, got %q", got.RemovedHeader) + } + if got.XForwardedFor != "203.0.113.9, 198.51.100.10" { + t.Fatalf("unexpected X-Forwarded-For: %q", got.XForwardedFor) + } + if got.XForwardedHost != "edge.example" { + t.Fatalf("unexpected X-Forwarded-Host: %q", got.XForwardedHost) + } + if got.XForwardedProto != "https" { + t.Fatalf("unexpected X-Forwarded-Proto: %q", got.XForwardedProto) + } + if got.TE != "trailers" { + t.Fatalf("unexpected TE header: %q", got.TE) + } + if got.UserAgent != "" { + t.Fatalf("expected empty user-agent suppression, got %q", got.UserAgent) + } + if !strings.Contains(got.Forwarded, "for=203.0.113.9") { + t.Fatalf("forwarded header missing prior hop: %q", got.Forwarded) + } + if !strings.Contains(got.Forwarded, "for=198.51.100.10") { + t.Fatalf("forwarded header missing client ip: %q", got.Forwarded) + } + if !strings.Contains(got.Forwarded, "by=proxy-node") { + t.Fatalf("forwarded header missing by token: %q", got.Forwarded) + } + if !strings.Contains(got.Forwarded, "host=client.example") { + t.Fatalf("forwarded header missing host: %q", got.Forwarded) + } + if !strings.Contains(got.Forwarded, "proto=http") { + t.Fatalf("forwarded header missing proto: %q", got.Forwarded) + } + if len(got.Via) != 1 || got.Via[0] != "1.1 proxy.test" { + t.Fatalf("unexpected upstream Via headers: %#v", got.Via) + } + if resp.Header.Get("Connection") != "" { + t.Fatalf("response connection header should be stripped, got %q", resp.Header.Get("Connection")) + } + if resp.Header.Get("X-Backend-Secret") != "" { + t.Fatalf("response connection-token header should be stripped, got %q", resp.Header.Get("X-Backend-Secret")) + } + if gotVia := resp.Header.Values("Via"); len(gotVia) != 2 || gotVia[0] != "1.0 upstream" || gotVia[1] != "1.1 proxy.test" { + t.Fatalf("unexpected response Via headers: %#v", gotVia) + } + if resp.Trailer.Get("X-Upstream-Trailer") != "done" { + t.Fatalf("unexpected proxied trailer: %q", resp.Trailer.Get("X-Upstream-Trailer")) + } +} + +func TestReverseProxyDefaultViaFallback(t *testing.T) { + t.Helper() + + viaCh := make(chan []string, 1) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + viaCh <- append([]string(nil), r.Header.Values("Via")...) + w.WriteHeader(http.StatusNoContent) + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{Target: target})) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusNoContent { + t.Fatalf("unexpected status: %d", rr.Code) + } + + select { + case via := <-viaCh: + if len(via) != 1 || via[0] != "1.1 touka-engine" { + t.Fatalf("unexpected default Via header: %#v", via) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for backend Via header") + } +} + +func TestReverseProxyCustomErrorHandler(t *testing.T) { + t.Helper() + + engine := New() + target, err := url.Parse("http://127.0.0.1:1") + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: target, + ErrorHandler: func(w http.ResponseWriter, _ *http.Request, err error) { + w.WriteHeader(http.StatusGatewayTimeout) + _, _ = io.WriteString(w, fmt.Sprintf("proxy failure: %v", err)) + }, + })) + + rr := PerformRequest(engine, http.MethodGet, "/proxy", nil, nil) + if rr.Code != http.StatusGatewayTimeout { + t.Fatalf("unexpected status: %d", rr.Code) + } + if !strings.Contains(rr.Body.String(), "proxy failure:") { + t.Fatalf("unexpected body: %q", rr.Body.String()) + } +} + +func TestReverseProxyUnannouncedTrailerForwarding(t *testing.T) { + t.Helper() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "later") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "streamed") + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/trailers", ReverseProxy(ReverseProxyConfig{Target: target})) + + rr := PerformRequest(engine, http.MethodGet, "/trailers", nil, nil) + resp := rr.Result() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if string(body) != "streamed" { + t.Fatalf("unexpected body: %q", string(body)) + } + if got := resp.Trailer.Get("X-Unannounced-Trailer"); got != "later" { + t.Fatalf("unexpected unannounced trailer: %q", got) + } +} + +func TestReverseProxyProtocolUpgrade(t *testing.T) { + t.Helper() + + errCh := make(chan error, 8) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !headerValuesContainToken(r.Header["Connection"], "Upgrade") { + errCh <- fmt.Errorf("missing upgrade connection header: %#v", r.Header.Values("Connection")) + w.WriteHeader(http.StatusBadRequest) + return + } + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + errCh <- fmt.Errorf("unexpected upgrade header: %q", r.Header.Get("Upgrade")) + w.WriteHeader(http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("backend response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend flush failed: %w", err) + return + } + + line, err := brw.ReadString('\n') + if err != nil { + errCh <- fmt.Errorf("backend read failed: %w", err) + return + } + _, _ = io.WriteString(brw, "echo:"+line) + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend echo flush failed: %w", err) + return + } + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/ws", ReverseProxy(ReverseProxyConfig{ + Target: target, + Via: "proxy.test", + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + + _, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n") + if err != nil { + t.Fatalf("write upgrade request: %v", err) + } + + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read status line: %v", err) + } + if !strings.Contains(statusLine, "101") { + t.Fatalf("unexpected status line: %q", statusLine) + } + + headers, err := textproto.NewReader(reader).ReadMIMEHeader() + if err != nil { + t.Fatalf("read headers: %v", err) + } + respHeader := http.Header(headers) + if !strings.EqualFold(respHeader.Get("Upgrade"), "websocket") { + t.Fatalf("unexpected upgrade response header: %q", respHeader.Get("Upgrade")) + } + if !headerValuesContainToken(respHeader.Values("Connection"), "Upgrade") { + t.Fatalf("unexpected connection response header: %#v", respHeader.Values("Connection")) + } + if gotVia := respHeader.Values("Via"); len(gotVia) != 1 || gotVia[0] != "1.1 proxy.test" { + t.Fatalf("unexpected Via response header: %#v", gotVia) + } + + if _, err := io.WriteString(conn, "ping\n"); err != nil { + t.Fatalf("write tunneled payload: %v", err) + } + message, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("read tunneled payload: %v", err) + } + if message != "echo:ping\n" { + t.Fatalf("unexpected tunneled payload: %q", message) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyRejectsEmptyUpgradeProtocol(t *testing.T) { + t.Helper() + + errCh := make(chan error, 4) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + errCh <- errors.New("backend response writer does not support hijack") + return + } + conn, brw, err := hj.Hijack() + if err != nil { + errCh <- fmt.Errorf("backend hijack failed: %w", err) + return + } + defer conn.Close() + + _, _ = io.WriteString(brw, "HTTP/1.1 101 Switching Protocols\r\n\r\n") + if err := brw.Flush(); err != nil { + errCh <- fmt.Errorf("backend flush failed: %w", err) + return + } + })) + defer backend.Close() + + target, err := url.Parse(backend.URL) + if err != nil { + t.Fatalf("parse target: %v", err) + } + + engine := New() + engine.GET("/ws", ReverseProxy(ReverseProxyConfig{Target: target})) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + conn, err := net.DialTimeout("tcp", proxy.Listener.Addr().String(), 5*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + + _, err = fmt.Fprintf(conn, "GET /ws HTTP/1.1\r\nHost: client.example\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n") + if err != nil { + t.Fatalf("write upgrade request: %v", err) + } + + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } +} + +func TestReverseProxyRestoresHeadersAfter1xx(t *testing.T) { + t.Helper() + + type oneXXInfo struct { + code int + header http.Header + } + + backendTraceCh := make(chan struct{}, 1) + oneXXCh := make(chan oneXXInfo, 1) + + transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.Got1xxResponse == nil { + return nil, errors.New("missing Got1xxResponse trace") + } + backendTraceCh <- struct{}{} + if err := trace.Got1xxResponse(http.StatusEarlyHints, textproto.MIMEHeader{"Link": {"; rel=preload; as=style"}}); err != nil { + return nil, err + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + Body: io.NopCloser(strings.NewReader("ok")), + ContentLength: 2, + Request: req, + }, nil + }) + + engine := New() + engine.Use(func(c *Context) { + c.Writer.Header().Set("X-Request-Id", "req-123") + c.Next() + }) + engine.GET("/proxy", ReverseProxy(ReverseProxyConfig{ + Target: mustParseURL(t, "http://example.com"), + Transport: transport, + })) + + proxy := httptest.NewServer(engine) + defer proxy.Close() + + client := proxy.Client() + req, err := http.NewRequest(http.MethodGet, proxy.URL+"/proxy", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + oneXXCh <- oneXXInfo{code: code, header: http.Header(header).Clone()} + return nil + }, + })) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("perform request: %v", err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + _ = resp.Body.Close() + + select { + case <-backendTraceCh: + case <-time.After(2 * time.Second): + t.Fatal("expected proxy transport 1xx trace to be invoked") + } + + var oneXX oneXXInfo + select { + case oneXX = <-oneXXCh: + case <-time.After(2 * time.Second): + t.Fatal("expected client to receive 1xx response") + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } + if string(body) != "ok" { + t.Fatalf("unexpected body: %q", string(body)) + } + if got := resp.Header.Get("X-Request-Id"); got != "req-123" { + t.Fatalf("final response lost preserved header: %q", got) + } + if got := resp.Header.Get("Link"); got != "" { + t.Fatalf("interim 1xx header leaked into final response: %q", got) + } + if oneXX.code != http.StatusEarlyHints { + t.Fatalf("unexpected interim status: %d", oneXX.code) + } + if got := oneXX.header.Get("Link"); got != "; rel=preload; as=style" { + t.Fatalf("unexpected interim Link header: %q", got) + } + if got := oneXX.header.Get("X-Request-Id"); got != "" { + t.Fatalf("final-only header leaked into interim response: %q", got) + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + if err != nil { + t.Fatalf("parse url %q: %v", raw, err) + } + return u +} diff --git a/serve.go b/serve.go index 7e05b8c..f3ddc5f 100644 --- a/serve.go +++ b/serve.go @@ -211,7 +211,7 @@ func (engine *Engine) Run(addr ...string) error { srv := &http.Server{Addr: address, Handler: engine} // 即使是不支持优雅关闭的 Run,也应用默认和用户配置,以保持行为一致性 - //engine.applyDefaultServerConfig(srv) + engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } @@ -224,10 +224,14 @@ func (engine *Engine) RunShutdown(addr string, timeouts ...time.Duration) error srv := &http.Server{ Addr: addr, Handler: engine, + BaseContext: func(l net.Listener) context.Context { + return engine.shutdownCtx + }, } + srv.RegisterOnShutdown(engine.shutdownCancel) // 应用框架的默认配置和用户提供的自定义配置 - //engine.applyDefaultServerConfig(srv) + engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } @@ -241,10 +245,14 @@ func (engine *Engine) RunShutdownWithContext(addr string, ctx context.Context, t srv := &http.Server{ Addr: addr, Handler: engine, + BaseContext: func(l net.Listener) context.Context { + return engine.shutdownCtx + }, } + srv.RegisterOnShutdown(engine.shutdownCancel) // 应用框架的默认配置和用户提供的自定义配置 - //engine.applyDefaultServerConfig(srv) + engine.applyDefaultServerConfig(srv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(srv) } @@ -260,7 +268,7 @@ func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...tim // 配置 HTTP/2 支持 (如果使用默认配置) if engine.useDefaultProtocols { - engine.SetProtocols(&ProtocolsConfig{ + engine.setProtocols(&ProtocolsConfig{ Http1: true, Http2: true, // 默认在 TLS 上启用 HTTP/2 }) @@ -270,11 +278,15 @@ func (engine *Engine) RunTLS(addr string, tlsConfig *tls.Config, timeouts ...tim Addr: addr, Handler: engine, TLSConfig: tlsConfig, + BaseContext: func(l net.Listener) context.Context { + return engine.shutdownCtx + }, } + srv.RegisterOnShutdown(engine.shutdownCancel) // 应用框架的默认配置和用户提供的自定义配置 // 优先使用 TLSServerConfigurator,如果未设置,则回退到通用的 ServerConfigurator - //engine.applyDefaultServerConfig(srv) + engine.applyDefaultServerConfig(srv) if engine.TLSServerConfigurator != nil { engine.TLSServerConfigurator(srv) } else if engine.ServerConfigurator != nil { @@ -298,14 +310,18 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con // --- HTTPS 服务器 --- if engine.useDefaultProtocols { - engine.SetProtocols(&ProtocolsConfig{Http1: true, Http2: true}) + engine.setProtocols(&ProtocolsConfig{Http1: true, Http2: true}) } httpsSrv := &http.Server{ Addr: httpsAddr, Handler: engine, TLSConfig: tlsConfig, + BaseContext: func(l net.Listener) context.Context { + return engine.shutdownCtx + }, } - //engine.applyDefaultServerConfig(httpsSrv) + httpsSrv.RegisterOnShutdown(engine.shutdownCancel) + engine.applyDefaultServerConfig(httpsSrv) if engine.TLSServerConfigurator != nil { engine.TLSServerConfigurator(httpsSrv) } else if engine.ServerConfigurator != nil { @@ -339,7 +355,7 @@ func (engine *Engine) RunTLSRedir(httpAddr, httpsAddr string, tlsConfig *tls.Con Addr: httpAddr, Handler: redirectHandler, } - //engine.applyDefaultServerConfig(httpSrv) + engine.applyDefaultServerConfig(httpSrv) if engine.ServerConfigurator != nil { engine.ServerConfigurator(httpSrv) } diff --git a/sse.go b/sse.go index 3b98800..ab6c226 100644 --- a/sse.go +++ b/sse.go @@ -40,8 +40,8 @@ func (e *Event) Render(w io.Writer) error { buf.WriteString("\n") } if len(e.Data) > 0 { - lines := strings.Split(e.Data, "\n") - for _, line := range lines { + lines := strings.SplitSeq(e.Data, "\n") + for line := range lines { buf.WriteString("data: ") buf.WriteString(line) buf.WriteString("\n") diff --git a/touka.go b/touka.go index 837d62d..dd529cb 100644 --- a/touka.go +++ b/touka.go @@ -12,7 +12,7 @@ const ( defaultMemory = 32 << 20 // 32 MB, Gin 的默认值,用于 ParseMultipartForm ) -type H map[string]interface{} // map简写, 类似gin.H +type H map[string]any // map简写, 类似gin.H type Handle func(http.ResponseWriter, *http.Request, Params) diff --git a/tree.go b/tree.go index 09711a1..31246a5 100644 --- a/tree.go +++ b/tree.go @@ -5,7 +5,6 @@ package touka import ( - "bytes" "net/url" "strings" "unicode" @@ -27,12 +26,6 @@ func BytesToString(b []byte) string { return unsafe.String(unsafe.SliceData(b), len(b)) } -var ( - strColon = []byte(":") // 定义字节切片常量, 表示冒号, 用于路径参数识别 - strStar = []byte("*") // 定义字节切片常量, 表示星号, 用于捕获所有路径识别 - strSlash = []byte("/") // 定义字节切片常量, 表示斜杠, 用于路径分隔符识别 -) - // Param 是单个 URL 参数, 由键和值组成. type Param struct { Key string // 参数的键名 @@ -106,17 +99,14 @@ func (n *node) addChild(child *node) { // countParams 计算路径中参数(冒号)和捕获所有(星号)的数量. func countParams(path string) uint16 { - var n uint16 - s := StringToBytes(path) // 将路径字符串转换为字节切片 - n += uint16(bytes.Count(s, strColon)) // 统计冒号的数量 - n += uint16(bytes.Count(s, strStar)) // 统计星号的数量 - return n + colons := strings.Count(path, ":") + stars := strings.Count(path, "*") + return uint16(colons + stars) } // countSections 计算路径中斜杠('/')的数量, 即路径段的数量. func countSections(path string) uint16 { - s := StringToBytes(path) // 将路径字符串转换为字节切片 - return uint16(bytes.Count(s, strSlash)) // 统计斜杠的数量 + return uint16(strings.Count(path, "/")) } // nodeType 定义了节点的类型. @@ -418,10 +408,10 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) fullPath: fullPath, // 设置完整路径 } - n.addChild(child) // 添加子节点 - n.indices = string('/') // 索引设置为 '/' - n = child // 移动到新创建的 catchAll 节点 - n.priority++ // 增加优先级 + n.addChild(child) // 添加子节点 + n.indices = "/" // 索引设置为 '/' + n = child // 移动到新创建的 catchAll 节点 + n.priority++ // 增加优先级 // 第二个节点: 包含变量的节点 child = &node{