From c552f969621753e40adc2e1b667b01153322ab13 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Wed, 22 Nov 2023 12:18:26 -0800 Subject: [PATCH] Add vector deduplication --- faiss/CMakeLists.txt | 2 + faiss/Index.h | 4 +- faiss/IndexHNSW.cpp | 1 - faiss/impl/HNSW.cpp | 56 ++++++++++-- faiss/impl/IDDeduper.cpp | 22 +++++ faiss/impl/IDDeduper.h | 36 ++++++++ faiss/utils/Heap.h | 174 ++++++++++++++++++++++++++++++++++++ tutorial/cpp/6-Dedupe.cpp | 111 +++++++++++++++++++++++ tutorial/cpp/CMakeLists.txt | 3 + 9 files changed, 399 insertions(+), 10 deletions(-) create mode 100644 faiss/impl/IDDeduper.cpp create mode 100644 faiss/impl/IDDeduper.h create mode 100644 tutorial/cpp/6-Dedupe.cpp diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 291b225cd7..43778eee62 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -54,6 +54,7 @@ set(FAISS_SRC impl/AuxIndexStructures.cpp impl/CodePacker.cpp impl/IDSelector.cpp + impl/IDDeduper.cpp impl/FaissException.cpp impl/HNSW.cpp impl/NSG.cpp @@ -148,6 +149,7 @@ set(FAISS_HEADERS impl/AdditiveQuantizer.h impl/AuxIndexStructures.h impl/IDSelector.h + impl/IDDeduper.h impl/DistanceComputer.h impl/FaissAssert.h impl/FaissException.h diff --git a/faiss/Index.h b/faiss/Index.h index 4b4b302b47..55ccdf484d 100644 --- a/faiss/Index.h +++ b/faiss/Index.h @@ -38,9 +38,10 @@ namespace faiss { -/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and +/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h, impl/IDDeduper.h and /// impl/DistanceComputer.h struct IDSelector; +struct IDDeduper; struct RangeSearchResult; struct DistanceComputer; @@ -52,6 +53,7 @@ struct DistanceComputer; struct SearchParameters { /// if non-null, only these IDs will be considered during search. IDSelector* sel = nullptr; + IDDeduper* dedup = nullptr; /// make sure we can dynamic_cast this virtual ~SearchParameters() {} }; diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index f846223479..4baabebac0 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -296,7 +296,6 @@ void IndexHNSW::search( for (idx_t i0 = 0; i0 < n; i0 += check_period) { idx_t i1 = std::min(i0 + check_period, n); - #pragma omp parallel { VisitedTable vt(ntotal); diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 9fc201ea39..c0b2ca4792 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -538,16 +538,34 @@ int search_from_candidates( : hnsw.check_relative_distance; int efSearch = params ? params->efSearch : hnsw.efSearch; const IDSelector* sel = params ? params->sel : nullptr; + const IDDeduper* dedup = params ? params->dedup : nullptr; + std::unordered_map group_id_to_id; + std::unordered_map group_id_to_index; for (int i = 0; i < candidates.size(); i++) { idx_t v1 = candidates.ids[i]; float d = candidates.dis[i]; FAISS_ASSERT(v1 >= 0); if (!sel || sel->is_member(v1)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, d, v1); - } else if (d < D[0]) { - faiss::maxheap_replace_top(nres, D, I, d, v1); + if (dedup) { + idx_t group_id = dedup->group_id(v1); + if (group_id_to_index.find(group_id) == group_id_to_index.end()) { + if (nres < k) { + faiss::maxheap_push_with_dedupe(++nres, D, I, d, v1, &group_id_to_id, &group_id_to_index, group_id); + } else if (d < D[0]) { + faiss::maxheap_replace_top_with_dedupe(nres, D, I, d, v1, &group_id_to_id, &group_id_to_index, group_id); + } + } + else if (d < D[group_id_to_index.at(group_id)]) { + faiss::maxheap_update_with_dedupe(nres, D, I, d, v1, &group_id_to_id, &group_id_to_index, group_id); + } + } + else { + if (nres < k) { + faiss::maxheap_push(++nres, D, I, d, v1); + } else if (d < D[0]) { + faiss::maxheap_replace_top(nres, D, I, d, v1); + } } } vt.set(v1); @@ -612,10 +630,25 @@ int search_from_candidates( auto add_to_heap = [&](const size_t idx, const float dis) { if (!sel || sel->is_member(idx)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, dis, idx); - } else if (dis < D[0]) { - faiss::maxheap_replace_top(nres, D, I, dis, idx); + if (dedup) { + idx_t group_id = dedup->group_id(idx); + if (group_id_to_index.find(group_id) == group_id_to_index.end()) { + if (nres < k) { + faiss::maxheap_push_with_dedupe(++nres, D, I, dis, idx, &group_id_to_id, &group_id_to_index, group_id); + } else if (dis < D[0]) { + faiss::maxheap_replace_top_with_dedupe(nres, D, I, dis, idx, &group_id_to_id, &group_id_to_index, group_id); + } + } + else if (dis < D[group_id_to_index.at(group_id)]) { + faiss::maxheap_update_with_dedupe(nres, D, I, dis, idx, &group_id_to_id, &group_id_to_index, group_id); + } + } + else { + if (nres < k) { + faiss::maxheap_push(++nres, D, I, dis, idx); + } else if (dis < D[0]) { + faiss::maxheap_replace_top(nres, D, I, dis, idx); + } } } candidates.push(idx, dis); @@ -660,6 +693,13 @@ int search_from_candidates( } } + // Convert group id to id before return + if (dedup) { + for (size_t icnt = 0; icnt < nres; icnt++) { + I[icnt] = group_id_to_id.at(I[icnt]); + } + } + if (level == 0) { stats.n1++; if (candidates.size() == 0) { diff --git a/faiss/impl/IDDeduper.cpp b/faiss/impl/IDDeduper.cpp new file mode 100644 index 0000000000..94aee3bcba --- /dev/null +++ b/faiss/impl/IDDeduper.cpp @@ -0,0 +1,22 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace faiss { + +/*********************************************************************** + * IDDeduperMap + ***********************************************************************/ +IDDeduperMap::IDDeduperMap(std::unordered_map* m) + : m(m) {} + +idx_t IDDeduperMap::group_id(idx_t id) const { + return m->at(id); +} + +} // namespace faiss diff --git a/faiss/impl/IDDeduper.h b/faiss/impl/IDDeduper.h new file mode 100644 index 0000000000..38816a88c9 --- /dev/null +++ b/faiss/impl/IDDeduper.h @@ -0,0 +1,36 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +/** IDDeduper is intended to define a group of vectors to dedupe */ + +namespace faiss { + +/** Encapsulates a set of id groups to handle. */ +struct IDDeduper { + virtual idx_t group_id(idx_t id) const = 0; + virtual ~IDDeduper() {} +}; + +/** id to group id mapping */ +struct IDDeduperMap : IDDeduper { + std::unordered_map* m; + + IDDeduperMap(std::unordered_map* m); + + idx_t group_id(idx_t id) const final; + + ~IDDeduperMap() override {} +}; + +} // namespace faiss diff --git a/faiss/utils/Heap.h b/faiss/utils/Heap.h index cdb714f4d6..b06d2e6c28 100644 --- a/faiss/utils/Heap.h +++ b/faiss/utils/Heap.h @@ -32,6 +32,7 @@ #include #include +#include namespace faiss { @@ -148,6 +149,140 @@ inline void heap_replace_top( bh_ids[i] = id; } + +template +inline void up_heap_with_dedupe( + size_t k, + typename C::T* bh_val, + typename C::TI* bh_ids, + typename C::T val, + typename C::TI id, + std::unordered_map* group_id_to_id, + std::unordered_map* group_id_to_index, + typename C::TI group_id, + size_t start_index) { + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + size_t i = start_index + 1, i_father; + + while (i > 1) { + i_father = i >> 1; + if (!C::cmp2(val, bh_val[i_father], group_id, bh_ids[i_father])) { + /* the heap structure is ok */ + break; + } + bh_val[i] = bh_val[i_father]; + bh_ids[i] = bh_ids[i_father]; + (*group_id_to_index)[bh_ids[i]] = i - 1; + i = i_father; + } + bh_val[i] = val; + bh_ids[i] = group_id; + (*group_id_to_id)[group_id] = id; + (*group_id_to_index)[group_id] = i - 1; +} + +template +inline void down_heap_with_dedupe( + size_t k, + typename C::T* bh_val, + typename C::TI* bh_ids, + typename C::T val, + typename C::TI id, + std::unordered_map* group_id_to_id, + std::unordered_map* group_id_to_index, + typename C::TI group_id, + size_t start_index) { + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + size_t i = start_index + 1, i1, i2; + + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) { + break; + } + + // Note that C::cmp2() is a bool function answering + // `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max + // heap and same with the `<` sign for min heap. + if ((i2 == k + 1) || + C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) { + if (C::cmp2(val, bh_val[i1], group_id, bh_ids[i1])) { + break; + } + bh_val[i] = bh_val[i1]; + bh_ids[i] = bh_ids[i1]; + (*group_id_to_index)[bh_ids[i]] = i - 1; + i = i1; + } else { + if (C::cmp2(val, bh_val[i2], group_id, bh_ids[i2])) { + break; + } + bh_val[i] = bh_val[i2]; + bh_ids[i] = bh_ids[i2]; + (*group_id_to_index)[bh_ids[i]] = i - 1; + i = i2; + } + } + bh_val[i] = val; + bh_ids[i] = group_id; + (*group_id_to_id)[group_id] = id; + (*group_id_to_index)[group_id] = i - 1; +} + +/** heap_push with group id + */ +template +inline void heap_push_with_dedupe( + size_t k, + typename C::T* bh_val, + typename C::TI* bh_ids, + typename C::T val, + typename C::TI id, + std::unordered_map* group_id_to_id, + std::unordered_map* group_id_to_index, + typename C::TI group_id) { + up_heap_with_dedupe(k, bh_val, bh_ids, val, id, group_id_to_id, group_id_to_index, group_id, k - 1); +} + +/** + * heap_replace_top with with group id + */ +template +inline void heap_replace_top_with_dedupe( + size_t k, + typename C::T* bh_val, + typename C::TI* bh_ids, + typename C::T val, + typename C::TI id, + std::unordered_map* group_id_to_id, + std::unordered_map* group_id_to_index, + typename C::TI group_id) { + group_id_to_id->erase(bh_ids[0]); + group_id_to_index->erase(bh_ids[0]); + down_heap_with_dedupe(k, bh_val, bh_ids, val, id, group_id_to_id, group_id_to_index, group_id, 0); +} + +/** + * heap_update with group id + */ +template +inline void heap_update_with_dedupe( + size_t k, + typename C::T* bh_val, + typename C::TI* bh_ids, + typename C::T val, + typename C::TI id, + std::unordered_map* group_id_to_id, + std::unordered_map* group_id_to_index, + typename C::TI group_id) { + size_t start_index = group_id_to_index->at(group_id); + up_heap_with_dedupe(k, bh_val, bh_ids, val, id, group_id_to_id, group_id_to_index, group_id, start_index); + down_heap_with_dedupe(k, bh_val, bh_ids, val, id, group_id_to_id, group_id_to_index, group_id, start_index); +} + /* Partial instanciation for heaps with TI = int64_t */ template @@ -200,6 +335,45 @@ inline void maxheap_replace_top( heap_replace_top>(k, bh_val, bh_ids, val, ids); } +template +inline void maxheap_push_with_dedupe( + size_t k, + T* bh_val, + int64_t* bh_ids, + T val, + int64_t ids, + std::unordered_map* group_id_to_id, + std::unordered_map* group_id_to_index, + int64_t group_id) { + heap_push_with_dedupe>(k, bh_val, bh_ids, val, ids, group_id_to_id, group_id_to_index, group_id); +} + +template +inline void maxheap_replace_top_with_dedupe( + size_t k, + T* bh_val, + int64_t* bh_ids, + T val, + int64_t ids, + std::unordered_map* group_id_to_id, + std::unordered_map* group_id_to_index, + int64_t group_id) { + heap_replace_top_with_dedupe>(k, bh_val, bh_ids, val, ids, group_id_to_id, group_id_to_index, group_id); +} + +template +inline void maxheap_update_with_dedupe( + size_t k, + T* bh_val, + int64_t* bh_ids, + T val, + int64_t ids, + std::unordered_map* group_id_to_id, + std::unordered_map* group_id_to_index, + int64_t group_id) { + heap_update_with_dedupe>(k, bh_val, bh_ids, val, ids, group_id_to_id, group_id_to_index, group_id); +} + /******************************************************************* * Heap initialization *******************************************************************/ diff --git a/tutorial/cpp/6-Dedupe.cpp b/tutorial/cpp/6-Dedupe.cpp new file mode 100644 index 0000000000..0afaa5dc9a --- /dev/null +++ b/tutorial/cpp/6-Dedupe.cpp @@ -0,0 +1,111 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include + +// 64-bit int +using idx_t = faiss::idx_t; + +int main() { + int d = 64; // dimension + int nb = 10; // database size + int nq = 1; // nb of queries + + std::mt19937 rng; + std::uniform_real_distribution<> distrib; + + float* xb = new float[d * nb]; + float* xq = new float[d * nq]; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) + xb[d * i + j] = distrib(rng); + xb[d * i] += i / 1000.; + } + + for (int i = 0; i < nq; i++) { + for (int j = 0; j < d; j++) + xq[d * i + j] = distrib(rng); + xq[d * i] += i / 1000.; + } + + int k = 4; + int m = 8; + faiss::Index* index = + new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); + printf("is_trained = %s\n", index->is_trained ? "true" : "false"); + index->add(nb, xb); // add vectors to the index + printf("ntotal = %zd\n", index->ntotal); + + { // sanity check: search 5 first vectors of xb + idx_t* I = new idx_t[k * 5]; + float* D = new float[k * 5]; + + index->search(5, xb, k, D, I); + + // print results + printf("I=\n"); + for (int i = 0; i < 5; i++) { + for (int j = 0; j < k; j++) + printf("%5zd ", I[i * k + j]); + printf("\n"); + } + + printf("D=\n"); + for (int i = 0; i < 5; i++) { + for (int j = 0; j < k; j++) + printf("%7g ", D[i * k + j]); + printf("\n"); + } + + delete[] I; + delete[] D; + } + + { // search 5 first vectors of xb with deduper + idx_t* I = new idx_t[k * 5]; + float* D = new float[k * 5]; + std::unordered_map group; + for (int i = 0; i < nb; i++) { + group[i] = i % 2; + } + faiss::IDDeduperMap idDeduper(&group); + auto pSearchParameters = new faiss::SearchParametersHNSW(); + pSearchParameters->dedup = &idDeduper; + + index->search(5, xb, k, D, I, pSearchParameters); + + // print results + printf("I=\n"); + for (int i = 0; i < 5; i++) { + for (int j = 0; j < k; j++) + printf("%5zd ", I[i * k + j]); + printf("\n"); + } + + printf("D=\n"); + for (int i = 0; i < 5; i++) { + for (int j = 0; j < k; j++) + printf("%7g ", D[i * k + j]); + printf("\n"); + } + + delete[] I; + delete[] D; + } + + delete[] xb; + delete[] xq; + + return 0; +} diff --git a/tutorial/cpp/CMakeLists.txt b/tutorial/cpp/CMakeLists.txt index 7361b33a03..6022540a05 100644 --- a/tutorial/cpp/CMakeLists.txt +++ b/tutorial/cpp/CMakeLists.txt @@ -18,3 +18,6 @@ target_link_libraries(4-GPU PRIVATE faiss) add_executable(5-Multiple-GPUs EXCLUDE_FROM_ALL 5-Multiple-GPUs.cpp) target_link_libraries(5-Multiple-GPUs PRIVATE faiss) + +add_executable(6-Dedupe EXCLUDE_FROM_ALL 6-Dedupe.cpp) +target_link_libraries(6-Dedupe PRIVATE faiss)