From c95eb34eca3451ad49d455cc2bd62d8db436e77b Mon Sep 17 00:00:00 2001 From: Jimmy Byrd Date: Wed, 8 May 2024 08:58:07 -0400 Subject: [PATCH] Pass CancellationToken to GetAsyncEnumerator --- src/IcedTasks/CancellableTaskBuilderBase.fs | 18 ++++---- tests/IcedTasks.Tests/AsyncExTests.fs | 35 ++++++++++++++++ .../CancellablePoolingValueTaskTests.fs | 33 +++++++++++++++ tests/IcedTasks.Tests/CancellableTaskTests.fs | 35 ++++++++++++++++ .../CancellableValueTaskTests.fs | 33 +++++++++++++++ tests/IcedTasks.Tests/Expect.fs | 41 +++++++++++++------ 6 files changed, 176 insertions(+), 19 deletions(-) diff --git a/src/IcedTasks/CancellableTaskBuilderBase.fs b/src/IcedTasks/CancellableTaskBuilderBase.fs index 134c84d..b1fd3cd 100644 --- a/src/IcedTasks/CancellableTaskBuilderBase.fs +++ b/src/IcedTasks/CancellableTaskBuilderBase.fs @@ -741,13 +741,17 @@ module CancellableTaskBase = source: #IAsyncEnumerable<'T>, body: 'T -> CancellableTaskBaseCode<_, unit, 'Builder> ) : CancellableTaskBaseCode<_, _, 'Builder> = - - this.Using( - source.GetAsyncEnumerator CancellationToken.None, - (fun (e: IAsyncEnumerator<'T>) -> - this.WhileAsync( - (fun () -> Awaitable.GetAwaiter(e.MoveNextAsync())), - (fun sm -> (body e.Current).Invoke(&sm)) + this.Bind( + this.Source((fun (ct: CancellationToken) -> ValueTask<_> ct)), + (fun ct -> + this.Using( + source.GetAsyncEnumerator ct, + (fun (e: IAsyncEnumerator<'T>) -> + this.WhileAsync( + (fun () -> Awaitable.GetAwaiter(e.MoveNextAsync())), + (fun sm -> (body e.Current).Invoke(&sm)) + ) + ) ) ) ) diff --git a/tests/IcedTasks.Tests/AsyncExTests.fs b/tests/IcedTasks.Tests/AsyncExTests.fs index 0500695..dd13630 100644 --- a/tests/IcedTasks.Tests/AsyncExTests.fs +++ b/tests/IcedTasks.Tests/AsyncExTests.fs @@ -741,6 +741,41 @@ module AsyncExTests = ) } + + testCaseAsync "IAsyncEnumerator receives CancellationToken" + <| async { + do! + asyncEx { + + let mutable index = 0 + let loops = 10 + + let asyncSeq = + AsyncEnumerable.forXtoY + 0 + loops + (fun _ -> valueTaskUnit { do! Task.Yield() }) + + use cts = new CancellationTokenSource() + + let actual = + asyncEx { + for (i: int) in asyncSeq do + do! Task.Yield() + index <- index + 1 + } + + do! + Async.StartAsTask(actual, cancellationToken = cts.Token) + |> Async.AwaitTask + + Expect.equal + asyncSeq.LastEnumerator.Value.CancellationToken + cts.Token + "" + } + } + ] ] diff --git a/tests/IcedTasks.Tests/CancellablePoolingValueTaskTests.fs b/tests/IcedTasks.Tests/CancellablePoolingValueTaskTests.fs index dcb1df8..5809886 100644 --- a/tests/IcedTasks.Tests/CancellablePoolingValueTaskTests.fs +++ b/tests/IcedTasks.Tests/CancellablePoolingValueTaskTests.fs @@ -851,6 +851,39 @@ module CancellablePoolingValueTaskTests = } ) } + + + testCaseAsync "IAsyncEnumerator receives CancellationToken" + <| async { + do! + cancellablePoolingValueTask { + + let mutable index = 0 + let loops = 10 + + let asyncSeq = + AsyncEnumerable.forXtoY + 0 + loops + (fun _ -> valueTaskUnit { do! Task.Yield() }) + + use cts = new CancellationTokenSource() + + let actual = + cancellablePoolingValueTask { + for (i: int) in asyncSeq do + do! Task.Yield() + index <- index + 1 + } + + do! actual cts.Token + + Expect.equal + asyncSeq.LastEnumerator.Value.CancellationToken + cts.Token + "" + } + } ] diff --git a/tests/IcedTasks.Tests/CancellableTaskTests.fs b/tests/IcedTasks.Tests/CancellableTaskTests.fs index e29f6a8..54a1b12 100644 --- a/tests/IcedTasks.Tests/CancellableTaskTests.fs +++ b/tests/IcedTasks.Tests/CancellableTaskTests.fs @@ -812,6 +812,41 @@ module CancellableTaskTests = } ) } + + + testCaseAsync "IAsyncEnumerator receives CancellationToken" + <| async { + + do! + cancellableTask { + + let mutable index = 0 + let loops = 10 + + let asyncSeq = + AsyncEnumerable.forXtoY + 0 + loops + (fun _ -> valueTaskUnit { do! Task.Yield() }) + + use cts = new CancellationTokenSource() + + let actual = + cancellableTask { + for (i: int) in asyncSeq do + do! Task.Yield() + index <- index + 1 + } + + do! actual cts.Token + + Expect.equal + asyncSeq.LastEnumerator.Value.CancellationToken + cts.Token + "" + } + + } ] testList "MergeSources" [ diff --git a/tests/IcedTasks.Tests/CancellableValueTaskTests.fs b/tests/IcedTasks.Tests/CancellableValueTaskTests.fs index c3e36f0..c2d289f 100644 --- a/tests/IcedTasks.Tests/CancellableValueTaskTests.fs +++ b/tests/IcedTasks.Tests/CancellableValueTaskTests.fs @@ -851,6 +851,39 @@ module CancellableValueTaskTests = ) } + + testCaseAsync "IAsyncEnumerator receives CancellationToken" + <| async { + do! + cancellableValueTask { + + let mutable index = 0 + let loops = 10 + + let asyncSeq = + AsyncEnumerable.forXtoY + 0 + loops + (fun _ -> valueTaskUnit { do! Task.Yield() }) + + use cts = new CancellationTokenSource() + + let actual = + cancellableValueTask { + for (i: int) in asyncSeq do + do! Task.Yield() + index <- index + 1 + } + + do! actual cts.Token + + Expect.equal + asyncSeq.LastEnumerator.Value.CancellationToken + cts.Token + "" + } + } + ] testList "MergeSources" [ diff --git a/tests/IcedTasks.Tests/Expect.fs b/tests/IcedTasks.Tests/Expect.fs index 7884c5d..f2bc584 100644 --- a/tests/IcedTasks.Tests/Expect.fs +++ b/tests/IcedTasks.Tests/Expect.fs @@ -171,23 +171,40 @@ module AsyncEnumerable = open System.Collections.Generic open System.Threading - type AsyncEnumerable<'T>(e: IEnumerable<'T>, beforeMoveNext: Func<_, ValueTask>) = + type AsyncEnumerator<'T>(current, moveNext, dispose, cancellationToken: CancellationToken) = + member this.CancellationToken = cancellationToken - member this.GetAsyncEnumerator(ct) = - let enumerator = e.GetEnumerator() + interface IAsyncEnumerator<'T> with + member this.Current = current () + member this.MoveNextAsync() = moveNext () + member this.DisposeAsync() = dispose () - { new IAsyncEnumerator<'T> with - member this.Current = enumerator.Current + type AsyncEnumerable<'T>(e: IEnumerable<'T>, beforeMoveNext: Func<_, ValueTask>) = - member this.MoveNextAsync() = - valueTask { - do! beforeMoveNext.Invoke(ct) - return enumerator.MoveNext() - } + let mutable lastEnumerator = None + member this.LastEnumerator = lastEnumerator - member this.DisposeAsync() = valueTaskUnit { enumerator.Dispose() } + member this.GetAsyncEnumerator(ct) = + let enumerator = e.GetEnumerator() - } + lastEnumerator <- + Some + <| AsyncEnumerator( + (fun () -> enumerator.Current), + (fun () -> + valueTask { + do! beforeMoveNext.Invoke(ct) + return enumerator.MoveNext() + } + ), + (fun () -> + enumerator.Dispose() + |> ValueTask + ), + ct + ) + + lastEnumerator.Value interface IAsyncEnumerable<'T> with member this.GetAsyncEnumerator(ct: CancellationToken) = this.GetAsyncEnumerator(ct)