Skip to content

Commit

Permalink
Propagate cancellation more thoroughly
Browse files Browse the repository at this point in the history
  • Loading branch information
AArnott committed Nov 29, 2019
1 parent 9af5cf8 commit 50c37af
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 17 deletions.
25 changes: 21 additions & 4 deletions doc/asyncenumerable.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Notice how it is not necessary (or desirable) to wrap the resulting `IAsyncEnume
C# 8 lets you consume such an async enumerable using `await foreach`:

```cs
await foreach (int number in this.clientProxy.GenerateNumbersAsync(token).WithCancellation(token))
await foreach (int number in this.clientProxy.GenerateNumbersAsync(token))
{
Console.WriteLine(number);
}
Expand All @@ -50,13 +50,30 @@ await foreach (int number in this.clientProxy.GenerateNumbersAsync(token).WithCa
All the foregoing is simple C# 8 async enumerable syntax and use cases.
StreamJsonRpc lets you use this natural syntax over an RPC connection.

We pass `token` in once to the method that calls to the RPC server, and again to the `WithCancellation`
extension method so that the token is applied to each iteration of the loop over the enumerable.

A remoted `IAsyncEnumerable<T>` can only be enumerated once.
Calling `IAsyncEnumerable<T>.GetAsyncEnumerator(CancellationToken)` more than once will
result in an `InvalidOperationException` being thrown.

When *not* using the dynamically generated proxies, acquiring and enumerating an `IAsyncEnumerator<T>` looks like this:

```cs
var enumerable = await this.clientRpc.InvokeWithCancellationAsync<IAsyncEnumerable<int>>(
"GetNumbersAsync", cancellationToken);

await foreach (var item in enumerable.WithCancellation(cancellationToken))
{
// processing
}
```

We pass `cancellationToken` into `InvokeWithCancellationAsync` so we can cancel the initial call.
We pass it again to the `WithCancellation` extension method inside the `foreach` expression
so that the token is applied to each iteration of the loop over the enumerable when
we may be awaiting a network call.

Using the `WithCancellation` extension method is not necessary when using dynamically generated proxies
because they automatically propagate the token from the first call to the enumerator.

### Transmitting large collections

Most C# iterator methods return `IEnumerable<T>` and produce values synchronously.
Expand Down
95 changes: 87 additions & 8 deletions src/StreamJsonRpc.Tests/AsyncEnumerableTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Globalization;
using System.Runtime.CompilerServices;
Expand Down Expand Up @@ -49,12 +50,16 @@ protected interface IServer2

protected interface IServer
{
IAsyncEnumerable<int> GetValuesFromEnumeratedSourceAsync(CancellationToken cancellationToken);

IAsyncEnumerable<int> GetNumbersInBatchesAsync(CancellationToken cancellationToken);

IAsyncEnumerable<int> GetNumbersWithReadAheadAsync(CancellationToken cancellationToken);

IAsyncEnumerable<int> GetNumbersAsync(CancellationToken cancellationToken);

IAsyncEnumerable<int> GetNumbersNoCancellationAsync();

IAsyncEnumerable<int> WaitTillCanceledBeforeFirstItemAsync(CancellationToken cancellationToken);

Task<IAsyncEnumerable<int>> WaitTillCanceledBeforeReturningAsync(CancellationToken cancellationToken);
Expand Down Expand Up @@ -127,6 +132,20 @@ public async Task GetIAsyncEnumerableAsReturnType(bool useProxy)
Assert.Equal(Server.ValuesReturnedByEnumerables, realizedValuesCount);
}

[Fact]
public async Task GetIAsyncEnumerableAsReturnType_WithProxy_NoCancellation()
{
int realizedValuesCount = 0;
IAsyncEnumerable<int> enumerable = this.clientProxy.Value.GetNumbersNoCancellationAsync();
await foreach (int number in enumerable)
{
realizedValuesCount++;
this.Logger.WriteLine(number.ToString(CultureInfo.InvariantCulture));
}

Assert.Equal(Server.ValuesReturnedByEnumerables, realizedValuesCount);
}

[Theory]
[PairwiseData]
public async Task GetIAsyncEnumerableAsMemberWithinReturnType(bool useProxy)
Expand Down Expand Up @@ -289,24 +308,40 @@ public async Task Cancellation_AfterFirstMoveNext(bool useProxy)
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync());
}

[Theory]
[PairwiseData]
public async Task Cancellation_AfterFirstMoveNext_NaturalForEach(bool useProxy)
[Fact]
public async Task Cancellation_AfterFirstMoveNext_NaturalForEach_Proxy()
{
IAsyncEnumerable<int> enumerable = useProxy
? this.clientProxy.Value.GetNumbersAsync(this.TimeoutToken)
: await this.clientRpc.InvokeWithCancellationAsync<IAsyncEnumerable<int>>(nameof(Server.GetNumbersAsync), cancellationToken: this.TimeoutToken);
using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.TimeoutToken);
IAsyncEnumerable<int> enumerable = this.clientProxy.Value.GetNumbersAsync(cts.Token);

int iterations = 0;
await Assert.ThrowsAsync<OperationCanceledException>(async delegate
{
await foreach (var item in enumerable)
{
iterations++;
cts.Cancel();
}
}).WithCancellation(this.TimeoutToken);

Assert.Equal(1, iterations);
}

[Fact]
public async Task Cancellation_AfterFirstMoveNext_NaturalForEach_NoProxy()
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.TimeoutToken);
var enumerable = await this.clientRpc.InvokeWithCancellationAsync<IAsyncEnumerable<int>>(nameof(Server.GetNumbersAsync), cancellationToken: cts.Token);

int iterations = 0;
await Assert.ThrowsAsync<OperationCanceledException>(async delegate
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(this.TimeoutToken);
await foreach (var item in enumerable.WithCancellation(cts.Token))
{
iterations++;
cts.Cancel();
}
});
}).WithCancellation(this.TimeoutToken);

Assert.Equal(1, iterations);
}
Expand Down Expand Up @@ -424,6 +459,14 @@ public async Task ArgumentEnumerable_ForciblyDisposedAndReleasedWhenNotDisposedW
await Assert.ThrowsAsync<InvalidOperationException>(() => this.server.ArgEnumeratorAfterReturn).WithCancellation(this.TimeoutToken);
}

[SkippableFact]
[Trait("GC", "")]
public async Task ReturnEnumerable_AutomaticallyReleasedOnErrorFromIteratorMethod()
{
WeakReference enumerable = await this.ReturnEnumerable_AutomaticallyReleasedOnErrorFromIteratorMethod_Helper();
AssertCollectedObject(enumerable);
}

[Theory]
[InlineData(1, 0, 2, Server.ValuesReturnedByEnumerables)]
[InlineData(2, 2, 2, Server.ValuesReturnedByEnumerables)]
Expand Down Expand Up @@ -500,6 +543,28 @@ private async Task<WeakReference> ArgumentEnumerable_ForciblyDisposedAndReleased
return result;
}

[MethodImpl(MethodImplOptions.NoInlining)]
private async Task<WeakReference> ReturnEnumerable_AutomaticallyReleasedOnErrorFromIteratorMethod_Helper()
{
this.server.EnumeratedSource = ImmutableList.Create(1, 2, 3);
WeakReference weakReferenceToSource = new WeakReference(this.server.EnumeratedSource);
var cts = CancellationTokenSource.CreateLinkedTokenSource(this.TimeoutToken);

// Start up th emethod and get the first item.
var enumerable = this.clientProxy.Value.GetValuesFromEnumeratedSourceAsync(cts.Token);
var enumerator = enumerable.GetAsyncEnumerator(cts.Token);
Assert.True(await enumerator.MoveNextAsync());

// Now remove the only strong reference to the source object other than what would be captured by the async iterator method.
this.server.EnumeratedSource = this.server.EnumeratedSource.Clear();

// Now array for the server method to be canceled
cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync());

return weakReferenceToSource;
}

protected class Server : IServer
{
/// <summary>
Expand All @@ -524,12 +589,26 @@ protected class Server : IServer

public AsyncManualResetEvent ValueGenerated { get; } = new AsyncManualResetEvent();

public ImmutableList<int> EnumeratedSource { get; set; } = ImmutableList<int>.Empty;

public async IAsyncEnumerable<int> GetValuesFromEnumeratedSourceAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
foreach (var item in this.EnumeratedSource)
{
cancellationToken.ThrowIfCancellationRequested();
await Task.Yield();
yield return item;
}
}

public IAsyncEnumerable<int> GetNumbersInBatchesAsync(CancellationToken cancellationToken)
=> this.GetNumbersAsync(cancellationToken).WithJsonRpcSettings(new JsonRpcEnumerableSettings { MinBatchSize = MinBatchSize });

public IAsyncEnumerable<int> GetNumbersWithReadAheadAsync(CancellationToken cancellationToken)
=> this.GetNumbersAsync(cancellationToken).WithJsonRpcSettings(new JsonRpcEnumerableSettings { MaxReadAhead = MaxReadAhead, MinBatchSize = MinBatchSize });

public IAsyncEnumerable<int> GetNumbersNoCancellationAsync() => this.GetNumbersAsync(CancellationToken.None);

public async IAsyncEnumerable<int> GetNumbersAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
try
Expand Down
26 changes: 21 additions & 5 deletions src/StreamJsonRpc/ProxyGeneration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ internal static TypeInfo Get(TypeInfo serviceInterface)

il.EmitCall(OpCodes.Callvirt, invokingMethod, null);

AdaptReturnType(method, returnTypeIsValueTask, returnTypeIsIAsyncEnumerable, il, invokingMethod);
AdaptReturnType(method, returnTypeIsValueTask, returnTypeIsIAsyncEnumerable, il, invokingMethod, cancellationTokenParameter);

il.Emit(OpCodes.Ret);
}
Expand Down Expand Up @@ -346,7 +346,7 @@ internal static TypeInfo Get(TypeInfo serviceInterface)

il.EmitCall(OpCodes.Callvirt, invokingMethod, null);

AdaptReturnType(method, returnTypeIsValueTask, returnTypeIsIAsyncEnumerable, il, invokingMethod);
AdaptReturnType(method, returnTypeIsValueTask, returnTypeIsIAsyncEnumerable, il, invokingMethod, cancellationTokenParameter);

il.Emit(OpCodes.Ret);
}
Expand Down Expand Up @@ -374,7 +374,8 @@ internal static TypeInfo Get(TypeInfo serviceInterface)
/// <param name="returnTypeIsIAsyncEnumerable"><c>true</c> if the return type is <see cref="IAsyncEnumerable{TResult}"/>; <c>false</c> otherwise.</param>
/// <param name="il">The IL emitter for the method.</param>
/// <param name="invokingMethod">The Invoke method on <see cref="JsonRpc"/> that IL was just emitted to invoke.</param>
private static void AdaptReturnType(MethodInfo method, bool returnTypeIsValueTask, bool returnTypeIsIAsyncEnumerable, ILGenerator il, MethodInfo invokingMethod)
/// <param name="cancellationTokenParameter">The <see cref="CancellationToken"/> parameter in the proxy method, if there is one.</param>
private static void AdaptReturnType(MethodInfo method, bool returnTypeIsValueTask, bool returnTypeIsIAsyncEnumerable, ILGenerator il, MethodInfo invokingMethod, ParameterInfo? cancellationTokenParameter)
{
if (returnTypeIsValueTask)
{
Expand All @@ -384,6 +385,19 @@ private static void AdaptReturnType(MethodInfo method, bool returnTypeIsValueTas
else if (returnTypeIsIAsyncEnumerable)
{
// We must convert the Task<IAsyncEnumerable<T>> to IAsyncEnumerable<T>
// Push a CancellationToken to the stack as well. Use the one this method was given if available, otherwise push CancellationToken.None.
if (cancellationTokenParameter != null)
{
il.Emit(OpCodes.Ldarg, cancellationTokenParameter.Position + 1);
}
else
{
LocalBuilder local = il.DeclareLocal(typeof(CancellationToken));
il.Emit(OpCodes.Ldloca, local);
il.Emit(OpCodes.Initobj, typeof(CancellationToken));
il.Emit(OpCodes.Ldloc, local);
}

Type proxyEnumerableType = typeof(AsyncEnumerableProxy<>).MakeGenericType(method.ReturnType.GenericTypeArguments[0]);
ConstructorInfo ctor = proxyEnumerableType.GetConstructors(BindingFlags.NonPublic | BindingFlags.Instance).Single();
il.Emit(OpCodes.Newobj, ctor);
Expand Down Expand Up @@ -596,15 +610,17 @@ private static IEnumerable<T> FindAllOnThisAndOtherInterfaces<T>(TypeInfo interf
private class AsyncEnumerableProxy<T> : IAsyncEnumerable<T>
{
private readonly Task<IAsyncEnumerable<T>> enumerableTask;
private readonly CancellationToken defaultCancellationToken;

internal AsyncEnumerableProxy(Task<IAsyncEnumerable<T>> enumerableTask)
internal AsyncEnumerableProxy(Task<IAsyncEnumerable<T>> enumerableTask, CancellationToken defaultCancellationToken)
{
this.enumerableTask = enumerableTask ?? throw new ArgumentNullException(nameof(enumerableTask));
this.defaultCancellationToken = defaultCancellationToken;
}

public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken)
{
return new AsyncEnumeratorProxy(this.enumerableTask, cancellationToken);
return new AsyncEnumeratorProxy(this.enumerableTask, cancellationToken.CanBeCanceled ? cancellationToken : this.defaultCancellationToken);
}

private class AsyncEnumeratorProxy : IAsyncEnumerator<T>
Expand Down

0 comments on commit 50c37af

Please sign in to comment.