Skip to content

Commit

Permalink
Always read grain state during activation if it has not been rehydrat…
Browse files Browse the repository at this point in the history
…ed (#8944)

* Reduce memory footprint of StateStorageBridge<TState>
  • Loading branch information
ReubenBond authored Apr 13, 2024
1 parent 27e801b commit deda4ba
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 64 deletions.
12 changes: 2 additions & 10 deletions src/Orleans.Runtime/Core/GrainRuntime.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
using System;
using Microsoft.Extensions.Logging;
using Orleans.Core;
using Orleans.Timers;
using Orleans.Storage;
using Orleans.Serialization.Serializers;

namespace Orleans.Runtime
{
internal class GrainRuntime : IGrainRuntime
{
private readonly ILoggerFactory loggerFactory;
private readonly IActivatorProvider activatorProvider;
private readonly IServiceProvider serviceProvider;
private readonly ITimerRegistry timerRegistry;
private readonly IGrainFactory grainFactory;
Expand All @@ -19,17 +15,13 @@ public GrainRuntime(
ILocalSiloDetails localSiloDetails,
IGrainFactory grainFactory,
ITimerRegistry timerRegistry,
IServiceProvider serviceProvider,
ILoggerFactory loggerFactory,
IActivatorProvider activatorProvider)
IServiceProvider serviceProvider)
{
SiloAddress = localSiloDetails.SiloAddress;
SiloIdentity = SiloAddress.ToString();
this.grainFactory = grainFactory;
this.timerRegistry = timerRegistry;
this.serviceProvider = serviceProvider;
this.loggerFactory = loggerFactory;
this.activatorProvider = activatorProvider;
}

public string SiloIdentity { get; }
Expand Down Expand Up @@ -85,7 +77,7 @@ public IStorage<TGrainState> GetStorage<TGrainState>(IGrainContext grainContext)
if (grainContext is null) throw new ArgumentNullException(nameof(grainContext));
var grainType = grainContext.GrainInstance?.GetType() ?? throw new ArgumentNullException(nameof(IGrainContext.GrainInstance));
IGrainStorage grainStorage = GrainStorageHelpers.GetGrainStorage(grainType, ServiceProvider);
return new StateStorageBridge<TGrainState>("state", grainContext, grainStorage, this.loggerFactory, this.activatorProvider);
return new StateStorageBridge<TGrainState>("state", grainContext, grainStorage);
}

public static void CheckRuntimeContext(IGrainContext context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private static void ThrowMissingProviderException(IGrainContext context, IPersis

internal sealed class PersistentState<TState> : StateStorageBridge<TState>, IPersistentState<TState>, ILifecycleObserver
{
public PersistentState(string stateName, IGrainContext context, IGrainStorage storageProvider) : base(stateName, context, storageProvider, context.ActivationServices.GetRequiredService<ILoggerFactory>(), context.ActivationServices.GetRequiredService<IActivatorProvider>())
public PersistentState(string stateName, IGrainContext context, IGrainStorage storageProvider) : base(stateName, context, storageProvider)
{
var lifecycle = context.ObservableLifecycle;
lifecycle.Subscribe(RuntimeTypeNameFormatter.Format(GetType()), GrainLifecycleStage.SetupState, this);
Expand Down
2 changes: 2 additions & 0 deletions src/Orleans.Runtime/Hosting/DefaultSiloServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
using System.Collections.Generic;
using Microsoft.Extensions.Configuration;
using Orleans.Serialization.Internal;
using Orleans.Core;

namespace Orleans.Hosting
{
Expand Down Expand Up @@ -346,6 +347,7 @@ internal static void AddDefaultServices(ISiloBuilder builder)
services.TryAddSingleton<IGrainStorageSerializer, JsonGrainStorageSerializer>();
services.TryAddSingleton<IPersistentStateFactory, PersistentStateFactory>();
services.TryAddSingleton(typeof(IAttributeToFactoryMapper<PersistentStateAttribute>), typeof(PersistentStateAttributeMapper));
services.TryAddSingleton<StateStorageBridgeSharedMap>();

// IAsyncEnumerable support
services.AddScoped<IAsyncEnumerableGrainExtension, AsyncEnumerableGrainExtension>();
Expand Down
9 changes: 3 additions & 6 deletions src/Orleans.Runtime/Hosting/StorageProviderHostExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Orleans.Storage;
using System.Xml.Linq;
using Microsoft.Extensions.DependencyInjection;
using Orleans.GrainDirectory;
using Orleans.Providers;
using Microsoft.Extensions.DependencyInjection.Extensions;

Expand All @@ -27,16 +21,19 @@ public static IServiceCollection AddGrainStorage<T>(this IServiceCollection coll
where T : IGrainStorage
{
collection.AddKeyedSingleton<IGrainStorage>(name, (sp, key) => implementationFactory(sp, key as string));

// Check if it is the default implementation
if (string.Equals(name, ProviderConstants.DEFAULT_STORAGE_PROVIDER_NAME, StringComparison.Ordinal))
{
collection.TryAddSingleton(sp => sp.GetKeyedService<IGrainStorage>(ProviderConstants.DEFAULT_STORAGE_PROVIDER_NAME));
}

// Check if the grain storage implements ILifecycleParticipant<ISiloLifecycle>
if (typeof(ILifecycleParticipant<ISiloLifecycle>).IsAssignableFrom(typeof(T)))
{
collection.AddSingleton(s => (ILifecycleParticipant<ISiloLifecycle>)s.GetRequiredKeyedService<IGrainStorage>(name));
}

return collection;
}
}
Expand Down
98 changes: 70 additions & 28 deletions src/Orleans.Runtime/Storage/StateStorageBridge.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#nullable enable
using System;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.ExceptionServices;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Orleans.Runtime;
using Orleans.Serialization.Activators;
Expand All @@ -19,11 +21,8 @@ namespace Orleans.Core
/// <seealso cref="IStorage{TState}" />
public class StateStorageBridge<TState> : IStorage<TState>, IGrainMigrationParticipant
{
private readonly string _name;
private readonly IGrainContext _grainContext;
private readonly IGrainStorage _store;
private readonly ILogger _logger;
private readonly IActivator<TState> _activator;
private readonly StateStorageBridgeShared<TState> _shared;
private GrainState<TState>? _grainState;

/// <inheritdoc/>
Expand All @@ -32,7 +31,12 @@ public TState State
get
{
GrainRuntime.CheckRuntimeContext(RuntimeContext.Current);
return GrainState.State;
if (_grainState is { } grainState)
{
return grainState.State;
}

return default!;
}

set
Expand All @@ -42,28 +46,32 @@ public TState State
}
}

private GrainState<TState> GrainState => _grainState ??= new GrainState<TState>(_activator.Create());
internal bool IsStateInitialized => _grainState != null;
private GrainState<TState> GrainState => _grainState ??= new GrainState<TState>(_shared.Activator.Create());
internal bool IsStateInitialized { get; private set; }

/// <inheritdoc/>
public string? Etag { get => GrainState.ETag; set => GrainState.ETag = value; }
public string? Etag { get => _grainState?.ETag; set => GrainState.ETag = value; }

/// <inheritdoc/>
public bool RecordExists => GrainState.RecordExists;
public bool RecordExists => IsStateInitialized switch
{
true => GrainState.RecordExists,
_ => throw new InvalidOperationException("State has not yet been loaded")
};

[Obsolete("Use StateStorageBridge(string, IGrainContext, IGrainStorage) instead.")]
public StateStorageBridge(string name, IGrainContext grainContext, IGrainStorage store, ILoggerFactory loggerFactory, IActivatorProvider activatorProvider) : this(name, grainContext, store)
{ }

public StateStorageBridge(string name, IGrainContext grainContext, IGrainStorage store, ILoggerFactory loggerFactory, IActivatorProvider activatorProvider)
public StateStorageBridge(string name, IGrainContext grainContext, IGrainStorage store)
{
ArgumentNullException.ThrowIfNull(name);
ArgumentNullException.ThrowIfNull(grainContext);
ArgumentNullException.ThrowIfNull(store);
ArgumentNullException.ThrowIfNull(loggerFactory);
ArgumentNullException.ThrowIfNull(activatorProvider);

_logger = loggerFactory.CreateLogger(store.GetType());
_name = name;
_grainContext = grainContext;
_store = store;
_activator = activatorProvider.GetActivator<TState>();
var sharedInstances = ActivatorUtilities.GetServiceOrCreateInstance<StateStorageBridgeSharedMap>(grainContext.ActivationServices);
_shared = sharedInstances.Get<TState>(name, store);
}

/// <inheritdoc />
Expand All @@ -74,7 +82,8 @@ public async Task ReadStateAsync()
GrainRuntime.CheckRuntimeContext(RuntimeContext.Current);

var sw = ValueStopwatch.StartNew();
await _store.ReadStateAsync(_name, _grainContext.GrainId, GrainState);
await _shared.Store.ReadStateAsync(_shared.Name, _grainContext.GrainId, GrainState);
IsStateInitialized = true;
StorageInstruments.OnStorageRead(sw.Elapsed);
}
catch (Exception exc)
Expand All @@ -92,7 +101,7 @@ public async Task WriteStateAsync()
GrainRuntime.CheckRuntimeContext(RuntimeContext.Current);

var sw = ValueStopwatch.StartNew();
await _store.WriteStateAsync(_name, _grainContext.GrainId, GrainState);
await _shared.Store.WriteStateAsync(_shared.Name, _grainContext.GrainId, GrainState);
StorageInstruments.OnStorageWrite(sw.Elapsed);
}
catch (Exception exc)
Expand All @@ -110,12 +119,13 @@ public async Task ClearStateAsync()
GrainRuntime.CheckRuntimeContext(RuntimeContext.Current);

var sw = ValueStopwatch.StartNew();

// Clear (most likely Delete) state from external storage
await _store.ClearStateAsync(_name, _grainContext.GrainId, GrainState);
await _shared.Store.ClearStateAsync(_shared.Name, _grainContext.GrainId, GrainState);
sw.Stop();

// Reset the in-memory copy of the state
GrainState.State = _activator.Create();
GrainState.State = _shared.Activator.Create();

// Update counters
StorageInstruments.OnStorageDelete(sw.Elapsed);
Expand All @@ -131,11 +141,11 @@ public void OnDehydrate(IDehydrationContext dehydrationContext)
{
try
{
dehydrationContext.TryAddValue($"state.{_name}", _grainState);
dehydrationContext.TryAddValue(_shared.MigrationContextKey, _grainState);
}
catch (Exception exception)
{
_logger.LogError(exception, "Failed to dehydrate state named {StateName} for grain {GrainId}", _name, _grainContext.GrainId);
_shared.Logger.LogError(exception, "Failed to dehydrate state named {StateName} for grain {GrainId}", _shared.Name, _grainContext.GrainId);

// We must throw here since we do not know that the dehydration context is in a clean state after this.
throw;
Expand All @@ -146,34 +156,66 @@ public void OnRehydrate(IRehydrationContext rehydrationContext)
{
try
{
rehydrationContext.TryGetValue($"state.{_name}", out _grainState);
if (rehydrationContext.TryGetValue<GrainState<TState>>(_shared.MigrationContextKey, out var grainState))
{
_grainState = grainState;
IsStateInitialized = true;
}
}
catch (Exception exception)
{
// It is ok to swallow this exception, since state rehydration is best-effort.
_logger.LogError(exception, "Failed to rehydrate state named {StateName} for grain {GrainId}", _name, _grainContext.GrainId);
_shared.Logger.LogError(exception, "Failed to rehydrate state named {StateName} for grain {GrainId}", _shared.Name, _grainContext.GrainId);
}
}

[DoesNotReturn]
private void OnError(Exception exception, ErrorCode id, string operation)
{
string? errorCode = null;
(_store as IRestExceptionDecoder)?.DecodeException(exception, out _, out errorCode, true);
(_shared.Store as IRestExceptionDecoder)?.DecodeException(exception, out _, out errorCode, true);
var errorString = errorCode is { Length: > 0 } ? $" Error: {errorCode}" : null;

var grainId = _grainContext.GrainId;
var providerName = _store.GetType().Name;
_logger.LogError((int)id, exception, "Error from storage provider {ProviderName}.{StateName} during {Operation} for grain {GrainId}{ErrorCode}", providerName, _name, operation, grainId, errorString);
var providerName = _shared.Store.GetType().Name;
_shared.Logger.LogError((int)id, exception, "Error from storage provider {ProviderName}.{StateName} during {Operation} for grain {GrainId}{ErrorCode}", providerName, _shared.Name, operation, grainId, errorString);

// If error is not specialization of OrleansException, wrap it
if (exception is not OrleansException)
{
var errMsg = $"Error from storage provider {providerName}.{_name} during {operation} for grain {grainId}{errorString}{Environment.NewLine} {LogFormatter.PrintException(exception)}";
var errMsg = $"Error from storage provider {providerName}.{_shared.Name} during {operation} for grain {grainId}{errorString}{Environment.NewLine} {LogFormatter.PrintException(exception)}";
throw new OrleansException(errMsg, exception);
}

ExceptionDispatchInfo.Throw(exception);
}
}

internal sealed class StateStorageBridgeSharedMap(ILoggerFactory loggerFactory, IActivatorProvider activatorProvider)
{
private readonly ConcurrentDictionary<(string Name, IGrainStorage Store, Type StateType), object> _instances = new();
private readonly ILoggerFactory _loggerFactory = loggerFactory;
private readonly IActivatorProvider _activatorProvider = activatorProvider;

public StateStorageBridgeShared<TState> Get<TState>(string name, IGrainStorage store)
=> (StateStorageBridgeShared<TState>)_instances.GetOrAdd(
(name, store, typeof(TState)),
static (key, self) => new StateStorageBridgeShared<TState>(
key.Name,
key.Store,
self._loggerFactory.CreateLogger(key.Store.GetType()),
self._activatorProvider.GetActivator<TState>()),
this);
}

internal sealed class StateStorageBridgeShared<TState>(string name, IGrainStorage store, ILogger logger, IActivator<TState> activator)
{
private string? _migrationContextKey;

public readonly string Name = name;
public readonly IGrainStorage Store = store;
public readonly ILogger Logger = logger;
public readonly IActivator<TState> Activator = activator;
public string MigrationContextKey => _migrationContextKey ??= $"state.{Name}";
}
}
5 changes: 1 addition & 4 deletions src/Orleans.Streaming/PubSub/PubSubRendezvousGrain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ namespace Orleans.Streams
internal sealed class PubSubGrainStateStorageFactory
{
private readonly IServiceProvider _serviceProvider;
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger<PubSubGrainStateStorageFactory> _logger;

public PubSubGrainStateStorageFactory(IServiceProvider serviceProvider, ILoggerFactory loggerFactory)
{
_serviceProvider = serviceProvider;
_loggerFactory = loggerFactory;
_logger = loggerFactory.CreateLogger<PubSubGrainStateStorageFactory>();
}

Expand Down Expand Up @@ -57,8 +55,7 @@ public StateStorageBridge<PubSubGrainState> GetStorage(PubSubRendezvousGrain gra
storage = _serviceProvider.GetRequiredKeyedService<IGrainStorage>(ProviderConstants.DEFAULT_PUBSUB_PROVIDER_NAME);
}

var activatorProvider = _serviceProvider.GetRequiredService<IActivatorProvider>();
return new(nameof(PubSubRendezvousGrain), grain.GrainContext, storage, _loggerFactory, activatorProvider);
return new(nameof(PubSubRendezvousGrain), grain.GrainContext, storage);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
using System;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Orleans.Runtime;
using Orleans.Transactions.Abstractions;
using Orleans.Storage;
using Orleans.Serialization.Serializers;

namespace Orleans.Transactions
{
public class NamedTransactionalStateStorageFactory : INamedTransactionalStateStorageFactory
{
private readonly IGrainContextAccessor contextAccessor;
private readonly ILoggerFactory loggerFactory;

public NamedTransactionalStateStorageFactory(IGrainContextAccessor contextAccessor, ILoggerFactory loggerFactory)
[Obsolete("Use the NamedTransactionalStateStorageFactory(IGrainContextAccessor contextAccessor) constructor.")]
public NamedTransactionalStateStorageFactory(IGrainContextAccessor contextAccessor, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory) : this(contextAccessor)
{
}

public NamedTransactionalStateStorageFactory(IGrainContextAccessor contextAccessor)
{
this.contextAccessor = contextAccessor;
this.loggerFactory = loggerFactory;
}

public ITransactionalStateStorage<TState> Create<TState>(string storageName, string stateName)
Expand All @@ -37,8 +38,7 @@ public ITransactionalStateStorage<TState> Create<TState>(string storageName, str

if (grainStorage != null)
{
IActivatorProvider activatorProvider = currentContext.ActivationServices.GetRequiredService<IActivatorProvider>();
return new TransactionalStateStorageProviderWrapper<TState>(grainStorage, stateName, currentContext, this.loggerFactory, activatorProvider);
return new TransactionalStateStorageProviderWrapper<TState>(grainStorage, stateName, currentContext);
}

throw (string.IsNullOrEmpty(storageName))
Expand Down
Loading

0 comments on commit deda4ba

Please sign in to comment.