mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
Merge pull request #91 from infinite-iroha/feat/httpc-context-integration
Some checks are pending
Go Test / test (push) Waiting to run
Some checks are pending
Go Test / test (push) Waiting to run
feat: httpc 集成、MergeCtx cause 传播
This commit is contained in:
commit
01395dc942
7 changed files with 701 additions and 59 deletions
17
compat.go
17
compat.go
|
|
@ -4,7 +4,12 @@
|
||||||
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||||
package touka
|
package touka
|
||||||
|
|
||||||
import "github.com/fenthope/reco"
|
import (
|
||||||
|
"github.com/WJQSERVER-STUDIO/httpc"
|
||||||
|
"github.com/fenthope/reco"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- reco 兼容函数 ---
|
||||||
|
|
||||||
// GetLogReco 返回底层的 reco.Logger 实例
|
// GetLogReco 返回底层的 reco.Logger 实例
|
||||||
// 用于需要访问 reco 特定功能的场景
|
// 用于需要访问 reco 特定功能的场景
|
||||||
|
|
@ -35,3 +40,13 @@ func (c *Context) GetLoggerReco() *reco.Logger {
|
||||||
}
|
}
|
||||||
return c.engine.LogReco
|
return c.engine.LogReco
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- httpc 兼容函数 ---
|
||||||
|
|
||||||
|
// GetHTTPC 返回底层的 httpc.Client 实例
|
||||||
|
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
|
||||||
|
//
|
||||||
|
//go:fix inline
|
||||||
|
func (c *Context) GetHTTPC() *httpc.Client {
|
||||||
|
return c.Client()
|
||||||
|
}
|
||||||
|
|
|
||||||
30
context.go
30
context.go
|
|
@ -864,10 +864,29 @@ func (c *Context) GetErrors() []error {
|
||||||
return c.Errors
|
return c.Errors
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client 返回 Engine 提供的 HTTPClient
|
// Client 返回当前请求的 HTTPClient
|
||||||
// 方便在请求处理函数中进行出站 HTTP 请求
|
// 如果请求处理函数或中间件设置了自定义 HTTPClient,返回该实例;
|
||||||
|
// 否则返回 Engine 提供的默认实例
|
||||||
|
//
|
||||||
|
// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context
|
||||||
func (c *Context) Client() *httpc.Client {
|
func (c *Context) Client() *httpc.Client {
|
||||||
return c.HTTPClient
|
if c.HTTPClient != nil {
|
||||||
|
return c.HTTPClient
|
||||||
|
}
|
||||||
|
return c.engine.HTTPClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPC 返回自动关联请求 Context 的 HTTP 客户端
|
||||||
|
// 当请求被取消时,通过此客户端发起的出站请求也会自动取消
|
||||||
|
func (c *Context) HTTPC() *contextHTTPClient {
|
||||||
|
client := c.HTTPClient
|
||||||
|
if client == nil {
|
||||||
|
client = c.engine.HTTPClient
|
||||||
|
}
|
||||||
|
return &contextHTTPClient{
|
||||||
|
client: client,
|
||||||
|
ctx: c.ctx,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Context() 返回请求的上下文,用于取消操作
|
// Context() 返回请求的上下文,用于取消操作
|
||||||
|
|
@ -1130,11 +1149,6 @@ func (c *Context) GetProtocol() string {
|
||||||
return c.Request.Proto
|
return c.Request.Proto
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHTTPC 获取框架自带传递的httpc
|
|
||||||
func (c *Context) GetHTTPC() *httpc.Client {
|
|
||||||
return c.HTTPClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLogger 获取engine的Logger接口
|
// GetLogger 获取engine的Logger接口
|
||||||
func (c *Context) GetLogger() Logger {
|
func (c *Context) GetLogger() Logger {
|
||||||
return c.engine.logger
|
return c.engine.logger
|
||||||
|
|
|
||||||
58
context_httpc.go
Normal file
58
context_httpc.go
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
// Copyright 2024 WJQSERVER. All rights reserved.
|
||||||
|
// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization.
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/WJQSERVER-STUDIO/httpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// contextHTTPClient 包装 httpc.Client,自动关联请求的 Context
|
||||||
|
// 当请求被取消时,出站 HTTP 请求也会自动取消
|
||||||
|
type contextHTTPClient struct {
|
||||||
|
client *httpc.Client
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequestBuilder 创建请求构建器,自动关联请求 Context
|
||||||
|
func (c *contextHTTPClient) NewRequestBuilder(method, urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.NewRequestBuilder(method, urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GET 创建 GET 请求构建器
|
||||||
|
func (c *contextHTTPClient) GET(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.GET(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// POST 创建 POST 请求构建器
|
||||||
|
func (c *contextHTTPClient) POST(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.POST(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PUT 创建 PUT 请求构建器
|
||||||
|
func (c *contextHTTPClient) PUT(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.PUT(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DELETE 创建 DELETE 请求构建器
|
||||||
|
func (c *contextHTTPClient) DELETE(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.DELETE(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PATCH 创建 PATCH 请求构建器
|
||||||
|
func (c *contextHTTPClient) PATCH(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.PATCH(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HEAD 创建 HEAD 请求构建器
|
||||||
|
func (c *contextHTTPClient) HEAD(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.HEAD(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OPTIONS 创建 OPTIONS 请求构建器
|
||||||
|
func (c *contextHTTPClient) OPTIONS(urlStr string) *httpc.RequestBuilder {
|
||||||
|
return c.client.OPTIONS(urlStr).WithContext(c.ctx)
|
||||||
|
}
|
||||||
188
docs/httpc.md
Normal file
188
docs/httpc.md
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
# HTTP Client (httpc)
|
||||||
|
|
||||||
|
Touka 内置了 [httpc](https://github.com/WJQSERVER-STUDIO/httpc) HTTP 客户端,方便在请求处理函数中发起出站 HTTP 请求。
|
||||||
|
|
||||||
|
## 核心特性
|
||||||
|
|
||||||
|
- **自动 Context 关联**:使用 `HTTPC()` 方法时,出站请求会自动关联当前请求的 Context
|
||||||
|
- **请求取消传播**:当客户端断开连接时,出站请求会自动取消,避免资源泄漏
|
||||||
|
- **链式调用**:保持 httpc 原有的组合式构建器风格
|
||||||
|
|
||||||
|
## 基本用法
|
||||||
|
|
||||||
|
### 简单 GET 请求
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/proxy", func(c *touka.Context) {
|
||||||
|
body, err := c.HTTPC().
|
||||||
|
GET("https://api.example.com/data").
|
||||||
|
Text()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.String(200, body)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### POST JSON 请求
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.POST("/users", func(c *touka.Context) {
|
||||||
|
var req struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
c.ShouldBindJSON(&req)
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.HTTPC().
|
||||||
|
POST("https://api.example.com/users").
|
||||||
|
SetHeader("Authorization", "Bearer "+token).
|
||||||
|
SetJSONBody(req).
|
||||||
|
DecodeJSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, result)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 带查询参数
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/search", func(c *touka.Context) {
|
||||||
|
query := c.Query("q")
|
||||||
|
|
||||||
|
var result SearchResult
|
||||||
|
err := c.HTTPC().
|
||||||
|
GET("https://api.example.com/search").
|
||||||
|
SetQueryParam("q", query).
|
||||||
|
SetQueryParam("limit", "10").
|
||||||
|
DecodeJSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, result)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## API 对比
|
||||||
|
|
||||||
|
### 旧方式(Deprecated)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 需要手动 WithContext,容易忘记
|
||||||
|
resp, err := c.Client().
|
||||||
|
WithContext(c.Context()).
|
||||||
|
GET(url).
|
||||||
|
Execute()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 新方式(推荐)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 自动关联请求 Context
|
||||||
|
resp, err := c.HTTPC().
|
||||||
|
GET(url).
|
||||||
|
Execute()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Context 取消机制
|
||||||
|
|
||||||
|
使用 `HTTPC()` 时,当客户端断开连接(如关闭浏览器),出站请求会自动取消:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.GET("/long-task", func(c *touka.Context) {
|
||||||
|
// 这个请求会在客户端断开时自动取消
|
||||||
|
resp, err := c.HTTPC().
|
||||||
|
GET("https://slow-api.example.com/data").
|
||||||
|
Execute()
|
||||||
|
|
||||||
|
// 如果客户端已断开,err 会包含 context.Canceled
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return // 客户端已断开,无需处理
|
||||||
|
}
|
||||||
|
// ...
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## 完整 API
|
||||||
|
|
||||||
|
### contextHTTPClient 方法
|
||||||
|
|
||||||
|
| 方法 | 返回类型 | 说明 |
|
||||||
|
|------|----------|------|
|
||||||
|
| `NewRequestBuilder(method, url)` | `*httpc.RequestBuilder` | 创建通用请求构建器 |
|
||||||
|
| `GET(url)` | `*httpc.RequestBuilder` | 创建 GET 请求 |
|
||||||
|
| `POST(url)` | `*httpc.RequestBuilder` | 创建 POST 请求 |
|
||||||
|
| `PUT(url)` | `*httpc.RequestBuilder` | 创建 PUT 请求 |
|
||||||
|
| `DELETE(url)` | `*httpc.RequestBuilder` | 创建 DELETE 请求 |
|
||||||
|
| `PATCH(url)` | `*httpc.RequestBuilder` | 创建 PATCH 请求 |
|
||||||
|
| `HEAD(url)` | `*httpc.RequestBuilder` | 创建 HEAD 请求 |
|
||||||
|
| `OPTIONS(url)` | `*httpc.RequestBuilder` | 创建 OPTIONS 请求 |
|
||||||
|
|
||||||
|
### httpc.RequestBuilder 链式方法
|
||||||
|
|
||||||
|
返回 `*httpc.RequestBuilder`(用于链式调用):
|
||||||
|
|
||||||
|
| 方法 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `WithContext(ctx)` | 设置 Context(通常不需要,已自动关联) |
|
||||||
|
| `NoDefaultHeaders()` | 不添加默认 Header |
|
||||||
|
| `SetHeader(key, value)` | 设置 Header |
|
||||||
|
| `AddHeader(key, value)` | 添加 Header(可重复) |
|
||||||
|
| `SetHeaders(map)` | 批量设置 Headers |
|
||||||
|
| `SetQueryParam(key, value)` | 设置查询参数 |
|
||||||
|
| `AddQueryParam(key, value)` | 添加查询参数(可重复) |
|
||||||
|
| `SetQueryParams(map)` | 批量设置查询参数 |
|
||||||
|
| `SetBody(io.Reader)` | 设置请求 Body |
|
||||||
|
| `SetRawBody([]byte)` | 设置字节 Body |
|
||||||
|
|
||||||
|
返回 `(*httpc.RequestBuilder, error)`(可能失败):
|
||||||
|
|
||||||
|
| 方法 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `SetJSONBody(any)` | 设置 JSON Body |
|
||||||
|
| `SetXMLBody(any)` | 设置 XML Body |
|
||||||
|
| `SetGOBBody(any)` | 设置 GOB Body |
|
||||||
|
|
||||||
|
### 终结方法
|
||||||
|
|
||||||
|
| 方法 | 返回类型 | 说明 |
|
||||||
|
|------|----------|------|
|
||||||
|
| `Build()` | `(*http.Request, error)` | 构建请求但不执行 |
|
||||||
|
| `Execute()` | `(*http.Response, error)` | 执行并返回原始响应 |
|
||||||
|
| `DecodeJSON(v)` | `error` | 执行并解码 JSON |
|
||||||
|
| `DecodeXML(v)` | `error` | 执行并解码 XML |
|
||||||
|
| `DecodeGOB(v)` | `error` | 执行并解码 GOB |
|
||||||
|
| `Text()` | `(string, error)` | 执行并返回文本 |
|
||||||
|
| `Bytes()` | `([]byte, error)` | 执行并返回字节 |
|
||||||
|
| `SSE()` | `(*SSEStream, error)` | 建立 SSE 流连接 |
|
||||||
|
|
||||||
|
## 迁移指南
|
||||||
|
|
||||||
|
### go:fix inline 兼容
|
||||||
|
|
||||||
|
旧代码 `c.GetHTTPC()` 可通过 `go fix` 自动迁移到 `c.Client()`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go fix ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 手动迁移
|
||||||
|
|
||||||
|
| 旧代码 | 新代码 |
|
||||||
|
|--------|--------|
|
||||||
|
| `c.GetHTTPC()` | `c.Client()` 或 `c.HTTPC()` |
|
||||||
|
| `c.Client().WithContext(ctx).GET(url)` | `c.HTTPC().GET(url)` |
|
||||||
|
|
||||||
|
## 示例
|
||||||
|
|
||||||
|
完整示例请参考 [examples/httpc](../examples/httpc)。
|
||||||
103
examples/httpc/main.go
Normal file
103
examples/httpc/main.go
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/infinite-iroha/touka"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
r := touka.Default()
|
||||||
|
|
||||||
|
// 示例 1:简单 GET 请求(自动关联请求 Context)
|
||||||
|
r.GET("/proxy", func(c *touka.Context) {
|
||||||
|
// 使用 HTTPC() 方法,自动关联请求 Context
|
||||||
|
// 当客户端断开连接时,出站请求也会自动取消
|
||||||
|
body, err := c.HTTPC().
|
||||||
|
GET("https://httpbin.org/get").
|
||||||
|
Text()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.String(http.StatusOK, "%s", body)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 示例 2:带 Header 的 POST 请求
|
||||||
|
r.POST("/users", func(c *touka.Context) {
|
||||||
|
var req struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 链式调用,保持 httpc 风格
|
||||||
|
// 注意:SetJSONBody 返回 (*RequestBuilder, error)
|
||||||
|
rb, err := c.HTTPC().
|
||||||
|
POST("https://httpbin.org/post").
|
||||||
|
SetHeader("X-API-Key", "secret").
|
||||||
|
SetJSONBody(req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := rb.DecodeJSON(&result); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 示例 3:带查询参数的请求
|
||||||
|
r.GET("/search", func(c *touka.Context) {
|
||||||
|
query := c.DefaultQuery("q", "")
|
||||||
|
page := c.DefaultQuery("page", "1")
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Items []string `json:"items"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.HTTPC().
|
||||||
|
GET("https://httpbin.org/get").
|
||||||
|
SetQueryParam("q", query).
|
||||||
|
SetQueryParam("page", page).
|
||||||
|
DecodeJSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 示例 4:使用底层 httpc.Client(旧方式,仍可用但不推荐)
|
||||||
|
r.GET("/legacy", func(c *touka.Context) {
|
||||||
|
// 旧方式:需要手动 WithContext
|
||||||
|
body, err := c.Client().
|
||||||
|
GET("https://httpbin.org/get").
|
||||||
|
WithContext(c.Context()).
|
||||||
|
Text()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.String(http.StatusOK, "%s", body)
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Println("Server running on :8080")
|
||||||
|
fmt.Println("Try:")
|
||||||
|
fmt.Println(" curl http://localhost:8080/proxy")
|
||||||
|
fmt.Println(" curl -X POST -d '{\"name\":\"test\",\"email\":\"test@example.com\"}' http://localhost:8080/users")
|
||||||
|
fmt.Println(" curl 'http://localhost:8080/search?q=golang&page=1'")
|
||||||
|
|
||||||
|
// r.Run(touka.WithAddr(":8080"))
|
||||||
|
}
|
||||||
108
mergectx.go
108
mergectx.go
|
|
@ -11,18 +11,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
|
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
|
||||||
|
// 嵌入 cancelCtx 作为基础 context, 支持 cause 传播.
|
||||||
|
// deadlineCtx 作为 cancelCtx 的子 context, 确保 deadline 到期时 cancelCtx 也被取消.
|
||||||
type mergedContext struct {
|
type mergedContext struct {
|
||||||
// 嵌入一个基础 context, 它持有最早的 deadline 和取消信号.
|
|
||||||
context.Context
|
context.Context
|
||||||
// 保存了所有的父 context, 用于 Value() 方法的查找.
|
|
||||||
parents []context.Context
|
parents []context.Context
|
||||||
// 用于手动取消此 mergedContext 的函数.
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MergeCtx 创建并返回一个新的 context.Context.
|
// MergeCtx 创建并返回一个新的 context.Context.
|
||||||
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
|
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
|
||||||
// 自动被取消 (逻辑或关系).
|
// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context.
|
||||||
//
|
//
|
||||||
// 新的 context 会继承:
|
// 新的 context 会继承:
|
||||||
// - Deadline: 所有父 context 中最早的截止时间.
|
// - Deadline: 所有父 context 中最早的截止时间.
|
||||||
|
|
@ -32,7 +30,8 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
|
||||||
return context.WithCancel(context.Background())
|
return context.WithCancel(context.Background())
|
||||||
}
|
}
|
||||||
if len(parents) == 1 {
|
if len(parents) == 1 {
|
||||||
return context.WithCancel(parents[0])
|
ctx, cancel := context.WithCancelCause(parents[0])
|
||||||
|
return ctx, func() { cancel(nil) }
|
||||||
}
|
}
|
||||||
|
|
||||||
var earliestDeadline time.Time
|
var earliestDeadline time.Time
|
||||||
|
|
@ -44,37 +43,71 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var baseCtx context.Context
|
// cancelCtx 作为基础 context, 提供 CancelCauseFunc 以支持 cause 传播.
|
||||||
var baseCancel context.CancelFunc
|
cancelCtx, cancelCause := context.WithCancelCause(context.Background())
|
||||||
|
|
||||||
|
// deadlineCtx 作为 cancelCtx 的子 context (如果有 deadline).
|
||||||
|
// 当 cancelCtx 被取消时, deadlineCtx 也会被取消;
|
||||||
|
// 当 deadline 到期时, deadlineCtx 自行取消, watcher 负责关闭 cancelCtx.
|
||||||
|
var deadlineCtx context.Context
|
||||||
|
var deadlineCancel context.CancelFunc
|
||||||
if !earliestDeadline.IsZero() {
|
if !earliestDeadline.IsZero() {
|
||||||
baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline)
|
deadlineCtx, deadlineCancel = context.WithDeadlineCause(cancelCtx, earliestDeadline, context.DeadlineExceeded)
|
||||||
} else {
|
}
|
||||||
baseCtx, baseCancel = context.WithCancel(context.Background())
|
|
||||||
|
// 嵌入的 context: 有 deadline 时用 deadlineCtx (以返回正确的 Deadline),
|
||||||
|
// 否则用 cancelCtx.
|
||||||
|
embedCtx := cancelCtx
|
||||||
|
if deadlineCtx != nil {
|
||||||
|
embedCtx = deadlineCtx
|
||||||
}
|
}
|
||||||
|
|
||||||
mc := &mergedContext{
|
mc := &mergedContext{
|
||||||
Context: baseCtx,
|
Context: embedCtx,
|
||||||
parents: parents,
|
parents: parents,
|
||||||
cancel: baseCancel,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动一个监控 goroutine.
|
// 启动监控 goroutine, 监听 parent 取消或 deadline 到期.
|
||||||
go func() {
|
go func() {
|
||||||
defer mc.cancel()
|
// 将 cancelCtx 加入 orDone, 确保手动 cancel() 时 orDone goroutine 能退出, 防止泄漏.
|
||||||
|
parentDone := orDone(append(mc.parents, cancelCtx)...)
|
||||||
|
|
||||||
// orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭.
|
if deadlineCtx != nil {
|
||||||
// 同时监听 baseCtx.Done() 以便支持手动取消.
|
defer deadlineCancel()
|
||||||
select {
|
select {
|
||||||
case <-orDone(mc.parents...):
|
case <-parentDone:
|
||||||
case <-mc.Context.Done():
|
// parent 取消或手动 cancel()
|
||||||
|
for _, p := range mc.parents {
|
||||||
|
if p.Err() != nil {
|
||||||
|
cancelCause(context.Cause(p))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 手动 cancel(), cause 已由 cancelCause() 设置
|
||||||
|
case <-deadlineCtx.Done():
|
||||||
|
// deadline 到期, 需要关闭 cancelCtx 并设置 cause
|
||||||
|
cancelCause(context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
<-parentDone
|
||||||
|
for _, p := range mc.parents {
|
||||||
|
if p.Err() != nil {
|
||||||
|
cancelCause(context.Cause(p))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return mc, mc.cancel
|
return mc, func() { cancelCause(nil) }
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value 返回当前Ctx Value
|
// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause),
|
||||||
|
// 再按传入顺序从 parents 中查找.
|
||||||
func (mc *mergedContext) Value(key any) any {
|
func (mc *mergedContext) Value(key any) any {
|
||||||
|
if v := mc.Context.Value(key); v != nil {
|
||||||
|
return v
|
||||||
|
}
|
||||||
for _, p := range mc.parents {
|
for _, p := range mc.parents {
|
||||||
if val := p.Value(key); val != nil {
|
if val := p.Value(key); val != nil {
|
||||||
return val
|
return val
|
||||||
|
|
@ -83,45 +116,20 @@ func (mc *mergedContext) Value(key any) any {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deadline 实现了 context.Context 的 Deadline 方法.
|
// Deadline, Done, Err 均由嵌入的 context.Context 提供.
|
||||||
func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) {
|
|
||||||
return mc.Context.Deadline()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Done 实现了 context.Context 的 Done 方法.
|
// orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭.
|
||||||
func (mc *mergedContext) Done() <-chan struct{} {
|
|
||||||
return mc.Context.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Err 实现了 context.Context 的 Err 方法.
|
|
||||||
func (mc *mergedContext) Err() error {
|
|
||||||
return mc.Context.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// orDone 是一个辅助函数, 返回一个 channel.
|
|
||||||
// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭.
|
|
||||||
// 这是一个非阻塞的、不会泄漏 goroutine 的实现.
|
|
||||||
func orDone(contexts ...context.Context) <-chan struct{} {
|
func orDone(contexts ...context.Context) <-chan struct{} {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
closeDone := func() {
|
|
||||||
once.Do(func() {
|
|
||||||
close(done)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// 为每个父 context 启动一个 goroutine.
|
|
||||||
for _, ctx := range contexts {
|
for _, ctx := range contexts {
|
||||||
go func(c context.Context) {
|
go func(c context.Context) {
|
||||||
select {
|
select {
|
||||||
case <-c.Done():
|
case <-c.Done():
|
||||||
closeDone()
|
once.Do(func() { close(done) })
|
||||||
case <-done:
|
case <-done:
|
||||||
// orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出.
|
|
||||||
}
|
}
|
||||||
}(ctx)
|
}(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
return done
|
return done
|
||||||
}
|
}
|
||||||
|
|
|
||||||
256
mergectx_test.go
Normal file
256
mergectx_test.go
Normal file
|
|
@ -0,0 +1,256 @@
|
||||||
|
package touka
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeCtx_NoParents(t *testing.T) {
|
||||||
|
ctx, cancel := MergeCtx()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
t.Fatal("expected no error before cancel")
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_SingleParent(t *testing.T) {
|
||||||
|
parent, parentCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(parent)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
t.Fatal("expected no error before parent cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
parentCancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after parent cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cancel1()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after p1 cancel")
|
||||||
|
}
|
||||||
|
// p2 should still be fine
|
||||||
|
if p2.Err() != nil {
|
||||||
|
t.Fatal("expected p2 to be unaffected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel1()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cancel2()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after p2 cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_ExternalCancel(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel1()
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after external cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_CausePropagation(t *testing.T) {
|
||||||
|
testErr := errors.New("test cause")
|
||||||
|
|
||||||
|
p1, cancel1 := context.WithCancelCause(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cancel1(testErr)
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after p1 cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
cause := context.Cause(ctx)
|
||||||
|
if cause != testErr {
|
||||||
|
t.Fatalf("expected cause %v, got %v", testErr, cause)
|
||||||
|
}
|
||||||
|
cancel1(nil) // cleanup (already cancelled, no-op)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) {
|
||||||
|
testErr := errors.New("second parent cause")
|
||||||
|
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancelCause(context.Background())
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cancel2(testErr)
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after p2 cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
cause := context.Cause(ctx)
|
||||||
|
if cause != testErr {
|
||||||
|
t.Fatalf("expected cause %v, got %v", testErr, cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel1()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_Deadline_Earliest(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
early := now.Add(100 * time.Millisecond)
|
||||||
|
late := now.Add(1 * time.Hour)
|
||||||
|
|
||||||
|
p1, cancel1 := context.WithDeadline(context.Background(), late)
|
||||||
|
p2, cancel2 := context.WithDeadline(context.Background(), early)
|
||||||
|
defer cancel1()
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
dl, ok := ctx.Deadline()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected deadline to be set")
|
||||||
|
}
|
||||||
|
if !dl.Equal(early) {
|
||||||
|
t.Fatalf("expected deadline %v, got %v", early, dl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_Deadline_Expires(t *testing.T) {
|
||||||
|
p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancelP()
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
t.Fatal("expected error after deadline expires")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_ValueLookup(t *testing.T) {
|
||||||
|
type key struct{}
|
||||||
|
p1 := context.WithValue(context.Background(), key{}, "from_p1")
|
||||||
|
p2 := context.WithValue(context.Background(), key{}, "from_p2")
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
val := ctx.Value(key{})
|
||||||
|
if val != "from_p1" {
|
||||||
|
t.Fatalf("expected 'from_p1', got %v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) {
|
||||||
|
type key1 struct{}
|
||||||
|
type key2 struct{}
|
||||||
|
p1 := context.WithValue(context.Background(), key1{}, "val1")
|
||||||
|
p2 := context.WithValue(context.Background(), key2{}, "val2")
|
||||||
|
|
||||||
|
ctx, cancel := MergeCtx(p1, p2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if v := ctx.Value(key1{}); v != "val1" {
|
||||||
|
t.Fatalf("expected 'val1', got %v", v)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(key2{}); v != "val2" {
|
||||||
|
t.Fatalf("expected 'val2', got %v", v)
|
||||||
|
}
|
||||||
|
if v := ctx.Value("missing"); v != nil {
|
||||||
|
t.Fatalf("expected nil, got %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeCtx_ContextInterface(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel1()
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, _ = MergeCtx(p1, p2)
|
||||||
|
|
||||||
|
// Verify all Context interface methods work
|
||||||
|
_ = ctx.Done()
|
||||||
|
_ = ctx.Err()
|
||||||
|
_, _ = ctx.Deadline()
|
||||||
|
_ = ctx.Value("any")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrDone_SingleContext(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := orDone(ctx)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
<-done // should not block
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrDone_MultipleContexts(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
done := orDone(p1, p2)
|
||||||
|
|
||||||
|
cancel1()
|
||||||
|
<-done // should not block
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrDone_SecondContextCancels(t *testing.T) {
|
||||||
|
p1, cancel1 := context.WithCancel(context.Background())
|
||||||
|
p2, cancel2 := context.WithCancel(context.Background())
|
||||||
|
defer cancel1()
|
||||||
|
|
||||||
|
done := orDone(p1, p2)
|
||||||
|
|
||||||
|
cancel2()
|
||||||
|
<-done // should not block
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue