Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support grouping of vector data in HNSW #3227

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ set(FAISS_SRC
impl/AuxIndexStructures.cpp
impl/CodePacker.cpp
impl/IDSelector.cpp
impl/IDGrouper.cpp
impl/FaissException.cpp
impl/HNSW.cpp
impl/NSG.cpp
Expand Down Expand Up @@ -149,6 +150,7 @@ set(FAISS_HEADERS
impl/AuxIndexStructures.h
impl/CodePacker.h
impl/IDSelector.h
impl/IDGrouper.h
impl/DistanceComputer.h
impl/FaissAssert.h
impl/FaissException.h
Expand Down Expand Up @@ -183,6 +185,7 @@ set(FAISS_HEADERS
invlists/InvertedLists.h
invlists/InvertedListsIOHook.h
utils/AlignedTable.h
utils/GroupHeap.h
utils/Heap.h
utils/WorkerThread.h
utils/distances.h
Expand Down
8 changes: 6 additions & 2 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@

namespace faiss {

/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
/// impl/DistanceComputer.h
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h
/// ,impl/IDGrouper.h and impl/DistanceComputer.h
struct IDSelector;
struct IDGrouper;
struct RangeSearchResult;
struct DistanceComputer;

Expand All @@ -52,6 +53,9 @@ struct DistanceComputer;
struct SearchParameters {
/// if non-null, only these IDs will be considered during search.
IDSelector* sel = nullptr;
/// if non-null, only best matched ID per group will be included in the
/// result.
IDGrouper* grp = nullptr;
/// make sure we can dynamic_cast this
virtual ~SearchParameters() {}
};
Expand Down
13 changes: 10 additions & 3 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,17 @@ void IndexHNSW::search(
const SearchParameters* params_in) const {
FAISS_THROW_IF_NOT(k > 0);

using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances, labels, k);
if (params_in && params_in->grp) {
using RH = GroupedHeapBlockResultHandler<HNSW::C>;
RH bres(n, distances, labels, k, params_in->grp);

hnsw_search(this, n, x, bres, params_in);
hnsw_search(this, n, x, bres, params_in);
} else {
using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances, labels, k);

hnsw_search(this, n, x, bres, params_in);
}

if (is_similarity_metric(this->metric_type)) {
// we need to revert the negated distances
Expand Down
29 changes: 29 additions & 0 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,23 @@ struct ScopedSelChange {
}
};

/// RAII object to reset the IDGrouper in the params object
struct ScopedGrpChange {
SearchParameters* params = nullptr;
IDGrouper* old_grp = nullptr;

void set(SearchParameters* params_2, IDGrouper* new_grp) {
this->params = params_2;
old_grp = params_2->grp;
params_2->grp = new_grp;
}
~ScopedGrpChange() {
if (params) {
params->grp = old_grp;
}
}
};

} // namespace

template <typename IndexT>
Expand All @@ -114,6 +131,8 @@ void IndexIDMapTemplate<IndexT>::search(
const SearchParameters* params) const {
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
ScopedSelChange sel_change;
IDGrouperTranslated this_idgrptrans(this->id_map, nullptr);
ScopedGrpChange grp_change;

if (params && params->sel) {
auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
Expand All @@ -131,6 +150,16 @@ void IndexIDMapTemplate<IndexT>::search(
sel_change.set(params_non_const, &this_idtrans);
}
}

if (params && params->grp) {
auto idtrans = dynamic_cast<const IDGrouperTranslated*>(params->grp);

if (!idtrans) {
auto params_non_const = const_cast<SearchParameters*>(params);
this_idgrptrans.grp = params->grp;
grp_change.set(params_non_const, &this_idgrptrans);
}
}
index->search(n, x, k, distances, labels, params);
idx_t* li = labels;
#pragma omp parallel for
Expand Down
22 changes: 22 additions & 0 deletions faiss/IndexIDMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <faiss/Index.h>
#include <faiss/IndexBinary.h>
#include <faiss/impl/IDGrouper.h>
#include <faiss/impl/IDSelector.h>

#include <unordered_map>
Expand Down Expand Up @@ -124,4 +125,25 @@ struct IDSelectorTranslated : IDSelector {
}
};

// IDGrouper that translates the ids using an IDMap
struct IDGrouperTranslated : IDGrouper {
const std::vector<int64_t>& id_map;
const IDGrouper* grp;

IDGrouperTranslated(
const std::vector<int64_t>& id_map,
const IDGrouper* grp)
: id_map(id_map), grp(grp) {}

IDGrouperTranslated(IndexBinaryIDMap& index_idmap, const IDGrouper* grp)
: id_map(index_idmap.id_map), grp(grp) {}

IDGrouperTranslated(IndexIDMap& index_idmap, const IDGrouper* grp)
: id_map(index_idmap.id_map), grp(grp) {}

idx_t get_group(idx_t id) const override {
return grp->get_group(id_map[id]);
}
};

} // namespace faiss
6 changes: 6 additions & 0 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
return hres->k;
}

if (auto hres = dynamic_cast<
GroupedHeapBlockResultHandler<C>::SingleResultHandler*>(&res)) {
return hres->k;
}

return 1;
}

Expand Down
51 changes: 51 additions & 0 deletions faiss/impl/IDGrouper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <assert.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDGrouper.h>

namespace faiss {

/***********************************************************************
* IDGrouperBitmap
***********************************************************************/

IDGrouperBitmap::IDGrouperBitmap(size_t n, uint64_t* bitmap)
: n(n), bitmap(bitmap) {}

idx_t IDGrouperBitmap::get_group(idx_t id) const {
assert(id >= 0 && "id shouldn't be less than zero");
assert(id < this->n * 64 && "is should be less than total number of bits");

idx_t index = id >> 6; // div by 64
uint64_t block = this->bitmap[index] >>
(id & 63); // Equivalent of words[i] >> (index % 64)
// block is non zero after right shift, it means, next set bit is in current
// block The index of set bit is "given index" + "trailing zero in the right
// shifted word"
if (block != 0) {
return id + __builtin_ctzll(block);
}

while (++index < this->n) {
block = this->bitmap[index];
if (block != 0) {
return (index << 6) + __builtin_ctzll(block);
}
}

return NO_MORE_DOCS;
}

void IDGrouperBitmap::set_group(idx_t group_id) {
idx_t index = group_id >> 6;
this->bitmap[index] |= 1ULL
<< (group_id & 63); // Equivalent of 1ULL << (value % 64)
}

} // namespace faiss
51 changes: 51 additions & 0 deletions faiss/impl/IDGrouper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <limits>
#include <unordered_set>
#include <vector>

#include <faiss/MetricType.h>

/** IDGrouper is intended to define a group of vectors to include only
* the nearest vector of each group during search */

namespace faiss {

/** Encapsulates a group id of ids */
struct IDGrouper {
const idx_t NO_MORE_DOCS = std::numeric_limits<idx_t>::max();
virtual idx_t get_group(idx_t id) const = 0;
virtual ~IDGrouper() {}
};

/** One bit per element. Constructed with a bitmap, size ceil(n / 8).
*/
struct IDGrouperBitmap : IDGrouper {
// length of the bitmap array
size_t n;

// Array of uint64_t holding the bits
// Using uint64_t to leverage function __builtin_ctzll which is defined in
// faiss/impl/platform_macros.h Group id of a given id is next set bit in
// the bitmap
uint64_t* bitmap;

/** Construct with a binary mask
*
* @param n size of the bitmap array
* @param bitmap group id of a given id is next set bit in the bitmap
*/
IDGrouperBitmap(size_t n, uint64_t* bitmap);
idx_t get_group(idx_t id) const final;
void set_group(idx_t group_id);
~IDGrouperBitmap() override {}
};

} // namespace faiss
Loading