Skip to content

Commit

Permalink
context: add APIs for writing and reading cancelation cause
Browse files Browse the repository at this point in the history
Extend the context package to allow users to specify why a context was
canceled in the form of an error, the "cause". Users write the cause
by calling WithCancelCause to construct a derived context, then
calling cancel(cause) to cancel the context with the provided cause.
Users retrieve the cause by calling context.Cause(ctx), which returns
the cause of the first cancelation for ctx or any of its parents.

The cause is implemented as a field of cancelCtx, since only cancelCtx
can be canceled. Calling cancel copies the cause to all derived (child)
cancelCtxs. Calling Cause(ctx) finds the nearest parent cancelCtx by
looking up the context value keyed by cancelCtxKey.

API changes:
+pkg context, func Cause(Context) error
+pkg context, func WithCancelCause(Context) (Context, CancelCauseFunc)
+pkg context, type CancelCauseFunc func(error)

Fixes #26356
Fixes #51365

Change-Id: I15b62bd454c014db3f4f1498b35204451509e641
Reviewed-on: https://go-review.googlesource.com/c/go/+/375977
Reviewed-by: Damien Neil <[email protected]>
TryBot-Result: Gopher Robot <[email protected]>
Run-TryBot: Sameer Ajmani <[email protected]>
Auto-Submit: Sameer Ajmani <[email protected]>
  • Loading branch information
Sajmani authored and gopherbot committed Nov 8, 2022
1 parent 0b7aa9f commit 5b42f89
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 20 deletions.
3 changes: 3 additions & 0 deletions api/next/51365.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pkg context, func Cause(Context) error #51365
pkg context, func WithCancelCause(Context) (Context, CancelCauseFunc) #51365
pkg context, type CancelCauseFunc func(error) #51365
94 changes: 76 additions & 18 deletions src/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
// fires. The go vet tool checks that CancelFuncs are used on all
// control-flow paths.
//
// The WithCancelCause function returns a CancelCauseFunc, which
// takes an error and records it as the cancelation cause. Calling
// Cause on the canceled context or any of its children retrieves
// the cause. If no cause is specified, Cause(ctx) returns the same
// value as ctx.Err().
//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
Expand Down Expand Up @@ -230,17 +236,63 @@ type CancelFunc func()
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
c := withCancel(parent)
return c, func() { c.cancel(true, Canceled, nil) }
}

// A CancelCauseFunc behaves like a CancelFunc but additionally sets the cancelation cause.
// This cause can be retrieved by calling Cause on the canceled Context or on
// any of its derived Contexts.
//
// If the context has already been canceled, CancelCauseFunc does not set the cause.
// For example, if childContext is derived from parentContext:
// - if parentContext is canceled with cause1 before childContext is canceled with cause2,
// then Cause(parentContext) == Cause(childContext) == cause1
// - if childContext is canceled with cause2 before parentContext is canceled with cause1,
// then Cause(parentContext) == cause1 and Cause(childContext) == cause2
type CancelCauseFunc func(cause error)

// WithCancelCause behaves like WithCancel but returns a CancelCauseFunc instead of a CancelFunc.
// Calling cancel with a non-nil error (the "cause") records that error in ctx;
// it can then be retrieved using Cause(ctx).
// Calling cancel with nil sets the cause to Canceled.
//
// Example use:
//
// ctx, cancel := context.WithCancelCause(parent)
// cancel(myError)
// ctx.Err() // returns context.Canceled
// context.Cause(ctx) // returns myError
func WithCancelCause(parent Context) (ctx Context, cancel CancelCauseFunc) {
c := withCancel(parent)
return c, func(cause error) { c.cancel(true, Canceled, cause) }
}

func withCancel(parent Context) *cancelCtx {
if parent == nil {
panic("cannot create context from nil parent")
}
c := newCancelCtx(parent)
propagateCancel(parent, &c)
return &c, func() { c.cancel(true, Canceled) }
propagateCancel(parent, c)
return c
}

// Cause returns a non-nil error explaining why c was canceled.
// The first cancelation of c or one of its parents sets the cause.
// If that cancelation happened via a call to CancelCauseFunc(err),
// then Cause returns err.
// Otherwise Cause(c) returns the same value as c.Err().
// Cause returns nil if c has not been canceled yet.
func Cause(c Context) error {
if cc, ok := c.Value(&cancelCtxKey).(*cancelCtx); ok {
return cc.cause
}
return nil
}

// newCancelCtx returns an initialized cancelCtx.
func newCancelCtx(parent Context) cancelCtx {
return cancelCtx{Context: parent}
func newCancelCtx(parent Context) *cancelCtx {
return &cancelCtx{Context: parent}
}

// goroutines counts the number of goroutines ever created; for testing.
Expand All @@ -256,7 +308,7 @@ func propagateCancel(parent Context, child canceler) {
select {
case <-done:
// parent is already canceled
child.cancel(false, parent.Err())
child.cancel(false, parent.Err(), Cause(parent))
return
default:
}
Expand All @@ -265,7 +317,7 @@ func propagateCancel(parent Context, child canceler) {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err)
child.cancel(false, p.err, p.cause)
} else {
if p.children == nil {
p.children = make(map[canceler]struct{})
Expand All @@ -278,7 +330,7 @@ func propagateCancel(parent Context, child canceler) {
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err())
child.cancel(false, parent.Err(), Cause(parent))
case <-child.Done():
}
}()
Expand Down Expand Up @@ -326,7 +378,7 @@ func removeChild(parent Context, child canceler) {
// A canceler is a context type that can be canceled directly. The
// implementations are *cancelCtx and *timerCtx.
type canceler interface {
cancel(removeFromParent bool, err error)
cancel(removeFromParent bool, err, cause error)
Done() <-chan struct{}
}

Expand All @@ -346,6 +398,7 @@ type cancelCtx struct {
done atomic.Value // of chan struct{}, created lazily, closed by first cancel call
children map[canceler]struct{} // set to nil by the first cancel call
err error // set to non-nil by the first cancel call
cause error // set to non-nil by the first cancel call
}

func (c *cancelCtx) Value(key any) any {
Expand Down Expand Up @@ -394,16 +447,21 @@ func (c *cancelCtx) String() string {

// cancel closes c.done, cancels each of c's children, and, if
// removeFromParent is true, removes c from its parent's children.
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
// cancel sets c.cause to cause if this is the first time c is canceled.
func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
if cause == nil {
cause = err
}
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // already canceled
}
c.err = err
c.cause = cause
d, _ := c.done.Load().(chan struct{})
if d == nil {
c.done.Store(closedchan)
Expand All @@ -412,7 +470,7 @@ func (c *cancelCtx) cancel(removeFromParent bool, err error) {
}
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
child.cancel(false, err)
child.cancel(false, err, cause)
}
c.children = nil
c.mu.Unlock()
Expand Down Expand Up @@ -446,24 +504,24 @@ func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
propagateCancel(parent, c)
dur := time.Until(d)
if dur <= 0 {
c.cancel(true, DeadlineExceeded) // deadline has already passed
return c, func() { c.cancel(false, Canceled) }
c.cancel(true, DeadlineExceeded, nil) // deadline has already passed
return c, func() { c.cancel(false, Canceled, nil) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.timer = time.AfterFunc(dur, func() {
c.cancel(true, DeadlineExceeded)
c.cancel(true, DeadlineExceeded, nil)
})
}
return c, func() { c.cancel(true, Canceled) }
return c, func() { c.cancel(true, Canceled, nil) }
}

// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
cancelCtx
*cancelCtx
timer *time.Timer // Under cancelCtx.mu.

deadline time.Time
Expand All @@ -479,8 +537,8 @@ func (c *timerCtx) String() string {
time.Until(c.deadline).String() + "])"
}

func (c *timerCtx) cancel(removeFromParent bool, err error) {
c.cancelCtx.cancel(false, err)
func (c *timerCtx) cancel(removeFromParent bool, err, cause error) {
c.cancelCtx.cancel(false, err, cause)
if removeFromParent {
// Remove this timerCtx from its parent cancelCtx's children.
removeChild(c.cancelCtx.Context, c)
Expand Down Expand Up @@ -581,7 +639,7 @@ func value(c Context, key any) any {
c = ctx.Context
case *timerCtx:
if key == &cancelCtxKey {
return &ctx.cancelCtx
return ctx.cancelCtx
}
c = ctx.Context
case *emptyCtx:
Expand Down
153 changes: 151 additions & 2 deletions src/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,9 @@ func XTestCancelRemoves(t testingT) {
}

func XTestWithCancelCanceledParent(t testingT) {
parent, pcancel := WithCancel(Background())
pcancel()
parent, pcancel := WithCancelCause(Background())
cause := fmt.Errorf("Because!")
pcancel(cause)

c, _ := WithCancel(parent)
select {
Expand All @@ -662,6 +663,9 @@ func XTestWithCancelCanceledParent(t testingT) {
if got, want := c.Err(), Canceled; got != want {
t.Errorf("child not canceled; got = %v, want = %v", got, want)
}
if got, want := Cause(c), cause; got != want {
t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
}
}

func XTestWithValueChecksKey(t testingT) {
Expand Down Expand Up @@ -785,3 +789,148 @@ func XTestCustomContextGoroutines(t testingT) {
defer cancel7()
checkNoGoroutine()
}

func XTestCause(t testingT) {
var (
parentCause = fmt.Errorf("parentCause")
childCause = fmt.Errorf("childCause")
)
for _, test := range []struct {
name string
ctx Context
err error
cause error
}{
{
name: "Background",
ctx: Background(),
err: nil,
cause: nil,
},
{
name: "TODO",
ctx: TODO(),
err: nil,
cause: nil,
},
{
name: "WithCancel",
ctx: func() Context {
ctx, cancel := WithCancel(Background())
cancel()
return ctx
}(),
err: Canceled,
cause: Canceled,
},
{
name: "WithCancelCause",
ctx: func() Context {
ctx, cancel := WithCancelCause(Background())
cancel(parentCause)
return ctx
}(),
err: Canceled,
cause: parentCause,
},
{
name: "WithCancelCause nil",
ctx: func() Context {
ctx, cancel := WithCancelCause(Background())
cancel(nil)
return ctx
}(),
err: Canceled,
cause: Canceled,
},
{
name: "WithCancelCause: parent cause before child",
ctx: func() Context {
ctx, cancelParent := WithCancelCause(Background())
ctx, cancelChild := WithCancelCause(ctx)
cancelParent(parentCause)
cancelChild(childCause)
return ctx
}(),
err: Canceled,
cause: parentCause,
},
{
name: "WithCancelCause: parent cause after child",
ctx: func() Context {
ctx, cancelParent := WithCancelCause(Background())
ctx, cancelChild := WithCancelCause(ctx)
cancelChild(childCause)
cancelParent(parentCause)
return ctx
}(),
err: Canceled,
cause: childCause,
},
{
name: "WithCancelCause: parent cause before nil",
ctx: func() Context {
ctx, cancelParent := WithCancelCause(Background())
ctx, cancelChild := WithCancel(ctx)
cancelParent(parentCause)
cancelChild()
return ctx
}(),
err: Canceled,
cause: parentCause,
},
{
name: "WithCancelCause: parent cause after nil",
ctx: func() Context {
ctx, cancelParent := WithCancelCause(Background())
ctx, cancelChild := WithCancel(ctx)
cancelChild()
cancelParent(parentCause)
return ctx
}(),
err: Canceled,
cause: Canceled,
},
{
name: "WithCancelCause: child cause after nil",
ctx: func() Context {
ctx, cancelParent := WithCancel(Background())
ctx, cancelChild := WithCancelCause(ctx)
cancelParent()
cancelChild(childCause)
return ctx
}(),
err: Canceled,
cause: Canceled,
},
{
name: "WithCancelCause: child cause before nil",
ctx: func() Context {
ctx, cancelParent := WithCancel(Background())
ctx, cancelChild := WithCancelCause(ctx)
cancelChild(childCause)
cancelParent()
return ctx
}(),
err: Canceled,
cause: childCause,
},
{
name: "WithTimeout",
ctx: func() Context {
ctx, cancel := WithTimeout(Background(), 0)
cancel()
return ctx
}(),
err: DeadlineExceeded,
cause: DeadlineExceeded,
},
} {
if got, want := test.ctx.Err(), test.err; want != got {
t.Errorf("%s: ctx.Err() = %v want %v", test.name, got, want)
}
if got, want := Cause(test.ctx), test.cause; want != got {
t.Errorf("%s: Cause(ctx) = %v want %v", test.name, got, want)
}
}
}
1 change: 1 addition & 0 deletions src/context/x_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ func TestWithValueChecksKey(t *testing.T) { XTestWithValueChecksKey
func TestInvalidDerivedFail(t *testing.T) { XTestInvalidDerivedFail(t) }
func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) }
func TestCustomContextGoroutines(t *testing.T) { XTestCustomContextGoroutines(t) }
func TestCause(t *testing.T) { XTestCause(t) }

0 comments on commit 5b42f89

Please sign in to comment.