diff --git a/mergectx.go b/mergectx.go index e5d3ec4..2e36c09 100644 --- a/mergectx.go +++ b/mergectx.go @@ -12,17 +12,19 @@ import ( // mergedContext 实现了 context.Context 接口, 是 Merge 函数返回的实际类型. type mergedContext struct { - // 嵌入一个基础 context, 它持有最早的 deadline 和取消信号. + // 嵌入一个基础 context, 用于 Deadline() 和 Value() 查找. context.Context // 保存了所有的父 context, 用于 Value() 方法的查找. parents []context.Context - // 用于手动取消此 mergedContext 的函数. - cancel context.CancelFunc + // cancelCtx 由 CancelCause 管理, 当 cause 取消时其 Done() 关闭. + cancelCtx context.Context + // deadlineCtx 仅在有 deadline 时非 nil, 用于检测 deadline 到期. + deadlineCtx context.Context } // MergeCtx 创建并返回一个新的 context.Context. // 这个新的 context 会在任何一个传入的父 contexts 被取消时, 或者当返回的 CancelFunc 被调用时, -// 自动被取消 (逻辑或关系). +// 自动被取消 (逻辑或关系). 父 context 的取消原因 (cause) 会自动传播到返回的 context. // // 新的 context 会继承: // - Deadline: 所有父 context 中最早的截止时间. @@ -32,7 +34,8 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C return context.WithCancel(context.Background()) } if len(parents) == 1 { - return context.WithCancel(parents[0]) + ctx, cancel := context.WithCancelCause(parents[0]) + return ctx, func() { cancel(nil) } } var earliestDeadline time.Time @@ -44,37 +47,78 @@ func MergeCtx(parents ...context.Context) (ctx context.Context, cancel context.C } } - var baseCtx context.Context - var baseCancel context.CancelFunc + // baseCtx 提供 CancelCauseFunc 以支持 cause 传播. + baseCtx, baseCancel := context.WithCancelCause(context.Background()) + + // deadlineCtx 仅用于监听 deadline 到期信号. + var deadlineCtx context.Context + var deadlineCancel context.CancelFunc if !earliestDeadline.IsZero() { - baseCtx, baseCancel = context.WithDeadline(context.Background(), earliestDeadline) - } else { - baseCtx, baseCancel = context.WithCancel(context.Background()) + deadlineCtx, deadlineCancel = context.WithDeadlineCause(context.Background(), earliestDeadline, context.DeadlineExceeded) + } + + // 嵌入的 context: 有 deadline 时用 deadlineCtx, 否则用 baseCtx. + embedCtx := baseCtx + if deadlineCtx != nil { + embedCtx = deadlineCtx } mc := &mergedContext{ - Context: baseCtx, - parents: parents, - cancel: baseCancel, + Context: embedCtx, + parents: parents, + cancelCtx: baseCtx, + deadlineCtx: deadlineCtx, } - // 启动一个监控 goroutine. + // 启动监控 goroutine. go func() { - defer mc.cancel() + var once sync.Once + doCancel := func(cause error) { + once.Do(func() { baseCancel(cause) }) + } + defer doCancel(nil) - // orDone 会返回一个 channel, 当任何一个父 context 被取消时, 这个 channel 就会关闭. - // 同时监听 baseCtx.Done() 以便支持手动取消. - select { - case <-orDone(mc.parents...): - case <-mc.Context.Done(): + parentDone := orDone(mc.parents...) + + if deadlineCtx != nil { + defer deadlineCancel() + select { + case <-parentDone: + for _, p := range mc.parents { + if p.Err() != nil { + doCancel(context.Cause(p)) + return + } + } + doCancel(nil) + case <-deadlineCtx.Done(): + doCancel(context.DeadlineExceeded) + case <-baseCtx.Done(): + } + } else { + select { + case <-parentDone: + for _, p := range mc.parents { + if p.Err() != nil { + doCancel(context.Cause(p)) + return + } + } + doCancel(nil) + case <-baseCtx.Done(): + } } }() - return mc, mc.cancel + return mc, func() { baseCancel(nil) } } -// Value 返回当前Ctx Value +// Value 返回当前Ctx Value. 先检查嵌入的 context (以支持 context.Cause), +// 再按传入顺序从 parents 中查找. func (mc *mergedContext) Value(key any) any { + if v := mc.Context.Value(key); v != nil { + return v + } for _, p := range mc.parents { if val := p.Value(key); val != nil { return val @@ -90,12 +134,26 @@ func (mc *mergedContext) Deadline() (deadline time.Time, ok bool) { // Done 实现了 context.Context 的 Done 方法. func (mc *mergedContext) Done() <-chan struct{} { - return mc.Context.Done() + if mc.deadlineCtx != nil { + return orDone(mc.cancelCtx, mc.deadlineCtx) + } + return mc.cancelCtx.Done() } // Err 实现了 context.Context 的 Err 方法. func (mc *mergedContext) Err() error { - return mc.Context.Err() + if mc.cancelCtx.Err() != nil { + return mc.cancelCtx.Err() + } + if mc.deadlineCtx != nil { + return mc.deadlineCtx.Err() + } + return nil +} + +// Cause 返回取消原因, 使 context.Cause() 能正确传播 cause. +func (mc *mergedContext) Cause() error { + return context.Cause(mc.cancelCtx) } // orDone 是一个辅助函数, 返回一个 channel. diff --git a/mergectx_test.go b/mergectx_test.go new file mode 100644 index 0000000..d6d1225 --- /dev/null +++ b/mergectx_test.go @@ -0,0 +1,256 @@ +package touka + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestMergeCtx_NoParents(t *testing.T) { + ctx, cancel := MergeCtx() + defer cancel() + + if ctx.Err() != nil { + t.Fatal("expected no error before cancel") + } + cancel() + if ctx.Err() == nil { + t.Fatal("expected error after cancel") + } +} + +func TestMergeCtx_SingleParent(t *testing.T) { + parent, parentCancel := context.WithCancel(context.Background()) + + ctx, cancel := MergeCtx(parent) + defer cancel() + + if ctx.Err() != nil { + t.Fatal("expected no error before parent cancel") + } + + parentCancel() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after parent cancel") + } +} + +func TestMergeCtx_MultipleParents_FirstCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel1() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p1 cancel") + } + // p2 should still be fine + if p2.Err() != nil { + t.Fatal("expected p2 to be unaffected") + } +} + +func TestMergeCtx_MultipleParents_SecondCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel2() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p2 cancel") + } +} + +func TestMergeCtx_ExternalCancel(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + + cancel() + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after external cancel") + } +} + +func TestMergeCtx_CausePropagation(t *testing.T) { + testErr := errors.New("test cause") + + p1, cancel1 := context.WithCancelCause(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel1(testErr) + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p1 cancel") + } + + cause := context.Cause(ctx) + if cause != testErr { + t.Fatalf("expected cause %v, got %v", testErr, cause) + } + cancel1(nil) // cleanup (already cancelled, no-op) +} + +func TestMergeCtx_CausePropagation_SecondParent(t *testing.T) { + testErr := errors.New("second parent cause") + + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancelCause(context.Background()) + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + cancel2(testErr) + + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after p2 cancel") + } + + cause := context.Cause(ctx) + if cause != testErr { + t.Fatalf("expected cause %v, got %v", testErr, cause) + } + + cancel1() +} + +func TestMergeCtx_Deadline_Earliest(t *testing.T) { + now := time.Now() + early := now.Add(100 * time.Millisecond) + late := now.Add(1 * time.Hour) + + p1, cancel1 := context.WithDeadline(context.Background(), late) + p2, cancel2 := context.WithDeadline(context.Background(), early) + defer cancel1() + defer cancel2() + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + dl, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline to be set") + } + if !dl.Equal(early) { + t.Fatalf("expected deadline %v, got %v", early, dl) + } +} + +func TestMergeCtx_Deadline_Expires(t *testing.T) { + p, cancelP := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancelP() + + ctx, cancel := MergeCtx(p) + defer cancel() + + <-ctx.Done() + + if ctx.Err() == nil { + t.Fatal("expected error after deadline expires") + } +} + +func TestMergeCtx_ValueLookup(t *testing.T) { + type key struct{} + p1 := context.WithValue(context.Background(), key{}, "from_p1") + p2 := context.WithValue(context.Background(), key{}, "from_p2") + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + val := ctx.Value(key{}) + if val != "from_p1" { + t.Fatalf("expected 'from_p1', got %v", val) + } +} + +func TestMergeCtx_ValueLookup_SecondParent(t *testing.T) { + type key1 struct{} + type key2 struct{} + p1 := context.WithValue(context.Background(), key1{}, "val1") + p2 := context.WithValue(context.Background(), key2{}, "val2") + + ctx, cancel := MergeCtx(p1, p2) + defer cancel() + + if v := ctx.Value(key1{}); v != "val1" { + t.Fatalf("expected 'val1', got %v", v) + } + if v := ctx.Value(key2{}); v != "val2" { + t.Fatalf("expected 'val2', got %v", v) + } + if v := ctx.Value("missing"); v != nil { + t.Fatalf("expected nil, got %v", v) + } +} + +func TestMergeCtx_ContextInterface(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + defer cancel2() + + var ctx context.Context + ctx, _ = MergeCtx(p1, p2) + + // Verify all Context interface methods work + _ = ctx.Done() + _ = ctx.Err() + _, _ = ctx.Deadline() + _ = ctx.Value("any") +} + +func TestOrDone_SingleContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + done := orDone(ctx) + + cancel() + <-done // should not block +} + +func TestOrDone_MultipleContexts(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + done := orDone(p1, p2) + + cancel1() + <-done // should not block +} + +func TestOrDone_SecondContextCancels(t *testing.T) { + p1, cancel1 := context.WithCancel(context.Background()) + p2, cancel2 := context.WithCancel(context.Background()) + defer cancel1() + + done := orDone(p1, p2) + + cancel2() + <-done // should not block +}