diff --git a/compat.go b/compat.go index 6a49c89..0be715d 100644 --- a/compat.go +++ b/compat.go @@ -4,7 +4,12 @@ // All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. package touka -import "github.com/fenthope/reco" +import ( + "github.com/WJQSERVER-STUDIO/httpc" + "github.com/fenthope/reco" +) + +// --- reco 兼容函数 --- // GetLogReco 返回底层的 reco.Logger 实例 // 用于需要访问 reco 特定功能的场景 @@ -35,3 +40,13 @@ func (c *Context) GetLoggerReco() *reco.Logger { } return c.engine.LogReco } + +// --- httpc 兼容函数 --- + +// GetHTTPC 返回底层的 httpc.Client 实例 +// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context +// +//go:fix inline +func (c *Context) GetHTTPC() *httpc.Client { + return c.Client() +} diff --git a/context.go b/context.go index 324386e..f21ed48 100644 --- a/context.go +++ b/context.go @@ -864,10 +864,29 @@ func (c *Context) GetErrors() []error { return c.Errors } -// Client 返回 Engine 提供的 HTTPClient -// 方便在请求处理函数中进行出站 HTTP 请求 +// Client 返回当前请求的 HTTPClient +// 如果请求处理函数或中间件设置了自定义 HTTPClient,返回该实例; +// 否则返回 Engine 提供的默认实例 +// +// Deprecated: 使用 HTTPC() 替代,新方法会自动关联请求 Context func (c *Context) Client() *httpc.Client { - return c.HTTPClient + if c.HTTPClient != nil { + return c.HTTPClient + } + return c.engine.HTTPClient +} + +// HTTPC 返回自动关联请求 Context 的 HTTP 客户端 +// 当请求被取消时,通过此客户端发起的出站请求也会自动取消 +func (c *Context) HTTPC() *contextHTTPClient { + client := c.HTTPClient + if client == nil { + client = c.engine.HTTPClient + } + return &contextHTTPClient{ + client: client, + ctx: c.ctx, + } } // Context() 返回请求的上下文,用于取消操作 @@ -1130,11 +1149,6 @@ func (c *Context) GetProtocol() string { return c.Request.Proto } -// GetHTTPC 获取框架自带传递的httpc -func (c *Context) GetHTTPC() *httpc.Client { - return c.HTTPClient -} - // GetLogger 获取engine的Logger接口 func (c *Context) GetLogger() Logger { return c.engine.logger diff --git a/context_httpc.go b/context_httpc.go new file mode 100644 index 0000000..3256a3b --- /dev/null +++ b/context_httpc.go @@ -0,0 +1,58 @@ +// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// Copyright 2024 WJQSERVER. All rights reserved. +// All rights reserved by WJQSERVER, related rights can be exercised by the infinite-iroha organization. +package touka + +import ( + "context" + + "github.com/WJQSERVER-STUDIO/httpc" +) + +// contextHTTPClient 包装 httpc.Client,自动关联请求的 Context +// 当请求被取消时,出站 HTTP 请求也会自动取消 +type contextHTTPClient struct { + client *httpc.Client + ctx context.Context +} + +// NewRequestBuilder 创建请求构建器,自动关联请求 Context +func (c *contextHTTPClient) NewRequestBuilder(method, urlStr string) *httpc.RequestBuilder { + return c.client.NewRequestBuilder(method, urlStr).WithContext(c.ctx) +} + +// GET 创建 GET 请求构建器 +func (c *contextHTTPClient) GET(urlStr string) *httpc.RequestBuilder { + return c.client.GET(urlStr).WithContext(c.ctx) +} + +// POST 创建 POST 请求构建器 +func (c *contextHTTPClient) POST(urlStr string) *httpc.RequestBuilder { + return c.client.POST(urlStr).WithContext(c.ctx) +} + +// PUT 创建 PUT 请求构建器 +func (c *contextHTTPClient) PUT(urlStr string) *httpc.RequestBuilder { + return c.client.PUT(urlStr).WithContext(c.ctx) +} + +// DELETE 创建 DELETE 请求构建器 +func (c *contextHTTPClient) DELETE(urlStr string) *httpc.RequestBuilder { + return c.client.DELETE(urlStr).WithContext(c.ctx) +} + +// PATCH 创建 PATCH 请求构建器 +func (c *contextHTTPClient) PATCH(urlStr string) *httpc.RequestBuilder { + return c.client.PATCH(urlStr).WithContext(c.ctx) +} + +// HEAD 创建 HEAD 请求构建器 +func (c *contextHTTPClient) HEAD(urlStr string) *httpc.RequestBuilder { + return c.client.HEAD(urlStr).WithContext(c.ctx) +} + +// OPTIONS 创建 OPTIONS 请求构建器 +func (c *contextHTTPClient) OPTIONS(urlStr string) *httpc.RequestBuilder { + return c.client.OPTIONS(urlStr).WithContext(c.ctx) +} diff --git a/docs/httpc.md b/docs/httpc.md new file mode 100644 index 0000000..8742c18 --- /dev/null +++ b/docs/httpc.md @@ -0,0 +1,188 @@ +# HTTP Client (httpc) + +Touka 内置了 [httpc](https://github.com/WJQSERVER-STUDIO/httpc) HTTP 客户端,方便在请求处理函数中发起出站 HTTP 请求。 + +## 核心特性 + +- **自动 Context 关联**:使用 `HTTPC()` 方法时,出站请求会自动关联当前请求的 Context +- **请求取消传播**:当客户端断开连接时,出站请求会自动取消,避免资源泄漏 +- **链式调用**:保持 httpc 原有的组合式构建器风格 + +## 基本用法 + +### 简单 GET 请求 + +```go +r.GET("/proxy", func(c *touka.Context) { + body, err := c.HTTPC(). + GET("https://api.example.com/data"). + Text() + if err != nil { + c.JSON(500, touka.H{"error": err.Error()}) + return + } + c.String(200, body) +}) +``` + +### POST JSON 请求 + +```go +r.POST("/users", func(c *touka.Context) { + var req struct { + Name string `json:"name"` + Email string `json:"email"` + } + c.ShouldBindJSON(&req) + + var result struct { + ID int `json:"id"` + Name string `json:"name"` + } + + err := c.HTTPC(). + POST("https://api.example.com/users"). + SetHeader("Authorization", "Bearer "+token). + SetJSONBody(req). + DecodeJSON(&result) + if err != nil { + c.JSON(500, touka.H{"error": err.Error()}) + return + } + c.JSON(200, result) +}) +``` + +### 带查询参数 + +```go +r.GET("/search", func(c *touka.Context) { + query := c.Query("q") + + var result SearchResult + err := c.HTTPC(). + GET("https://api.example.com/search"). + SetQueryParam("q", query). + SetQueryParam("limit", "10"). + DecodeJSON(&result) + if err != nil { + c.JSON(500, touka.H{"error": err.Error()}) + return + } + c.JSON(200, result) +}) +``` + +## API 对比 + +### 旧方式(Deprecated) + +```go +// 需要手动 WithContext,容易忘记 +resp, err := c.Client(). + WithContext(c.Context()). + GET(url). + Execute() +``` + +### 新方式(推荐) + +```go +// 自动关联请求 Context +resp, err := c.HTTPC(). + GET(url). + Execute() +``` + +## Context 取消机制 + +使用 `HTTPC()` 时,当客户端断开连接(如关闭浏览器),出站请求会自动取消: + +```go +r.GET("/long-task", func(c *touka.Context) { + // 这个请求会在客户端断开时自动取消 + resp, err := c.HTTPC(). + GET("https://slow-api.example.com/data"). + Execute() + + // 如果客户端已断开,err 会包含 context.Canceled + if errors.Is(err, context.Canceled) { + return // 客户端已断开,无需处理 + } + // ... +}) +``` + +## 完整 API + +### contextHTTPClient 方法 + +| 方法 | 返回类型 | 说明 | +|------|----------|------| +| `NewRequestBuilder(method, url)` | `*httpc.RequestBuilder` | 创建通用请求构建器 | +| `GET(url)` | `*httpc.RequestBuilder` | 创建 GET 请求 | +| `POST(url)` | `*httpc.RequestBuilder` | 创建 POST 请求 | +| `PUT(url)` | `*httpc.RequestBuilder` | 创建 PUT 请求 | +| `DELETE(url)` | `*httpc.RequestBuilder` | 创建 DELETE 请求 | +| `PATCH(url)` | `*httpc.RequestBuilder` | 创建 PATCH 请求 | +| `HEAD(url)` | `*httpc.RequestBuilder` | 创建 HEAD 请求 | +| `OPTIONS(url)` | `*httpc.RequestBuilder` | 创建 OPTIONS 请求 | + +### httpc.RequestBuilder 链式方法 + +返回 `*httpc.RequestBuilder`(用于链式调用): + +| 方法 | 说明 | +|------|------| +| `WithContext(ctx)` | 设置 Context(通常不需要,已自动关联) | +| `NoDefaultHeaders()` | 不添加默认 Header | +| `SetHeader(key, value)` | 设置 Header | +| `AddHeader(key, value)` | 添加 Header(可重复) | +| `SetHeaders(map)` | 批量设置 Headers | +| `SetQueryParam(key, value)` | 设置查询参数 | +| `AddQueryParam(key, value)` | 添加查询参数(可重复) | +| `SetQueryParams(map)` | 批量设置查询参数 | +| `SetBody(io.Reader)` | 设置请求 Body | +| `SetRawBody([]byte)` | 设置字节 Body | + +返回 `(*httpc.RequestBuilder, error)`(可能失败): + +| 方法 | 说明 | +|------|------| +| `SetJSONBody(any)` | 设置 JSON Body | +| `SetXMLBody(any)` | 设置 XML Body | +| `SetGOBBody(any)` | 设置 GOB Body | + +### 终结方法 + +| 方法 | 返回类型 | 说明 | +|------|----------|------| +| `Build()` | `(*http.Request, error)` | 构建请求但不执行 | +| `Execute()` | `(*http.Response, error)` | 执行并返回原始响应 | +| `DecodeJSON(v)` | `error` | 执行并解码 JSON | +| `DecodeXML(v)` | `error` | 执行并解码 XML | +| `DecodeGOB(v)` | `error` | 执行并解码 GOB | +| `Text()` | `(string, error)` | 执行并返回文本 | +| `Bytes()` | `([]byte, error)` | 执行并返回字节 | +| `SSE()` | `(*SSEStream, error)` | 建立 SSE 流连接 | + +## 迁移指南 + +### go:fix inline 兼容 + +旧代码 `c.GetHTTPC()` 可通过 `go fix` 自动迁移到 `c.Client()`: + +```bash +go fix ./... +``` + +### 手动迁移 + +| 旧代码 | 新代码 | +|--------|--------| +| `c.GetHTTPC()` | `c.Client()` 或 `c.HTTPC()` | +| `c.Client().WithContext(ctx).GET(url)` | `c.HTTPC().GET(url)` | + +## 示例 + +完整示例请参考 [examples/httpc](../examples/httpc)。 diff --git a/examples/httpc/main.go b/examples/httpc/main.go new file mode 100644 index 0000000..db2be4f --- /dev/null +++ b/examples/httpc/main.go @@ -0,0 +1,103 @@ +package main + +import ( + "fmt" + "net/http" + + "github.com/infinite-iroha/touka" +) + +func main() { + r := touka.Default() + + // 示例 1:简单 GET 请求(自动关联请求 Context) + r.GET("/proxy", func(c *touka.Context) { + // 使用 HTTPC() 方法,自动关联请求 Context + // 当客户端断开连接时,出站请求也会自动取消 + body, err := c.HTTPC(). + GET("https://httpbin.org/get"). + Text() + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.String(http.StatusOK, "%s", body) + }) + + // 示例 2:带 Header 的 POST 请求 + r.POST("/users", func(c *touka.Context) { + var req struct { + Name string `json:"name"` + Email string `json:"email"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, touka.H{"error": err.Error()}) + return + } + + var result struct { + ID int `json:"id"` + Name string `json:"name"` + } + + // 链式调用,保持 httpc 风格 + // 注意:SetJSONBody 返回 (*RequestBuilder, error) + rb, err := c.HTTPC(). + POST("https://httpbin.org/post"). + SetHeader("X-API-Key", "secret"). + SetJSONBody(req) + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + if err := rb.DecodeJSON(&result); err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, result) + }) + + // 示例 3:带查询参数的请求 + r.GET("/search", func(c *touka.Context) { + query := c.DefaultQuery("q", "") + page := c.DefaultQuery("page", "1") + + var result struct { + Items []string `json:"items"` + Total int `json:"total"` + } + + err := c.HTTPC(). + GET("https://httpbin.org/get"). + SetQueryParam("q", query). + SetQueryParam("page", page). + DecodeJSON(&result) + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, result) + }) + + // 示例 4:使用底层 httpc.Client(旧方式,仍可用但不推荐) + r.GET("/legacy", func(c *touka.Context) { + // 旧方式:需要手动 WithContext + body, err := c.Client(). + GET("https://httpbin.org/get"). + WithContext(c.Context()). + Text() + if err != nil { + c.JSON(http.StatusInternalServerError, touka.H{"error": err.Error()}) + return + } + c.String(http.StatusOK, "%s", body) + }) + + fmt.Println("Server running on :8080") + fmt.Println("Try:") + fmt.Println(" curl http://localhost:8080/proxy") + fmt.Println(" curl -X POST -d '{\"name\":\"test\",\"email\":\"test@example.com\"}' http://localhost:8080/users") + fmt.Println(" curl 'http://localhost:8080/search?q=golang&page=1'") + + // r.Run(touka.WithAddr(":8080")) +} diff --git a/mergectx.go b/mergectx.go index e5d3ec4..404f7b1 100644 --- a/mergectx.go +++ b/mergectx.go @@ -11,18 +11,16 @@ import ( ) // mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. +// 嵌入 cancelCtx 作为基础 context, 支持 cause 传播. +// deadlineCtx 作为 cancelCtx 的子 context, 确保 deadline 到期时 cancelCtx 也被取消. type mergedContext struct { - // 嵌入一个基础 context, 它持有最早的 deadline 和取消信号. context.Context - // 保存了所有的父 context, 用于 Value() 方法的查找. parents []context.Context - // 用于手动取消此 mergedContext 的函数. - cancel context.CancelFunc } // MergeCtx 创建并返回一个新的 context.Context. // 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时, -// 自动被取消 (逻辑或关系). +// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context. // // 新的 context 会继承: // - Deadline: 所有父 context 中最早的截止时间. @@ -32,7 +30,8 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C return context.WithCancel(context.Background()) } if len(parents) == 1 { - return context.WithCancel(parents[0]) + ctx, cancel := context.WithCancelCause(parents[0]) + return ctx, func() { cancel(nil) } } var earliestDeadline time.Time @@ -44,37 +43,71 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C } } - var baseCtx context.Context - var baseCancel context.CancelFunc + // cancelCtx 作为基础 context, 提供 CancelCauseFunc 以支持 cause 传播. + cancelCtx, cancelCause := context.WithCancelCause(context.Background()) + + // deadlineCtx 作为 cancelCtx 的子 context (如果有 deadline). + // 当 cancelCtx 被取消时, deadlineCtx 也会被取消; + // 当 deadline 到期时, deadlineCtx 自行取消, watcher 负责关闭 cancelCtx. + var deadlineCtx context.Context + var deadlineCancel context.CancelFunc if !earliestDeadline.IsZero() { - baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline) - } else { - baseCtx, baseCancel = context.WithCancel(context.Background()) + deadlineCtx, deadlineCancel = context.WithDeadlineCause(cancelCtx, earliestDeadline, context.DeadlineExceeded) + } + + // 嵌入的 context: 有 deadline 时用 deadlineCtx (以返回正确的 Deadline), + // 否则用 cancelCtx. + embedCtx := cancelCtx + if deadlineCtx != nil { + embedCtx = deadlineCtx } mc := &mergedContext{ - Context: baseCtx, + Context: embedCtx, parents: parents, - cancel: baseCancel, } - // 启动一个监控 goroutine. + // 启动监控 goroutine, 监听 parent 取消或 deadline 到期. go func() { - defer mc.cancel() + // 将 cancelCtx 加入 orDone, 确保手动 cancel() 时 orDone goroutine 能退出, 防止泄漏. + parentDone := orDone(append(mc.parents, cancelCtx)...) - // orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭. - // 同时监听 baseCtx.Done() 以便支持手动取消. - select { - case <-orDone(mc.parents...): - case <-mc.Context.Done(): + if deadlineCtx != nil { + defer deadlineCancel() + select { + case <-parentDone: + // parent 取消或手动 cancel() + for _, p := range mc.parents { + if p.Err() != nil { + cancelCause(context.Cause(p)) + return + } + } + // 手动 cancel(), cause 已由 cancelCause() 设置 + case <-deadlineCtx.Done(): + // deadline 到期, 需要关闭 cancelCtx 并设置 cause + cancelCause(context.DeadlineExceeded) + } + } else { + <-parentDone + for _, p := range mc.parents { + if p.Err() != nil { + cancelCause(context.Cause(p)) + return + } + } } }() - return mc, mc.cancel + return mc, func() { cancelCause(nil) } } -// Value 返回当前Ctx Value +// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause), +// 再按传入顺序从 parents 中查找. func (mc *mergedContext) Value(key any) any { + if v := mc.Context.Value(key); v != nil { + return v + } for _, p := range mc.parents { if val := p.Value(key); val != nil { return val @@ -83,45 +116,20 @@ func (mc *mergedContext) Value(key any) any { return nil } -// Deadline 实现了 context.Context 的 Deadline 方法. -func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) { - return mc.Context.Deadline() -} +// Deadline, Done, Err 均由嵌入的 context.Context 提供. -// Done 实现了 context.Context 的 Done 方法. -func (mc *mergedContext) Done() <-chan struct{} { - return mc.Context.Done() -} - -// Err 实现了 context.Context 的 Err 方法. -func (mc *mergedContext) Err() error { - return mc.Context.Err() -} - -// orDone 是一个辅助函数, 返回一个 channel. -// 当任意一个输入 context 的 Done() channel 关闭时, orDone 返回的 channel 也会关闭. -// 这是一个非阻塞的、不会泄漏 goroutine 的实现. +// orDone 返回一个 channel, 当任意一个输入 context 的 Done() channel 关闭时关闭. func orDone(contexts ...context.Context) <-chan struct{} { done := make(chan struct{}) - var once sync.Once - closeDone := func() { - once.Do(func() { - close(done) - }) - } - - // 为每个父 context 启动一个 goroutine. for _, ctx := range contexts { go func(c context.Context) { select { case <-c.Done(): - closeDone() + once.Do(func() { close(done) }) case <-done: - // orDone 已经被其他 goroutine 关闭了, 当前 goroutine 可以安全退出. } }(ctx) } - return done } diff --git a/mergectx_test.go b/mergectx_test.go new file mode 100644 index 0000000..d6d1225 --- /dev/null +++ b/mergectx_test.go @@ -0,0 +1,256 @@ +package touka + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestMergeCtx_NoParents(t *testing.T) { + ctx, cancel := MergeCtx() + defer cancel() + + if ctx.Err() != nil { + t.Fatal("expected no error before cancel") + } + cancel() + if ctx.Err() == nil { + t.Fatal("expected error after cancel") + } +} + +func TestMergeCtx_SingleParent(t *testing.T) { + parent, parentCancel := context.WithCancel(context.Background()) + + ctx, cancel := MergeCtx(parent) + defer cancel() + + if ctx.Err() != nil { + t.Fatal("expected no error before parent cancel") + } + + parentCancel() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after parent cancel") + } +} + +func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel1() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p1 cancel") + } + // p2 should still be fine + if p2.Err() != nil { + t.Fatal("expected p2 to be unaffected") + } +} + +func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel2() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p2 cancel") + } +} + +func TestMergeCtx_ExternalCancel(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + + cancel() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after external cancel") + } +} + +func TestMergeCtx_CausePropagation(t *testing.T) { + testErr := errors.New("test cause") + + p1, cancel1 := context.WithCancelCause(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel1(testErr) + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p1 cancel") + } + + cause := context.Cause(ctx) + if cause != testErr { + t.Fatalf("expected cause %v, got %v", testErr, cause) + } + cancel1(nil) // cleanup (already cancelled, no-op) +} + +func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) { + testErr := errors.New("second parent cause") + + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancelCause(context.Background()) + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel2(testErr) + + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p2 cancel") + } + + cause := context.Cause(ctx) + if cause != testErr { + t.Fatalf("expected cause %v, got %v", testErr, cause) + } + + cancel1() +} + +func TestMergeCtx_Deadline_Earliest(t *testing.T) { + now := time.Now() + early := now.Add(100 * time.Millisecond) + late := now.Add(1 * time.Hour) + + p1, cancel1 := context.WithDeadline(context.Background(), late) + p2, cancel2 := context.WithDeadline(context.Background(), early) + defer cancel1() + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + dl, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline to be set") + } + if !dl.Equal(early) { + t.Fatalf("expected deadline %v, got %v", early, dl) + } +} + +func TestMergeCtx_Deadline_Expires(t *testing.T) { + p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancelP() + + ctx, cancel := MergeCtx(p) + defer cancel() + + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after deadline expires") + } +} + +func TestMergeCtx_ValueLookup(t *testing.T) { + type key struct{} + p1 := context.WithValue(context.Background(), key{}, "from_p1") + p2 := context.WithValue(context.Background(), key{}, "from_p2") + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + val := ctx.Value(key{}) + if val != "from_p1" { + t.Fatalf("expected 'from_p1', got %v", val) + } +} + +func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) { + type key1 struct{} + type key2 struct{} + p1 := context.WithValue(context.Background(), key1{}, "val1") + p2 := context.WithValue(context.Background(), key2{}, "val2") + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + if v := ctx.Value(key1{}); v != "val1" { + t.Fatalf("expected 'val1', got %v", v) + } + if v := ctx.Value(key2{}); v != "val2" { + t.Fatalf("expected 'val2', got %v", v) + } + if v := ctx.Value("missing"); v != nil { + t.Fatalf("expected nil, got %v", v) + } +} + +func TestMergeCtx_ContextInterface(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + defer cancel2() + + var ctx context.Context + ctx, _ = MergeCtx(p1, p2) + + // Verify all Context interface methods work + _ = ctx.Done() + _ = ctx.Err() + _, _ = ctx.Deadline() + _ = ctx.Value("any") +} + +func TestOrDone_SingleContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + done := orDone(ctx) + + cancel() + <-done // should not block +} + +func TestOrDone_MultipleContexts(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + done := orDone(p1, p2) + + cancel1() + <-done // should not block +} + +func TestOrDone_SecondContextCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + + done := orDone(p1, p2) + + cancel2() + <-done // should not block +}