From 0680adad99f46dcb5dd4d74c647a8782bf161a1e Mon Sep 17 00:00:00 2001 From: Mengdi Lin Date: Wed, 25 Sep 2024 12:30:25 -0700 Subject: [PATCH] refactor Differential Revision: D63406173 --- faiss/impl/ScalarQuantizer.cpp | 20 +++++----- faiss/utils/simdlib_neon.h | 69 ++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 10 deletions(-) diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 91e00ee032..8d440926a7 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -671,7 +671,7 @@ struct QuantizerBF16<8> : QuantizerBF16<1> { FAISS_ALWAYS_INLINE simd8float32 reconstruct_8_components(const uint8_t* code, int i) const { -#ifdef __AVX2__ + // #ifdef __AVX2__ // reference impl: decode_bf16(((uint16_t*)code)[i]); // decode_bf16(v) -> (uint32_t(v) << 16) // read 128-bits (16 uint8_t) -> (uint16_t*)code)[i] @@ -683,18 +683,18 @@ struct QuantizerBF16<8> : QuantizerBF16<1> { simd8uint32 shifted_16 = code_256i << 16; return as_float32(shifted_16); -#endif + // #endif -#ifdef __aarch64__ + // #ifdef __aarch64__ - uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return simd8float32( - {vreinterpretq_f32_u32( - vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), - vreinterpretq_f32_u32( - vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}); + // uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * + // i)); return simd8float32( + // {vreinterpretq_f32_u32( + // vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), + // vreinterpretq_f32_u32( + // vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}); -#endif + // #endif } }; diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 21bda18898..456a35551e 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -254,6 +254,11 @@ static inline uint32_t cmp_xe32( return d0_mask | static_cast(d1_mask) << 16; } +template +static inline uint32x4_t vshlq(uint32x4_t vec) { + return vshlq_n_u32(vec, Shift); +} + template static inline uint16x8_t vshlq(uint16x8_t vec) { return vshlq_n_u16(vec, Shift); @@ -972,6 +977,63 @@ struct simd8uint32 { return ~(*this == other); } + // shift must be known at compile time + simd8uint32 operator<<(const int shift) const { + switch (shift) { + case 0: + return *this; + case 1: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 2: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 3: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 4: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 5: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 6: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 7: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 8: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 9: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 10: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 11: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 12: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 13: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 14: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 15: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 16: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + default: + FAISS_THROW_FMT("Invalid shift %d", shift); + } + } // Checks whether the other holds exactly the same bytes. template bool is_same_as(T other) const { @@ -1240,6 +1302,13 @@ inline simd8float32 load8(const uint8_t* code, int i) { {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))}); } +inline simd8uint32 load8_16bits_as_uint32(const uint8_t* code, int i) { + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return simd8uint32({vmovl_u16(codei.val[0]), vmovl_u16(codei.val[1])}); +} +inline simd8float32 as_float32(simd8uint32 x) { + return simd8float32(detail::simdlib::reinterpret_f32(x.data)); +} // The following primitive is a vectorized version of the following code // snippet: // float lowestValue = HUGE_VAL;