diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index 6bda8b2c900e1..9e9e1dd28f8cf 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -1149,10 +1149,45 @@ private static void ValidateInputOutputSpanNonOverlapping(ReadOnlySpan in } } + /// Mask used to handle alignment elements before vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N alignment elements at the + /// beginning of the input, where elements in the vector after that will be zero'd. + /// + /// There actually exists 17 rows in the table with the last row being a repeat of the first. This is + /// done because it allows the main algorithms to use a simplified algorithm when computing the amount + /// of misalignment where we always skip the first 16 elements, even if already aligned, so we don't + /// double process them. This allows us to avoid an additional branch. + /// + private static ReadOnlySpan AlignmentUInt32Mask_16x16 => + [ + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + ]; + /// Mask used to handle remaining elements after vectorized handling of the input. /// /// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the /// end of the input, where elements in the vector prior to that will be zero'd. + /// + /// Much as with the AlignmentMask table, we actually have 17 rows where the last row is a repeat of + /// the first. Doing this allows us to avoid an additional branch and instead to always process the + /// last 16 elements via a conditional select instead. /// private static ReadOnlySpan RemainderUInt32Mask_16x16 => [ @@ -1172,7 +1207,7 @@ private static void ValidateInputOutputSpanNonOverlapping(ReadOnlySpan in 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, ]; } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index d77cd743a0713..fc2821163bbad 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -7,7 +7,6 @@ using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.Arm; using System.Runtime.Intrinsics.X86; -using System.Security.Cryptography; namespace System.Numerics.Tensors { @@ -797,248 +796,1699 @@ private static float Aggregate( where TTransformOperator : struct, IUnaryOperator where TAggregationOperator : struct, IAggregationOperator { - if (x.Length == 0) + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector512.Count)) + { + result = Vectorized512(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized512Small(ref xRef, remainder); + } + + return result; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector256.Count)) + { + result = Vectorized256(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized256Small(ref xRef, remainder); + } + + return result; + } + + if (Vector128.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector128.Count)) + { + result = Vectorized128(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized128Small(ref xRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available. + // It requires no branches to hit. + + return SoftwareFallback(ref xRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, nuint length) + { + float result = TAggregationOperator.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, i))); + } + + return result; + } + + static float Vectorized128(ref float xRef, nuint remainder) + { + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized128Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + + static float Vectorized256(ref float xRef, nuint remainder) + { + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized256Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + +#if NET8_0_OR_GREATER + static float Vectorized512(ref float xRef, nuint remainder) + { + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized512Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } +#endif + } + + /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. + /// Specifies the binary operation that should be applied to the pair-wise elements loaded from and . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied to the results of the binary operations on the pair-wise values. + /// + private static float Aggregate( + ReadOnlySpan x, ReadOnlySpan y) + where TBinaryOperator : struct, IBinaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector512.Count)) + { + result = Vectorized512(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized512Small(ref xRef, ref yRef, remainder); + } + + return result; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector256.Count)) + { + result = Vectorized256(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized256Small(ref xRef, ref yRef, remainder); + } + + return result; + } + + if (Vector128.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector128.Count)) + { + result = Vectorized128(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized128Small(ref xRef, ref yRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, ref yRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, ref float yRef, nuint length) + { + float result = TAggregationOperator.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i))); + } + + return result; + } + + static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) + { + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized128Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + + static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) + { + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized256Small(ref float xRef, ref float yRef, nuint remainder) { - return 0; + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; } - ref float xRef = ref MemoryMarshal.GetReference(x); +#if NET8_0_OR_GREATER + static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) + { + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) - { - // Load the first vector as the initial set of results - Vector512 result = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector512.Count; - int i = Vector512.Count; + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - while (i <= oneVectorFromEnd) - { - result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i))); - i += Vector512.Count; - } + // Adjusting the refs here allows us to avoid pinning for very small inputs - // Process the last vector in the span, masking off elements already processed. - if (i != x.Length) - { - result = TAggregationOperator.Invoke(result, - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.Create(TAggregationOperator.IdentityValue), - TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count))))); + xRef = ref *xPtr; + yRef = ref *yPtr; + } } - // Aggregate the lanes in the vector back into the scalar result - return TAggregationOperator.Invoke(result); - } -#endif + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) - { - // Load the first vector as the initial set of results - Vector256 result = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector256.Count; - int i = Vector256.Count; + beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - while (i <= oneVectorFromEnd) - { - result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i))); - i += Vector256.Count; - } + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. - // Process the last vector in the span, masking off elements already processed. - if (i != x.Length) + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) { - result = TAggregationOperator.Invoke(result, - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.Create(TAggregationOperator.IdentityValue), - TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count))))); - } + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } - // Aggregate the lanes in the vector back into the scalar result - return TAggregationOperator.Invoke(result); - } + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } - if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) - { - // Load the first vector as the initial set of results - Vector128 result = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector128.Count; - int i = Vector128.Count; + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - while (i <= oneVectorFromEnd) - { - result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i))); - i += Vector128.Count; - } + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } - // Process the last vector in the span, masking off elements already processed. - if (i != x.Length) - { - result = TAggregationOperator.Invoke(result, - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.Create(TAggregationOperator.IdentityValue), - TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count))))); + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } } - // Aggregate the lanes in the vector back into the scalar result - return TAggregationOperator.Invoke(result); + return TAggregationOperator.Invoke(vresult); } - // Vectorization isn't supported or there are too few elements to vectorize. - // Use a scalar implementation. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) { - float result = TTransformOperator.Invoke(x[0]); - for (int i = 1; i < x.Length; i++) + float result = TAggregationOperator.IdentityValue; + + switch (remainder) { - result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(x[i])); - } + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); - return result; - } - } + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); - /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. - /// Specifies the binary operation that should be applied to the pair-wise elements loaded from and . - /// - /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. - /// The aggregation is applied to the results of the binary operations on the pair-wise values. - /// - private static float Aggregate( - ReadOnlySpan x, ReadOnlySpan y) - where TBinaryOperator : struct, IBinaryOperator - where TAggregationOperator : struct, IAggregationOperator - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); - if (x.IsEmpty) - { - return 0; - } + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + result = TAggregationOperator.Invoke(vresult); + break; + } -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) - { - // Load the first vector as the initial set of results - Vector512 result = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, 0), Vector512.LoadUnsafe(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector512.Count; - int i = Vector512.Count; + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - while (i <= oneVectorFromEnd) - { - result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), Vector512.LoadUnsafe(ref yRef, (uint)i))); - i += Vector512.Count; - } + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); - // Process the last vector in the spans, masking off elements already processed. - if (i != x.Length) - { - result = TAggregationOperator.Invoke(result, - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.Create(TAggregationOperator.IdentityValue), - TBinaryOperator.Invoke( - Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)), - Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512.Count))))); - } + result = TAggregationOperator.Invoke(vresult); + break; + } - // Aggregate the lanes in the vector back into the scalar result - return TAggregationOperator.Invoke(result); - } -#endif + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) - { - // Load the first vector as the initial set of results - Vector256 result = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, 0), Vector256.LoadUnsafe(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector256.Count; - int i = Vector256.Count; + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - while (i <= oneVectorFromEnd) - { - result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), Vector256.LoadUnsafe(ref yRef, (uint)i))); - i += Vector256.Count; - } + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); - // Process the last vector in the spans, masking off elements already processed. - if (i != x.Length) - { - result = TAggregationOperator.Invoke(result, - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.Create(TAggregationOperator.IdentityValue), - TBinaryOperator.Invoke( - Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)), - Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256.Count))))); - } + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); - // Aggregate the lanes in the vector back into the scalar result - return TAggregationOperator.Invoke(result); - } + result = TAggregationOperator.Invoke(vresult); + break; + } - if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) - { - // Load the first vector as the initial set of results - Vector128 result = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, 0), Vector128.LoadUnsafe(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector128.Count; - int i = Vector128.Count; + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - while (i <= oneVectorFromEnd) - { - result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), Vector128.LoadUnsafe(ref yRef, (uint)i))); - i += Vector128.Count; - } + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); - // Process the last vector in the spans, masking off elements already processed. - if (i != x.Length) - { - result = TAggregationOperator.Invoke(result, - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.Create(TAggregationOperator.IdentityValue), - TBinaryOperator.Invoke( - Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)), - Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128.Count))))); - } + result = TAggregationOperator.Invoke(vresult); + break; + } - // Aggregate the lanes in the vector back into the scalar result - return TAggregationOperator.Invoke(result); - } + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } - // Vectorization isn't supported or there are too few elements to vectorize. - // Use a scalar implementation. - { - float result = TBinaryOperator.Invoke(xRef, yRef); - for (int i = 1; i < x.Length; i++) - { - result = TAggregationOperator.Invoke(result, - TBinaryOperator.Invoke( - Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i))); + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } } return result; } +#endif } /// @@ -1350,7 +2800,7 @@ static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -1591,7 +3041,7 @@ static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -1858,7 +3308,7 @@ static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -2265,7 +3715,7 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -2539,7 +3989,7 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -2842,7 +4292,7 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -3294,7 +4744,7 @@ static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remaind { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -3564,7 +5014,7 @@ static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remaind { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -3865,7 +5315,7 @@ static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remaind { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -4315,7 +5765,7 @@ static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref fl { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -4622,7 +6072,7 @@ static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref fl { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -4961,7 +6411,7 @@ static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref fl { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -5444,7 +6894,7 @@ static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRe { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -5747,7 +7197,7 @@ static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRe { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -6084,7 +7534,7 @@ static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRe { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -6567,7 +8017,7 @@ static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRe { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -6870,7 +8320,7 @@ static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRe { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -7207,7 +8657,7 @@ static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRe { // Compute by how many elements we're misaligned and adjust the pointers accordingly // - // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // Noting that we are only actually aligning dPtr. This is because unaligned stores // are more expensive than unaligned loads and aligning both is significantly more // complex. @@ -7590,10 +9040,15 @@ private static Vector512 FusedMultiplyAdd(Vector512 x, Vector512Aggregates all of the elements in the into a single value. /// Specifies the operation to be performed on each pair of values. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static float HorizontalAggregate(Vector128 x) where TAggregate : struct, IBinaryOperator => - TAggregate.Invoke( - TAggregate.Invoke(x[0], x[1]), - TAggregate.Invoke(x[2], x[3])); + private static float HorizontalAggregate(Vector128 x) where TAggregate : struct, IBinaryOperator + { + // We need to do log2(count) operations to compute the total sum + + x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(2, 3, 0, 1))); + x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(1, 0, 3, 2))); + + return x.ToScalar(); + } /// Aggregates all of the elements in the into a single value. /// Specifies the operation to be performed on each pair of values. @@ -7658,6 +9113,38 @@ private static float GetFirstNaN(Vector512 vector) /// Gets the base 2 logarithm of . private static float Log2(float x) => MathF.Log2(x); + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 CreateAlignmentMaskSingleVector128(int count) => + Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // first four floats in the row + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 CreateAlignmentMaskSingleVector256(int count) => + Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // first eight floats in the row + +#if NET8_0_OR_GREATER + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 CreateAlignmentMaskSingleVector512(int count) => + Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // all sixteen floats in the row +#endif + /// /// Gets a vector mask that will be all-ones-set for the last elements /// and zero for all other elements. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index 5e6e9ac6252e3..fefdeec80570a 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -93,63 +93,292 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan - private static float Aggregate( + private static unsafe float Aggregate( ReadOnlySpan x, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) where TTransformOperator : struct, IUnaryOperator where TAggregationOperator : struct, IAggregationOperator { - if (x.Length == 0) + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && transformOp.CanVectorize) { - return 0; + float result; + + if (remainder >= (uint)(Vector.Count)) + { + result = Vectorized(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = VectorizedSmall(ref xRef, remainder); + } + + return result; } - float result; + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, remainder); - if (Vector.IsHardwareAccelerated && transformOp.CanVectorize && x.Length >= Vector.Count) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, nuint length, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) { - ref float xRef = ref MemoryMarshal.GetReference(x); + float result = aggregationOp.IdentityValue; - // Load the first vector as the initial set of results - Vector resultVector = transformOp.Invoke(AsVector(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector.Count; - int i = Vector.Count; + for (nuint i = 0; i < length; i++) + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, (nint)(i)))); + } + + return result; + } + + static float Vectorized(ref float xRef, nuint remainder, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + Vector vresult = new Vector(aggregationOp.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = transformOp.Invoke(AsVector(ref xRef)); + Vector end = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - while (i <= oneVectorFromEnd) + nuint misalignment = 0; + + if (remainder > (uint)(Vector.Count * 8)) { - resultVector = aggregationOp.Invoke(resultVector, transformOp.Invoke(AsVector(ref xRef, i))); - i += Vector.Count; + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector)) - ((nuint)(xPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))); + vector2 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))); + vector3 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))); + vector4 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))); + vector2 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))); + vector3 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))); + vector4 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } } - // Process the last vector in the span, masking off elements already processed. - if (i != x.Length) + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector.ConditionalSelect(CreateAlignmentMaskSingleVector((int)(misalignment)), beg, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + nuint blocks = remainder / (nuint)(Vector.Count); + nuint trailing = remainder - (blocks * (nuint)(Vector.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) { - resultVector = aggregationOp.Invoke(resultVector, - Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - new Vector(aggregationOp.IdentityValue), - transformOp.Invoke(AsVector(ref xRef, x.Length - Vector.Count)))); + case 7: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 1))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector.ConditionalSelect(CreateRemainderMaskSingleVector((int)(trailing)), end, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, end); + break; + } } - // Aggregate the lanes in the vector back into the scalar result - result = resultVector[0]; - for (int f = 1; f < Vector.Count; f++) + float result = aggregationOp.IdentityValue; + + for (int i = 0; i < Vector.Count; i++) { - result = aggregationOp.Invoke(result, resultVector[f]); + result = aggregationOp.Invoke(result, vresult[i]); } return result; } - // Aggregate the remaining items in the input span. - result = transformOp.Invoke(x[0]); - for (int i = 1; i < x.Length; i++) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float VectorizedSmall(ref float xRef, nuint remainder, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) { - result = aggregationOp.Invoke(result, transformOp.Invoke(x[i])); - } + float result = aggregationOp.IdentityValue; - return result; + switch (remainder) + { + case 7: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 6))); + goto case 6; + } + + case 6: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 5))); + goto case 5; + } + + case 5: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 4))); + goto case 4; + } + + case 4: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 3))); + goto case 3; + } + + case 3: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } } /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. @@ -158,7 +387,7 @@ private static float Aggregate( /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. /// The aggregation is applied to the results of the binary operations on the pair-wise values. /// - private static float Aggregate( + private static unsafe float Aggregate( ReadOnlySpan x, ReadOnlySpan y, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) where TBinaryOperator : struct, IBinaryOperator where TAggregationOperator : struct, IAggregationOperator @@ -168,61 +397,317 @@ private static float Aggregate( ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); } - if (x.Length == 0) - { - return 0; - } + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); - float result; + nuint remainder = (uint)(x.Length); - if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count) + if (Vector.IsHardwareAccelerated) { - // Load the first vector as the initial set of results - Vector resultVector = binaryOp.Invoke(AsVector(ref xRef, 0), AsVector(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector.Count; - int i = Vector.Count; + float result; - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - while (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector.Count)) { - resultVector = aggregationOp.Invoke(resultVector, binaryOp.Invoke(AsVector(ref xRef, i), AsVector(ref yRef, i))); - i += Vector.Count; + result = Vectorized(ref xRef, ref yRef, remainder); } - - // Process the last vector in the spans, masking off elements already processed. - if (i != x.Length) + else { - resultVector = aggregationOp.Invoke(resultVector, - Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - new Vector(aggregationOp.IdentityValue), - binaryOp.Invoke( - AsVector(ref xRef, x.Length - Vector.Count), - AsVector(ref yRef, x.Length - Vector.Count)))); + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = VectorizedSmall(ref xRef, ref yRef, remainder); } - // Aggregate the lanes in the vector back into the scalar result - result = resultVector[0]; - for (int f = 1; f < Vector.Count; f++) + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, ref yRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, ref float yRef, nuint length, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + for (nuint i = 0; i < length; i++) { - result = aggregationOp.Invoke(result, resultVector[f]); + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)))); } return result; } - // Aggregate the remaining items in the input span. - result = binaryOp.Invoke(x[0], y[0]); - for (int i = 1; i < x.Length; i++) + static float Vectorized(ref float xRef, ref float yRef, nuint remainder, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) { - result = aggregationOp.Invoke(result, binaryOp.Invoke(x[i], y[i])); + Vector vresult = new Vector(aggregationOp.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = binaryOp.Invoke(AsVector(ref xRef), + AsVector(ref yRef)); + Vector end = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector)) - ((nuint)(xPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0))); + vector2 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1))); + vector3 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2))); + vector4 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4))); + vector2 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5))); + vector3 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6))); + vector4 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector.ConditionalSelect(CreateAlignmentMaskSingleVector((int)(misalignment)), beg, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + nuint blocks = remainder / (nuint)(Vector.Count); + nuint trailing = remainder - (blocks * (nuint)(Vector.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 1)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 1))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector.ConditionalSelect(CreateRemainderMaskSingleVector((int)(trailing)), end, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, end); + break; + } + } + + float result = aggregationOp.IdentityValue; + + for (int i = 0; i < Vector.Count; i++) + { + result = aggregationOp.Invoke(result, vresult[i]); + } + + return result; } - return result; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float VectorizedSmall(ref float xRef, ref float yRef, nuint remainder, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + switch (remainder) + { + case 7: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6))); + goto case 6; + } + + case 6: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5))); + goto case 5; + } + + case 5: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4))); + goto case 4; + } + + case 4: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3))); + goto case 3; + } + + case 3: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } } /// @@ -2256,7 +2741,6 @@ static void VectorizedSmall(ref float xRef, float y, ref float zRef, ref float d case 0: { - Debug.Assert(remainder == 0); break; } } @@ -2290,6 +2774,19 @@ private static Vector IsNegative(Vector f) => /// Gets the base 2 logarithm of . private static float Log2(float x) => MathF.Log(x, 2); + /// + /// Gets a vector mask that will be all-ones-set for the first elements + /// and zero for all other elements. + /// + private static Vector CreateAlignmentMaskSingleVector(int count) + { + Debug.Assert(Vector.Count is 4 or 8 or 16); + + return AsVector( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (count * 16)); + } + /// /// Gets a vector mask that will be all-ones-set for the last elements /// and zero for all other elements.