Skip to content

Commit

Permalink
Fix potential overflows
Browse files Browse the repository at this point in the history
  • Loading branch information
benaadams committed Mar 17, 2020
1 parent 0353f74 commit 6b3970a
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,8 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
Vector256<byte> vecResult;
nuint offset = 0;
nuint lengthToExamine = length - (nuint)Vector256<byte>.Count;
// Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
Debug.Assert(lengthToExamine < length);
if (lengthToExamine != 0)
{
do
Expand All @@ -1393,7 +1395,6 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
}

// Do final compare as Vector256<byte>.Count from end rather than start
Debug.Assert(lengthToExamine >= 0);
vecResult = Avx2.CompareEqual(LoadVector256(ref first, lengthToExamine), LoadVector256(ref second, lengthToExamine));
if (Avx2.MoveMask(vecResult) == -1)
{
Expand All @@ -1411,6 +1412,8 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
Vector128<byte> vecResult;
nuint offset = 0;
nuint lengthToExamine = length - Vector128.Size;
// Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
Debug.Assert(lengthToExamine < length);
if (lengthToExamine != 0)
{
do
Expand All @@ -1427,7 +1430,6 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
}

// Do final compare as Vector128<byte>.Count from end rather than start
Debug.Assert(lengthToExamine >= 0);
vecResult = Sse2.CompareEqual(LoadVector128(ref first, lengthToExamine), LoadVector128(ref second, lengthToExamine));
if (Sse2.MoveMask(vecResult) == 0xFFFF)
{
Expand All @@ -1439,10 +1441,12 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
goto NotEqual;
}
}
else if (Vector.IsHardwareAccelerated && (nint)length >= Vector<byte>.Count)
else if (Vector.IsHardwareAccelerated && length >= (nuint)Vector<byte>.Count)
{
nint offset = 0;
nint lengthToExamine = (nint)length - Vector<byte>.Count;
nuint offset = 0;
nuint lengthToExamine = length - (nuint)Vector<byte>.Count;
// Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
Debug.Assert(lengthToExamine < length);
if (lengthToExamine > 0)
{
do
Expand All @@ -1451,12 +1455,11 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
{
goto NotEqual;
}
offset += Vector<byte>.Count;
offset += (nuint)Vector<byte>.Count;
} while (lengthToExamine > offset);
}

// Do final compare as Vector<byte>.Count from end rather than start
Debug.Assert(lengthToExamine >= 0);
if (LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine))
{
// C# compiler inverts this test, making the outer goto the conditional jmp.
Expand Down Expand Up @@ -1485,6 +1488,8 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
{
nuint offset = 0;
nuint lengthToExamine = length - sizeof(nuint);
// Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
Debug.Assert(lengthToExamine < length);
if (lengthToExamine > 0)
{
do
Expand All @@ -1499,7 +1504,6 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
}

// Do final compare as sizeof(nuint) from end rather than start
Debug.Assert(lengthToExamine >= 0);
result = (LoadNUInt(ref first, lengthToExamine) == LoadNUInt(ref second, lengthToExamine));
goto Result;
}
Expand Down Expand Up @@ -1772,23 +1776,23 @@ private static ushort LoadUShort(ref byte start)

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ushort LoadUShort(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<ushort>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset));
=> Unsafe.ReadUnaligned<ushort>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static uint LoadUInt(ref byte start)
=> Unsafe.ReadUnaligned<uint>(ref start);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static uint LoadUInt(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<uint>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset));
=> Unsafe.ReadUnaligned<uint>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static nuint LoadNUInt(ref byte start)
=> Unsafe.ReadUnaligned<nuint>(ref start);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static nuint LoadNUInt(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<nuint>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset));
=> Unsafe.ReadUnaligned<nuint>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static UIntPtr LoadUIntPtr(ref byte start, IntPtr offset)
Expand All @@ -1799,24 +1803,24 @@ private static Vector<byte> LoadVector(ref byte start, IntPtr offset)
=> Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector<byte> LoadVector(ref byte start, nint offset)
=> Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset));
private static Vector<byte> LoadVector(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector128<byte> LoadVector128(ref byte start, IntPtr offset)
=> Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector128<byte> LoadVector128(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset));
=> Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector256<byte> LoadVector256(ref byte start, IntPtr offset)
=> Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector256<byte> LoadVector256(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset));
=> Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe IntPtr GetByteVectorSpanLength(IntPtr offset, int length)
Expand Down

0 comments on commit 6b3970a

Please sign in to comment.