Skip to content

Commit

Permalink
This vectorizes TensorPrimitives.Log2 (#92897)
Browse files Browse the repository at this point in the history
* Add a way to support operations that can't be vectorized on netstandard

* Updating TensorPrimitives.Log2 to be vectorized on .NET Core

* Update src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs

Co-authored-by: Stephen Toub <[email protected]>

* Ensure we do an arithmetic right shift in the Log2 vectorization

* Ensure the code can compile on .NET 7

* Ensure that edge cases are properly handled and don't resolve to `x`

* Ensure that Log2 special results are explicitly handled.

---------

Co-authored-by: Stephen Toub <[email protected]>
  • Loading branch information
tannergooding and stephentoub authored Oct 3, 2023
1 parent f41715c commit 781e002
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -598,20 +598,8 @@ public static void Log(ReadOnlySpan<float> x, Span<float> destination)
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Log2(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = Log2(x[i]);
}
}
public static void Log2(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<Log2Operator>(x, destination);

/// <summary>Searches for the largest single-precision floating-point number in the specified tensor.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2579,6 +2579,286 @@ public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y)
#endif
}

private readonly struct Log2Operator : IUnaryOperator
{
// This code is based on `vrs4_log2f` from amd/aocl-libm-ose
// Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// Spec:
// log2f(x)
// = log2f(x) if x ∈ F and x > 0
// = x if x = qNaN
// = 0 if x = 1
// = -inf if x = (-0, 0}
// = NaN otherwise
//
// Assumptions/Expectations
// - Maximum ULP is observed to be at 4
// - Some FPU Exceptions may not be available
// - Performance is at least 3x
//
// Implementation Notes:
// 1. Range Reduction:
// x = 2^n*(1+f) .... (1)
// where n is exponent and is an integer
// (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2)
//
// From (1), taking log on both sides
// log2(x) = log2(2^n * (1+f))
// = n + log2(1+f) .... (3)
//
// let z = 1 + f
// log2(z) = log2(k) + log2(z) - log2(k)
// log2(z) = log2(kz) - log2(k)
//
// From (2), range of z is [1, 2)
// by simply dividing range by 'k', z is in [1/k, 2/k) .... (4)
// Best choice of k is the one which gives equal and opposite values
// at extrema +- -+
// 1 | 2 |
// --- - 1 = - |--- - 1 |
// k | k | .... (5)
// +- -+
//
// Solving for k, k = 3/2,
// From (4), using 'k' value, range is therefore [-0.3333, 0.3333]
//
// 2. Polynomial Approximation:
// More information refer to tools/sollya/vrs4_logf.sollya
//
// 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19

private const uint V_MIN = 0x00800000;
private const uint V_MAX = 0x7F800000;
private const uint V_MASK = 0x007FFFFF;
private const uint V_OFF = 0x3F2AAAAB;

private const float C0 = 0.0f;
private const float C1 = 1.4426951f;
private const float C2 = -0.72134554f;
private const float C3 = 0.48089063f;
private const float C4 = -0.36084408f;
private const float C5 = 0.2888971f;
private const float C6 = -0.23594281f;
private const float C7 = 0.19948183f;
private const float C8 = -0.22616665f;
private const float C9 = 0.21228963f;

public static float Invoke(float x) => MathF.Log2(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<float> specialResult = x;

// x is subnormal or infinity or NaN
Vector128<uint> specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN));

if (specialMask != Vector128<uint>.Zero)
{
// float.IsZero(x) ? float.NegativeInfinity : x
Vector128<float> zeroMask = Vector128.Equals(x, Vector128<float>.Zero);

specialResult = Vector128.ConditionalSelect(
zeroMask,
Vector128.Create(float.NegativeInfinity),
specialResult
);

// (x < 0) ? float.NaN : x
Vector128<float> lessThanZeroMask = Vector128.LessThan(x, Vector128<float>.Zero);

specialResult = Vector128.ConditionalSelect(
lessThanZeroMask,
Vector128.Create(float.NaN),
specialResult
);

// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
Vector128<float> temp = zeroMask
| lessThanZeroMask
| ~Vector128.Equals(x, x)
| Vector128.Equals(x, Vector128.Create(float.PositiveInfinity));

// subnormal
Vector128<float> subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp);

x = Vector128.ConditionalSelect(
subnormalMask,
((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(),
x
);

specialMask = temp.AsUInt32();
}

Vector128<uint> vx = x.AsUInt32() - Vector128.Create(V_OFF);
Vector128<float> n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23));

vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF);

Vector128<float> r = vx.AsSingle() - Vector128.Create(1.0f);

Vector128<float> r2 = r * r;
Vector128<float> r4 = r2 * r2;
Vector128<float> r8 = r4 * r4;

Vector128<float> poly = (Vector128.Create(C9) * r + Vector128.Create(C8)) * r8
+ (((Vector128.Create(C7) * r + Vector128.Create(C6)) * r2
+ (Vector128.Create(C5) * r + Vector128.Create(C4))) * r4
+ ((Vector128.Create(C3) * r + Vector128.Create(C2)) * r2
+ (Vector128.Create(C1) * r + Vector128.Create(C0))));

return Vector128.ConditionalSelect(
specialMask.AsSingle(),
specialResult,
n + poly
);
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<float> specialResult = x;

// x is subnormal or infinity or NaN
Vector256<uint> specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN));

if (specialMask != Vector256<uint>.Zero)
{
// float.IsZero(x) ? float.NegativeInfinity : x
Vector256<float> zeroMask = Vector256.Equals(x, Vector256<float>.Zero);

specialResult = Vector256.ConditionalSelect(
zeroMask,
Vector256.Create(float.NegativeInfinity),
specialResult
);

// (x < 0) ? float.NaN : x
Vector256<float> lessThanZeroMask = Vector256.LessThan(x, Vector256<float>.Zero);

specialResult = Vector256.ConditionalSelect(
lessThanZeroMask,
Vector256.Create(float.NaN),
specialResult
);

// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
Vector256<float> temp = zeroMask
| lessThanZeroMask
| ~Vector256.Equals(x, x)
| Vector256.Equals(x, Vector256.Create(float.PositiveInfinity));

// subnormal
Vector256<float> subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp);

x = Vector256.ConditionalSelect(
subnormalMask,
((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(),
x
);

specialMask = temp.AsUInt32();
}

Vector256<uint> vx = x.AsUInt32() - Vector256.Create(V_OFF);
Vector256<float> n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23));

vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF);

Vector256<float> r = vx.AsSingle() - Vector256.Create(1.0f);

Vector256<float> r2 = r * r;
Vector256<float> r4 = r2 * r2;
Vector256<float> r8 = r4 * r4;

Vector256<float> poly = (Vector256.Create(C9) * r + Vector256.Create(C8)) * r8
+ (((Vector256.Create(C7) * r + Vector256.Create(C6)) * r2
+ (Vector256.Create(C5) * r + Vector256.Create(C4))) * r4
+ ((Vector256.Create(C3) * r + Vector256.Create(C2)) * r2
+ (Vector256.Create(C1) * r + Vector256.Create(C0))));

return Vector256.ConditionalSelect(
specialMask.AsSingle(),
specialResult,
n + poly
);
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<float> specialResult = x;

// x is subnormal or infinity or NaN
Vector512<uint> specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN));

if (specialMask != Vector512<uint>.Zero)
{
// float.IsZero(x) ? float.NegativeInfinity : x
Vector512<float> zeroMask = Vector512.Equals(x, Vector512<float>.Zero);

specialResult = Vector512.ConditionalSelect(
zeroMask,
Vector512.Create(float.NegativeInfinity),
specialResult
);

// (x < 0) ? float.NaN : x
Vector512<float> lessThanZeroMask = Vector512.LessThan(x, Vector512<float>.Zero);

specialResult = Vector512.ConditionalSelect(
lessThanZeroMask,
Vector512.Create(float.NaN),
specialResult
);

// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
Vector512<float> temp = zeroMask
| lessThanZeroMask
| ~Vector512.Equals(x, x)
| Vector512.Equals(x, Vector512.Create(float.PositiveInfinity));

// subnormal
Vector512<float> subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp);

x = Vector512.ConditionalSelect(
subnormalMask,
((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(),
x
);

specialMask = temp.AsUInt32();
}

Vector512<uint> vx = x.AsUInt32() - Vector512.Create(V_OFF);
Vector512<float> n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23));

vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF);

Vector512<float> r = vx.AsSingle() - Vector512.Create(1.0f);

Vector512<float> r2 = r * r;
Vector512<float> r4 = r2 * r2;
Vector512<float> r8 = r4 * r4;

Vector512<float> poly = (Vector512.Create(C9) * r + Vector512.Create(C8)) * r8
+ (((Vector512.Create(C7) * r + Vector512.Create(C6)) * r2
+ (Vector512.Create(C5) * r + Vector512.Create(C4))) * r4
+ ((Vector512.Create(C3) * r + Vector512.Create(C2)) * r2
+ (Vector512.Create(C1) * r + Vector512.Create(C0))));

return Vector512.ConditionalSelect(
specialMask.AsSingle(),
specialResult,
n + poly
);
}
#endif
}

private interface IUnaryOperator
{
static abstract float Invoke(float x);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private static float Aggregate<TLoad, TAggregate>(

float result;

if (Vector.IsHardwareAccelerated && x.Length >= Vector<float>.Count)
if (Vector.IsHardwareAccelerated && load.CanVectorize && x.Length >= Vector<float>.Count)
{
ref float xRef = ref MemoryMarshal.GetReference(x);

Expand Down Expand Up @@ -304,7 +304,7 @@ private static void InvokeSpanIntoSpan<TUnaryOperator>(
ref float dRef = ref MemoryMarshal.GetReference(destination);
int i = 0, oneVectorFromEnd;

if (Vector.IsHardwareAccelerated)
if (Vector.IsHardwareAccelerated && op.CanVectorize)
{
oneVectorFromEnd = x.Length - Vector<float>.Count;
if (oneVectorFromEnd >= 0)
Expand Down Expand Up @@ -885,6 +885,7 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)

private readonly struct NegateOperator : IUnaryOperator
{
public bool CanVectorize => true;
public float Invoke(float x) => -x;
public Vector<float> Invoke(Vector<float> x) => -x;
}
Expand All @@ -903,24 +904,41 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)

private readonly struct IdentityOperator : IUnaryOperator
{
public bool CanVectorize => true;
public float Invoke(float x) => x;
public Vector<float> Invoke(Vector<float> x) => x;
}

private readonly struct SquaredOperator : IUnaryOperator
{
public bool CanVectorize => true;
public float Invoke(float x) => x * x;
public Vector<float> Invoke(Vector<float> x) => x * x;
}

private readonly struct AbsoluteOperator : IUnaryOperator
{
public bool CanVectorize => true;
public float Invoke(float x) => MathF.Abs(x);
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
}

private readonly struct Log2Operator : IUnaryOperator
{
public bool CanVectorize => false;

public float Invoke(float x) => Log2(x);

public Vector<float> Invoke(Vector<float> x)
{
// Vectorizing requires shift right support, which is .NET 7 or later
throw new NotImplementedException();
}
}

private interface IUnaryOperator
{
bool CanVectorize { get; }
float Invoke(float x);
Vector<float> Invoke(Vector<float> x);
}
Expand Down

0 comments on commit 781e002

Please sign in to comment.