diff --git a/src/Polly.Core/ToBeRemoved/TimeProvider.cs b/src/Polly.Core/ToBeRemoved/TimeProvider.cs index 58c519d7e9e..573be648b64 100644 --- a/src/Polly.Core/ToBeRemoved/TimeProvider.cs +++ b/src/Polly.Core/ToBeRemoved/TimeProvider.cs @@ -3,7 +3,15 @@ using System.Diagnostics.CodeAnalysis; -#pragma warning disable S3872 // Parameter names should not duplicate the names of their methods +#pragma warning disable + +namespace System.Threading +{ + internal interface ITimer : IDisposable, IAsyncDisposable + { + bool Change(TimeSpan dueTime, TimeSpan period); + } +} namespace System { @@ -53,9 +61,6 @@ public DateTimeOffset GetLocalNow() // This one is not on TimeProvider, temporarly we need to use it public virtual Task Delay(TimeSpan delay, CancellationToken cancellationToken = default) => Task.Delay(delay, cancellationToken); - // This one is not on TimeProvider, temporarly we need to use it - public virtual void CancelAfter(CancellationTokenSource source, TimeSpan delay) => source.CancelAfter(delay); - public TimeSpan GetElapsedTime(long startingTimestamp, long endingTimestamp) { long timestampFrequency = TimestampFrequency; @@ -69,6 +74,104 @@ public TimeSpan GetElapsedTime(long startingTimestamp, long endingTimestamp) public TimeSpan GetElapsedTime(long startingTimestamp) => GetElapsedTime(startingTimestamp, GetTimestamp()); + public virtual ITimer CreateTimer(TimerCallback callback, object? state, TimeSpan dueTime, TimeSpan period) + { + if (callback is null) + { + throw new ArgumentNullException(nameof(callback)); + } + + return new SystemTimeProviderTimer(dueTime, period, callback, state); + } + + [ExcludeFromCodeCoverage] + private sealed class SystemTimeProviderTimer : ITimer + { + private readonly Timer _timer; + + public SystemTimeProviderTimer(TimeSpan dueTime, TimeSpan period, TimerCallback callback, object? state) + { + (uint duration, uint periodTime) = CheckAndGetValues(dueTime, period); + + // We need to ensure the timer roots itself. Timer created with a duration and period argument + // only roots the state object, so to root the timer we need the state object to reference the + // timer recursively. + var timerState = new TimerState(callback, state); + timerState.Timer = _timer = new Timer(static s => + { + TimerState ts = (TimerState)s!; + ts.Callback(ts.State); + }, timerState, duration, periodTime); + } + + private sealed class TimerState + { + public TimerState(TimerCallback callback, object? state) + { + Callback = callback; + State = state; + } + + public TimerCallback Callback { get; } + + public object? State { get; } + + public Timer? Timer { get; set; } + } + + public bool Change(TimeSpan dueTime, TimeSpan period) + { + (uint duration, uint periodTime) = CheckAndGetValues(dueTime, period); + try + { + return _timer.Change(duration, periodTime); + } + catch (ObjectDisposedException) + { + return false; + } + } + + public void Dispose() => _timer.Dispose(); + + public ValueTask DisposeAsync() + { + _timer.Dispose(); + return default; + } + + private static (uint duration, uint periodTime) CheckAndGetValues(TimeSpan dueTime, TimeSpan periodTime) + { + long dueTm = (long)dueTime.TotalMilliseconds; + long periodTm = (long)periodTime.TotalMilliseconds; + + const uint MaxSupportedTimeout = 0xfffffffe; + + if (dueTm < -1) + { + throw new ArgumentOutOfRangeException(nameof(dueTime)); + } + + if (dueTm > MaxSupportedTimeout) + { + throw new ArgumentOutOfRangeException(nameof(dueTime)); + } + + if (periodTm < -1) + { + throw new ArgumentOutOfRangeException(nameof(periodTm)); + } + + if (periodTm > MaxSupportedTimeout) + { + throw new ArgumentOutOfRangeException(nameof(periodTm)); + } + + return ((uint)dueTm, (uint)periodTm); + } + } + + [ExcludeFromCodeCoverage] private sealed class SystemTimeProvider : TimeProvider { internal SystemTimeProvider() @@ -77,3 +180,280 @@ internal SystemTimeProvider() } } } + +namespace System.Threading.Tasks +{ + /// + /// Provide extensions methods for operations with . + /// + /// + /// The Microsoft.Bcl.TimeProvider library interfaces are intended solely for use in building against pre-.NET 8 surface area. + /// If your code is being built against .NET 8 or higher, then this library should not be utilized. + /// + [ExcludeFromCodeCoverage] + internal static class TimeProviderTaskExtensions + { + private sealed class DelayState : TaskCompletionSource + { + public DelayState(CancellationToken cancellationToken) + : base(TaskCreationOptions.RunContinuationsAsynchronously) => CancellationToken = cancellationToken; + + public ITimer? Timer { get; set; } + public CancellationToken CancellationToken { get; } + public CancellationTokenRegistration Registration { get; set; } + } + + private sealed class WaitAsyncState : TaskCompletionSource + { + public WaitAsyncState(CancellationToken cancellationToken) + : base(TaskCreationOptions.RunContinuationsAsynchronously) => CancellationToken = cancellationToken; + + public readonly CancellationTokenSource ContinuationCancellation = new(); + public CancellationToken CancellationToken { get; } + public CancellationTokenRegistration Registration; + public ITimer? Timer; + } + + /// Creates a task that completes after a specified time interval. + /// The with which to interpret . + /// The to wait before completing the returned task, or to wait indefinitely. + /// A cancellation token to observe while waiting for the task to complete. + /// A task that represents the time delay. + /// The argument is null. + /// represents a negative time interval other than . + public static Task Delay(this TimeProvider timeProvider, TimeSpan delay, CancellationToken cancellationToken = default) + { + if (timeProvider == TimeProvider.System) + { + return Task.Delay(delay, cancellationToken); + } + + if (timeProvider is null) + { + throw new ArgumentNullException(nameof(timeProvider)); + } + + if (delay != Timeout.InfiniteTimeSpan && delay < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(delay)); + } + + if (delay == TimeSpan.Zero) + { + return Task.CompletedTask; + } + + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + DelayState state = new(cancellationToken); + + state.Timer = timeProvider.CreateTimer(static delayState => + { + DelayState s = (DelayState)delayState!; + s.TrySetResult(true); + s.Registration.Dispose(); + s.Timer?.Dispose(); + }, state, delay, Timeout.InfiniteTimeSpan); + + state.Registration = cancellationToken.Register(static delayState => + { + DelayState s = (DelayState)delayState!; + s.TrySetCanceled(s.CancellationToken); + s.Registration.Dispose(); + s.Timer?.Dispose(); + }, state); + + // There are race conditions where the timer fires after we have attached the cancellation callback but before the + // registration is stored in state.Registration, or where cancellation is requested prior to the registration being + // stored into state.Registration, or where the timer could fire after it's been created but before it's been stored + // in state.Timer. In such cases, the cancellation registration and/or the Timer might be stored into state after the + // callbacks and thus left undisposed. So, we do a subsequent check here. If the task isn't completed by this point, + // then the callbacks won't have called TrySetResult (the callbacks invoke TrySetResult before disposing of the fields), + // in which case it will see both the timer and registration set and be able to Dispose them. If the task is completed + // by this point, then this is guaranteed to see s.Timer as non-null because it was deterministically set above. + if (state.Task.IsCompleted) + { + state.Registration.Dispose(); + state.Timer.Dispose(); + } + + return state.Task; + } + + /// + /// Gets a that will complete when this completes, + /// when the specified timeout expires, or when the specified has cancellation requested. + /// + /// The task for which to wait on until completion. + /// The timeout after which the should be faulted with a if it hasn't otherwise completed. + /// The with which to interpret . + /// The to monitor for a cancellation request. + /// The representing the asynchronous wait. It may or may not be the same instance as the current instance. + /// The argument is null. + /// The argument is null. + /// represents a negative time interval other than . + public static Task WaitAsync(this Task task, TimeSpan timeout, TimeProvider timeProvider, CancellationToken cancellationToken = default) + { + if (task is null) + { + throw new ArgumentNullException(nameof(task)); + } + + if (timeout != Timeout.InfiniteTimeSpan && timeout < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(timeout)); + } + + if (timeProvider is null) + { + throw new ArgumentNullException(nameof(timeProvider)); + } + + if (task.IsCompleted) + { + return task; + } + + if (timeout == Timeout.InfiniteTimeSpan && !cancellationToken.CanBeCanceled) + { + return task; + } + + if (timeout == TimeSpan.Zero) + { + Task.FromException(new TimeoutException()); + } + + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + WaitAsyncState state = new(cancellationToken); + + state.Timer = timeProvider.CreateTimer(static s => + { + var state = (WaitAsyncState)s!; + + state.TrySetException(new TimeoutException()); + + state.Registration.Dispose(); + state.Timer?.Dispose(); + state.ContinuationCancellation.Cancel(); + }, state, timeout, Timeout.InfiniteTimeSpan); + + _ = task.ContinueWith(static (t, s) => + { + var state = (WaitAsyncState)s!; + + if (t.IsFaulted) + { + state.TrySetException(t.Exception!.InnerExceptions); + } + else if (t.IsCanceled) + { + state.TrySetCanceled(); + } + else + { + state.TrySetResult(true); + } + + state.Registration.Dispose(); + state.Timer?.Dispose(); + }, state, state.ContinuationCancellation.Token, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + + state.Registration = cancellationToken.Register(static s => + { + var state = (WaitAsyncState)s!; + + state.TrySetCanceled(state.CancellationToken); + + state.Timer?.Dispose(); + state.ContinuationCancellation.Cancel(); + }, state); + + // See explanation in Delay for this final check + if (state.Task.IsCompleted) + { + state.Registration.Dispose(); + state.Timer.Dispose(); + } + + return state.Task; + } + + /// + /// Gets a that will complete when this completes, + /// when the specified timeout expires, or when the specified has cancellation requested. + /// + /// The task for which to wait on until completion. + /// The timeout after which the should be faulted with a if it hasn't otherwise completed. + /// The with which to interpret . + /// The to monitor for a cancellation request. + /// The representing the asynchronous wait. It may or may not be the same instance as the current instance. + /// The argument is null. + /// The argument is null. + /// represents a negative time interval other than . + public static async Task WaitAsync(this Task task, TimeSpan timeout, TimeProvider timeProvider, CancellationToken cancellationToken = default) + { + await ((Task)task).WaitAsync(timeout, timeProvider, cancellationToken).ConfigureAwait(false); + return task.Result; + } + + /// Initializes a new instance of the class that will be canceled after the specified . + /// The with which to interpret the . + /// The time interval to wait before canceling this . + /// The is negative and not equal to + /// or greater than maximum allowed timer duration. + /// that will be canceled after the specified . + /// + /// + /// The countdown for the delay starts during the call to the constructor. When the delay expires, + /// the constructed is canceled if it has + /// not been canceled already. + /// + /// + /// If running on .NET versions earlier than .NET 8.0, there is a constraint when invoking on the resultant object. + /// This action will not terminate the initial timer indicated by . However, this restriction does not apply on .NET 8.0 and later versions. + /// + /// + public static CancellationTokenSource CreateCancellationTokenSource(this TimeProvider timeProvider, TimeSpan delay) + { + if (timeProvider is null) + { + throw new ArgumentNullException(nameof(timeProvider)); + } + + if (delay != Timeout.InfiniteTimeSpan && delay < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(delay)); + } + + if (timeProvider == TimeProvider.System) + { + return new CancellationTokenSource(delay); + } + + var cts = new CancellationTokenSource(); + + ITimer timer = timeProvider.CreateTimer(static s => + { + try + { + ((CancellationTokenSource)s!).Cancel(); + } + catch (ObjectDisposedException) + { + // ok + } + }, cts, delay, Timeout.InfiniteTimeSpan); + + cts.Token.Register(static t => ((ITimer)t!).Dispose(), timer); + return cts; + } + } +} diff --git a/src/Polly.Core/Utils/CancellationTokenSourcePool.Disposable.cs b/src/Polly.Core/Utils/CancellationTokenSourcePool.Disposable.cs index f42b6c511cf..0510fa5cb5b 100644 --- a/src/Polly.Core/Utils/CancellationTokenSourcePool.Disposable.cs +++ b/src/Polly.Core/Utils/CancellationTokenSourcePool.Disposable.cs @@ -10,14 +10,12 @@ private sealed class DisposableCancellationTokenSourcePool : CancellationTokenSo protected override CancellationTokenSource GetCore(TimeSpan delay) { - var source = new CancellationTokenSource(); - - if (IsCancellable(delay)) + if (!IsCancellable(delay)) { - _timeProvider.CancelAfter(source, delay); + return new CancellationTokenSource(); } - return source; + return _timeProvider.CreateCancellationTokenSource(delay); } public override void Return(CancellationTokenSource source) => source.Dispose(); diff --git a/src/Polly.Core/Utils/CancellationTokenSourcePool.Pooled.cs b/src/Polly.Core/Utils/CancellationTokenSourcePool.Pooled.cs index 2d2c4091e3b..93b7acdfc11 100644 --- a/src/Polly.Core/Utils/CancellationTokenSourcePool.Pooled.cs +++ b/src/Polly.Core/Utils/CancellationTokenSourcePool.Pooled.cs @@ -5,14 +5,11 @@ internal abstract partial class CancellationTokenSourcePool #if NET6_0_OR_GREATER private sealed class PooledCancellationTokenSourcePool : CancellationTokenSourcePool { - public static readonly PooledCancellationTokenSourcePool SystemInstance = new(TimeProvider.System); + public static readonly PooledCancellationTokenSourcePool SystemInstance = new(); - public PooledCancellationTokenSourcePool(TimeProvider timeProvider) => _timeProvider = timeProvider; + private readonly ObjectPool _pool; - private readonly ObjectPool _pool = new( - static () => new CancellationTokenSource(), - static cts => true); - private readonly TimeProvider _timeProvider; + public PooledCancellationTokenSourcePool() => _pool = new(static () => new CancellationTokenSource(), static cts => true); protected override CancellationTokenSource GetCore(TimeSpan delay) { @@ -20,7 +17,7 @@ protected override CancellationTokenSource GetCore(TimeSpan delay) if (IsCancellable(delay)) { - _timeProvider.CancelAfter(source, delay); + source.CancelAfter(delay); } return source; diff --git a/src/Polly.Core/Utils/CancellationTokenSourcePool.cs b/src/Polly.Core/Utils/CancellationTokenSourcePool.cs index 942069dadf9..cabb1f291d7 100644 --- a/src/Polly.Core/Utils/CancellationTokenSourcePool.cs +++ b/src/Polly.Core/Utils/CancellationTokenSourcePool.cs @@ -11,8 +11,11 @@ public static CancellationTokenSourcePool Create(TimeProvider timeProvider) { return PooledCancellationTokenSourcePool.SystemInstance; } -#endif + return new DisposableCancellationTokenSourcePool(timeProvider); +#else + return new DisposableCancellationTokenSourcePool(timeProvider); +#endif } public CancellationTokenSource Get(TimeSpan delay) diff --git a/test/Polly.Core.Tests/Hedging/HedgingTimeProvider.cs b/test/Polly.Core.Tests/Hedging/HedgingTimeProvider.cs index d3607aa3ddc..b71c289f8b1 100644 --- a/test/Polly.Core.Tests/Hedging/HedgingTimeProvider.cs +++ b/test/Polly.Core.Tests/Hedging/HedgingTimeProvider.cs @@ -26,8 +26,6 @@ public void Advance(TimeSpan diff) public override long GetTimestamp() => TimeStampProvider(); - public override void CancelAfter(CancellationTokenSource source, TimeSpan delay) => throw new NotSupportedException(); - public override Task Delay(TimeSpan delayValue, CancellationToken cancellationToken = default) { var entry = new DelayEntry(delayValue, new TaskCompletionSource(), _utcNow.Add(delayValue)); diff --git a/test/Polly.Core.Tests/Helpers/FakeTimeProvider.cs b/test/Polly.Core.Tests/Helpers/FakeTimeProvider.cs index e436d536d0a..6946ac944a6 100644 --- a/test/Polly.Core.Tests/Helpers/FakeTimeProvider.cs +++ b/test/Polly.Core.Tests/Helpers/FakeTimeProvider.cs @@ -164,6 +164,14 @@ public override long GetTimestamp() /// A string representing the provider's current time. public override string ToString() => GetUtcNow().ToString("yyyy-MM-ddTHH:mm:ss.fff", CultureInfo.InvariantCulture); + /// + public override ITimer CreateTimer(TimerCallback callback, object? state, TimeSpan dueTime, TimeSpan period) + { + var timer = new Timer(this, callback, state); + _ = timer.Change(dueTime, period); + return timer; + } + internal void RemoveWaiter(Waiter waiter) { lock (Waiters) @@ -277,3 +285,91 @@ public void InvokeCallback() _callback(_state); } } + +// This implements the timer abstractions and is a thin wrapper around a waiter object. +// The main role of this type is to create the waiter, add it to the waiter list, and ensure it gets +// removed from the waiter list when the dispose is disposed or collected. +internal sealed class Timer : ITimer +{ + private const uint MaxSupportedTimeout = 0xfffffffe; + + private Waiter? _waiter; + private FakeTimeProvider? _timeProvider; + private TimerCallback? _callback; + private object? _state; + + public Timer(FakeTimeProvider timeProvider, TimerCallback callback, object? state) + { + _timeProvider = timeProvider; + _callback = callback; + _state = state; + } + + public bool Change(TimeSpan dueTime, TimeSpan period) + { + var dueTimeMs = (long)dueTime.TotalMilliseconds; + var periodMs = (long)period.TotalMilliseconds; + + if (_timeProvider == null) + { + // timer has been disposed + return false; + } + + if (_waiter != null) + { + // remove any previous waiter + _timeProvider.RemoveWaiter(_waiter); + _waiter = null; + } + + if (dueTimeMs < 0) + { + // this waiter will never wake up, so just bail + return true; + } + + if (periodMs < 0 || periodMs == Timeout.Infinite) + { + // normalize + period = TimeSpan.Zero; + } + + _waiter = new Waiter(_callback!, _state, period.Ticks); + _timeProvider.AddWaiter(_waiter, dueTime.Ticks); + return true; + } + + // In case the timer is not disposed, this will remove the Waiter instance from the provider. + ~Timer() => Dispose(false); + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + public ValueTask DisposeAsync() + { + Dispose(true); + GC.SuppressFinalize(this); +#if NET5_0_OR_GREATER + return ValueTask.CompletedTask; +#else + return default; +#endif + } + + private void Dispose(bool _) + { + if (_waiter != null) + { + _timeProvider!.RemoveWaiter(_waiter); + _waiter = null; + } + + _timeProvider = null; + _callback = null; + _state = null; + } +} diff --git a/test/Polly.Core.Tests/Helpers/MockTimeProvider.cs b/test/Polly.Core.Tests/Helpers/MockTimeProvider.cs index 8e027185cef..70d0d139f22 100644 --- a/test/Polly.Core.Tests/Helpers/MockTimeProvider.cs +++ b/test/Polly.Core.Tests/Helpers/MockTimeProvider.cs @@ -26,10 +26,4 @@ public MockTimeProvider SetupDelayCancelled(TimeSpan delay, CancellationToken ca Setup(x => x.Delay(delay, cancellationToken)).ThrowsAsync(new OperationCanceledException()); return this; } - - public MockTimeProvider SetupCancelAfterNow(TimeSpan delay) - { - Setup(v => v.CancelAfter(It.IsAny(), delay)).Callback((cts, _) => cts.Cancel()); - return this; - } } diff --git a/test/Polly.Core.Tests/Timeout/TimeoutResilienceStrategyTests.cs b/test/Polly.Core.Tests/Timeout/TimeoutResilienceStrategyTests.cs index 9d547ccf240..7cac4a0177e 100644 --- a/test/Polly.Core.Tests/Timeout/TimeoutResilienceStrategyTests.cs +++ b/test/Polly.Core.Tests/Timeout/TimeoutResilienceStrategyTests.cs @@ -1,3 +1,4 @@ +using Microsoft.Extensions.Time.Testing; using Moq; using Polly.Telemetry; using Polly.Timeout; @@ -7,16 +8,16 @@ namespace Polly.Core.Tests.Timeout; public class TimeoutResilienceStrategyTests : IDisposable { private readonly ResilienceStrategyTelemetry _telemetry; - private readonly MockTimeProvider _timeProvider; + private readonly FakeTimeProvider _timeProvider = new(); private readonly TimeoutStrategyOptions _options; private readonly CancellationTokenSource _cancellationSource; private readonly TimeSpan _delay = TimeSpan.FromSeconds(12); + private readonly Mock _diagnosticSource = new(); public TimeoutResilienceStrategyTests() { _telemetry = TestUtilities.CreateResilienceTelemetry(_diagnosticSource.Object); - _timeProvider = new MockTimeProvider(); _options = new TimeoutStrategyOptions(); _cancellationSource = new CancellationTokenSource(); } @@ -55,6 +56,8 @@ public async Task Execute_EnsureOnTimeoutCalled() var called = false; SetTimeout(_delay); + + var executionTime = _delay + TimeSpan.FromSeconds(1); _options.OnTimeout = args => { args.Exception.Should().BeAssignableTo(); @@ -65,10 +68,14 @@ public async Task Execute_EnsureOnTimeoutCalled() return default; }; - _timeProvider.SetupCancelAfterNow(_delay); - var sut = CreateSut(); - await sut.Invoking(s => sut.ExecuteAsync(async token => await Task.Delay(_delay, token)).AsTask()).Should().ThrowAsync(); + await sut.Invoking(s => sut.ExecuteAsync(async token => + { + var delay = _timeProvider.Delay(executionTime, token); + _timeProvider.Advance(_delay); + await delay; + }) + .AsTask()).Should().ThrowAsync(); called.Should().BeTrue(); _diagnosticSource.VerifyAll(); @@ -94,26 +101,37 @@ public async Task Execute_Timeout(bool defaultCancellationToken) using var cts = new CancellationTokenSource(); CancellationToken token = defaultCancellationToken ? default : cts.Token; SetTimeout(TimeSpan.FromSeconds(2)); - _timeProvider.SetupCancelAfterNow(TimeSpan.FromSeconds(2)); var sut = CreateSut(); await sut - .Invoking(s => s.ExecuteAsync(async token => await Delay(token), token).AsTask()) + .Invoking(s => s.ExecuteAsync(async token => + { + var delay = _timeProvider.Delay(TimeSpan.FromSeconds(4), token); + _timeProvider.Advance(TimeSpan.FromSeconds(2)); + await delay; + }, + token) + .AsTask()) .Should().ThrowAsync() .WithMessage("The operation didn't complete within the allowed timeout of '00:00:02'."); - - _timeProvider.VerifyAll(); } [Fact] public async Task Execute_Timeout_EnsureStackTrace() { - using var cts = new CancellationTokenSource(); SetTimeout(TimeSpan.FromSeconds(2)); - _timeProvider.SetupCancelAfterNow(TimeSpan.FromSeconds(2)); var sut = CreateSut(); - var outcome = await sut.ExecuteOutcomeAsync(async (c, _) => { await Delay(c.CancellationToken); return Outcome.FromResult("dummy"); }, ResilienceContext.Get(), "state"); + var outcome = await sut.ExecuteOutcomeAsync(async (c, _) => + { + var delay = _timeProvider.Delay(TimeSpan.FromSeconds(4), c.CancellationToken); + _timeProvider.Advance(TimeSpan.FromSeconds(2)); + await delay; + + return Outcome.FromResult("dummy"); + }, + ResilienceContext.Get(), + "state"); outcome.Exception.Should().BeOfType(); outcome.Exception!.StackTrace.Should().Contain("Execute_Timeout_EnsureStackTrace"); } @@ -132,15 +150,18 @@ public async Task Execute_Cancelled_EnsureNoTimeout() return default; }; - _timeProvider.Setup(v => v.CancelAfter(It.IsAny(), delay)); - var sut = CreateSut(); - await sut.Invoking(s => s.ExecuteAsync(async token => await Delay(token, () => cts.Cancel()), cts.Token).AsTask()) - .Should() - .ThrowAsync(); + await sut.Invoking(s => s.ExecuteAsync(async token => + { + var task = _timeProvider.Delay(delay, token); + cts.Cancel(); + await task; + }, + cts.Token).AsTask()) + .Should() + .ThrowAsync(); - _timeProvider.VerifyAll(); onTimeoutCalled.Should().BeFalse(); _diagnosticSource.Verify(v => v.IsEnabled("OnTimeout"), Times.Never()); @@ -153,7 +174,7 @@ public async Task Execute_NoTimeoutOrCancellation_EnsureCancellationTokenRestore using var cts = new CancellationTokenSource(); SetTimeout(TimeSpan.FromSeconds(10)); - _timeProvider.Setup(v => v.CancelAfter(It.IsAny(), delay)); + _timeProvider.Advance(delay); var sut = CreateSut(); @@ -178,7 +199,6 @@ public async Task Execute_EnsureCancellationTokenRegistrationNotExecutedOnSynchr // Arrange using var cts = new CancellationTokenSource(); SetTimeout(TimeSpan.FromSeconds(10)); - _timeProvider.Setup(v => v.CancelAfter(It.IsAny(), TimeSpan.FromSeconds(10))); var sut = CreateSut(); @@ -196,7 +216,13 @@ public async Task Execute_EnsureCancellationTokenRegistrationNotExecutedOnSynchr // Act try { - await sut.ExecuteAsync(async token => await Delay(token, () => cts.Cancel()), cts.Token); + await sut.ExecuteAsync(async token => + { + Task delayTask = Task.Delay(TimeSpan.FromSeconds(10), token); + cts.Cancel(); + await delayTask; + }, + cts.Token); } catch (OperationCanceledException) { @@ -209,19 +235,5 @@ public async Task Execute_EnsureCancellationTokenRegistrationNotExecutedOnSynchr private void SetTimeout(TimeSpan timeout) => _options.TimeoutGenerator = args => new ValueTask(timeout); - private TimeoutResilienceStrategy CreateSut() => new(_options, _timeProvider.Object, _telemetry); - - private static Task Delay(CancellationToken token, Action? onWaiting = null) - { - Task delayTask = Task.CompletedTask; - - try - { - return Task.Delay(TimeSpan.FromSeconds(2), token); - } - finally - { - onWaiting?.Invoke(); - } - } + private TimeoutResilienceStrategy CreateSut() => new(_options, _timeProvider, _telemetry); } diff --git a/test/Polly.Core.Tests/Utils/CancellationTokenSourcePoolTests.cs b/test/Polly.Core.Tests/Utils/CancellationTokenSourcePoolTests.cs index 03f27f36660..648e29cc198 100644 --- a/test/Polly.Core.Tests/Utils/CancellationTokenSourcePoolTests.cs +++ b/test/Polly.Core.Tests/Utils/CancellationTokenSourcePoolTests.cs @@ -1,3 +1,4 @@ +using Microsoft.Extensions.Time.Testing; using Moq; using Polly.Utils; @@ -8,7 +9,7 @@ public class CancellationTokenSourcePoolTests public static IEnumerable TimeProviders() { yield return new object[] { TimeProvider.System }; - yield return new object[] { new MockTimeProvider() }; + yield return new object[] { new FakeTimeProvider() }; } [Fact] @@ -66,24 +67,17 @@ public void RentReturn_NotReusable_EnsureProperBehavior(object timeProvider) [Theory] public async Task Rent_Cancellable_EnsureCancelled(object timeProvider) { - if (timeProvider is Mock fakeTimeProvider) - { - fakeTimeProvider - .Setup(v => v.CancelAfter(It.IsAny(), TimeSpan.FromMilliseconds(1))) - .Callback((source, _) => source.Cancel()); - } - else - { - fakeTimeProvider = null!; - } - var pool = CancellationTokenSourcePool.Create(GetTimeProvider(timeProvider)); var cts = pool.Get(TimeSpan.FromMilliseconds(1)); + if (timeProvider is FakeTimeProvider fakeTimeProvider) + { + fakeTimeProvider.Advance(TimeSpan.FromSeconds(1)); + } + await Task.Delay(100); cts.IsCancellationRequested.Should().BeTrue(); - fakeTimeProvider?.VerifyAll(); } [MemberData(nameof(TimeProviders))] @@ -96,12 +90,6 @@ public async Task Rent_NotCancellable_EnsureNotCancelled(object timeProvider) await Task.Delay(20); cts.IsCancellationRequested.Should().BeFalse(); - - if (timeProvider is Mock fakeTimeProvider) - { - fakeTimeProvider - .Verify(v => v.CancelAfter(It.IsAny(), It.IsAny()), Times.Never()); - } } private static TimeProvider GetTimeProvider(object timeProvider) => timeProvider switch