Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This vectorizes TensorPrimitives.Log2 #92897

Merged
merged 7 commits into from
Oct 3, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -598,20 +598,8 @@ public static void Log(ReadOnlySpan<float> x, Span<float> destination)
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Log2(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = Log2(x[i]);
}
}
public static void Log2(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<Log2Operator>(x, destination);

/// <summary>Searches for the largest single-precision floating-point number in the specified tensor.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2579,6 +2579,286 @@ public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y)
#endif
}

private readonly struct Log2Operator : IUnaryOperator
{
// This code is based on `vrs4_log2f` from amd/aocl-libm-ose
// Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// Spec:
// log2f(x)
// = log2f(x) if x ∈ F and x > 0
// = x if x = qNaN
// = 0 if x = 1
// = -inf if x = (-0, 0}
// = NaN otherwise
//
// Assumptions/Expectations
// - Maximum ULP is observed to be at 4
// - Some FPU Exceptions may not be available
// - Performance is at least 3x
//
// Implementation Notes:
// 1. Range Reduction:
// x = 2^n*(1+f) .... (1)
// where n is exponent and is an integer
// (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2)
//
// From (1), taking log on both sides
// log2(x) = log2(2^n * (1+f))
// = n + log2(1+f) .... (3)
//
// let z = 1 + f
// log2(z) = log2(k) + log2(z) - log2(k)
// log2(z) = log2(kz) - log2(k)
//
// From (2), range of z is [1, 2)
// by simply dividing range by 'k', z is in [1/k, 2/k) .... (4)
// Best choice of k is the one which gives equal and opposite values
// at extrema +- -+
// 1 | 2 |
// --- - 1 = - |--- - 1 |
// k | k | .... (5)
// +- -+
//
// Solving for k, k = 3/2,
// From (4), using 'k' value, range is therefore [-0.3333, 0.3333]
//
// 2. Polynomial Approximation:
// More information refer to tools/sollya/vrs4_logf.sollya
//
// 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19

private const uint V_MIN = 0x00800000;
private const uint V_MAX = 0x7F800000;
private const uint V_MASK = 0x007FFFFF;
private const uint V_OFF = 0x3F2AAAAB;

private const float C0 = 0.0f;
private const float C1 = 1.4426951f;
private const float C2 = -0.72134554f;
private const float C3 = 0.48089063f;
private const float C4 = -0.36084408f;
private const float C5 = 0.2888971f;
private const float C6 = -0.23594281f;
private const float C7 = 0.19948183f;
private const float C8 = -0.22616665f;
private const float C9 = 0.21228963f;

public static float Invoke(float x) => MathF.Log2(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<float> specialResult = x;

// x is subnormal or infinity or NaN
Vector128<uint> specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN));

if (specialMask != Vector128<uint>.Zero)
{
// float.IsZero(x) ? float.NegativeInfinity : x
Vector128<float> zeroMask = Vector128.Equals(x, Vector128<float>.Zero);

specialResult = Vector128.ConditionalSelect(
zeroMask,
Vector128.Create(float.NegativeInfinity),
specialResult
);

// (x < 0) ? float.NaN : x
Vector128<float> lessThanZeroMask = Vector128.LessThan(x, Vector128<float>.Zero);

specialResult = Vector128.ConditionalSelect(
lessThanZeroMask,
Vector128.Create(float.NaN),
specialResult
);

// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
Vector128<float> temp = zeroMask
| lessThanZeroMask
| ~Vector128.Equals(x, x)
| Vector128.Equals(x, Vector128.Create(float.PositiveInfinity));

// subnormal
Vector128<float> subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp);

x = Vector128.ConditionalSelect(
subnormalMask,
((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(),
x
);

specialMask = temp.AsUInt32();
}

Vector128<uint> vx = x.AsUInt32() - Vector128.Create(V_OFF);
Vector128<float> n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23));

vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF);

Vector128<float> r = vx.AsSingle() - Vector128.Create(1.0f);

Vector128<float> r2 = r * r;
Vector128<float> r4 = r2 * r2;
Vector128<float> r8 = r4 * r4;

Vector128<float> poly = (Vector128.Create(C9) * r + Vector128.Create(C8)) * r8
+ (((Vector128.Create(C7) * r + Vector128.Create(C6)) * r2
+ (Vector128.Create(C5) * r + Vector128.Create(C4))) * r4
+ ((Vector128.Create(C3) * r + Vector128.Create(C2)) * r2
+ (Vector128.Create(C1) * r + Vector128.Create(C0))));

return Vector128.ConditionalSelect(
specialMask.AsSingle(),
specialResult,
n + poly
);
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<float> specialResult = x;

// x is subnormal or infinity or NaN
Vector256<uint> specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN));

if (specialMask != Vector256<uint>.Zero)
{
// float.IsZero(x) ? float.NegativeInfinity : x
Vector256<float> zeroMask = Vector256.Equals(x, Vector256<float>.Zero);

specialResult = Vector256.ConditionalSelect(
zeroMask,
Vector256.Create(float.NegativeInfinity),
specialResult
);

// (x < 0) ? float.NaN : x
Vector256<float> lessThanZeroMask = Vector256.LessThan(x, Vector256<float>.Zero);

specialResult = Vector256.ConditionalSelect(
lessThanZeroMask,
Vector256.Create(float.NaN),
specialResult
);

// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
Vector256<float> temp = zeroMask
| lessThanZeroMask
| ~Vector256.Equals(x, x)
| Vector256.Equals(x, Vector256.Create(float.PositiveInfinity));

// subnormal
Vector256<float> subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp);

x = Vector256.ConditionalSelect(
subnormalMask,
((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(),
x
);

specialMask = temp.AsUInt32();
}

Vector256<uint> vx = x.AsUInt32() - Vector256.Create(V_OFF);
Vector256<float> n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23));

vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF);

Vector256<float> r = vx.AsSingle() - Vector256.Create(1.0f);

Vector256<float> r2 = r * r;
Vector256<float> r4 = r2 * r2;
Vector256<float> r8 = r4 * r4;

Vector256<float> poly = (Vector256.Create(C9) * r + Vector256.Create(C8)) * r8
+ (((Vector256.Create(C7) * r + Vector256.Create(C6)) * r2
+ (Vector256.Create(C5) * r + Vector256.Create(C4))) * r4
+ ((Vector256.Create(C3) * r + Vector256.Create(C2)) * r2
+ (Vector256.Create(C1) * r + Vector256.Create(C0))));

return Vector256.ConditionalSelect(
specialMask.AsSingle(),
specialResult,
n + poly
);
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<float> specialResult = x;

// x is subnormal or infinity or NaN
Vector512<uint> specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN));

if (specialMask != Vector512<uint>.Zero)
{
// float.IsZero(x) ? float.NegativeInfinity : x
Vector512<float> zeroMask = Vector512.Equals(x, Vector512<float>.Zero);

specialResult = Vector512.ConditionalSelect(
zeroMask,
Vector512.Create(float.NegativeInfinity),
specialResult
);

// (x < 0) ? float.NaN : x
Vector512<float> lessThanZeroMask = Vector512.LessThan(x, Vector512<float>.Zero);

specialResult = Vector512.ConditionalSelect(
lessThanZeroMask,
Vector512.Create(float.NaN),
specialResult
);

// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
Vector512<float> temp = zeroMask
| lessThanZeroMask
| ~Vector512.Equals(x, x)
| Vector512.Equals(x, Vector512.Create(float.PositiveInfinity));

// subnormal
Vector512<float> subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp);

x = Vector512.ConditionalSelect(
subnormalMask,
((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(),
x
);

specialMask = temp.AsUInt32();
}

Vector512<uint> vx = x.AsUInt32() - Vector512.Create(V_OFF);
Vector512<float> n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23));

vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF);

Vector512<float> r = vx.AsSingle() - Vector512.Create(1.0f);

Vector512<float> r2 = r * r;
Vector512<float> r4 = r2 * r2;
Vector512<float> r8 = r4 * r4;

Vector512<float> poly = (Vector512.Create(C9) * r + Vector512.Create(C8)) * r8
+ (((Vector512.Create(C7) * r + Vector512.Create(C6)) * r2
+ (Vector512.Create(C5) * r + Vector512.Create(C4))) * r4
+ ((Vector512.Create(C3) * r + Vector512.Create(C2)) * r2
+ (Vector512.Create(C1) * r + Vector512.Create(C0))));

return Vector512.ConditionalSelect(
specialMask.AsSingle(),
specialResult,
n + poly
);
}
#endif
}

private interface IUnaryOperator
{
static abstract float Invoke(float x);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private static float Aggregate<TLoad, TAggregate>(

float result;

if (Vector.IsHardwareAccelerated && x.Length >= Vector<float>.Count)
if (Vector.IsHardwareAccelerated && load.CanVectorize && x.Length >= Vector<float>.Count)
{
ref float xRef = ref MemoryMarshal.GetReference(x);

Expand Down Expand Up @@ -304,7 +304,7 @@ private static void InvokeSpanIntoSpan<TUnaryOperator>(
ref float dRef = ref MemoryMarshal.GetReference(destination);
int i = 0, oneVectorFromEnd;

if (Vector.IsHardwareAccelerated)
if (Vector.IsHardwareAccelerated && op.CanVectorize)
{
oneVectorFromEnd = x.Length - Vector<float>.Count;
if (oneVectorFromEnd >= 0)
Expand Down Expand Up @@ -885,6 +885,7 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)

private readonly struct NegateOperator : IUnaryOperator
{
public bool CanVectorize => true;
public float Invoke(float x) => -x;
public Vector<float> Invoke(Vector<float> x) => -x;
}
Expand All @@ -903,24 +904,41 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)

private readonly struct IdentityOperator : IUnaryOperator
{
public bool CanVectorize => true;
public float Invoke(float x) => x;
public Vector<float> Invoke(Vector<float> x) => x;
}

private readonly struct SquaredOperator : IUnaryOperator
{
public bool CanVectorize => true;
public float Invoke(float x) => x * x;
public Vector<float> Invoke(Vector<float> x) => x * x;
}

private readonly struct AbsoluteOperator : IUnaryOperator
{
public bool CanVectorize => true;
public float Invoke(float x) => MathF.Abs(x);
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
}

private readonly struct Log2Operator : IUnaryOperator
{
public bool CanVectorize => false;
tannergooding marked this conversation as resolved.
Show resolved Hide resolved

public float Invoke(float x) => Log2(x);

public Vector<float> Invoke(Vector<float> x)
{
// Vectorizing requires shift right support, which is .NET 7 or later
throw new NotImplementedException();
}
}

private interface IUnaryOperator
{
bool CanVectorize { get; }
float Invoke(float x);
Vector<float> Invoke(Vector<float> x);
}
Expand Down