Skip to content

Commit

Permalink
Fix PooledBuffer serialization (#8852)
Browse files Browse the repository at this point in the history
* Fix MigrationContext serialization

* Ensure ActivationMigrationManager is started during silo startup

* Add additional PooledBuffer tests + cleanup
  • Loading branch information
ReubenBond authored Feb 14, 2024
1 parent 34f7e0b commit 57f3812
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 21 deletions.
20 changes: 5 additions & 15 deletions src/Orleans.Core/Lifecycle/MigrationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ public bool TryGetValue<T>(string key, out T? value)
IEnumerator<string> IEnumerable<string>.GetEnumerator() => new Enumerator(this);
IEnumerator IEnumerable.GetEnumerator() => new Enumerator(this);

private sealed class Enumerator : IEnumerator<string>, IEnumerator
private sealed class Enumerator(MigrationContext context) : IEnumerator<string>, IEnumerator
{
private Dictionary<string, (int Offset, int Length)>.KeyCollection.Enumerator _value;
public Enumerator(MigrationContext context) => _value = context._indices.Keys.GetEnumerator();
private Dictionary<string, (int Offset, int Length)>.KeyCollection.Enumerator _value = context._indices.Keys.GetEnumerator();

public string Current => _value.Current;
object IEnumerator.Current => Current;
public void Dispose() => _value.Dispose();
Expand All @@ -133,18 +133,8 @@ public void Reset()
}
}

internal sealed class SerializationHooks
internal sealed class SerializationHooks(SerializerSessionPool serializerSessionPool)
{
private readonly SerializerSessionPool _serializerSessionPool;

public SerializationHooks(SerializerSessionPool serializerSessionPool)
{
_serializerSessionPool = serializerSessionPool;
}

public void OnDeserializing(MigrationContext context)
{
context._sessionPool = _serializerSessionPool;
}
public void OnDeserializing(MigrationContext context) => context._sessionPool = serializerSessionPool;
}
}
25 changes: 24 additions & 1 deletion src/Orleans.Runtime/Catalog/ActivationMigrationManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks.Sources;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.ObjectPool;
using Orleans.Internal;
using Orleans.Runtime.Internal;
using Orleans.Runtime.Scheduler;

Expand Down Expand Up @@ -52,7 +53,7 @@ internal interface IActivationMigrationManager
/// <summary>
/// Migrates grain activations to target hosts and handles migration requests from other hosts.
/// </summary>
internal class ActivationMigrationManager : SystemTarget, IActivationMigrationManagerSystemTarget, IActivationMigrationManager
internal class ActivationMigrationManager : SystemTarget, IActivationMigrationManagerSystemTarget, IActivationMigrationManager, ILifecycleParticipant<ISiloLifecycle>
{
private const int MaxBatchSize = 1_000;
private readonly ConcurrentDictionary<SiloAddress, (Task PumpTask, Channel<MigrationWorkItem> WorkItemChannel)> _workers = new();
Expand Down Expand Up @@ -305,6 +306,28 @@ private void RemoveWorker(SiloAddress targetSilo)
}
}

private Task StartAsync(CancellationToken cancellationToken) => Task.CompletedTask;
private async Task StopAsync(CancellationToken cancellationToken)
{
var workerTasks = new List<Task>();
foreach (var (_, value) in _workers)
{
value.WorkItemChannel.Writer.TryComplete();
workerTasks.Add(value.PumpTask);
}

await Task.WhenAll(workerTasks).WithCancellation(cancellationToken);
}

void ILifecycleParticipant<ISiloLifecycle>.Participate(ISiloLifecycle lifecycle)
{
lifecycle.Subscribe(
nameof(ActivationMigrationManager),
ServiceLifecycleStage.RuntimeGrainServices,
ct => this.RunOrQueueTask(() => StartAsync(ct)),
ct => this.RunOrQueueTask(() => StopAsync(ct)));
}

private class MigrationWorkItem : IValueTaskSource
{
private ManualResetValueTaskSourceCore<int> _core = new() { RunContinuationsAsynchronously = true };
Expand Down
1 change: 1 addition & 0 deletions src/Orleans.Runtime/Hosting/DefaultSiloServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ internal static void AddDefaultServices(ISiloBuilder builder)
services.AddSingleton<MigrationContext.SerializationHooks>();
services.AddSingleton<ActivationMigrationManager>();
services.AddFromExisting<IActivationMigrationManager, ActivationMigrationManager>();
services.AddFromExisting<ILifecycleParticipant<ISiloLifecycle>, ActivationMigrationManager>();

ApplyConfiguration(builder);
}
Expand Down
11 changes: 7 additions & 4 deletions src/Orleans.Serialization/Codecs/ByteArrayCodec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,11 @@ public static Memory<byte> DeepCopy(Memory<byte> input, CopyContext copyContext)
/// Serializer for <see cref="PooledBuffer"/> instances.
/// </summary>
[RegisterSerializer]
public sealed class PooledBufferCodec : IValueSerializer<PooledBuffer>
public sealed class PooledBufferCodec : IFieldCodec<PooledBuffer>
{
public void Serialize<TBufferWriter>(ref Writer<TBufferWriter> writer, scoped ref PooledBuffer value) where TBufferWriter : IBufferWriter<byte>
public void WriteField<TBufferWriter>(ref Writer<TBufferWriter> writer, uint fieldIdDelta, Type expectedType, PooledBuffer value) where TBufferWriter : IBufferWriter<byte>
{
writer.WriteFieldHeader(fieldIdDelta, expectedType, typeof(PooledBuffer), WireType.LengthPrefixed);
writer.WriteVarUInt32((uint)value.Length);
foreach (var segment in value)
{
Expand All @@ -311,11 +312,12 @@ public void Serialize<TBufferWriter>(ref Writer<TBufferWriter> writer, scoped re
// Senders must not use the value after sending.
// Receivers must dispose of the value after use.
value.Reset();
value = default;
}

public void Deserialize<TInput>(ref Reader<TInput> reader, scoped ref PooledBuffer value)
public PooledBuffer ReadValue<TInput>(ref Reader<TInput> reader, Field field)
{
field.EnsureWireType(WireType.LengthPrefixed);
var value = new PooledBuffer();
const int MaxSpanLength = 4096;
var length = (int)reader.ReadVarUInt32();
while (length > 0)
Expand All @@ -328,6 +330,7 @@ public void Deserialize<TInput>(ref Reader<TInput> reader, scoped ref PooledBuff
}

Debug.Assert(length == 0);
return value;
}
}

Expand Down
56 changes: 56 additions & 0 deletions test/DefaultCluster.Tests/Migration/MigrationTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Diagnostics;
using Orleans.Core.Internal;
using Orleans.Runtime;
using Orleans.Runtime.Placement;
Expand Down Expand Up @@ -75,6 +76,61 @@ public async Task DirectedGrainMigrationTest()
}
}

/// <summary>
/// Tests that multiple grains can be migrated simultaneously.
/// </summary>
[Fact, TestCategory("BVT")]
public async Task MultiGrainDirectedMigrationTest()
{
var baseId = GetRandomGrainId();
for (var i = 1; i < 100; ++i)
{
var a = GrainFactory.GetGrain<IMigrationTestGrain>(baseId + 2 * i);
var expectedState = Random.Shared.Next();
await a.SetState(expectedState);
var originalAddressA = await a.GetGrainAddress();
var originalHostA = originalAddressA.SiloAddress;

RequestContext.Set(IPlacementDirector.PlacementHintKey, originalHostA);
var b = GrainFactory.GetGrain<IMigrationTestGrain>(baseId + 1 + 2 * i);
await b.SetState(expectedState);
var originalAddressB = await b.GetGrainAddress();
Assert.Equal(originalHostA, originalAddressB.SiloAddress);

var targetHost = Fixture.HostedCluster.GetActiveSilos().Select(s => s.SiloAddress).First(address => address != originalHostA);

// Trigger migration, setting a placement hint to coerce the placement director to use the target silo
RequestContext.Set(IPlacementDirector.PlacementHintKey, targetHost);
var migrateA = a.Cast<IGrainManagementExtension>().MigrateOnIdle();
var migrateB = b.Cast<IGrainManagementExtension>().MigrateOnIdle();
await migrateA;
await migrateB;

while (true)
{
var newAddress = await a.GetGrainAddress();
if (newAddress.ActivationId != originalAddressA.ActivationId)
{
Assert.Equal(targetHost, newAddress.SiloAddress);
break;
}
}

while (true)
{
var newAddress = await b.GetGrainAddress();
if (newAddress.ActivationId != originalAddressB.ActivationId)
{
Assert.Equal(targetHost, newAddress.SiloAddress);
break;
}
}

Assert.Equal(expectedState, await a.GetState());
Assert.Equal(expectedState, await b.GetState());
}
}

/// <summary>
/// Tests that grain migration works for a simple grain which uses <see cref="Grain{TGrainState}"/> for state.
/// The test specifies an alternative location for the grain to migrate to and asserts that it migrates to that location.
Expand Down
1 change: 0 additions & 1 deletion test/NonSilo.Tests/Serialization/BuiltInSerializerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public void ValueTupleTypesHasSerializer()
/// <summary>
/// Tests that the default (non-fallback) serializer can handle complex classes.
/// </summary>
/// <param name="serializerToUse"></param>
[Fact, TestCategory("BVT")]
public void Serialize_ComplexAccessibleClass()
{
Expand Down
84 changes: 84 additions & 0 deletions test/Orleans.Serialization.UnitTests/PooledBufferTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,90 @@ static void SerializeObject(SerializerSessionPool pool, Serializer serializer, L
}
}

/// <summary>
/// Tests that the serializer can correctly serialized <see cref="PooledBuffer"/>.
/// </summary>
[Fact]
public void PooledBuffer_SerializerRoundTrip()
{
var serviceProvider = new ServiceCollection()
.AddSerializer()
.BuildServiceProvider();
var serializer = serviceProvider.GetRequiredService<Serializer>();

var random = new Random();
for (var i = 0; i < 10; i++)
{
const int TargetLength = 8120;

// NOTE: The serializer is responsible for freeing the buffer provided to it, so we do not free this.
var buffer = new PooledBuffer();
while (buffer.Length < TargetLength)
{
var span = buffer.GetSpan(TargetLength - buffer.Length);
var writeLen = Math.Min(span.Length, TargetLength - buffer.Length);
random.NextBytes(span[..writeLen]);
buffer.Advance(writeLen);
}

var bytes = buffer.ToArray();
Assert.Equal(TargetLength, bytes.Length);

var result = serializer.Deserialize<PooledBuffer>(serializer.SerializeToArray(buffer));
Assert.Equal(TargetLength, result.Length);

var resultBytes = result.ToArray();
Assert.Equal(bytes, resultBytes);

// NOTE: we are responsible for disposing a buffer returned from deserialization.
result.Dispose();
}
}

/// <summary>
/// Tests that the serializer can correctly serialized <see cref="PooledBuffer"/> when it's embedded in another structure.
/// </summary>
[Fact]
public void PooledBuffer_SerializerRoundTrip_Embedded()
{
var serviceProvider = new ServiceCollection()
.AddSerializer()
.BuildServiceProvider();
var serializer = serviceProvider.GetRequiredService<Serializer>();

var random = new Random();
for (var i = 0; i < 10; i++)
{
const int TargetLength = 8120;

// NOTE: The serializer is responsible for freeing the buffer provided to it, so we do not free this.
var buffer = new PooledBuffer();
while (buffer.Length < TargetLength)
{
var span = buffer.GetSpan(TargetLength - buffer.Length);
var writeLen = Math.Min(span.Length, TargetLength - buffer.Length);
random.NextBytes(span[..writeLen]);
buffer.Advance(writeLen);
}

var bytes = buffer.ToArray();
Assert.Equal(TargetLength, bytes.Length);

var embed = (Guid: Guid.NewGuid(), Buffer: buffer, Int: 42);
var result = serializer.Deserialize<(Guid Guid, PooledBuffer Buffer, int Int)>(serializer.SerializeToArray(embed));
Assert.Equal(embed.Guid, result.Guid);
Assert.Equal(embed.Int, result.Int);
var resultBuffer = result.Buffer;
Assert.Equal(TargetLength, resultBuffer.Length);

var resultBytes = resultBuffer.ToArray();
Assert.Equal(bytes, resultBytes);

// NOTE: we are responsible for disposing a buffer returned from deserialization.
resultBuffer.Dispose();
}
}

[GenerateSerializer]
public readonly record struct LargeObject(
[property: Id(0)] Guid Id,
Expand Down

0 comments on commit 57f3812

Please sign in to comment.