Skip to content

Commit

Permalink
Add IDGrouper for HNSW
Browse files Browse the repository at this point in the history
  • Loading branch information
heemin32 committed May 22, 2024
1 parent b39dd4d commit e5c507e
Show file tree
Hide file tree
Showing 13 changed files with 890 additions and 5 deletions.
3 changes: 3 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ set(FAISS_SRC
impl/AuxIndexStructures.cpp
impl/CodePacker.cpp
impl/IDSelector.cpp
impl/IDGrouper.cpp
impl/FaissException.cpp
impl/HNSW.cpp
impl/NSG.cpp
Expand Down Expand Up @@ -149,6 +150,7 @@ set(FAISS_HEADERS
impl/AuxIndexStructures.h
impl/CodePacker.h
impl/IDSelector.h
impl/IDGrouper.h
impl/DistanceComputer.h
impl/FaissAssert.h
impl/FaissException.h
Expand Down Expand Up @@ -183,6 +185,7 @@ set(FAISS_HEADERS
invlists/InvertedLists.h
invlists/InvertedListsIOHook.h
utils/AlignedTable.h
utils/GroupHeap.h
utils/Heap.h
utils/WorkerThread.h
utils/distances.h
Expand Down
8 changes: 6 additions & 2 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@

namespace faiss {

/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
/// impl/DistanceComputer.h
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h
/// ,impl/IDGrouper.h and impl/DistanceComputer.h
struct IDSelector;
struct IDGrouper;
struct RangeSearchResult;
struct DistanceComputer;

Expand All @@ -52,6 +53,9 @@ struct DistanceComputer;
struct SearchParameters {
/// if non-null, only these IDs will be considered during search.
IDSelector* sel = nullptr;
/// if non-null, only best matched ID per group will be included in the
/// result.
IDGrouper* grp = nullptr;
/// make sure we can dynamic_cast this
virtual ~SearchParameters() {}
};
Expand Down
13 changes: 10 additions & 3 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,17 @@ void IndexHNSW::search(
const SearchParameters* params_in) const {
FAISS_THROW_IF_NOT(k > 0);

using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances, labels, k);
if (params_in && params_in->grp) {
using RH = GroupedHeapBlockResultHandler<HNSW::C>;
RH bres(n, distances, labels, k, params_in->grp);

hnsw_search(this, n, x, bres, params_in);
hnsw_search(this, n, x, bres, params_in);
} else {
using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances, labels, k);

hnsw_search(this, n, x, bres, params_in);
}

if (is_similarity_metric(this->metric_type)) {
// we need to revert the negated distances
Expand Down
29 changes: 29 additions & 0 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,23 @@ struct ScopedSelChange {
}
};

/// RAII object to reset the IDGrouper in the params object
struct ScopedGrpChange {
SearchParameters* params = nullptr;
IDGrouper* old_grp = nullptr;

void set(SearchParameters* params_2, IDGrouper* new_grp) {
this->params = params_2;
old_grp = params_2->grp;
params_2->grp = new_grp;
}
~ScopedGrpChange() {
if (params) {
params->grp = old_grp;
}
}
};

} // namespace

template <typename IndexT>
Expand All @@ -114,6 +131,8 @@ void IndexIDMapTemplate<IndexT>::search(
const SearchParameters* params) const {
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
ScopedSelChange sel_change;
IDGrouperTranslated this_idgrptrans(this->id_map, nullptr);
ScopedGrpChange grp_change;

if (params && params->sel) {
auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
Expand All @@ -131,6 +150,16 @@ void IndexIDMapTemplate<IndexT>::search(
sel_change.set(params_non_const, &this_idtrans);
}
}

if (params && params->grp) {
auto idtrans = dynamic_cast<const IDGrouperTranslated*>(params->grp);

if (!idtrans) {
auto params_non_const = const_cast<SearchParameters*>(params);
this_idgrptrans.grp = params->grp;
grp_change.set(params_non_const, &this_idgrptrans);
}
}
index->search(n, x, k, distances, labels, params);
idx_t* li = labels;
#pragma omp parallel for
Expand Down
22 changes: 22 additions & 0 deletions faiss/IndexIDMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <faiss/Index.h>
#include <faiss/IndexBinary.h>
#include <faiss/impl/IDGrouper.h>
#include <faiss/impl/IDSelector.h>

#include <unordered_map>
Expand Down Expand Up @@ -124,4 +125,25 @@ struct IDSelectorTranslated : IDSelector {
}
};

// IDGrouper that translates the ids using an IDMap
struct IDGrouperTranslated : IDGrouper {
const std::vector<int64_t>& id_map;
const IDGrouper* grp;

IDGrouperTranslated(
const std::vector<int64_t>& id_map,
const IDGrouper* grp)
: id_map(id_map), grp(grp) {}

IDGrouperTranslated(IndexBinaryIDMap& index_idmap, const IDGrouper* grp)
: id_map(index_idmap.id_map), grp(grp) {}

IDGrouperTranslated(IndexIDMap& index_idmap, const IDGrouper* grp)
: id_map(index_idmap.id_map), grp(grp) {}

idx_t get_group(idx_t id) const override {
return grp->get_group(id_map[id]);
}
};

} // namespace faiss
6 changes: 6 additions & 0 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
return hres->k;
}

if (auto hres = dynamic_cast<
GroupedHeapBlockResultHandler<C>::SingleResultHandler*>(&res)) {
return hres->k;
}

return 1;
}

Expand Down
51 changes: 51 additions & 0 deletions faiss/impl/IDGrouper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <assert.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDGrouper.h>

namespace faiss {

/***********************************************************************
* IDGrouperBitmap
***********************************************************************/

IDGrouperBitmap::IDGrouperBitmap(size_t n, uint64_t* bitmap)
: n(n), bitmap(bitmap) {}

idx_t IDGrouperBitmap::get_group(idx_t id) const {
assert(id >= 0 && "id shouldn't be less than zero");
assert(id < this->n * 64 && "is should be less than total number of bits");

idx_t index = id >> 6; // div by 64
uint64_t block = this->bitmap[index] >>
(id & 63); // Equivalent of words[i] >> (index % 64)
// block is non zero after right shift, it means, next set bit is in current
// block The index of set bit is "given index" + "trailing zero in the right
// shifted word"
if (block != 0) {
return id + __builtin_ctzll(block);
}

while (++index < this->n) {
block = this->bitmap[index];
if (block != 0) {
return (index << 6) + __builtin_ctzll(block);
}
}

return NO_MORE_DOCS;
}

void IDGrouperBitmap::set_group(idx_t group_id) {
idx_t index = group_id >> 6;
this->bitmap[index] |= 1ULL
<< (group_id & 63); // Equivalent of 1ULL << (value % 64)
}

} // namespace faiss
51 changes: 51 additions & 0 deletions faiss/impl/IDGrouper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <limits>
#include <unordered_set>
#include <vector>

#include <faiss/MetricType.h>

/** IDGrouper is intended to define a group of vectors to include only
* the nearest vector of each group during search */

namespace faiss {

/** Encapsulates a group id of ids */
struct IDGrouper {
const idx_t NO_MORE_DOCS = std::numeric_limits<idx_t>::max();
virtual idx_t get_group(idx_t id) const = 0;
virtual ~IDGrouper() {}
};

/** One bit per element. Constructed with a bitmap, size ceil(n / 8).
*/
struct IDGrouperBitmap : IDGrouper {
// length of the bitmap array
size_t n;

// Array of uint64_t holding the bits
// Using uint64_t to leverage function __builtin_ctzll which is defined in
// faiss/impl/platform_macros.h Group id of a given id is next set bit in
// the bitmap
uint64_t* bitmap;

/** Construct with a binary mask
*
* @param n size of the bitmap array
* @param bitmap group id of a given id is next set bit in the bitmap
*/
IDGrouperBitmap(size_t n, uint64_t* bitmap);
idx_t get_group(idx_t id) const final;
void set_group(idx_t group_id);
~IDGrouperBitmap() override {}
};

} // namespace faiss
Loading

0 comments on commit e5c507e

Please sign in to comment.