Skip to content

Commit

Permalink
Condense ElementWiseSelect (#97282)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Jan 24, 2024
1 parent 5654425 commit d523506
Showing 1 changed file with 13 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13834,16 +13834,13 @@ private static Vector128<T> ElementWiseSelect<T>(Vector128<T> mask, Vector128<T>
{
if (Sse41.IsSupported)
{
if (typeof(T) == typeof(byte)) return Sse41.BlendVariable(left.AsByte(), right.AsByte(), (~mask).AsByte()).As<byte, T>();
if (typeof(T) == typeof(sbyte)) return Sse41.BlendVariable(left.AsSByte(), right.AsSByte(), (~mask).AsSByte()).As<sbyte, T>();
if (typeof(T) == typeof(ushort)) return Sse41.BlendVariable(left.AsUInt16(), right.AsUInt16(), (~mask).AsUInt16()).As<ushort, T>();
if (typeof(T) == typeof(short)) return Sse41.BlendVariable(left.AsInt16(), right.AsInt16(), (~mask).AsInt16()).As<short, T>();
if (typeof(T) == typeof(uint)) return Sse41.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As<uint, T>();
if (typeof(T) == typeof(int)) return Sse41.BlendVariable(left.AsInt32(), right.AsInt32(), (~mask).AsInt32()).As<int, T>();
if (typeof(T) == typeof(ulong)) return Sse41.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As<ulong, T>();
if (typeof(T) == typeof(long)) return Sse41.BlendVariable(left.AsInt64(), right.AsInt64(), (~mask).AsInt64()).As<long, T>();
if (typeof(T) == typeof(float)) return Sse41.BlendVariable(left.AsSingle(), right.AsSingle(), (~mask).AsSingle()).As<float, T>();
if (typeof(T) == typeof(double)) return Sse41.BlendVariable(left.AsDouble(), right.AsDouble(), (~mask).AsDouble()).As<double, T>();

if (sizeof(T) == 1) return Sse41.BlendVariable(left.AsByte(), right.AsByte(), (~mask).AsByte()).As<byte, T>();
if (sizeof(T) == 2) return Sse41.BlendVariable(left.AsUInt16(), right.AsUInt16(), (~mask).AsUInt16()).As<ushort, T>();
if (sizeof(T) == 4) return Sse41.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As<uint, T>();
if (sizeof(T) == 8) return Sse41.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As<ulong, T>();
}

return Vector128.ConditionalSelect(mask, left, right);
Expand All @@ -13854,16 +13851,13 @@ private static Vector256<T> ElementWiseSelect<T>(Vector256<T> mask, Vector256<T>
{
if (Avx2.IsSupported)
{
if (typeof(T) == typeof(byte)) return Avx2.BlendVariable(left.AsByte(), right.AsByte(), (~mask).AsByte()).As<byte, T>();
if (typeof(T) == typeof(sbyte)) return Avx2.BlendVariable(left.AsSByte(), right.AsSByte(), (~mask).AsSByte()).As<sbyte, T>();
if (typeof(T) == typeof(ushort)) return Avx2.BlendVariable(left.AsUInt16(), right.AsUInt16(), (~mask).AsUInt16()).As<ushort, T>();
if (typeof(T) == typeof(short)) return Avx2.BlendVariable(left.AsInt16(), right.AsInt16(), (~mask).AsInt16()).As<short, T>();
if (typeof(T) == typeof(uint)) return Avx2.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As<uint, T>();
if (typeof(T) == typeof(int)) return Avx2.BlendVariable(left.AsInt32(), right.AsInt32(), (~mask).AsInt32()).As<int, T>();
if (typeof(T) == typeof(ulong)) return Avx2.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As<ulong, T>();
if (typeof(T) == typeof(long)) return Avx2.BlendVariable(left.AsInt64(), right.AsInt64(), (~mask).AsInt64()).As<long, T>();
if (typeof(T) == typeof(float)) return Avx2.BlendVariable(left.AsSingle(), right.AsSingle(), (~mask).AsSingle()).As<float, T>();
if (typeof(T) == typeof(double)) return Avx2.BlendVariable(left.AsDouble(), right.AsDouble(), (~mask).AsDouble()).As<double, T>();

if (sizeof(T) == 1) return Avx2.BlendVariable(left.AsByte(), right.AsByte(), (~mask).AsByte()).As<byte, T>();
if (sizeof(T) == 2) return Avx2.BlendVariable(left.AsUInt16(), right.AsUInt16(), (~mask).AsUInt16()).As<ushort, T>();
if (sizeof(T) == 4) return Avx2.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As<uint, T>();
if (sizeof(T) == 8) return Avx2.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As<ulong, T>();
}

return Vector256.ConditionalSelect(mask, left, right);
Expand All @@ -13874,12 +13868,11 @@ private static Vector512<T> ElementWiseSelect<T>(Vector512<T> mask, Vector512<T>
{
if (Avx512F.IsSupported)
{
if (typeof(T) == typeof(uint)) return Avx512F.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As<uint, T>();
if (typeof(T) == typeof(int)) return Avx512F.BlendVariable(left.AsInt32(), right.AsInt32(), (~mask).AsInt32()).As<int, T>();
if (typeof(T) == typeof(ulong)) return Avx512F.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As<ulong, T>();
if (typeof(T) == typeof(long)) return Avx512F.BlendVariable(left.AsInt64(), right.AsInt64(), (~mask).AsInt64()).As<long, T>();
if (typeof(T) == typeof(float)) return Avx512F.BlendVariable(left.AsSingle(), right.AsSingle(), (~mask).AsSingle()).As<float, T>();
if (typeof(T) == typeof(double)) return Avx512F.BlendVariable(left.AsDouble(), right.AsDouble(), (~mask).AsDouble()).As<double, T>();

if (sizeof(T) == 4) return Avx512F.BlendVariable(left.AsUInt32(), right.AsUInt32(), (~mask).AsUInt32()).As<uint, T>();
if (sizeof(T) == 8) return Avx512F.BlendVariable(left.AsUInt64(), right.AsUInt64(), (~mask).AsUInt64()).As<ulong, T>();
}

return Vector512.ConditionalSelect(mask, left, right);
Expand Down

0 comments on commit d523506

Please sign in to comment.