Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gracefully dispose IAsyncEnumerable requests #9186

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 55 additions & 34 deletions src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Orleans.Configuration;
using Orleans.Internal;
Expand All @@ -19,7 +20,7 @@ namespace Orleans.Runtime;
internal sealed class AsyncEnumerableGrainExtension : IAsyncEnumerableGrainExtension, IAsyncDisposable, IDisposable
{
private const long EnumeratorExpirationMilliseconds = 10_000;
private readonly Dictionary<Guid, EnumeratorState> _enumerators = new();
private readonly Dictionary<Guid, EnumeratorState> _enumerators = [];
private readonly IGrainContext _grainContext;
private readonly MessagingOptions _messagingOptions;
private readonly IDisposable _timer;
Expand Down Expand Up @@ -47,15 +48,7 @@ static async (state, cancellationToken) => await state.RemoveExpiredAsync(cancel
}

/// <inheritdoc/>
public ValueTask DisposeAsync(Guid requestId)
{
if (_enumerators.Remove(requestId, out var enumerator) && enumerator.Enumerator is { } value)
{
return value.DisposeAsync();
}

return default;
}
public ValueTask DisposeAsync(Guid requestId) => RemoveEnumeratorAsync(requestId);

private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
{
Expand All @@ -65,7 +58,7 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
if (state.LastSeenTimer.ElapsedMilliseconds > EnumeratorExpirationMilliseconds
&& state.MoveNextTask is null or { IsCompleted: true })
{
toRemove ??= new List<Guid>();
toRemove ??= [];
toRemove.Add(requestId);
}
}
Expand All @@ -75,13 +68,11 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
{
foreach (var requestId in toRemove)
{
_enumerators.Remove(requestId, out var state);
state.MoveNextTask?.Ignore();
var disposeTask = state.Enumerator.DisposeAsync();
if (!disposeTask.IsCompletedSuccessfully)
var removeTask = RemoveEnumeratorAsync(requestId);
if (!removeTask.IsCompletedSuccessfully)
{
tasks ??= new List<Task>();
tasks.Add(disposeTask.AsTask());
tasks ??= [];
tasks.Add(removeTask.AsTask());
}
}
}
Expand All @@ -97,24 +88,22 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
{
request.SetTarget(_grainContext);
var enumerable = request.InvokeImplementation();
var enumerator = enumerable.GetAsyncEnumerator();
ref var entry = ref CollectionsMarshal.GetValueRefOrAddDefault(_enumerators, requestId, out bool exists);
if (exists)
{
return ThrowAlreadyExists(enumerator);
return ThrowAlreadyExists();
}

var cts = new CancellationTokenSource();
var enumerator = enumerable.GetAsyncEnumerator(cts.Token);
entry.Enumerator = enumerator;
entry.LastSeenTimer.Restart();
entry.MaxBatchSize = request.MaxBatchSize;
entry.CancellationTokenSource = cts;
Debug.Assert(entry.MaxBatchSize > 0, "Max batch size must be positive.");
return MoveNextAsync(ref entry, requestId, enumerator);

static async ValueTask<(EnumerationResult Status, object Value)> ThrowAlreadyExists(IAsyncEnumerator<T> enumerator)
{
await enumerator.DisposeAsync();
throw new InvalidOperationException("An enumerator with the same ID already exists.");
}
static ValueTask<(EnumerationResult Status, object Value)> ThrowAlreadyExists() => ValueTask.FromException<(EnumerationResult Status, object Value)>(new InvalidOperationException("An enumerator with the same id already exists."));
}

/// <inheritdoc/>
Expand Down Expand Up @@ -237,7 +226,7 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
}
else
{
_enumerators.Remove(requestId);
await RemoveEnumeratorAsync(requestId);
await typedEnumerator.DisposeAsync();
return (EnumerationResult.Completed, default);
}
Expand All @@ -247,25 +236,29 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
}
catch
{
_enumerators.Remove(requestId);
await RemoveEnumeratorAsync(requestId);
await typedEnumerator.DisposeAsync();
throw;
}
}

private async ValueTask RemoveEnumeratorAsync(Guid requestId)
{
if (_enumerators.Remove(requestId, out var state))
{
await DisposeEnumeratorAsync(state);
}
}

private async ValueTask<(EnumerationResult Status, object Value)> OnComplete<T>(Guid requestId, IAsyncEnumerator<T> enumerator)
{
_enumerators.Remove(requestId, out var state);
state.MoveNextTask?.Ignore();
await enumerator.DisposeAsync();
await RemoveEnumeratorAsync(requestId);
return (EnumerationResult.Completed, default);
}

private async ValueTask<(EnumerationResult Status, object Value)> OnError<T>(Guid requestId, IAsyncEnumerator<T> enumerator, Exception exception)
{
_enumerators.Remove(requestId, out var state);
state.MoveNextTask?.Ignore();
await enumerator.DisposeAsync();
await RemoveEnumeratorAsync(requestId);
ExceptionDispatchInfo.Throw(exception);
return default;
}
Expand All @@ -291,15 +284,42 @@ public async ValueTask DisposeAsync()
_enumerators.Clear();

foreach (var enumerator in enumerators)
{
await DisposeEnumeratorAsync(enumerator);
}
}

_timer.Dispose();
}

private async ValueTask DisposeEnumeratorAsync(EnumeratorState enumerator)
{
try
{
enumerator.CancellationTokenSource.Cancel();
}
catch (Exception exception)
{
var logger = _grainContext.GetComponent<ILogger>();
logger?.LogWarning(exception, "Error cancelling enumerator.");
}

try
{
if (enumerator.MoveNextTask is { } task)
{
if (enumerator.Enumerator is { } value)
{
await task.SuppressThrowing();
await value.DisposeAsync();
}
}
}

_timer.Dispose();
catch (Exception exception)
{
var logger = _grainContext.GetComponent<ILogger>();
logger?.LogWarning(exception, "Error disposing enumerator.");
}
}

/// <inheritdoc/>
Expand All @@ -314,5 +334,6 @@ private struct EnumeratorState
public Task<bool> MoveNextTask;
public CoarseStopwatch LastSeenTimer;
public int MaxBatchSize;
internal CancellationTokenSource CancellationTokenSource;
}
}
10 changes: 5 additions & 5 deletions src/Orleans.Runtime/Catalog/ActivationData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ private DehydrationContextHolder? DehydrationContext

public TComponent? GetComponent<TComponent>() where TComponent : class
{
TComponent? result;
TComponent? result = default;
if (GrainInstance is TComponent grainResult)
{
result = grainResult;
Expand All @@ -260,15 +260,15 @@ private DehydrationContextHolder? DehydrationContext
{
result = (TComponent)resultObj;
}
else if (_shared.GetComponent<TComponent>() is { } sharedComponent)
{
result = sharedComponent;
}
else if (ActivationServices.GetService<TComponent>() is { } component)
{
SetComponent(component);
result = component;
}
else
{
result = _shared.GetComponent<TComponent>();
}

return result;
}
Expand Down
5 changes: 5 additions & 0 deletions src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ private static TimeSpan GetCollectionAgeLimit(GrainType grainType, Type grainCla
return component;
}

if (typeof(TComponent) == typeof(ILogger))
{
return (TComponent)Logger;
}

if (_components is null) return default;
_components.TryGetValue(typeof(TComponent), out var resultObj);
return (TComponent?)resultObj;
Expand Down
Loading