diff --git a/faiss/impl/LookupTableScaler.h b/faiss/impl/LookupTableScaler.h index c553a0f14d..b6438307fb 100644 --- a/faiss/impl/LookupTableScaler.h +++ b/faiss/impl/LookupTableScaler.h @@ -38,6 +38,23 @@ struct DummyScaler { return simd16uint16(0); } +#ifdef __AVX512F__ + inline simd64uint8 lookup(const simd64uint8&, const simd64uint8&) const { + FAISS_THROW_MSG("DummyScaler::lookup should not be called."); + return simd64uint8(0); + } + + inline simd32uint16 scale_lo(const simd64uint8&) const { + FAISS_THROW_MSG("DummyScaler::scale_lo should not be called."); + return simd32uint16(0); + } + + inline simd32uint16 scale_hi(const simd64uint8&) const { + FAISS_THROW_MSG("DummyScaler::scale_hi should not be called."); + return simd32uint16(0); + } +#endif + template inline dist_t scale_one(const dist_t&) const { FAISS_THROW_MSG("DummyScaler::scale_one should not be called."); @@ -67,6 +84,23 @@ struct NormTableScaler { return (simd16uint16(res) >> 8) * scale_simd; } +#ifdef __AVX512F__ + inline simd64uint8 lookup(const simd64uint8& lut, const simd64uint8& c) + const { + return lut.lookup_4_lanes(c); + } + + inline simd32uint16 scale_lo(const simd64uint8& res) const { + auto scale_simd_wide = simd32uint16(scale_simd, scale_simd); + return simd32uint16(res) * scale_simd_wide; + } + + inline simd32uint16 scale_hi(const simd64uint8& res) const { + auto scale_simd_wide = simd32uint16(scale_simd, scale_simd); + return (simd32uint16(res) >> 8) * scale_simd_wide; + } +#endif + // for non-SIMD implem 2, 3, 4 template inline dist_t scale_one(const dist_t& x) const { diff --git a/faiss/impl/pq4_fast_scan_search_qbs.cpp b/faiss/impl/pq4_fast_scan_search_qbs.cpp index d69542c309..bf2ccd1f76 100644 --- a/faiss/impl/pq4_fast_scan_search_qbs.cpp +++ b/faiss/impl/pq4_fast_scan_search_qbs.cpp @@ -31,6 +31,8 @@ namespace { * writes results in a ResultHandler */ +#ifndef __AVX512F__ + template void kernel_accumulate_block( int nsq, @@ -111,6 +113,451 @@ void kernel_accumulate_block( } } +#else + +// a special version for NQ=1. +// Despite the function being large in the text form, it compiles to a very +// compact assembler code. +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template +void kernel_accumulate_block_avx512_nq1( + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + ResultHandler& res, + const Scaler& scaler) { + // NQ is kept in order to match the similarity to baseline function + constexpr int NQ = 1; + // distance accumulators. We can accept more for NQ=1 + // layout: accu[q][b]: distance accumulator for vectors 32*b..32*b+15 + simd32uint16 accu[NQ][4]; + // layout: accu[q][b]: distance accumulator for vectors 32*b+16..32*b+31 + simd32uint16 accu1[NQ][4]; + + for (int q = 0; q < NQ; q++) { + for (int b = 0; b < 4; b++) { + accu[q][b].clear(); + accu1[q][b].clear(); + } + } + + // process "nsq - scaler.nscale" part + const int nsq_minus_nscale = nsq - scaler.nscale; + const int nsq_minus_nscale_8 = (nsq_minus_nscale / 8) * 8; + const int nsq_minus_nscale_4 = (nsq_minus_nscale / 4) * 4; + + // process in chunks of 8 + for (int sq = 0; sq < nsq_minus_nscale_8; sq += 8) { + // prefetch + simd64uint8 c(codes); + codes += 64; + + simd64uint8 c1(codes); + codes += 64; + + simd64uint8 mask(0xf); + // shift op does not exist for int8... + simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask; + simd64uint8 clo = c & mask; + + simd64uint8 c1hi = simd64uint8(simd32uint16(c1) >> 4) & mask; + simd64uint8 c1lo = c1 & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 4 quantizers + simd64uint8 lut(LUT); + LUT += 64; + + { + simd64uint8 res0 = lut.lookup_4_lanes(clo); + simd64uint8 res1 = lut.lookup_4_lanes(chi); + + accu[q][0] += simd32uint16(res0); + accu[q][1] += simd32uint16(res0) >> 8; + + accu[q][2] += simd32uint16(res1); + accu[q][3] += simd32uint16(res1) >> 8; + } + } + + for (int q = 0; q < NQ; q++) { + // load LUTs for 4 quantizers + simd64uint8 lut(LUT); + LUT += 64; + + { + simd64uint8 res0 = lut.lookup_4_lanes(c1lo); + simd64uint8 res1 = lut.lookup_4_lanes(c1hi); + + accu1[q][0] += simd32uint16(res0); + accu1[q][1] += simd32uint16(res0) >> 8; + + accu1[q][2] += simd32uint16(res1); + accu1[q][3] += simd32uint16(res1) >> 8; + } + } + } + + // process leftovers: a single chunk of size 4 + if (nsq_minus_nscale_8 != nsq_minus_nscale_4) { + // prefetch + simd64uint8 c(codes); + codes += 64; + + simd64uint8 mask(0xf); + // shift op does not exist for int8... + simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask; + simd64uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 4 quantizers + simd64uint8 lut(LUT); + LUT += 64; + + simd64uint8 res0 = lut.lookup_4_lanes(clo); + simd64uint8 res1 = lut.lookup_4_lanes(chi); + + accu[q][0] += simd32uint16(res0); + accu[q][1] += simd32uint16(res0) >> 8; + + accu[q][2] += simd32uint16(res1); + accu[q][3] += simd32uint16(res1) >> 8; + } + } + + // process leftovers: a single chunk of size 2 + if (nsq_minus_nscale_4 != nsq_minus_nscale) { + // prefetch + simd32uint8 c(codes); + codes += 32; + + simd32uint8 mask(0xf); + // shift op does not exist for int8... + simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; + simd32uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 2 quantizers + simd32uint8 lut(LUT); + LUT += 32; + + simd32uint8 res0 = lut.lookup_2_lanes(clo); + simd32uint8 res1 = lut.lookup_2_lanes(chi); + + accu[q][0] += simd32uint16(simd16uint16(res0)); + accu[q][1] += simd32uint16(simd16uint16(res0) >> 8); + + accu[q][2] += simd32uint16(simd16uint16(res1)); + accu[q][3] += simd32uint16(simd16uint16(res1) >> 8); + } + } + + // process "sq" part + const int nscale = scaler.nscale; + const int nscale_8 = (nscale / 8) * 8; + const int nscale_4 = (nscale / 4) * 4; + + // process in chunks of 8 + for (int sq = 0; sq < nscale_8; sq += 8) { + // prefetch + simd64uint8 c(codes); + codes += 64; + + simd64uint8 c1(codes); + codes += 64; + + simd64uint8 mask(0xf); + // shift op does not exist for int8... + simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask; + simd64uint8 clo = c & mask; + + simd64uint8 c1hi = simd64uint8(simd32uint16(c1) >> 4) & mask; + simd64uint8 c1lo = c1 & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 4 quantizers + simd64uint8 lut(LUT); + LUT += 64; + + { + simd64uint8 res0 = scaler.lookup(lut, clo); + accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..15 + accu[q][1] += scaler.scale_hi(res0); // handle vectors 16..31 + + simd64uint8 res1 = scaler.lookup(lut, chi); + accu[q][2] += scaler.scale_lo(res1); // handle vectors 32..47 + accu[q][3] += scaler.scale_hi(res1); // handle vectors 48..63 + } + } + + for (int q = 0; q < NQ; q++) { + // load LUTs for 4 quantizers + simd64uint8 lut(LUT); + LUT += 64; + + { + simd64uint8 res0 = scaler.lookup(lut, c1lo); + accu1[q][0] += scaler.scale_lo(res0); // handle vectors 0..7 + accu1[q][1] += scaler.scale_hi(res0); // handle vectors 8..15 + + simd64uint8 res1 = scaler.lookup(lut, c1hi); + accu1[q][2] += scaler.scale_lo(res1); // handle vectors 16..23 + accu1[q][3] += scaler.scale_hi(res1); // handle vectors 24..31 + } + } + } + + // process leftovers: a single chunk of size 4 + if (nscale_8 != nscale_4) { + // prefetch + simd64uint8 c(codes); + codes += 64; + + simd64uint8 mask(0xf); + // shift op does not exist for int8... + simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask; + simd64uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 4 quantizers + simd64uint8 lut(LUT); + LUT += 64; + + simd64uint8 res0 = scaler.lookup(lut, clo); + accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..15 + accu[q][1] += scaler.scale_hi(res0); // handle vectors 16..31 + + simd64uint8 res1 = scaler.lookup(lut, chi); + accu[q][2] += scaler.scale_lo(res1); // handle vectors 32..47 + accu[q][3] += scaler.scale_hi(res1); // handle vectors 48..63 + } + } + + // process leftovers: a single chunk of size 2 + if (nscale_4 != nscale) { + // prefetch + simd32uint8 c(codes); + codes += 32; + + simd32uint8 mask(0xf); + // shift op does not exist for int8... + simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; + simd32uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 2 quantizers + simd32uint8 lut(LUT); + LUT += 32; + + simd32uint8 res0 = scaler.lookup(lut, clo); + accu[q][0] += + simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7 + accu[q][1] += + simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15 + + simd32uint8 res1 = scaler.lookup(lut, chi); + accu[q][2] += simd32uint16( + scaler.scale_lo(res1)); // handle vectors 16..23 + accu[q][3] += simd32uint16( + scaler.scale_hi(res1)); // handle vectors 24..31 + } + } + + for (int q = 0; q < NQ; q++) { + for (int b = 0; b < 4; b++) { + accu[q][b] += accu1[q][b]; + } + } + + for (int q = 0; q < NQ; q++) { + accu[q][0] -= accu[q][1] << 8; + simd16uint16 dis0 = combine4x2(accu[q][0], accu[q][1]); + accu[q][2] -= accu[q][3] << 8; + simd16uint16 dis1 = combine4x2(accu[q][2], accu[q][3]); + res.handle(q, 0, dis0, dis1); + } +} + +// general-purpose case +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template +void kernel_accumulate_block_avx512_nqx( + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + ResultHandler& res, + const Scaler& scaler) { + // dummy alloc to keep the windows compiler happy + constexpr int NQA = NQ > 0 ? NQ : 1; + // distance accumulators + // layout: accu[q][b]: distance accumulator for vectors 8*b..8*b+7 + simd32uint16 accu[NQA][4]; + + for (int q = 0; q < NQ; q++) { + for (int b = 0; b < 4; b++) { + accu[q][b].clear(); + } + } + + // process "nsq - scaler.nscale" part + const int nsq_minus_nscale = nsq - scaler.nscale; + const int nsq_minus_nscale_4 = (nsq_minus_nscale / 4) * 4; + + // process in chunks of 8 + for (int sq = 0; sq < nsq_minus_nscale_4; sq += 4) { + // prefetch + simd64uint8 c(codes); + codes += 64; + + simd64uint8 mask(0xf); + // shift op does not exist for int8... + simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask; + simd64uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 4 quantizers + simd32uint8 lut_a(LUT); + simd32uint8 lut_b(LUT + NQ * 32); + + simd64uint8 lut(lut_a, lut_b); + LUT += 32; + + { + simd64uint8 res0 = lut.lookup_4_lanes(clo); + simd64uint8 res1 = lut.lookup_4_lanes(chi); + + accu[q][0] += simd32uint16(res0); + accu[q][1] += simd32uint16(res0) >> 8; + + accu[q][2] += simd32uint16(res1); + accu[q][3] += simd32uint16(res1) >> 8; + } + } + + LUT += NQ * 32; + } + + // process leftovers: a single chunk of size 2 + if (nsq_minus_nscale_4 != nsq_minus_nscale) { + // prefetch + simd32uint8 c(codes); + codes += 32; + + simd32uint8 mask(0xf); + // shift op does not exist for int8... + simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; + simd32uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 2 quantizers + simd32uint8 lut(LUT); + LUT += 32; + + simd32uint8 res0 = lut.lookup_2_lanes(clo); + simd32uint8 res1 = lut.lookup_2_lanes(chi); + + accu[q][0] += simd32uint16(simd16uint16(res0)); + accu[q][1] += simd32uint16(simd16uint16(res0) >> 8); + + accu[q][2] += simd32uint16(simd16uint16(res1)); + accu[q][3] += simd32uint16(simd16uint16(res1) >> 8); + } + } + + // process "sq" part + const int nscale = scaler.nscale; + const int nscale_4 = (nscale / 4) * 4; + + // process in chunks of 4 + for (int sq = 0; sq < nscale_4; sq += 4) { + // prefetch + simd64uint8 c(codes); + codes += 64; + + simd64uint8 mask(0xf); + // shift op does not exist for int8... + simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask; + simd64uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 4 quantizers + simd32uint8 lut_a(LUT); + simd32uint8 lut_b(LUT + NQ * 32); + + simd64uint8 lut(lut_a, lut_b); + LUT += 32; + + { + simd64uint8 res0 = scaler.lookup(lut, clo); + accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..7 + accu[q][1] += scaler.scale_hi(res0); // handle vectors 8..15 + + simd64uint8 res1 = scaler.lookup(lut, chi); + accu[q][2] += scaler.scale_lo(res1); // handle vectors 16..23 + accu[q][3] += scaler.scale_hi(res1); // handle vectors 24..31 + } + } + + LUT += NQ * 32; + } + + // process leftovers: a single chunk of size 2 + if (nscale_4 != nscale) { + // prefetch + simd32uint8 c(codes); + codes += 32; + + simd32uint8 mask(0xf); + // shift op does not exist for int8... + simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask; + simd32uint8 clo = c & mask; + + for (int q = 0; q < NQ; q++) { + // load LUTs for 2 quantizers + simd32uint8 lut(LUT); + LUT += 32; + + simd32uint8 res0 = scaler.lookup(lut, clo); + accu[q][0] += + simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7 + accu[q][1] += + simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15 + + simd32uint8 res1 = scaler.lookup(lut, chi); + accu[q][2] += simd32uint16( + scaler.scale_lo(res1)); // handle vectors 16..23 + accu[q][3] += simd32uint16( + scaler.scale_hi(res1)); // handle vectors 24..31 + } + } + + for (int q = 0; q < NQ; q++) { + accu[q][0] -= accu[q][1] << 8; + simd16uint16 dis0 = combine4x2(accu[q][0], accu[q][1]); + accu[q][2] -= accu[q][3] << 8; + simd16uint16 dis1 = combine4x2(accu[q][2], accu[q][3]); + res.handle(q, 0, dis0, dis1); + } +} + +template +void kernel_accumulate_block( + int nsq, + const uint8_t* codes, + const uint8_t* LUT, + ResultHandler& res, + const Scaler& scaler) { + if constexpr (NQ == 1) { + kernel_accumulate_block_avx512_nq1( + nsq, codes, LUT, res, scaler); + } else { + kernel_accumulate_block_avx512_nqx( + nsq, codes, LUT, res, scaler); + } +} + +#endif + // handle at most 4 blocks of queries template void accumulate_q_4step( diff --git a/faiss/impl/simd_result_handlers.h b/faiss/impl/simd_result_handlers.h index 2d8e5388d9..633d480990 100644 --- a/faiss/impl/simd_result_handlers.h +++ b/faiss/impl/simd_result_handlers.h @@ -505,7 +505,7 @@ struct RangeHandler : ResultHandlerCompare { n_per_query.resize(nq + 1); } - virtual void begin(const float* norms) { + virtual void begin(const float* norms) override { normalizers = norms; for (int q = 0; q < nq; ++q) { thresholds[q] = diff --git a/faiss/utils/simdlib.h b/faiss/utils/simdlib.h index 27e9cc59f5..beeec2374e 100644 --- a/faiss/utils/simdlib.h +++ b/faiss/utils/simdlib.h @@ -14,7 +14,12 @@ * functions. */ -#ifdef __AVX2__ +#if defined(__AVX512F__) + +#include +#include + +#elif defined(__AVX2__) #include diff --git a/faiss/utils/simdlib_avx512.h b/faiss/utils/simdlib_avx512.h new file mode 100644 index 0000000000..9ce0965895 --- /dev/null +++ b/faiss/utils/simdlib_avx512.h @@ -0,0 +1,296 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#include + +#include + +namespace faiss { + +/** Simple wrapper around the AVX 512-bit registers + * + * The objective is to separate the different interpretations of the same + * registers (as a vector of uint8, uint16 or uint32), to provide printing + * functions, and to give more readable names to the AVX intrinsics. It does not + * pretend to be exhausitve, functions are added as needed. + */ + +/// 512-bit representation without interpretation as a vector +struct simd512bit { + union { + __m512i i; + __m512 f; + }; + + simd512bit() {} + + explicit simd512bit(__m512i i) : i(i) {} + + explicit simd512bit(__m512 f) : f(f) {} + + explicit simd512bit(const void* x) + : i(_mm512_loadu_si512((__m512i const*)x)) {} + + // sets up a lower half of the register while keeping upper one as zero + explicit simd512bit(simd256bit lo) + : simd512bit(_mm512_inserti32x8( + _mm512_castsi256_si512(lo.i), + _mm256_setzero_si256(), + 1)) {} + + // constructs from lower and upper halves + explicit simd512bit(simd256bit lo, simd256bit hi) + : simd512bit(_mm512_inserti32x8( + _mm512_castsi256_si512(lo.i), + hi.i, + 1)) {} + + void clear() { + i = _mm512_setzero_si512(); + } + + void storeu(void* ptr) const { + _mm512_storeu_si512((__m512i*)ptr, i); + } + + void loadu(const void* ptr) { + i = _mm512_loadu_si512((__m512i*)ptr); + } + + void store(void* ptr) const { + _mm512_storeu_si512((__m512i*)ptr, i); + } + + void bin(char bits[513]) const { + char bytes[64]; + storeu((void*)bytes); + for (int i = 0; i < 512; i++) { + bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1); + } + bits[512] = 0; + } + + std::string bin() const { + char bits[257]; + bin(bits); + return std::string(bits); + } +}; + +/// vector of 32 elements in uint16 +struct simd32uint16 : simd512bit { + simd32uint16() {} + + explicit simd32uint16(__m512i i) : simd512bit(i) {} + + explicit simd32uint16(int x) : simd512bit(_mm512_set1_epi16(x)) {} + + explicit simd32uint16(uint16_t x) : simd512bit(_mm512_set1_epi16(x)) {} + + explicit simd32uint16(simd512bit x) : simd512bit(x) {} + + explicit simd32uint16(const uint16_t* x) : simd512bit((const void*)x) {} + + // sets up a lower half of the register + explicit simd32uint16(simd256bit lo) : simd512bit(lo) {} + + // constructs from lower and upper halves + explicit simd32uint16(simd256bit lo, simd256bit hi) : simd512bit(lo, hi) {} + + std::string elements_to_string(const char* fmt) const { + uint16_t bytes[32]; + storeu((void*)bytes); + char res[2000]; + char* ptr = res; + for (int i = 0; i < 32; i++) { + ptr += sprintf(ptr, fmt, bytes[i]); + } + // strip last , + ptr[-1] = 0; + return std::string(res); + } + + std::string hex() const { + return elements_to_string("%02x,"); + } + + std::string dec() const { + return elements_to_string("%3d,"); + } + + void set1(uint16_t x) { + i = _mm512_set1_epi16((short)x); + } + + simd32uint16 operator*(const simd32uint16& other) const { + return simd32uint16(_mm512_mullo_epi16(i, other.i)); + } + + // shift must be known at compile time + simd32uint16 operator>>(const int shift) const { + return simd32uint16(_mm512_srli_epi16(i, shift)); + } + + // shift must be known at compile time + simd32uint16 operator<<(const int shift) const { + return simd32uint16(_mm512_slli_epi16(i, shift)); + } + + simd32uint16 operator+=(simd32uint16 other) { + i = _mm512_add_epi16(i, other.i); + return *this; + } + + simd32uint16 operator-=(simd32uint16 other) { + i = _mm512_sub_epi16(i, other.i); + return *this; + } + + simd32uint16 operator+(simd32uint16 other) const { + return simd32uint16(_mm512_add_epi16(i, other.i)); + } + + simd32uint16 operator-(simd32uint16 other) const { + return simd32uint16(_mm512_sub_epi16(i, other.i)); + } + + simd32uint16 operator&(simd512bit other) const { + return simd32uint16(_mm512_and_si512(i, other.i)); + } + + simd32uint16 operator|(simd512bit other) const { + return simd32uint16(_mm512_or_si512(i, other.i)); + } + + simd32uint16 operator^(simd512bit other) const { + return simd32uint16(_mm512_xor_si512(i, other.i)); + } + + simd32uint16 operator~() const { + return simd32uint16(_mm512_xor_si512(i, _mm512_set1_epi32(-1))); + } + + simd16uint16 low() const { + return simd16uint16(_mm512_castsi512_si256(i)); + } + + simd16uint16 high() const { + return simd16uint16(_mm512_extracti32x8_epi32(i, 1)); + } + + // for debugging only + uint16_t operator[](int i) const { + ALIGNED(64) uint16_t tab[32]; + store(tab); + return tab[i]; + } + + void accu_min(simd32uint16 incoming) { + i = _mm512_min_epu16(i, incoming.i); + } + + void accu_max(simd32uint16 incoming) { + i = _mm512_max_epu16(i, incoming.i); + } +}; + +// decompose in 128-lanes: a = (a0, a1, a2, a3), b = (b0, b1, b2, b3) +// return (a0 + a1 + a2 + a3, b0 + b1 + b2 + b3) +inline simd16uint16 combine4x2(simd32uint16 a, simd32uint16 b) { + return combine2x2(a.low(), b.low()) + combine2x2(a.high(), b.high()); +} + +// vector of 32 unsigned 8-bit integers +struct simd64uint8 : simd512bit { + simd64uint8() {} + + explicit simd64uint8(__m512i i) : simd512bit(i) {} + + explicit simd64uint8(int x) : simd512bit(_mm512_set1_epi8(x)) {} + + explicit simd64uint8(uint8_t x) : simd512bit(_mm512_set1_epi8(x)) {} + + // sets up a lower half of the register + explicit simd64uint8(simd256bit lo) : simd512bit(lo) {} + + // constructs from lower and upper halves + explicit simd64uint8(simd256bit lo, simd256bit hi) : simd512bit(lo, hi) {} + + explicit simd64uint8(simd512bit x) : simd512bit(x) {} + + explicit simd64uint8(const uint8_t* x) : simd512bit((const void*)x) {} + + std::string elements_to_string(const char* fmt) const { + uint8_t bytes[64]; + storeu((void*)bytes); + char res[2000]; + char* ptr = res; + for (int i = 0; i < 64; i++) { + ptr += sprintf(ptr, fmt, bytes[i]); + } + // strip last , + ptr[-1] = 0; + return std::string(res); + } + + std::string hex() const { + return elements_to_string("%02x,"); + } + + std::string dec() const { + return elements_to_string("%3d,"); + } + + void set1(uint8_t x) { + i = _mm512_set1_epi8((char)x); + } + + simd64uint8 operator&(simd512bit other) const { + return simd64uint8(_mm512_and_si512(i, other.i)); + } + + simd64uint8 operator+(simd64uint8 other) const { + return simd64uint8(_mm512_add_epi8(i, other.i)); + } + + simd64uint8 lookup_4_lanes(simd64uint8 idx) const { + return simd64uint8(_mm512_shuffle_epi8(i, idx.i)); + } + + // extract + 0-extend lane + // this operation is slow (3 cycles) + simd32uint16 lane0_as_uint16() const { + __m256i x = _mm512_extracti32x8_epi32(i, 0); + return simd32uint16(_mm512_cvtepu8_epi16(x)); + } + + simd32uint16 lane1_as_uint16() const { + __m256i x = _mm512_extracti32x8_epi32(i, 1); + return simd32uint16(_mm512_cvtepu8_epi16(x)); + } + + simd64uint8 operator+=(simd64uint8 other) { + i = _mm512_add_epi8(i, other.i); + return *this; + } + + // for debugging only + uint8_t operator[](int i) const { + ALIGNED(64) uint8_t tab[64]; + store(tab); + return tab[i]; + } +}; + +} // namespace faiss