From 9c884225c1ced599f9494fc1c2578460013d71e9 Mon Sep 17 00:00:00 2001 From: I <1091761+wx257osn2@users.noreply.github.com> Date: Thu, 1 Jun 2023 07:39:02 -0700 Subject: [PATCH] Some changes to simdlib (#2885) Summary: - Use elementwise operation and reduction once instead of across-vector comparing operation twice - Use already implemented supporting functions - Unify semantics of `operator==` as same as `simd16uint16` - `operator==` of `simd8uint32` and `simd8float32` had been implemented on https://github.com/facebookresearch/faiss/issues/2568, but these has not same semantics as `simd16uint16` (which had been implemented in a long time ago). For getting the vector equality as `bool` , now we should use `is_same_as` member function. - Change `is_same_as` to accept any vector type as argument for `simdlib_neon` - `is_same_as` has supported any vector type on `simdlib_avx2` and `simdlib_emulated` already - Remove unused function `simd16uint16::is_same` on `simdlib_avx2` - Is it typo of `is_same_as` ? Anyway it seems to be used unlikely Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2885 Reviewed By: mdouze Differential Revision: D46330666 Pulled By: alexanderguzhva fbshipit-source-id: 0ea14f8e9a8bda78f24a655219dffe3e07fc110f --- faiss/utils/simdlib_avx2.h | 6 -- faiss/utils/simdlib_neon.h | 149 ++++++++++++++++++------------------- 2 files changed, 72 insertions(+), 83 deletions(-) diff --git a/faiss/utils/simdlib_avx2.h b/faiss/utils/simdlib_avx2.h index 34d788ccd5..fc51e3ed18 100644 --- a/faiss/utils/simdlib_avx2.h +++ b/faiss/utils/simdlib_avx2.h @@ -202,12 +202,6 @@ struct simd16uint16 : simd256bit { return simd16uint16(_mm256_cmpeq_epi16(lhs.i, rhs.i)); } - bool is_same(simd16uint16 other) const { - const __m256i pcmp = _mm256_cmpeq_epi16(i, other.i); - unsigned bitmask = _mm256_movemask_epi8(pcmp); - return (bitmask == 0xffffffffU); - } - simd16uint16 operator~() const { return simd16uint16(_mm256_xor_si256(i, _mm256_set1_epi32(-1))); } diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 1dbfa2cd27..656a561217 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -559,15 +559,13 @@ struct simd16uint16 { } // Checks whether the other holds exactly the same bytes. - bool is_same_as(simd16uint16 other) const { - const bool equal0 = - (vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) == - 0xffff); - const bool equal1 = - (vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) == - 0xffff); - - return equal0 && equal1; + template + bool is_same_as(T other) const { + const auto o = detail::simdlib::reinterpret_u16(other.data); + const auto equals = detail::simdlib::binary_func(data, o) + .template call<&vceqq_u16>(); + const auto equal = vandq_u16(equals.val[0], equals.val[1]); + return vminvq_u16(equal) == 0xffffu; } simd16uint16 operator~() const { @@ -689,13 +687,12 @@ inline void cmplt_min_max_fast( simd16uint16& minIndices, simd16uint16& maxValues, simd16uint16& maxIndices) { - const uint16x8x2_t comparison = uint16x8x2_t{ - vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]), - vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])}; + const uint16x8x2_t comparison = + detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vcltq_u16>(); - minValues.data = uint16x8x2_t{ - vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]), - vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])}; + minValues = min(candidateValues, currentValues); minIndices.data = uint16x8x2_t{ vbslq_u16( comparison.val[0], @@ -706,9 +703,7 @@ inline void cmplt_min_max_fast( candidateIndices.data.val[1], currentIndices.data.val[1])}; - maxValues.data = uint16x8x2_t{ - vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]), - vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])}; + maxValues = max(candidateValues, currentValues); maxIndices.data = uint16x8x2_t{ vbslq_u16( comparison.val[0], @@ -869,13 +864,13 @@ struct simd32uint8 { } // Checks whether the other holds exactly the same bytes. - bool is_same_as(simd32uint8 other) const { - const bool equal0 = - (vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff); - const bool equal1 = - (vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff); - - return equal0 && equal1; + template + bool is_same_as(T other) const { + const auto o = detail::simdlib::reinterpret_u8(other.data); + const auto equals = detail::simdlib::binary_func(data, o) + .template call<&vceqq_u8>(); + const auto equal = vandq_u8(equals.val[0], equals.val[1]); + return vminvq_u8(equal) == 0xffu; } }; @@ -960,27 +955,28 @@ struct simd8uint32 { return *this; } - bool operator==(simd8uint32 other) const { - const auto equals = detail::simdlib::binary_func(data, other.data) - .call<&vceqq_u32>(); - const auto equal = vandq_u32(equals.val[0], equals.val[1]); - return vminvq_u32(equal) == 0xffffffff; + simd8uint32 operator==(simd8uint32 other) const { + return simd8uint32{detail::simdlib::binary_func(data, other.data) + .call<&vceqq_u32>()}; } - bool operator!=(simd8uint32 other) const { - return !(*this == other); + simd8uint32 operator~() const { + return simd8uint32{ + detail::simdlib::unary_func(data).call<&vmvnq_u32>()}; } - // Checks whether the other holds exactly the same bytes. - bool is_same_as(simd8uint32 other) const { - const bool equal0 = - (vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) == - 0xffffffff); - const bool equal1 = - (vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) == - 0xffffffff); + simd8uint32 operator!=(simd8uint32 other) const { + return ~(*this == other); + } - return equal0 && equal1; + // Checks whether the other holds exactly the same bytes. + template + bool is_same_as(T other) const { + const auto o = detail::simdlib::reinterpret_u32(other.data); + const auto equals = detail::simdlib::binary_func(data, o) + .template call<&vceqq_u32>(); + const auto equal = vandq_u32(equals.val[0], equals.val[1]); + return vminvq_u32(equal) == 0xffffffffu; } void clear() { @@ -1053,13 +1049,14 @@ inline void cmplt_min_max_fast( simd8uint32& minIndices, simd8uint32& maxValues, simd8uint32& maxIndices) { - const uint32x4x2_t comparison = uint32x4x2_t{ - vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]), - vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])}; - - minValues.data = uint32x4x2_t{ - vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]), - vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])}; + const uint32x4x2_t comparison = + detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vcltq_u32>(); + + minValues.data = detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vminq_u32>(); minIndices.data = uint32x4x2_t{ vbslq_u32( comparison.val[0], @@ -1070,9 +1067,9 @@ inline void cmplt_min_max_fast( candidateIndices.data.val[1], currentIndices.data.val[1])}; - maxValues.data = uint32x4x2_t{ - vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]), - vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])}; + maxValues.data = detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vmaxq_u32>(); maxIndices.data = uint32x4x2_t{ vbslq_u32( comparison.val[0], @@ -1167,28 +1164,25 @@ struct simd8float32 { return *this; } - bool operator==(simd8float32 other) const { - const auto equals = + simd8uint32 operator==(simd8float32 other) const { + return simd8uint32{ detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data) - .call<&vceqq_f32>(); - const auto equal = vandq_u32(equals.val[0], equals.val[1]); - return vminvq_u32(equal) == 0xffffffff; + .call<&vceqq_f32>()}; } - bool operator!=(simd8float32 other) const { - return !(*this == other); + simd8uint32 operator!=(simd8float32 other) const { + return ~(*this == other); } // Checks whether the other holds exactly the same bytes. - bool is_same_as(simd8float32 other) const { - const bool equal0 = - (vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) == - 0xffffffff); - const bool equal1 = - (vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) == - 0xffffffff); - - return equal0 && equal1; + template + bool is_same_as(T other) const { + const auto o = detail::simdlib::reinterpret_f32(other.data); + const auto equals = + detail::simdlib::binary_func<::uint32x4x2_t>(data, o) + .template call<&vceqq_f32>(); + const auto equal = vandq_u32(equals.val[0], equals.val[1]); + return vminvq_u32(equal) == 0xffffffffu; } std::string tostring() const { @@ -1302,13 +1296,14 @@ inline void cmplt_min_max_fast( simd8uint32& minIndices, simd8float32& maxValues, simd8uint32& maxIndices) { - const uint32x4x2_t comparison = uint32x4x2_t{ - vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]), - vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])}; - - minValues.data = float32x4x2_t{ - vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]), - vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])}; + const uint32x4x2_t comparison = + detail::simdlib::binary_func<::uint32x4x2_t>( + candidateValues.data, currentValues.data) + .call<&vcltq_f32>(); + + minValues.data = detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vminq_f32>(); minIndices.data = uint32x4x2_t{ vbslq_u32( comparison.val[0], @@ -1319,9 +1314,9 @@ inline void cmplt_min_max_fast( candidateIndices.data.val[1], currentIndices.data.val[1])}; - maxValues.data = float32x4x2_t{ - vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]), - vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])}; + maxValues.data = detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vmaxq_f32>(); maxIndices.data = uint32x4x2_t{ vbslq_u32( comparison.val[0],