From ce800f9dbfae23d175a48ce49de8c2020e7f2b76 Mon Sep 17 00:00:00 2001 From: Reuben Bond Date: Fri, 10 May 2024 15:19:00 -0700 Subject: [PATCH] Fix behavior of DictionaryBaseCodec when values are added from constructor (#8993) (cherry picked from commit f89629b2b19a35a155f980746185ad99824224aa) # Conflicts: # test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs --- .../Codecs/DictionaryCodec.cs | 41 +++++++++++++------ .../BuiltInCodecTests.cs | 34 +++++++++++++++ 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/src/Orleans.Serialization/Codecs/DictionaryCodec.cs b/src/Orleans.Serialization/Codecs/DictionaryCodec.cs index c1cceb661a1..1c0184bdc3c 100644 --- a/src/Orleans.Serialization/Codecs/DictionaryCodec.cs +++ b/src/Orleans.Serialization/Codecs/DictionaryCodec.cs @@ -1,6 +1,8 @@ +#nullable enable using System; using System.Buffers; using System.Collections.Generic; +using System.Linq; using System.Reflection; using Orleans.Serialization.Buffers; using Orleans.Serialization.Cloning; @@ -17,7 +19,7 @@ namespace Orleans.Serialization.Codecs /// The key type. /// The value type. [RegisterSerializer] - public sealed class DictionaryCodec : IFieldCodec> + public sealed class DictionaryCodec : IFieldCodec> where TKey : notnull { private readonly Type _keyFieldType = typeof(TKey); private readonly Type _valueFieldType = typeof(TValue); @@ -83,10 +85,10 @@ public Dictionary ReadValue(ref Reader reader, Fie field.EnsureWireTypeTagDelimited(); var placeholderReferenceId = ReferenceCodec.CreateRecordPlaceholder(reader.Session); - TKey key = default; + TKey? key = default; var valueExpected = false; - Dictionary result = null; - IEqualityComparer comparer = null; + Dictionary? result = null; + IEqualityComparer? comparer = null; uint fieldId = 0; while (true) { @@ -122,7 +124,7 @@ public Dictionary ReadValue(ref Reader reader, Fie } else { - result.Add(key, _valueCodec.ReadValue(ref reader, header)); + result!.Add(key!, _valueCodec.ReadValue(ref reader, header)); valueExpected = false; } break; @@ -136,7 +138,7 @@ public Dictionary ReadValue(ref Reader reader, Fie return result; } - private Dictionary CreateInstance(int length, IEqualityComparer comparer, SerializerSession session, uint placeholderReferenceId) + private Dictionary CreateInstance(int length, IEqualityComparer? comparer, SerializerSession session, uint placeholderReferenceId) { var result = new Dictionary(length, comparer); ReferenceCodec.RecordObject(session, result, placeholderReferenceId); @@ -155,10 +157,11 @@ private Dictionary CreateInstance(int length, IEqualityComparerThe type of the t key. /// The type of the t value. [RegisterCopier] - public sealed class DictionaryCopier : IDeepCopier>, IBaseCopier> + public sealed class DictionaryCopier : IDeepCopier>, IBaseCopier> where TKey : notnull { private readonly IDeepCopier _keyCopier; private readonly IDeepCopier _valueCopier; + private readonly ConstructorInfo _baseConstructor; /// /// Initializes a new instance of the class. @@ -169,6 +172,7 @@ public DictionaryCopier(IDeepCopier keyCopier, IDeepCopier valueCo { _keyCopier = keyCopier; _valueCopier = valueCopier; + _baseConstructor = typeof(Dictionary).GetConstructor([typeof(int), typeof(IEqualityComparer)])!; } /// @@ -197,6 +201,12 @@ public Dictionary DeepCopy(Dictionary input, CopyCon /// public void DeepCopy(Dictionary input, Dictionary output, CopyContext context) { + output.Clear(); + if (input.Comparer is { } comparer) + { + _baseConstructor.Invoke(output, [input.Count, comparer]); + } + foreach (var pair in input) { output[_keyCopier.DeepCopy(pair.Key, context)] = _valueCopier.DeepCopy(pair.Value, context); @@ -210,7 +220,7 @@ public void DeepCopy(Dictionary input, Dictionary ou /// The key type. /// The value type. [RegisterSerializer] - public sealed class DictionaryBaseCodec : IBaseCodec> + public sealed class DictionaryBaseCodec : IBaseCodec> where TKey : notnull { private readonly Type _keyFieldType = typeof(TKey); private readonly Type _valueFieldType = typeof(TValue); @@ -234,7 +244,7 @@ public DictionaryBaseCodec( _keyCodec = OrleansGeneratedCodeHelper.UnwrapService(this, keyCodec); _valueCodec = OrleansGeneratedCodeHelper.UnwrapService(this, valueCodec); _comparerCodec = OrleansGeneratedCodeHelper.UnwrapService(this, comparerCodec); - _baseConstructor = typeof(Dictionary).GetConstructor(new Type[] { typeof(int), typeof(IEqualityComparer) }); + _baseConstructor = typeof(Dictionary).GetConstructor([typeof(int), typeof(IEqualityComparer)])!; } void IBaseCodec>.Serialize(ref Writer writer, Dictionary value) @@ -259,9 +269,13 @@ void IBaseCodec>.Serialize(ref Writer>.Deserialize(ref Reader reader, Dictionary value) { - TKey key = default; + // If the dictionary has some values added by the default constructor, clear them. + // If those values are in the serialized payload, they will be added below. + value.Clear(); + + TKey? key = default; var valueExpected = false; - IEqualityComparer comparer = null; + IEqualityComparer? comparer = null; uint fieldId = 0; bool hasLengthField = false; while (true) @@ -286,7 +300,8 @@ void IBaseCodec>.Deserialize(ref Reader } hasLengthField = true; - _baseConstructor.Invoke(value, new object[] { length, comparer }); + _baseConstructor.Invoke(value, [length, comparer]); + break; case 2: if (!hasLengthField) @@ -301,7 +316,7 @@ void IBaseCodec>.Deserialize(ref Reader } else { - value.Add(key, _valueCodec.ReadValue(ref reader, header)); + value.Add(key!, _valueCodec.ReadValue(ref reader, header)); valueExpected = false; } break; diff --git a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs index 6f7cb7c8910..ffd921ad6a6 100644 --- a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs +++ b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs @@ -2154,6 +2154,40 @@ protected override Dictionary CreateValue() protected override bool Equals(Dictionary left, Dictionary right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right); } + [GenerateSerializer] + public class TypeWithDictionaryBase : Dictionary + { + public TypeWithDictionaryBase() : this(true) { } + public TypeWithDictionaryBase(bool addDefaultValue) + { + if (addDefaultValue) + { + this["key"] = 1; + } + } + + [Id(0)] + public int OtherProperty { get; set; } + + public override string ToString() => $"[OtherProperty: {OtherProperty}, Values: [{string.Join(", ", this.Select(kvp => $"[{kvp.Key}] = '{kvp.Value}'"))}]]"; + } + + public class DictionaryBaseCodecTests(ITestOutputHelper output) : FieldCodecTester>(output) + { + protected override TypeWithDictionaryBase[] TestValues => [null, new(), new(addDefaultValue: false), new() { ["foo"] = 15 }, new() { ["foo"] = 15, OtherProperty = 123 }]; + + protected override TypeWithDictionaryBase CreateValue() => new() { OtherProperty = Random.Next() }; + protected override bool Equals(TypeWithDictionaryBase left, TypeWithDictionaryBase right) => ReferenceEquals(left, right) || left.SequenceEqual(right) && left.OtherProperty == right.OtherProperty; + } + + public class DictionaryBaseCopierTests(ITestOutputHelper output) : CopierTester>(output) + { + protected override TypeWithDictionaryBase[] TestValues => [null, new(), new(addDefaultValue: false), new() { ["foo"] = 15 }, new() { ["foo"] = 15, OtherProperty = 123 }]; + + protected override TypeWithDictionaryBase CreateValue() => new() { OtherProperty = Random.Next() }; + protected override bool Equals(TypeWithDictionaryBase left, TypeWithDictionaryBase right) => ReferenceEquals(left, right) || left.SequenceEqual(right) && left.OtherProperty == right.OtherProperty; + } + public class DictionaryWithComparerCodecTests : FieldCodecTester, DictionaryCodec> { protected override int[] MaxSegmentSizes => new[] { 1024 };