Skip to content

Commit

Permalink
upgrade horizontal sum in distance_single_code for PQ/IVFPQ (#2830)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2830

17 cycles per AVX2 horizontal sum instead of 19

Reviewed By: mdouze

Differential Revision: D45244153

fbshipit-source-id: 15accba2e8b4f12725dba41696c302e72f61c2db
  • Loading branch information
Alexandr Guzhva authored and facebook-github-bot committed Apr 25, 2023
1 parent d0ba4c0 commit 1cb1e54
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions faiss/impl/code_distance/code_distance-avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,19 @@

namespace {

// Computes a horizontal sum over an __m256 register
inline float horizontal_sum(const __m256 reg) {
const __m256 h0 = _mm256_hadd_ps(reg, reg);
const __m256 h1 = _mm256_hadd_ps(h0, h0);

// extract high and low __m128 regs from __m256
const __m128 h2 = _mm256_extractf128_ps(h1, 1);
const __m128 h3 = _mm256_castps256_ps128(h1);

// get a final hsum into all 4 regs
const __m128 h4 = _mm_add_ss(h2, h3);
inline float horizontal_sum(const __m128 v) {
const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
const __m128 v1 = _mm_add_ps(v, v0);
__m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
const __m128 v3 = _mm_add_ps(v1, v2);
return _mm_cvtss_f32(v3);
}

// extract f[0] from __m128
const float hsum = _mm_cvtss_f32(h4);
return hsum;
// Computes a horizontal sum over an __m256 register
inline float horizontal_sum(const __m256 v) {
const __m128 v0 =
_mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
return horizontal_sum(v0);
}

} // namespace
Expand Down

0 comments on commit 1cb1e54

Please sign in to comment.