diff --git a/src/Serializers/Orleans.Serialization.Protobuf/ByteStringCodec.cs b/src/Serializers/Orleans.Serialization.Protobuf/ByteStringCodec.cs new file mode 100644 index 0000000000..561ae06bcf --- /dev/null +++ b/src/Serializers/Orleans.Serialization.Protobuf/ByteStringCodec.cs @@ -0,0 +1,42 @@ +using System; +using Google.Protobuf; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.WireProtocol; + +namespace Orleans.Serialization; + +/// +/// Serializer for . +/// +[RegisterSerializer] +public sealed class ByteStringCodec : IFieldCodec +{ + /// + ByteString IFieldCodec.ReadValue(ref Reader reader, Field field) + { + if (field.WireType == WireType.Reference) + { + return ReferenceCodec.ReadReference(ref reader, field); + } + + field.EnsureWireType(WireType.LengthPrefixed); + var length = reader.ReadVarUInt32(); + var result = UnsafeByteOperations.UnsafeWrap(reader.ReadBytes(length)); + ReferenceCodec.RecordObject(reader.Session, result); + return result; + } + + /// + void IFieldCodec.WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, ByteString value) + { + if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value)) + { + return; + } + + writer.WriteFieldHeader(fieldIdDelta, expectedType, typeof(ByteString), WireType.LengthPrefixed); + writer.WriteVarUInt32((uint)value.Length); + writer.Write(value.Span); + } +} \ No newline at end of file diff --git a/src/Serializers/Orleans.Serialization.Protobuf/ByteStringCopier.cs b/src/Serializers/Orleans.Serialization.Protobuf/ByteStringCopier.cs new file mode 100644 index 0000000000..3f96ea306b --- /dev/null +++ b/src/Serializers/Orleans.Serialization.Protobuf/ByteStringCopier.cs @@ -0,0 +1,24 @@ +using Google.Protobuf; +using Orleans.Serialization.Cloning; + +namespace Orleans.Serialization; + +/// +/// Copier for . +/// +[RegisterCopier] +public sealed class ByteStringCopier : IDeepCopier +{ + /// + public ByteString DeepCopy(ByteString input, CopyContext context) + { + if (context.TryGetCopy(input, out var result)) + { + return result; + } + + result = ByteString.CopyFrom(input.Span); + context.RecordCopy(input, result); + return result; + } +} \ No newline at end of file diff --git a/src/Serializers/Orleans.Serialization.Protobuf/MapFieldCodec.cs b/src/Serializers/Orleans.Serialization.Protobuf/MapFieldCodec.cs new file mode 100644 index 0000000000..5ca1fe953a --- /dev/null +++ b/src/Serializers/Orleans.Serialization.Protobuf/MapFieldCodec.cs @@ -0,0 +1,135 @@ +using System; +using System.Buffers; +using Google.Protobuf.Collections; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.GeneratedCodeHelpers; +using Orleans.Serialization.Session; +using Orleans.Serialization.WireProtocol; + +namespace Orleans.Serialization; + +/// +/// Serializer for . +/// +/// The key type. +/// The value type. +[RegisterSerializer] +public sealed class MapFieldCodec : IFieldCodec> +{ + private readonly Type _keyFieldType = typeof(TKey); + private readonly Type _valueFieldType = typeof(TValue); + + private readonly IFieldCodec _keyCodec; + private readonly IFieldCodec _valueCodec; + + /// + /// Initializes a new instance of the class. + /// + /// The key codec. + /// The value codec. + public MapFieldCodec( + IFieldCodec keyCodec, + IFieldCodec valueCodec) + { + _keyCodec = OrleansGeneratedCodeHelper.UnwrapService(this, keyCodec); + _valueCodec = OrleansGeneratedCodeHelper.UnwrapService(this, valueCodec); + } + + /// + public void WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, MapField value) where TBufferWriter : IBufferWriter + { + if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value)) + { + return; + } + + writer.WriteFieldHeader(fieldIdDelta, expectedType, value.GetType(), WireType.TagDelimited); + + if (value.Count > 0) + { + UInt32Codec.WriteField(ref writer, 0, (uint)value.Count); + uint innerFieldIdDelta = 1; + foreach (var element in value) + { + _keyCodec.WriteField(ref writer, innerFieldIdDelta, _keyFieldType, element.Key); + _valueCodec.WriteField(ref writer, 0, _valueFieldType, element.Value); + innerFieldIdDelta = 0; + } + } + + writer.WriteEndObject(); + } + + /// + public MapField ReadValue(ref Reader reader, Field field) + { + if (field.WireType == WireType.Reference) + { + return ReferenceCodec.ReadReference, TInput>(ref reader, field); + } + + field.EnsureWireTypeTagDelimited(); + + var placeholderReferenceId = ReferenceCodec.CreateRecordPlaceholder(reader.Session); + TKey key = default; + var valueExpected = false; + MapField result = null; + uint fieldId = 0; + while (true) + { + var header = reader.ReadFieldHeader(); + if (header.IsEndBaseOrEndObject) + { + break; + } + + fieldId += header.FieldIdDelta; + switch (fieldId) + { + case 0: + var length = (int)UInt32Codec.ReadValue(ref reader, header); + if (length > 10240 && length > reader.Length) + { + ThrowInvalidSizeException(length); + } + + result = CreateInstance(reader.Session, placeholderReferenceId); + break; + case 1: + if (result is null) + ThrowLengthFieldMissing(); + + if (!valueExpected) + { + key = _keyCodec.ReadValue(ref reader, header); + valueExpected = true; + } + else + { + result.Add(key, _valueCodec.ReadValue(ref reader, header)); + valueExpected = false; + } + break; + default: + reader.ConsumeUnknownField(header); + break; + } + } + + result ??= CreateInstance(reader.Session, placeholderReferenceId); + return result; + } + + private static MapField CreateInstance(SerializerSession session, uint placeholderReferenceId) + { + var result = new MapField(); + ReferenceCodec.RecordObject(session, result, placeholderReferenceId); + return result; + } + + private static void ThrowInvalidSizeException(int length) => throw new IndexOutOfRangeException( + $"Declared length of {typeof(MapField)}, {length}, is greater than total length of input."); + + private static void ThrowLengthFieldMissing() => throw new RequiredFieldMissingException("Serialized MapField is missing its length field."); +} \ No newline at end of file diff --git a/src/Serializers/Orleans.Serialization.Protobuf/MapFieldCopier.cs b/src/Serializers/Orleans.Serialization.Protobuf/MapFieldCopier.cs new file mode 100644 index 0000000000..9d1093a885 --- /dev/null +++ b/src/Serializers/Orleans.Serialization.Protobuf/MapFieldCopier.cs @@ -0,0 +1,59 @@ +using Google.Protobuf.Collections; +using Orleans.Serialization.Cloning; + +namespace Orleans.Serialization; + +/// +/// Copier for . +/// +/// The type of the t key. +/// The type of the t value. +[RegisterCopier] +public sealed class MapFieldCopier : IDeepCopier>, IBaseCopier> +{ + private readonly IDeepCopier _keyCopier; + private readonly IDeepCopier _valueCopier; + + /// + /// Initializes a new instance of the class. + /// + /// The key copier. + /// The value copier. + public MapFieldCopier(IDeepCopier keyCopier, IDeepCopier valueCopier) + { + _keyCopier = keyCopier; + _valueCopier = valueCopier; + } + + /// + public MapField DeepCopy(MapField input, CopyContext context) + { + if (context.TryGetCopy>(input, out var result)) + { + return result; + } + + if (input.GetType() != typeof(MapField)) + { + return context.DeepCopy(input); + } + + result = new MapField(); + context.RecordCopy(input, result); + foreach (var pair in input) + { + result[_keyCopier.DeepCopy(pair.Key, context)] = _valueCopier.DeepCopy(pair.Value, context); + } + + return result; + } + + /// + public void DeepCopy(MapField input, MapField output, CopyContext context) + { + foreach (var pair in input) + { + output[_keyCopier.DeepCopy(pair.Key, context)] = _valueCopier.DeepCopy(pair.Value, context); + } + } +} \ No newline at end of file diff --git a/src/Serializers/Orleans.Serialization.Protobuf/RepeatedFieldCodec.cs b/src/Serializers/Orleans.Serialization.Protobuf/RepeatedFieldCodec.cs new file mode 100644 index 0000000000..909cc8fca5 --- /dev/null +++ b/src/Serializers/Orleans.Serialization.Protobuf/RepeatedFieldCodec.cs @@ -0,0 +1,118 @@ +using System; +using System.Buffers; +using System.Runtime.CompilerServices; +using Google.Protobuf.Collections; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.GeneratedCodeHelpers; +using Orleans.Serialization.WireProtocol; + +namespace Orleans.Serialization; + +/// +/// Serializer for . +/// +/// The element type. +[RegisterSerializer] +public sealed class RepeatedFieldCodec : IFieldCodec> +{ + private readonly Type CodecElementType = typeof(T); + + private readonly IFieldCodec _fieldCodec; + + /// + /// Initializes a new instance of the class. + /// + /// The field codec. + public RepeatedFieldCodec(IFieldCodec fieldCodec) + { + _fieldCodec = OrleansGeneratedCodeHelper.UnwrapService(this, fieldCodec); + } + + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, RepeatedField value) where TBufferWriter : IBufferWriter + { + if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value)) + { + return; + } + + writer.WriteFieldHeader(fieldIdDelta, expectedType, value.GetType(), WireType.TagDelimited); + + if (value.Count > 0) + { + UInt32Codec.WriteField(ref writer, 0, (uint)value.Count); + uint innerFieldIdDelta = 1; + foreach (var element in value) + { + _fieldCodec.WriteField(ref writer, innerFieldIdDelta, CodecElementType, element); + innerFieldIdDelta = 0; + } + } + + writer.WriteEndObject(); + } + + /// + public RepeatedField ReadValue(ref Reader reader, Field field) + { + if (field.WireType == WireType.Reference) + { + return ReferenceCodec.ReadReference, TInput>(ref reader, field); + } + + field.EnsureWireTypeTagDelimited(); + + var placeholderReferenceId = ReferenceCodec.CreateRecordPlaceholder(reader.Session); + RepeatedField result = null; + uint fieldId = 0; + while (true) + { + var header = reader.ReadFieldHeader(); + if (header.IsEndBaseOrEndObject) + { + break; + } + + fieldId += header.FieldIdDelta; + switch (fieldId) + { + case 0: + var length = (int)UInt32Codec.ReadValue(ref reader, header); + if (length > 10240 && length > reader.Length) + { + ThrowInvalidSizeException(length); + } + + result = new RepeatedField{ Capacity = length }; + ReferenceCodec.RecordObject(reader.Session, result, placeholderReferenceId); + break; + case 1: + if (result is null) + { + ThrowLengthFieldMissing(); + } + + result.Add(_fieldCodec.ReadValue(ref reader, header)); + break; + default: + reader.ConsumeUnknownField(header); + break; + } + } + + if (result is null) + { + result = new(); + ReferenceCodec.RecordObject(reader.Session, result, placeholderReferenceId); + } + + return result; + } + + private static void ThrowInvalidSizeException(int length) => throw new IndexOutOfRangeException( + $"Declared length of {typeof(RepeatedField)}, {length}, is greater than total length of input."); + + private static void ThrowLengthFieldMissing() => throw new RequiredFieldMissingException("Serialized RepeatedField is missing its length field."); +} \ No newline at end of file diff --git a/src/Serializers/Orleans.Serialization.Protobuf/RepeatedFieldCopier.cs b/src/Serializers/Orleans.Serialization.Protobuf/RepeatedFieldCopier.cs new file mode 100644 index 0000000000..61aae9ea8e --- /dev/null +++ b/src/Serializers/Orleans.Serialization.Protobuf/RepeatedFieldCopier.cs @@ -0,0 +1,55 @@ +using Google.Protobuf.Collections; +using Orleans.Serialization.Cloning; + +namespace Orleans.Serialization; + +/// +/// Copier for . +/// +/// The element type. +[RegisterCopier] +public sealed class RepeatedFieldCopier : IDeepCopier>, IBaseCopier> +{ + private readonly IDeepCopier _copier; + + /// + /// Initializes a new instance of the class. + /// + /// The value copier. + public RepeatedFieldCopier(IDeepCopier valueCopier) + { + _copier = valueCopier; + } + + /// + public RepeatedField DeepCopy(RepeatedField input, CopyContext context) + { + if (context.TryGetCopy>(input, out var result)) + { + return result; + } + + if (input.GetType() != typeof(RepeatedField)) + { + return context.DeepCopy(input); + } + + result = new RepeatedField { Capacity = input.Count }; + context.RecordCopy(input, result); + foreach (var item in input) + { + result.Add(_copier.DeepCopy(item, context)); + } + + return result; + } + + /// + public void DeepCopy(RepeatedField input, RepeatedField output, CopyContext context) + { + foreach (var item in input) + { + output.Add(_copier.DeepCopy(item, context)); + } + } +} \ No newline at end of file diff --git a/test/Orleans.Serialization.UnitTests/ProtobufSerializerTests.cs b/test/Orleans.Serialization.UnitTests/ProtobufSerializerTests.cs index 0ae491fdba..5ff46c04e8 100644 --- a/test/Orleans.Serialization.UnitTests/ProtobufSerializerTests.cs +++ b/test/Orleans.Serialization.UnitTests/ProtobufSerializerTests.cs @@ -1,6 +1,8 @@ #nullable enable using System; +using System.Linq; using Google.Protobuf; +using Google.Protobuf.Collections; using Microsoft.Extensions.DependencyInjection; using Orleans.Serialization.Cloning; using Orleans.Serialization.Codecs; @@ -101,6 +103,134 @@ protected override void Configure(ISerializerBuilder builder) }; } +[Trait("Category", "BVT")] +public class ProtobufRepeatedFieldCodecTests : FieldCodecTester, RepeatedFieldCodec> +{ + public ProtobufRepeatedFieldCodecTests(ITestOutputHelper output) : base(output) + { + } + + protected override RepeatedField CreateValue() + { + var result = new RepeatedField(); + for (var i = 0; i < Random.Next(17) + 5; i++) + { + result.Add(Random.Next()); + } + + return result; + } + + protected override bool Equals(RepeatedField left, RepeatedField right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right); + protected override RepeatedField[] TestValues => new[] { new RepeatedField(), CreateValue(), CreateValue(), CreateValue() }; +} + +[Trait("Category", "BVT")] +public class ProtobufRepeatedFieldCopierTests : CopierTester, IDeepCopier>> +{ + public ProtobufRepeatedFieldCopierTests(ITestOutputHelper output) : base(output) + { + } + + protected override IDeepCopier> CreateCopier() => ServiceProvider.GetRequiredService().GetDeepCopier>(); + + protected override RepeatedField CreateValue() + { + var result = new RepeatedField(); + for (var i = 0; i < Random.Next(17) + 5; i++) + { + result.Add(Random.Next()); + } + + return result; + } + + protected override bool Equals(RepeatedField left, RepeatedField right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right); + protected override RepeatedField[] TestValues => new[] { new RepeatedField(), CreateValue(), CreateValue(), CreateValue() }; +} + +[Trait("Category", "BVT")] +public class MapFieldCodecTests : FieldCodecTester, MapFieldCodec> +{ + public MapFieldCodecTests(ITestOutputHelper output) : base(output) + { + } + + protected override MapField CreateValue() + { + var result = new MapField(); + for (var i = 0; i < Random.Next(17) + 5; i++) + { + result[Random.Next().ToString()] = Random.Next(); + } + + return result; + } + + protected override MapField[] TestValues => new[] { new MapField(), CreateValue(), CreateValue(), CreateValue() }; + protected override bool Equals(MapField left, MapField right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right); +} + +[Trait("Category", "BVT")] +public class MapFieldCopierTests : CopierTester, MapFieldCopier> +{ + public MapFieldCopierTests(ITestOutputHelper output) : base(output) + { + } + + protected override MapField CreateValue() + { + var result = new MapField(); + for (var i = 0; i < Random.Next(17) + 5; i++) + { + result[Random.Next().ToString()] = Random.Next(); + } + + return result; + } + + protected override MapField[] TestValues => new[] { new MapField(), CreateValue(), CreateValue(), CreateValue() }; + protected override bool Equals(MapField left, MapField right) => object.ReferenceEquals(left, right) || left.SequenceEqual(right); +} + +[Trait("Category", "BVT")] +public class ByteStringCodecTests : FieldCodecTester +{ + public ByteStringCodecTests(ITestOutputHelper output) : base(output) + { + } + + protected override ByteString CreateValue() => Guid.NewGuid().ToByteString(); + + protected override bool Equals(ByteString left, ByteString right) => ReferenceEquals(left, right) || left.SequenceEqual(right); + + protected override ByteString[] TestValues => new[] + { + ByteString.Empty, + ByteString.CopyFrom(Enumerable.Range(0, 4097).Select(b => unchecked((byte)b)).ToArray()), + CreateValue() + }; +} + +[Trait("Category", "BVT")] +public class ByteStringCopierTests : CopierTester +{ + public ByteStringCopierTests(ITestOutputHelper output) : base(output) + { + } + + protected override ByteString CreateValue() => Guid.NewGuid().ToByteString(); + + protected override bool Equals(ByteString left, ByteString right) => ReferenceEquals(left, right) || left.SequenceEqual(right); + + protected override ByteString[] TestValues => new[] + { + ByteString.Empty, + ByteString.CopyFrom(Enumerable.Range(0, 4097).Select(b => unchecked((byte)b)).ToArray()), + CreateValue() + }; +} + public static class ProtobufGuidExtensions { public static ByteString ToByteString(this Guid guid)