-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP - Speed improvements to resize convolution (no vpermps w/ FMA) #3
base: main
Are you sure you want to change the base?
Changes from all commits
cd1b77a
36fefc6
4728b97
8c19a97
7840665
58f6afb
0594035
e60dd07
72813ee
6e84a34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1097,4 +1097,79 @@ public static nuint Vector512Count<TVector>(this Span<float> span) | |
public static nuint Vector512Count<TVector>(int length) | ||
where TVector : struct | ||
=> (uint)length / (uint)Vector512<TVector>.Count; | ||
|
||
/// <summary> | ||
/// Normalizes the values in a given <see cref="Span{T}"/>. | ||
/// </summary> | ||
/// <param name="span">The sequence of <see cref="float"/> values to normalize.</param> | ||
/// <param name="sum">The sum of the values in <paramref name="span"/>.</param> | ||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | ||
public static void Normalize(Span<float> span, float sum) | ||
{ | ||
if (Vector512.IsHardwareAccelerated) | ||
{ | ||
ref float startRef = ref MemoryMarshal.GetReference(span); | ||
ref float endRef = ref Unsafe.Add(ref startRef, span.Length & ~15); | ||
Vector512<float> sum512 = Vector512.Create(sum); | ||
|
||
while (Unsafe.IsAddressLessThan(ref startRef, ref endRef)) | ||
{ | ||
Unsafe.As<float, Vector512<float>>(ref startRef) /= sum512; | ||
startRef = ref Unsafe.Add(ref startRef, (nuint)16); | ||
} | ||
|
||
if ((span.Length & 15) >= 8) | ||
{ | ||
Unsafe.As<float, Vector256<float>>(ref startRef) /= sum512.GetLower(); | ||
startRef = ref Unsafe.Add(ref startRef, (nuint)8); | ||
} | ||
|
||
if ((span.Length & 7) >= 4) | ||
{ | ||
Unsafe.As<float, Vector128<float>>(ref startRef) /= sum512.GetLower().GetLower(); | ||
startRef = ref Unsafe.Add(ref startRef, (nuint)4); | ||
} | ||
Comment on lines
+1121
to
+1131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: The remainder handling logic could potentially be simplified using Vector128/256/512.Count constants |
||
|
||
endRef = ref Unsafe.Add(ref startRef, span.Length & 3); | ||
|
||
while (Unsafe.IsAddressLessThan(ref startRef, ref endRef)) | ||
{ | ||
startRef /= sum; | ||
startRef = ref Unsafe.Add(ref startRef, (nuint)1); | ||
} | ||
} | ||
else if (Vector256.IsHardwareAccelerated) | ||
{ | ||
ref float startRef = ref MemoryMarshal.GetReference(span); | ||
ref float endRef = ref Unsafe.Add(ref startRef, span.Length & ~7); | ||
Vector256<float> sum256 = Vector256.Create(sum); | ||
|
||
while (Unsafe.IsAddressLessThan(ref startRef, ref endRef)) | ||
{ | ||
Unsafe.As<float, Vector256<float>>(ref startRef) /= sum256; | ||
startRef = ref Unsafe.Add(ref startRef, (nuint)8); | ||
} | ||
|
||
if ((span.Length & 7) >= 4) | ||
{ | ||
Unsafe.As<float, Vector128<float>>(ref startRef) /= sum256.GetLower(); | ||
startRef = ref Unsafe.Add(ref startRef, (nuint)4); | ||
} | ||
|
||
endRef = ref Unsafe.Add(ref startRef, span.Length & 3); | ||
|
||
while (Unsafe.IsAddressLessThan(ref startRef, ref endRef)) | ||
{ | ||
startRef /= sum; | ||
startRef = ref Unsafe.Add(ref startRef, (nuint)1); | ||
} | ||
} | ||
else | ||
{ | ||
for (int i = 0; i < span.Length; i++) | ||
{ | ||
span[i] /= sum; | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -245,6 +245,44 @@ public static Vector128<short> PackSignedSaturate(Vector128<int> left, Vector128 | |
return default; | ||
} | ||
|
||
/// <summary> | ||
/// Performs a multiply-add operation on three vectors, where each element of the resulting vector is the | ||
/// product of corresponding elements in <paramref name="a"/> and <paramref name="b"/> added to the | ||
/// corresponding element in <paramref name="c"/>. | ||
/// If the CPU supports FMA (Fused Multiply-Add) instructions, the operation is performed as a single | ||
/// fused operation for better performance and precision. | ||
/// </summary> | ||
/// <param name="a">The first vector of single-precision floating-point numbers to be multiplied.</param> | ||
/// <param name="b">The second vector of single-precision floating-point numbers to be multiplied.</param> | ||
/// <param name="c">The vector of single-precision floating-point numbers to be added to the product of | ||
/// <paramref name="a"/> and <paramref name="b"/>.</param> | ||
/// <returns> | ||
/// A <see cref="Vector128{Single}"/> where each element is the result of multiplying the corresponding elements | ||
/// of <paramref name="a"/> and <paramref name="b"/>, and then adding the corresponding element from <paramref name="c"/>. | ||
/// </returns> | ||
/// <remarks> | ||
/// If the FMA (Fused Multiply-Add) instruction set is supported by the CPU, the operation is performed using | ||
/// <see cref="Fma.MultiplyAdd(Vector128{float}, Vector128{float}, Vector128{float})"/>. This approach can result | ||
/// in slightly different results compared to performing the multiplication and addition separately due to | ||
/// differences in how floating-point | ||
/// rounding is handled. | ||
/// <para> | ||
/// If FMA is not supported, the operation is performed as a separate multiplication and addition. This might lead | ||
/// to a minor difference in precision compared to the fused operation, particularly in cases where numerical accuracy | ||
/// is critical. | ||
/// </para> | ||
/// </remarks> | ||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | ||
public static Vector128<float> MultiplyAddEstimate(Vector128<float> a, Vector128<float> b, Vector128<float> c) | ||
{ | ||
if (Fma.IsSupported) | ||
{ | ||
return Fma.MultiplyAdd(a, b, c); | ||
} | ||
|
||
return (a * b) + c; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: consider adding parentheses around (a * b) to make operator precedence explicit |
||
} | ||
|
||
[DoesNotReturn] | ||
private static void ThrowUnreachableException() => throw new UnreachableException(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,6 +110,44 @@ public static Vector256<int> ConvertToInt32RoundToEven(Vector256<float> vector) | |
return Vector256.ConvertToInt32(val_2p23_f32 | sign); | ||
} | ||
|
||
/// <summary> | ||
/// Performs a multiply-add operation on three vectors, where each element of the resulting vector is the | ||
/// product of corresponding elements in <paramref name="a"/> and <paramref name="b"/> added to the | ||
/// corresponding element in <paramref name="c"/>. | ||
/// If the CPU supports FMA (Fused Multiply-Add) instructions, the operation is performed as a single | ||
/// fused operation for better performance and precision. | ||
/// </summary> | ||
/// <param name="a">The first vector of single-precision floating-point numbers to be multiplied.</param> | ||
/// <param name="b">The second vector of single-precision floating-point numbers to be multiplied.</param> | ||
/// <param name="c">The vector of single-precision floating-point numbers to be added to the product of | ||
/// <paramref name="a"/> and <paramref name="b"/>.</param> | ||
/// <returns> | ||
/// A <see cref="Vector256{Single}"/> where each element is the result of multiplying the corresponding elements | ||
/// of <paramref name="a"/> and <paramref name="b"/>, and then adding the corresponding element from <paramref name="c"/>. | ||
/// </returns> | ||
/// <remarks> | ||
/// If the FMA (Fused Multiply-Add) instruction set is supported by the CPU, the operation is performed using | ||
/// <see cref="Fma.MultiplyAdd(Vector256{float}, Vector256{float}, Vector256{float})"/>. This approach can result | ||
/// in slightly different results compared to performing the multiplication and addition separately due to | ||
/// differences in how floating-point | ||
/// rounding is handled. | ||
Comment on lines
+132
to
+133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: line break in middle of 'floating-point' creates awkward documentation formatting |
||
/// <para> | ||
/// If FMA is not supported, the operation is performed as a separate multiplication and addition. This might lead | ||
/// to a minor difference in precision compared to the fused operation, particularly in cases where numerical accuracy | ||
/// is critical. | ||
/// </para> | ||
/// </remarks> | ||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | ||
public static Vector256<float> MultiplyAddEstimate(Vector256<float> a, Vector256<float> b, Vector256<float> c) | ||
{ | ||
if (Fma.IsSupported) | ||
{ | ||
return Fma.MultiplyAdd(a, b, c); | ||
} | ||
|
||
return (a * b) + c; | ||
} | ||
|
||
[DoesNotReturn] | ||
private static void ThrowUnreachableException() => throw new UnreachableException(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
using System.Diagnostics; | ||
using System.Diagnostics.CodeAnalysis; | ||
using System.Numerics; | ||
using System.Runtime.CompilerServices; | ||
using System.Runtime.Intrinsics; | ||
using System.Runtime.Intrinsics.X86; | ||
|
@@ -110,6 +111,43 @@ public static Vector512<int> ConvertToInt32RoundToEven(Vector512<float> vector) | |
return Vector512.ConvertToInt32(val_2p23_f32 | sign); | ||
} | ||
|
||
/// <summary> | ||
/// Performs a multiply-add operation on three vectors, where each element of the resulting vector is the | ||
/// product of corresponding elements in <paramref name="a"/> and <paramref name="b"/> added to the | ||
/// corresponding element in <paramref name="c"/>. | ||
/// If the CPU supports FMA (Fused Multiply-Add) instructions, the operation is performed as a single | ||
/// fused operation for better performance and precision. | ||
/// </summary> | ||
/// <param name="a">The first vector of single-precision floating-point numbers to be multiplied.</param> | ||
/// <param name="b">The second vector of single-precision floating-point numbers to be multiplied.</param> | ||
/// <param name="c">The vector of single-precision floating-point numbers to be added to the product of | ||
/// <paramref name="a"/> and <paramref name="b"/>.</param> | ||
/// <returns> | ||
/// A <see cref="Vector512{Single}"/> where each element is the result of multiplying the corresponding elements | ||
/// of <paramref name="a"/> and <paramref name="b"/>, and then adding the corresponding element from <paramref name="c"/>. | ||
/// </returns> | ||
/// <remarks> | ||
/// If the FMA (Fused Multiply-Add) instruction set is supported by the CPU, the operation is performed using | ||
/// <see cref="Fma.MultiplyAdd(Vector256{float}, Vector256{float}, Vector256{float})"/> against the upper and lower | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: Documentation refers to Fma.MultiplyAdd with Vector256 but code uses Avx512F.FusedMultiplyAdd with Vector512 |
||
/// buts. This approach can result in slightly different results compared to performing the multiplication and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: Typo in documentation: 'buts' should be 'bits' |
||
/// addition separately due to differences in how floating-point rounding is handled. | ||
/// <para> | ||
/// If FMA is not supported, the operation is performed as a separate multiplication and addition. This might lead | ||
/// to a minor difference in precision compared to the fused operation, particularly in cases where numerical accuracy | ||
/// is critical. | ||
/// </para> | ||
/// </remarks> | ||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | ||
public static Vector512<float> MultiplyAddEstimate(Vector512<float> a, Vector512<float> b, Vector512<float> c) | ||
{ | ||
if (Avx512F.IsSupported) | ||
{ | ||
return Avx512F.FusedMultiplyAdd(a, b, c); | ||
} | ||
Comment on lines
+143
to
+146
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Consider checking for Fma.IsSupported before Avx512F.IsSupported for better compatibility with older processors |
||
|
||
return (a + b) * c; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Incorrect fallback implementation - should be (a * b) + c to match the documented behavior and FMA operation |
||
} | ||
|
||
[DoesNotReturn] | ||
private static void ThrowUnreachableException() => throw new UnreachableException(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Consider adding a check for sum == 0 to avoid division by zero