From a91a2887fe44e2290072041cb52c6069a074f2ec Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Mon, 26 Jun 2023 14:06:10 -0700 Subject: [PATCH] use dispatcher function to call HammingComputer (#2918) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2918 The HammingComputer class is optimized for several vector sizes. So far it's been the caller's responsiblity to instanciate the relevant optimized version. This diff introduces a `dispatch_HammingComputer` function that can be called with a template class that is instanciated for all existing optimized HammingComputer's. Reviewed By: algoriddle Differential Revision: D46858553 fbshipit-source-id: 32c31689bba7c0b406b309fc8574c95fa24022ba --- benchs/bench_hamming_computer.cpp | 60 +++++ faiss/IndexBinaryHNSW.cpp | 30 +-- faiss/IndexBinaryHash.cpp | 74 +++--- faiss/IndexBinaryIVF.cpp | 247 +++++++-------------- faiss/IndexIVFPQ.cpp | 31 +-- faiss/IndexIVFSpectralHash.cpp | 25 +-- faiss/IndexPQ.cpp | 75 +++---- faiss/utils/hamming.cpp | 174 ++++++--------- faiss/utils/hamming_distance/avx2-inl.h | 87 -------- faiss/utils/hamming_distance/common.h | 1 + faiss/utils/hamming_distance/generic-inl.h | 87 -------- faiss/utils/hamming_distance/hamdis-inl.h | 57 +++++ faiss/utils/hamming_distance/neon-inl.h | 103 --------- 13 files changed, 365 insertions(+), 686 deletions(-) diff --git a/benchs/bench_hamming_computer.cpp b/benchs/bench_hamming_computer.cpp index fa37b86af4..0dafeb24b8 100644 --- a/benchs/bench_hamming_computer.cpp +++ b/benchs/bench_hamming_computer.cpp @@ -18,6 +18,66 @@ using namespace faiss; +// These implementations are currently slower than HammingComputerDefault so +// they are not in the main faiss anymore. +struct HammingComputerM8 { + const uint64_t* a; + int n; + + HammingComputerM8() {} + + HammingComputerM8(const uint8_t* a8, int code_size) { + set(a8, code_size); + } + + void set(const uint8_t* a8, int code_size) { + assert(code_size % 8 == 0); + a = (uint64_t*)a8; + n = code_size / 8; + } + + int hamming(const uint8_t* b8) const { + const uint64_t* b = (uint64_t*)b8; + int accu = 0; + for (int i = 0; i < n; i++) + accu += popcount64(a[i] ^ b[i]); + return accu; + } + + inline int get_code_size() const { + return n * 8; + } +}; + +struct HammingComputerM4 { + const uint32_t* a; + int n; + + HammingComputerM4() {} + + HammingComputerM4(const uint8_t* a4, int code_size) { + set(a4, code_size); + } + + void set(const uint8_t* a4, int code_size) { + assert(code_size % 4 == 0); + a = (uint32_t*)a4; + n = code_size / 4; + } + + int hamming(const uint8_t* b8) const { + const uint32_t* b = (uint32_t*)b8; + int accu = 0; + for (int i = 0; i < n; i++) + accu += popcount64(a[i] ^ b[i]); + return accu; + } + + inline int get_code_size() const { + return n * 4; + } +}; + template void hamming_cpt_test( int code_size, diff --git a/faiss/IndexBinaryHNSW.cpp b/faiss/IndexBinaryHNSW.cpp index 1f034009f8..9481fe67f2 100644 --- a/faiss/IndexBinaryHNSW.cpp +++ b/faiss/IndexBinaryHNSW.cpp @@ -281,31 +281,21 @@ struct FlatHammingDis : DistanceComputer { } }; +struct BuildDistanceComputer { + using T = DistanceComputer*; + template + DistanceComputer* f(IndexBinaryFlat* flat_storage) { + return new FlatHammingDis(*flat_storage); + } +}; + } // namespace DistanceComputer* IndexBinaryHNSW::get_distance_computer() const { IndexBinaryFlat* flat_storage = dynamic_cast(storage); - FAISS_ASSERT(flat_storage != nullptr); - - switch (code_size) { - case 4: - return new FlatHammingDis(*flat_storage); - case 8: - return new FlatHammingDis(*flat_storage); - case 16: - return new FlatHammingDis(*flat_storage); - case 20: - return new FlatHammingDis(*flat_storage); - case 32: - return new FlatHammingDis(*flat_storage); - case 64: - return new FlatHammingDis(*flat_storage); - default: - break; - } - - return new FlatHammingDis(*flat_storage); + BuildDistanceComputer bd; + return dispatch_HammingComputer(code_size, bd, flat_storage); } } // namespace faiss diff --git a/faiss/IndexBinaryHash.cpp b/faiss/IndexBinaryHash.cpp index 0e449bab77..86a6d52ded 100644 --- a/faiss/IndexBinaryHash.cpp +++ b/faiss/IndexBinaryHash.cpp @@ -176,6 +176,14 @@ void search_single_query_template( } while (fe.next()); } +struct Run_search_single_query { + using T = void; + template + T f(Types... args) { + search_single_query_template(args...); + } +}; + template void search_single_query( const IndexBinaryHash& index, @@ -184,29 +192,9 @@ void search_single_query( size_t& n0, size_t& nlist, size_t& ndis) { -#define HC(name) \ - search_single_query_template(index, q, res, n0, nlist, ndis); - switch (index.code_size) { - case 4: - HC(HammingComputer4); - break; - case 8: - HC(HammingComputer8); - break; - case 16: - HC(HammingComputer16); - break; - case 20: - HC(HammingComputer20); - break; - case 32: - HC(HammingComputer32); - break; - default: - HC(HammingComputerDefault); - break; - } -#undef HC + Run_search_single_query r; + dispatch_HammingComputer( + index.code_size, r, index, q, res, n0, nlist, ndis); } } // anonymous namespace @@ -349,15 +337,15 @@ namespace { template static void verify_shortlist( - const IndexBinaryFlat& index, + const IndexBinaryFlat* index, const uint8_t* q, const std::unordered_set& shortlist, SearchResults& res) { - size_t code_size = index.code_size; + size_t code_size = index->code_size; size_t nlist = 0, ndis = 0, n0 = 0; HammingComputer hc(q, code_size); - const uint8_t* codes = index.xb.data(); + const uint8_t* codes = index->xb.data(); for (auto i : shortlist) { int dis = hc.hamming(codes + i * code_size); @@ -365,6 +353,14 @@ static void verify_shortlist( } } +struct Run_verify_shortlist { + using T = void; + template + void f(Types... args) { + verify_shortlist(args...); + } +}; + template void search_1_query_multihash( const IndexBinaryMultiHash& index, @@ -405,29 +401,9 @@ void search_1_query_multihash( ndis += shortlist.size(); // verify shortlist - -#define HC(name) verify_shortlist(*index.storage, xi, shortlist, res) - switch (index.code_size) { - case 4: - HC(HammingComputer4); - break; - case 8: - HC(HammingComputer8); - break; - case 16: - HC(HammingComputer16); - break; - case 20: - HC(HammingComputer20); - break; - case 32: - HC(HammingComputer32); - break; - default: - HC(HammingComputerDefault); - break; - } -#undef HC + Run_verify_shortlist r; + dispatch_HammingComputer( + index.code_size, r, index.storage, xi, shortlist, res); } } // anonymous namespace diff --git a/faiss/IndexBinaryIVF.cpp b/faiss/IndexBinaryIVF.cpp index 65b98280dc..0e886b47e2 100644 --- a/faiss/IndexBinaryIVF.cpp +++ b/faiss/IndexBinaryIVF.cpp @@ -370,7 +370,7 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner { }; void search_knn_hamming_heap( - const IndexBinaryIVF& ivf, + const IndexBinaryIVF* ivf, size_t n, const uint8_t* __restrict x, idx_t k, @@ -380,10 +380,10 @@ void search_knn_hamming_heap( idx_t* __restrict labels, bool store_pairs, const IVFSearchParameters* params) { - idx_t nprobe = params ? params->nprobe : ivf.nprobe; - nprobe = std::min((idx_t)ivf.nlist, nprobe); - idx_t max_codes = params ? params->max_codes : ivf.max_codes; - MetricType metric_type = ivf.metric_type; + idx_t nprobe = params ? params->nprobe : ivf->nprobe; + nprobe = std::min((idx_t)ivf->nlist, nprobe); + idx_t max_codes = params ? params->max_codes : ivf->max_codes; + MetricType metric_type = ivf->metric_type; // almost verbatim copy from IndexIVF::search_preassigned @@ -394,11 +394,11 @@ void search_knn_hamming_heap( #pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap) { std::unique_ptr scanner( - ivf.get_InvertedListScanner(store_pairs)); + ivf->get_InvertedListScanner(store_pairs)); #pragma omp for for (idx_t i = 0; i < n; i++) { - const uint8_t* xi = x + i * ivf.code_size; + const uint8_t* xi = x + i * ivf->code_size; scanner->set_query(xi); const idx_t* keysi = keys + i * nprobe; @@ -420,23 +420,24 @@ void search_knn_hamming_heap( continue; } FAISS_THROW_IF_NOT_FMT( - key < (idx_t)ivf.nlist, + key < (idx_t)ivf->nlist, "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n", key, ik, - ivf.nlist); + ivf->nlist); scanner->set_list(key, coarse_dis[i * nprobe + ik]); nlistv++; - size_t list_size = ivf.invlists->list_size(key); - InvertedLists::ScopedCodes scodes(ivf.invlists, key); + size_t list_size = ivf->invlists->list_size(key); + InvertedLists::ScopedCodes scodes(ivf->invlists, key); std::unique_ptr sids; const idx_t* ids = nullptr; if (!store_pairs) { - sids.reset(new InvertedLists::ScopedIds(ivf.invlists, key)); + sids.reset( + new InvertedLists::ScopedIds(ivf->invlists, key)); ids = sids->get(); } @@ -466,7 +467,7 @@ void search_knn_hamming_heap( template void search_knn_hamming_count( - const IndexBinaryIVF& ivf, + const IndexBinaryIVF* ivf, size_t nx, const uint8_t* __restrict x, const idx_t* __restrict keys, @@ -474,21 +475,21 @@ void search_knn_hamming_count( int32_t* __restrict distances, idx_t* __restrict labels, const IVFSearchParameters* params) { - const int nBuckets = ivf.d + 1; + const int nBuckets = ivf->d + 1; std::vector all_counters(nx * nBuckets, 0); std::unique_ptr all_ids_per_dis(new idx_t[nx * nBuckets * k]); - idx_t nprobe = params ? params->nprobe : ivf.nprobe; - nprobe = std::min((idx_t)ivf.nlist, nprobe); - idx_t max_codes = params ? params->max_codes : ivf.max_codes; + idx_t nprobe = params ? params->nprobe : ivf->nprobe; + nprobe = std::min((idx_t)ivf->nlist, nprobe); + idx_t max_codes = params ? params->max_codes : ivf->max_codes; std::vector> cs; for (size_t i = 0; i < nx; ++i) { cs.push_back(HCounterState( all_counters.data() + i * nBuckets, all_ids_per_dis.get() + i * nBuckets * k, - x + i * ivf.code_size, - ivf.d, + x + i * ivf->code_size, + ivf->d, k)); } @@ -508,27 +509,28 @@ void search_knn_hamming_count( continue; } FAISS_THROW_IF_NOT_FMT( - key < (idx_t)ivf.nlist, + key < (idx_t)ivf->nlist, "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n", key, ik, - ivf.nlist); + ivf->nlist); nlistv++; - size_t list_size = ivf.invlists->list_size(key); - InvertedLists::ScopedCodes scodes(ivf.invlists, key); + size_t list_size = ivf->invlists->list_size(key); + InvertedLists::ScopedCodes scodes(ivf->invlists, key); const uint8_t* list_vecs = scodes.get(); const idx_t* ids = - store_pairs ? nullptr : ivf.invlists->get_ids(key); + store_pairs ? nullptr : ivf->invlists->get_ids(key); for (size_t j = 0; j < list_size; j++) { - const uint8_t* yj = list_vecs + ivf.code_size * j; + const uint8_t* yj = list_vecs + ivf->code_size * j; idx_t id = store_pairs ? (key << 32 | j) : ids[j]; csi.update_counter(yj, id); } - if (ids) - ivf.invlists->release_ids(key, ids); + if (ids) { + ivf->invlists->release_ids(key, ids); + } nscan += list_size; if (max_codes && nscan >= max_codes) @@ -634,7 +636,7 @@ struct BlockSearchVariableK { template void search_knn_hamming_per_invlist( - const IndexBinaryIVF& ivf, + const IndexBinaryIVF* ivf, size_t n, const uint8_t* __restrict x, idx_t k, @@ -644,12 +646,12 @@ void search_knn_hamming_per_invlist( idx_t* __restrict labels, bool store_pairs, const IVFSearchParameters* params) { - idx_t nprobe = params ? params->nprobe : ivf.nprobe; - nprobe = std::min((idx_t)ivf.nlist, nprobe); - idx_t max_codes = params ? params->max_codes : ivf.max_codes; + idx_t nprobe = params ? params->nprobe : ivf->nprobe; + nprobe = std::min((idx_t)ivf->nlist, nprobe); + idx_t max_codes = params ? params->max_codes : ivf->max_codes; FAISS_THROW_IF_NOT(max_codes == 0); FAISS_THROW_IF_NOT(!store_pairs); - MetricType metric_type = ivf.metric_type; + MetricType metric_type = ivf->metric_type; // reorder buckets std::vector lims(n + 1); @@ -658,18 +660,18 @@ void search_knn_hamming_per_invlist( for (idx_t i = 0; i < n * nprobe; i++) { keys[i] = keys_in[i]; } - matrix_bucket_sort_inplace(n, nprobe, keys, ivf.nlist, lims.data(), 0); + matrix_bucket_sort_inplace(n, nprobe, keys, ivf->nlist, lims.data(), 0); using C = CMax; heap_heapify(n * k, distances, labels); - const size_t code_size = ivf.code_size; + const size_t code_size = ivf->code_size; - for (idx_t l = 0; l < ivf.nlist; l++) { + for (idx_t l = 0; l < ivf->nlist; l++) { idx_t l0 = lims[l], nq = lims[l + 1] - l0; - InvertedLists::ScopedCodes scodes(ivf.invlists, l); - InvertedLists::ScopedIds sidx(ivf.invlists, l); - idx_t nb = ivf.invlists->list_size(l); + InvertedLists::ScopedCodes scodes(ivf->invlists, l); + InvertedLists::ScopedIds sidx(ivf->invlists, l); + idx_t nb = ivf->invlists->list_size(l); const uint8_t* bcodes = scodes.get(); const idx_t* ids = sidx.get(); @@ -735,151 +737,70 @@ void search_knn_hamming_per_invlist( } } +struct Run_search_knn_hamming_per_invlist { + using T = void; + + template + void f(Types... args) { + search_knn_hamming_per_invlist(args...); + } +}; + template -void search_knn_hamming_count_1( - const IndexBinaryIVF& ivf, - size_t nx, - const uint8_t* x, - const idx_t* keys, - int k, - int32_t* distances, - idx_t* labels, - const IVFSearchParameters* params) { - switch (ivf.code_size) { -#define HANDLE_CS(cs) \ - case cs: \ - search_knn_hamming_count( \ - ivf, nx, x, keys, k, distances, labels, params); \ - break; - HANDLE_CS(4); - HANDLE_CS(8); - HANDLE_CS(16); - HANDLE_CS(20); - HANDLE_CS(32); - HANDLE_CS(64); -#undef HANDLE_CS - default: - search_knn_hamming_count( - ivf, nx, x, keys, k, distances, labels, params); - break; +struct Run_search_knn_hamming_count { + using T = void; + + template + void f(Types... args) { + search_knn_hamming_count(args...); } -} +}; -void search_knn_hamming_per_invlist_1( - const IndexBinaryIVF& ivf, - size_t n, - const uint8_t* x, - idx_t k, - const idx_t* keys, - const int32_t* coarse_dis, - int32_t* distances, - idx_t* labels, - bool store_pairs, - const IVFSearchParameters* params) { - switch (ivf.code_size) { -#define HANDLE_CS(cs) \ - case cs: \ - search_knn_hamming_per_invlist( \ - ivf, \ - n, \ - x, \ - k, \ - keys, \ - coarse_dis, \ - distances, \ - labels, \ - store_pairs, \ - params); \ - break; - HANDLE_CS(4); - HANDLE_CS(8); - HANDLE_CS(16); - HANDLE_CS(20); - HANDLE_CS(32); - HANDLE_CS(64); -#undef HANDLE_CS - default: - search_knn_hamming_per_invlist( - ivf, - n, - x, - k, - keys, - coarse_dis, - distances, - labels, - store_pairs, - params); - break; +struct BuildScanner { + using T = BinaryInvertedListScanner*; + + template + T f(size_t code_size, bool store_pairs) { + return new IVFBinaryScannerL2(code_size, store_pairs); } -} +}; } // anonymous namespace BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner( bool store_pairs) const { -#define HC(name) return new IVFBinaryScannerL2(code_size, store_pairs) - switch (code_size) { - case 4: - HC(HammingComputer4); - case 8: - HC(HammingComputer8); - case 16: - HC(HammingComputer16); - case 20: - HC(HammingComputer20); - case 32: - HC(HammingComputer32); - case 64: - HC(HammingComputer64); - default: - HC(HammingComputerDefault); - } -#undef HC + BuildScanner bs; + return dispatch_HammingComputer(code_size, bs, code_size, store_pairs); } void IndexBinaryIVF::search_preassigned( idx_t n, const uint8_t* x, idx_t k, - const idx_t* idx, - const int32_t* coarse_dis, - int32_t* distances, - idx_t* labels, + const idx_t* cidx, + const int32_t* cdis, + int32_t* dis, + idx_t* idx, bool store_pairs, const IVFSearchParameters* params) const { if (per_invlist_search) { - search_knn_hamming_per_invlist_1( - *this, - n, - x, - k, - idx, - coarse_dis, - distances, - labels, - store_pairs, - params); + Run_search_knn_hamming_per_invlist r; + // clang-format off + dispatch_HammingComputer( + code_size, r, this, n, x, k, + cidx, cdis, dis, idx, store_pairs, params); + // clang-format on } else if (use_heap) { search_knn_hamming_heap( - *this, - n, - x, - k, - idx, - coarse_dis, - distances, - labels, - store_pairs, - params); - } else { - if (store_pairs) { - search_knn_hamming_count_1( - *this, n, x, idx, k, distances, labels, params); - } else { - search_knn_hamming_count_1( - *this, n, x, idx, k, distances, labels, params); - } + this, n, x, k, cidx, cdis, dis, idx, store_pairs, params); + } else if (store_pairs) { // !use_heap && store_pairs + Run_search_knn_hamming_count r; + dispatch_HammingComputer( + code_size, r, this, n, x, cidx, k, dis, idx, params); + } else { // !use_heap && !store_pairs + Run_search_knn_hamming_count r; + dispatch_HammingComputer( + code_size, r, this, n, x, cidx, k, dis, idx, params); } } diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index 60633cc41b..058798b15c 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -1154,30 +1154,23 @@ struct IVFPQScannerT : QueryTables { { indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; } } + template + struct Run_scan_list_polysemous_hc { + using T = void; + template + void f(const IVFPQScannerT* scanner, Types... args) { + scanner->scan_list_polysemous_hc( + args...); + } + }; + template void scan_list_polysemous( size_t ncode, const uint8_t* codes, SearchResultType& res) const { - switch (pq.code_size) { -#define HANDLE_CODE_SIZE(cs) \ - case cs: \ - scan_list_polysemous_hc( \ - ncode, codes, res); \ - break - HANDLE_CODE_SIZE(4); - HANDLE_CODE_SIZE(8); - HANDLE_CODE_SIZE(16); - HANDLE_CODE_SIZE(20); - HANDLE_CODE_SIZE(32); - HANDLE_CODE_SIZE(64); -#undef HANDLE_CODE_SIZE - default: - scan_list_polysemous_hc< - HammingComputerDefault, - SearchResultType>(ncode, codes, res); - break; - } + Run_scan_list_polysemous_hc r; + dispatch_HammingComputer(pq.code_size, r, this, ncode, codes, res); } }; diff --git a/faiss/IndexIVFSpectralHash.cpp b/faiss/IndexIVFSpectralHash.cpp index 443c45dee6..d9a51fbe64 100644 --- a/faiss/IndexIVFSpectralHash.cpp +++ b/faiss/IndexIVFSpectralHash.cpp @@ -288,26 +288,23 @@ struct IVFScanner : InvertedListScanner { } }; +struct BuildScanner { + using T = InvertedListScanner*; + + template + static T f(const IndexIVFSpectralHash* index, bool store_pairs) { + return new IVFScanner(index, store_pairs); + } +}; + } // anonymous namespace InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner( bool store_pairs, const IDSelector* sel) const { FAISS_THROW_IF_NOT(!sel); - switch (code_size) { -#define HANDLE_CODE_SIZE(cs) \ - case cs: \ - return new IVFScanner(this, store_pairs) - HANDLE_CODE_SIZE(4); - HANDLE_CODE_SIZE(8); - HANDLE_CODE_SIZE(16); - HANDLE_CODE_SIZE(20); - HANDLE_CODE_SIZE(32); - HANDLE_CODE_SIZE(64); -#undef HANDLE_CODE_SIZE - default: - return new IVFScanner(this, store_pairs); - } + BuildScanner bs; + return dispatch_HammingComputer(code_size, bs, this, store_pairs); } void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) { diff --git a/faiss/IndexPQ.cpp b/faiss/IndexPQ.cpp index 7b1c28f8fd..2986745fbc 100644 --- a/faiss/IndexPQ.cpp +++ b/faiss/IndexPQ.cpp @@ -263,21 +263,23 @@ void IndexPQStats::reset() { IndexPQStats indexPQ_stats; +namespace { + template -static size_t polysemous_inner_loop( - const IndexPQ& index, +size_t polysemous_inner_loop( + const IndexPQ* index, const float* dis_table_qi, const uint8_t* q_code, size_t k, float* heap_dis, int64_t* heap_ids, int ht) { - int M = index.pq.M; - int code_size = index.pq.code_size; - int ksub = index.pq.ksub; - size_t ntotal = index.ntotal; + int M = index->pq.M; + int code_size = index->pq.code_size; + int ksub = index->pq.ksub; + size_t ntotal = index->ntotal; - const uint8_t* b_code = index.codes.data(); + const uint8_t* b_code = index->codes.data(); size_t n_pass_i = 0; @@ -305,6 +307,16 @@ static size_t polysemous_inner_loop( return n_pass_i; } +struct Run_polysemous_inner_loop { + using T = size_t; + template + size_t f(Types... args) { + return polysemous_inner_loop(args...); + } +}; + +} // anonymous namespace + void IndexPQ::search_core_polysemous( idx_t n, const float* x, @@ -355,45 +367,24 @@ void IndexPQ::search_core_polysemous( maxheap_heapify(k, heap_dis, heap_ids); if (!generalized_hamming) { - switch (pq.code_size) { -#define DISPATCH(cs) \ - case cs: \ - n_pass += polysemous_inner_loop( \ - *this, \ - dis_table_qi, \ - q_code, \ - k, \ - heap_dis, \ - heap_ids, \ - polysemous_ht); \ - break; - DISPATCH(4) - DISPATCH(8) - DISPATCH(16) - DISPATCH(32) - DISPATCH(20) - default: - if (pq.code_size % 4 == 0) { - n_pass += polysemous_inner_loop( - *this, - dis_table_qi, - q_code, - k, - heap_dis, - heap_ids, - polysemous_ht); - } else { - bad_code_size++; - } - break; - } -#undef DISPATCH + Run_polysemous_inner_loop r; + n_pass += dispatch_HammingComputer( + pq.code_size, + r, + this, + dis_table_qi, + q_code, + k, + heap_dis, + heap_ids, + polysemous_ht); + } else { // generalized hamming switch (pq.code_size) { #define DISPATCH(cs) \ case cs: \ n_pass += polysemous_inner_loop( \ - *this, \ + this, \ dis_table_qi, \ q_code, \ k, \ @@ -407,7 +398,7 @@ void IndexPQ::search_core_polysemous( default: if (pq.code_size % 8 == 0) { n_pass += polysemous_inner_loop( - *this, + this, dis_table_qi, q_code, k, diff --git a/faiss/utils/hamming.cpp b/faiss/utils/hamming.cpp index 7019183bd0..ee9ed60eb2 100644 --- a/faiss/utils/hamming.cpp +++ b/faiss/utils/hamming.cpp @@ -5,14 +5,13 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - /* * Implementation of Hamming related functions (distances, smallest distance * selection with regular heap|radix and probabilistic heap|radix. * * IMPLEMENTATION NOTES - * Bitvectors are generally assumed to be multiples of 64 bits. + * Optimal speed is typically obtained for vector sizes of multiples of 64 + * bits. * * hamdis_t is used for distances because at this time * it is not clear how we will need to balance @@ -20,8 +19,6 @@ * - memory usage * - cache-misses when dealing with large volumes of data (lower bits is better) * - * The hamdis_t should optimally be compatibe with one of the Torch Storage - * (Byte,Short,Long) and therefore should be signed for 2-bytes and 4-bytes */ #include @@ -165,9 +162,11 @@ size_t match_hamming_thres( return posm; } +namespace { + /* Return closest neighbors w.r.t Hamming distance, using a heap. */ template -static void hammings_knn_hc( +void hammings_knn_hc( int bytes_per_code, int_maxheap_array_t* __restrict ha, const uint8_t* __restrict bs1, @@ -234,7 +233,7 @@ static void hammings_knn_hc( /* Return closest neighbors w.r.t Hamming distance, using max count. */ template -static void hammings_knn_mc( +void hammings_knn_mc( int bytes_per_code, const uint8_t* __restrict a, const uint8_t* __restrict b, @@ -287,6 +286,63 @@ static void hammings_knn_mc( } } +template +void hamming_range_search( + const uint8_t* a, + const uint8_t* b, + size_t na, + size_t nb, + int radius, + size_t code_size, + RangeSearchResult* res) { +#pragma omp parallel + { + RangeSearchPartialResult pres(res); + +#pragma omp for + for (int64_t i = 0; i < na; i++) { + HammingComputer hc(a + i * code_size, code_size); + const uint8_t* yi = b; + RangeQueryResult& qres = pres.new_result(i); + + for (size_t j = 0; j < nb; j++) { + int dis = hc.hamming(yi); + if (dis < radius) { + qres.add(dis, j); + } + yi += code_size; + } + } + pres.finalize(); + } +} + +struct Run_hammings_knn_hc { + using T = void; + template + void f(Types... args) { + hammings_knn_hc(args...); + } +}; + +struct Run_hammings_knn_mc { + using T = void; + template + void f(Types... args) { + hammings_knn_mc(args...); + } +}; + +struct Run_hamming_range_search { + using T = void; + template + void f(Types... args) { + hamming_range_search(args...); + } +}; + +} // namespace + /* Functions to maps vectors to bits. Assume proper allocation done beforehand, meaning that b should be be able to receive as many bits as x may produce. */ @@ -437,28 +493,9 @@ void hammings_knn_hc( size_t ncodes, int order, ApproxTopK_mode_t approx_topk_mode) { - switch (ncodes) { - case 4: - hammings_knn_hc( - 4, ha, a, b, nb, order, true, approx_topk_mode); - break; - case 8: - hammings_knn_hc( - 8, ha, a, b, nb, order, true, approx_topk_mode); - break; - case 16: - hammings_knn_hc( - 16, ha, a, b, nb, order, true, approx_topk_mode); - break; - case 32: - hammings_knn_hc( - 32, ha, a, b, nb, order, true, approx_topk_mode); - break; - default: - hammings_knn_hc( - ncodes, ha, a, b, nb, order, true, approx_topk_mode); - break; - } + Run_hammings_knn_hc r; + dispatch_HammingComputer( + ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode); } void hammings_knn_mc( @@ -470,58 +507,9 @@ void hammings_knn_mc( size_t ncodes, int32_t* __restrict distances, int64_t* __restrict labels) { - switch (ncodes) { - case 4: - hammings_knn_mc( - 4, a, b, na, nb, k, distances, labels); - break; - case 8: - hammings_knn_mc( - 8, a, b, na, nb, k, distances, labels); - break; - case 16: - hammings_knn_mc( - 16, a, b, na, nb, k, distances, labels); - break; - case 32: - hammings_knn_mc( - 32, a, b, na, nb, k, distances, labels); - break; - default: - hammings_knn_mc( - ncodes, a, b, na, nb, k, distances, labels); - break; - } -} -template -static void hamming_range_search_template( - const uint8_t* a, - const uint8_t* b, - size_t na, - size_t nb, - int radius, - size_t code_size, - RangeSearchResult* res) { -#pragma omp parallel - { - RangeSearchPartialResult pres(res); - -#pragma omp for - for (int64_t i = 0; i < na; i++) { - HammingComputer hc(a + i * code_size, code_size); - const uint8_t* yi = b; - RangeQueryResult& qres = pres.new_result(i); - - for (size_t j = 0; j < nb; j++) { - int dis = hc.hamming(yi); - if (dis < radius) { - qres.add(dis, j); - } - yi += code_size; - } - } - pres.finalize(); - } + Run_hammings_knn_mc r; + dispatch_HammingComputer( + ncodes, r, ncodes, a, b, na, nb, k, distances, labels); } void hamming_range_search( @@ -532,27 +520,9 @@ void hamming_range_search( int radius, size_t code_size, RangeSearchResult* result) { -#define HC(name) \ - hamming_range_search_template(a, b, na, nb, radius, code_size, result) - - switch (code_size) { - case 4: - HC(HammingComputer4); - break; - case 8: - HC(HammingComputer8); - break; - case 16: - HC(HammingComputer16); - break; - case 32: - HC(HammingComputer32); - break; - default: - HC(HammingComputerDefault); - break; - } -#undef HC + Run_hamming_range_search r; + dispatch_HammingComputer( + code_size, r, a, b, na, nb, radius, code_size, result); } /* Count number of matches given a max threshold */ diff --git a/faiss/utils/hamming_distance/avx2-inl.h b/faiss/utils/hamming_distance/avx2-inl.h index 2393b75778..fdc746c019 100644 --- a/faiss/utils/hamming_distance/avx2-inl.h +++ b/faiss/utils/hamming_distance/avx2-inl.h @@ -345,93 +345,6 @@ struct HammingComputerDefault { } }; -// more inefficient than HammingComputerDefault (obsolete) -struct HammingComputerM8 { - const uint64_t* a; - int n; - - HammingComputerM8() {} - - HammingComputerM8(const uint8_t* a8, int code_size) { - set(a8, code_size); - } - - void set(const uint8_t* a8, int code_size) { - assert(code_size % 8 == 0); - a = (uint64_t*)a8; - n = code_size / 8; - } - - int hamming(const uint8_t* b8) const { - const uint64_t* b = (uint64_t*)b8; - int accu = 0; - for (int i = 0; i < n; i++) - accu += popcount64(a[i] ^ b[i]); - return accu; - } - - inline int get_code_size() const { - return n * 8; - } -}; - -// more inefficient than HammingComputerDefault (obsolete) -struct HammingComputerM4 { - const uint32_t* a; - int n; - - HammingComputerM4() {} - - HammingComputerM4(const uint8_t* a4, int code_size) { - set(a4, code_size); - } - - void set(const uint8_t* a4, int code_size) { - assert(code_size % 4 == 0); - a = (uint32_t*)a4; - n = code_size / 4; - } - - int hamming(const uint8_t* b8) const { - const uint32_t* b = (uint32_t*)b8; - int accu = 0; - for (int i = 0; i < n; i++) - accu += popcount64(a[i] ^ b[i]); - return accu; - } - - inline int get_code_size() const { - return n * 4; - } -}; - -/*************************************************************************** - * Equivalence with a template class when code size is known at compile time - **************************************************************************/ - -// default template -template -struct HammingComputer : HammingComputerDefault { - HammingComputer(const uint8_t* a, int code_size) - : HammingComputerDefault(a, code_size) {} -}; - -#define SPECIALIZED_HC(CODE_SIZE) \ - template <> \ - struct HammingComputer : HammingComputer##CODE_SIZE { \ - HammingComputer(const uint8_t* a) \ - : HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \ - } - -SPECIALIZED_HC(4); -SPECIALIZED_HC(8); -SPECIALIZED_HC(16); -SPECIALIZED_HC(20); -SPECIALIZED_HC(32); -SPECIALIZED_HC(64); - -#undef SPECIALIZED_HC - /*************************************************************************** * generalized Hamming = number of bytes that are different between * two codes. diff --git a/faiss/utils/hamming_distance/common.h b/faiss/utils/hamming_distance/common.h index 62c1f9fd03..0a2de08d17 100644 --- a/faiss/utils/hamming_distance/common.h +++ b/faiss/utils/hamming_distance/common.h @@ -17,6 +17,7 @@ using hamdis_t = int32_t; namespace faiss { +// trust the compiler to provide efficient popcount implementations inline int popcount32(uint32_t x) { return __builtin_popcount(x); } diff --git a/faiss/utils/hamming_distance/generic-inl.h b/faiss/utils/hamming_distance/generic-inl.h index 8e9356c9ab..1607fb5d05 100644 --- a/faiss/utils/hamming_distance/generic-inl.h +++ b/faiss/utils/hamming_distance/generic-inl.h @@ -329,93 +329,6 @@ struct HammingComputerDefault { } }; -// more inefficient than HammingComputerDefault (obsolete) -struct HammingComputerM8 { - const uint64_t* a; - int n; - - HammingComputerM8() {} - - HammingComputerM8(const uint8_t* a8, int code_size) { - set(a8, code_size); - } - - void set(const uint8_t* a8, int code_size) { - assert(code_size % 8 == 0); - a = (uint64_t*)a8; - n = code_size / 8; - } - - int hamming(const uint8_t* b8) const { - const uint64_t* b = (uint64_t*)b8; - int accu = 0; - for (int i = 0; i < n; i++) - accu += popcount64(a[i] ^ b[i]); - return accu; - } - - inline int get_code_size() const { - return n * 8; - } -}; - -// more inefficient than HammingComputerDefault (obsolete) -struct HammingComputerM4 { - const uint32_t* a; - int n; - - HammingComputerM4() {} - - HammingComputerM4(const uint8_t* a4, int code_size) { - set(a4, code_size); - } - - void set(const uint8_t* a4, int code_size) { - assert(code_size % 4 == 0); - a = (uint32_t*)a4; - n = code_size / 4; - } - - int hamming(const uint8_t* b8) const { - const uint32_t* b = (uint32_t*)b8; - int accu = 0; - for (int i = 0; i < n; i++) - accu += popcount64(a[i] ^ b[i]); - return accu; - } - - inline int get_code_size() const { - return n * 4; - } -}; - -/*************************************************************************** - * Equivalence with a template class when code size is known at compile time - **************************************************************************/ - -// default template -template -struct HammingComputer : HammingComputerDefault { - HammingComputer(const uint8_t* a, int code_size) - : HammingComputerDefault(a, code_size) {} -}; - -#define SPECIALIZED_HC(CODE_SIZE) \ - template <> \ - struct HammingComputer : HammingComputer##CODE_SIZE { \ - HammingComputer(const uint8_t* a) \ - : HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \ - } - -SPECIALIZED_HC(4); -SPECIALIZED_HC(8); -SPECIALIZED_HC(16); -SPECIALIZED_HC(20); -SPECIALIZED_HC(32); -SPECIALIZED_HC(64); - -#undef SPECIALIZED_HC - /*************************************************************************** * generalized Hamming = number of bytes that are different between * two codes. diff --git a/faiss/utils/hamming_distance/hamdis-inl.h b/faiss/utils/hamming_distance/hamdis-inl.h index aaea84735e..b830df38b6 100644 --- a/faiss/utils/hamming_distance/hamdis-inl.h +++ b/faiss/utils/hamming_distance/hamdis-inl.h @@ -23,4 +23,61 @@ #include #endif +namespace faiss { + +/*************************************************************************** + * Equivalence with a template class when code size is known at compile time + **************************************************************************/ + +// default template +template +struct HammingComputer : HammingComputerDefault { + HammingComputer(const uint8_t* a, int code_size) + : HammingComputerDefault(a, code_size) {} +}; + +#define SPECIALIZED_HC(CODE_SIZE) \ + template <> \ + struct HammingComputer : HammingComputer##CODE_SIZE { \ + HammingComputer(const uint8_t* a) \ + : HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \ + } + +SPECIALIZED_HC(4); +SPECIALIZED_HC(8); +SPECIALIZED_HC(16); +SPECIALIZED_HC(20); +SPECIALIZED_HC(32); +SPECIALIZED_HC(64); + +#undef SPECIALIZED_HC + +/*************************************************************************** + * Dispatching function that takes a code size and a consumer object + * the consumer object should contain a retun type t and a operation template + * function f() that to be called to perform the operation. + **************************************************************************/ + +template +typename Consumer::T dispatch_HammingComputer( + int code_size, + Consumer& consumer, + Types... args) { + switch (code_size) { +#define DISPATCH_HC(CODE_SIZE) \ + case CODE_SIZE: \ + return consumer.template f(args...); + DISPATCH_HC(4); + DISPATCH_HC(8); + DISPATCH_HC(16); + DISPATCH_HC(20); + DISPATCH_HC(32); + DISPATCH_HC(64); + default: + return consumer.template f(args...); + } +} + +} // namespace faiss + #endif diff --git a/faiss/utils/hamming_distance/neon-inl.h b/faiss/utils/hamming_distance/neon-inl.h index 38b5aa6af2..d1a0fdee7a 100644 --- a/faiss/utils/hamming_distance/neon-inl.h +++ b/faiss/utils/hamming_distance/neon-inl.h @@ -392,109 +392,6 @@ struct HammingComputerDefault { } }; -// more inefficient than HammingComputerDefault (obsolete) -struct HammingComputerM8 { - const uint64_t* a; - int n; - - HammingComputerM8() {} - - HammingComputerM8(const uint8_t* a8, int code_size) { - set(a8, code_size); - } - - void set(const uint8_t* a8, int code_size) { - assert(code_size % 8 == 0); - a = (uint64_t*)a8; - n = code_size / 8; - } - - int hamming(const uint8_t* b8) const { - const uint64_t* b = (uint64_t*)b8; - int n4 = (n / 4) * 4; - int accu = 0; - - int i = 0; - for (; i < n4; i += 4) { - accu += ::faiss::hamming<256>(a + i, b + i); - } - for (; i < n; i++) { - accu += popcount64(a[i] ^ b[i]); - } - return accu; - } - - inline int get_code_size() const { - return n * 8; - } -}; - -// more inefficient than HammingComputerDefault (obsolete) -struct HammingComputerM4 { - const uint32_t* a; - int n; - - HammingComputerM4() {} - - HammingComputerM4(const uint8_t* a4, int code_size) { - set(a4, code_size); - } - - void set(const uint8_t* a4, int code_size) { - assert(code_size % 4 == 0); - a = (uint32_t*)a4; - n = code_size / 4; - } - - int hamming(const uint8_t* b8) const { - const uint32_t* b = (uint32_t*)b8; - - int n8 = (n / 8) * 8; - int accu = 0; - - int i = 0; - for (; i < n8; i += 8) { - accu += ::faiss::hamming<256>( - (const uint64_t*)(a + i), (const uint64_t*)(b + i)); - } - for (; i < n; i++) { - accu += popcount64(a[i] ^ b[i]); - } - return accu; - } - - inline int get_code_size() const { - return n * 4; - } -}; - -/*************************************************************************** - * Equivalence with a template class when code size is known at compile time - **************************************************************************/ - -// default template -template -struct HammingComputer : HammingComputerDefault { - HammingComputer(const uint8_t* a, int code_size) - : HammingComputerDefault(a, code_size) {} -}; - -#define SPECIALIZED_HC(CODE_SIZE) \ - template <> \ - struct HammingComputer : HammingComputer##CODE_SIZE { \ - HammingComputer(const uint8_t* a) \ - : HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \ - } - -SPECIALIZED_HC(4); -SPECIALIZED_HC(8); -SPECIALIZED_HC(16); -SPECIALIZED_HC(20); -SPECIALIZED_HC(32); -SPECIALIZED_HC(64); - -#undef SPECIALIZED_HC - /*************************************************************************** * generalized Hamming = number of bytes that are different between * two codes.