Skip to content

Commit

Permalink
Add contexts that use FakeClock rather than the system time. (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
DPJacques authored Nov 28, 2024
1 parent 7e524bd commit 91d2c0a
Show file tree
Hide file tree
Showing 4 changed files with 416 additions and 41 deletions.
70 changes: 35 additions & 35 deletions clockwork.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,44 +157,53 @@ func (fc *FakeClock) NewTicker(d time.Duration) Ticker {
ft = &fakeTicker{
firer: newFirer(),
d: d,
reset: func(d time.Duration) { fc.set(ft, d) },
stop: func() { fc.stop(ft) },
reset: func(d time.Duration) {
fc.l.Lock()
defer fc.l.Unlock()
fc.setExpirer(ft, d)
},
stop: func() { fc.stop(ft) },
}
fc.set(ft, d)
fc.l.Lock()
defer fc.l.Unlock()
fc.setExpirer(ft, d)
return ft
}

// NewTimer returns a Timer that will fire only after calls to
// fakeClock.Advance() have moved the clock past the given duration.
func (fc *FakeClock) NewTimer(d time.Duration) Timer {
return fc.newTimer(d, nil)
t, _ := fc.newTimer(d, nil)
return t
}

// AfterFunc mimics [time.AfterFunc]; it returns a Timer that will invoke the
// given function only after calls to fakeClock.Advance() have moved the clock
// past the given duration.
func (fc *FakeClock) AfterFunc(d time.Duration, f func()) Timer {
return fc.newTimer(d, f)
t, _ := fc.newTimer(d, f)
return t
}

// newTimer returns a new timer, using an optional afterFunc.
func (fc *FakeClock) newTimer(d time.Duration, afterfunc func()) *fakeTimer {
var ft *fakeTimer
ft = &fakeTimer{
firer: newFirer(),
reset: func(d time.Duration) bool {
fc.l.Lock()
defer fc.l.Unlock()
// fc.l must be held across the calls to stopExpirer & setExpirer.
stopped := fc.stopExpirer(ft)
fc.setExpirer(ft, d)
return stopped
},
stop: func() bool { return fc.stop(ft) },
// newTimer returns a new timer using an optional afterFunc and the time that
// timer expires.
func (fc *FakeClock) newTimer(d time.Duration, afterfunc func()) (*fakeTimer, time.Time) {
ft := newFakeTimer(fc, afterfunc)
fc.l.Lock()
defer fc.l.Unlock()
fc.setExpirer(ft, d)
return ft, ft.expiry()
}

afterFunc: afterfunc,
}
fc.set(ft, d)
// newTimerAtTime is like newTimer, but uses a time instead of a duration.
//
// It is used to ensure FakeClock's lock is held constant through calling
// fc.After(t.Sub(fc.Now())). It should not be exposed externally.
func (fc *FakeClock) newTimerAtTime(t time.Time, afterfunc func()) *fakeTimer {
ft := newFakeTimer(fc, afterfunc)
fc.l.Lock()
defer fc.l.Unlock()
fc.setExpirer(ft, t.Sub(fc.time))
return ft
}

Expand Down Expand Up @@ -289,13 +298,6 @@ func (fc *FakeClock) stopExpirer(e expirer) bool {
return true
}

// set sets an expirer to expire at a future point in time.
func (fc *FakeClock) set(e expirer, d time.Duration) {
fc.l.Lock()
defer fc.l.Unlock()
fc.setExpirer(e, d)
}

// setExpirer sets an expirer to expire at a future point in time.
//
// The caller must hold fc.l.
Expand All @@ -316,16 +318,14 @@ func (fc *FakeClock) setExpirer(e expirer, d time.Duration) {
})

// Notify blockers of our new waiter.
var blocked []*blocker
count := len(fc.waiters)
for _, b := range fc.blockers {
fc.blockers = slices.DeleteFunc(fc.blockers, func(b *blocker) bool {
if b.count <= count {
close(b.ch)
continue
return true
}
blocked = append(blocked, b)
}
fc.blockers = blocked
return false
})
}

// firer is used by fakeTimer and fakeTicker used to help implement expirer.
Expand Down
143 changes: 141 additions & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ package clockwork

import (
"context"
"fmt"
"sync"
"time"
)

// contextKey is private to this package so we can ensure uniqueness here. This
// type identifies context values provided by this package.
type contextKey string

// keyClock provides a clock for injecting during tests. If absent, a real clock should be used.
// keyClock provides a clock for injecting during tests. If absent, a real clock
// should be used.
var keyClock = contextKey("clock") // clockwork.Clock

// AddToContext creates a derived context that references the specified clock.
Expand All @@ -21,10 +25,145 @@ func AddToContext(ctx context.Context, clock Clock) context.Context {
return context.WithValue(ctx, keyClock, clock)
}

// FromContext extracts a clock from the context. If not present, a real clock is returned.
// FromContext extracts a clock from the context. If not present, a real clock
// is returned.
func FromContext(ctx context.Context) Clock {
if clock, ok := ctx.Value(keyClock).(Clock); ok {
return clock
}
return NewRealClock()
}

// ErrFakeClockDeadlineExceeded is the error returned by [context.Context] when
// the deadline passes on a context which uses a [FakeClock].
//
// It wraps a [context.DeadlineExceeded] error, i.e.:
//
// // The following is true for any Context whose deadline has been exceeded,
// // including contexts made with clockwork.WithDeadline or clockwork.WithTimeout.
//
// errors.Is(ctx.Err(), context.DeadlineExceeded)
//
// // The following can only be true for contexts made
// // with clockwork.WithDeadline or clockwork.WithTimeout.
//
// errors.Is(ctx.Err(), clockwork.ErrFakeClockDeadlineExceeded)
var ErrFakeClockDeadlineExceeded error = fmt.Errorf("clockwork.FakeClock: %w", context.DeadlineExceeded)

// WithDeadline returns a context with a deadline based on a [FakeClock].
//
// The returned context ignores parent cancelation if the parent was cancelled
// with a [context.DeadlineExceeded] error. Any other error returned by the
// parent is treated normally, cancelling the returned context.
//
// If the parent is cancelled with a [context.DeadlineExceeded] error, the only
// way to then cancel the returned context is by calling the returned
// context.CancelFunc.
func WithDeadline(parent context.Context, clock Clock, t time.Time) (context.Context, context.CancelFunc) {
if fc, ok := clock.(*FakeClock); ok {
return newFakeClockContext(parent, t, fc.newTimerAtTime(t, nil).Chan())
}
return context.WithDeadline(parent, t)
}

// WithTimeout returns a context with a timeout based on a [FakeClock].
//
// The returned context follows the same behaviors as [WithDeadline].
func WithTimeout(parent context.Context, clock Clock, d time.Duration) (context.Context, context.CancelFunc) {
if fc, ok := clock.(*FakeClock); ok {
t, deadline := fc.newTimer(d, nil)
return newFakeClockContext(parent, deadline, t.Chan())
}
return context.WithTimeout(parent, d)
}

// fakeClockContext implements context.Context, using a fake clock for its
// deadline.
//
// It ignores parent cancellation if the parent is cancelled with
// context.DeadlineExceeded.
type fakeClockContext struct {
parent context.Context
deadline time.Time // The user-facing deadline based on the fake clock's time.

// Tracks timeout/deadline cancellation.
timerDone <-chan time.Time

// Tracks manual calls to the cancel function.
cancel func() // Closes cancelCalled wrapped in a sync.Once.
cancelCalled chan struct{}

// The user-facing data from the context.Context interface.
ctxDone chan struct{} // Returned by Done().
err error // nil until ctxDone is ready to be closed.
}

func newFakeClockContext(parent context.Context, deadline time.Time, timer <-chan time.Time) (context.Context, context.CancelFunc) {
cancelCalled := make(chan struct{})
ctx := &fakeClockContext{
parent: parent,
deadline: deadline,
timerDone: timer,
cancelCalled: cancelCalled,
ctxDone: make(chan struct{}),
cancel: sync.OnceFunc(func() {
close(cancelCalled)
}),
}
ready := make(chan struct{}, 1)
go ctx.runCancel(ready)
<-ready // Wait until the cancellation goroutine is running.
return ctx, ctx.cancel
}

func (c *fakeClockContext) Deadline() (time.Time, bool) {
return c.deadline, true
}

func (c *fakeClockContext) Done() <-chan struct{} {
return c.ctxDone
}

func (c *fakeClockContext) Err() error {
<-c.Done() // Don't return the error before it is ready.
return c.err
}

func (c *fakeClockContext) Value(key any) any {
return c.parent.Value(key)
}

// runCancel runs the fakeClockContext's cancel goroutine and returns the
// fakeClockContext's cancel function.
//
// fakeClockContext is then cancelled when any of the following occur:
//
// - The fakeClockContext.done channel is closed by its timer.
// - The returned CancelFunc is executed.
// - The fakeClockContext's parent context is cancelled with an error other
// than context.DeadlineExceeded.
func (c *fakeClockContext) runCancel(ready chan struct{}) {
parentDone := c.parent.Done()

// Close ready when done, just in case the ready signal races with other
// branches of our select statement below.
defer close(ready)

for c.err == nil {
select {
case <-c.timerDone:
c.err = ErrFakeClockDeadlineExceeded
case <-c.cancelCalled:
c.err = context.Canceled
case <-parentDone:
c.err = c.parent.Err()

case ready <- struct{}{}:
// Signals the cancellation goroutine has begun, in an attempt to minimize
// race conditions related to goroutine startup time.
ready = nil // This case statement can only fire once.
}
}
close(c.ctxDone)
return
}
Loading

0 comments on commit 91d2c0a

Please sign in to comment.