mirror of
https://github.com/infinite-iroha/touka.git
synced 2026-06-13 15:47:38 +08:00
improve: MergeCtx 支持 cause 传播, 使用 WithCancelCause/WithDeadlineCause
- 内部改用 context.WithCancelCause 和 WithDeadlineCause, 父 context 取消原因自动传播 - Value() 先检查嵌入 context 再查 parents, 确保 context.Cause() 正确工作 - Done()/Err() 同时监听 cancelCtx 和 deadlineCtx, 支持 deadline 到期 cause - 新增 Cause() 便捷方法 - 单 parent 短路径改用 WithCancelCause 保留 cause - 新增 mergectx_test.go, 覆盖 cause 传播、deadline、Value 查找等场景 - API 兼容: 返回类型保持 CancelFunc 不变 Alina Agent生成
This commit is contained in:
parent
e7c7d5e41f
commit
7487369125
2 changed files with 338 additions and 24 deletions
106
mergectx.go
106
mergectx.go
|
|
@ -12,17 +12,19 @@ import (
|
|||
|
||||
// mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型.
|
||||
type mergedContext struct {
|
||||
// 嵌入一个基础 context, 它持有最早的 deadline 和取消信号.
|
||||
// 嵌入一个基础 context, 用于 Deadline() 和 Value() 查找.
|
||||
context.Context
|
||||
// 保存了所有的父 context, 用于 Value() 方法的查找.
|
||||
parents []context.Context
|
||||
// 用于手动取消此 mergedContext 的函数.
|
||||
cancel context.CancelFunc
|
||||
// cancelCtx 由 CancelCause 管理, 当 cause 取消时其 Done() 关闭.
|
||||
cancelCtx context.Context
|
||||
// deadlineCtx 仅在有 deadline 时非 nil, 用于检测 deadline 到期.
|
||||
deadlineCtx context.Context
|
||||
}
|
||||
|
||||
// MergeCtx 创建并返回一个新的 context.Context.
|
||||
// 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时,
|
||||
// 自动被取消 (逻辑或关系).
|
||||
// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context.
|
||||
//
|
||||
// 新的 context 会继承:
|
||||
// - Deadline: 所有父 context 中最早的截止时间.
|
||||
|
|
@ -32,7 +34,8 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
|
|||
return context.WithCancel(context.Background())
|
||||
}
|
||||
if len(parents) == 1 {
|
||||
return context.WithCancel(parents[0])
|
||||
ctx, cancel := context.WithCancelCause(parents[0])
|
||||
return ctx, func() { cancel(nil) }
|
||||
}
|
||||
|
||||
var earliestDeadline time.Time
|
||||
|
|
@ -44,37 +47,78 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C
|
|||
}
|
||||
}
|
||||
|
||||
var baseCtx context.Context
|
||||
var baseCancel context.CancelFunc
|
||||
// baseCtx 提供 CancelCauseFunc 以支持 cause 传播.
|
||||
baseCtx, baseCancel := context.WithCancelCause(context.Background())
|
||||
|
||||
// deadlineCtx 仅用于监听 deadline 到期信号.
|
||||
var deadlineCtx context.Context
|
||||
var deadlineCancel context.CancelFunc
|
||||
if !earliestDeadline.IsZero() {
|
||||
baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline)
|
||||
} else {
|
||||
baseCtx, baseCancel = context.WithCancel(context.Background())
|
||||
deadlineCtx, deadlineCancel = context.WithDeadlineCause(context.Background(), earliestDeadline, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// 嵌入的 context: 有 deadline 时用 deadlineCtx, 否则用 baseCtx.
|
||||
embedCtx := baseCtx
|
||||
if deadlineCtx != nil {
|
||||
embedCtx = deadlineCtx
|
||||
}
|
||||
|
||||
mc := &mergedContext{
|
||||
Context: baseCtx,
|
||||
parents: parents,
|
||||
cancel: baseCancel,
|
||||
Context: embedCtx,
|
||||
parents: parents,
|
||||
cancelCtx: baseCtx,
|
||||
deadlineCtx: deadlineCtx,
|
||||
}
|
||||
|
||||
// 启动一个监控 goroutine.
|
||||
// 启动监控 goroutine.
|
||||
go func() {
|
||||
defer mc.cancel()
|
||||
var once sync.Once
|
||||
doCancel := func(cause error) {
|
||||
once.Do(func() { baseCancel(cause) })
|
||||
}
|
||||
defer doCancel(nil)
|
||||
|
||||
// orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭.
|
||||
// 同时监听 baseCtx.Done() 以便支持手动取消.
|
||||
select {
|
||||
case <-orDone(mc.parents...):
|
||||
case <-mc.Context.Done():
|
||||
parentDone := orDone(mc.parents...)
|
||||
|
||||
if deadlineCtx != nil {
|
||||
defer deadlineCancel()
|
||||
select {
|
||||
case <-parentDone:
|
||||
for _, p := range mc.parents {
|
||||
if p.Err() != nil {
|
||||
doCancel(context.Cause(p))
|
||||
return
|
||||
}
|
||||
}
|
||||
doCancel(nil)
|
||||
case <-deadlineCtx.Done():
|
||||
doCancel(context.DeadlineExceeded)
|
||||
case <-baseCtx.Done():
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-parentDone:
|
||||
for _, p := range mc.parents {
|
||||
if p.Err() != nil {
|
||||
doCancel(context.Cause(p))
|
||||
return
|
||||
}
|
||||
}
|
||||
doCancel(nil)
|
||||
case <-baseCtx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return mc, mc.cancel
|
||||
return mc, func() { baseCancel(nil) }
|
||||
}
|
||||
|
||||
// Value 返回当前Ctx Value
|
||||
// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause),
|
||||
// 再按传入顺序从 parents 中查找.
|
||||
func (mc *mergedContext) Value(key any) any {
|
||||
if v := mc.Context.Value(key); v != nil {
|
||||
return v
|
||||
}
|
||||
for _, p := range mc.parents {
|
||||
if val := p.Value(key); val != nil {
|
||||
return val
|
||||
|
|
@ -90,12 +134,26 @@ func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) {
|
|||
|
||||
// Done 实现了 context.Context 的 Done 方法.
|
||||
func (mc *mergedContext) Done() <-chan struct{} {
|
||||
return mc.Context.Done()
|
||||
if mc.deadlineCtx != nil {
|
||||
return orDone(mc.cancelCtx, mc.deadlineCtx)
|
||||
}
|
||||
return mc.cancelCtx.Done()
|
||||
}
|
||||
|
||||
// Err 实现了 context.Context 的 Err 方法.
|
||||
func (mc *mergedContext) Err() error {
|
||||
return mc.Context.Err()
|
||||
if mc.cancelCtx.Err() != nil {
|
||||
return mc.cancelCtx.Err()
|
||||
}
|
||||
if mc.deadlineCtx != nil {
|
||||
return mc.deadlineCtx.Err()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cause 返回取消原因, 使 context.Cause() 能正确传播 cause.
|
||||
func (mc *mergedContext) Cause() error {
|
||||
return context.Cause(mc.cancelCtx)
|
||||
}
|
||||
|
||||
// orDone 是一个辅助函数, 返回一个 channel.
|
||||
|
|
|
|||
256
mergectx_test.go
Normal file
256
mergectx_test.go
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
package touka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMergeCtx_NoParents(t *testing.T) {
|
||||
ctx, cancel := MergeCtx()
|
||||
defer cancel()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
t.Fatal("expected no error before cancel")
|
||||
}
|
||||
cancel()
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_SingleParent(t *testing.T) {
|
||||
parent, parentCancel := context.WithCancel(context.Background())
|
||||
|
||||
ctx, cancel := MergeCtx(parent)
|
||||
defer cancel()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
t.Fatal("expected no error before parent cancel")
|
||||
}
|
||||
|
||||
parentCancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after parent cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
cancel1()
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after p1 cancel")
|
||||
}
|
||||
// p2 should still be fine
|
||||
if p2.Err() != nil {
|
||||
t.Fatal("expected p2 to be unaffected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
cancel2()
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after p2 cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_ExternalCancel(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
defer cancel2()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
|
||||
cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after external cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_CausePropagation(t *testing.T) {
|
||||
testErr := errors.New("test cause")
|
||||
|
||||
p1, cancel1 := context.WithCancelCause(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
cancel1(testErr)
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after p1 cancel")
|
||||
}
|
||||
|
||||
cause := context.Cause(ctx)
|
||||
if cause != testErr {
|
||||
t.Fatalf("expected cause %v, got %v", testErr, cause)
|
||||
}
|
||||
cancel1(nil) // cleanup (already cancelled, no-op)
|
||||
}
|
||||
|
||||
func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) {
|
||||
testErr := errors.New("second parent cause")
|
||||
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancelCause(context.Background())
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
cancel2(testErr)
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after p2 cancel")
|
||||
}
|
||||
|
||||
cause := context.Cause(ctx)
|
||||
if cause != testErr {
|
||||
t.Fatalf("expected cause %v, got %v", testErr, cause)
|
||||
}
|
||||
|
||||
cancel1()
|
||||
}
|
||||
|
||||
func TestMergeCtx_Deadline_Earliest(t *testing.T) {
|
||||
now := time.Now()
|
||||
early := now.Add(100 * time.Millisecond)
|
||||
late := now.Add(1 * time.Hour)
|
||||
|
||||
p1, cancel1 := context.WithDeadline(context.Background(), late)
|
||||
p2, cancel2 := context.WithDeadline(context.Background(), early)
|
||||
defer cancel1()
|
||||
defer cancel2()
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
dl, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Fatal("expected deadline to be set")
|
||||
}
|
||||
if !dl.Equal(early) {
|
||||
t.Fatalf("expected deadline %v, got %v", early, dl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_Deadline_Expires(t *testing.T) {
|
||||
p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancelP()
|
||||
|
||||
ctx, cancel := MergeCtx(p)
|
||||
defer cancel()
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if ctx.Err() == nil {
|
||||
t.Fatal("expected error after deadline expires")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_ValueLookup(t *testing.T) {
|
||||
type key struct{}
|
||||
p1 := context.WithValue(context.Background(), key{}, "from_p1")
|
||||
p2 := context.WithValue(context.Background(), key{}, "from_p2")
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
val := ctx.Value(key{})
|
||||
if val != "from_p1" {
|
||||
t.Fatalf("expected 'from_p1', got %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) {
|
||||
type key1 struct{}
|
||||
type key2 struct{}
|
||||
p1 := context.WithValue(context.Background(), key1{}, "val1")
|
||||
p2 := context.WithValue(context.Background(), key2{}, "val2")
|
||||
|
||||
ctx, cancel := MergeCtx(p1, p2)
|
||||
defer cancel()
|
||||
|
||||
if v := ctx.Value(key1{}); v != "val1" {
|
||||
t.Fatalf("expected 'val1', got %v", v)
|
||||
}
|
||||
if v := ctx.Value(key2{}); v != "val2" {
|
||||
t.Fatalf("expected 'val2', got %v", v)
|
||||
}
|
||||
if v := ctx.Value("missing"); v != nil {
|
||||
t.Fatalf("expected nil, got %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeCtx_ContextInterface(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
defer cancel2()
|
||||
|
||||
var ctx context.Context
|
||||
ctx, _ = MergeCtx(p1, p2)
|
||||
|
||||
// Verify all Context interface methods work
|
||||
_ = ctx.Done()
|
||||
_ = ctx.Err()
|
||||
_, _ = ctx.Deadline()
|
||||
_ = ctx.Value("any")
|
||||
}
|
||||
|
||||
func TestOrDone_SingleContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := orDone(ctx)
|
||||
|
||||
cancel()
|
||||
<-done // should not block
|
||||
}
|
||||
|
||||
func TestOrDone_MultipleContexts(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
|
||||
done := orDone(p1, p2)
|
||||
|
||||
cancel1()
|
||||
<-done // should not block
|
||||
}
|
||||
|
||||
func TestOrDone_SecondContextCancels(t *testing.T) {
|
||||
p1, cancel1 := context.WithCancel(context.Background())
|
||||
p2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
|
||||
done := orDone(p1, p2)
|
||||
|
||||
cancel2()
|
||||
<-done // should not block
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue