From 6d6ed8bc723b615972f9531d37299dea50d8372b Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Sun, 16 Feb 2020 02:57:45 +0000 Subject: [PATCH 1/4] Intrinsicify SequenceEqual --- .../src/System/SpanHelpers.Byte.cs | 219 +++++++++++++----- 1 file changed, 165 insertions(+), 54 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index 4d84fd38518ef..833b365cc5b47 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -13,8 +13,10 @@ #pragma warning disable SA1121 // explicitly using type aliases instead of built-in types #if TARGET_64BIT using nuint = System.UInt64; +using nint = System.Int64; #else using nuint = System.UInt32; +using nint = System.Int32; #endif // TARGET_64BIT namespace System @@ -1310,84 +1312,169 @@ public static unsafe int LastIndexOfAny(ref byte searchSpace, byte value0, byte // Optimized byte-based SequenceEquals. The "length" parameter for this one is declared a nuint rather than int as we also use it for types other than byte // where the length can exceed 2Gb once scaled by sizeof(T). [MethodImpl(MethodImplOptions.AggressiveOptimization)] - public static unsafe bool SequenceEqual(ref byte first, ref byte second, nuint length) + public static bool SequenceEqual(ref byte first, ref byte second, nuint length) { - IntPtr offset = (IntPtr)0; // Use IntPtr for arithmetic to avoid unnecessary 64->32->64 truncations - IntPtr lengthToExamine = (IntPtr)(void*)length; - - if ((byte*)lengthToExamine >= (byte*)sizeof(UIntPtr)) + // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations + if (length >= sizeof(nuint)) { - // Only check that the ref is the same if buffers are large, and hence - // its worth avoiding doing unnecessary comparisons - if (Unsafe.AreSame(ref first, ref second)) - goto Equal; + // Conditional jmp foward to favor shorter lengths. (See comment at "Equal:" label) + // The longer lengths can make back the time due to branch misprediction + // better than shorter lengths. + goto Longer; + } - if (Vector.IsHardwareAccelerated && (byte*)lengthToExamine >= (byte*)Vector.Count) +#if TARGET_64BIT + // On 32-bit, this will never be true since sizeof(nuint) == 4 + if (length >= sizeof(uint)) + { + nint offset = 0; + if (length > sizeof(uint)) { - lengthToExamine -= Vector.Count; - while ((byte*)lengthToExamine > (byte*)offset) + // Set offset for next compare to sizeof(uint) from end + offset = (nint)length - sizeof(uint); + // Compare start + if (LoadUInt(ref first) != LoadUInt(ref second)) { - if (LoadVector(ref first, offset) != LoadVector(ref second, offset)) - { - goto NotEqual; - } - offset += Vector.Count; + goto NotEqual; } - return LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine); } - Debug.Assert((byte*)lengthToExamine >= (byte*)sizeof(UIntPtr)); - - lengthToExamine -= sizeof(UIntPtr); - while ((byte*)lengthToExamine > (byte*)offset) + // Compare end + return LoadUInt(ref first, offset) == LoadUInt(ref second, offset); + } +#endif + if (length >= sizeof(ushort)) + { + nint offset = 0; + if (length > sizeof(ushort)) { - if (LoadUIntPtr(ref first, offset) != LoadUIntPtr(ref second, offset)) + // Set offset for next compare to sizeof(ushort) from end + offset = (nint)length - sizeof(ushort); + // Compare start + if (LoadUShort(ref first) != LoadUShort(ref second)) { goto NotEqual; } - offset += sizeof(UIntPtr); } - return LoadUIntPtr(ref first, lengthToExamine) == LoadUIntPtr(ref second, lengthToExamine); + + // Compare end + return LoadUShort(ref first, offset) == LoadUShort(ref second, offset); } - Debug.Assert((byte*)lengthToExamine < (byte*)sizeof(UIntPtr)); + if (length != 0) + { + // Only length 1 possible + Debug.Assert((int)length == 1); + return first == second; + } - // On 32-bit, this will never be true since sizeof(UIntPtr) == 4 -#if TARGET_64BIT - if ((byte*)lengthToExamine >= (byte*)sizeof(int)) + Debug.Assert((int)length == 0); + goto Equal; + Longer: + // Only check that the ref is the same if buffers are large, and hence + // its worth avoiding doing unnecessary comparisons + if (Unsafe.AreSame(ref first, ref second)) + { + // This becomes a conditional jmp foward to not favor it. (See comment at "Equal:" label) + return true; + } + + if (Sse.IsSupported) { - if (LoadInt(ref first, offset) != LoadInt(ref second, offset)) + if (Avx.IsSupported && (nint)length >= Vector256.Count) + { + nint offset = 0; + nint lengthToExamine = (nint)length - Vector256.Count; + if (lengthToExamine != 0) + { + do + { + if (!LoadVector256(ref first, offset).Equals(LoadVector256(ref second, offset))) + { + goto NotEqual; + } + offset += Vector256.Count; + } while (lengthToExamine > offset); + } + + // Do final compare as Vector256.Count from end rather than start + Debug.Assert(lengthToExamine >= 0); + return LoadVector256(ref first, lengthToExamine).Equals(LoadVector256(ref second, lengthToExamine)); + } + else if ((nint)length >= Vector128.Count) { - goto NotEqual; + nint offset = 0; + nint lengthToExamine = (nint)length - Vector128.Count; + if (lengthToExamine != 0) + { + do + { + if (!LoadVector128(ref first, offset).Equals(LoadVector128(ref second, offset))) + { + goto NotEqual; + } + offset += Vector128.Count; + } while (lengthToExamine > offset); + } + + // Do final compare as Vector128.Count from end rather than start + Debug.Assert(lengthToExamine >= 0); + return LoadVector128(ref first, lengthToExamine).Equals(LoadVector128(ref second, lengthToExamine)); } - offset += sizeof(int); - lengthToExamine -= sizeof(int); } -#endif - - if ((byte*)lengthToExamine >= (byte*)sizeof(short)) + else if (Vector.IsHardwareAccelerated && (nint)length >= Vector.Count) { - if (LoadShort(ref first, offset) != LoadShort(ref second, offset)) + nint offset = 0; + nint lengthToExamine = (nint)length - Vector.Count; + if (lengthToExamine > 0) { - goto NotEqual; + do + { + if (LoadVector(ref first, offset) != LoadVector(ref second, offset)) + { + goto NotEqual; + } + offset += Vector.Count; + } while (lengthToExamine > offset); } - offset += sizeof(short); - lengthToExamine -= sizeof(short); + + // Do final compare as Vector.Count from end rather than start + Debug.Assert(lengthToExamine >= 0); + return LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine); } - if (lengthToExamine != IntPtr.Zero) + Debug.Assert(length >= sizeof(nuint)); { - Debug.Assert((int)lengthToExamine == 1); - - if (Unsafe.AddByteOffset(ref first, offset) != Unsafe.AddByteOffset(ref second, offset)) + nint offset = 0; + nint lengthToExamine = (nint)length - sizeof(nuint); + if (lengthToExamine > 0) { - goto NotEqual; + do + { + // Compare unsigned so not do a sign extend mov on 64 bit + if (LoadNUInt(ref first, offset) != LoadNUInt(ref second, offset)) + { + goto NotEqual; + } + offset += sizeof(nuint); + } while (lengthToExamine > offset); + + // Do final compare as sizeof(nuint) from end rather than start + Debug.Assert(lengthToExamine >= 0); + return LoadNUInt(ref first, lengthToExamine) == LoadNUInt(ref second, lengthToExamine); } } + // As there are so many true/false exit points the Jit will coalesce them to one location. + // We want them at the end so the conditional early exit jmps are all jmp forwards so the + // branch predictor in a uninitialized state will not take them e.g. + // - loops are conditional jmps backwards and predicted + // - exceptions are conditional fowards jmps and not predicted + // When the sequence is equal; which is the longest execution, we want it to determine that + // as fast as possible so we do not want the early outs to be "predicted not taken" branches. Equal: return true; - NotEqual: // Workaround for https://github.com/dotnet/runtime/issues/8795 + NotEqual: return false; } @@ -1644,29 +1731,53 @@ private static int LocateLastFoundByte(ulong match) 0x01ul << 48) + 1; [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe short LoadShort(ref byte start, IntPtr offset) - => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, offset)); + private static ushort LoadUShort(ref byte start) + => Unsafe.ReadUnaligned(ref start); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ushort LoadUShort(ref byte start, nint offset) + => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint LoadUInt(ref byte start) + => Unsafe.ReadUnaligned(ref start); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint LoadUInt(ref byte start, nint offset) + => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe int LoadInt(ref byte start, IntPtr offset) - => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, offset)); + private static nuint LoadNUInt(ref byte start, nint offset) + => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe UIntPtr LoadUIntPtr(ref byte start, IntPtr offset) + private static UIntPtr LoadUIntPtr(ref byte start, IntPtr offset) => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector LoadVector(ref byte start, IntPtr offset) + private static Vector LoadVector(ref byte start, IntPtr offset) => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector128 LoadVector128(ref byte start, IntPtr offset) + private static Vector LoadVector(ref byte start, nint offset) + => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 LoadVector128(ref byte start, IntPtr offset) => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static unsafe Vector256 LoadVector256(ref byte start, IntPtr offset) + private static Vector128 LoadVector128(ref byte start, nint offset) + => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 LoadVector256(ref byte start, IntPtr offset) => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, offset)); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 LoadVector256(ref byte start, nint offset) + => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe IntPtr GetByteVectorSpanLength(IntPtr offset, int length) => (IntPtr)((length - (int)(byte*)offset) & ~(Vector.Count - 1)); From 64c3d8246e444805655dc880abf15899aa47e6a1 Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Sat, 14 Mar 2020 03:58:28 +0000 Subject: [PATCH 2/4] Feedback + R2R --- .../src/System/SpanHelpers.Byte.cs | 167 ++++++++++-------- 1 file changed, 96 insertions(+), 71 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index 833b365cc5b47..aeb9529b2b151 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -1311,7 +1311,6 @@ public static unsafe int LastIndexOfAny(ref byte searchSpace, byte value0, byte // Optimized byte-based SequenceEquals. The "length" parameter for this one is declared a nuint rather than int as we also use it for types other than byte // where the length can exceed 2Gb once scaled by sizeof(T). - [MethodImpl(MethodImplOptions.AggressiveOptimization)] public static bool SequenceEqual(ref byte first, ref byte second, nuint length) { // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations @@ -1327,49 +1326,33 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) // On 32-bit, this will never be true since sizeof(nuint) == 4 if (length >= sizeof(uint)) { - nint offset = 0; - if (length > sizeof(uint)) - { - // Set offset for next compare to sizeof(uint) from end - offset = (nint)length - sizeof(uint); - // Compare start - if (LoadUInt(ref first) != LoadUInt(ref second)) - { - goto NotEqual; - } - } - - // Compare end - return LoadUInt(ref first, offset) == LoadUInt(ref second, offset); + nuint offset = length - sizeof(uint); + uint differentBits = LoadUInt(ref first) - LoadUInt(ref second); + differentBits |= LoadUInt(ref first, offset) - LoadUInt(ref second, offset); + return (differentBits == 0); } #endif - if (length >= sizeof(ushort)) { - nint offset = 0; - if (length > sizeof(ushort)) + nuint offset = 0; + uint differentBits = 0; + if ((length & 2) != 0) { - // Set offset for next compare to sizeof(ushort) from end - offset = (nint)length - sizeof(ushort); - // Compare start - if (LoadUShort(ref first) != LoadUShort(ref second)) - { - goto NotEqual; - } + offset = 2; + differentBits = LoadUShort(ref first); + differentBits -= LoadUShort(ref second); } - - // Compare end - return LoadUShort(ref first, offset) == LoadUShort(ref second, offset); + if ((length & 1) != 0) + { + differentBits |= (uint)Unsafe.AddByteOffset(ref first, offset) - (uint)Unsafe.AddByteOffset(ref second, offset); + } + return (differentBits == 0); } - if (length != 0) - { - // Only length 1 possible - Debug.Assert((int)length == 1); - return first == second; - } + // When the sequence is equal; which is the longest execution, we want it to determine that + // as fast as possible so we do not want the early outs to be "predicted not taken" branches. + Equal: + return true; - Debug.Assert((int)length == 0); - goto Equal; Longer: // Only check that the ref is the same if buffers are large, and hence // its worth avoiding doing unnecessary comparisons @@ -1379,47 +1362,67 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) return true; } - if (Sse.IsSupported) + if (Sse2.IsSupported) { - if (Avx.IsSupported && (nint)length >= Vector256.Count) + if (Avx2.IsSupported && length >= (nuint)Vector256.Count) { - nint offset = 0; - nint lengthToExamine = (nint)length - Vector256.Count; + Vector256 result; + nuint offset = 0; + nuint lengthToExamine = length - (nuint)Vector256.Count; if (lengthToExamine != 0) { do { - if (!LoadVector256(ref first, offset).Equals(LoadVector256(ref second, offset))) + result = Avx2.CompareEqual(LoadVector256(ref first, offset), LoadVector256(ref second, offset)); + if (Avx2.MoveMask(result) != -1) { goto NotEqual; } - offset += Vector256.Count; + offset += (nuint)Vector256.Count; } while (lengthToExamine > offset); } // Do final compare as Vector256.Count from end rather than start Debug.Assert(lengthToExamine >= 0); - return LoadVector256(ref first, lengthToExamine).Equals(LoadVector256(ref second, lengthToExamine)); + result = Avx2.CompareEqual(LoadVector256(ref first, lengthToExamine), LoadVector256(ref second, lengthToExamine)); + if (Avx2.MoveMask(result) == -1) + { + goto Equal; + } + + goto NotEqual; } - else if ((nint)length >= Vector128.Count) + // Use Vector128.Size as Vector128.Count doesn't inline at R2R time + // https://github.com/dotnet/runtime/issues/32714 + else if (length >= Vector128.Size) { - nint offset = 0; - nint lengthToExamine = (nint)length - Vector128.Count; + Vector128 result; + nuint offset = 0; + nuint lengthToExamine = length - Vector128.Size; if (lengthToExamine != 0) { do { - if (!LoadVector128(ref first, offset).Equals(LoadVector128(ref second, offset))) + // We use instrincs directly as .Equals calls .AsByte() which doesn't inline at R2R time + // https://github.com/dotnet/runtime/issues/32714 + result = Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset)); + if (Sse2.MoveMask(result) != 0xFFFF) { goto NotEqual; } - offset += Vector128.Count; + offset += Vector128.Size; } while (lengthToExamine > offset); } // Do final compare as Vector128.Count from end rather than start Debug.Assert(lengthToExamine >= 0); - return LoadVector128(ref first, lengthToExamine).Equals(LoadVector128(ref second, lengthToExamine)); + result = Sse2.CompareEqual(LoadVector128(ref first, lengthToExamine), LoadVector128(ref second, lengthToExamine)); + if (Sse2.MoveMask(result) == 0xFFFF) + { + goto Equal; + } + + goto NotEqual; } } else if (Vector.IsHardwareAccelerated && (nint)length >= Vector.Count) @@ -1440,28 +1443,50 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) // Do final compare as Vector.Count from end rather than start Debug.Assert(lengthToExamine >= 0); - return LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine); + if (LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine)) + { + goto Equal; + } + + goto NotEqual; } - Debug.Assert(length >= sizeof(nuint)); +#if TARGET_64BIT + if (Sse2.IsSupported) { - nint offset = 0; - nint lengthToExamine = (nint)length - sizeof(nuint); - if (lengthToExamine > 0) + Debug.Assert(length <= sizeof(nuint) * 2); + + nuint offset = length - sizeof(nuint); + nuint differentBits = LoadNUInt(ref first) - LoadNUInt(ref second); + differentBits |= LoadNUInt(ref first, offset) - LoadNUInt(ref second, offset); + return (differentBits == 0); + } + else +#endif + { + Debug.Assert(length >= sizeof(nuint)); { - do + nuint offset = 0; + nuint lengthToExamine = length - sizeof(nuint); + if (lengthToExamine > 0) { - // Compare unsigned so not do a sign extend mov on 64 bit - if (LoadNUInt(ref first, offset) != LoadNUInt(ref second, offset)) + do { - goto NotEqual; - } - offset += sizeof(nuint); - } while (lengthToExamine > offset); + // Compare unsigned so not do a sign extend mov on 64 bit + if (LoadNUInt(ref first, offset) != LoadNUInt(ref second, offset)) + { + goto NotEqual; + } + offset += sizeof(nuint); + } while (lengthToExamine > offset); + } // Do final compare as sizeof(nuint) from end rather than start Debug.Assert(lengthToExamine >= 0); - return LoadNUInt(ref first, lengthToExamine) == LoadNUInt(ref second, lengthToExamine); + if (LoadNUInt(ref first, lengthToExamine) == LoadNUInt(ref second, lengthToExamine)) + { + goto Equal; + } } } @@ -1470,10 +1495,6 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) // branch predictor in a uninitialized state will not take them e.g. // - loops are conditional jmps backwards and predicted // - exceptions are conditional fowards jmps and not predicted - // When the sequence is equal; which is the longest execution, we want it to determine that - // as fast as possible so we do not want the early outs to be "predicted not taken" branches. - Equal: - return true; NotEqual: return false; } @@ -1735,7 +1756,7 @@ private static ushort LoadUShort(ref byte start) => Unsafe.ReadUnaligned(ref start); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static ushort LoadUShort(ref byte start, nint offset) + private static ushort LoadUShort(ref byte start, nuint offset) => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1743,11 +1764,15 @@ private static uint LoadUInt(ref byte start) => Unsafe.ReadUnaligned(ref start); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static uint LoadUInt(ref byte start, nint offset) + private static uint LoadUInt(ref byte start, nuint offset) => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static nuint LoadNUInt(ref byte start, nint offset) + private static nuint LoadNUInt(ref byte start) + => Unsafe.ReadUnaligned(ref start); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static nuint LoadNUInt(ref byte start, nuint offset) => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1767,7 +1792,7 @@ private static Vector128 LoadVector128(ref byte start, IntPtr offset) => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector128 LoadVector128(ref byte start, nint offset) + private static Vector128 LoadVector128(ref byte start, nuint offset) => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1775,7 +1800,7 @@ private static Vector256 LoadVector256(ref byte start, IntPtr offset) => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector256 LoadVector256(ref byte start, nint offset) + private static Vector256 LoadVector256(ref byte start, nuint offset) => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] From 0353f743e944224e11474eab7a1490e8d9474f68 Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Mon, 16 Mar 2020 00:37:47 +0000 Subject: [PATCH 3/4] Tweak jmps --- .../src/System/SpanHelpers.Byte.cs | 87 +++++++++++-------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index aeb9529b2b151..a3fceeba6ea2f 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -1313,6 +1313,7 @@ public static unsafe int LastIndexOfAny(ref byte searchSpace, byte value0, byte // where the length can exceed 2Gb once scaled by sizeof(T). public static bool SequenceEqual(ref byte first, ref byte second, nuint length) { + bool result; // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations if (length >= sizeof(nuint)) { @@ -1323,21 +1324,14 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) } #if TARGET_64BIT - // On 32-bit, this will never be true since sizeof(nuint) == 4 - if (length >= sizeof(uint)) - { - nuint offset = length - sizeof(uint); - uint differentBits = LoadUInt(ref first) - LoadUInt(ref second); - differentBits |= LoadUInt(ref first, offset) - LoadUInt(ref second, offset); - return (differentBits == 0); - } + // On 32-bit, this will always be true since sizeof(nuint) == 4 + if (length < sizeof(uint)) #endif { - nuint offset = 0; uint differentBits = 0; - if ((length & 2) != 0) + nuint offset = (length & 2); + if (offset != 0) { - offset = 2; differentBits = LoadUShort(ref first); differentBits -= LoadUShort(ref second); } @@ -1345,36 +1339,52 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) { differentBits |= (uint)Unsafe.AddByteOffset(ref first, offset) - (uint)Unsafe.AddByteOffset(ref second, offset); } - return (differentBits == 0); + result = (differentBits == 0); + goto Result; + } +#if TARGET_64BIT + else + { + nuint offset = length - sizeof(uint); + uint differentBits = LoadUInt(ref first) - LoadUInt(ref second); + differentBits |= LoadUInt(ref first, offset) - LoadUInt(ref second, offset); + result = (differentBits == 0); + goto Result; + } +#endif + Longer: + // Only check that the ref is the same if buffers are large, + // and hence its worth avoiding doing unnecessary comparisons + if (!Unsafe.AreSame(ref first, ref second)) + { + // C# compiler inverts this test, making the outer goto the conditional jmp. + goto Vector; } + // This becomes a conditional jmp foward to not favor it. + goto Equal; + + Result: + return result; // When the sequence is equal; which is the longest execution, we want it to determine that // as fast as possible so we do not want the early outs to be "predicted not taken" branches. Equal: return true; - Longer: - // Only check that the ref is the same if buffers are large, and hence - // its worth avoiding doing unnecessary comparisons - if (Unsafe.AreSame(ref first, ref second)) - { - // This becomes a conditional jmp foward to not favor it. (See comment at "Equal:" label) - return true; - } - + Vector: if (Sse2.IsSupported) { if (Avx2.IsSupported && length >= (nuint)Vector256.Count) { - Vector256 result; + Vector256 vecResult; nuint offset = 0; nuint lengthToExamine = length - (nuint)Vector256.Count; if (lengthToExamine != 0) { do { - result = Avx2.CompareEqual(LoadVector256(ref first, offset), LoadVector256(ref second, offset)); - if (Avx2.MoveMask(result) != -1) + vecResult = Avx2.CompareEqual(LoadVector256(ref first, offset), LoadVector256(ref second, offset)); + if (Avx2.MoveMask(vecResult) != -1) { goto NotEqual; } @@ -1384,19 +1394,21 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) // Do final compare as Vector256.Count from end rather than start Debug.Assert(lengthToExamine >= 0); - result = Avx2.CompareEqual(LoadVector256(ref first, lengthToExamine), LoadVector256(ref second, lengthToExamine)); - if (Avx2.MoveMask(result) == -1) + vecResult = Avx2.CompareEqual(LoadVector256(ref first, lengthToExamine), LoadVector256(ref second, lengthToExamine)); + if (Avx2.MoveMask(vecResult) == -1) { + // C# compiler inverts this test, making the outer goto the conditional jmp. goto Equal; } + // This becomes a conditional jmp foward to not favor it. goto NotEqual; } // Use Vector128.Size as Vector128.Count doesn't inline at R2R time // https://github.com/dotnet/runtime/issues/32714 else if (length >= Vector128.Size) { - Vector128 result; + Vector128 vecResult; nuint offset = 0; nuint lengthToExamine = length - Vector128.Size; if (lengthToExamine != 0) @@ -1405,8 +1417,8 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) { // We use instrincs directly as .Equals calls .AsByte() which doesn't inline at R2R time // https://github.com/dotnet/runtime/issues/32714 - result = Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset)); - if (Sse2.MoveMask(result) != 0xFFFF) + vecResult = Sse2.CompareEqual(LoadVector128(ref first, offset), LoadVector128(ref second, offset)); + if (Sse2.MoveMask(vecResult) != 0xFFFF) { goto NotEqual; } @@ -1416,12 +1428,14 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) // Do final compare as Vector128.Count from end rather than start Debug.Assert(lengthToExamine >= 0); - result = Sse2.CompareEqual(LoadVector128(ref first, lengthToExamine), LoadVector128(ref second, lengthToExamine)); - if (Sse2.MoveMask(result) == 0xFFFF) + vecResult = Sse2.CompareEqual(LoadVector128(ref first, lengthToExamine), LoadVector128(ref second, lengthToExamine)); + if (Sse2.MoveMask(vecResult) == 0xFFFF) { + // C# compiler inverts this test, making the outer goto the conditional jmp. goto Equal; } + // This becomes a conditional jmp foward to not favor it. goto NotEqual; } } @@ -1445,9 +1459,11 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) Debug.Assert(lengthToExamine >= 0); if (LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine)) { + // C# compiler inverts this test, making the outer goto the conditional jmp. goto Equal; } + // This becomes a conditional jmp foward to not favor it. goto NotEqual; } @@ -1459,7 +1475,8 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) nuint offset = length - sizeof(nuint); nuint differentBits = LoadNUInt(ref first) - LoadNUInt(ref second); differentBits |= LoadNUInt(ref first, offset) - LoadNUInt(ref second, offset); - return (differentBits == 0); + result = (differentBits == 0); + goto Result; } else #endif @@ -1483,10 +1500,8 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) // Do final compare as sizeof(nuint) from end rather than start Debug.Assert(lengthToExamine >= 0); - if (LoadNUInt(ref first, lengthToExamine) == LoadNUInt(ref second, lengthToExamine)) - { - goto Equal; - } + result = (LoadNUInt(ref first, lengthToExamine) == LoadNUInt(ref second, lengthToExamine)); + goto Result; } } From 6b3970a1f1e785166e278322e97463c18b597dff Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Tue, 17 Mar 2020 02:05:26 +0000 Subject: [PATCH 4/4] Fix potential overflows --- .../src/System/SpanHelpers.Byte.cs | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index a3fceeba6ea2f..8e2245a97a124 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -1379,6 +1379,8 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) Vector256 vecResult; nuint offset = 0; nuint lengthToExamine = length - (nuint)Vector256.Count; + // Unsigned, so it shouldn't have overflowed larger than length (rather than negative) + Debug.Assert(lengthToExamine < length); if (lengthToExamine != 0) { do @@ -1393,7 +1395,6 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) } // Do final compare as Vector256.Count from end rather than start - Debug.Assert(lengthToExamine >= 0); vecResult = Avx2.CompareEqual(LoadVector256(ref first, lengthToExamine), LoadVector256(ref second, lengthToExamine)); if (Avx2.MoveMask(vecResult) == -1) { @@ -1411,6 +1412,8 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) Vector128 vecResult; nuint offset = 0; nuint lengthToExamine = length - Vector128.Size; + // Unsigned, so it shouldn't have overflowed larger than length (rather than negative) + Debug.Assert(lengthToExamine < length); if (lengthToExamine != 0) { do @@ -1427,7 +1430,6 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) } // Do final compare as Vector128.Count from end rather than start - Debug.Assert(lengthToExamine >= 0); vecResult = Sse2.CompareEqual(LoadVector128(ref first, lengthToExamine), LoadVector128(ref second, lengthToExamine)); if (Sse2.MoveMask(vecResult) == 0xFFFF) { @@ -1439,10 +1441,12 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) goto NotEqual; } } - else if (Vector.IsHardwareAccelerated && (nint)length >= Vector.Count) + else if (Vector.IsHardwareAccelerated && length >= (nuint)Vector.Count) { - nint offset = 0; - nint lengthToExamine = (nint)length - Vector.Count; + nuint offset = 0; + nuint lengthToExamine = length - (nuint)Vector.Count; + // Unsigned, so it shouldn't have overflowed larger than length (rather than negative) + Debug.Assert(lengthToExamine < length); if (lengthToExamine > 0) { do @@ -1451,12 +1455,11 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) { goto NotEqual; } - offset += Vector.Count; + offset += (nuint)Vector.Count; } while (lengthToExamine > offset); } // Do final compare as Vector.Count from end rather than start - Debug.Assert(lengthToExamine >= 0); if (LoadVector(ref first, lengthToExamine) == LoadVector(ref second, lengthToExamine)) { // C# compiler inverts this test, making the outer goto the conditional jmp. @@ -1485,6 +1488,8 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) { nuint offset = 0; nuint lengthToExamine = length - sizeof(nuint); + // Unsigned, so it shouldn't have overflowed larger than length (rather than negative) + Debug.Assert(lengthToExamine < length); if (lengthToExamine > 0) { do @@ -1499,7 +1504,6 @@ public static bool SequenceEqual(ref byte first, ref byte second, nuint length) } // Do final compare as sizeof(nuint) from end rather than start - Debug.Assert(lengthToExamine >= 0); result = (LoadNUInt(ref first, lengthToExamine) == LoadNUInt(ref second, lengthToExamine)); goto Result; } @@ -1772,7 +1776,7 @@ private static ushort LoadUShort(ref byte start) [MethodImpl(MethodImplOptions.AggressiveInlining)] private static ushort LoadUShort(ref byte start, nuint offset) - => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] private static uint LoadUInt(ref byte start) @@ -1780,7 +1784,7 @@ private static uint LoadUInt(ref byte start) [MethodImpl(MethodImplOptions.AggressiveInlining)] private static uint LoadUInt(ref byte start, nuint offset) - => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] private static nuint LoadNUInt(ref byte start) @@ -1788,7 +1792,7 @@ private static nuint LoadNUInt(ref byte start) [MethodImpl(MethodImplOptions.AggressiveInlining)] private static nuint LoadNUInt(ref byte start, nuint offset) - => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + => Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] private static UIntPtr LoadUIntPtr(ref byte start, IntPtr offset) @@ -1799,8 +1803,8 @@ private static Vector LoadVector(ref byte start, IntPtr offset) => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector LoadVector(ref byte start, nint offset) - => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + private static Vector LoadVector(ref byte start, nuint offset) + => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] private static Vector128 LoadVector128(ref byte start, IntPtr offset) @@ -1808,7 +1812,7 @@ private static Vector128 LoadVector128(ref byte start, IntPtr offset) [MethodImpl(MethodImplOptions.AggressiveInlining)] private static Vector128 LoadVector128(ref byte start, nuint offset) - => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] private static Vector256 LoadVector256(ref byte start, IntPtr offset) @@ -1816,7 +1820,7 @@ private static Vector256 LoadVector256(ref byte start, IntPtr offset) [MethodImpl(MethodImplOptions.AggressiveInlining)] private static Vector256 LoadVector256(ref byte start, nuint offset) - => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, (IntPtr)offset)); + => Unsafe.ReadUnaligned>(ref Unsafe.AddByteOffset(ref start, offset)); [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe IntPtr GetByteVectorSpanLength(IntPtr offset, int length)