Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PooledBuffer serialization & initialize ActivationMigrationManager on startup #8861

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/Orleans.Runtime/Catalog/ActivationMigrationManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks.Sources;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.ObjectPool;
using Orleans.Internal;
using Orleans.Runtime.Internal;
using Orleans.Runtime.Scheduler;

Expand Down Expand Up @@ -52,7 +53,7 @@ internal interface IActivationMigrationManager
/// <summary>
/// Migrates grain activations to target hosts and handles migration requests from other hosts.
/// </summary>
internal class ActivationMigrationManager : SystemTarget, IActivationMigrationManagerSystemTarget, IActivationMigrationManager
internal class ActivationMigrationManager : SystemTarget, IActivationMigrationManagerSystemTarget, IActivationMigrationManager, ILifecycleParticipant<ISiloLifecycle>
{
private const int MaxBatchSize = 1_000;
private readonly ConcurrentDictionary<SiloAddress, (Task PumpTask, Channel<MigrationWorkItem> WorkItemChannel)> _workers = new();
Expand Down Expand Up @@ -305,6 +306,28 @@ private void RemoveWorker(SiloAddress targetSilo)
}
}

private Task StartAsync(CancellationToken cancellationToken) => Task.CompletedTask;
private async Task StopAsync(CancellationToken cancellationToken)
{
var workerTasks = new List<Task>();
foreach (var (_, value) in _workers)
{
value.WorkItemChannel.Writer.TryComplete();
workerTasks.Add(value.PumpTask);
}

await Task.WhenAll(workerTasks).WithCancellation(cancellationToken);
}

void ILifecycleParticipant<ISiloLifecycle>.Participate(ISiloLifecycle lifecycle)
{
lifecycle.Subscribe(
nameof(ActivationMigrationManager),
ServiceLifecycleStage.RuntimeGrainServices,
ct => this.RunOrQueueTask(() => StartAsync(ct)),
ct => this.RunOrQueueTask(() => StopAsync(ct)));
}

private class MigrationWorkItem : IValueTaskSource
{
private ManualResetValueTaskSourceCore<int> _core = new() { RunContinuationsAsynchronously = true };
Expand Down
1 change: 1 addition & 0 deletions src/Orleans.Runtime/Hosting/DefaultSiloServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ internal static void AddDefaultServices(IServiceCollection services)
services.AddSingleton<MigrationContext.SerializationHooks>();
services.AddSingleton<ActivationMigrationManager>();
services.AddFromExisting<IActivationMigrationManager, ActivationMigrationManager>();
services.AddFromExisting<ILifecycleParticipant<ISiloLifecycle>, ActivationMigrationManager>();
}

private class AllowOrleansTypes : ITypeNameFilter
Expand Down
11 changes: 7 additions & 4 deletions src/Orleans.Serialization/Codecs/ByteArrayCodec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,11 @@ public static Memory<byte> DeepCopy(Memory<byte> input, CopyContext copyContext)
/// Serializer for <see cref="PooledBuffer"/> instances.
/// </summary>
[RegisterSerializer]
public sealed class PooledBufferCodec : IValueSerializer<PooledBuffer>
public sealed class PooledBufferCodec : IFieldCodec<PooledBuffer>
{
public void Serialize<TBufferWriter>(ref Writer<TBufferWriter> writer, scoped ref PooledBuffer value) where TBufferWriter : IBufferWriter<byte>
public void WriteField<TBufferWriter>(ref Writer<TBufferWriter> writer, uint fieldIdDelta, Type expectedType, PooledBuffer value) where TBufferWriter : IBufferWriter<byte>
{
writer.WriteFieldHeader(fieldIdDelta, expectedType, typeof(PooledBuffer), WireType.LengthPrefixed);
writer.WriteVarUInt32((uint)value.Length);
foreach (var segment in value)
{
Expand All @@ -311,11 +312,12 @@ public void Serialize<TBufferWriter>(ref Writer<TBufferWriter> writer, scoped re
// Senders must not use the value after sending.
// Receivers must dispose of the value after use.
value.Reset();
value = default;
}

public void Deserialize<TInput>(ref Reader<TInput> reader, scoped ref PooledBuffer value)
public PooledBuffer ReadValue<TInput>(ref Reader<TInput> reader, Field field)
{
field.EnsureWireType(WireType.LengthPrefixed);
var value = new PooledBuffer();
const int MaxSpanLength = 4096;
var length = (int)reader.ReadVarUInt32();
while (length > 0)
Expand All @@ -328,6 +330,7 @@ public void Deserialize<TInput>(ref Reader<TInput> reader, scoped ref PooledBuff
}

Debug.Assert(length == 0);
return value;
}
}

Expand Down
56 changes: 56 additions & 0 deletions test/DefaultCluster.Tests/Migration/MigrationTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Diagnostics;
using Orleans.Core.Internal;
using Orleans.Runtime;
using Orleans.Runtime.Placement;
Expand Down Expand Up @@ -75,6 +76,61 @@ public async Task DirectedGrainMigrationTest()
}
}

/// <summary>
/// Tests that multiple grains can be migrated simultaneously.
/// </summary>
[Fact, TestCategory("BVT")]
public async Task MultiGrainDirectedMigrationTest()
{
var baseId = GetRandomGrainId();
for (var i = 1; i < 100; ++i)
{
var a = GrainFactory.GetGrain<IMigrationTestGrain>(baseId + 2 * i);
var expectedState = Random.Shared.Next();
await a.SetState(expectedState);
var originalAddressA = await a.GetGrainAddress();
var originalHostA = originalAddressA.SiloAddress;

RequestContext.Set(IPlacementDirector.PlacementHintKey, originalHostA);
var b = GrainFactory.GetGrain<IMigrationTestGrain>(baseId + 1 + 2 * i);
await b.SetState(expectedState);
var originalAddressB = await b.GetGrainAddress();
Assert.Equal(originalHostA, originalAddressB.SiloAddress);

var targetHost = Fixture.HostedCluster.GetActiveSilos().Select(s => s.SiloAddress).First(address => address != originalHostA);

// Trigger migration, setting a placement hint to coerce the placement director to use the target silo
RequestContext.Set(IPlacementDirector.PlacementHintKey, targetHost);
var migrateA = a.Cast<IGrainManagementExtension>().MigrateOnIdle();
var migrateB = b.Cast<IGrainManagementExtension>().MigrateOnIdle();
await migrateA;
await migrateB;

while (true)
{
var newAddress = await a.GetGrainAddress();
if (newAddress.ActivationId != originalAddressA.ActivationId)
{
Assert.Equal(targetHost, newAddress.SiloAddress);
break;
}
}

while (true)
{
var newAddress = await b.GetGrainAddress();
if (newAddress.ActivationId != originalAddressB.ActivationId)
{
Assert.Equal(targetHost, newAddress.SiloAddress);
break;
}
}

Assert.Equal(expectedState, await a.GetState());
Assert.Equal(expectedState, await b.GetState());
}
}

/// <summary>
/// Tests that grain migration works for a simple grain which uses <see cref="Grain{TGrainState}"/> for state.
/// The test specifies an alternative location for the grain to migrate to and asserts that it migrates to that location.
Expand Down
1 change: 0 additions & 1 deletion test/NonSilo.Tests/Serialization/BuiltInSerializerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public void ValueTupleTypesHasSerializer()
/// <summary>
/// Tests that the default (non-fallback) serializer can handle complex classes.
/// </summary>
/// <param name="serializerToUse"></param>
[Fact, TestCategory("BVT")]
public void Serialize_ComplexAccessibleClass()
{
Expand Down
84 changes: 84 additions & 0 deletions test/Orleans.Serialization.UnitTests/PooledBufferTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,90 @@ static void SerializeObject(SerializerSessionPool pool, Serializer serializer, L
}
}

/// <summary>
/// Tests that the serializer can correctly serialized <see cref="PooledBuffer"/>.
/// </summary>
[Fact]
public void PooledBuffer_SerializerRoundTrip()
{
var serviceProvider = new ServiceCollection()
.AddSerializer()
.BuildServiceProvider();
var serializer = serviceProvider.GetRequiredService<Serializer>();

var random = new Random();
for (var i = 0; i < 10; i++)
{
const int TargetLength = 8120;

// NOTE: The serializer is responsible for freeing the buffer provided to it, so we do not free this.
var buffer = new PooledBuffer();
while (buffer.Length < TargetLength)
{
var span = buffer.GetSpan(TargetLength - buffer.Length);
var writeLen = Math.Min(span.Length, TargetLength - buffer.Length);
random.NextBytes(span[..writeLen]);
buffer.Advance(writeLen);
}

var bytes = buffer.ToArray();
Assert.Equal(TargetLength, bytes.Length);

var result = serializer.Deserialize<PooledBuffer>(serializer.SerializeToArray(buffer));
Assert.Equal(TargetLength, result.Length);

var resultBytes = result.ToArray();
Assert.Equal(bytes, resultBytes);

// NOTE: we are responsible for disposing a buffer returned from deserialization.
result.Dispose();
}
}

/// <summary>
/// Tests that the serializer can correctly serialized <see cref="PooledBuffer"/> when it's embedded in another structure.
/// </summary>
[Fact]
public void PooledBuffer_SerializerRoundTrip_Embedded()
{
var serviceProvider = new ServiceCollection()
.AddSerializer()
.BuildServiceProvider();
var serializer = serviceProvider.GetRequiredService<Serializer>();

var random = new Random();
for (var i = 0; i < 10; i++)
{
const int TargetLength = 8120;

// NOTE: The serializer is responsible for freeing the buffer provided to it, so we do not free this.
var buffer = new PooledBuffer();
while (buffer.Length < TargetLength)
{
var span = buffer.GetSpan(TargetLength - buffer.Length);
var writeLen = Math.Min(span.Length, TargetLength - buffer.Length);
random.NextBytes(span[..writeLen]);
buffer.Advance(writeLen);
}

var bytes = buffer.ToArray();
Assert.Equal(TargetLength, bytes.Length);

var embed = (Guid: Guid.NewGuid(), Buffer: buffer, Int: 42);
var result = serializer.Deserialize<(Guid Guid, PooledBuffer Buffer, int Int)>(serializer.SerializeToArray(embed));
Assert.Equal(embed.Guid, result.Guid);
Assert.Equal(embed.Int, result.Int);
var resultBuffer = result.Buffer;
Assert.Equal(TargetLength, resultBuffer.Length);

var resultBytes = resultBuffer.ToArray();
Assert.Equal(bytes, resultBytes);

// NOTE: we are responsible for disposing a buffer returned from deserialization.
resultBuffer.Dispose();
}
}

[GenerateSerializer]
public readonly record struct LargeObject(
[property: Id(0)] Guid Id,
Expand Down