Skip to content

Commit

Permalink
Add vector deduplication
Browse files Browse the repository at this point in the history
  • Loading branch information
heemin32 committed Nov 22, 2023
1 parent 467f70e commit 864e0b3
Show file tree
Hide file tree
Showing 9 changed files with 504 additions and 11 deletions.
2 changes: 2 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ set(FAISS_SRC
impl/AuxIndexStructures.cpp
impl/CodePacker.cpp
impl/IDSelector.cpp
impl/IDDeduper.cpp
impl/FaissException.cpp
impl/HNSW.cpp
impl/NSG.cpp
Expand Down Expand Up @@ -148,6 +149,7 @@ set(FAISS_HEADERS
impl/AdditiveQuantizer.h
impl/AuxIndexStructures.h
impl/IDSelector.h
impl/IDDeduper.h
impl/DistanceComputer.h
impl/FaissAssert.h
impl/FaissException.h
Expand Down
6 changes: 4 additions & 2 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@

namespace faiss {

/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
/// impl/DistanceComputer.h
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h,
/// impl/IDDeduper.h and impl/DistanceComputer.h
struct IDSelector;
struct IDDeduper;
struct RangeSearchResult;
struct DistanceComputer;

Expand All @@ -52,6 +53,7 @@ struct DistanceComputer;
struct SearchParameters {
/// if non-null, only these IDs will be considered during search.
IDSelector* sel = nullptr;
IDDeduper* dedup = nullptr;
/// make sure we can dynamic_cast this
virtual ~SearchParameters() {}
};
Expand Down
1 change: 0 additions & 1 deletion faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ void IndexHNSW::search(

for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);

#pragma omp parallel
{
VisitedTable vt(ntotal);
Expand Down
102 changes: 94 additions & 8 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,16 +538,57 @@ int search_from_candidates(
: hnsw.check_relative_distance;
int efSearch = params ? params->efSearch : hnsw.efSearch;
const IDSelector* sel = params ? params->sel : nullptr;
const IDDeduper* dedup = params ? params->dedup : nullptr;
std::unordered_map<idx_t, idx_t> group_id_to_id;
std::unordered_map<idx_t, size_t> group_id_to_index;

for (int i = 0; i < candidates.size(); i++) {
idx_t v1 = candidates.ids[i];
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
if (!sel || sel->is_member(v1)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
if (dedup) {
idx_t group_id = dedup->group_id(v1);
if (group_id_to_index.find(group_id) ==
group_id_to_index.end()) {
if (nres < k) {
faiss::maxheap_push_with_dedupe(
++nres,
D,
I,
d,
v1,
&group_id_to_id,
&group_id_to_index,
group_id);
} else if (d < D[0]) {
faiss::maxheap_replace_top_with_dedupe(
nres,
D,
I,
d,
v1,
&group_id_to_id,
&group_id_to_index,
group_id);
}
} else if (d < D[group_id_to_index.at(group_id)]) {
faiss::maxheap_update_with_dedupe(
nres,
D,
I,
d,
v1,
&group_id_to_id,
&group_id_to_index,
group_id);
}
} else {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
}
}
}
vt.set(v1);
Expand Down Expand Up @@ -612,10 +653,48 @@ int search_from_candidates(

auto add_to_heap = [&](const size_t idx, const float dis) {
if (!sel || sel->is_member(idx)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
if (dedup) {
idx_t group_id = dedup->group_id(idx);
if (group_id_to_index.find(group_id) ==
group_id_to_index.end()) {
if (nres < k) {
faiss::maxheap_push_with_dedupe(
++nres,
D,
I,
dis,
idx,
&group_id_to_id,
&group_id_to_index,
group_id);
} else if (dis < D[0]) {
faiss::maxheap_replace_top_with_dedupe(
nres,
D,
I,
dis,
idx,
&group_id_to_id,
&group_id_to_index,
group_id);
}
} else if (dis < D[group_id_to_index.at(group_id)]) {
faiss::maxheap_update_with_dedupe(
nres,
D,
I,
dis,
idx,
&group_id_to_id,
&group_id_to_index,
group_id);
}
} else {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
}
}
}
candidates.push(idx, dis);
Expand Down Expand Up @@ -660,6 +739,13 @@ int search_from_candidates(
}
}

// Convert group id to id before return
if (dedup) {
for (size_t icnt = 0; icnt < nres; icnt++) {
I[icnt] = group_id_to_id.at(I[icnt]);
}
}

if (level == 0) {
stats.n1++;
if (candidates.size() == 0) {
Expand Down
21 changes: 21 additions & 0 deletions faiss/impl/IDDeduper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <faiss/impl/IDDeduper.h>

namespace faiss {

/***********************************************************************
* IDDeduperMap
***********************************************************************/
IDDeduperMap::IDDeduperMap(std::unordered_map<idx_t, idx_t>* m) : m(m) {}

idx_t IDDeduperMap::group_id(idx_t id) const {
return m->at(id);
}

} // namespace faiss
36 changes: 36 additions & 0 deletions faiss/impl/IDDeduper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <unordered_map>
#include <vector>

#include <faiss/MetricType.h>

/** IDDeduper is intended to define a group of vectors to dedupe */

namespace faiss {

/** Encapsulates a set of id groups to handle. */
struct IDDeduper {
virtual idx_t group_id(idx_t id) const = 0;
virtual ~IDDeduper() {}
};

/** id to group id mapping */
struct IDDeduperMap : IDDeduper {
std::unordered_map<idx_t, idx_t>* m;

IDDeduperMap(std::unordered_map<idx_t, idx_t>* m);

idx_t group_id(idx_t id) const final;

~IDDeduperMap() override {}
};

} // namespace faiss
Loading

0 comments on commit 864e0b3

Please sign in to comment.