From 831cf05a7c4f839939a6c5e1e8bee238ddc6baac Mon Sep 17 00:00:00 2001 From: Clayton Coleman Date: Fri, 10 Mar 2023 10:48:14 -0600 Subject: [PATCH] wait: Deprecate legacy Poll methods for new context aware methods The Poll* methods predate context in Go, and the current implementation will return ErrWaitTimeout even if the context is cancelled, which prevents callers who are using Poll* from handling that error directly (for instance, if you want to cancel a function in a controlled fashion but still report cleanup errors to logs, you want to know the difference between 'didn't cancel', 'cancelled cleanly', and 'hit an error). This commit adds two new methods that reflect how modern Go uses context in polling while preserving all Kubernetes-specific behavior: PollUntilContextCancel PollUntilContextTimeout These methods can be used for infinite polling (normal context), timed polling (deadline context), and cancellable poll (cancel context). All other Poll/Wait methods are marked as deprecated for removal in the future. The ErrWaitTimeout error will no longer be returned from the Poll* methods, but will continue to be returned from ExponentialBackoff*. Users updating to use these new methods are responsible for converting their error handling as appropriate. A convenience helper `Interrupted(err) bool` has been added that should be used instead of checking `err == ErrWaitTimeout`. In a future release ErrWaitTimeout will be made private to prevent incorrect use. The helper can be used with all polling methods since context cancellation and deadline are semantically equivalent to ErrWaitTimeout. A new `ErrorInterrupted(cause error)` method should be used instead of returning ErrWaitTimeout in custom code. The convenience method PollUntilContextTimeout is added because deadline context creation is verbose and the cancel function must be called to properly cleanup the context - many of the current poll users would see code sizes increase. To reduce the overall method surface area, the distinction between PollImmediate and Poll has been reduced to a single boolean on PollUntilContextCancel so we do not need multiple helper methods. The existing methods were not altered because ecosystem callers have been observed to use ErrWaitTimeout to mean "any error that my condition func did not return" which prevents cancellation errors from being returned from the existing methods. Callers must make a deliberate migration. Callers migrating to `PollWithContextCancel` should: 1. Pass a context with a deadline or timeout if they were previously using `Poll*Until*` and check `err` for `context.DeadlineExceeded` instead of `ErrWaitTimeout` (more specific) or use `Interrupted(err)` for a generic check. 2. Callers that were waiting forever or for context cancellation should ensure they are checking `context.Canceled` instead of `ErrWaitTimeout` to detect when the poll was stopped early. Callers of `ExponentialBackoffWithContext` should use `Interrupted(err)` instead of directly checking `err == ErrWaitTimeout`. No other changes are needed. Code that returns `ErrWaitTimeout` should instead define a local cause and return `wait.ErrorInterrupted(cause)`, which will be recognized by `wait.Interrupted()`. If nil is passed the previous message will be used but clients are highly recommended to use typed checks vs message checks. As a consequence of this change the new methods are more efficient - Poll uses one less goroutine. Kubernetes-commit: 133dd6157887f26aa91f648ea3103936d67d747b --- pkg/util/wait/backoff.go | 213 +++++++++++++++-- pkg/util/wait/delay.go | 51 ++++ pkg/util/wait/error.go | 80 ++++++- pkg/util/wait/error_test.go | 144 ++++++++++++ pkg/util/wait/loop.go | 86 +++++++ pkg/util/wait/loop_test.go | 447 +++++++++++++++++++++++++++++++++++ pkg/util/wait/poll.go | 82 ++++++- pkg/util/wait/timer.go | 121 ++++++++++ pkg/util/wait/wait.go | 25 +- pkg/util/wait/wait_test.go | 454 +++++++++++++++++++++++++----------- 10 files changed, 1537 insertions(+), 166 deletions(-) create mode 100644 pkg/util/wait/delay.go create mode 100644 pkg/util/wait/error_test.go create mode 100644 pkg/util/wait/loop.go create mode 100644 pkg/util/wait/loop_test.go create mode 100644 pkg/util/wait/timer.go diff --git a/pkg/util/wait/backoff.go b/pkg/util/wait/backoff.go index ed419d105..418761925 100644 --- a/pkg/util/wait/backoff.go +++ b/pkg/util/wait/backoff.go @@ -19,6 +19,7 @@ package wait import ( "context" "math" + "sync" "time" "k8s.io/apimachinery/pkg/util/runtime" @@ -51,33 +52,104 @@ type Backoff struct { Cap time.Duration } -// Step (1) returns an amount of time to sleep determined by the -// original Duration and Jitter and (2) mutates the provided Backoff -// to update its Steps and Duration. +// Step returns an amount of time to sleep determined by the original +// Duration and Jitter. The backoff is mutated to update its Steps and +// Duration. A nil Backoff always has a zero-duration step. func (b *Backoff) Step() time.Duration { - if b.Steps < 1 { - if b.Jitter > 0 { - return Jitter(b.Duration, b.Jitter) - } - return b.Duration + if b == nil { + return 0 } - b.Steps-- + var nextDuration time.Duration + nextDuration, b.Duration, b.Steps = delay(b.Steps, b.Duration, b.Cap, b.Factor, b.Jitter) + return nextDuration +} +// DelayFunc returns a function that will compute the next interval to +// wait given the arguments in b. It does not mutate the original backoff +// but the function is safe to use only from a single goroutine. +func (b Backoff) DelayFunc() DelayFunc { + steps := b.Steps duration := b.Duration + cap := b.Cap + factor := b.Factor + jitter := b.Jitter + + return func() time.Duration { + var nextDuration time.Duration + // jitter is applied per step and is not cumulative over multiple steps + nextDuration, duration, steps = delay(steps, duration, cap, factor, jitter) + return nextDuration + } +} - // calculate the next step - if b.Factor != 0 { - b.Duration = time.Duration(float64(b.Duration) * b.Factor) - if b.Cap > 0 && b.Duration > b.Cap { - b.Duration = b.Cap - b.Steps = 0 +// Timer returns a timer implementation appropriate to this backoff's parameters +// for use with wait functions. +func (b Backoff) Timer() Timer { + if b.Steps > 1 || b.Jitter != 0 { + return &variableTimer{new: internalClock.NewTimer, fn: b.DelayFunc()} + } + if b.Duration > 0 { + return &fixedTimer{new: internalClock.NewTicker, interval: b.Duration} + } + return newNoopTimer() +} + +// delay implements the core delay algorithm used in this package. +func delay(steps int, duration, cap time.Duration, factor, jitter float64) (_ time.Duration, next time.Duration, nextSteps int) { + // when steps is non-positive, do not alter the base duration + if steps < 1 { + if jitter > 0 { + return Jitter(duration, jitter), duration, 0 } + return duration, duration, 0 + } + steps-- + + // calculate the next step's interval + if factor != 0 { + next = time.Duration(float64(duration) * factor) + if cap > 0 && next > cap { + next = cap + steps = 0 + } + } else { + next = duration } - if b.Jitter > 0 { - duration = Jitter(duration, b.Jitter) + // add jitter for this step + if jitter > 0 { + duration = Jitter(duration, jitter) } - return duration + + return duration, next, steps + +} + +// DelayWithReset returns a DelayFunc that will return the appropriate next interval to +// wait. Every resetInterval the backoff parameters are reset to their initial state. +// This method is safe to invoke from multiple goroutines, but all calls will advance +// the backoff state when Factor is set. If Factor is zero, this method is the same as +// invoking b.DelayFunc() since Steps has no impact without Factor. If resetInterval is +// zero no backoff will be performed as the same calling DelayFunc with a zero factor +// and steps. +func (b Backoff) DelayWithReset(c clock.Clock, resetInterval time.Duration) DelayFunc { + if b.Factor <= 0 { + return b.DelayFunc() + } + if resetInterval <= 0 { + b.Steps = 0 + b.Factor = 0 + return b.DelayFunc() + } + return (&backoffManager{ + backoff: b, + initialBackoff: b, + resetInterval: resetInterval, + + clock: c, + lastStart: c.Now(), + timer: nil, + }).Step } // Until loops until stop channel is closed, running f every period. @@ -187,15 +259,65 @@ func JitterUntilWithContext(ctx context.Context, f func(context.Context), period JitterUntil(func() { f(ctx) }, period, jitterFactor, sliding, ctx.Done()) } -// BackoffManager manages backoff with a particular scheme based on its underlying implementation. It provides -// an interface to return a timer for backoff, and caller shall backoff until Timer.C() drains. If the second Backoff() -// is called before the timer from the first Backoff() call finishes, the first timer will NOT be drained and result in -// undetermined behavior. -// The BackoffManager is supposed to be called in a single-threaded environment. +// backoffManager provides simple backoff behavior in a threadsafe manner to a caller. +type backoffManager struct { + backoff Backoff + initialBackoff Backoff + resetInterval time.Duration + + clock clock.Clock + + lock sync.Mutex + lastStart time.Time + timer clock.Timer +} + +// Step returns the expected next duration to wait. +func (b *backoffManager) Step() time.Duration { + b.lock.Lock() + defer b.lock.Unlock() + + switch { + case b.resetInterval == 0: + b.backoff = b.initialBackoff + case b.clock.Now().Sub(b.lastStart) > b.resetInterval: + b.backoff = b.initialBackoff + b.lastStart = b.clock.Now() + } + return b.backoff.Step() +} + +// Backoff implements BackoffManager.Backoff, it returns a timer so caller can block on the timer +// for exponential backoff. The returned timer must be drained before calling Backoff() the second +// time. +func (b *backoffManager) Backoff() clock.Timer { + b.lock.Lock() + defer b.lock.Unlock() + if b.timer == nil { + b.timer = b.clock.NewTimer(b.Step()) + } else { + b.timer.Reset(b.Step()) + } + return b.timer +} + +// Timer returns a new Timer instance that shares the clock and the reset behavior with all other +// timers. +func (b *backoffManager) Timer() Timer { + return DelayFunc(b.Step).Timer(b.clock) +} + +// BackoffManager manages backoff with a particular scheme based on its underlying implementation. type BackoffManager interface { + // Backoff returns a shared clock.Timer that is Reset on every invocation. This method is not + // safe for use from multiple threads. It returns a timer for backoff, and caller shall backoff + // until Timer.C() drains. If the second Backoff() is called before the timer from the first + // Backoff() call finishes, the first timer will NOT be drained and result in undetermined + // behavior. Backoff() clock.Timer } +// Deprecated: Will be removed when the legacy polling functions are removed. type exponentialBackoffManagerImpl struct { backoff *Backoff backoffTimer clock.Timer @@ -208,6 +330,27 @@ type exponentialBackoffManagerImpl struct { // NewExponentialBackoffManager returns a manager for managing exponential backoff. Each backoff is jittered and // backoff will not exceed the given max. If the backoff is not called within resetDuration, the backoff is reset. // This backoff manager is used to reduce load during upstream unhealthiness. +// +// Deprecated: Will be removed when the legacy Poll methods are removed. Callers should construct a +// Backoff struct, use DelayWithReset() to get a DelayFunc that periodically resets itself, and then +// invoke Timer() when calling wait.BackoffUntil. +// +// Instead of: +// +// bm := wait.NewExponentialBackoffManager(init, max, reset, factor, jitter, clock) +// ... +// wait.BackoffUntil(..., bm.Backoff, ...) +// +// Use: +// +// delayFn := wait.Backoff{ +// Duration: init, +// Cap: max, +// Steps: int(math.Ceil(float64(max) / float64(init))), // now a required argument +// Factor: factor, +// Jitter: jitter, +// }.DelayWithReset(reset, clock) +// wait.BackoffUntil(..., delayFn.Timer(), ...) func NewExponentialBackoffManager(initBackoff, maxBackoff, resetDuration time.Duration, backoffFactor, jitter float64, c clock.Clock) BackoffManager { return &exponentialBackoffManagerImpl{ backoff: &Backoff{ @@ -248,6 +391,7 @@ func (b *exponentialBackoffManagerImpl) Backoff() clock.Timer { return b.backoffTimer } +// Deprecated: Will be removed when the legacy polling functions are removed. type jitteredBackoffManagerImpl struct { clock clock.Clock duration time.Duration @@ -257,6 +401,19 @@ type jitteredBackoffManagerImpl struct { // NewJitteredBackoffManager returns a BackoffManager that backoffs with given duration plus given jitter. If the jitter // is negative, backoff will not be jittered. +// +// Deprecated: Will be removed when the legacy Poll methods are removed. Callers should construct a +// Backoff struct and invoke Timer() when calling wait.BackoffUntil. +// +// Instead of: +// +// bm := wait.NewJitteredBackoffManager(duration, jitter, clock) +// ... +// wait.BackoffUntil(..., bm.Backoff, ...) +// +// Use: +// +// wait.BackoffUntil(..., wait.Backoff{Duration: duration, Jitter: jitter}.Timer(), ...) func NewJitteredBackoffManager(duration time.Duration, jitter float64, c clock.Clock) BackoffManager { return &jitteredBackoffManagerImpl{ clock: c, @@ -296,6 +453,9 @@ func (j *jitteredBackoffManagerImpl) Backoff() clock.Timer { // 3. a sleep truncated by the cap on duration has been completed. // In case (1) the returned error is what the condition function returned. // In all other cases, ErrWaitTimeout is returned. +// +// Since backoffs are often subject to cancellation, we recommend using +// ExponentialBackoffWithContext and passing a context to the method. func ExponentialBackoff(backoff Backoff, condition ConditionFunc) error { for backoff.Steps > 0 { if ok, err := runConditionWithCrashProtection(condition); err != nil || ok { @@ -309,8 +469,11 @@ func ExponentialBackoff(backoff Backoff, condition ConditionFunc) error { return ErrWaitTimeout } -// ExponentialBackoffWithContext works with a request context and a Backoff. It ensures that the retry wait never -// exceeds the deadline specified by the request context. +// ExponentialBackoffWithContext repeats a condition check with exponential backoff. +// It immediately returns an error if the condition returns an error, the context is cancelled +// or hits the deadline, or if the maximum attempts defined in backoff is exceeded (ErrWaitTimeout). +// If an error is returned by the condition the backoff stops immediately. The condition will +// never be invoked more than backoff.Steps times. func ExponentialBackoffWithContext(ctx context.Context, backoff Backoff, condition ConditionWithContextFunc) error { for backoff.Steps > 0 { select { diff --git a/pkg/util/wait/delay.go b/pkg/util/wait/delay.go new file mode 100644 index 000000000..1d3dcaa74 --- /dev/null +++ b/pkg/util/wait/delay.go @@ -0,0 +1,51 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "sync" + "time" + + "k8s.io/utils/clock" +) + +// DelayFunc returns the next time interval to wait. +type DelayFunc func() time.Duration + +// Timer takes an arbitrary delay function and returns a timer that can handle arbitrary interval changes. +// Use Backoff{...}.Timer() for simple delays and more efficient timers. +func (fn DelayFunc) Timer(c clock.Clock) Timer { + return &variableTimer{fn: fn, new: c.NewTimer} +} + +// Until takes an arbitrary delay function and runs until cancelled or the condition indicates exit. This +// offers all of the functionality of the methods in this package. +func (fn DelayFunc) Until(ctx context.Context, immediate, sliding bool, condition ConditionWithContextFunc) error { + return loopConditionUntilContext(ctx, &variableTimer{fn: fn, new: internalClock.NewTimer}, immediate, sliding, condition) +} + +// Concurrent returns a version of this DelayFunc that is safe for use by multiple goroutines that +// wish to share a single delay timer. +func (fn DelayFunc) Concurrent() DelayFunc { + var lock sync.Mutex + return func() time.Duration { + lock.Lock() + defer lock.Unlock() + return fn() + } +} diff --git a/pkg/util/wait/error.go b/pkg/util/wait/error.go index 5172f08df..dd75801d8 100644 --- a/pkg/util/wait/error.go +++ b/pkg/util/wait/error.go @@ -16,7 +16,81 @@ limitations under the License. package wait -import "errors" +import ( + "context" + "errors" +) -// ErrWaitTimeout is returned when the condition exited without success. -var ErrWaitTimeout = errors.New("timed out waiting for the condition") +// ErrWaitTimeout is returned when the condition was not satisfied in time. +// +// Deprecated: This type will be made private in favor of Interrupted() +// for checking errors or ErrorInterrupted(err) for returning a wrapped error. +var ErrWaitTimeout = ErrorInterrupted(errors.New("timed out waiting for the condition")) + +// Interrupted returns true if the error indicates a Poll, ExponentialBackoff, or +// Until loop exited for any reason besides the condition returning true or an +// error. A loop is considered interrupted if the calling context is cancelled, +// the context reaches its deadline, or a backoff reaches its maximum allowed +// steps. +// +// Callers should use this method instead of comparing the error value directly to +// ErrWaitTimeout, as methods that cancel a context may not return that error. +// +// Instead of: +// +// err := wait.Poll(...) +// if err == wait.ErrWaitTimeout { +// log.Infof("Wait for operation exceeded") +// } else ... +// +// Use: +// +// err := wait.Poll(...) +// if wait.Interrupted(err) { +// log.Infof("Wait for operation exceeded") +// } else ... +func Interrupted(err error) bool { + switch { + case errors.Is(err, errWaitTimeout), + errors.Is(err, context.Canceled), + errors.Is(err, context.DeadlineExceeded): + return true + default: + return false + } +} + +// errInterrupted +type errInterrupted struct { + cause error +} + +// ErrorInterrupted returns an error that indicates the wait was ended +// early for a given reason. If no cause is provided a generic error +// will be used but callers are encouraged to provide a real cause for +// clarity in debugging. +func ErrorInterrupted(cause error) error { + switch cause.(type) { + case errInterrupted: + // no need to wrap twice since errInterrupted is only needed + // once in a chain + return cause + default: + return errInterrupted{cause} + } +} + +// errWaitTimeout is the private version of the previous ErrWaitTimeout +// and is private to prevent direct comparison. Use ErrorInterrupted(err) +// to get an error that will return true for Interrupted(err). +var errWaitTimeout = errInterrupted{} + +func (e errInterrupted) Unwrap() error { return e.cause } +func (e errInterrupted) Is(target error) bool { return target == errWaitTimeout } +func (e errInterrupted) Error() string { + if e.cause == nil { + // returns the same error message as historical behavior + return "timed out waiting for the condition" + } + return e.cause.Error() +} diff --git a/pkg/util/wait/error_test.go b/pkg/util/wait/error_test.go new file mode 100644 index 000000000..0c96f0619 --- /dev/null +++ b/pkg/util/wait/error_test.go @@ -0,0 +1,144 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "errors" + "fmt" + "testing" +) + +type errWrapper struct { + wrapped error +} + +func (w errWrapper) Unwrap() error { + return w.wrapped +} +func (w errWrapper) Error() string { + return fmt.Sprintf("wrapped: %v", w.wrapped) +} + +type errNotWrapper struct { + wrapped error +} + +func (w errNotWrapper) Error() string { + return fmt.Sprintf("wrapped: %v", w.wrapped) +} + +func TestInterrupted(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + err: ErrWaitTimeout, + want: true, + }, + { + err: context.Canceled, + want: true, + }, { + err: context.DeadlineExceeded, + want: true, + }, + { + err: errWrapper{ErrWaitTimeout}, + want: true, + }, + { + err: errWrapper{context.Canceled}, + want: true, + }, + { + err: errWrapper{context.DeadlineExceeded}, + want: true, + }, + { + err: ErrorInterrupted(nil), + want: true, + }, + { + err: ErrorInterrupted(errors.New("unknown")), + want: true, + }, + { + err: ErrorInterrupted(context.Canceled), + want: true, + }, + { + err: ErrorInterrupted(ErrWaitTimeout), + want: true, + }, + + { + err: nil, + }, + { + err: errors.New("not a cancellation"), + }, + { + err: errNotWrapper{ErrWaitTimeout}, + }, + { + err: errNotWrapper{context.Canceled}, + }, + { + err: errNotWrapper{context.DeadlineExceeded}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Interrupted(tt.err); got != tt.want { + t.Errorf("Interrupted() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestErrorInterrupted(t *testing.T) { + internalErr := errInterrupted{} + if ErrorInterrupted(internalErr) != internalErr { + t.Fatalf("error should not be wrapped twice") + } + + internalErr = errInterrupted{errInterrupted{}} + if ErrorInterrupted(internalErr) != internalErr { + t.Fatalf("object should be identical") + } + + in := errors.New("test") + actual, expected := ErrorInterrupted(in), (errInterrupted{in}) + if actual != expected { + t.Fatalf("did not wrap error") + } + if !errors.Is(actual, errWaitTimeout) { + t.Fatalf("does not obey errors.Is contract") + } + if actual.Error() != in.Error() { + t.Fatalf("unexpected error output") + } + if !Interrupted(actual) { + t.Fatalf("is not Interrupted") + } + if Interrupted(in) { + t.Fatalf("should not be Interrupted") + } +} diff --git a/pkg/util/wait/loop.go b/pkg/util/wait/loop.go new file mode 100644 index 000000000..51864d70f --- /dev/null +++ b/pkg/util/wait/loop.go @@ -0,0 +1,86 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "time" + + "k8s.io/apimachinery/pkg/util/runtime" +) + +// loopConditionUntilContext executes the provided condition at intervals defined by +// the provided timer until the provided context is cancelled, the condition returns +// true, or the condition returns an error. If sliding is true, the period is computed +// after condition runs. If it is false then period includes the runtime for condition. +// If immediate is false the first delay happens before any call to condition. The +// returned error is the error returned by the last condition or the context error if +// the context was terminated. +// +// This is the common loop construct for all polling in the wait package. +func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding bool, condition ConditionWithContextFunc) error { + defer t.Stop() + + var timeCh <-chan time.Time + doneCh := ctx.Done() + + // if we haven't requested immediate execution, delay once + if !immediate { + timeCh = t.C() + select { + case <-doneCh: + return ctx.Err() + case <-timeCh: + } + } + + for { + // checking ctx.Err() is slightly faster than checking a select + if err := ctx.Err(); err != nil { + return err + } + + if !sliding { + t.Next() + } + if ok, err := func() (bool, error) { + defer runtime.HandleCrash() + return condition(ctx) + }(); err != nil || ok { + return err + } + if sliding { + t.Next() + } + + if timeCh == nil { + timeCh = t.C() + } + + // NOTE: b/c there is no priority selection in golang + // it is possible for this to race, meaning we could + // trigger t.C and doneCh, and t.C select falls through. + // In order to mitigate we re-check doneCh at the beginning + // of every loop to guarantee at-most one extra execution + // of condition. + select { + case <-doneCh: + return ctx.Err() + case <-timeCh: + } + } +} diff --git a/pkg/util/wait/loop_test.go b/pkg/util/wait/loop_test.go new file mode 100644 index 000000000..c5849250a --- /dev/null +++ b/pkg/util/wait/loop_test.go @@ -0,0 +1,447 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "k8s.io/utils/clock" + testingclock "k8s.io/utils/clock/testing" +) + +func timerWithClock(t Timer, c clock.WithTicker) Timer { + switch t := t.(type) { + case *fixedTimer: + t.new = c.NewTicker + case *variableTimer: + t.new = c.NewTimer + default: + panic("unrecognized timer type, cannot inject clock") + } + return t +} + +func Test_loopConditionWithContextImmediateDelay(t *testing.T) { + fakeClock := testingclock.NewFakeClock(time.Time{}) + backoff := Backoff{Duration: time.Second} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + expectedError := errors.New("Expected error") + var attempt int + f := ConditionFunc(func() (bool, error) { + attempt++ + return false, expectedError + }) + + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + if err := loopConditionUntilContext(ctx, timerWithClock(backoff.Timer(), fakeClock), false, true, f.WithContext()); err == nil || err != expectedError { + t.Errorf("unexpected error: %v", err) + } + }() + + for !fakeClock.HasWaiters() { + time.Sleep(time.Microsecond) + } + + fakeClock.Step(time.Second - time.Millisecond) + if attempt != 0 { + t.Fatalf("should still be waiting for condition") + } + fakeClock.Step(2 * time.Millisecond) + + select { + case <-doneCh: + case <-time.After(time.Second): + t.Fatalf("should have exited after a single loop") + } + if attempt != 1 { + t.Fatalf("expected attempt") + } +} + +func Test_loopConditionUntilContext_semantic(t *testing.T) { + defaultCallback := func(_ int) (bool, error) { + return false, nil + } + + conditionErr := errors.New("condition failed") + + tests := []struct { + name string + immediate bool + sliding bool + context func() (context.Context, context.CancelFunc) + callback func(calls int) (bool, error) + cancelContextAfter int + attemptsExpected int + errExpected error + }{ + { + name: "condition successful is only one attempt", + callback: func(attempts int) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + }, + { + name: "delayed condition successful causes return and attempts", + callback: func(attempts int) (bool, error) { + return attempts > 1, nil + }, + attemptsExpected: 2, + }, + { + name: "delayed condition successful causes return and attempts many times", + callback: func(attempts int) (bool, error) { + return attempts >= 100, nil + }, + attemptsExpected: 100, + }, + { + name: "condition returns error even if ok is true", + callback: func(_ int) (bool, error) { + return true, conditionErr + }, + attemptsExpected: 1, + errExpected: conditionErr, + }, + { + name: "condition exits after an error", + callback: func(_ int) (bool, error) { + return false, conditionErr + }, + attemptsExpected: 1, + errExpected: conditionErr, + }, + { + name: "context already canceled no attempts expected", + context: cancelledContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.Canceled, + }, + { + name: "context cancelled after 5 attempts", + context: defaultContext, + callback: defaultCallback, + cancelContextAfter: 5, + attemptsExpected: 5, + errExpected: context.Canceled, + }, + { + name: "context at deadline no attempts expected", + context: deadlinedContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.DeadlineExceeded, + }, + } + + for _, test := range tests { + for _, immediate := range []bool{true, false} { + t.Run(fmt.Sprintf("immediate=%t", immediate), func(t *testing.T) { + for _, sliding := range []bool{true, false} { + t.Run(fmt.Sprintf("sliding=%t", sliding), func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() + defer cancel() + + timer := Backoff{Duration: time.Microsecond}.Timer() + attempts := 0 + err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) { + attempts++ + defer func() { + if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts { + cancel() + } + }() + return test.callback(attempts) + }) + + if test.errExpected != err { + t.Errorf("expected error: %v but got: %v", test.errExpected, err) + } + + if test.attemptsExpected != attempts { + t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts) + } + }) + }) + } + }) + } + } +} + +type timerWrapper struct { + timer clock.Timer + resets []time.Duration + onReset func(d time.Duration) +} + +func (w *timerWrapper) C() <-chan time.Time { return w.timer.C() } +func (w *timerWrapper) Stop() bool { return w.timer.Stop() } +func (w *timerWrapper) Reset(d time.Duration) bool { + w.resets = append(w.resets, d) + b := w.timer.Reset(d) + if w.onReset != nil { + w.onReset(d) + } + return b +} + +func Test_loopConditionUntilContext_timings(t *testing.T) { + // Verify that timings returned by the delay func are passed to the timer, and that + // the timer advancing is enough to drive the state machine. Not a deep verification + // of the behavior of the loop, but tests that we drive the scenario to completion. + tests := []struct { + name string + delayFn DelayFunc + immediate bool + sliding bool + context func() (context.Context, context.CancelFunc) + callback func(calls int, lastInterval time.Duration) (bool, error) + cancelContextAfter int + attemptsExpected int + errExpected error + expectedIntervals func(t *testing.T, delays []time.Duration, delaysRequested []time.Duration) + }{ + { + name: "condition success", + delayFn: Backoff{Duration: time.Second, Steps: 2, Factor: 2.0, Jitter: 0}.DelayFunc(), + callback: func(attempts int, _ time.Duration) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + expectedIntervals: func(t *testing.T, delays []time.Duration, delaysRequested []time.Duration) { + if reflect.DeepEqual(delays, []time.Duration{time.Second, 2 * time.Second}) { + return + } + if reflect.DeepEqual(delaysRequested, []time.Duration{time.Second}) { + return + } + }, + }, + { + name: "condition success and immediate", + immediate: true, + delayFn: Backoff{Duration: time.Second, Steps: 2, Factor: 2.0, Jitter: 0}.DelayFunc(), + callback: func(attempts int, _ time.Duration) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + expectedIntervals: func(t *testing.T, delays []time.Duration, delaysRequested []time.Duration) { + if reflect.DeepEqual(delays, []time.Duration{time.Second}) { + return + } + if reflect.DeepEqual(delaysRequested, []time.Duration{}) { + return + } + }, + }, + { + name: "condition success and sliding", + sliding: true, + delayFn: Backoff{Duration: time.Second, Steps: 2, Factor: 2.0, Jitter: 0}.DelayFunc(), + callback: func(attempts int, _ time.Duration) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + expectedIntervals: func(t *testing.T, delays []time.Duration, delaysRequested []time.Duration) { + if reflect.DeepEqual(delays, []time.Duration{time.Second}) { + return + } + if !reflect.DeepEqual(delays, delaysRequested) { + t.Fatalf("sliding non-immediate should have equal delays: %v", cmp.Diff(delays, delaysRequested)) + } + }, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s/sliding=%t/immediate=%t", test.name, test.sliding, test.immediate), func(t *testing.T) { + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() + defer cancel() + + fakeClock := &testingclock.FakeClock{} + var fakeTimers []*timerWrapper + timerFn := func(d time.Duration) clock.Timer { + t := fakeClock.NewTimer(d) + fakeClock.Step(d + 1) + w := &timerWrapper{timer: t, resets: []time.Duration{d}, onReset: func(d time.Duration) { + fakeClock.Step(d + 1) + }} + fakeTimers = append(fakeTimers, w) + return w + } + + delayFn := test.delayFn + if delayFn == nil { + delayFn = Backoff{Duration: time.Microsecond}.DelayFunc() + } + var delays []time.Duration + wrappedDelayFn := func() time.Duration { + d := delayFn() + delays = append(delays, d) + return d + } + timer := &variableTimer{fn: wrappedDelayFn, new: timerFn} + + attempts := 0 + err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) { + attempts++ + defer func() { + if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts { + cancel() + } + }() + lastInterval := time.Duration(-1) + if len(delays) > 0 { + lastInterval = delays[len(delays)-1] + } + return test.callback(attempts, lastInterval) + }) + + if test.errExpected != err { + t.Errorf("expected error: %v but got: %v", test.errExpected, err) + } + + if test.attemptsExpected != attempts { + t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts) + } + switch len(fakeTimers) { + case 0: + test.expectedIntervals(t, delays, nil) + case 1: + test.expectedIntervals(t, delays, fakeTimers[0].resets) + default: + t.Fatalf("expected zero or one timers: %#v", fakeTimers) + } + }) + } +} + +// Test_loopConditionUntilContext_timings runs actual timing loops and calculates the delta. This +// test depends on high precision wakeups which depends on low CPU contention so it is not a +// candidate to run during normal unit test execution (nor is it a benchmark or example). Instead, +// it can be run manually if there is a scenario where we suspect the timings are off and other +// tests haven't caught it. A final sanity test that would have to be run serially in isolation. +func Test_loopConditionUntilContext_Elapsed(t *testing.T) { + const maxAttempts = 10 + // TODO: this may be too aggressive, but the overhead should be minor + const estimatedLoopOverhead = time.Millisecond + // estimate how long this delay can be + intervalMax := func(backoff Backoff) time.Duration { + d := backoff.Duration + if backoff.Jitter > 0 { + d += time.Duration(backoff.Jitter * float64(d)) + } + return d + } + // estimate how short this delay can be + intervalMin := func(backoff Backoff) time.Duration { + d := backoff.Duration + return d + } + + // Because timing is dependent other factors in test environments, such as + // whether the OS or go runtime scheduler wake the timers, excess duration + // is logged by default and can be converted to a fatal error for testing. + // fail := t.Fatalf + fail := t.Logf + + for _, test := range []struct { + name string + backoff Backoff + t reflect.Type + }{ + {name: "variable timer with jitter", backoff: Backoff{Duration: time.Millisecond, Jitter: 1.0}, t: reflect.TypeOf(&variableTimer{})}, + {name: "fixed timer", backoff: Backoff{Duration: time.Millisecond}, t: reflect.TypeOf(&fixedTimer{})}, + {name: "no-op timer", backoff: Backoff{}, t: reflect.TypeOf(noopTimer{})}, + } { + t.Run(test.name, func(t *testing.T) { + var attempts int + start := time.Now() + timer := test.backoff.Timer() + if test.t != reflect.ValueOf(timer).Type() { + t.Fatalf("unexpected timer type %T: expected %v", timer, test.t) + } + if err := loopConditionUntilContext(context.Background(), timer, false, false, func(_ context.Context) (bool, error) { + attempts++ + if attempts > maxAttempts { + t.Fatalf("should not reach %d attempts", maxAttempts+1) + } + return attempts >= maxAttempts, nil + }); err != nil { + t.Fatal(err) + } + duration := time.Since(start) + if min := maxAttempts * intervalMin(test.backoff); duration < min { + fail("elapsed duration %v < expected min duration %v", duration, min) + } + if max := maxAttempts * (intervalMax(test.backoff) + estimatedLoopOverhead); duration > max { + fail("elapsed duration %v > expected max duration %v", duration, max) + } + }) + } +} + +func Benchmark_loopConditionUntilContext_ZeroDuration(b *testing.B) { + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + attempts := 0 + if err := loopConditionUntilContext(ctx, Backoff{Duration: 0}.Timer(), true, false, func(_ context.Context) (bool, error) { + attempts++ + return attempts >= 100, nil + }); err != nil { + b.Fatalf("unexpected err: %v", err) + } + } +} + +func Benchmark_loopConditionUntilContext_ShortDuration(b *testing.B) { + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + attempts := 0 + if err := loopConditionUntilContext(ctx, Backoff{Duration: time.Microsecond}.Timer(), true, false, func(_ context.Context) (bool, error) { + attempts++ + return attempts >= 100, nil + }); err != nil { + b.Fatalf("unexpected err: %v", err) + } + } +} diff --git a/pkg/util/wait/poll.go b/pkg/util/wait/poll.go index 564e9b9d2..32e8688ca 100644 --- a/pkg/util/wait/poll.go +++ b/pkg/util/wait/poll.go @@ -21,6 +21,33 @@ import ( "time" ) +// PollUntilContextCancel tries a condition func until it returns true, an error, or the context +// is cancelled or hits a deadline. condition will be invoked after the first interval if the +// context is not cancelled first. The returned error will be from ctx.Err(), the condition's +// err return value, or nil. If invoking condition takes longer than interval the next condition +// will be invoked immediately. When using very short intervals, condition may be invoked multiple +// times before a context cancellation is detected. If immediate is true, condition will be +// invoked before waiting and guarantees that condition is invoked at least once, regardless of +// whether the context has been cancelled. +func PollUntilContextCancel(ctx context.Context, interval time.Duration, immediate bool, condition ConditionWithContextFunc) error { + return loopConditionUntilContext(ctx, Backoff{Duration: interval}.Timer(), immediate, false, condition) +} + +// PollUntilContextTimeout will terminate polling after timeout duration by setting a context +// timeout. This is provided as a convenience function for callers not currently executing under +// a deadline and is equivalent to: +// +// deadlineCtx, deadlineCancel := context.WithTimeout(ctx, timeout) +// err := PollUntilContextCancel(ctx, interval, immediate, condition) +// +// The deadline context will be cancelled if the Poll succeeds before the timeout, simplifying +// inline usage. All other behavior is identical to PollWithContextTimeout. +func PollUntilContextTimeout(ctx context.Context, interval, timeout time.Duration, immediate bool, condition ConditionWithContextFunc) error { + deadlineCtx, deadlineCancel := context.WithTimeout(ctx, timeout) + defer deadlineCancel() + return loopConditionUntilContext(deadlineCtx, Backoff{Duration: interval}.Timer(), immediate, false, condition) +} + // Poll tries a condition func until it returns true, an error, or the timeout // is reached. // @@ -31,6 +58,10 @@ import ( // window is too short. // // If you want to Poll something forever, see PollInfinite. +// +// Deprecated: This method does not return errors from context, use PollWithContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func Poll(interval, timeout time.Duration, condition ConditionFunc) error { return PollWithContext(context.Background(), interval, timeout, condition.WithContext()) } @@ -46,6 +77,10 @@ func Poll(interval, timeout time.Duration, condition ConditionFunc) error { // window is too short. // // If you want to Poll something forever, see PollInfinite. +// +// Deprecated: This method does not return errors from context, use PollWithContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, false, poller(interval, timeout), condition) } @@ -55,6 +90,10 @@ func PollWithContext(ctx context.Context, interval, timeout time.Duration, condi // // PollUntil always waits interval before the first run of 'condition'. // 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { return PollUntilWithContext(ContextForChannel(stopCh), interval, condition.WithContext()) } @@ -64,6 +103,10 @@ func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan st // // PollUntilWithContext always waits interval before the first run of 'condition'. // 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, false, poller(interval, 0), condition) } @@ -74,6 +117,10 @@ func PollUntilWithContext(ctx context.Context, interval time.Duration, condition // // Some intervals may be missed if the condition takes too long or the time // window is too short. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollInfinite(interval time.Duration, condition ConditionFunc) error { return PollInfiniteWithContext(context.Background(), interval, condition.WithContext()) } @@ -84,6 +131,10 @@ func PollInfinite(interval time.Duration, condition ConditionFunc) error { // // Some intervals may be missed if the condition takes too long or the time // window is too short. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, false, poller(interval, 0), condition) } @@ -98,6 +149,10 @@ func PollInfiniteWithContext(ctx context.Context, interval time.Duration, condit // window is too short. // // If you want to immediately Poll something forever, see PollImmediateInfinite. +// +// Deprecated: This method does not return errors from context, use PollWithContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error { return PollImmediateWithContext(context.Background(), interval, timeout, condition.WithContext()) } @@ -112,6 +167,10 @@ func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) err // window is too short. // // If you want to immediately Poll something forever, see PollImmediateInfinite. +// +// Deprecated: This method does not return errors from context, use PollWithContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, true, poller(interval, timeout), condition) } @@ -120,6 +179,10 @@ func PollImmediateWithContext(ctx context.Context, interval, timeout time.Durati // // PollImmediateUntil runs the 'condition' before waiting for the interval. // 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { return PollImmediateUntilWithContext(ContextForChannel(stopCh), interval, condition.WithContext()) } @@ -129,6 +192,10 @@ func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh // // PollImmediateUntilWithContext runs the 'condition' before waiting for the interval. // 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, true, poller(interval, 0), condition) } @@ -139,6 +206,10 @@ func PollImmediateUntilWithContext(ctx context.Context, interval time.Duration, // // Some intervals may be missed if the condition takes too long or the time // window is too short. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateInfinite(interval time.Duration, condition ConditionFunc) error { return PollImmediateInfiniteWithContext(context.Background(), interval, condition.WithContext()) } @@ -150,6 +221,10 @@ func PollImmediateInfinite(interval time.Duration, condition ConditionFunc) erro // // Some intervals may be missed if the condition takes too long or the time // window is too short. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, true, poller(interval, 0), condition) } @@ -163,6 +238,8 @@ func PollImmediateInfiniteWithContext(ctx context.Context, interval time.Duratio // wait: user specified WaitFunc function that controls at what interval the condition // function should be invoked periodically and whether it is bound by a timeout. // condition: user specified ConditionWithContextFunc function. +// +// Deprecated: will be removed in favor of loopConditionUntilContext. func poll(ctx context.Context, immediate bool, wait waitWithContextFunc, condition ConditionWithContextFunc) error { if immediate { done, err := runConditionWithCrashProtectionWithContext(ctx, condition) @@ -176,7 +253,8 @@ func poll(ctx context.Context, immediate bool, wait waitWithContextFunc, conditi select { case <-ctx.Done(): - // returning ctx.Err() will break backward compatibility + // returning ctx.Err() will break backward compatibility, use new PollUntilContext* + // methods instead return ErrWaitTimeout default: return waitForWithContext(ctx, wait, condition) @@ -193,6 +271,8 @@ func poll(ctx context.Context, immediate bool, wait waitWithContextFunc, conditi // // Output ticks are not buffered. If the channel is not ready to receive an // item, the tick is skipped. +// +// Deprecated: Will be removed in a future release. func poller(interval, timeout time.Duration) waitWithContextFunc { return waitWithContextFunc(func(ctx context.Context) <-chan struct{} { ch := make(chan struct{}) diff --git a/pkg/util/wait/timer.go b/pkg/util/wait/timer.go new file mode 100644 index 000000000..3efba3213 --- /dev/null +++ b/pkg/util/wait/timer.go @@ -0,0 +1,121 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "time" + + "k8s.io/utils/clock" +) + +// Timer abstracts how wait functions interact with time runtime efficiently. Test +// code may implement this interface directly but package consumers are encouraged +// to use the Backoff type as the primary mechanism for acquiring a Timer. The +// interface is a simplification of clock.Timer to prevent misuse. Timers are not +// expected to be safe for calls from multiple goroutines. +type Timer interface { + // C returns a channel that will receive a struct{} each time the timer fires. + // The channel should not be waited on after Stop() is invoked. It is allowed + // to cache the returned value of C() for the lifetime of the Timer. + C() <-chan time.Time + // Next is invoked by wait functions to signal timers that the next interval + // should begin. You may only use Next() if you have drained the channel C(). + // You should not call Next() after Stop() is invoked. + Next() + // Stop releases the timer. It is safe to invoke if no other methods have been + // called. + Stop() +} + +type noopTimer struct { + closedCh <-chan time.Time +} + +// newNoopTimer creates a timer with a unique channel to avoid contention +// for the channel's lock across multiple unrelated timers. +func newNoopTimer() noopTimer { + ch := make(chan time.Time) + close(ch) + return noopTimer{closedCh: ch} +} + +func (t noopTimer) C() <-chan time.Time { + return t.closedCh +} +func (noopTimer) Next() {} +func (noopTimer) Stop() {} + +type variableTimer struct { + fn DelayFunc + t clock.Timer + new func(time.Duration) clock.Timer +} + +func (t *variableTimer) C() <-chan time.Time { + if t.t == nil { + d := t.fn() + t.t = t.new(d) + } + return t.t.C() +} +func (t *variableTimer) Next() { + if t.t == nil { + return + } + d := t.fn() + t.t.Reset(d) +} +func (t *variableTimer) Stop() { + if t.t == nil { + return + } + t.t.Stop() + t.t = nil +} + +type fixedTimer struct { + interval time.Duration + t clock.Ticker + new func(time.Duration) clock.Ticker +} + +func (t *fixedTimer) C() <-chan time.Time { + if t.t == nil { + t.t = t.new(t.interval) + } + return t.t.C() +} +func (t *fixedTimer) Next() { + // no-op for fixed timers +} +func (t *fixedTimer) Stop() { + if t.t == nil { + return + } + t.t.Stop() + t.t = nil +} + +var ( + // RealTimer can be passed to methods that need a clock.Timer. + RealTimer = clock.RealClock{}.NewTimer +) + +var ( + // internalClock is used for test injection of clocks + internalClock = clock.RealClock{} +) diff --git a/pkg/util/wait/wait.go b/pkg/util/wait/wait.go index c6e516dfc..6805e8cf9 100644 --- a/pkg/util/wait/wait.go +++ b/pkg/util/wait/wait.go @@ -137,13 +137,18 @@ func (c channelContext) Err() error { func (c channelContext) Deadline() (time.Time, bool) { return time.Time{}, false } func (c channelContext) Value(key any) any { return nil } -// runConditionWithCrashProtection runs a ConditionFunc with crash protection +// runConditionWithCrashProtection runs a ConditionFunc with crash protection. +// +// Deprecated: Will be removed when the legacy polling methods are removed. func runConditionWithCrashProtection(condition ConditionFunc) (bool, error) { - return runConditionWithCrashProtectionWithContext(context.TODO(), condition.WithContext()) + defer runtime.HandleCrash() + return condition() } -// runConditionWithCrashProtectionWithContext runs a -// ConditionWithContextFunc with crash protection. +// runConditionWithCrashProtectionWithContext runs a ConditionWithContextFunc +// with crash protection. +// +// Deprecated: Will be removed when the legacy polling methods are removed. func runConditionWithCrashProtectionWithContext(ctx context.Context, condition ConditionWithContextFunc) (bool, error) { defer runtime.HandleCrash() return condition(ctx) @@ -151,6 +156,9 @@ func runConditionWithCrashProtectionWithContext(ctx context.Context, condition C // waitFunc creates a channel that receives an item every time a test // should be executed and is closed when the last test should be invoked. +// +// Deprecated: Will be removed in a future release in favor of +// loopConditionUntilContext. type waitFunc func(done <-chan struct{}) <-chan struct{} // WithContext converts the WaitFunc to an equivalent WaitWithContextFunc @@ -166,7 +174,8 @@ func (w waitFunc) WithContext() waitWithContextFunc { // When the specified context gets cancelled or expires the function // stops sending item and returns immediately. // -// Deprecated: Will be removed when the legacy Poll methods are removed. +// Deprecated: Will be removed in a future release in favor of +// loopConditionUntilContext. type waitWithContextFunc func(ctx context.Context) <-chan struct{} // waitForWithContext continually checks 'fn' as driven by 'wait'. @@ -186,7 +195,8 @@ type waitWithContextFunc func(ctx context.Context) <-chan struct{} // "uniform pseudo-random", the `fn` might still run one or multiple times, // though eventually `waitForWithContext` will return. // -// Deprecated: Will be removed when the legacy Poll methods are removed. +// Deprecated: Will be removed in a future release in favor of +// loopConditionUntilContext. func waitForWithContext(ctx context.Context, wait waitWithContextFunc, fn ConditionWithContextFunc) error { waitCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -205,7 +215,8 @@ func waitForWithContext(ctx context.Context, wait waitWithContextFunc, fn Condit return ErrWaitTimeout } case <-ctx.Done(): - // returning ctx.Err() will break backward compatibility + // returning ctx.Err() will break backward compatibility, use new PollUntilContext* + // methods instead return ErrWaitTimeout } } diff --git a/pkg/util/wait/wait_test.go b/pkg/util/wait/wait_test.go index 82ff8866f..c8dd0bf58 100644 --- a/pkg/util/wait/wait_test.go +++ b/pkg/util/wait/wait_test.go @@ -114,7 +114,12 @@ func TestNonSlidingUntilWithContext(t *testing.T) { func TestUntilReturnsImmediately(t *testing.T) { now := time.Now() ch := make(chan struct{}) + var attempts int Until(func() { + attempts++ + if attempts > 1 { + t.Fatalf("invoked after close of channel") + } close(ch) }, 30*time.Second, ch) if now.Add(25 * time.Second).Before(time.Now()) { @@ -233,15 +238,24 @@ func TestJitterUntilNegativeFactor(t *testing.T) { if now.Add(3 * time.Second).Before(time.Now()) { t.Errorf("JitterUntil did not returned after predefined period with negative jitter factor when the stop chan was closed inside the func") } - } func TestExponentialBackoff(t *testing.T) { + // exits immediately + i := 0 + err := ExponentialBackoff(Backoff{Factor: 1.0}, func() (bool, error) { + i++ + return false, nil + }) + if err != ErrWaitTimeout || i != 0 { + t.Errorf("unexpected error: %v", err) + } + opts := Backoff{Factor: 1.0, Steps: 3} // waits up to steps - i := 0 - err := ExponentialBackoff(opts, func() (bool, error) { + i = 0 + err = ExponentialBackoff(opts, func() (bool, error) { i++ return false, nil }) @@ -339,7 +353,7 @@ func (fp *fakePoller) GetwaitFunc() waitFunc { func TestPoll(t *testing.T) { invocations := 0 - f := ConditionFunc(func() (bool, error) { + f := ConditionWithContextFunc(func(ctx context.Context) (bool, error) { invocations++ return true, nil }) @@ -347,7 +361,7 @@ func TestPoll(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f.WithContext()); err != nil { + if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f); err != nil { t.Fatalf("unexpected error %v", err) } fp.wg.Wait() @@ -540,7 +554,7 @@ func Test_waitFor(t *testing.T) { } } -// Test_waitForWithEarlyClosing_waitFunc tests waitFor when the waitFunc closes its channel. The waitFor should +// Test_waitForWithEarlyClosing_waitFunc tests WaitFor when the waitFunc closes its channel. The WaitFor should // always return ErrWaitTimeout. func Test_waitForWithEarlyClosing_waitFunc(t *testing.T) { stopCh := make(chan struct{}) @@ -597,12 +611,12 @@ func Test_waitForWithClosedChannel(t *testing.T) { func Test_waitForWithContextCancelsContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - waitFunc := poller(time.Millisecond, ForeverTestTimeout) + waitFn := poller(time.Millisecond, ForeverTestTimeout) var ctxPassedToWait context.Context waitForWithContext(ctx, func(ctx context.Context) <-chan struct{} { ctxPassedToWait = ctx - return waitFunc(ctx) + return waitFn(ctx) }, func(ctx context.Context) (bool, error) { time.Sleep(10 * time.Millisecond) return true, nil @@ -633,14 +647,14 @@ func TestPollUntil(t *testing.T) { close(stopCh) go func() { - // release the condition func if needed - for { - <-called + // release the condition func if needed + for range called { } }() // make sure we finished the poll <-pollDone + close(called) } func TestBackoff_Step(t *testing.T) { @@ -648,6 +662,8 @@ func TestBackoff_Step(t *testing.T) { initial *Backoff want []time.Duration }{ + {initial: nil, want: []time.Duration{0, 0, 0, 0}}, + {initial: &Backoff{Duration: time.Second, Steps: -1}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Steps: 0}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Factor: 1.0, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}}, @@ -658,13 +674,19 @@ func TestBackoff_Step(t *testing.T) { } for seed := int64(0); seed < 5; seed++ { for _, tt := range tests { - initial := *tt.initial + var initial *Backoff + if tt.initial != nil { + copied := *tt.initial + initial = &copied + } else { + initial = nil + } t.Run(fmt.Sprintf("%#v seed=%d", initial, seed), func(t *testing.T) { rand.Seed(seed) for i := 0; i < len(tt.want); i++ { got := initial.Step() t.Logf("[%d]=%s", i, got) - if initial.Jitter > 0 { + if initial != nil && initial.Jitter > 0 { if got == tt.want[i] { // this is statistically unlikely to happen by chance t.Errorf("Backoff.Step(%d) = %v, no jitter", i, got) @@ -779,11 +801,105 @@ func TestExponentialBackoffManagerWithRealClock(t *testing.T) { } } -func TestExponentialBackoffWithContext(t *testing.T) { - defaultCtx := func() context.Context { - return context.Background() +func TestBackoffDelayWithResetExponential(t *testing.T) { + fc := testingclock.NewFakeClock(time.Now()) + backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 10) + durations := []time.Duration{1, 2, 4, 8, 10, 10, 10} + for i := 0; i < len(durations); i++ { + generatedBackoff := backoff() + if generatedBackoff != durations[i] { + t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i]) + } } + fc.Step(11) + resetDuration := backoff() + if resetDuration != 1 { + t.Errorf("after reset, backoff should be 1, but got %d", resetDuration) + } +} + +func TestBackoffDelayWithResetEmpty(t *testing.T) { + fc := testingclock.NewFakeClock(time.Now()) + backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 0) + // we reset to initial duration because the resetInterval is 0, immediate + durations := []time.Duration{1, 1, 1, 1, 1, 1, 1} + for i := 0; i < len(durations); i++ { + generatedBackoff := backoff() + if generatedBackoff != durations[i] { + t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i]) + } + } + + fc.Step(11) + resetDuration := backoff() + if resetDuration != 1 { + t.Errorf("after reset, backoff should be 1, but got %d", resetDuration) + } +} + +func TestBackoffDelayWithResetJitter(t *testing.T) { + // positive jitter + backoff := Backoff{Duration: 1, Jitter: 1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0) + for i := 0; i < 5; i++ { + value := backoff() + if value < 1 || value > 2 { + t.Errorf("backoff out of range: %d", value) + } + } + + // negative jitter, shall be a fixed backoff + backoff = Backoff{Duration: 1, Jitter: -1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0) + value := backoff() + if value != 1 { + t.Errorf("backoff should be 1, but got %d", value) + } +} + +func TestBackoffDelayWithResetWithRealClockJitter(t *testing.T) { + backoff := Backoff{Duration: 1 * time.Millisecond, Jitter: 0}.DelayWithReset(&clock.RealClock{}, 0) + for i := 0; i < 5; i++ { + start := time.Now() + <-RealTimer(backoff()).C() + passed := time.Since(start) + if passed < 1*time.Millisecond { + t.Errorf("backoff should be at least 1ms, but got %s", passed.String()) + } + } +} + +func TestBackoffDelayWithResetWithRealClockExponential(t *testing.T) { + // backoff at least 1ms, 2ms, 4ms, 8ms, 10ms, 10ms, 10ms + durationFactors := []time.Duration{1, 2, 4, 8, 10, 10, 10} + backoff := Backoff{Duration: 1 * time.Millisecond, Cap: 10 * time.Millisecond, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(&clock.RealClock{}, 1*time.Hour) + + for i := range durationFactors { + start := time.Now() + <-RealTimer(backoff()).C() + passed := time.Since(start) + if passed < durationFactors[i]*time.Millisecond { + t.Errorf("backoff should be at least %d ms, but got %s", durationFactors[i], passed.String()) + } + } +} + +func defaultContext() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) +} +func cancelledContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx, cancel +} +func deadlinedContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + for ctx.Err() != context.DeadlineExceeded { + time.Sleep(501 * time.Microsecond) + } + return ctx, cancel +} + +func TestExponentialBackoffWithContext(t *testing.T) { defaultCallback := func(_ int) (bool, error) { return false, nil } @@ -791,17 +907,18 @@ func TestExponentialBackoffWithContext(t *testing.T) { conditionErr := errors.New("condition failed") tests := []struct { - name string - steps int - ctxGetter func() context.Context - callback func(calls int) (bool, error) - attemptsExpected int - errExpected error + name string + steps int + zeroDuration bool + context func() (context.Context, context.CancelFunc) + callback func(calls int) (bool, error) + cancelContextAfter int + attemptsExpected int + errExpected error }{ { name: "no attempts expected with zero backoff steps", steps: 0, - ctxGetter: defaultCtx, callback: defaultCallback, attemptsExpected: 0, errExpected: ErrWaitTimeout, @@ -809,15 +926,13 @@ func TestExponentialBackoffWithContext(t *testing.T) { { name: "condition returns false with single backoff step", steps: 1, - ctxGetter: defaultCtx, callback: defaultCallback, attemptsExpected: 1, errExpected: ErrWaitTimeout, }, { - name: "condition returns true with single backoff step", - steps: 1, - ctxGetter: defaultCtx, + name: "condition returns true with single backoff step", + steps: 1, callback: func(_ int) (bool, error) { return true, nil }, @@ -827,15 +942,13 @@ func TestExponentialBackoffWithContext(t *testing.T) { { name: "condition always returns false with multiple backoff steps", steps: 5, - ctxGetter: defaultCtx, callback: defaultCallback, attemptsExpected: 5, errExpected: ErrWaitTimeout, }, { - name: "condition returns true after certain attempts with multiple backoff steps", - steps: 5, - ctxGetter: defaultCtx, + name: "condition returns true after certain attempts with multiple backoff steps", + steps: 5, callback: func(attempts int) (bool, error) { if attempts == 3 { return true, nil @@ -846,9 +959,8 @@ func TestExponentialBackoffWithContext(t *testing.T) { errExpected: nil, }, { - name: "condition returns error no further attempts expected", - steps: 5, - ctxGetter: defaultCtx, + name: "condition returns error no further attempts expected", + steps: 5, callback: func(_ int) (bool, error) { return true, conditionErr }, @@ -856,30 +968,118 @@ func TestExponentialBackoffWithContext(t *testing.T) { errExpected: conditionErr, }, { - name: "context already canceled no attempts expected", + name: "context already canceled no attempts expected", + steps: 5, + context: cancelledContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.Canceled, + }, + { + name: "context at deadline no attempts expected", + steps: 5, + context: deadlinedContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.DeadlineExceeded, + }, + { + name: "no attempts expected with zero backoff steps", + steps: 0, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: ErrWaitTimeout, + }, + { + name: "condition returns false with single backoff step", + steps: 1, + callback: defaultCallback, + attemptsExpected: 1, + errExpected: ErrWaitTimeout, + }, + { + name: "condition returns true with single backoff step", + steps: 1, + callback: func(_ int) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + errExpected: nil, + }, + { + name: "condition always returns false with multiple backoff steps but is cancelled at step 4", + steps: 5, + callback: defaultCallback, + attemptsExpected: 4, + cancelContextAfter: 4, + errExpected: context.Canceled, + }, + { + name: "condition returns true after certain attempts with multiple backoff steps and zero duration", + steps: 5, + zeroDuration: true, + callback: func(attempts int) (bool, error) { + if attempts == 3 { + return true, nil + } + return false, nil + }, + attemptsExpected: 3, + errExpected: nil, + }, + { + name: "condition returns error no further attempts expected", steps: 5, - ctxGetter: func() context.Context { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx + callback: func(_ int) (bool, error) { + return true, conditionErr }, + attemptsExpected: 1, + errExpected: conditionErr, + }, + { + name: "context already canceled no attempts expected", + steps: 5, + context: cancelledContext, callback: defaultCallback, attemptsExpected: 0, errExpected: context.Canceled, }, + { + name: "context at deadline no attempts expected", + steps: 5, + context: deadlinedContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.DeadlineExceeded, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { backoff := Backoff{ - Duration: 1 * time.Millisecond, + Duration: 1 * time.Microsecond, Factor: 1.0, Steps: test.steps, } + if test.zeroDuration { + backoff.Duration = 0 + } + + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() + defer cancel() attempts := 0 - err := ExponentialBackoffWithContext(test.ctxGetter(), backoff, func(_ context.Context) (bool, error) { + err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) { attempts++ + defer func() { + if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts { + cancel() + } + }() return test.callback(attempts) }) @@ -894,6 +1094,26 @@ func TestExponentialBackoffWithContext(t *testing.T) { } } +func BenchmarkExponentialBackoffWithContext(b *testing.B) { + backoff := Backoff{ + Duration: 0, + Factor: 0, + Steps: 101, + } + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + attempts := 0 + if err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) { + attempts++ + return attempts >= 100, nil + }); err != nil { + b.Fatalf("unexpected err: %v", err) + } + } +} + func TestPollImmediateUntilWithContext(t *testing.T) { fakeErr := errors.New("my error") tests := []struct { @@ -911,9 +1131,6 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return false, fakeErr } }, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, errExpected: fakeErr, attemptsExpected: 1, }, @@ -924,9 +1141,6 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return true, nil } }, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, errExpected: nil, attemptsExpected: 1, }, @@ -937,12 +1151,8 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return false, nil } }, - context: func() (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, - errExpected: ErrWaitTimeout, + context: cancelledContext, + errExpected: ErrWaitTimeout, // this should be context.Canceled but that would break callers that assume all errors are ErrWaitTimeout attemptsExpected: 1, }, { @@ -956,9 +1166,6 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return true, nil } }, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, errExpected: nil, attemptsExpected: 4, }, @@ -969,18 +1176,19 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return false, nil } }, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, cancelContextAfterNthAttempt: 4, - errExpected: ErrWaitTimeout, + errExpected: ErrWaitTimeout, // this should be context.Canceled, but this method cannot change attemptsExpected: 4, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx, cancel := test.context() + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() defer cancel() var attempts int @@ -1018,10 +1226,8 @@ func Test_waitForWithContext(t *testing.T) { errExpected error }{ { - name: "condition returns done=true on first attempt, no retry is attempted", - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, + name: "condition returns done=true on first attempt, no retry is attempted", + context: defaultContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), @@ -1030,10 +1236,8 @@ func Test_waitForWithContext(t *testing.T) { errExpected: nil, }, { - name: "condition always returns done=false, timeout error expected", - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, + name: "condition always returns done=false, timeout error expected", + context: defaultContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1043,10 +1247,8 @@ func Test_waitForWithContext(t *testing.T) { errExpected: ErrWaitTimeout, }, { - name: "condition returns an error on first attempt, the error is returned", - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, + name: "condition returns an error on first attempt, the error is returned", + context: defaultContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), @@ -1055,12 +1257,8 @@ func Test_waitForWithContext(t *testing.T) { errExpected: fakeErr, }, { - name: "context is cancelled, context cancelled error expected", - context: func() (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - return ctx, cancel - }, + name: "context is cancelled, context cancelled error expected", + context: cancelledContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1086,7 +1284,11 @@ func Test_waitForWithContext(t *testing.T) { ticker := test.waitFunc() err := func() error { - ctx, cancel := test.context() + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() defer cancel() return waitForWithContext(ctx, ticker.WithContext(), conditionWrapper) @@ -1102,7 +1304,7 @@ func Test_waitForWithContext(t *testing.T) { } } -func TestPollInternal(t *testing.T) { +func Test_poll(t *testing.T) { fakeErr := errors.New("fake error") tests := []struct { name string @@ -1117,13 +1319,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is true, condition returns an error", immediate: true, - context: func() (context.Context, context.CancelFunc) { - // use a cancelled context, we want to make sure the - // condition is expected to be invoked immediately. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), @@ -1134,13 +1329,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is true, condition returns true", immediate: true, - context: func() (context.Context, context.CancelFunc) { - // use a cancelled context, we want to make sure the - // condition is expected to be invoked immediately. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), @@ -1151,13 +1339,7 @@ func TestPollInternal(t *testing.T) { { name: "immediate is true, context is cancelled, condition return false", immediate: true, - context: func() (context.Context, context.CancelFunc) { - // use a cancelled context, we want to make sure the - // condition is expected to be invoked immediately. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, + context: cancelledContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1168,13 +1350,7 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, context is cancelled", immediate: false, - context: func() (context.Context, context.CancelFunc) { - // use a cancelled context, we want to make sure the - // condition is expected to be invoked immediately. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, + context: cancelledContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1185,9 +1361,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, condition returns an error", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), @@ -1198,9 +1371,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, condition returns true", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), @@ -1211,9 +1381,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, ticker channel is closed, condition returns true", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), @@ -1230,9 +1397,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, ticker channel is closed, condition returns error", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), @@ -1249,9 +1413,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, ticker channel is closed, condition returns false", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1268,9 +1429,6 @@ func TestPollInternal(t *testing.T) { { name: "condition always returns false, timeout error expected", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1282,9 +1440,27 @@ func TestPollInternal(t *testing.T) { { name: "context is cancelled after N attempts, timeout error expected", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, nil + }), + waitFunc: func() waitFunc { + return func(done <-chan struct{}) <-chan struct{} { + ch := make(chan struct{}) + // just tick twice + go func() { + ch <- struct{}{} + ch <- struct{}{} + }() + return ch + } }, + cancelContextAfter: 2, + attemptsExpected: 2, + errExpected: ErrWaitTimeout, + }, + { + name: "context is cancelled after N attempts, context error not expected (legacy behavior)", + immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1315,7 +1491,11 @@ func TestPollInternal(t *testing.T) { ticker = test.waitFunc() } err := func() error { - ctx, cancel := test.context() + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() defer cancel() conditionWrapper := func(ctx context.Context) (done bool, err error) { @@ -1342,3 +1522,17 @@ func TestPollInternal(t *testing.T) { }) } } + +func Benchmark_poll(b *testing.B) { + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + attempts := 0 + if err := poll(ctx, true, poller(time.Microsecond, 0), func(_ context.Context) (bool, error) { + attempts++ + return attempts >= 100, nil + }); err != nil { + b.Fatalf("unexpected err: %v", err) + } + } +}