Skip to content

Commit

Permalink
Gracefully dispose IAsyncEnumerable requests (#9186)
Browse files Browse the repository at this point in the history
  • Loading branch information
ReubenBond authored Oct 16, 2024
1 parent a7510b2 commit 188b6b5
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 39 deletions.
89 changes: 55 additions & 34 deletions src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Orleans.Configuration;
using Orleans.Internal;
Expand All @@ -19,7 +20,7 @@ namespace Orleans.Runtime;
internal sealed class AsyncEnumerableGrainExtension : IAsyncEnumerableGrainExtension, IAsyncDisposable, IDisposable
{
private const long EnumeratorExpirationMilliseconds = 10_000;
private readonly Dictionary<Guid, EnumeratorState> _enumerators = new();
private readonly Dictionary<Guid, EnumeratorState> _enumerators = [];
private readonly IGrainContext _grainContext;
private readonly MessagingOptions _messagingOptions;
private readonly IDisposable _timer;
Expand Down Expand Up @@ -47,15 +48,7 @@ static async (state, cancellationToken) => await state.RemoveExpiredAsync(cancel
}

/// <inheritdoc/>
public ValueTask DisposeAsync(Guid requestId)
{
if (_enumerators.Remove(requestId, out var enumerator) && enumerator.Enumerator is { } value)
{
return value.DisposeAsync();
}

return default;
}
public ValueTask DisposeAsync(Guid requestId) => RemoveEnumeratorAsync(requestId);

private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
{
Expand All @@ -65,7 +58,7 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
if (state.LastSeenTimer.ElapsedMilliseconds > EnumeratorExpirationMilliseconds
&& state.MoveNextTask is null or { IsCompleted: true })
{
toRemove ??= new List<Guid>();
toRemove ??= [];
toRemove.Add(requestId);
}
}
Expand All @@ -75,13 +68,11 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
{
foreach (var requestId in toRemove)
{
_enumerators.Remove(requestId, out var state);
state.MoveNextTask?.Ignore();
var disposeTask = state.Enumerator.DisposeAsync();
if (!disposeTask.IsCompletedSuccessfully)
var removeTask = RemoveEnumeratorAsync(requestId);
if (!removeTask.IsCompletedSuccessfully)
{
tasks ??= new List<Task>();
tasks.Add(disposeTask.AsTask());
tasks ??= [];
tasks.Add(removeTask.AsTask());
}
}
}
Expand All @@ -97,24 +88,22 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
{
request.SetTarget(_grainContext);
var enumerable = request.InvokeImplementation();
var enumerator = enumerable.GetAsyncEnumerator();
ref var entry = ref CollectionsMarshal.GetValueRefOrAddDefault(_enumerators, requestId, out bool exists);
if (exists)
{
return ThrowAlreadyExists(enumerator);
return ThrowAlreadyExists();
}

var cts = new CancellationTokenSource();
var enumerator = enumerable.GetAsyncEnumerator(cts.Token);
entry.Enumerator = enumerator;
entry.LastSeenTimer.Restart();
entry.MaxBatchSize = request.MaxBatchSize;
entry.CancellationTokenSource = cts;
Debug.Assert(entry.MaxBatchSize > 0, "Max batch size must be positive.");
return MoveNextAsync(ref entry, requestId, enumerator);

static async ValueTask<(EnumerationResult Status, object Value)> ThrowAlreadyExists(IAsyncEnumerator<T> enumerator)
{
await enumerator.DisposeAsync();
throw new InvalidOperationException("An enumerator with the same ID already exists.");
}
static ValueTask<(EnumerationResult Status, object Value)> ThrowAlreadyExists() => ValueTask.FromException<(EnumerationResult Status, object Value)>(new InvalidOperationException("An enumerator with the same id already exists."));
}

/// <inheritdoc/>
Expand Down Expand Up @@ -237,7 +226,7 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
}
else
{
_enumerators.Remove(requestId);
await RemoveEnumeratorAsync(requestId);
await typedEnumerator.DisposeAsync();
return (EnumerationResult.Completed, default);
}
Expand All @@ -247,25 +236,29 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
}
catch
{
_enumerators.Remove(requestId);
await RemoveEnumeratorAsync(requestId);
await typedEnumerator.DisposeAsync();
throw;
}
}

private async ValueTask RemoveEnumeratorAsync(Guid requestId)
{
if (_enumerators.Remove(requestId, out var state))
{
await DisposeEnumeratorAsync(state);
}
}

private async ValueTask<(EnumerationResult Status, object Value)> OnComplete<T>(Guid requestId, IAsyncEnumerator<T> enumerator)
{
_enumerators.Remove(requestId, out var state);
state.MoveNextTask?.Ignore();
await enumerator.DisposeAsync();
await RemoveEnumeratorAsync(requestId);
return (EnumerationResult.Completed, default);
}

private async ValueTask<(EnumerationResult Status, object Value)> OnError<T>(Guid requestId, IAsyncEnumerator<T> enumerator, Exception exception)
{
_enumerators.Remove(requestId, out var state);
state.MoveNextTask?.Ignore();
await enumerator.DisposeAsync();
await RemoveEnumeratorAsync(requestId);
ExceptionDispatchInfo.Throw(exception);
return default;
}
Expand All @@ -291,15 +284,42 @@ public async ValueTask DisposeAsync()
_enumerators.Clear();

foreach (var enumerator in enumerators)
{
await DisposeEnumeratorAsync(enumerator);
}
}

_timer.Dispose();
}

private async ValueTask DisposeEnumeratorAsync(EnumeratorState enumerator)
{
try
{
enumerator.CancellationTokenSource.Cancel();
}
catch (Exception exception)
{
var logger = _grainContext.GetComponent<ILogger>();
logger?.LogWarning(exception, "Error cancelling enumerator.");
}

try
{
if (enumerator.MoveNextTask is { } task)
{
if (enumerator.Enumerator is { } value)
{
await task.SuppressThrowing();
await value.DisposeAsync();
}
}
}

_timer.Dispose();
catch (Exception exception)
{
var logger = _grainContext.GetComponent<ILogger>();
logger?.LogWarning(exception, "Error disposing enumerator.");
}
}

/// <inheritdoc/>
Expand All @@ -314,5 +334,6 @@ private struct EnumeratorState
public Task<bool> MoveNextTask;
public CoarseStopwatch LastSeenTimer;
public int MaxBatchSize;
internal CancellationTokenSource CancellationTokenSource;
}
}
10 changes: 5 additions & 5 deletions src/Orleans.Runtime/Catalog/ActivationData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ private DehydrationContextHolder? DehydrationContext

public TComponent? GetComponent<TComponent>() where TComponent : class
{
TComponent? result;
TComponent? result = default;
if (GrainInstance is TComponent grainResult)
{
result = grainResult;
Expand All @@ -260,15 +260,15 @@ private DehydrationContextHolder? DehydrationContext
{
result = (TComponent)resultObj;
}
else if (_shared.GetComponent<TComponent>() is { } sharedComponent)
{
result = sharedComponent;
}
else if (ActivationServices.GetService<TComponent>() is { } component)
{
SetComponent(component);
result = component;
}
else
{
result = _shared.GetComponent<TComponent>();
}

return result;
}
Expand Down
5 changes: 5 additions & 0 deletions src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ private static TimeSpan GetCollectionAgeLimit(GrainType grainType, Type grainCla
return component;
}

if (typeof(TComponent) == typeof(ILogger))
{
return (TComponent)Logger;
}

if (_components is null) return default;
_components.TryGetValue(typeof(TComponent), out var resultObj);
return (TComponent?)resultObj;
Expand Down

0 comments on commit 188b6b5

Please sign in to comment.