Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1D convolution optimization and general codegen tweaks #1477

Merged
merged 12 commits into from
Dec 16, 2020
Merged
24 changes: 14 additions & 10 deletions src/ImageSharp/ColorSpaces/Companding/SRgbCompanding.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Six Labors.
// Copyright (c) Six Labors.
// Licensed under the Apache License, Version 2.0.

using System;
Expand All @@ -25,12 +25,14 @@ public static class SRgbCompanding
[MethodImpl(InliningOptions.ShortMethod)]
public static void Expand(Span<Vector4> vectors)
{
ref Vector4 baseRef = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsStart = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsEnd = ref Unsafe.Add(ref vectorsStart, vectors.Length);

for (int i = 0; i < vectors.Length; i++)
while (Unsafe.IsAddressLessThan(ref vectorsStart, ref vectorsEnd))
{
ref Vector4 v = ref Unsafe.Add(ref baseRef, i);
Expand(ref v);
Expand(ref vectorsStart);

vectorsStart = ref Unsafe.Add(ref vectorsStart, 1);
}
}

Expand All @@ -41,12 +43,14 @@ public static void Expand(Span<Vector4> vectors)
[MethodImpl(InliningOptions.ShortMethod)]
public static void Compress(Span<Vector4> vectors)
{
ref Vector4 baseRef = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsStart = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsEnd = ref Unsafe.Add(ref vectorsStart, vectors.Length);

for (int i = 0; i < vectors.Length; i++)
while (Unsafe.IsAddressLessThan(ref vectorsStart, ref vectorsEnd))
{
ref Vector4 v = ref Unsafe.Add(ref baseRef, i);
Compress(ref v);
Compress(ref vectorsStart);

vectorsStart = ref Unsafe.Add(ref vectorsStart, 1);
}
}

Expand Down Expand Up @@ -90,4 +94,4 @@ public static void Compress(ref Vector4 vector)
[MethodImpl(InliningOptions.ShortMethod)]
public static float Compress(float channel) => channel <= 0.0031308F ? 12.92F * channel : (1.055F * MathF.Pow(channel, 0.416666666666667F)) - 0.055F;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we ever figure out how to do an accurate SIMD enable approximation of this we would be laughing.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pow(channel, 0.416666666666667F) => exp(channel * log(0.416666666666667F))

log(0.416666666666667F) == -0.875468737353899935628f

So...

public static void Compress(ref Vector4 vector)
{
    var channels = Unsafe.As<Vector4, Vector128<float>>(ref vector);
    var log = Vector128.Create(-0.875468737353899935628f);

    channels = Sse.Multiply(channels, log);

    channels = Exp(channels); // Isn't simd intrinsic

    if (Fma.IsSupported)
    {
        channels = Fma.MultiplyAdd(Vector128.Create(1.055F), channels, Vector128.Create(-0.055F));
    }
    else
    {
        channels = Sse.Add(Sse.Multiply(Vector128.Create(1.055F), channels), Vector128.Create(-0.055F));
    }

    Unsafe.As<Vector4, Vector128<float>>(ref vector) = channels;
}

But Exp isn't a Simd intrinsic; however you can approximate it with these sequences sse_mathfun or avx_mathfun?

}
}
}
198 changes: 113 additions & 85 deletions src/ImageSharp/Common/Helpers/Numerics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@ public static int GreatestCommonDivisor(int a, int b)

/// <summary>
/// Determine the Least Common Multiple (LCM) of two numbers.
/// See https://en.wikipedia.org/wiki/Least_common_multiple#Reduction_by_the_greatest_common_divisor.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int LeastCommonMultiple(int a, int b)
{
// https://en.wikipedia.org/wiki/Least_common_multiple#Reduction_by_the_greatest_common_divisor
return (a / GreatestCommonDivisor(a, b)) * b;
}
=> a / GreatestCommonDivisor(a, b) * b;

/// <summary>
/// Calculates <paramref name="x"/> % 2
Expand Down Expand Up @@ -290,10 +288,14 @@ public static void Clamp(Span<byte> span, byte min, byte max)

if (remainder.Length > 0)
{
for (int i = 0; i < remainder.Length; i++)
ref byte remainderStart = ref MemoryMarshal.GetReference(remainder);
ref byte remainderEnd = ref Unsafe.Add(ref remainderStart, remainder.Length);

while (Unsafe.IsAddressLessThan(ref remainderStart, ref remainderEnd))
{
ref byte v = ref remainder[i];
v = Clamp(v, min, max);
remainderStart = Clamp(remainderStart, min, max);

remainderStart = ref Unsafe.Add(ref remainderStart, 1);
}
}
}
Expand All @@ -311,10 +313,14 @@ public static void Clamp(Span<uint> span, uint min, uint max)

if (remainder.Length > 0)
{
for (int i = 0; i < remainder.Length; i++)
ref uint remainderStart = ref MemoryMarshal.GetReference(remainder);
ref uint remainderEnd = ref Unsafe.Add(ref remainderStart, remainder.Length);

while (Unsafe.IsAddressLessThan(ref remainderStart, ref remainderEnd))
{
ref uint v = ref remainder[i];
v = Clamp(v, min, max);
remainderStart = Clamp(remainderStart, min, max);

remainderStart = ref Unsafe.Add(ref remainderStart, 1);
}
}
}
Expand All @@ -332,10 +338,14 @@ public static void Clamp(Span<int> span, int min, int max)

if (remainder.Length > 0)
{
for (int i = 0; i < remainder.Length; i++)
ref int remainderStart = ref MemoryMarshal.GetReference(remainder);
ref int remainderEnd = ref Unsafe.Add(ref remainderStart, remainder.Length);

while (Unsafe.IsAddressLessThan(ref remainderStart, ref remainderEnd))
{
ref int v = ref remainder[i];
v = Clamp(v, min, max);
remainderStart = Clamp(remainderStart, min, max);

remainderStart = ref Unsafe.Add(ref remainderStart, 1);
}
}
}
Expand All @@ -353,10 +363,14 @@ public static void Clamp(Span<float> span, float min, float max)

if (remainder.Length > 0)
{
for (int i = 0; i < remainder.Length; i++)
ref float remainderStart = ref MemoryMarshal.GetReference(remainder);
ref float remainderEnd = ref Unsafe.Add(ref remainderStart, remainder.Length);

while (Unsafe.IsAddressLessThan(ref remainderStart, ref remainderEnd))
{
ref float v = ref remainder[i];
v = Clamp(v, min, max);
remainderStart = Clamp(remainderStart, min, max);

remainderStart = ref Unsafe.Add(ref remainderStart, 1);
}
}
}
Expand All @@ -374,10 +388,14 @@ public static void Clamp(Span<double> span, double min, double max)

if (remainder.Length > 0)
{
for (int i = 0; i < remainder.Length; i++)
ref double remainderStart = ref MemoryMarshal.GetReference(remainder);
ref double remainderEnd = ref Unsafe.Add(ref remainderStart, remainder.Length);

while (Unsafe.IsAddressLessThan(ref remainderStart, ref remainderEnd))
{
ref double v = ref remainder[i];
v = Clamp(v, min, max);
remainderStart = Clamp(remainderStart, min, max);

remainderStart = ref Unsafe.Add(ref remainderStart, 1);
}
}
}
Expand Down Expand Up @@ -407,33 +425,42 @@ private static void ClampImpl<T>(Span<T> span, T min, T max)
where T : unmanaged
{
ref T sRef = ref MemoryMarshal.GetReference(span);
ref Vector<T> vsBase = ref Unsafe.As<T, Vector<T>>(ref MemoryMarshal.GetReference(span));
var vmin = new Vector<T>(min);
var vmax = new Vector<T>(max);

int n = span.Length / Vector<T>.Count;
int m = Modulo4(n);
int u = n - m;

for (int i = 0; i < u; i += 4)
{
ref Vector<T> vs0 = ref Unsafe.Add(ref vsBase, i);
ref Vector<T> vs1 = ref Unsafe.Add(ref vs0, 1);
ref Vector<T> vs2 = ref Unsafe.Add(ref vs0, 2);
ref Vector<T> vs3 = ref Unsafe.Add(ref vs0, 3);
ref Vector<T> vs0 = ref Unsafe.As<T, Vector<T>>(ref MemoryMarshal.GetReference(span));
ref Vector<T> vs1 = ref Unsafe.Add(ref vs0, 1);
ref Vector<T> vs2 = ref Unsafe.Add(ref vs0, 2);
ref Vector<T> vs3 = ref Unsafe.Add(ref vs0, 3);
ref Vector<T> vsEnd = ref Unsafe.Add(ref vs0, u);

while (Unsafe.IsAddressLessThan(ref vs0, ref vsEnd))
{
vs0 = Vector.Min(Vector.Max(vmin, vs0), vmax);
vs1 = Vector.Min(Vector.Max(vmin, vs1), vmax);
vs2 = Vector.Min(Vector.Max(vmin, vs2), vmax);
vs3 = Vector.Min(Vector.Max(vmin, vs3), vmax);

vs0 = ref Unsafe.Add(ref vs0, 4);
vs1 = ref Unsafe.Add(ref vs1, 4);
vs2 = ref Unsafe.Add(ref vs2, 4);
vs3 = ref Unsafe.Add(ref vs3, 4);
}

if (m > 0)
{
for (int i = u; i < n; i++)
vs0 = ref vsEnd;
vsEnd = ref Unsafe.Add(ref vsEnd, m);

while (Unsafe.IsAddressLessThan(ref vs0, ref vsEnd))
{
ref Vector<T> vs0 = ref Unsafe.Add(ref vsBase, i);
vs0 = Vector.Min(Vector.Max(vmin, vs0), vmax);

vs0 = ref Unsafe.Add(ref vs0, 1);
}
}
}
Expand Down Expand Up @@ -472,10 +499,8 @@ public static void Premultiply(Span<Vector4> vectors)
#if SUPPORTS_RUNTIME_INTRINSICS
if (Avx2.IsSupported && vectors.Length >= 2)
{
ref Vector256<float> vectorsBase =
ref Unsafe.As<Vector4, Vector256<float>>(ref MemoryMarshal.GetReference(vectors));

// Divide by 2 as 4 elements per Vector4 and 8 per Vector256<float>
ref Vector256<float> vectorsBase = ref Unsafe.As<Vector4, Vector256<float>>(ref MemoryMarshal.GetReference(vectors));
ref Vector256<float> vectorsLast = ref Unsafe.Add(ref vectorsBase, (IntPtr)((uint)vectors.Length / 2u));

while (Unsafe.IsAddressLessThan(ref vectorsBase, ref vectorsLast))
Expand All @@ -495,12 +520,14 @@ public static void Premultiply(Span<Vector4> vectors)
else
#endif
{
ref Vector4 baseRef = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsStart = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsEnd = ref Unsafe.Add(ref vectorsStart, vectors.Length);

for (int i = 0; i < vectors.Length; i++)
while (Unsafe.IsAddressLessThan(ref vectorsStart, ref vectorsEnd))
{
ref Vector4 v = ref Unsafe.Add(ref baseRef, i);
Premultiply(ref v);
Premultiply(ref vectorsStart);

vectorsStart = ref Unsafe.Add(ref vectorsStart, 1);
}
}
}
Expand All @@ -515,10 +542,8 @@ public static void UnPremultiply(Span<Vector4> vectors)
#if SUPPORTS_RUNTIME_INTRINSICS
if (Avx2.IsSupported && vectors.Length >= 2)
{
ref Vector256<float> vectorsBase =
ref Unsafe.As<Vector4, Vector256<float>>(ref MemoryMarshal.GetReference(vectors));

// Divide by 2 as 4 elements per Vector4 and 8 per Vector256<float>
ref Vector256<float> vectorsBase = ref Unsafe.As<Vector4, Vector256<float>>(ref MemoryMarshal.GetReference(vectors));
ref Vector256<float> vectorsLast = ref Unsafe.Add(ref vectorsBase, (IntPtr)((uint)vectors.Length / 2u));

while (Unsafe.IsAddressLessThan(ref vectorsBase, ref vectorsLast))
Expand All @@ -538,12 +563,14 @@ public static void UnPremultiply(Span<Vector4> vectors)
else
#endif
{
ref Vector4 baseRef = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsStart = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsEnd = ref Unsafe.Add(ref vectorsStart, vectors.Length);

for (int i = 0; i < vectors.Length; i++)
while (Unsafe.IsAddressLessThan(ref vectorsStart, ref vectorsEnd))
{
ref Vector4 v = ref Unsafe.Add(ref baseRef, i);
UnPremultiply(ref v);
UnPremultiply(ref vectorsStart);

vectorsStart = ref Unsafe.Add(ref vectorsStart, 1);
}
}
}
Expand Down Expand Up @@ -633,53 +660,54 @@ public static unsafe void CubeRootOnXYZ(Span<Vector4> vectors)
vectors128Ref = y4;
vectors128Ref = ref Unsafe.Add(ref vectors128Ref, 1);
}

return;
}
else
#endif
ref Vector4 vectorsRef = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsEnd = ref Unsafe.Add(ref vectorsRef, vectors.Length);

// Fallback with scalar preprocessing and vectorized approximation steps
while (Unsafe.IsAddressLessThan(ref vectorsRef, ref vectorsEnd))
{
Vector4 v = vectorsRef;
ref Vector4 vectorsRef = ref MemoryMarshal.GetReference(vectors);
ref Vector4 vectorsEnd = ref Unsafe.Add(ref vectorsRef, vectors.Length);

double
x64 = v.X,
y64 = v.Y,
z64 = v.Z;
float a = v.W;

ulong
xl = *(ulong*)&x64,
yl = *(ulong*)&y64,
zl = *(ulong*)&z64;

// Here we use a trick to compute the starting value x0 for the cube root. This is because doing
// pow(x, 1 / gamma) is the same as the gamma-th root of x, and since gamme is 3 in this case,
// this means what we actually want is to find the cube root of our clamped values.
// For more info on the constant below, see:
// https://community.intel.com/t5/Intel-C-Compiler/Fast-approximate-of-transcendental-operations/td-p/1044543.
// Here we perform the same trick on all RGB channels separately to help the CPU execute them in paralle, and
// store the alpha channel to preserve it. Then we set these values to the fields of a temporary 128-bit
// register, and use it to accelerate two steps of the Newton approximation using SIMD.
xl = 0x2a9f8a7be393b600 + (xl / 3);
yl = 0x2a9f8a7be393b600 + (yl / 3);
zl = 0x2a9f8a7be393b600 + (zl / 3);

Vector4 y4;
y4.X = (float)*(double*)&xl;
y4.Y = (float)*(double*)&yl;
y4.Z = (float)*(double*)&zl;
y4.W = 0;

y4 = (2 / 3f * y4) + (1 / 3f * (v / (y4 * y4)));
y4 = (2 / 3f * y4) + (1 / 3f * (v / (y4 * y4)));
y4.W = a;

vectorsRef = y4;
vectorsRef = ref Unsafe.Add(ref vectorsRef, 1);
// Fallback with scalar preprocessing and vectorized approximation steps
while (Unsafe.IsAddressLessThan(ref vectorsRef, ref vectorsEnd))
{
Vector4 v = vectorsRef;

double
x64 = v.X,
y64 = v.Y,
z64 = v.Z;
float a = v.W;

ulong
xl = *(ulong*)&x64,
yl = *(ulong*)&y64,
zl = *(ulong*)&z64;

// Here we use a trick to compute the starting value x0 for the cube root. This is because doing
// pow(x, 1 / gamma) is the same as the gamma-th root of x, and since gamme is 3 in this case,
// this means what we actually want is to find the cube root of our clamped values.
// For more info on the constant below, see:
// https://community.intel.com/t5/Intel-C-Compiler/Fast-approximate-of-transcendental-operations/td-p/1044543.
// Here we perform the same trick on all RGB channels separately to help the CPU execute them in paralle, and
// store the alpha channel to preserve it. Then we set these values to the fields of a temporary 128-bit
// register, and use it to accelerate two steps of the Newton approximation using SIMD.
xl = 0x2a9f8a7be393b600 + (xl / 3);
yl = 0x2a9f8a7be393b600 + (yl / 3);
zl = 0x2a9f8a7be393b600 + (zl / 3);

Vector4 y4;
y4.X = (float)*(double*)&xl;
y4.Y = (float)*(double*)&yl;
y4.Z = (float)*(double*)&zl;
y4.W = 0;

y4 = (2 / 3f * y4) + (1 / 3f * (v / (y4 * y4)));
y4 = (2 / 3f * y4) + (1 / 3f * (v / (y4 * y4)));
y4.W = a;

vectorsRef = y4;
vectorsRef = ref Unsafe.Add(ref vectorsRef, 1);
}
}
}
}
Expand Down
Loading