Skip to content

Commit

Permalink
Add vector deduplication
Browse files Browse the repository at this point in the history
  • Loading branch information
heemin32 committed Nov 22, 2023
1 parent 467f70e commit c552f96
Show file tree
Hide file tree
Showing 9 changed files with 399 additions and 10 deletions.
2 changes: 2 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ set(FAISS_SRC
impl/AuxIndexStructures.cpp
impl/CodePacker.cpp
impl/IDSelector.cpp
impl/IDDeduper.cpp
impl/FaissException.cpp
impl/HNSW.cpp
impl/NSG.cpp
Expand Down Expand Up @@ -148,6 +149,7 @@ set(FAISS_HEADERS
impl/AdditiveQuantizer.h
impl/AuxIndexStructures.h
impl/IDSelector.h
impl/IDDeduper.h
impl/DistanceComputer.h
impl/FaissAssert.h
impl/FaissException.h
Expand Down
4 changes: 3 additions & 1 deletion faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@

namespace faiss {

/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h, impl/IDDeduper.h and
/// impl/DistanceComputer.h
struct IDSelector;
struct IDDeduper;
struct RangeSearchResult;
struct DistanceComputer;

Expand All @@ -52,6 +53,7 @@ struct DistanceComputer;
struct SearchParameters {
/// if non-null, only these IDs will be considered during search.
IDSelector* sel = nullptr;
IDDeduper* dedup = nullptr;
/// make sure we can dynamic_cast this
virtual ~SearchParameters() {}
};
Expand Down
1 change: 0 additions & 1 deletion faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ void IndexHNSW::search(

for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);

#pragma omp parallel
{
VisitedTable vt(ntotal);
Expand Down
56 changes: 48 additions & 8 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,16 +538,34 @@ int search_from_candidates(
: hnsw.check_relative_distance;
int efSearch = params ? params->efSearch : hnsw.efSearch;
const IDSelector* sel = params ? params->sel : nullptr;
const IDDeduper* dedup = params ? params->dedup : nullptr;
std::unordered_map<idx_t, idx_t> group_id_to_id;
std::unordered_map<idx_t, size_t> group_id_to_index;

for (int i = 0; i < candidates.size(); i++) {
idx_t v1 = candidates.ids[i];
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
if (!sel || sel->is_member(v1)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
if (dedup) {
idx_t group_id = dedup->group_id(v1);
if (group_id_to_index.find(group_id) == group_id_to_index.end()) {
if (nres < k) {
faiss::maxheap_push_with_dedupe(++nres, D, I, d, v1, &group_id_to_id, &group_id_to_index, group_id);
} else if (d < D[0]) {
faiss::maxheap_replace_top_with_dedupe(nres, D, I, d, v1, &group_id_to_id, &group_id_to_index, group_id);
}
}
else if (d < D[group_id_to_index.at(group_id)]) {
faiss::maxheap_update_with_dedupe(nres, D, I, d, v1, &group_id_to_id, &group_id_to_index, group_id);
}
}
else {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
}
}
}
vt.set(v1);
Expand Down Expand Up @@ -612,10 +630,25 @@ int search_from_candidates(

auto add_to_heap = [&](const size_t idx, const float dis) {
if (!sel || sel->is_member(idx)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
if (dedup) {
idx_t group_id = dedup->group_id(idx);
if (group_id_to_index.find(group_id) == group_id_to_index.end()) {
if (nres < k) {
faiss::maxheap_push_with_dedupe(++nres, D, I, dis, idx, &group_id_to_id, &group_id_to_index, group_id);
} else if (dis < D[0]) {
faiss::maxheap_replace_top_with_dedupe(nres, D, I, dis, idx, &group_id_to_id, &group_id_to_index, group_id);
}
}
else if (dis < D[group_id_to_index.at(group_id)]) {
faiss::maxheap_update_with_dedupe(nres, D, I, dis, idx, &group_id_to_id, &group_id_to_index, group_id);
}
}
else {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
}
}
}
candidates.push(idx, dis);
Expand Down Expand Up @@ -660,6 +693,13 @@ int search_from_candidates(
}
}

// Convert group id to id before return
if (dedup) {
for (size_t icnt = 0; icnt < nres; icnt++) {
I[icnt] = group_id_to_id.at(I[icnt]);
}
}

if (level == 0) {
stats.n1++;
if (candidates.size() == 0) {
Expand Down
22 changes: 22 additions & 0 deletions faiss/impl/IDDeduper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <faiss/impl/IDDeduper.h>

namespace faiss {

/***********************************************************************
* IDDeduperMap
***********************************************************************/
IDDeduperMap::IDDeduperMap(std::unordered_map<idx_t, idx_t>* m)
: m(m) {}

idx_t IDDeduperMap::group_id(idx_t id) const {
return m->at(id);
}

} // namespace faiss
36 changes: 36 additions & 0 deletions faiss/impl/IDDeduper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <unordered_set>
#include <vector>

#include <faiss/MetricType.h>

/** IDDeduper is intended to define a group of vectors to dedupe */

namespace faiss {

/** Encapsulates a set of id groups to handle. */
struct IDDeduper {
virtual idx_t group_id(idx_t id) const = 0;
virtual ~IDDeduper() {}
};

/** id to group id mapping */
struct IDDeduperMap : IDDeduper {
std::unordered_map<idx_t, idx_t>* m;

IDDeduperMap(std::unordered_map<idx_t, idx_t>* m);

idx_t group_id(idx_t id) const final;

~IDDeduperMap() override {}
};

} // namespace faiss
174 changes: 174 additions & 0 deletions faiss/utils/Heap.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <limits>

#include <faiss/utils/ordered_key_value.h>
#include <faiss/impl/IDDeduper.h>

namespace faiss {

Expand Down Expand Up @@ -148,6 +149,140 @@ inline void heap_replace_top(
bh_ids[i] = id;
}


template <class C>
inline void up_heap_with_dedupe(
size_t k,
typename C::T* bh_val,
typename C::TI* bh_ids,
typename C::T val,
typename C::TI id,
std::unordered_map<typename C::TI, typename C::TI>* group_id_to_id,
std::unordered_map<typename C::TI, size_t>* group_id_to_index,
typename C::TI group_id,
size_t start_index) {
bh_val--; /* Use 1-based indexing for easier node->child translation */
bh_ids--;
size_t i = start_index + 1, i_father;

while (i > 1) {
i_father = i >> 1;
if (!C::cmp2(val, bh_val[i_father], group_id, bh_ids[i_father])) {
/* the heap structure is ok */
break;
}
bh_val[i] = bh_val[i_father];
bh_ids[i] = bh_ids[i_father];
(*group_id_to_index)[bh_ids[i]] = i - 1;
i = i_father;
}
bh_val[i] = val;
bh_ids[i] = group_id;
(*group_id_to_id)[group_id] = id;
(*group_id_to_index)[group_id] = i - 1;
}

template <class C>
inline void down_heap_with_dedupe(
size_t k,
typename C::T* bh_val,
typename C::TI* bh_ids,
typename C::T val,
typename C::TI id,
std::unordered_map<typename C::TI, typename C::TI>* group_id_to_id,
std::unordered_map<typename C::TI, size_t>* group_id_to_index,
typename C::TI group_id,
size_t start_index) {
bh_val--; /* Use 1-based indexing for easier node->child translation */
bh_ids--;
size_t i = start_index + 1, i1, i2;

while (1) {
i1 = i << 1;
i2 = i1 + 1;
if (i1 > k) {
break;
}

// Note that C::cmp2() is a bool function answering
// `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max
// heap and same with the `<` sign for min heap.
if ((i2 == k + 1) ||
C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) {
if (C::cmp2(val, bh_val[i1], group_id, bh_ids[i1])) {
break;
}
bh_val[i] = bh_val[i1];
bh_ids[i] = bh_ids[i1];
(*group_id_to_index)[bh_ids[i]] = i - 1;
i = i1;
} else {
if (C::cmp2(val, bh_val[i2], group_id, bh_ids[i2])) {
break;
}
bh_val[i] = bh_val[i2];
bh_ids[i] = bh_ids[i2];
(*group_id_to_index)[bh_ids[i]] = i - 1;
i = i2;
}
}
bh_val[i] = val;
bh_ids[i] = group_id;
(*group_id_to_id)[group_id] = id;
(*group_id_to_index)[group_id] = i - 1;
}

/** heap_push with group id
*/
template <class C>
inline void heap_push_with_dedupe(
size_t k,
typename C::T* bh_val,
typename C::TI* bh_ids,
typename C::T val,
typename C::TI id,
std::unordered_map<typename C::TI, typename C::TI>* group_id_to_id,
std::unordered_map<typename C::TI, size_t>* group_id_to_index,
typename C::TI group_id) {
up_heap_with_dedupe<C>(k, bh_val, bh_ids, val, id, group_id_to_id, group_id_to_index, group_id, k - 1);
}

/**
* heap_replace_top with with group id
*/
template <class C>
inline void heap_replace_top_with_dedupe(
size_t k,
typename C::T* bh_val,
typename C::TI* bh_ids,
typename C::T val,
typename C::TI id,
std::unordered_map<typename C::TI, typename C::TI>* group_id_to_id,
std::unordered_map<typename C::TI, size_t>* group_id_to_index,
typename C::TI group_id) {
group_id_to_id->erase(bh_ids[0]);
group_id_to_index->erase(bh_ids[0]);
down_heap_with_dedupe<C>(k, bh_val, bh_ids, val, id, group_id_to_id, group_id_to_index, group_id, 0);
}

/**
* heap_update with group id
*/
template <class C>
inline void heap_update_with_dedupe(
size_t k,
typename C::T* bh_val,
typename C::TI* bh_ids,
typename C::T val,
typename C::TI id,
std::unordered_map<typename C::TI, typename C::TI>* group_id_to_id,
std::unordered_map<typename C::TI, size_t>* group_id_to_index,
typename C::TI group_id) {
size_t start_index = group_id_to_index->at(group_id);
up_heap_with_dedupe<C>(k, bh_val, bh_ids, val, id, group_id_to_id, group_id_to_index, group_id, start_index);
down_heap_with_dedupe<C>(k, bh_val, bh_ids, val, id, group_id_to_id, group_id_to_index, group_id, start_index);
}

/* Partial instanciation for heaps with TI = int64_t */

template <typename T>
Expand Down Expand Up @@ -200,6 +335,45 @@ inline void maxheap_replace_top(
heap_replace_top<CMax<T, int64_t>>(k, bh_val, bh_ids, val, ids);
}

template <typename T>
inline void maxheap_push_with_dedupe(
size_t k,
T* bh_val,
int64_t* bh_ids,
T val,
int64_t ids,
std::unordered_map<int64_t, int64_t>* group_id_to_id,
std::unordered_map<int64_t, size_t>* group_id_to_index,
int64_t group_id) {
heap_push_with_dedupe<CMax<T, int64_t>>(k, bh_val, bh_ids, val, ids, group_id_to_id, group_id_to_index, group_id);
}

template <typename T>
inline void maxheap_replace_top_with_dedupe(
size_t k,
T* bh_val,
int64_t* bh_ids,
T val,
int64_t ids,
std::unordered_map<int64_t, int64_t>* group_id_to_id,
std::unordered_map<int64_t, size_t>* group_id_to_index,
int64_t group_id) {
heap_replace_top_with_dedupe<CMax<T, int64_t>>(k, bh_val, bh_ids, val, ids, group_id_to_id, group_id_to_index, group_id);
}

template <typename T>
inline void maxheap_update_with_dedupe(
size_t k,
T* bh_val,
int64_t* bh_ids,
T val,
int64_t ids,
std::unordered_map<int64_t, int64_t>* group_id_to_id,
std::unordered_map<int64_t, size_t>* group_id_to_index,
int64_t group_id) {
heap_update_with_dedupe<CMax<T, int64_t>>(k, bh_val, bh_ids, val, ids, group_id_to_id, group_id_to_index, group_id);
}

/*******************************************************************
* Heap initialization
*******************************************************************/
Expand Down
Loading

0 comments on commit c552f96

Please sign in to comment.