diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 846f542409..e3e78076ce 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -13,8 +13,8 @@ #include #include +#include #include - #ifdef __SSE__ #include #endif @@ -100,30 +100,13 @@ struct Codec8bit { const __m512 one_255 = _mm512_set1_ps(1.f / 255.f); return _mm512_fmadd_ps(f16, one_255, half_one_255); } -#elif defined(__AVX2__) - static FAISS_ALWAYS_INLINE __m256 - decode_8_components(const uint8_t* code, int i) { - const uint64_t c8 = *(uint64_t*)(code + i); - - const __m128i i8 = _mm_set1_epi64x(c8); - const __m256i i32 = _mm256_cvtepu8_epi32(i8); - const __m256 f8 = _mm256_cvtepi32_ps(i32); - const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f); - const __m256 one_255 = _mm256_set1_ps(1.f / 255.f); - return _mm256_fmadd_ps(f8, one_255, half_one_255); - } -#endif - -#ifdef USE_NEON - static FAISS_ALWAYS_INLINE float32x4x2_t +#else + static FAISS_ALWAYS_INLINE simd8float32 decode_8_components(const uint8_t* code, int i) { - float32_t result[8] = {}; - for (size_t j = 0; j < 8; j++) { - result[j] = decode_component(code, i + j); - } - float32x4_t res1 = vld1q_f32(result); - float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; + simd8float32 f8 = load8(code, i); + simd8float32 half_one_255 = simd8float32(0.5f / 255.f); + simd8float32 one_255 = simd8float32(1.f / 255.f); + return fmadd(f8, one_255, half_one_255); } #endif }; @@ -162,7 +145,8 @@ struct Codec4bit { return _mm512_fmadd_ps(f16, one_255, half_one_255); } #elif defined(__AVX2__) - static FAISS_ALWAYS_INLINE __m256 + + static FAISS_ALWAYS_INLINE simd8float32 decode_8_components(const uint8_t* code, int i) { uint32_t c4 = *(uint32_t*)(code + (i >> 1)); uint32_t mask = 0x0f0f0f0f; @@ -180,12 +164,12 @@ struct Codec4bit { __m256 half = _mm256_set1_ps(0.5f); f8 = _mm256_add_ps(f8, half); __m256 one_255 = _mm256_set1_ps(1.f / 15.f); - return _mm256_mul_ps(f8, one_255); + return simd8float32(_mm256_mul_ps(f8, one_255)); } #endif #ifdef USE_NEON - static FAISS_ALWAYS_INLINE float32x4x2_t + static FAISS_ALWAYS_INLINE simd8float32 decode_8_components(const uint8_t* code, int i) { float32_t result[8] = {}; for (size_t j = 0; j < 8; j++) { @@ -193,7 +177,7 @@ struct Codec4bit { } float32x4_t res1 = vld1q_f32(result); float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; + return simd8float32({res1, res2}); } #endif }; @@ -300,7 +284,7 @@ struct Codec6bit { /* Load 6 bytes that represent 8 6-bit values, return them as a * 8*32 bit vector register */ - static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) { + static FAISS_ALWAYS_INLINE simd8uint32 load6(const uint16_t* code16) { const __m128i perm = _mm_set_epi8( -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0); const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0); @@ -316,10 +300,10 @@ struct Codec6bit { // shift and mask out useless bits __m256i c4 = _mm256_srlv_epi32(c3, shifts); __m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4); - return c5; + return simd8uint32(c5); } - static FAISS_ALWAYS_INLINE __m256 + static FAISS_ALWAYS_INLINE simd8float32 decode_8_components(const uint8_t* code, int i) { // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here // // for the reference, maybe, it becomes used oned day. @@ -334,19 +318,19 @@ struct Codec6bit { // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); // return _mm256_fmadd_ps(f8, one_255, half_one_255); - __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3)); + __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3)).i; __m256 f8 = _mm256_cvtepi32_ps(i8); // this could also be done with bit manipulations but it is // not obviously faster const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f); const __m256 one_255 = _mm256_set1_ps(1.f / 63.f); - return _mm256_fmadd_ps(f8, one_255, half_one_255); + return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255)); } #endif #ifdef USE_NEON - static FAISS_ALWAYS_INLINE float32x4x2_t + static FAISS_ALWAYS_INLINE simd8float32 decode_8_components(const uint8_t* code, int i) { float32_t result[8] = {}; for (size_t j = 0; j < 8; j++) { @@ -354,7 +338,7 @@ struct Codec6bit { } float32x4_t res1 = vld1q_f32(result); float32x4_t res2 = vld1q_f32(result + 4); - return {res1, res2}; + return simd8float32({res1, res2}); } #endif }; @@ -426,27 +410,7 @@ struct QuantizerTemplate } }; -#elif defined(__AVX2__) - -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( - d, - trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m256 xi = Codec::decode_8_components(code, i); - return _mm256_fmadd_ps( - xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin)); - } -}; - -#endif - -#ifdef USE_NEON +#else template struct QuantizerTemplate @@ -456,17 +420,11 @@ struct QuantizerTemplate d, trained) {} - FAISS_ALWAYS_INLINE float32x4x2_t + FAISS_ALWAYS_INLINE simd8float32 reconstruct_8_components(const uint8_t* code, int i) const { - float32x4x2_t xi = Codec::decode_8_components(code, i); - return {vfmaq_f32( - vdupq_n_f32(this->vmin), - xi.val[0], - vdupq_n_f32(this->vdiff)), - vfmaq_f32( - vdupq_n_f32(this->vmin), - xi.val[1], - vdupq_n_f32(this->vdiff))}; + simd8float32 xi = Codec::decode_8_components(code, i); + return simd8float32( + fmadd(xi, simd8float32(this->vdiff), simd8float32(this->vmin))); } }; @@ -543,13 +501,13 @@ struct QuantizerTemplate QuantizerTemplateScaling::NON_UNIFORM, 1>(d, trained) {} - FAISS_ALWAYS_INLINE __m256 + FAISS_ALWAYS_INLINE simd8float32 reconstruct_8_components(const uint8_t* code, int i) const { - __m256 xi = Codec::decode_8_components(code, i); - return _mm256_fmadd_ps( - xi, - _mm256_loadu_ps(this->vdiff + i), - _mm256_loadu_ps(this->vmin + i)); + simd8float32 xi = Codec::decode_8_components(code, i); + return simd8float32( + fmadd(xi, + simd8float32(this->vdiff + i), + simd8float32(this->vmin + i))); } }; @@ -566,15 +524,16 @@ struct QuantizerTemplate QuantizerTemplateScaling::NON_UNIFORM, 1>(d, trained) {} - FAISS_ALWAYS_INLINE float32x4x2_t + FAISS_ALWAYS_INLINE simd8float32 reconstruct_8_components(const uint8_t* code, int i) const { - float32x4x2_t xi = Codec::decode_8_components(code, i); + float32x4x2_t xi = Codec::decode_8_components(code, i).data; float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i); float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i); - return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), - vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])}; + return simd8float32( + {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), + vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])}); } }; @@ -627,36 +586,31 @@ struct QuantizerFP16<16> : QuantizerFP16<1> { #endif -#if defined(USE_F16C) +#if defined(USE_F16C) || defined(USE_NEON) template <> struct QuantizerFP16<8> : QuantizerFP16<1> { QuantizerFP16(size_t d, const std::vector& trained) : QuantizerFP16<1>(d, trained) {} - FAISS_ALWAYS_INLINE __m256 + FAISS_ALWAYS_INLINE simd8float32 reconstruct_8_components(const uint8_t* code, int i) const { +#ifdef USE_F16C __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i)); - return _mm256_cvtph_ps(codei); - } -}; - + return simd8float32(_mm256_cvtph_ps(codei)); #endif -#ifdef USE_NEON +#ifdef __aarch64__ -template <> -struct QuantizerFP16<8> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} - - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), - vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))}; + return simd8float32( + {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), + vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))}); + +#endif } }; + #endif /******************************************************************* @@ -705,39 +659,43 @@ struct QuantizerBF16<16> : QuantizerBF16<1> { } }; -#elif defined(__AVX2__) +#else template <> struct QuantizerBF16<8> : QuantizerBF16<1> { QuantizerBF16(size_t d, const std::vector& trained) : QuantizerBF16<1>(d, trained) {} - FAISS_ALWAYS_INLINE __m256 + FAISS_ALWAYS_INLINE simd8float32 reconstruct_8_components(const uint8_t* code, int i) const { - __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i)); - __m256i code_256i = _mm256_cvtepu16_epi32(code_128i); - code_256i = _mm256_slli_epi32(code_256i, 16); - return _mm256_castsi256_ps(code_256i); - } -}; +#ifdef __AVX2__ + // reference impl: decode_bf16(((uint16_t*)code)[i]); + // decode_bf16(v) -> (uint32_t(v) << 16) + // read 128-bits (16 uint8_t) -> (uint16_t*)code)[i] + // read as 8 x 16 bits and 0-extend into 8 x 32 bits -> uint32_t(v) + // bit shift by 16 -> uint32_t(v) << 16 + + // load 8 as i32 and bit shift by 16 + simd8uint32 code_256i = load8_16bits_as_uint32(code, i); + simd8uint32 shifted_16 = code_256i << 16; + return as_float32(shifted_16); #endif -#ifdef USE_NEON - -template <> -struct QuantizerBF16<8> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} +#ifdef __aarch64__ - FAISS_ALWAYS_INLINE float32x4x2_t - reconstruct_8_components(const uint8_t* code, int i) const { uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), - vreinterpretq_f32_u32( - vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}; + return simd8float32( + {vreinterpretq_f32_u32( + vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), + vreinterpretq_f32_u32( + vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}); + +#endif + throw std::runtime_error("unreachable"); } }; + #endif /******************************************************************* @@ -787,39 +745,16 @@ struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> { } }; -#elif defined(__AVX2__) - -template <> -struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 - __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 - return _mm256_cvtepi32_ps(y8); // 8 * float32 - } -}; - -#endif - -#ifdef USE_NEON +#else template <> struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { Quantizer8bitDirect(size_t d, const std::vector& trained) : Quantizer8bitDirect<1>(d, trained) {} - FAISS_ALWAYS_INLINE float32x4x2_t + FAISS_ALWAYS_INLINE simd8float32 reconstruct_8_components(const uint8_t* code, int i) const { - uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); - uint16x8_t y8 = vmovl_u8(x8); - uint16x4_t y8_0 = vget_low_u16(y8); - uint16x4_t y8_1 = vget_high_u16(y8); - - // convert uint16 -> uint32 -> fp32 - return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))}; + return load8(code, i); } }; @@ -874,46 +809,17 @@ struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> { } }; -#elif defined(__AVX2__) - -template <> -struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} - - FAISS_ALWAYS_INLINE __m256 - reconstruct_8_components(const uint8_t* code, int i) const { - __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 - __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 - __m256i c8 = _mm256_set1_epi32(128); - __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes - return _mm256_cvtepi32_ps(z8); // 8 * float32 - } -}; - -#endif - -#ifdef USE_NEON +#else template <> struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { Quantizer8bitDirectSigned(size_t d, const std::vector& trained) : Quantizer8bitDirectSigned<1>(d, trained) {} - FAISS_ALWAYS_INLINE float32x4x2_t + FAISS_ALWAYS_INLINE simd8float32 reconstruct_8_components(const uint8_t* code, int i) const { - uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); - uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16 - uint16x4_t y8_0 = vget_low_u16(y8); - uint16x4_t y8_1 = vget_high_u16(y8); - - float32x4_t z8_0 = vcvtq_f32_u32( - vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32 - float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1)); - - // subtract 128 to convert into signed numbers - return {vsubq_f32(z8_0, vmovq_n_f32(128.0)), - vsubq_f32(z8_1, vmovq_n_f32(128.0))}; + simd8float32 f8 = load8(code, i); // 8 * float32 + return f8 - simd8float32(128.0); // subtract 128 from all lanes } }; @@ -1209,7 +1115,7 @@ struct SimilarityL2<16> { } }; -#elif defined(__AVX2__) +#elif defined(USE_F16C) || defined(USE_NEON) template <> struct SimilarityL2<8> { @@ -1219,87 +1125,38 @@ struct SimilarityL2<8> { const float *y, *yi; explicit SimilarityL2(const float* y) : y(y) {} - __m256 accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = _mm256_setzero_ps(); - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(__m256 x) { - __m256 yiv = _mm256_loadu_ps(yi); - yi += 8; - __m256 tmp = _mm256_sub_ps(yiv, x); - accu8 = _mm256_fmadd_ps(tmp, tmp, accu8); - } - - FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x, __m256 y_2) { - __m256 tmp = _mm256_sub_ps(y_2, x); - accu8 = _mm256_fmadd_ps(tmp, tmp, accu8); - } - - FAISS_ALWAYS_INLINE float result_8() { - const __m128 sum = _mm_add_ps( - _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1)); - const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(sum, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); - } -}; - -#endif - -#ifdef USE_NEON -template <> -struct SimilarityL2<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_L2; - - const float *y, *yi; - explicit SimilarityL2(const float* y) : y(y) {} - float32x4x2_t accu8; + simd8float32 accu8; FAISS_ALWAYS_INLINE void begin_8() { - accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; + accu8.clear(); yi = y; } - FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { - float32x4x2_t yiv = vld1q_f32_x2(yi); + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + // reference implementation + // yi is the current pointer into floats y + // *(yi++) - x -> increment yi and compute distance from float x + // tmp = *yi++ - x + // accu += tmp * tmp + // accu += (yi - x)^2 + simd8float32 yiv(yi); yi += 8; - - float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]); - float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]); - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - - accu8 = {accu8_0, accu8_1}; + simd8float32 tmp = yiv - x; + accu8 = fmadd(tmp, tmp, accu8); } FAISS_ALWAYS_INLINE void add_8_components_2( - float32x4x2_t x, - float32x4x2_t y) { - float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]); - float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]); - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - - accu8 = {accu8_0, accu8_1}; + simd8float32 x, + simd8float32 y_2) { + simd8float32 tmp = y_2 - x; + accu8 = fmadd(tmp, tmp, accu8); } FAISS_ALWAYS_INLINE float result_8() { - float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]); - float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]); - - float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0); - float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1); - return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0); + return accu8.accumulate(); } }; + #endif template @@ -1369,7 +1226,7 @@ struct SimilarityIP<16> { } }; -#elif defined(__AVX2__) +#elif defined(USE_F16C) || defined(USE_NEON) template <> struct SimilarityIP<8> { @@ -1382,78 +1239,27 @@ struct SimilarityIP<8> { explicit SimilarityIP(const float* y) : y(y) {} - __m256 accu8; + simd8float32 accu8; FAISS_ALWAYS_INLINE void begin_8() { - accu8 = _mm256_setzero_ps(); + accu8.clear(); yi = y; } - FAISS_ALWAYS_INLINE void add_8_components(__m256 x) { - __m256 yiv = _mm256_loadu_ps(yi); + FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) { + simd8float32 yiv(yi); yi += 8; - accu8 = _mm256_fmadd_ps(yiv, x, accu8); - } - - FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x1, __m256 x2) { - accu8 = _mm256_fmadd_ps(x1, x2, accu8); - } - - FAISS_ALWAYS_INLINE float result_8() { - const __m128 sum = _mm_add_ps( - _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1)); - const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); - const __m128 v1 = _mm_add_ps(sum, v0); - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - const __m128 v3 = _mm_add_ps(v1, v2); - return _mm_cvtss_f32(v3); - } -}; -#endif - -#ifdef USE_NEON - -template <> -struct SimilarityIP<8> { - static constexpr int simdwidth = 8; - static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; - - const float *y, *yi; - - explicit SimilarityIP(const float* y) : y(y) {} - float32x4x2_t accu8; - - FAISS_ALWAYS_INLINE void begin_8() { - accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; - yi = y; - } - - FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) { - float32x4x2_t yiv = vld1q_f32_x2(yi); - yi += 8; - - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]); - accu8 = {accu8_0, accu8_1}; + accu8 = fmadd(yiv, x, accu8); } FAISS_ALWAYS_INLINE void add_8_components_2( - float32x4x2_t x1, - float32x4x2_t x2) { - float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]); - float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]); - accu8 = {accu8_0, accu8_1}; + simd8float32 x1, + simd8float32 x2) { + accu8 = fmadd(x1, x2, accu8); } FAISS_ALWAYS_INLINE float result_8() { - float32x4x2_t sum = { - vpaddq_f32(accu8.val[0], accu8.val[0]), - vpaddq_f32(accu8.val[1], accu8.val[1])}; - - float32x4x2_t sum2 = { - vpaddq_f32(sum.val[0], sum.val[0]), - vpaddq_f32(sum.val[1], sum.val[1])}; - return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0); + return accu8.accumulate(); } }; #endif @@ -1559,7 +1365,7 @@ struct DCTemplate } }; -#elif defined(USE_F16C) +#elif defined(USE_F16C) || defined(USE_NEON) template struct DCTemplate : SQDistanceComputer { @@ -1574,7 +1380,7 @@ struct DCTemplate : SQDistanceComputer { Similarity sim(x); sim.begin_8(); for (size_t i = 0; i < quant.d; i += 8) { - __m256 xi = quant.reconstruct_8_components(code, i); + simd8float32 xi = quant.reconstruct_8_components(code, i); sim.add_8_components(xi); } return sim.result_8(); @@ -1585,8 +1391,8 @@ struct DCTemplate : SQDistanceComputer { Similarity sim(nullptr); sim.begin_8(); for (size_t i = 0; i < quant.d; i += 8) { - __m256 x1 = quant.reconstruct_8_components(code1, i); - __m256 x2 = quant.reconstruct_8_components(code2, i); + simd8float32 x1 = quant.reconstruct_8_components(code1, i); + simd8float32 x2 = quant.reconstruct_8_components(code2, i); sim.add_8_components_2(x1, x2); } return sim.result_8(); @@ -1608,53 +1414,6 @@ struct DCTemplate : SQDistanceComputer { #endif -#ifdef USE_NEON - -template -struct DCTemplate : SQDistanceComputer { - using Sim = Similarity; - - Quantizer quant; - - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} - float compute_distance(const float* x, const uint8_t* code) const { - Similarity sim(x); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - float32x4x2_t xi = quant.reconstruct_8_components(code, i); - sim.add_8_components(xi); - } - return sim.result_8(); - } - - float compute_code_distance(const uint8_t* code1, const uint8_t* code2) - const { - Similarity sim(nullptr); - sim.begin_8(); - for (size_t i = 0; i < quant.d; i += 8) { - float32x4x2_t x1 = quant.reconstruct_8_components(code1, i); - float32x4x2_t x2 = quant.reconstruct_8_components(code2, i); - sim.add_8_components_2(x1, x2); - } - return sim.result_8(); - } - - void set_query(const float* x) final { - q = x; - } - - float symmetric_dis(idx_t i, idx_t j) override { - return compute_code_distance( - codes + i * code_size, codes + j * code_size); - } - - float query_to_code(const uint8_t* code) const final { - return compute_distance(q, code); - } -}; -#endif - /******************************************************************* * DistanceComputerByte: computes distances in the integer domain *******************************************************************/ diff --git a/faiss/utils/simdlib_avx2.h b/faiss/utils/simdlib_avx2.h index fc51e3ed18..daf0c382de 100644 --- a/faiss/utils/simdlib_avx2.h +++ b/faiss/utils/simdlib_avx2.h @@ -217,6 +217,7 @@ struct simd16uint16 : simd256bit { __m256i j = thresh.i; __m256i max = _mm256_max_epu16(i, j); __m256i ge = _mm256_cmpeq_epi16(i, max); + // 0xFFFFFFFF if this >= thresh else 0 return _mm256_movemask_epi8(ge); } @@ -240,6 +241,8 @@ struct simd16uint16 : simd256bit { } void accu_min(simd16uint16 incoming) { + // compare 16 16-bit unsigned integers and return the corresponding + // smaller integer as part of the 256-bit result i = _mm256_min_epu16(i, incoming.i); } @@ -547,6 +550,15 @@ struct simd8uint32 : simd256bit { return !(*this == other); } + // // shift must be known at compile time + simd8uint32 operator<<(const int shift) const { + return simd8uint32(_mm256_slli_epi32(i, shift)); + } + + // // shift must be known at compile time + simd8uint32 operator>>(const int shift) const { + return simd8uint32(_mm256_srli_epi32(i, shift)); + } std::string elements_to_string(const char* fmt) const { uint32_t bytes[8]; storeu((void*)bytes); @@ -676,6 +688,22 @@ struct simd8float32 : simd256bit { ptr[-1] = 0; return std::string(res); } + + float accumulate() const { + // sum = (s0, s1, s2, s3) = (f0+f4, f1+f5, f2+f6, f3+f7) + // v0 = (s2, s3, s0, s0) + // v1 = (s2+s0, s3+s1, s0+s2, s0+s3) + // v2 = (s1+s3, s2+s0, s2+s0, s2+s0) + // v3 = (s0+s1+s2+s3, s1+s2+s3, 2s0+2s2, 2s0+s2+s3) + // return v3[0] + const __m128 sum = _mm_add_ps( + _mm256_castps256_ps128(f), _mm256_extractf128_ps(f, 1)); + const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2)); + const __m128 v1 = _mm_add_ps(sum, v0); + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + const __m128 v3 = _mm_add_ps(v1, v2); + return _mm_cvtss_f32(v3); + } }; inline simd8float32 hadd(simd8float32 a, simd8float32 b) { @@ -695,6 +723,22 @@ inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) { return simd8float32(_mm256_fmadd_ps(a.f, b.f, c.f)); } +// load8_8bits uint8_t from code + i and extend the elements to float32 +inline simd8float32 load8(const uint8_t* code, int i) { + __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8 + __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32 + return simd8float32(_mm256_cvtepi32_ps(y8)); // 8 * float32 +} + +inline simd8uint32 load8_16bits_as_uint32(const uint8_t* code, int i) { + __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i)); + return simd8uint32(_mm256_cvtepu16_epi32(code_128i)); +} + +inline simd8float32 as_float32(simd8uint32 x) { + return simd8float32(_mm256_castsi256_ps(x.i)); +} + // The following primitive is a vectorized version of the following code // snippet: // float lowestValue = HUGE_VAL; diff --git a/faiss/utils/simdlib_emulated.h b/faiss/utils/simdlib_emulated.h index f9cfb3b34b..b98e72ae5d 100644 --- a/faiss/utils/simdlib_emulated.h +++ b/faiss/utils/simdlib_emulated.h @@ -833,6 +833,14 @@ struct simd8float32 : simd256bit { ptr[-1] = 0; return std::string(res); } + + float accumulate() const { + float res = 0; + for (int i = 0; i < 8; i++) { + res += f32[i]; + } + return res; + } }; // hadd does not cross lanes @@ -893,6 +901,13 @@ inline simd8float32 fmadd( return res; } +inline simd8float32 load8(const uint8_t* code, int i) { + simd8float32 res; + for (int j = 0; j < 8; j++) { + res.f32[i] = *(code + i + j); + } + return res; +} namespace { // get even float32's of a and b, interleaved diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 1bdf0ed01e..a47a51b2e3 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -972,6 +972,19 @@ struct simd8uint32 { return ~(*this == other); } + // shift must be known at compile time + simd8uint32 operator<<(const int shift) const { + uint32x4_t shifts = { + static_cast(shift), + static_cast(shift), + static_cast(shift), + static_cast(shift)}; + simd8uint32 result; + result.data.val[0] = vshlq_u32(data.val[0], shifts); + result.data.val[1] = vshlq_u32(data.val[1], shifts); + return result; + } + // Checks whether the other holds exactly the same bytes. template bool is_same_as(T other) const { @@ -1040,9 +1053,9 @@ struct simd8uint32 { // maxValues[i] = !flag ? candidateValues[i] : currentValues[i]; // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i]; // } -// Max indices evaluation is inaccurate in case of equal values (the index of -// the last equal value is saved instead of the first one), but this behavior -// saves instructions. +// Max indices evaluation is inaccurate in case of equal values (the +// index of the last equal value is saved instead of the first one), but +// this behavior saves instructions. inline void cmplt_min_max_fast( const simd8uint32 candidateValues, const simd8uint32 candidateIndices, @@ -1160,8 +1173,8 @@ struct simd8float32 { } simd8float32& operator+=(const simd8float32& other) { - // In this context, it is more compiler friendly to write intrinsics - // directly instead of using binary_func + // In this context, it is more compiler friendly to write + // intrinsics directly instead of using binary_func data.val[0] = vaddq_f32(data.val[0], other.data.val[0]); data.val[1] = vaddq_f32(data.val[1], other.data.val[1]); return *this; @@ -1191,6 +1204,19 @@ struct simd8float32 { std::string tostring() const { return detail::simdlib::elements_to_string("%g,", *this); } + + float accumulate() const { + // data = {v01, v02, v03, v04, v11, v12, v13, v14} + // sum_0 = {v01+v02, v03+v04, v01+v02, v03+v04} + // sum_1 = {v11+v12, v13+v14, v11+v12, v13+v14} + // sum2_0 = {v01+v02+v03+v04, v01+v02+v03+v04, v01+v02+v03+v04, + // v01+v02+v03+v04} + // sum2_1 = {v11+v12+v13+v14, v11+v12+v13+v14, v11+v12+v13+v14, + // v11+v12+v13+v14} + // vgetq_lane_f32(sum2_0, 0) = v01+v02+v03+v04 + // vgetq_lane_f32(sum2_1, 0) = v11+v12+v13+v14 + return vaddvq_f32(data.val[0]) + vaddvq_f32(data.val[1]); + } }; // hadd does not cross lanes @@ -1219,6 +1245,17 @@ inline simd8float32 fmadd( vfmaq_f32(c.data.val[1], a.data.val[1], b.data.val[1])}}; } +// load 8 uint8_t from code + i and extend the elements to float32 +inline simd8float32 load8(const uint8_t* code, int i) { + uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i)); + uint16x8_t y8 = vmovl_u8(x8); + uint16x4_t y8_0 = vget_low_u16(y8); + uint16x4_t y8_1 = vget_high_u16(y8); + + // convert uint16 -> uint32 -> fp32 + return simd8float32( + {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))}); +} // The following primitive is a vectorized version of the following code // snippet: // float lowestValue = HUGE_VAL; @@ -1229,8 +1266,8 @@ inline simd8float32 fmadd( // lowestIndex = i; // } // } -// Vectorized version can be implemented via two operations: cmp and blend -// with something like this: +// Vectorized version can be implemented via two operations: cmp and +// blend with something like this: // lowestValues = [HUGE_VAL; 8]; // lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7}; // for (size_t i = 0; i < n; i += 8) { @@ -1248,8 +1285,9 @@ inline simd8float32 fmadd( // The problem is that blend primitive needs very different instruction // order for AVX and ARM. // So, let's introduce a combination of these two in order to avoid -// confusion for ppl who write in low-level SIMD instructions. Additionally, -// these two ops (cmp and blend) are very often used together. +// confusion for ppl who write in low-level SIMD instructions. +// Additionally, these two ops (cmp and blend) are very often used +// together. inline void cmplt_and_blend_inplace( const simd8float32 candidateValues, const simd8uint32 candidateIndices, @@ -1287,9 +1325,9 @@ inline void cmplt_and_blend_inplace( // maxValues[i] = !flag ? candidateValues[i] : currentValues[i]; // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i]; // } -// Max indices evaluation is inaccurate in case of equal values (the index of -// the last equal value is saved instead of the first one), but this behavior -// saves instructions. +// Max indices evaluation is inaccurate in case of equal values (the +// index of the last equal value is saved instead of the first one), but +// this behavior saves instructions. inline void cmplt_min_max_fast( const simd8float32 candidateValues, const simd8uint32 candidateIndices,