diff --git a/src/Orleans.Runtime/Core/GrainRuntime.cs b/src/Orleans.Runtime/Core/GrainRuntime.cs index 8654a79353..454f646a2c 100644 --- a/src/Orleans.Runtime/Core/GrainRuntime.cs +++ b/src/Orleans.Runtime/Core/GrainRuntime.cs @@ -1,16 +1,12 @@ using System; -using Microsoft.Extensions.Logging; using Orleans.Core; using Orleans.Timers; using Orleans.Storage; -using Orleans.Serialization.Serializers; namespace Orleans.Runtime { internal class GrainRuntime : IGrainRuntime { - private readonly ILoggerFactory loggerFactory; - private readonly IActivatorProvider activatorProvider; private readonly IServiceProvider serviceProvider; private readonly ITimerRegistry timerRegistry; private readonly IGrainFactory grainFactory; @@ -19,17 +15,13 @@ public GrainRuntime( ILocalSiloDetails localSiloDetails, IGrainFactory grainFactory, ITimerRegistry timerRegistry, - IServiceProvider serviceProvider, - ILoggerFactory loggerFactory, - IActivatorProvider activatorProvider) + IServiceProvider serviceProvider) { SiloAddress = localSiloDetails.SiloAddress; SiloIdentity = SiloAddress.ToString(); this.grainFactory = grainFactory; this.timerRegistry = timerRegistry; this.serviceProvider = serviceProvider; - this.loggerFactory = loggerFactory; - this.activatorProvider = activatorProvider; } public string SiloIdentity { get; } @@ -85,7 +77,7 @@ public IStorage GetStorage(IGrainContext grainContext) if (grainContext is null) throw new ArgumentNullException(nameof(grainContext)); var grainType = grainContext.GrainInstance?.GetType() ?? throw new ArgumentNullException(nameof(IGrainContext.GrainInstance)); IGrainStorage grainStorage = GrainStorageHelpers.GetGrainStorage(grainType, ServiceProvider); - return new StateStorageBridge("state", grainContext, grainStorage, this.loggerFactory, this.activatorProvider); + return new StateStorageBridge("state", grainContext, grainStorage); } public static void CheckRuntimeContext(IGrainContext context) diff --git a/src/Orleans.Runtime/Facet/Persistent/PersistentStateStorageFactory.cs b/src/Orleans.Runtime/Facet/Persistent/PersistentStateStorageFactory.cs index 27a549660b..6b7ef3b4fc 100644 --- a/src/Orleans.Runtime/Facet/Persistent/PersistentStateStorageFactory.cs +++ b/src/Orleans.Runtime/Facet/Persistent/PersistentStateStorageFactory.cs @@ -56,7 +56,7 @@ private static void ThrowMissingProviderException(IGrainContext context, IPersis internal sealed class PersistentState : StateStorageBridge, IPersistentState, ILifecycleObserver { - public PersistentState(string stateName, IGrainContext context, IGrainStorage storageProvider) : base(stateName, context, storageProvider, context.ActivationServices.GetRequiredService(), context.ActivationServices.GetRequiredService()) + public PersistentState(string stateName, IGrainContext context, IGrainStorage storageProvider) : base(stateName, context, storageProvider) { var lifecycle = context.ObservableLifecycle; lifecycle.Subscribe(RuntimeTypeNameFormatter.Format(GetType()), GrainLifecycleStage.SetupState, this); diff --git a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs index 5883c5d434..2b7471cbc6 100644 --- a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs +++ b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs @@ -41,6 +41,7 @@ using System.Collections.Generic; using Microsoft.Extensions.Configuration; using Orleans.Serialization.Internal; +using Orleans.Core; namespace Orleans.Hosting { @@ -346,6 +347,7 @@ internal static void AddDefaultServices(ISiloBuilder builder) services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(typeof(IAttributeToFactoryMapper), typeof(PersistentStateAttributeMapper)); + services.TryAddSingleton(); // IAsyncEnumerable support services.AddScoped(); diff --git a/src/Orleans.Runtime/Hosting/StorageProviderHostExtensions.cs b/src/Orleans.Runtime/Hosting/StorageProviderHostExtensions.cs index 2527121702..93b65d98bd 100644 --- a/src/Orleans.Runtime/Hosting/StorageProviderHostExtensions.cs +++ b/src/Orleans.Runtime/Hosting/StorageProviderHostExtensions.cs @@ -1,12 +1,6 @@ using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; using Orleans.Storage; -using System.Xml.Linq; using Microsoft.Extensions.DependencyInjection; -using Orleans.GrainDirectory; using Orleans.Providers; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -27,16 +21,19 @@ public static IServiceCollection AddGrainStorage(this IServiceCollection coll where T : IGrainStorage { collection.AddKeyedSingleton(name, (sp, key) => implementationFactory(sp, key as string)); + // Check if it is the default implementation if (string.Equals(name, ProviderConstants.DEFAULT_STORAGE_PROVIDER_NAME, StringComparison.Ordinal)) { collection.TryAddSingleton(sp => sp.GetKeyedService(ProviderConstants.DEFAULT_STORAGE_PROVIDER_NAME)); } + // Check if the grain storage implements ILifecycleParticipant if (typeof(ILifecycleParticipant).IsAssignableFrom(typeof(T))) { collection.AddSingleton(s => (ILifecycleParticipant)s.GetRequiredKeyedService(name)); } + return collection; } } diff --git a/src/Orleans.Runtime/Storage/StateStorageBridge.cs b/src/Orleans.Runtime/Storage/StateStorageBridge.cs index 0fc0902586..1b1b5712b5 100644 --- a/src/Orleans.Runtime/Storage/StateStorageBridge.cs +++ b/src/Orleans.Runtime/Storage/StateStorageBridge.cs @@ -1,8 +1,10 @@ #nullable enable using System; +using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Runtime.ExceptionServices; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Orleans.Runtime; using Orleans.Serialization.Activators; @@ -19,11 +21,8 @@ namespace Orleans.Core /// public class StateStorageBridge : IStorage, IGrainMigrationParticipant { - private readonly string _name; private readonly IGrainContext _grainContext; - private readonly IGrainStorage _store; - private readonly ILogger _logger; - private readonly IActivator _activator; + private readonly StateStorageBridgeShared _shared; private GrainState? _grainState; /// @@ -32,7 +31,12 @@ public TState State get { GrainRuntime.CheckRuntimeContext(RuntimeContext.Current); - return GrainState.State; + if (_grainState is { } grainState) + { + return grainState.State; + } + + return default!; } set @@ -42,28 +46,32 @@ public TState State } } - private GrainState GrainState => _grainState ??= new GrainState(_activator.Create()); - internal bool IsStateInitialized => _grainState != null; + private GrainState GrainState => _grainState ??= new GrainState(_shared.Activator.Create()); + internal bool IsStateInitialized { get; private set; } /// - public string? Etag { get => GrainState.ETag; set => GrainState.ETag = value; } + public string? Etag { get => _grainState?.ETag; set => GrainState.ETag = value; } /// - public bool RecordExists => GrainState.RecordExists; + public bool RecordExists => IsStateInitialized switch + { + true => GrainState.RecordExists, + _ => throw new InvalidOperationException("State has not yet been loaded") + }; + + [Obsolete("Use StateStorageBridge(string, IGrainContext, IGrainStorage) instead.")] + public StateStorageBridge(string name, IGrainContext grainContext, IGrainStorage store, ILoggerFactory loggerFactory, IActivatorProvider activatorProvider) : this(name, grainContext, store) + { } - public StateStorageBridge(string name, IGrainContext grainContext, IGrainStorage store, ILoggerFactory loggerFactory, IActivatorProvider activatorProvider) + public StateStorageBridge(string name, IGrainContext grainContext, IGrainStorage store) { ArgumentNullException.ThrowIfNull(name); ArgumentNullException.ThrowIfNull(grainContext); ArgumentNullException.ThrowIfNull(store); - ArgumentNullException.ThrowIfNull(loggerFactory); - ArgumentNullException.ThrowIfNull(activatorProvider); - _logger = loggerFactory.CreateLogger(store.GetType()); - _name = name; _grainContext = grainContext; - _store = store; - _activator = activatorProvider.GetActivator(); + var sharedInstances = ActivatorUtilities.GetServiceOrCreateInstance(grainContext.ActivationServices); + _shared = sharedInstances.Get(name, store); } /// @@ -74,7 +82,8 @@ public async Task ReadStateAsync() GrainRuntime.CheckRuntimeContext(RuntimeContext.Current); var sw = ValueStopwatch.StartNew(); - await _store.ReadStateAsync(_name, _grainContext.GrainId, GrainState); + await _shared.Store.ReadStateAsync(_shared.Name, _grainContext.GrainId, GrainState); + IsStateInitialized = true; StorageInstruments.OnStorageRead(sw.Elapsed); } catch (Exception exc) @@ -92,7 +101,7 @@ public async Task WriteStateAsync() GrainRuntime.CheckRuntimeContext(RuntimeContext.Current); var sw = ValueStopwatch.StartNew(); - await _store.WriteStateAsync(_name, _grainContext.GrainId, GrainState); + await _shared.Store.WriteStateAsync(_shared.Name, _grainContext.GrainId, GrainState); StorageInstruments.OnStorageWrite(sw.Elapsed); } catch (Exception exc) @@ -110,12 +119,13 @@ public async Task ClearStateAsync() GrainRuntime.CheckRuntimeContext(RuntimeContext.Current); var sw = ValueStopwatch.StartNew(); + // Clear (most likely Delete) state from external storage - await _store.ClearStateAsync(_name, _grainContext.GrainId, GrainState); + await _shared.Store.ClearStateAsync(_shared.Name, _grainContext.GrainId, GrainState); sw.Stop(); // Reset the in-memory copy of the state - GrainState.State = _activator.Create(); + GrainState.State = _shared.Activator.Create(); // Update counters StorageInstruments.OnStorageDelete(sw.Elapsed); @@ -131,11 +141,11 @@ public void OnDehydrate(IDehydrationContext dehydrationContext) { try { - dehydrationContext.TryAddValue($"state.{_name}", _grainState); + dehydrationContext.TryAddValue(_shared.MigrationContextKey, _grainState); } catch (Exception exception) { - _logger.LogError(exception, "Failed to dehydrate state named {StateName} for grain {GrainId}", _name, _grainContext.GrainId); + _shared.Logger.LogError(exception, "Failed to dehydrate state named {StateName} for grain {GrainId}", _shared.Name, _grainContext.GrainId); // We must throw here since we do not know that the dehydration context is in a clean state after this. throw; @@ -146,12 +156,16 @@ public void OnRehydrate(IRehydrationContext rehydrationContext) { try { - rehydrationContext.TryGetValue($"state.{_name}", out _grainState); + if (rehydrationContext.TryGetValue>(_shared.MigrationContextKey, out var grainState)) + { + _grainState = grainState; + IsStateInitialized = true; + } } catch (Exception exception) { // It is ok to swallow this exception, since state rehydration is best-effort. - _logger.LogError(exception, "Failed to rehydrate state named {StateName} for grain {GrainId}", _name, _grainContext.GrainId); + _shared.Logger.LogError(exception, "Failed to rehydrate state named {StateName} for grain {GrainId}", _shared.Name, _grainContext.GrainId); } } @@ -159,21 +173,49 @@ public void OnRehydrate(IRehydrationContext rehydrationContext) private void OnError(Exception exception, ErrorCode id, string operation) { string? errorCode = null; - (_store as IRestExceptionDecoder)?.DecodeException(exception, out _, out errorCode, true); + (_shared.Store as IRestExceptionDecoder)?.DecodeException(exception, out _, out errorCode, true); var errorString = errorCode is { Length: > 0 } ? $" Error: {errorCode}" : null; var grainId = _grainContext.GrainId; - var providerName = _store.GetType().Name; - _logger.LogError((int)id, exception, "Error from storage provider {ProviderName}.{StateName} during {Operation} for grain {GrainId}{ErrorCode}", providerName, _name, operation, grainId, errorString); + var providerName = _shared.Store.GetType().Name; + _shared.Logger.LogError((int)id, exception, "Error from storage provider {ProviderName}.{StateName} during {Operation} for grain {GrainId}{ErrorCode}", providerName, _shared.Name, operation, grainId, errorString); // If error is not specialization of OrleansException, wrap it if (exception is not OrleansException) { - var errMsg = $"Error from storage provider {providerName}.{_name} during {operation} for grain {grainId}{errorString}{Environment.NewLine} {LogFormatter.PrintException(exception)}"; + var errMsg = $"Error from storage provider {providerName}.{_shared.Name} during {operation} for grain {grainId}{errorString}{Environment.NewLine} {LogFormatter.PrintException(exception)}"; throw new OrleansException(errMsg, exception); } ExceptionDispatchInfo.Throw(exception); } } + + internal sealed class StateStorageBridgeSharedMap(ILoggerFactory loggerFactory, IActivatorProvider activatorProvider) + { + private readonly ConcurrentDictionary<(string Name, IGrainStorage Store, Type StateType), object> _instances = new(); + private readonly ILoggerFactory _loggerFactory = loggerFactory; + private readonly IActivatorProvider _activatorProvider = activatorProvider; + + public StateStorageBridgeShared Get(string name, IGrainStorage store) + => (StateStorageBridgeShared)_instances.GetOrAdd( + (name, store, typeof(TState)), + static (key, self) => new StateStorageBridgeShared( + key.Name, + key.Store, + self._loggerFactory.CreateLogger(key.Store.GetType()), + self._activatorProvider.GetActivator()), + this); + } + + internal sealed class StateStorageBridgeShared(string name, IGrainStorage store, ILogger logger, IActivator activator) + { + private string? _migrationContextKey; + + public readonly string Name = name; + public readonly IGrainStorage Store = store; + public readonly ILogger Logger = logger; + public readonly IActivator Activator = activator; + public string MigrationContextKey => _migrationContextKey ??= $"state.{Name}"; + } } diff --git a/src/Orleans.Streaming/PubSub/PubSubRendezvousGrain.cs b/src/Orleans.Streaming/PubSub/PubSubRendezvousGrain.cs index 60c7fcd27f..ee87430531 100644 --- a/src/Orleans.Streaming/PubSub/PubSubRendezvousGrain.cs +++ b/src/Orleans.Streaming/PubSub/PubSubRendezvousGrain.cs @@ -20,13 +20,11 @@ namespace Orleans.Streams internal sealed class PubSubGrainStateStorageFactory { private readonly IServiceProvider _serviceProvider; - private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; public PubSubGrainStateStorageFactory(IServiceProvider serviceProvider, ILoggerFactory loggerFactory) { _serviceProvider = serviceProvider; - _loggerFactory = loggerFactory; _logger = loggerFactory.CreateLogger(); } @@ -57,8 +55,7 @@ public StateStorageBridge GetStorage(PubSubRendezvousGrain gra storage = _serviceProvider.GetRequiredKeyedService(ProviderConstants.DEFAULT_PUBSUB_PROVIDER_NAME); } - var activatorProvider = _serviceProvider.GetRequiredService(); - return new(nameof(PubSubRendezvousGrain), grain.GrainContext, storage, _loggerFactory, activatorProvider); + return new(nameof(PubSubRendezvousGrain), grain.GrainContext, storage); } } diff --git a/src/Orleans.Transactions/State/NamedTransactionalStateStorageFactory.cs b/src/Orleans.Transactions/State/NamedTransactionalStateStorageFactory.cs index 72e56b3a83..db645b5b95 100644 --- a/src/Orleans.Transactions/State/NamedTransactionalStateStorageFactory.cs +++ b/src/Orleans.Transactions/State/NamedTransactionalStateStorageFactory.cs @@ -1,22 +1,23 @@ using System; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using Orleans.Runtime; using Orleans.Transactions.Abstractions; using Orleans.Storage; -using Orleans.Serialization.Serializers; namespace Orleans.Transactions { public class NamedTransactionalStateStorageFactory : INamedTransactionalStateStorageFactory { private readonly IGrainContextAccessor contextAccessor; - private readonly ILoggerFactory loggerFactory; - public NamedTransactionalStateStorageFactory(IGrainContextAccessor contextAccessor, ILoggerFactory loggerFactory) + [Obsolete("Use the NamedTransactionalStateStorageFactory(IGrainContextAccessor contextAccessor) constructor.")] + public NamedTransactionalStateStorageFactory(IGrainContextAccessor contextAccessor, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory) : this(contextAccessor) + { + } + + public NamedTransactionalStateStorageFactory(IGrainContextAccessor contextAccessor) { this.contextAccessor = contextAccessor; - this.loggerFactory = loggerFactory; } public ITransactionalStateStorage Create(string storageName, string stateName) @@ -37,8 +38,7 @@ public ITransactionalStateStorage Create(string storageName, str if (grainStorage != null) { - IActivatorProvider activatorProvider = currentContext.ActivationServices.GetRequiredService(); - return new TransactionalStateStorageProviderWrapper(grainStorage, stateName, currentContext, this.loggerFactory, activatorProvider); + return new TransactionalStateStorageProviderWrapper(grainStorage, stateName, currentContext); } throw (string.IsNullOrEmpty(storageName)) diff --git a/src/Orleans.Transactions/State/TransactionalStateStorageProviderWrapper.cs b/src/Orleans.Transactions/State/TransactionalStateStorageProviderWrapper.cs index e0059333a2..e13fbe6b0f 100644 --- a/src/Orleans.Transactions/State/TransactionalStateStorageProviderWrapper.cs +++ b/src/Orleans.Transactions/State/TransactionalStateStorageProviderWrapper.cs @@ -2,10 +2,8 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading.Tasks; -using Microsoft.Extensions.Logging; using Orleans.Core; using Orleans.Runtime; -using Orleans.Serialization.Serializers; using Orleans.Storage; using Orleans.Transactions.Abstractions; @@ -17,20 +15,16 @@ internal sealed class TransactionalStateStorageProviderWrapper : ITransa { private readonly IGrainStorage grainStorage; private readonly IGrainContext context; - private readonly ILoggerFactory loggerFactory; - private readonly IActivatorProvider activatorProvider; private readonly string stateName; private StateStorageBridge>? stateStorage; [MemberNotNull(nameof(stateStorage))] private StateStorageBridge> StateStorage => stateStorage ??= GetStateStorage(); - public TransactionalStateStorageProviderWrapper(IGrainStorage grainStorage, string stateName, IGrainContext context, ILoggerFactory loggerFactory, IActivatorProvider activatorProvider) + public TransactionalStateStorageProviderWrapper(IGrainStorage grainStorage, string stateName, IGrainContext context) { this.grainStorage = grainStorage; this.context = context; - this.loggerFactory = loggerFactory; - this.activatorProvider = activatorProvider; this.stateName = stateName; } @@ -104,7 +98,7 @@ public async Task Store(string expectedETag, TransactionalStateMetaData private StateStorageBridge> GetStateStorage() { - return new(this.stateName, context, grainStorage, loggerFactory, activatorProvider); + return new(this.stateName, context, grainStorage); } }