Skip to content

Commit

Permalink
Vectorize TensorPrimitives.ConvertToHalf (#92715)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Sep 29, 2023
1 parent 0cc9f21 commit 56251ec
Show file tree
Hide file tree
Showing 2 changed files with 344 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,301 @@ public static void ConvertToHalf(ReadOnlySpan<float> source, Span<Half> destinat
ThrowHelper.ThrowArgument_DestinationTooShort();
}

for (int i = 0; i < source.Length; i++)
ref float sourceRef = ref MemoryMarshal.GetReference(source);
ref ushort destinationRef = ref Unsafe.As<Half, ushort>(ref MemoryMarshal.GetReference(destination));
int i = 0, twoVectorsFromEnd;

#if NET8_0_OR_GREATER
if (Vector512.IsHardwareAccelerated)
{
destination[i] = (Half)source[i];
twoVectorsFromEnd = source.Length - (Vector512<float>.Count * 2);
if (i <= twoVectorsFromEnd)
{
// Loop handling two input vectors / one output vector at a time.
do
{
Vector512<uint> lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
Vector512<uint> upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512<float>.Count)));
Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);

i += Vector512<float>.Count * 2;
}
while (i <= twoVectorsFromEnd);

// Handle any remaining elements with final vectors.
if (i != source.Length)
{
i = source.Length - (Vector512<float>.Count * 2);

Vector512<uint> lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
Vector512<uint> upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512<float>.Count)));
Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
}

return;
}
}
#endif

if (Vector256.IsHardwareAccelerated)
{
twoVectorsFromEnd = source.Length - (Vector256<float>.Count * 2);
if (i <= twoVectorsFromEnd)
{
// Loop handling two input vectors / one output vector at a time.
do
{
Vector256<uint> lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
Vector256<uint> upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256<float>.Count)));
Vector256<ushort> halfs = Vector256.Narrow(lower, upper);
halfs.StoreUnsafe(ref destinationRef, (uint)i);

i += Vector256<float>.Count * 2;
}
while (i <= twoVectorsFromEnd);

// Handle any remaining elements with final vectors.
if (i != source.Length)
{
i = source.Length - (Vector256<float>.Count * 2);

Vector256<uint> lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
Vector256<uint> upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256<float>.Count)));
Vector256.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
}

return;
}
}

if (Vector128.IsHardwareAccelerated)
{
twoVectorsFromEnd = source.Length - (Vector128<float>.Count * 2);
if (i <= twoVectorsFromEnd)
{
// Loop handling two input vectors / one output vector at a time.
do
{
Vector128<uint> lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
Vector128<uint> upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128<float>.Count)));
Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);

i += Vector128<float>.Count * 2;
}
while (i <= twoVectorsFromEnd);

// Handle any remaining elements with final vectors.
if (i != source.Length)
{
i = source.Length - (Vector128<float>.Count * 2);

Vector128<uint> lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
Vector128<uint> upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128<float>.Count)));
Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
}

return;
}
}

while (i < source.Length)
{
Unsafe.Add(ref destinationRef, i) = BitConverter.HalfToUInt16Bits((Half)Unsafe.Add(ref sourceRef, i));
i++;
}

// This implements a vectorized version of the `explicit operator Half(float value) operator`.
// See detailed description of the algorithm used here:
// https://github.com/dotnet/runtime/blob/ca8d6f0420096831766ec11c7d400e4f7ccc7a34/src/libraries/System.Private.CoreLib/src/System/Half.cs#L606-L714
// The cast operator converts a float to a Half represented as a UInt32, then narrows to a UInt16, and reinterpret casts to Half.
// This does the same, with an input VectorXx<float> and an output VectorXx<uint>.
// Loop handling two input vectors at a time; each input float is double the size of each output Half,
// so we need two vectors of floats to produce one vector of Halfs. Half isn't supported in VectorXx<T>,
// so we convert the VectorXx<float> to a VectorXx<uint>, and the caller then uses this twice, narrows the combination
// into a VectorXx<ushort>, and then saves that out to the destination `ref Half` reinterpreted as `ref ushort`.

#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
const uint MinExp = 0x3880_0000u; // Minimum exponent for rounding
const uint Exponent126 = 0x3f00_0000u; // Exponent displacement #1
const uint SingleBiasedExponentMask = 0x7F80_0000; // float.BiasedExponentMask; // Exponent mask
const uint Exponent13 = 0x0680_0000u; // Exponent displacement #2
const float MaxHalfValueBelowInfinity = 65520.0f; // Maximum value that is not Infinity in Half
const uint ExponentMask = 0x7C00; // Mask for exponent bits in Half
const uint SingleSignMask = 0x8000_0000u; // float.SignMask; // Mask for sign bit in float
#pragma warning restore IDE0059

static Vector128<uint> SingleToHalfAsWidenedUInt32_Vector128(Vector128<float> value)
{
Vector128<uint> bitValue = value.AsUInt32();

// Extract sign bit
Vector128<uint> sign = Vector128.ShiftRightLogical(bitValue & Vector128.Create(SingleSignMask), 16);

// Detecting NaN (0u if value is NaN; otherwise, ~0u)
Vector128<uint> realMask = Vector128.Equals(value, value).AsUInt32();

// Clear sign bit
value = Vector128.Abs(value);

// Rectify values that are Infinity in Half.
value = Vector128.Min(Vector128.Create(MaxHalfValueBelowInfinity), value);

// Rectify lower exponent
Vector128<uint> exponentOffset0 = Vector128.Max(value, Vector128.Create(MinExp).AsSingle()).AsUInt32();

// Extract exponent
exponentOffset0 &= Vector128.Create(SingleBiasedExponentMask);

// Add exponent by 13
exponentOffset0 += Vector128.Create(Exponent13);

// Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction)
value += exponentOffset0.AsSingle();
bitValue = value.AsUInt32();

// Only exponent bits will be modified if NaN
Vector128<uint> maskedHalfExponentForNaN = ~realMask & Vector128.Create(ExponentMask);

// Subtract exponent by 126
bitValue -= Vector128.Create(Exponent126);

// Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part.
Vector128<uint> newExponent = Vector128.ShiftRightLogical(bitValue, 13);

// Clear the fraction parts if the value was NaN.
bitValue &= realMask;

// Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow.
bitValue += newExponent;

// Clear exponents if value is NaN
bitValue &= ~maskedHalfExponentForNaN;

// Merge sign bit with possible NaN exponent
Vector128<uint> signAndMaskedExponent = maskedHalfExponentForNaN | sign;

// Merge sign bit and possible NaN exponent
bitValue |= signAndMaskedExponent;

// The final result
return bitValue;
}

static Vector256<uint> SingleToHalfAsWidenedUInt32_Vector256(Vector256<float> value)
{
Vector256<uint> bitValue = value.AsUInt32();

// Extract sign bit
Vector256<uint> sign = Vector256.ShiftRightLogical(bitValue & Vector256.Create(SingleSignMask), 16);

// Detecting NaN (0u if value is NaN; otherwise, ~0u)
Vector256<uint> realMask = Vector256.Equals(value, value).AsUInt32();

// Clear sign bit
value = Vector256.Abs(value);

// Rectify values that are Infinity in Half.
value = Vector256.Min(Vector256.Create(MaxHalfValueBelowInfinity), value);

// Rectify lower exponent
Vector256<uint> exponentOffset0 = Vector256.Max(value, Vector256.Create(MinExp).AsSingle()).AsUInt32();

// Extract exponent
exponentOffset0 &= Vector256.Create(SingleBiasedExponentMask);

// Add exponent by 13
exponentOffset0 += Vector256.Create(Exponent13);

// Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction)
value += exponentOffset0.AsSingle();
bitValue = value.AsUInt32();

// Only exponent bits will be modified if NaN
Vector256<uint> maskedHalfExponentForNaN = ~realMask & Vector256.Create(ExponentMask);

// Subtract exponent by 126
bitValue -= Vector256.Create(Exponent126);

// Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part.
Vector256<uint> newExponent = Vector256.ShiftRightLogical(bitValue, 13);

// Clear the fraction parts if the value was NaN.
bitValue &= realMask;

// Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow.
bitValue += newExponent;

// Clear exponents if value is NaN
bitValue &= ~maskedHalfExponentForNaN;

// Merge sign bit with possible NaN exponent
Vector256<uint> signAndMaskedExponent = maskedHalfExponentForNaN | sign;

// Merge sign bit and possible NaN exponent
bitValue |= signAndMaskedExponent;

// The final result
return bitValue;
}

#if NET8_0_OR_GREATER
static Vector512<uint> SingleToHalfAsWidenedUInt32_Vector512(Vector512<float> value)
{
Vector512<uint> bitValue = value.AsUInt32();

// Extract sign bit
Vector512<uint> sign = Vector512.ShiftRightLogical(bitValue & Vector512.Create(SingleSignMask), 16);

// Detecting NaN (0u if value is NaN; otherwise, ~0u)
Vector512<uint> realMask = Vector512.Equals(value, value).AsUInt32();

// Clear sign bit
value = Vector512.Abs(value);

// Rectify values that are Infinity in Half.
value = Vector512.Min(Vector512.Create(MaxHalfValueBelowInfinity), value);

// Rectify lower exponent
Vector512<uint> exponentOffset0 = Vector512.Max(value, Vector512.Create(MinExp).AsSingle()).AsUInt32();

// Extract exponent
exponentOffset0 &= Vector512.Create(SingleBiasedExponentMask);

// Add exponent by 13
exponentOffset0 += Vector512.Create(Exponent13);

// Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction)
value += exponentOffset0.AsSingle();
bitValue = value.AsUInt32();

// Only exponent bits will be modified if NaN
Vector512<uint> maskedHalfExponentForNaN = ~realMask & Vector512.Create(ExponentMask);

// Subtract exponent by 126
bitValue -= Vector512.Create(Exponent126);

// Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part.
Vector512<uint> newExponent = Vector512.ShiftRightLogical(bitValue, 13);

// Clear the fraction parts if the value was NaN.
bitValue &= realMask;

// Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow.
bitValue += newExponent;

// Clear exponents if value is NaN
bitValue &= ~maskedHalfExponentForNaN;

// Merge sign bit with possible NaN exponent
Vector512<uint> signAndMaskedExponent = maskedHalfExponentForNaN | sign;

// Merge sign bit and possible NaN exponent
bitValue |= signAndMaskedExponent;

// The final result
return bitValue;
}
#endif
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ public static void ConvertToHalf(int tensorLength)
using BoundedMemory<float> source = CreateAndFillTensor(tensorLength);
foreach (int destLength in new[] { source.Length, source.Length + 1 })
{
Half[] destination = new Half[destLength];
using BoundedMemory<Half> destination = BoundedMemory.Allocate<Half>(destLength);
destination.Span.Fill(Half.Zero);

TensorPrimitives.ConvertToHalf(source, destination);

Expand All @@ -35,6 +36,28 @@ public static void ConvertToHalf(int tensorLength)
}
}

[Theory]
[MemberData(nameof(TensorLengths))]
public static void ConvertToHalf_SpecialValues(int tensorLength)
{
using BoundedMemory<float> source = CreateAndFillTensor(tensorLength);
using BoundedMemory<Half> destination = BoundedMemory.Allocate<Half>(tensorLength);

// NaN, infinities, and 0s
source[s_random.Next(source.Length)] = float.NaN;
source[s_random.Next(source.Length)] = float.PositiveInfinity;
source[s_random.Next(source.Length)] = float.NegativeInfinity;
source[s_random.Next(source.Length)] = 0;
source[s_random.Next(source.Length)] = float.NegativeZero;

TensorPrimitives.ConvertToHalf(source, destination);

for (int i = 0; i < source.Length; i++)
{
Assert.Equal((Half)source[i], destination[i]);
}
}

[Theory]
[MemberData(nameof(TensorLengths))]
public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength)
Expand All @@ -51,7 +74,7 @@ public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength)
[MemberData(nameof(TensorLengthsIncluding0))]
public static void ConvertToSingle(int tensorLength)
{
Half[] source = new Half[tensorLength];
using BoundedMemory<Half> source = BoundedMemory.Allocate<Half>(tensorLength);
for (int i = 0; i < source.Length; i++)
{
source[i] = (Half)s_random.NextSingle();
Expand All @@ -78,6 +101,32 @@ public static void ConvertToSingle(int tensorLength)
}
}
}
[Theory]
[MemberData(nameof(TensorLengths))]
public static void ConvertToSingle_SpecialValues(int tensorLength)
{
using BoundedMemory<Half> source = BoundedMemory.Allocate<Half>(tensorLength);
for (int i = 0; i < source.Length; i++)
{
source[i] = (Half)s_random.NextSingle();
}

using BoundedMemory<float> destination = CreateTensor(tensorLength);

// NaN, infinities, and 0s
source[s_random.Next(source.Length)] = Half.NaN;
source[s_random.Next(source.Length)] = Half.PositiveInfinity;
source[s_random.Next(source.Length)] = Half.NegativeInfinity;
source[s_random.Next(source.Length)] = Half.Zero;
source[s_random.Next(source.Length)] = Half.NegativeZero;

TensorPrimitives.ConvertToSingle(source, destination);

for (int i = 0; i < source.Length; i++)
{
Assert.Equal((float)source[i], destination[i]);
}
}

[Theory]
[MemberData(nameof(TensorLengths))]
Expand Down

0 comments on commit 56251ec

Please sign in to comment.