Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize TensorPrimitives.ConvertToSingle #92779

Merged
merged 3 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,256 @@ public static void ConvertToSingle(ReadOnlySpan<Half> source, Span<float> destin
ThrowHelper.ThrowArgument_DestinationTooShort();
}

for (int i = 0; i < source.Length; i++)
ref short sourceRef = ref Unsafe.As<Half, short>(ref MemoryMarshal.GetReference(source));
ref float destinationRef = ref MemoryMarshal.GetReference(destination);
int i = 0, oneVectorFromEnd;

#if NET8_0_OR_GREATER
if (Vector512.IsHardwareAccelerated)
{
oneVectorFromEnd = source.Length - Vector512<short>.Count;
if (i <= oneVectorFromEnd)
{
// Loop handling one input vector / two output vectors at a time.
do
{
(Vector512<int> lower, Vector512<int> upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512<float>.Count));

i += Vector512<short>.Count;
}
while (i <= oneVectorFromEnd);

// Handle any remaining elements with a final input vector.
if (i != source.Length)
{
i = source.Length - Vector512<short>.Count;

(Vector512<int> lower, Vector512<int> upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512<float>.Count));
}

return;
}
}
#endif

if (Vector256.IsHardwareAccelerated)
{
oneVectorFromEnd = source.Length - Vector256<short>.Count;
if (i <= oneVectorFromEnd)
{
// Loop handling one input vector / two output vectors at a time.
do
{
(Vector256<int> lower, Vector256<int> upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256<float>.Count));

i += Vector256<short>.Count;
}
while (i <= oneVectorFromEnd);

// Handle any remaining elements with a final input vector.
if (i != source.Length)
{
i = source.Length - Vector256<short>.Count;

(Vector256<int> lower, Vector256<int> upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256<float>.Count));
}

return;
}
}

if (Vector128.IsHardwareAccelerated)
{
destination[i] = (float)source[i];
oneVectorFromEnd = source.Length - Vector128<short>.Count;
if (i <= oneVectorFromEnd)
{
// Loop handling one input vector / two output vectors at a time.
do
{
(Vector128<int> lower, Vector128<int> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128<float>.Count));

i += Vector128<short>.Count;
}
while (i <= oneVectorFromEnd);

// Handle any remaining elements with a final input vector.
if (i != source.Length)
{
i = source.Length - Vector128<short>.Count;

(Vector128<int> lower, Vector128<int> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128<float>.Count));
}

return;
}
}

while (i < source.Length)
{
Unsafe.Add(ref destinationRef, i) = (float)Unsafe.As<short, Half>(ref Unsafe.Add(ref sourceRef, i));
i++;
}

// This implements a vectorized version of the `explicit operator float(Half value) operator`.
// See detailed description of the algorithm used here:
// https://github.com/dotnet/runtime/blob/3bf40a378f00cb5bf18ff62796bc7097719b974c/src/libraries/System.Private.CoreLib/src/System/Half.cs#L1010-L1040
// The cast operator converts a Half represented as uint to a float. This does the same, with an input VectorXx<uint> and an output VectorXx<float>.
// The VectorXx<uint> is created by reading a vector of Halfs as a VectorXx<short> then widened to two VectorXx<int>s and cast to VectorXx<uint>s.
// We loop handling one input vector at a time, producing two output float vectors.

#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single
const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13)
const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single
const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half
const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half
#pragma warning restore IDE0059

static Vector128<float> HalfAsWidenedUInt32ToSingle_Vector128(Vector128<uint> value)
{
// Extract sign bit of value
Vector128<uint> sign = value & Vector128.Create(SingleSignMask);

// Copy sign bit to upper bits
Vector128<uint> bitValueInProcess = value;

// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
Vector128<uint> offsetExponent = bitValueInProcess & Vector128.Create(HalfExponentMask);

// ~0u when value is subnormal, 0 otherwise
Vector128<uint> subnormalMask = Vector128.Equals(offsetExponent, Vector128<uint>.Zero);

// ~0u when value is either Infinity or NaN, 0 otherwise
Vector128<uint> infinityOrNaNMask = Vector128.Equals(offsetExponent, Vector128.Create(HalfExponentMask));

// 0x3880_0000u if value is subnormal, 0 otherwise
Vector128<uint> maskedExponentLowerBound = subnormalMask & Vector128.Create(ExponentLowerBound);

// 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise
Vector128<uint> offsetMaskedExponentLowerBound = Vector128.Create(ExponentOffset) | maskedExponentLowerBound;

// Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single)
bitValueInProcess = Vector128.ShiftLeft(bitValueInProcess, 13);

// Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN
offsetMaskedExponentLowerBound = Vector128.ConditionalSelect(Vector128.Equals(infinityOrNaNMask, Vector128<uint>.Zero),
offsetMaskedExponentLowerBound,
Vector128.ShiftLeft(offsetMaskedExponentLowerBound, 1));

// Extract exponent bits and fraction bits of value
bitValueInProcess &= Vector128.Create(HalfToSingleBitsMask);

// Adjust exponent to match the range of exponent
bitValueInProcess += offsetMaskedExponentLowerBound;

// If value is subnormal, remove unnecessary 1 on top of fraction bits.
Vector128<uint> absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32();

// Merge sign bit with rest
return (absoluteValue | sign).AsSingle();
}

static Vector256<float> HalfAsWidenedUInt32ToSingle_Vector256(Vector256<uint> value)
{
// Extract sign bit of value
Vector256<uint> sign = value & Vector256.Create(SingleSignMask);

// Copy sign bit to upper bits
Vector256<uint> bitValueInProcess = value;

// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
Vector256<uint> offsetExponent = bitValueInProcess & Vector256.Create(HalfExponentMask);

// ~0u when value is subnormal, 0 otherwise
Vector256<uint> subnormalMask = Vector256.Equals(offsetExponent, Vector256<uint>.Zero);

// ~0u when value is either Infinity or NaN, 0 otherwise
Vector256<uint> infinityOrNaNMask = Vector256.Equals(offsetExponent, Vector256.Create(HalfExponentMask));

// 0x3880_0000u if value is subnormal, 0 otherwise
Vector256<uint> maskedExponentLowerBound = subnormalMask & Vector256.Create(ExponentLowerBound);

// 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise
Vector256<uint> offsetMaskedExponentLowerBound = Vector256.Create(ExponentOffset) | maskedExponentLowerBound;

// Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single)
bitValueInProcess = Vector256.ShiftLeft(bitValueInProcess, 13);

// Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN
offsetMaskedExponentLowerBound = Vector256.ConditionalSelect(Vector256.Equals(infinityOrNaNMask, Vector256<uint>.Zero),
offsetMaskedExponentLowerBound,
Vector256.ShiftLeft(offsetMaskedExponentLowerBound, 1));

// Extract exponent bits and fraction bits of value
bitValueInProcess &= Vector256.Create(HalfToSingleBitsMask);

// Adjust exponent to match the range of exponent
bitValueInProcess += offsetMaskedExponentLowerBound;

// If value is subnormal, remove unnecessary 1 on top of fraction bits.
Vector256<uint> absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32();

// Merge sign bit with rest
return (absoluteValue | sign).AsSingle();
}

#if NET8_0_OR_GREATER
static Vector512<float> HalfAsWidenedUInt32ToSingle_Vector512(Vector512<uint> value)
{
// Extract sign bit of value
Vector512<uint> sign = value & Vector512.Create(SingleSignMask);

// Copy sign bit to upper bits
Vector512<uint> bitValueInProcess = value;

// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
Vector512<uint> offsetExponent = bitValueInProcess & Vector512.Create(HalfExponentMask);

// ~0u when value is subnormal, 0 otherwise
Vector512<uint> subnormalMask = Vector512.Equals(offsetExponent, Vector512<uint>.Zero);

// ~0u when value is either Infinity or NaN, 0 otherwise
Vector512<uint> infinityOrNaNMask = Vector512.Equals(offsetExponent, Vector512.Create(HalfExponentMask));

// 0x3880_0000u if value is subnormal, 0 otherwise
Vector512<uint> maskedExponentLowerBound = subnormalMask & Vector512.Create(ExponentLowerBound);

// 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise
Vector512<uint> offsetMaskedExponentLowerBound = Vector512.Create(ExponentOffset) | maskedExponentLowerBound;

// Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single)
bitValueInProcess = Vector512.ShiftLeft(bitValueInProcess, 13);

// Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN
offsetMaskedExponentLowerBound = Vector512.ConditionalSelect(Vector512.Equals(infinityOrNaNMask, Vector512<uint>.Zero),
offsetMaskedExponentLowerBound,
Vector512.ShiftLeft(offsetMaskedExponentLowerBound, 1));

// Extract exponent bits and fraction bits of value
bitValueInProcess &= Vector512.Create(HalfToSingleBitsMask);

// Adjust exponent to match the range of exponent
bitValueInProcess += offsetMaskedExponentLowerBound;

// If value is subnormal, remove unnecessary 1 on top of fraction bits.
Vector512<uint> absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32();

// Merge sign bit with rest
return (absoluteValue | sign).AsSingle();
}
#endif
}

private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
Expand Down
4 changes: 2 additions & 2 deletions src/libraries/System.Private.CoreLib/src/System/Half.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1044,15 +1044,15 @@ public static explicit operator float(Half value)
// BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13)
const uint ExponentOffset = 0x3800_0000u;
// Mask for sign bit in Single
const uint FloatSignMask = float.SignMask;
const uint SingleSignMask = float.SignMask;
// Mask for exponent bits in Half
const uint HalfExponentMask = BiasedExponentMask;
// Mask for bits in Single converted from Half
const int HalfToSingleBitsMask = 0x0FFF_E000;
// Extract the internal representation of value
short valueInInt16Bits = BitConverter.HalfToInt16Bits(value);
// Extract sign bit of value
uint sign = (uint)(int)valueInInt16Bits & FloatSignMask;
uint sign = (uint)(int)valueInInt16Bits & SingleSignMask;
// Copy sign bit to upper bits
uint bitValueInProcess = (uint)valueInInt16Bits;
// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
Expand Down