Skip to content

Commit

Permalink
Use intrinsics for SequenceEqual<byte> vectorization to emit at R2R (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
benaadams authored Apr 27, 2020
1 parent feddac7 commit 535b998
Showing 1 changed file with 213 additions and 58 deletions.
271 changes: 213 additions & 58 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
#pragma warning disable SA1121 // explicitly using type aliases instead of built-in types
#if TARGET_64BIT
using nuint = System.UInt64;
using nint = System.Int64;
#else
using nuint = System.UInt32;
using nint = System.Int32;
#endif // TARGET_64BIT

namespace System
Expand Down Expand Up @@ -1309,85 +1311,210 @@ public static unsafe int LastIndexOfAny(ref byte searchSpace, byte value0, byte

// Optimized byte-based SequenceEquals. The "length" parameter for this one is declared a nuint rather than int as we also use it for types other than byte
// where the length can exceed 2Gb once scaled by sizeof(T).
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public static unsafe bool SequenceEqual(ref byte first, ref byte second, nuint length)
public static bool SequenceEqual(ref byte first, ref byte second, nuint length)
{
IntPtr offset = (IntPtr)0; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations
IntPtr lengthToExamine = (IntPtr)(void*)length;
bool result;
// Use nint for arithmetic to avoid unnecessary 64->32->64 truncations
if (length >= sizeof(nuint))
{
// Conditional jmp foward to favor shorter lengths. (See comment at "Equal:" label)
// The longer lengths can make back the time due to branch misprediction
// better than shorter lengths.
goto Longer;
}

if ((byte*)lengthToExamine >= (byte*)sizeof(UIntPtr))
#if TARGET_64BIT
// On 32-bit, this will always be true since sizeof(nuint) == 4
if (length < sizeof(uint))
#endif
{
// Only check that the ref is the same if buffers are large, and hence
// its worth avoiding doing unnecessary comparisons
if (Unsafe.AreSame(ref first, ref second))
goto Equal;
uint differentBits = 0;
nuint offset = (length & 2);
if (offset != 0)
{
differentBits = LoadUShort(ref first);
differentBits -= LoadUShort(ref second);
}
if ((length & 1) != 0)
{
differentBits |= (uint)Unsafe.AddByteOffset(ref first, offset) - (uint)Unsafe.AddByteOffset(ref second, offset);
}
result = (differentBits == 0);
goto Result;
}
#if TARGET_64BIT
else
{
nuint offset = length - sizeof(uint);
uint differentBits = LoadUInt(ref first) - LoadUInt(ref second);
differentBits |= LoadUInt(ref first, offset) - LoadUInt(ref second, offset);
result = (differentBits == 0);
goto Result;
}
#endif
Longer:
// Only check that the ref is the same if buffers are large,
// and hence its worth avoiding doing unnecessary comparisons
if (!Unsafe.AreSame(ref first, ref second))
{
// C# compiler inverts this test, making the outer goto the conditional jmp.
goto Vector;
}

if (Vector.IsHardwareAccelerated && (byte*)lengthToExamine >= (byte*)Vector<byte>.Count)
// This becomes a conditional jmp foward to not favor it.
goto Equal;

Result:
return result;
// When the sequence is equal; which is the longest execution, we want it to determine that
// as fast as possible so we do not want the early outs to be "predicted not taken" branches.
Equal:
return true;

Vector:
if (Sse2.IsSupported)
{
if (Avx2.IsSupported && length >= (nuint)Vector256<byte>.Count)
{
lengthToExamine -= Vector<byte>.Count;
while ((byte*)lengthToExamine > (byte*)offset)
Vector256<byte> vecResult;
nuint offset = 0;
nuint lengthToExamine = length - (nuint)Vector256<byte>.Count;
// Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
Debug.Assert(lengthToExamine < length);
if (lengthToExamine != 0)
{
if (LoadVector(ref first, offset) != LoadVector(ref second, offset))
do
{
goto NotEqual;
}
offset += Vector<byte>.Count;
vecResult = Avx2.CompareEqual(LoadVector256(ref first, offset), LoadVector256(ref second, offset));
if (Avx2.MoveMask(vecResult) != -1)
{
goto NotEqual;
}
offset += (nuint)Vector256<byte>.Count;
} while (lengthToExamine > offset);
}
return LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine);
}

Debug.Assert((byte*)lengthToExamine >= (byte*)sizeof(UIntPtr));
// Do final compare as Vector256<byte>.Count from end rather than start
vecResult = Avx2.CompareEqual(LoadVector256(ref first, lengthToExamine), LoadVector256(ref second, lengthToExamine));
if (Avx2.MoveMask(vecResult) == -1)
{
// C# compiler inverts this test, making the outer goto the conditional jmp.
goto Equal;
}

lengthToExamine -= sizeof(UIntPtr);
while ((byte*)lengthToExamine > (byte*)offset)
// This becomes a conditional jmp foward to not favor it.
goto NotEqual;
}
// Use Vector128.Size as Vector128<byte>.Count doesn't inline at R2R time
// https://github.com/dotnet/runtime/issues/32714
else if (length >= Vector128.Size)
{
if (LoadUIntPtr(ref first, offset) != LoadUIntPtr(ref second, offset))
Vector128<byte> vecResult;
nuint offset = 0;
nuint lengthToExamine = length - Vector128.Size;
// Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
Debug.Assert(lengthToExamine < length);
if (lengthToExamine != 0)
{
goto NotEqual;
do
{
// We use instrincs directly as .Equals calls .AsByte() which doesn't inline at R2R time
// https://github.com/dotnet/runtime/issues/32714
vecResult = Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset));
if (Sse2.MoveMask(vecResult) != 0xFFFF)
{
goto NotEqual;
}
offset += Vector128.Size;
} while (lengthToExamine > offset);
}
offset += sizeof(UIntPtr);
}
return LoadUIntPtr(ref first, lengthToExamine) == LoadUIntPtr(ref second, lengthToExamine);
}

Debug.Assert((byte*)lengthToExamine < (byte*)sizeof(UIntPtr));
// Do final compare as Vector128<byte>.Count from end rather than start
vecResult = Sse2.CompareEqual(LoadVector128(ref first, lengthToExamine), LoadVector128(ref second, lengthToExamine));
if (Sse2.MoveMask(vecResult) == 0xFFFF)
{
// C# compiler inverts this test, making the outer goto the conditional jmp.
goto Equal;
}

// On 32-bit, this will never be true since sizeof(UIntPtr) == 4
#if TARGET_64BIT
if ((byte*)lengthToExamine >= (byte*)sizeof(int))
{
if (LoadInt(ref first, offset) != LoadInt(ref second, offset))
{
// This becomes a conditional jmp foward to not favor it.
goto NotEqual;
}
offset += sizeof(int);
lengthToExamine -= sizeof(int);
}
#endif

if ((byte*)lengthToExamine >= (byte*)sizeof(short))
else if (Vector.IsHardwareAccelerated && length >= (nuint)Vector<byte>.Count)
{
if (LoadShort(ref first, offset) != LoadShort(ref second, offset))
nuint offset = 0;
nuint lengthToExamine = length - (nuint)Vector<byte>.Count;
// Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
Debug.Assert(lengthToExamine < length);
if (lengthToExamine > 0)
{
goto NotEqual;
do
{
if (LoadVector(ref first, offset) != LoadVector(ref second, offset))
{
goto NotEqual;
}
offset += (nuint)Vector<byte>.Count;
} while (lengthToExamine > offset);
}
offset += sizeof(short);
lengthToExamine -= sizeof(short);

// Do final compare as Vector<byte>.Count from end rather than start
if (LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine))
{
// C# compiler inverts this test, making the outer goto the conditional jmp.
goto Equal;
}

// This becomes a conditional jmp foward to not favor it.
goto NotEqual;
}

if (lengthToExamine != IntPtr.Zero)
#if TARGET_64BIT
if (Sse2.IsSupported)
{
Debug.Assert((int)lengthToExamine == 1);
Debug.Assert(length <= sizeof(nuint) * 2);

if (Unsafe.AddByteOffset(ref first, offset) != Unsafe.AddByteOffset(ref second, offset))
nuint offset = length - sizeof(nuint);
nuint differentBits = LoadNUInt(ref first) - LoadNUInt(ref second);
differentBits |= LoadNUInt(ref first, offset) - LoadNUInt(ref second, offset);
result = (differentBits == 0);
goto Result;
}
else
#endif
{
Debug.Assert(length >= sizeof(nuint));
{
goto NotEqual;
nuint offset = 0;
nuint lengthToExamine = length - sizeof(nuint);
// Unsigned, so it shouldn't have overflowed larger than length (rather than negative)
Debug.Assert(lengthToExamine < length);
if (lengthToExamine > 0)
{
do
{
// Compare unsigned so not do a sign extend mov on 64 bit
if (LoadNUInt(ref first, offset) != LoadNUInt(ref second, offset))
{
goto NotEqual;
}
offset += sizeof(nuint);
} while (lengthToExamine > offset);
}

// Do final compare as sizeof(nuint) from end rather than start
result = (LoadNUInt(ref first, lengthToExamine) == LoadNUInt(ref second, lengthToExamine));
goto Result;
}
}

Equal:
return true;
NotEqual: // Workaround for https://github.com/dotnet/runtime/issues/8795
// As there are so many true/false exit points the Jit will coalesce them to one location.
// We want them at the end so the conditional early exit jmps are all jmp forwards so the
// branch predictor in a uninitialized state will not take them e.g.
// - loops are conditional jmps backwards and predicted
// - exceptions are conditional fowards jmps and not predicted
NotEqual:
return false;
}

Expand Down Expand Up @@ -1644,27 +1771,55 @@ private static int LocateLastFoundByte(ulong match)
0x01ul << 48) + 1;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe short LoadShort(ref byte start, IntPtr offset)
=> Unsafe.ReadUnaligned<short>(ref Unsafe.AddByteOffset(ref start, offset));
private static ushort LoadUShort(ref byte start)
=> Unsafe.ReadUnaligned<ushort>(ref start);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ushort LoadUShort(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<ushort>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int LoadInt(ref byte start, IntPtr offset)
=> Unsafe.ReadUnaligned<int>(ref Unsafe.AddByteOffset(ref start, offset));
private static uint LoadUInt(ref byte start)
=> Unsafe.ReadUnaligned<uint>(ref start);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe UIntPtr LoadUIntPtr(ref byte start, IntPtr offset)
private static uint LoadUInt(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<uint>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static nuint LoadNUInt(ref byte start)
=> Unsafe.ReadUnaligned<nuint>(ref start);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static nuint LoadNUInt(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<nuint>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static UIntPtr LoadUIntPtr(ref byte start, IntPtr offset)
=> Unsafe.ReadUnaligned<UIntPtr>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector<byte> LoadVector(ref byte start, IntPtr offset)
private static Vector<byte> LoadVector(ref byte start, IntPtr offset)
=> Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector<byte> LoadVector(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector128<byte> LoadVector128(ref byte start, IntPtr offset)
private static Vector128<byte> LoadVector128(ref byte start, IntPtr offset)
=> Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector256<byte> LoadVector256(ref byte start, IntPtr offset)
private static Vector128<byte> LoadVector128(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector256<byte> LoadVector256(ref byte start, IntPtr offset)
=> Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector256<byte> LoadVector256(ref byte start, nuint offset)
=> Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, offset));

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down

0 comments on commit 535b998

Please sign in to comment.