Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <[email protected]>
  • Loading branch information
alexanderguzhva committed Mar 18, 2024
1 parent 3be9fc9 commit 74e4ae9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 29 deletions.
47 changes: 24 additions & 23 deletions faiss/impl/pq4_fast_scan_search_qbs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ void kernel_accumulate_block(

#else

// a special version for NQ=1.
// a special version for NQ=1.
// Despite the function being large in the text form, it compiles to a very
// compact assembler code.
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
Expand Down Expand Up @@ -143,10 +143,8 @@ void kernel_accumulate_block_avx512_nq1(

// process "nsq - scaler.nscale" part
const int nsq_minus_nscale = nsq - scaler.nscale;
const int nsq_minus_nscale_8 =
(nsq_minus_nscale / 8) * 8;
const int nsq_minus_nscale_4 =
(nsq_minus_nscale / 4) * 4;
const int nsq_minus_nscale_8 = (nsq_minus_nscale / 8) * 8;
const int nsq_minus_nscale_4 = (nsq_minus_nscale / 4) * 4;

// process in chunks of 8
for (int sq = 0; sq < nsq_minus_nscale_8; sq += 8) {
Expand Down Expand Up @@ -291,7 +289,7 @@ void kernel_accumulate_block_avx512_nq1(
accu[q][3] += scaler.scale_hi(res1); // handle vectors 48..63
}
}

for (int q = 0; q < NQ; q++) {
// load LUTs for 4 quantizers
simd64uint8 lut(LUT);
Expand Down Expand Up @@ -352,12 +350,16 @@ void kernel_accumulate_block_avx512_nq1(
LUT += 32;

simd32uint8 res0 = scaler.lookup(lut, clo);
accu[q][0] += simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7
accu[q][1] += simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15
accu[q][0] +=
simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7
accu[q][1] +=
simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15

simd32uint8 res1 = scaler.lookup(lut, chi);
accu[q][2] += simd32uint16(scaler.scale_lo(res1)); // handle vectors 16..23
accu[q][3] += simd32uint16(scaler.scale_hi(res1)); // handle vectors 24..31
accu[q][2] += simd32uint16(
scaler.scale_lo(res1)); // handle vectors 16..23
accu[q][3] += simd32uint16(
scaler.scale_hi(res1)); // handle vectors 24..31
}
}

Expand Down Expand Up @@ -385,7 +387,6 @@ void kernel_accumulate_block_avx512_nqx(
const uint8_t* LUT,
ResultHandler& res,
const Scaler& scaler) {

// dummy alloc to keep the windows compiler happy
constexpr int NQA = NQ > 0 ? NQ : 1;
// distance accumulators
Expand All @@ -400,8 +401,7 @@ void kernel_accumulate_block_avx512_nqx(

// process "nsq - scaler.nscale" part
const int nsq_minus_nscale = nsq - scaler.nscale;
const int nsq_minus_nscale_4 =
(nsq_minus_nscale / 4) * 4;
const int nsq_minus_nscale_4 = (nsq_minus_nscale / 4) * 4;

// process in chunks of 8
for (int sq = 0; sq < nsq_minus_nscale_4; sq += 4) {
Expand Down Expand Up @@ -518,12 +518,16 @@ void kernel_accumulate_block_avx512_nqx(
LUT += 32;

simd32uint8 res0 = scaler.lookup(lut, clo);
accu[q][0] += simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7
accu[q][1] += simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15
accu[q][0] +=
simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7
accu[q][1] +=
simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15

simd32uint8 res1 = scaler.lookup(lut, chi);
accu[q][2] += simd32uint16(scaler.scale_lo(res1)); // handle vectors 16..23
accu[q][3] += simd32uint16(scaler.scale_hi(res1)); // handle vectors 24..31
accu[q][2] += simd32uint16(
scaler.scale_lo(res1)); // handle vectors 16..23
accu[q][3] += simd32uint16(
scaler.scale_hi(res1)); // handle vectors 24..31
}
}

Expand All @@ -542,16 +546,13 @@ void kernel_accumulate_block(
const uint8_t* codes,
const uint8_t* LUT,
ResultHandler& res,
const Scaler& scaler
) {
const Scaler& scaler) {
if constexpr (NQ == 1) {
kernel_accumulate_block_avx512_nq1<ResultHandler, Scaler>(
nsq, codes, LUT, res, scaler
);
nsq, codes, LUT, res, scaler);
} else {
kernel_accumulate_block_avx512_nqx<NQ, ResultHandler, Scaler>(
nsq, codes, LUT, res, scaler
);
nsq, codes, LUT, res, scaler);
}
}

Expand Down
2 changes: 1 addition & 1 deletion faiss/utils/simdlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#if defined(__AVX512F__)

#include <faiss/utils/simdlib_avx512.h>
#include <faiss/utils/simdlib_avx2.h>
#include <faiss/utils/simdlib_avx512.h>

#elif defined(__AVX2__)

Expand Down
15 changes: 10 additions & 5 deletions faiss/utils/simdlib_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,18 @@ struct simd512bit {
: i(_mm512_loadu_si512((__m512i const*)x)) {}

// sets up a lower half of the register while keeping upper one as zero
explicit simd512bit(simd256bit lo) :
simd512bit(_mm512_inserti32x8(_mm512_castsi256_si512(lo.i), _mm256_setzero_si256(), 1)) {}
explicit simd512bit(simd256bit lo)
: simd512bit(_mm512_inserti32x8(
_mm512_castsi256_si512(lo.i),
_mm256_setzero_si256(),
1)) {}

// constructs from lower and upper halves
explicit simd512bit(simd256bit lo, simd256bit hi) :
simd512bit(_mm512_inserti32x8(_mm512_castsi256_si512(lo.i), hi.i, 1)) {}

explicit simd512bit(simd256bit lo, simd256bit hi)
: simd512bit(_mm512_inserti32x8(
_mm512_castsi256_si512(lo.i),
hi.i,
1)) {}

void clear() {
i = _mm512_setzero_si512();
Expand Down

0 comments on commit 74e4ae9

Please sign in to comment.