Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize SpanHelpers<T>.IndexOf #60974

Merged
merged 10 commits into from
Nov 22, 2021
111 changes: 111 additions & 0 deletions src/libraries/System.Memory/tests/Span/Contains.T.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Linq;
using Xunit;

namespace System.SpanTests
Expand Down Expand Up @@ -193,5 +195,114 @@ public static void ContainsNull_String(string[] spanInput, bool expected)
Span<string> theStrings = spanInput;
Assert.Equal(expected, theStrings.Contains(null));
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danmoseley I've added some tests here since there were not many tests for Span<T>.Contains. I believe the existing tests for Array.IndexOf (link) already have enough coverage. Let me know if I can add/modify anything.

[Theory]
[InlineData(new int[] { 1, 2, 3, 4 }, 4, true)]
[InlineData(new int[] { 1, 2, 3, 4 }, 5, false)]
public static void Contains_Int32(int[] array, int value, bool expectedResult)
{
// Test with short Span
Span<int> span = new Span<int>(array);
bool result = span.Contains(value);
Assert.Equal(result, expectedResult);

// Test with long Span
for (int i = 0; i < 10; i++)
array = array.Concat(array).ToArray();
span = new Span<int>(array);
result = span.Contains(value);
Assert.Equal(result, expectedResult);
}

[Theory]
[InlineData(new long[] { 1, 2, 3, 4 }, 4, true)]
[InlineData(new long[] { 1, 2, 3, 4 }, 5, false)]
public static void Contains_Int64(long[] array, long value, bool expectedResult)
{
// Test with short Span
Span<long> span = new Span<long>(array);
bool result = span.Contains(value);
Assert.Equal(result, expectedResult);

// Test with long Span
for (int i = 0; i < 10; i++)
array = array.Concat(array).ToArray();
span = new Span<long>(array);
result = span.Contains(value);
Assert.Equal(result, expectedResult);
}

[Theory]
[InlineData(new byte[] { 1, 2, 3, 4 }, 4, true)]
[InlineData(new byte[] { 1, 2, 3, 4 }, 5, false)]
public static void Contains_Byte(byte[] array, byte value, bool expectedResult)
{
// Test with short Span
Span<byte> span = new Span<byte>(array);
bool result = span.Contains(value);
Assert.Equal(result, expectedResult);

// Test with long Span
for (int i = 0; i < 10; i++)
array = array.Concat(array).ToArray();
span = new Span<byte>(array);
result = span.Contains(value);
Assert.Equal(result, expectedResult);
}

[Theory]
[InlineData(new char[] { 'a', 'b', 'c', 'd' }, 'd', true)]
[InlineData(new char[] { 'a', 'b', 'c', 'd' }, 'e', false)]
public static void Contains_Char(char[] array, char value, bool expectedResult)
{
// Test with short Span
Span<char> span = new Span<char>(array);
bool result = span.Contains(value);
Assert.Equal(result, expectedResult);

// Test with long Span
for (int i = 0; i < 10; i++)
array = array.Concat(array).ToArray();
span = new Span<char>(array);
result = span.Contains(value);
Assert.Equal(result, expectedResult);

}

[Theory]
[InlineData(new float[] { 1, 2, 3, 4 }, 4, true)]
[InlineData(new float[] { 1, 2, 3, 4 }, 5, false)]
public static void Contains_Float(float[] array, float value, bool expectedResult)
{
// Test with short Span
Span<float> span = new Span<float>(array);
bool result = span.Contains(value);
Assert.Equal(result, expectedResult);

// Test with long Span
for (int i = 0; i < 10; i++)
array = array.Concat(array).ToArray();
span = new Span<float>(array);
result = span.Contains(value);
Assert.Equal(result, expectedResult);
}

[Theory]
[InlineData(new double[] { 1, 2, 3, 4 }, 4, true)]
[InlineData(new double[] { 1, 2, 3, 4 }, 5, false)]
public static void Contains_Double(double[] array, double value, bool expectedResult)
{
// Test with short Span
Span<double> span = new Span<double>(array);
bool result = span.Contains(value);
Assert.Equal(result, expectedResult);

// Test with long Span
for (int i = 0; i < 10; i++)
array = array.Concat(array).ToArray();
span = new Span<double>(array);
result = span.Contains(value);
Assert.Equal(result, expectedResult);
}
}
}
26 changes: 18 additions & 8 deletions src/libraries/System.Private.CoreLib/src/System/Array.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1232,18 +1232,28 @@ ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<char[]>(array))
}
else if (Unsafe.SizeOf<T>() == sizeof(int))
{
int result = SpanHelpers.IndexOf(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<int[]>(array)), startIndex),
Unsafe.As<T, int>(ref value),
count);
int result = typeof(T).IsValueType
? SpanHelpers.IndexOfValueType(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<int[]>(array)), startIndex),
Unsafe.As<T, int>(ref value),
count)
: SpanHelpers.IndexOf(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<int[]>(array)), startIndex),
Unsafe.As<T, int>(ref value),
count);
return (result >= 0 ? startIndex : 0) + result;
}
else if (Unsafe.SizeOf<T>() == sizeof(long))
{
int result = SpanHelpers.IndexOf(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<long[]>(array)), startIndex),
Unsafe.As<T, long>(ref value),
count);
int result = typeof(T).IsValueType
? SpanHelpers.IndexOfValueType(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<long[]>(array)), startIndex),
Unsafe.As<T, long>(ref value),
count)
: SpanHelpers.IndexOf(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<long[]>(array)), startIndex),
Unsafe.As<T, long>(ref value),
count);
return (result >= 0 ? startIndex : 0) + result;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,18 @@ ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, char>(ref MemoryMarshal.GetReference(span)),
Unsafe.As<T, char>(ref value),
span.Length);

if (Unsafe.SizeOf<T>() == sizeof(int))
return -1 != SpanHelpers.IndexOfValueType(
ref Unsafe.As<T, int>(ref MemoryMarshal.GetReference(span)),
Unsafe.As<T, int>(ref value),
span.Length);

if (Unsafe.SizeOf<T>() == sizeof(long))
return -1 != SpanHelpers.IndexOfValueType(
ref Unsafe.As<T, long>(ref MemoryMarshal.GetReference(span)),
Unsafe.As<T, long>(ref value),
span.Length);
}

return SpanHelpers.Contains(ref MemoryMarshal.GetReference(span), value, span.Length);
Expand Down Expand Up @@ -306,6 +318,18 @@ ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, char>(ref MemoryMarshal.GetReference(span)),
Unsafe.As<T, char>(ref value),
span.Length);

if (Unsafe.SizeOf<T>() == sizeof(int))
return -1 != SpanHelpers.IndexOfValueType(
ref Unsafe.As<T, int>(ref MemoryMarshal.GetReference(span)),
Unsafe.As<T, int>(ref value),
span.Length);

if (Unsafe.SizeOf<T>() == sizeof(long))
return -1 != SpanHelpers.IndexOfValueType(
ref Unsafe.As<T, long>(ref MemoryMarshal.GetReference(span)),
Unsafe.As<T, long>(ref value),
span.Length);
}

return SpanHelpers.Contains(ref MemoryMarshal.GetReference(span), value, span.Length);
Expand All @@ -332,6 +356,18 @@ ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, char>(ref MemoryMarshal.GetReference(span)),
Unsafe.As<T, char>(ref value),
span.Length);

if (Unsafe.SizeOf<T>() == sizeof(int))
return SpanHelpers.IndexOfValueType(
ref Unsafe.As<T, int>(ref MemoryMarshal.GetReference(span)),
Unsafe.As<T, int>(ref value),
span.Length);

if (Unsafe.SizeOf<T>() == sizeof(long))
return SpanHelpers.IndexOfValueType(
ref Unsafe.As<T, long>(ref MemoryMarshal.GetReference(span)),
Unsafe.As<T, long>(ref value),
span.Length);
}

return SpanHelpers.IndexOf(ref MemoryMarshal.GetReference(span), value, span.Length);
Expand Down
110 changes: 110 additions & 0 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using Internal.Runtime.CompilerServices;

Expand Down Expand Up @@ -291,6 +292,115 @@ public static unsafe bool Contains<T>(ref T searchSpace, T value, int length) wh
return true;
}

internal static unsafe int IndexOfValueType<T>(ref T searchSpace, T value, int length) where T : struct, IEquatable<T>
{
Debug.Assert(length >= 0);

nint index = 0; // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations
if (Vector.IsHardwareAccelerated && Vector<T>.IsTypeSupported && (Vector<T>.Count * 2) <= length)
{
Vector<T> valueVector = new Vector<T>(value);
Vector<T> compareVector = default;
Vector<T> matchVector = default;
if ((uint)length % (uint)Vector<T>.Count != 0)
{
// Number of elements is not a multiple of Vector<T>.Count, so do one
// check and shift only enough for the remaining set to be a multiple
// of Vector<T>.Count.
compareVector = Unsafe.As<T, Vector<T>>(ref Unsafe.Add(ref searchSpace, index));
matchVector = Vector.Equals(valueVector, compareVector);
if (matchVector != Vector<T>.Zero)
{
goto VectorMatch;
}
index += length % Vector<T>.Count;
length -= length % Vector<T>.Count;
}
while (length > 0)
{
compareVector = Unsafe.As<T, Vector<T>>(ref Unsafe.Add(ref searchSpace, index));
matchVector = Vector.Equals(valueVector, compareVector);
if (matchVector != Vector<T>.Zero)
{
goto VectorMatch;
}
index += Vector<T>.Count;
length -= Vector<T>.Count;
}
goto NotFound;
VectorMatch:
for (int i = 0; i < Vector<T>.Count; i++)
if (compareVector[i].Equals(value))
return (int)(index + i);
}

while (length >= 8)
{
if (value.Equals(Unsafe.Add(ref searchSpace, index)))
goto Found;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 1)))
goto Found1;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 2)))
goto Found2;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 3)))
goto Found3;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 4)))
goto Found4;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 5)))
goto Found5;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 6)))
goto Found6;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 7)))
goto Found7;

length -= 8;
index += 8;
}

while (length >= 4)
{
if (value.Equals(Unsafe.Add(ref searchSpace, index)))
goto Found;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 1)))
goto Found1;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 2)))
goto Found2;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 3)))
goto Found3;

length -= 4;
index += 4;
}

while (length > 0)
{
if (value.Equals(Unsafe.Add(ref searchSpace, index)))
goto Found;

index += 1;
length--;
}
NotFound:
return -1;

Found: // Workaround for https://github.com/dotnet/runtime/issues/8795
return (int)index;
Found1:
return (int)(index + 1);
Found2:
return (int)(index + 2);
Found3:
return (int)(index + 3);
Found4:
return (int)(index + 4);
Found5:
return (int)(index + 5);
Found6:
return (int)(index + 6);
Found7:
return (int)(index + 7);
}

public static unsafe int IndexOf<T>(ref T searchSpace, T value, int length) where T : IEquatable<T>
{
Debug.Assert(length >= 0);
Expand Down