Skip to content

Commit

Permalink
Add BitArray serialization codec (#9121)
Browse files Browse the repository at this point in the history
  • Loading branch information
ReubenBond authored Aug 21, 2024
1 parent 609f72e commit 818ef3a
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 7 deletions.
90 changes: 90 additions & 0 deletions src/Orleans.Serialization/Codecs/ByteArrayCodec.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,104 @@
using System;
using System.Buffers;
using System.Collections;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Orleans.Serialization.Buffers;
using Orleans.Serialization.Cloning;
using Orleans.Serialization.Serializers;
using Orleans.Serialization.WireProtocol;

namespace Orleans.Serialization.Codecs
{
/// <summary>
/// Serializer for <see cref="BitArray"/> arrays.
/// </summary>
[RegisterSerializer]
public sealed partial class BitArrayCodec : IFieldCodec<BitArray>
{
#if NET8_0_OR_GREATER
[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "m_array")]
extern static ref int[] GetSetArray(BitArray bitArray);
#else
private static int[] GetSetArray(BitArray bitArray) => typeof(BitArray).GetField("m_array", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic).GetValue(bitArray) as int[];
#endif

BitArray IFieldCodec<BitArray>.ReadValue<TInput>(ref Reader<TInput> reader, Field field) => ReadValue(ref reader, field);

/// <summary>
/// Reads a value.
/// </summary>
/// <typeparam name="TInput">The reader input type.</typeparam>
/// <param name="reader">The reader.</param>
/// <param name="field">The field.</param>
/// <returns>The value.</returns>
public static BitArray ReadValue<TInput>(ref Reader<TInput> reader, Field field)
{
if (field.WireType == WireType.Reference)
{
return ReferenceCodec.ReadReference<BitArray, TInput>(ref reader, field);
}

field.EnsureWireType(WireType.LengthPrefixed);
var numBytes = reader.ReadVarUInt32();
var result = new BitArray((int)numBytes * 8, false);
var resultArray = GetSetArray(result);
reader.ReadBytes(MemoryMarshal.AsBytes(resultArray.AsSpan()).Slice(0, (int)numBytes));

ReferenceCodec.RecordObject(reader.Session, result);
return result;
}

void IFieldCodec<BitArray>.WriteField<TBufferWriter>(ref Writer<TBufferWriter> writer, uint fieldIdDelta, Type expectedType, BitArray value)
{
if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value))
{
return;
}

writer.WriteFieldHeader(fieldIdDelta, expectedType, typeof(BitArray), WireType.LengthPrefixed);
var numBytes = GetByteArrayLengthFromBitLength(value.Length);
writer.WriteVarUInt32((uint)numBytes);
writer.Write(MemoryMarshal.AsBytes(GetSetArray(value).AsSpan()).Slice(0, numBytes));

static int GetByteArrayLengthFromBitLength(int n)
{
const int BitShiftPerByte = 3;
Debug.Assert(n >= 0);
return (int)((uint)(n - 1 + (1 << BitShiftPerByte)) >> BitShiftPerByte);
}
}
}

/// <summary>
/// Copier for <see cref="byte"/> arrays.
/// </summary>
[RegisterCopier]
public sealed class BitArrayCopier : IDeepCopier<BitArray>
{
/// <inheritdoc/>
BitArray IDeepCopier<BitArray>.DeepCopy(BitArray input, CopyContext context) => DeepCopy(input, context);

/// <summary>
/// Creates a deep copy of the provided input.
/// </summary>
/// <param name="input">The input.</param>
/// <param name="context">The context.</param>
/// <returns>A copy of <paramref name="input" />.</returns>
public static BitArray DeepCopy(BitArray input, CopyContext context)
{
if (context.TryGetCopy<BitArray>(input, out var result))
{
return result;
}

result = new(input);
context.RecordCopy(input, result);
return result;
}
}

/// <summary>
/// Serializer for <see cref="byte"/> arrays.
/// </summary>
Expand Down
44 changes: 37 additions & 7 deletions test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void EveryCodecHasTests()
{ } t => t,
};
typesWithCopiers.Add(typeArg);
}
}
}
}

Expand Down Expand Up @@ -1045,6 +1045,36 @@ public class ObjectCopierTests(ITestOutputHelper output) : CopierTester<object,
protected override bool IsImmutable => true;
}

public class BitArrayCodecTests(ITestOutputHelper output) : FieldCodecTester<BitArray, BitArrayCodec>(output)
{
protected override BitArray CreateValue() => new BitArray(Guid.NewGuid().ToByteArray());

protected override bool Equals(BitArray left, BitArray right) => ReferenceEquals(left, right) || left.Length == right.Length && Enumerable.Range(0, left.Length).All(i => left[i] == right[i]);

protected override BitArray[] TestValues =>
[
null,
new BitArray(0, false),
new BitArray(Enumerable.Range(0, Random.Next(4097)).Select(b => unchecked((byte)b)).ToArray()),
CreateValue(),
];
}

public class BitArrayCopierTests(ITestOutputHelper output) : CopierTester<BitArray, BitArrayCopier>(output)
{
protected override BitArray CreateValue() => new BitArray(Guid.NewGuid().ToByteArray());

protected override bool Equals(BitArray left, BitArray right) => ReferenceEquals(left, right) || left.Length == right.Length && Enumerable.Range(0, left.Length).All(i => left[i] == right[i]);

protected override BitArray[] TestValues =>
[
null,
new BitArray(0, false),
new BitArray(Enumerable.Range(0, Random.Next(4097)).Select(b => unchecked((byte)b)).ToArray()),
CreateValue(),
];
}

public class ByteArrayCodecTests(ITestOutputHelper output) : FieldCodecTester<byte[], ByteArrayCodec>(output)
{
protected override byte[] CreateValue() => Guid.NewGuid().ToByteArray();
Expand Down Expand Up @@ -1084,7 +1114,7 @@ protected override Memory<int> CreateValue()
}

protected override bool Equals(Memory<int> left, Memory<int> right) => left.Span.SequenceEqual(right.Span);
protected override Memory<int>[] TestValues => [null, new Memory<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
protected override Memory<int>[] TestValues => [default, new Memory<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
}

public class MemoryCopierTests(ITestOutputHelper output) : CopierTester<Memory<int>, MemoryCopier<int>>(output)
Expand All @@ -1098,7 +1128,7 @@ protected override Memory<int> CreateValue()
}

protected override bool Equals(Memory<int> left, Memory<int> right) => left.Span.SequenceEqual(right.Span);
protected override Memory<int>[] TestValues => [null, new Memory<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
protected override Memory<int>[] TestValues => [default, new Memory<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
}

public class ReadOnlyMemoryCodecTests(ITestOutputHelper output) : FieldCodecTester<ReadOnlyMemory<int>, ReadOnlyMemoryCodec<int>>(output)
Expand All @@ -1112,7 +1142,7 @@ protected override ReadOnlyMemory<int> CreateValue()
}

protected override bool Equals(ReadOnlyMemory<int> left, ReadOnlyMemory<int> right) => left.Span.SequenceEqual(right.Span);
protected override ReadOnlyMemory<int>[] TestValues => [null, new ReadOnlyMemory<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
protected override ReadOnlyMemory<int>[] TestValues => [default, new ReadOnlyMemory<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
}

public class ReadOnlyMemoryCopierTests(ITestOutputHelper output) : CopierTester<ReadOnlyMemory<int>, ReadOnlyMemoryCopier<int>>(output)
Expand All @@ -1126,7 +1156,7 @@ protected override ReadOnlyMemory<int> CreateValue()
}

protected override bool Equals(ReadOnlyMemory<int> left, ReadOnlyMemory<int> right) => left.Span.SequenceEqual(right.Span);
protected override ReadOnlyMemory<int>[] TestValues => [null, new ReadOnlyMemory<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
protected override ReadOnlyMemory<int>[] TestValues => [default, new ReadOnlyMemory<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
}

public class ArraySegmentCodecTests(ITestOutputHelper output) : FieldCodecTester<ArraySegment<int>, ArraySegmentCodec<int>>(output)
Expand All @@ -1140,7 +1170,7 @@ protected override ArraySegment<int> CreateValue()
}

protected override bool Equals(ArraySegment<int> left, ArraySegment<int> right) => left.SequenceEqual(right);
protected override ArraySegment<int>[] TestValues => [null, new ArraySegment<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
protected override ArraySegment<int>[] TestValues => [default, new ArraySegment<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
}

public class ArraySegmentCopierTests(ITestOutputHelper output) : CopierTester<ArraySegment<int>, ArraySegmentCopier<int>>(output)
Expand All @@ -1154,7 +1184,7 @@ protected override ArraySegment<int> CreateValue()
}

protected override bool Equals(ArraySegment<int> left, ArraySegment<int> right) => left.SequenceEqual(right);
protected override ArraySegment<int>[] TestValues => [null, new ArraySegment<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
protected override ArraySegment<int>[] TestValues => [default, new ArraySegment<int>([], 0, 0), CreateValue(), CreateValue(), CreateValue()];
}

public class ArrayCodecTests(ITestOutputHelper output) : FieldCodecTester<int[], ArrayCodec<int>>(output)
Expand Down

0 comments on commit 818ef3a

Please sign in to comment.