Skip to content

Commit

Permalink
Fix stack deserialization order (#8768)
Browse files Browse the repository at this point in the history
  • Loading branch information
ReubenBond authored Dec 12, 2023
1 parent cd09126 commit 64cc3aa
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 121 deletions.
243 changes: 122 additions & 121 deletions src/Orleans.Serialization/Codecs/StackCodec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,165 +7,166 @@
using Orleans.Serialization.GeneratedCodeHelpers;
using Orleans.Serialization.WireProtocol;

namespace Orleans.Serialization.Codecs
namespace Orleans.Serialization.Codecs;

/// <summary>
/// Serializer for <see cref="Stack{T}"/>.
/// </summary>
/// <typeparam name="T">The element type.</typeparam>
[RegisterSerializer]
public sealed class StackCodec<T> : IFieldCodec<Stack<T>>
{
private readonly Type CodecElementType = typeof(T);

private readonly IFieldCodec<T> _fieldCodec;

/// <summary>
/// Serializer for <see cref="Stack{T}"/>.
/// Initializes a new instance of the <see cref="StackCodec{T}"/> class.
/// </summary>
/// <typeparam name="T">The element type.</typeparam>
[RegisterSerializer]
public sealed class StackCodec<T> : IFieldCodec<Stack<T>>
/// <param name="fieldCodec">The field codec.</param>
public StackCodec(IFieldCodec<T> fieldCodec)
{
private readonly Type CodecElementType = typeof(T);

private readonly IFieldCodec<T> _fieldCodec;
_fieldCodec = OrleansGeneratedCodeHelper.UnwrapService(this, fieldCodec);
}

/// <summary>
/// Initializes a new instance of the <see cref="StackCodec{T}"/> class.
/// </summary>
/// <param name="fieldCodec">The field codec.</param>
public StackCodec(IFieldCodec<T> fieldCodec)
/// <inheritdoc/>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteField<TBufferWriter>(ref Writer<TBufferWriter> writer, uint fieldIdDelta, Type expectedType, Stack<T> value) where TBufferWriter : IBufferWriter<byte>
{
if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value))
{
_fieldCodec = OrleansGeneratedCodeHelper.UnwrapService(this, fieldCodec);
return;
}

/// <inheritdoc/>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void WriteField<TBufferWriter>(ref Writer<TBufferWriter> writer, uint fieldIdDelta, Type expectedType, Stack<T> value) where TBufferWriter : IBufferWriter<byte>
writer.WriteFieldHeader(fieldIdDelta, expectedType, value.GetType(), WireType.TagDelimited);

if (value.Count > 0)
{
if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value))
UInt32Codec.WriteField(ref writer, 0, (uint)value.Count);
uint innerFieldIdDelta = 1;
foreach (var element in value)
{
return;
_fieldCodec.WriteField(ref writer, innerFieldIdDelta, CodecElementType, element);
innerFieldIdDelta = 0;
}
}

writer.WriteFieldHeader(fieldIdDelta, expectedType, value.GetType(), WireType.TagDelimited);

if (value.Count > 0)
{
UInt32Codec.WriteField(ref writer, 0, (uint)value.Count);
uint innerFieldIdDelta = 1;
foreach (var element in value)
{
_fieldCodec.WriteField(ref writer, innerFieldIdDelta, CodecElementType, element);
innerFieldIdDelta = 0;
}
}
writer.WriteEndObject();
}

writer.WriteEndObject();
/// <inheritdoc/>
public Stack<T> ReadValue<TInput>(ref Reader<TInput> reader, Field field)
{
if (field.WireType == WireType.Reference)
{
return ReferenceCodec.ReadReference<Stack<T>, TInput>(ref reader, field);
}

/// <inheritdoc/>
public Stack<T> ReadValue<TInput>(ref Reader<TInput> reader, Field field)
field.EnsureWireTypeTagDelimited();

var placeholderReferenceId = ReferenceCodec.CreateRecordPlaceholder(reader.Session);
T[] array = null;
var i = 0;
uint fieldId = 0;
while (true)
{
if (field.WireType == WireType.Reference)
var header = reader.ReadFieldHeader();
if (header.IsEndBaseOrEndObject)
{
return ReferenceCodec.ReadReference<Stack<T>, TInput>(ref reader, field);
break;
}

field.EnsureWireTypeTagDelimited();

var placeholderReferenceId = ReferenceCodec.CreateRecordPlaceholder(reader.Session);
Stack<T> result = null;
uint fieldId = 0;
while (true)
fieldId += header.FieldIdDelta;
switch (fieldId)
{
var header = reader.ReadFieldHeader();
if (header.IsEndBaseOrEndObject)
{
case 0:
var length = (int)UInt32Codec.ReadValue(ref reader, header);
if (length > 10240 && length > reader.Length)
{
ThrowInvalidSizeException(length);
}

array = new T[length];
i = length - 1;
break;
}

fieldId += header.FieldIdDelta;
switch (fieldId)
{
case 0:
var length = (int)UInt32Codec.ReadValue(ref reader, header);
if (length > 10240 && length > reader.Length)
{
ThrowInvalidSizeException(length);
}

result = new(length);
ReferenceCodec.RecordObject(reader.Session, result, placeholderReferenceId);
break;
case 1:
if (result is null)
{
ThrowLengthFieldMissing();
}

result.Push(_fieldCodec.ReadValue(ref reader, header));
break;
default:
reader.ConsumeUnknownField(header);
break;
}
}
case 1:
if (array is null)
{
ThrowLengthFieldMissing();
}

if (result is null)
{
result = new();
ReferenceCodec.RecordObject(reader.Session, result, placeholderReferenceId);
array[i--] = _fieldCodec.ReadValue(ref reader, header);
break;
default:
reader.ConsumeUnknownField(header);
break;
}

return result;
}

private void ThrowInvalidSizeException(int length) => throw new IndexOutOfRangeException(
$"Declared length of {typeof(Stack<T>)}, {length}, is greater than total length of input.");

private void ThrowLengthFieldMissing() => throw new RequiredFieldMissingException("Serialized stack is missing its length field.");
array ??= [];
var result = new Stack<T>(array);
ReferenceCodec.RecordObject(reader.Session, array, placeholderReferenceId);
return result;
}

private void ThrowInvalidSizeException(int length) => throw new IndexOutOfRangeException(
$"Declared length of {typeof(Stack<T>)}, {length}, is greater than total length of input.");

private void ThrowLengthFieldMissing() => throw new RequiredFieldMissingException("Serialized stack is missing its length field.");
}

/// <summary>
/// Copier for <see cref="Stack{T}"/>.
/// </summary>
/// <typeparam name="T">The element type.</typeparam>
[RegisterCopier]
public sealed class StackCopier<T> : IDeepCopier<Stack<T>>, IBaseCopier<Stack<T>>
{
private readonly Type _fieldType = typeof(Stack<T>);
private readonly IDeepCopier<T> _copier;

/// <summary>
/// Copier for <see cref="Stack{T}"/>.
/// Initializes a new instance of the <see cref="StackCopier{T}"/> class.
/// </summary>
/// <typeparam name="T">The element type.</typeparam>
[RegisterCopier]
public sealed class StackCopier<T> : IDeepCopier<Stack<T>>, IBaseCopier<Stack<T>>
/// <param name="valueCopier">The value copier.</param>
public StackCopier(IDeepCopier<T> valueCopier)
{
_copier = valueCopier;
}

/// <inheritdoc/>
public Stack<T> DeepCopy(Stack<T> input, CopyContext context)
{
private readonly Type _fieldType = typeof(Stack<T>);
private readonly IDeepCopier<T> _copier;

/// <summary>
/// Initializes a new instance of the <see cref="StackCopier{T}"/> class.
/// </summary>
/// <param name="valueCopier">The value copier.</param>
public StackCopier(IDeepCopier<T> valueCopier)
if (context.TryGetCopy<Stack<T>>(input, out var result))
{
_copier = valueCopier;
return result;
}

/// <inheritdoc/>
public Stack<T> DeepCopy(Stack<T> input, CopyContext context)
if (input.GetType() != _fieldType)
{
if (context.TryGetCopy<Stack<T>>(input, out var result))
{
return result;
}

if (input.GetType() as object != _fieldType as object)
{
return context.DeepCopy(input);
}

result = new Stack<T>(input.Count);
context.RecordCopy(input, result);
foreach (var item in input)
{
result.Push(_copier.DeepCopy(item, context));
}
return context.DeepCopy(input);
}

return result;
result = new Stack<T>(input.Count);
context.RecordCopy(input, result);
var array = new T[input.Count];
input.CopyTo(array, 0);
for (var i = array.Length - 1; i >= 0; --i)
{
result.Push(_copier.DeepCopy(array[i], context));
}

/// <inheritdoc/>
public void DeepCopy(Stack<T> input, Stack<T> output, CopyContext context)
return result;
}

/// <inheritdoc/>
public void DeepCopy(Stack<T> input, Stack<T> output, CopyContext context)
{
var array = new T[input.Count];
input.CopyTo(array, 0);
for (var i = array.Length - 1; i >= 0; --i)
{
foreach (var item in input)
{
output.Push(_copier.DeepCopy(item, context));
}
output.Push(_copier.DeepCopy(array[i], context));
}
}
}
42 changes: 42 additions & 0 deletions test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2028,6 +2028,48 @@ protected override Collection<int> CreateValue()
protected override Collection<int>[] TestValues => new[] { null, new Collection<int>(), CreateValue(), CreateValue(), CreateValue() };
}

public class StackCodecTests : FieldCodecTester<Stack<int>, StackCodec<int>>
{
public StackCodecTests(ITestOutputHelper output) : base(output)
{
}

protected override Stack<int> CreateValue()
{
var result = new Stack<int>();
for (var i = 0; i < Random.Next(17) + 5; i++)
{
result.Push(Random.Next());
}

return result;
}

protected override bool Equals(Stack<int> left, Stack<int> right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right);
protected override Stack<int>[] TestValues => new[] { null, new Stack<int>(), CreateValue(), CreateValue(), CreateValue() };
}

public class StackCopierTests : CopierTester<Stack<int>, StackCopier<int>>
{
public StackCopierTests(ITestOutputHelper output) : base(output)
{
}

protected override Stack<int> CreateValue()
{
var result = new Stack<int>();
for (var i = 0; i < Random.Next(17) + 5; i++)
{
result.Push(Random.Next());
}

return result;
}

protected override bool Equals(Stack<int> left, Stack<int> right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right);
protected override Stack<int>[] TestValues => new[] { null, new Stack<int>(), CreateValue(), CreateValue(), CreateValue() };
}

public class QueueCodecTests : FieldCodecTester<Queue<int>, QueueCodec<int>>
{
public QueueCodecTests(ITestOutputHelper output) : base(output)
Expand Down

0 comments on commit 64cc3aa

Please sign in to comment.