Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vector deduplication for HNSW #3140

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ set(FAISS_SRC
impl/AuxIndexStructures.cpp
impl/CodePacker.cpp
impl/IDSelector.cpp
impl/IDDeduper.cpp
impl/FaissException.cpp
impl/HNSW.cpp
impl/NSG.cpp
Expand Down Expand Up @@ -148,6 +149,7 @@ set(FAISS_HEADERS
impl/AdditiveQuantizer.h
impl/AuxIndexStructures.h
impl/IDSelector.h
impl/IDDeduper.h
impl/DistanceComputer.h
impl/FaissAssert.h
impl/FaissException.h
Expand Down
6 changes: 4 additions & 2 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@

namespace faiss {

/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
/// impl/DistanceComputer.h
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h,
/// impl/IDDeduper.h and impl/DistanceComputer.h
struct IDSelector;
struct IDDeduper;
struct RangeSearchResult;
struct DistanceComputer;

Expand All @@ -52,6 +53,7 @@ struct DistanceComputer;
struct SearchParameters {
/// if non-null, only these IDs will be considered during search.
IDSelector* sel = nullptr;
IDDeduper* dedup = nullptr;
/// make sure we can dynamic_cast this
virtual ~SearchParameters() {}
};
Expand Down
1 change: 0 additions & 1 deletion faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ void IndexHNSW::search(

for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);

#pragma omp parallel
{
VisitedTable vt(ntotal);
Expand Down
102 changes: 94 additions & 8 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,16 +538,57 @@ int search_from_candidates(
: hnsw.check_relative_distance;
int efSearch = params ? params->efSearch : hnsw.efSearch;
const IDSelector* sel = params ? params->sel : nullptr;
const IDDeduper* dedup = params ? params->dedup : nullptr;
std::unordered_map<idx_t, idx_t> group_id_to_id;
std::unordered_map<idx_t, size_t> group_id_to_index;

for (int i = 0; i < candidates.size(); i++) {
idx_t v1 = candidates.ids[i];
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
if (!sel || sel->is_member(v1)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
if (dedup) {
idx_t group_id = dedup->group_id(v1);
if (group_id_to_index.find(group_id) ==
group_id_to_index.end()) {
if (nres < k) {
faiss::maxheap_push_with_dedupe(
++nres,
D,
I,
d,
v1,
&group_id_to_id,
&group_id_to_index,
group_id);
} else if (d < D[0]) {
faiss::maxheap_replace_top_with_dedupe(
nres,
D,
I,
d,
v1,
&group_id_to_id,
&group_id_to_index,
group_id);
}
} else if (d < D[group_id_to_index.at(group_id)]) {
faiss::maxheap_update_with_dedupe(
nres,
D,
I,
d,
v1,
&group_id_to_id,
&group_id_to_index,
group_id);
}
} else {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
}
}
}
vt.set(v1);
Expand Down Expand Up @@ -612,10 +653,48 @@ int search_from_candidates(

auto add_to_heap = [&](const size_t idx, const float dis) {
if (!sel || sel->is_member(idx)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
if (dedup) {
idx_t group_id = dedup->group_id(idx);
if (group_id_to_index.find(group_id) ==
group_id_to_index.end()) {
if (nres < k) {
faiss::maxheap_push_with_dedupe(
++nres,
D,
I,
dis,
idx,
&group_id_to_id,
&group_id_to_index,
group_id);
} else if (dis < D[0]) {
faiss::maxheap_replace_top_with_dedupe(
nres,
D,
I,
dis,
idx,
&group_id_to_id,
&group_id_to_index,
group_id);
}
} else if (dis < D[group_id_to_index.at(group_id)]) {
faiss::maxheap_update_with_dedupe(
nres,
D,
I,
dis,
idx,
&group_id_to_id,
&group_id_to_index,
group_id);
}
} else {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
}
}
}
candidates.push(idx, dis);
Expand Down Expand Up @@ -660,6 +739,13 @@ int search_from_candidates(
}
}

// Convert group id to id before return
if (dedup) {
for (size_t icnt = 0; icnt < nres; icnt++) {
I[icnt] = group_id_to_id.at(I[icnt]);
}
}

if (level == 0) {
stats.n1++;
if (candidates.size() == 0) {
Expand Down
21 changes: 21 additions & 0 deletions faiss/impl/IDDeduper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <faiss/impl/IDDeduper.h>

namespace faiss {

/***********************************************************************
* IDDeduperMap
***********************************************************************/
IDDeduperMap::IDDeduperMap(std::unordered_map<idx_t, idx_t>* m) : m(m) {}

idx_t IDDeduperMap::group_id(idx_t id) const {
return m->at(id);
}

} // namespace faiss
36 changes: 36 additions & 0 deletions faiss/impl/IDDeduper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <unordered_map>
#include <vector>

#include <faiss/MetricType.h>

/** IDDeduper is intended to define a group of vectors to dedupe */

namespace faiss {

/** Encapsulates a set of id groups to handle. */
struct IDDeduper {
virtual idx_t group_id(idx_t id) const = 0;
virtual ~IDDeduper() {}
};

/** id to group id mapping */
struct IDDeduperMap : IDDeduper {
std::unordered_map<idx_t, idx_t>* m;

IDDeduperMap(std::unordered_map<idx_t, idx_t>* m);

idx_t group_id(idx_t id) const final;

~IDDeduperMap() override {}
};

} // namespace faiss
Loading