From 64cc3aa1f43f649e046a4f9052e2ca74e7ae56d4 Mon Sep 17 00:00:00 2001
From: Reuben Bond <203839+ReubenBond@users.noreply.github.com>
Date: Tue, 12 Dec 2023 10:27:58 -0800
Subject: [PATCH] Fix stack deserialization order (#8768)
---
.../Codecs/StackCodec.cs | 243 +++++++++---------
.../BuiltInCodecTests.cs | 42 +++
2 files changed, 164 insertions(+), 121 deletions(-)
diff --git a/src/Orleans.Serialization/Codecs/StackCodec.cs b/src/Orleans.Serialization/Codecs/StackCodec.cs
index 75d9c5545d..9ab75a7b1f 100644
--- a/src/Orleans.Serialization/Codecs/StackCodec.cs
+++ b/src/Orleans.Serialization/Codecs/StackCodec.cs
@@ -7,165 +7,166 @@
using Orleans.Serialization.GeneratedCodeHelpers;
using Orleans.Serialization.WireProtocol;
-namespace Orleans.Serialization.Codecs
+namespace Orleans.Serialization.Codecs;
+
+///
+/// Serializer for .
+///
+/// The element type.
+[RegisterSerializer]
+public sealed class StackCodec : IFieldCodec>
{
+ private readonly Type CodecElementType = typeof(T);
+
+ private readonly IFieldCodec _fieldCodec;
+
///
- /// Serializer for .
+ /// Initializes a new instance of the class.
///
- /// The element type.
- [RegisterSerializer]
- public sealed class StackCodec : IFieldCodec>
+ /// The field codec.
+ public StackCodec(IFieldCodec fieldCodec)
{
- private readonly Type CodecElementType = typeof(T);
-
- private readonly IFieldCodec _fieldCodec;
+ _fieldCodec = OrleansGeneratedCodeHelper.UnwrapService(this, fieldCodec);
+ }
- ///
- /// Initializes a new instance of the class.
- ///
- /// The field codec.
- public StackCodec(IFieldCodec fieldCodec)
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, Stack value) where TBufferWriter : IBufferWriter
+ {
+ if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value))
{
- _fieldCodec = OrleansGeneratedCodeHelper.UnwrapService(this, fieldCodec);
+ return;
}
- ///
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public void WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, Stack value) where TBufferWriter : IBufferWriter
+ writer.WriteFieldHeader(fieldIdDelta, expectedType, value.GetType(), WireType.TagDelimited);
+
+ if (value.Count > 0)
{
- if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value))
+ UInt32Codec.WriteField(ref writer, 0, (uint)value.Count);
+ uint innerFieldIdDelta = 1;
+ foreach (var element in value)
{
- return;
+ _fieldCodec.WriteField(ref writer, innerFieldIdDelta, CodecElementType, element);
+ innerFieldIdDelta = 0;
}
+ }
- writer.WriteFieldHeader(fieldIdDelta, expectedType, value.GetType(), WireType.TagDelimited);
-
- if (value.Count > 0)
- {
- UInt32Codec.WriteField(ref writer, 0, (uint)value.Count);
- uint innerFieldIdDelta = 1;
- foreach (var element in value)
- {
- _fieldCodec.WriteField(ref writer, innerFieldIdDelta, CodecElementType, element);
- innerFieldIdDelta = 0;
- }
- }
+ writer.WriteEndObject();
+ }
- writer.WriteEndObject();
+ ///
+ public Stack ReadValue(ref Reader reader, Field field)
+ {
+ if (field.WireType == WireType.Reference)
+ {
+ return ReferenceCodec.ReadReference, TInput>(ref reader, field);
}
- ///
- public Stack ReadValue(ref Reader reader, Field field)
+ field.EnsureWireTypeTagDelimited();
+
+ var placeholderReferenceId = ReferenceCodec.CreateRecordPlaceholder(reader.Session);
+ T[] array = null;
+ var i = 0;
+ uint fieldId = 0;
+ while (true)
{
- if (field.WireType == WireType.Reference)
+ var header = reader.ReadFieldHeader();
+ if (header.IsEndBaseOrEndObject)
{
- return ReferenceCodec.ReadReference, TInput>(ref reader, field);
+ break;
}
- field.EnsureWireTypeTagDelimited();
-
- var placeholderReferenceId = ReferenceCodec.CreateRecordPlaceholder(reader.Session);
- Stack result = null;
- uint fieldId = 0;
- while (true)
+ fieldId += header.FieldIdDelta;
+ switch (fieldId)
{
- var header = reader.ReadFieldHeader();
- if (header.IsEndBaseOrEndObject)
- {
+ case 0:
+ var length = (int)UInt32Codec.ReadValue(ref reader, header);
+ if (length > 10240 && length > reader.Length)
+ {
+ ThrowInvalidSizeException(length);
+ }
+
+ array = new T[length];
+ i = length - 1;
break;
- }
-
- fieldId += header.FieldIdDelta;
- switch (fieldId)
- {
- case 0:
- var length = (int)UInt32Codec.ReadValue(ref reader, header);
- if (length > 10240 && length > reader.Length)
- {
- ThrowInvalidSizeException(length);
- }
-
- result = new(length);
- ReferenceCodec.RecordObject(reader.Session, result, placeholderReferenceId);
- break;
- case 1:
- if (result is null)
- {
- ThrowLengthFieldMissing();
- }
-
- result.Push(_fieldCodec.ReadValue(ref reader, header));
- break;
- default:
- reader.ConsumeUnknownField(header);
- break;
- }
- }
+ case 1:
+ if (array is null)
+ {
+ ThrowLengthFieldMissing();
+ }
- if (result is null)
- {
- result = new();
- ReferenceCodec.RecordObject(reader.Session, result, placeholderReferenceId);
+ array[i--] = _fieldCodec.ReadValue(ref reader, header);
+ break;
+ default:
+ reader.ConsumeUnknownField(header);
+ break;
}
-
- return result;
}
- private void ThrowInvalidSizeException(int length) => throw new IndexOutOfRangeException(
- $"Declared length of {typeof(Stack)}, {length}, is greater than total length of input.");
-
- private void ThrowLengthFieldMissing() => throw new RequiredFieldMissingException("Serialized stack is missing its length field.");
+ array ??= [];
+ var result = new Stack(array);
+ ReferenceCodec.RecordObject(reader.Session, array, placeholderReferenceId);
+ return result;
}
+ private void ThrowInvalidSizeException(int length) => throw new IndexOutOfRangeException(
+ $"Declared length of {typeof(Stack)}, {length}, is greater than total length of input.");
+
+ private void ThrowLengthFieldMissing() => throw new RequiredFieldMissingException("Serialized stack is missing its length field.");
+}
+
+///
+/// Copier for .
+///
+/// The element type.
+[RegisterCopier]
+public sealed class StackCopier : IDeepCopier>, IBaseCopier>
+{
+ private readonly Type _fieldType = typeof(Stack);
+ private readonly IDeepCopier _copier;
+
///
- /// Copier for .
+ /// Initializes a new instance of the class.
///
- /// The element type.
- [RegisterCopier]
- public sealed class StackCopier : IDeepCopier>, IBaseCopier>
+ /// The value copier.
+ public StackCopier(IDeepCopier valueCopier)
+ {
+ _copier = valueCopier;
+ }
+
+ ///
+ public Stack DeepCopy(Stack input, CopyContext context)
{
- private readonly Type _fieldType = typeof(Stack);
- private readonly IDeepCopier _copier;
-
- ///
- /// Initializes a new instance of the class.
- ///
- /// The value copier.
- public StackCopier(IDeepCopier valueCopier)
+ if (context.TryGetCopy>(input, out var result))
{
- _copier = valueCopier;
+ return result;
}
- ///
- public Stack DeepCopy(Stack input, CopyContext context)
+ if (input.GetType() != _fieldType)
{
- if (context.TryGetCopy>(input, out var result))
- {
- return result;
- }
-
- if (input.GetType() as object != _fieldType as object)
- {
- return context.DeepCopy(input);
- }
-
- result = new Stack(input.Count);
- context.RecordCopy(input, result);
- foreach (var item in input)
- {
- result.Push(_copier.DeepCopy(item, context));
- }
+ return context.DeepCopy(input);
+ }
- return result;
+ result = new Stack(input.Count);
+ context.RecordCopy(input, result);
+ var array = new T[input.Count];
+ input.CopyTo(array, 0);
+ for (var i = array.Length - 1; i >= 0; --i)
+ {
+ result.Push(_copier.DeepCopy(array[i], context));
}
- ///
- public void DeepCopy(Stack input, Stack output, CopyContext context)
+ return result;
+ }
+
+ ///
+ public void DeepCopy(Stack input, Stack output, CopyContext context)
+ {
+ var array = new T[input.Count];
+ input.CopyTo(array, 0);
+ for (var i = array.Length - 1; i >= 0; --i)
{
- foreach (var item in input)
- {
- output.Push(_copier.DeepCopy(item, context));
- }
+ output.Push(_copier.DeepCopy(array[i], context));
}
}
}
diff --git a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs
index 6f7cb7c891..6015c0e57e 100644
--- a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs
+++ b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs
@@ -2028,6 +2028,48 @@ protected override Collection CreateValue()
protected override Collection[] TestValues => new[] { null, new Collection(), CreateValue(), CreateValue(), CreateValue() };
}
+ public class StackCodecTests : FieldCodecTester, StackCodec>
+ {
+ public StackCodecTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ protected override Stack CreateValue()
+ {
+ var result = new Stack();
+ for (var i = 0; i < Random.Next(17) + 5; i++)
+ {
+ result.Push(Random.Next());
+ }
+
+ return result;
+ }
+
+ protected override bool Equals(Stack left, Stack right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right);
+ protected override Stack[] TestValues => new[] { null, new Stack(), CreateValue(), CreateValue(), CreateValue() };
+ }
+
+ public class StackCopierTests : CopierTester, StackCopier>
+ {
+ public StackCopierTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ protected override Stack CreateValue()
+ {
+ var result = new Stack();
+ for (var i = 0; i < Random.Next(17) + 5; i++)
+ {
+ result.Push(Random.Next());
+ }
+
+ return result;
+ }
+
+ protected override bool Equals(Stack left, Stack right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right);
+ protected override Stack[] TestValues => new[] { null, new Stack(), CreateValue(), CreateValue(), CreateValue() };
+ }
+
public class QueueCodecTests : FieldCodecTester, QueueCodec>
{
public QueueCodecTests(ITestOutputHelper output) : base(output)