diff --git a/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs b/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs index 3c25f8259f..45569d03cb 100644 --- a/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs +++ b/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs @@ -6,6 +6,7 @@ using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Orleans.Configuration; using Orleans.Internal; @@ -19,7 +20,7 @@ namespace Orleans.Runtime; internal sealed class AsyncEnumerableGrainExtension : IAsyncEnumerableGrainExtension, IAsyncDisposable, IDisposable { private const long EnumeratorExpirationMilliseconds = 10_000; - private readonly Dictionary _enumerators = new(); + private readonly Dictionary _enumerators = []; private readonly IGrainContext _grainContext; private readonly MessagingOptions _messagingOptions; private readonly IDisposable _timer; @@ -47,15 +48,7 @@ static async (state, cancellationToken) => await state.RemoveExpiredAsync(cancel } /// - public ValueTask DisposeAsync(Guid requestId) - { - if (_enumerators.Remove(requestId, out var enumerator) && enumerator.Enumerator is { } value) - { - return value.DisposeAsync(); - } - - return default; - } + public ValueTask DisposeAsync(Guid requestId) => RemoveEnumeratorAsync(requestId); private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) { @@ -65,7 +58,7 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) if (state.LastSeenTimer.ElapsedMilliseconds > EnumeratorExpirationMilliseconds && state.MoveNextTask is null or { IsCompleted: true }) { - toRemove ??= new List(); + toRemove ??= []; toRemove.Add(requestId); } } @@ -75,13 +68,11 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) { foreach (var requestId in toRemove) { - _enumerators.Remove(requestId, out var state); - state.MoveNextTask?.Ignore(); - var disposeTask = state.Enumerator.DisposeAsync(); - if (!disposeTask.IsCompletedSuccessfully) + var removeTask = RemoveEnumeratorAsync(requestId); + if (!removeTask.IsCompletedSuccessfully) { - tasks ??= new List(); - tasks.Add(disposeTask.AsTask()); + tasks ??= []; + tasks.Add(removeTask.AsTask()); } } } @@ -97,24 +88,22 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) { request.SetTarget(_grainContext); var enumerable = request.InvokeImplementation(); - var enumerator = enumerable.GetAsyncEnumerator(); ref var entry = ref CollectionsMarshal.GetValueRefOrAddDefault(_enumerators, requestId, out bool exists); if (exists) { - return ThrowAlreadyExists(enumerator); + return ThrowAlreadyExists(); } + var cts = new CancellationTokenSource(); + var enumerator = enumerable.GetAsyncEnumerator(cts.Token); entry.Enumerator = enumerator; entry.LastSeenTimer.Restart(); entry.MaxBatchSize = request.MaxBatchSize; + entry.CancellationTokenSource = cts; Debug.Assert(entry.MaxBatchSize > 0, "Max batch size must be positive."); return MoveNextAsync(ref entry, requestId, enumerator); - static async ValueTask<(EnumerationResult Status, object Value)> ThrowAlreadyExists(IAsyncEnumerator enumerator) - { - await enumerator.DisposeAsync(); - throw new InvalidOperationException("An enumerator with the same ID already exists."); - } + static ValueTask<(EnumerationResult Status, object Value)> ThrowAlreadyExists() => ValueTask.FromException<(EnumerationResult Status, object Value)>(new InvalidOperationException("An enumerator with the same id already exists.")); } /// @@ -237,7 +226,7 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) } else { - _enumerators.Remove(requestId); + await RemoveEnumeratorAsync(requestId); await typedEnumerator.DisposeAsync(); return (EnumerationResult.Completed, default); } @@ -247,25 +236,29 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) } catch { - _enumerators.Remove(requestId); + await RemoveEnumeratorAsync(requestId); await typedEnumerator.DisposeAsync(); throw; } } + private async ValueTask RemoveEnumeratorAsync(Guid requestId) + { + if (_enumerators.Remove(requestId, out var state)) + { + await DisposeEnumeratorAsync(state); + } + } + private async ValueTask<(EnumerationResult Status, object Value)> OnComplete(Guid requestId, IAsyncEnumerator enumerator) { - _enumerators.Remove(requestId, out var state); - state.MoveNextTask?.Ignore(); - await enumerator.DisposeAsync(); + await RemoveEnumeratorAsync(requestId); return (EnumerationResult.Completed, default); } private async ValueTask<(EnumerationResult Status, object Value)> OnError(Guid requestId, IAsyncEnumerator enumerator, Exception exception) { - _enumerators.Remove(requestId, out var state); - state.MoveNextTask?.Ignore(); - await enumerator.DisposeAsync(); + await RemoveEnumeratorAsync(requestId); ExceptionDispatchInfo.Throw(exception); return default; } @@ -291,15 +284,42 @@ public async ValueTask DisposeAsync() _enumerators.Clear(); foreach (var enumerator in enumerators) + { + await DisposeEnumeratorAsync(enumerator); + } + } + + _timer.Dispose(); + } + + private async ValueTask DisposeEnumeratorAsync(EnumeratorState enumerator) + { + try + { + enumerator.CancellationTokenSource.Cancel(); + } + catch (Exception exception) + { + var logger = _grainContext.GetComponent(); + logger?.LogWarning(exception, "Error cancelling enumerator."); + } + + try + { + if (enumerator.MoveNextTask is { } task) { if (enumerator.Enumerator is { } value) { + await task.SuppressThrowing(); await value.DisposeAsync(); } } } - - _timer.Dispose(); + catch (Exception exception) + { + var logger = _grainContext.GetComponent(); + logger?.LogWarning(exception, "Error disposing enumerator."); + } } /// @@ -314,5 +334,6 @@ private struct EnumeratorState public Task MoveNextTask; public CoarseStopwatch LastSeenTimer; public int MaxBatchSize; + internal CancellationTokenSource CancellationTokenSource; } } diff --git a/src/Orleans.Runtime/Catalog/ActivationData.cs b/src/Orleans.Runtime/Catalog/ActivationData.cs index 345f71490d..db49d93f04 100644 --- a/src/Orleans.Runtime/Catalog/ActivationData.cs +++ b/src/Orleans.Runtime/Catalog/ActivationData.cs @@ -247,7 +247,7 @@ private DehydrationContextHolder? DehydrationContext public TComponent? GetComponent() where TComponent : class { - TComponent? result; + TComponent? result = default; if (GrainInstance is TComponent grainResult) { result = grainResult; @@ -260,15 +260,15 @@ private DehydrationContextHolder? DehydrationContext { result = (TComponent)resultObj; } + else if (_shared.GetComponent() is { } sharedComponent) + { + result = sharedComponent; + } else if (ActivationServices.GetService() is { } component) { SetComponent(component); result = component; } - else - { - result = _shared.GetComponent(); - } return result; } diff --git a/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs b/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs index 35c40dbeda..1c5715dea5 100644 --- a/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs +++ b/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs @@ -106,6 +106,11 @@ private static TimeSpan GetCollectionAgeLimit(GrainType grainType, Type grainCla return component; } + if (typeof(TComponent) == typeof(ILogger)) + { + return (TComponent)Logger; + } + if (_components is null) return default; _components.TryGetValue(typeof(TComponent), out var resultObj); return (TComponent?)resultObj;