From 64cc3aa1f43f649e046a4f9052e2ca74e7ae56d4 Mon Sep 17 00:00:00 2001 From: Reuben Bond <203839+ReubenBond@users.noreply.github.com> Date: Tue, 12 Dec 2023 10:27:58 -0800 Subject: [PATCH] Fix stack deserialization order (#8768) --- .../Codecs/StackCodec.cs | 243 +++++++++--------- .../BuiltInCodecTests.cs | 42 +++ 2 files changed, 164 insertions(+), 121 deletions(-) diff --git a/src/Orleans.Serialization/Codecs/StackCodec.cs b/src/Orleans.Serialization/Codecs/StackCodec.cs index 75d9c5545d..9ab75a7b1f 100644 --- a/src/Orleans.Serialization/Codecs/StackCodec.cs +++ b/src/Orleans.Serialization/Codecs/StackCodec.cs @@ -7,165 +7,166 @@ using Orleans.Serialization.GeneratedCodeHelpers; using Orleans.Serialization.WireProtocol; -namespace Orleans.Serialization.Codecs +namespace Orleans.Serialization.Codecs; + +/// +/// Serializer for . +/// +/// The element type. +[RegisterSerializer] +public sealed class StackCodec : IFieldCodec> { + private readonly Type CodecElementType = typeof(T); + + private readonly IFieldCodec _fieldCodec; + /// - /// Serializer for . + /// Initializes a new instance of the class. /// - /// The element type. - [RegisterSerializer] - public sealed class StackCodec : IFieldCodec> + /// The field codec. + public StackCodec(IFieldCodec fieldCodec) { - private readonly Type CodecElementType = typeof(T); - - private readonly IFieldCodec _fieldCodec; + _fieldCodec = OrleansGeneratedCodeHelper.UnwrapService(this, fieldCodec); + } - /// - /// Initializes a new instance of the class. - /// - /// The field codec. - public StackCodec(IFieldCodec fieldCodec) + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, Stack value) where TBufferWriter : IBufferWriter + { + if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value)) { - _fieldCodec = OrleansGeneratedCodeHelper.UnwrapService(this, fieldCodec); + return; } - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, Stack value) where TBufferWriter : IBufferWriter + writer.WriteFieldHeader(fieldIdDelta, expectedType, value.GetType(), WireType.TagDelimited); + + if (value.Count > 0) { - if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value)) + UInt32Codec.WriteField(ref writer, 0, (uint)value.Count); + uint innerFieldIdDelta = 1; + foreach (var element in value) { - return; + _fieldCodec.WriteField(ref writer, innerFieldIdDelta, CodecElementType, element); + innerFieldIdDelta = 0; } + } - writer.WriteFieldHeader(fieldIdDelta, expectedType, value.GetType(), WireType.TagDelimited); - - if (value.Count > 0) - { - UInt32Codec.WriteField(ref writer, 0, (uint)value.Count); - uint innerFieldIdDelta = 1; - foreach (var element in value) - { - _fieldCodec.WriteField(ref writer, innerFieldIdDelta, CodecElementType, element); - innerFieldIdDelta = 0; - } - } + writer.WriteEndObject(); + } - writer.WriteEndObject(); + /// + public Stack ReadValue(ref Reader reader, Field field) + { + if (field.WireType == WireType.Reference) + { + return ReferenceCodec.ReadReference, TInput>(ref reader, field); } - /// - public Stack ReadValue(ref Reader reader, Field field) + field.EnsureWireTypeTagDelimited(); + + var placeholderReferenceId = ReferenceCodec.CreateRecordPlaceholder(reader.Session); + T[] array = null; + var i = 0; + uint fieldId = 0; + while (true) { - if (field.WireType == WireType.Reference) + var header = reader.ReadFieldHeader(); + if (header.IsEndBaseOrEndObject) { - return ReferenceCodec.ReadReference, TInput>(ref reader, field); + break; } - field.EnsureWireTypeTagDelimited(); - - var placeholderReferenceId = ReferenceCodec.CreateRecordPlaceholder(reader.Session); - Stack result = null; - uint fieldId = 0; - while (true) + fieldId += header.FieldIdDelta; + switch (fieldId) { - var header = reader.ReadFieldHeader(); - if (header.IsEndBaseOrEndObject) - { + case 0: + var length = (int)UInt32Codec.ReadValue(ref reader, header); + if (length > 10240 && length > reader.Length) + { + ThrowInvalidSizeException(length); + } + + array = new T[length]; + i = length - 1; break; - } - - fieldId += header.FieldIdDelta; - switch (fieldId) - { - case 0: - var length = (int)UInt32Codec.ReadValue(ref reader, header); - if (length > 10240 && length > reader.Length) - { - ThrowInvalidSizeException(length); - } - - result = new(length); - ReferenceCodec.RecordObject(reader.Session, result, placeholderReferenceId); - break; - case 1: - if (result is null) - { - ThrowLengthFieldMissing(); - } - - result.Push(_fieldCodec.ReadValue(ref reader, header)); - break; - default: - reader.ConsumeUnknownField(header); - break; - } - } + case 1: + if (array is null) + { + ThrowLengthFieldMissing(); + } - if (result is null) - { - result = new(); - ReferenceCodec.RecordObject(reader.Session, result, placeholderReferenceId); + array[i--] = _fieldCodec.ReadValue(ref reader, header); + break; + default: + reader.ConsumeUnknownField(header); + break; } - - return result; } - private void ThrowInvalidSizeException(int length) => throw new IndexOutOfRangeException( - $"Declared length of {typeof(Stack)}, {length}, is greater than total length of input."); - - private void ThrowLengthFieldMissing() => throw new RequiredFieldMissingException("Serialized stack is missing its length field."); + array ??= []; + var result = new Stack(array); + ReferenceCodec.RecordObject(reader.Session, array, placeholderReferenceId); + return result; } + private void ThrowInvalidSizeException(int length) => throw new IndexOutOfRangeException( + $"Declared length of {typeof(Stack)}, {length}, is greater than total length of input."); + + private void ThrowLengthFieldMissing() => throw new RequiredFieldMissingException("Serialized stack is missing its length field."); +} + +/// +/// Copier for . +/// +/// The element type. +[RegisterCopier] +public sealed class StackCopier : IDeepCopier>, IBaseCopier> +{ + private readonly Type _fieldType = typeof(Stack); + private readonly IDeepCopier _copier; + /// - /// Copier for . + /// Initializes a new instance of the class. /// - /// The element type. - [RegisterCopier] - public sealed class StackCopier : IDeepCopier>, IBaseCopier> + /// The value copier. + public StackCopier(IDeepCopier valueCopier) + { + _copier = valueCopier; + } + + /// + public Stack DeepCopy(Stack input, CopyContext context) { - private readonly Type _fieldType = typeof(Stack); - private readonly IDeepCopier _copier; - - /// - /// Initializes a new instance of the class. - /// - /// The value copier. - public StackCopier(IDeepCopier valueCopier) + if (context.TryGetCopy>(input, out var result)) { - _copier = valueCopier; + return result; } - /// - public Stack DeepCopy(Stack input, CopyContext context) + if (input.GetType() != _fieldType) { - if (context.TryGetCopy>(input, out var result)) - { - return result; - } - - if (input.GetType() as object != _fieldType as object) - { - return context.DeepCopy(input); - } - - result = new Stack(input.Count); - context.RecordCopy(input, result); - foreach (var item in input) - { - result.Push(_copier.DeepCopy(item, context)); - } + return context.DeepCopy(input); + } - return result; + result = new Stack(input.Count); + context.RecordCopy(input, result); + var array = new T[input.Count]; + input.CopyTo(array, 0); + for (var i = array.Length - 1; i >= 0; --i) + { + result.Push(_copier.DeepCopy(array[i], context)); } - /// - public void DeepCopy(Stack input, Stack output, CopyContext context) + return result; + } + + /// + public void DeepCopy(Stack input, Stack output, CopyContext context) + { + var array = new T[input.Count]; + input.CopyTo(array, 0); + for (var i = array.Length - 1; i >= 0; --i) { - foreach (var item in input) - { - output.Push(_copier.DeepCopy(item, context)); - } + output.Push(_copier.DeepCopy(array[i], context)); } } } diff --git a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs index 6f7cb7c891..6015c0e57e 100644 --- a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs +++ b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs @@ -2028,6 +2028,48 @@ protected override Collection CreateValue() protected override Collection[] TestValues => new[] { null, new Collection(), CreateValue(), CreateValue(), CreateValue() }; } + public class StackCodecTests : FieldCodecTester, StackCodec> + { + public StackCodecTests(ITestOutputHelper output) : base(output) + { + } + + protected override Stack CreateValue() + { + var result = new Stack(); + for (var i = 0; i < Random.Next(17) + 5; i++) + { + result.Push(Random.Next()); + } + + return result; + } + + protected override bool Equals(Stack left, Stack right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right); + protected override Stack[] TestValues => new[] { null, new Stack(), CreateValue(), CreateValue(), CreateValue() }; + } + + public class StackCopierTests : CopierTester, StackCopier> + { + public StackCopierTests(ITestOutputHelper output) : base(output) + { + } + + protected override Stack CreateValue() + { + var result = new Stack(); + for (var i = 0; i < Random.Next(17) + 5; i++) + { + result.Push(Random.Next()); + } + + return result; + } + + protected override bool Equals(Stack left, Stack right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right); + protected override Stack[] TestValues => new[] { null, new Stack(), CreateValue(), CreateValue(), CreateValue() }; + } + public class QueueCodecTests : FieldCodecTester, QueueCodec> { public QueueCodecTests(ITestOutputHelper output) : base(output)