Skip to content

Commit

Permalink
Pass CancellationToken to GetAsyncEnumerator (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheAngryByrd authored May 8, 2024
1 parent 722aa5d commit c01c152
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 19 deletions.
18 changes: 11 additions & 7 deletions src/IcedTasks/CancellableTaskBuilderBase.fs
Original file line number Diff line number Diff line change
Expand Up @@ -741,13 +741,17 @@ module CancellableTaskBase =
source: #IAsyncEnumerable<'T>,
body: 'T -> CancellableTaskBaseCode<_, unit, 'Builder>
) : CancellableTaskBaseCode<_, _, 'Builder> =

this.Using(
source.GetAsyncEnumerator CancellationToken.None,
(fun (e: IAsyncEnumerator<'T>) ->
this.WhileAsync(
(fun () -> Awaitable.GetAwaiter(e.MoveNextAsync())),
(fun sm -> (body e.Current).Invoke(&sm))
this.Bind(
this.Source((fun (ct: CancellationToken) -> ValueTask<_> ct)),
(fun ct ->
this.Using(
source.GetAsyncEnumerator ct,
(fun (e: IAsyncEnumerator<'T>) ->
this.WhileAsync(
(fun () -> Awaitable.GetAwaiter(e.MoveNextAsync())),
(fun sm -> (body e.Current).Invoke(&sm))
)
)
)
)
)
35 changes: 35 additions & 0 deletions tests/IcedTasks.Tests/AsyncExTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,41 @@ module AsyncExTests =
)
}


testCaseAsync "IAsyncEnumerator receives CancellationToken"
<| async {
do!
asyncEx {

let mutable index = 0
let loops = 10

let asyncSeq =
AsyncEnumerable.forXtoY
0
loops
(fun _ -> valueTaskUnit { do! Task.Yield() })

use cts = new CancellationTokenSource()

let actual =
asyncEx {
for (i: int) in asyncSeq do
do! Task.Yield()
index <- index + 1
}

do!
Async.StartAsTask(actual, cancellationToken = cts.Token)
|> Async.AwaitTask

Expect.equal
asyncSeq.LastEnumerator.Value.CancellationToken
cts.Token
""
}
}

]
]

Expand Down
33 changes: 33 additions & 0 deletions tests/IcedTasks.Tests/CancellablePoolingValueTaskTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,39 @@ module CancellablePoolingValueTaskTests =
}
)
}


testCaseAsync "IAsyncEnumerator receives CancellationToken"
<| async {
do!
cancellablePoolingValueTask {

let mutable index = 0
let loops = 10

let asyncSeq =
AsyncEnumerable.forXtoY
0
loops
(fun _ -> valueTaskUnit { do! Task.Yield() })

use cts = new CancellationTokenSource()

let actual =
cancellablePoolingValueTask {
for (i: int) in asyncSeq do
do! Task.Yield()
index <- index + 1
}

do! actual cts.Token

Expect.equal
asyncSeq.LastEnumerator.Value.CancellationToken
cts.Token
""
}
}
]


Expand Down
35 changes: 35 additions & 0 deletions tests/IcedTasks.Tests/CancellableTaskTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,41 @@ module CancellableTaskTests =
}
)
}


testCaseAsync "IAsyncEnumerator receives CancellationToken"
<| async {

do!
cancellableTask {

let mutable index = 0
let loops = 10

let asyncSeq =
AsyncEnumerable.forXtoY
0
loops
(fun _ -> valueTaskUnit { do! Task.Yield() })

use cts = new CancellationTokenSource()

let actual =
cancellableTask {
for (i: int) in asyncSeq do
do! Task.Yield()
index <- index + 1
}

do! actual cts.Token

Expect.equal
asyncSeq.LastEnumerator.Value.CancellationToken
cts.Token
""
}

}
]
testList "MergeSources" [

Expand Down
33 changes: 33 additions & 0 deletions tests/IcedTasks.Tests/CancellableValueTaskTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,39 @@ module CancellableValueTaskTests =
)
}


testCaseAsync "IAsyncEnumerator receives CancellationToken"
<| async {
do!
cancellableValueTask {

let mutable index = 0
let loops = 10

let asyncSeq =
AsyncEnumerable.forXtoY
0
loops
(fun _ -> valueTaskUnit { do! Task.Yield() })

use cts = new CancellationTokenSource()

let actual =
cancellableValueTask {
for (i: int) in asyncSeq do
do! Task.Yield()
index <- index + 1
}

do! actual cts.Token

Expect.equal
asyncSeq.LastEnumerator.Value.CancellationToken
cts.Token
""
}
}

]

testList "MergeSources" [
Expand Down
41 changes: 29 additions & 12 deletions tests/IcedTasks.Tests/Expect.fs
Original file line number Diff line number Diff line change
Expand Up @@ -171,23 +171,40 @@ module AsyncEnumerable =
open System.Collections.Generic
open System.Threading

type AsyncEnumerable<'T>(e: IEnumerable<'T>, beforeMoveNext: Func<_, ValueTask>) =
type AsyncEnumerator<'T>(current, moveNext, dispose, cancellationToken: CancellationToken) =
member this.CancellationToken = cancellationToken

member this.GetAsyncEnumerator(ct) =
let enumerator = e.GetEnumerator()
interface IAsyncEnumerator<'T> with
member this.Current = current ()
member this.MoveNextAsync() = moveNext ()
member this.DisposeAsync() = dispose ()

{ new IAsyncEnumerator<'T> with
member this.Current = enumerator.Current
type AsyncEnumerable<'T>(e: IEnumerable<'T>, beforeMoveNext: Func<_, ValueTask>) =

member this.MoveNextAsync() =
valueTask {
do! beforeMoveNext.Invoke(ct)
return enumerator.MoveNext()
}
let mutable lastEnumerator = None
member this.LastEnumerator = lastEnumerator

member this.DisposeAsync() = valueTaskUnit { enumerator.Dispose() }
member this.GetAsyncEnumerator(ct) =
let enumerator = e.GetEnumerator()

}
lastEnumerator <-
Some
<| AsyncEnumerator(
(fun () -> enumerator.Current),
(fun () ->
valueTask {
do! beforeMoveNext.Invoke(ct)
return enumerator.MoveNext()
}
),
(fun () ->
enumerator.Dispose()
|> ValueTask
),
ct
)

lastEnumerator.Value

interface IAsyncEnumerable<'T> with
member this.GetAsyncEnumerator(ct: CancellationToken) = this.GetAsyncEnumerator(ct)
Expand Down

0 comments on commit c01c152

Please sign in to comment.