Skip to content

Commit

Permalink
Custom patch to support multi-vector
Browse files Browse the repository at this point in the history
  • Loading branch information
heemin32 committed Dec 22, 2023
1 parent 5b6c4b4 commit c59bc53
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 14 deletions.
2 changes: 2 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ set(FAISS_HEADERS
impl/ProductQuantizer.h
impl/Quantizer.h
impl/ResidualQuantizer.h
impl/ResultCollector.h
impl/ResultCollectorFactory.h
impl/ResultHandler.h
impl/ScalarQuantizer.h
impl/ThreadedIndex-inl.h
Expand Down
6 changes: 4 additions & 2 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@

namespace faiss {

/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
/// impl/DistanceComputer.h
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h,
/// impl/DistanceComputer.h, and impl/ResultCollectorFactory.h
struct IDSelector;
struct RangeSearchResult;
struct DistanceComputer;
struct ResultCollectorFactory;

/** Parent class for the optional search paramenters.
*
Expand All @@ -52,6 +53,7 @@ struct DistanceComputer;
struct SearchParameters {
/// if non-null, only these IDs will be considered during search.
IDSelector* sel = nullptr;
ResultCollectorFactory* col = nullptr;
/// make sure we can dynamic_cast this
virtual ~SearchParameters() {}
};
Expand Down
23 changes: 23 additions & 0 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,24 @@ struct ScopedSelChange {
}
};

// RAII object to reset the id_map parameter in ResultCollectorFactory object
// This object make sure to reset the id_map parameter in ResultCollectorFactory
// once the program exist current method scope.
struct ScopedColChange {
ResultCollectorFactory* collector_factory = nullptr;
void set(
ResultCollectorFactory* collector_factory,
const std::vector<int64_t>* id_map) {
this->collector_factory = collector_factory;
collector_factory->id_map = id_map;
}
~ScopedColChange() {
if (collector_factory) {
collector_factory->id_map = nullptr;
}
}
};

} // namespace

template <typename IndexT>
Expand All @@ -114,6 +132,7 @@ void IndexIDMapTemplate<IndexT>::search(
const SearchParameters* params) const {
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
ScopedSelChange sel_change;
ScopedColChange col_change;

if (params && params->sel) {
auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
Expand All @@ -131,6 +150,10 @@ void IndexIDMapTemplate<IndexT>::search(
sel_change.set(params_non_const, &this_idtrans);
}
}

if (params && params->col && !params->col->id_map) {
col_change.set(params->col, &this->id_map);
}
index->search(n, x, k, distances, labels, params);
idx_t* li = labels;
#pragma omp parallel for
Expand Down
1 change: 1 addition & 0 deletions faiss/IndexIDMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <faiss/Index.h>
#include <faiss/IndexBinary.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultCollectorFactory.h>

#include <unordered_map>
#include <vector>
Expand Down
31 changes: 19 additions & 12 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultCollectorFactory.h>
#include <faiss/utils/prefetch.h>

#include <faiss/impl/platform_macros.h>
Expand Down Expand Up @@ -111,8 +112,8 @@ void HNSW::print_neighbor_stats(int level) const {
level,
nb_neighbors(level));
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
reduction(+: tot_reciprocal) reduction(+: n_node)
#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
reduction(+ : tot_reciprocal) reduction(+ : n_node)
for (int i = 0; i < levels.size(); i++) {
if (levels[i] > level) {
n_node++;
Expand Down Expand Up @@ -530,6 +531,15 @@ int search_from_candidates(
int level,
int nres_in = 0,
const SearchParametersHNSW* params = nullptr) {
ResultCollectorFactory defaultFactory;
ResultCollectorFactory* collectorFactory;
if (params == nullptr || params->col == nullptr) {
collectorFactory = &defaultFactory;
} else {
collectorFactory = params->col;
}
ResultCollector* collector = collectorFactory->newCollector();

int nres = nres_in;
int ndis = 0;

Expand All @@ -544,11 +554,7 @@ int search_from_candidates(
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
if (!sel || sel->is_member(v1)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
}
collector->collect(k, nres, D, I, d, v1);
}
vt.set(v1);
}
Expand Down Expand Up @@ -612,11 +618,7 @@ int search_from_candidates(

auto add_to_heap = [&](const size_t idx, const float dis) {
if (!sel || sel->is_member(idx)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
}
collector->collect(k, nres, D, I, dis, idx);
}
candidates.push(idx, dis);
};
Expand Down Expand Up @@ -660,6 +662,11 @@ int search_from_candidates(
}
}

// Completed collection of result. Run post processor.
collector->post_process(nres, I);
// Collector completed its task. Release all resource of the collector.
collectorFactory->deleteCollector(collector);

if (level == 0) {
stats.n1++;
if (candidates.size() == 0) {
Expand Down
62 changes: 62 additions & 0 deletions faiss/impl/ResultCollector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <unordered_set>
#include <vector>

#include <faiss/MetricType.h>
#include <faiss/utils/Heap.h>

/** ResultCollector is intended to define how to collect search result */

namespace faiss {

/** Encapsulates a set of ids to handle. */
struct ResultCollector {
// For each result, collect method is called to store result
virtual void collect(
int k,
int& nres,
float* bh_val,
idx_t* bh_ids,
float val,
idx_t ids) = 0;

// This method is called after all result is collected
virtual void post_process(idx_t nres, idx_t* bh_ids) = 0;
virtual ~ResultCollector() {}
};

struct DefaultCollector : ResultCollector {
void collect(
int k,
int& nres,
float* bh_val,
idx_t* bh_ids,
float val,
idx_t ids) override {
if (nres < k) {
faiss::maxheap_push(++nres, bh_val, bh_ids, val, ids);
} else if (val < bh_val[0]) {
faiss::maxheap_replace_top(nres, bh_val, bh_ids, val, ids);
}
}

// This method is called once all result is collected so that final post
// processing can be done For example, if the result is collected using
// group id, the group id can be converted back to its original id inside
// this method
void post_process(idx_t nres, idx_t* bh_ids) override {
// Do nothing
}

~DefaultCollector() override {}
};

} // namespace faiss
33 changes: 33 additions & 0 deletions faiss/impl/ResultCollectorFactory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once
#include <faiss/impl/ResultCollector.h>
namespace faiss {

/** ResultCollectorFactory to create a ResultCollector object */
struct ResultCollectorFactory {
DefaultCollector default_collector;
const std::vector<int64_t>* id_map;

// Create a new ResultCollector object
virtual ResultCollector* newCollector() {
return &default_collector;
}

// For default case, the factory share single object and no need to delete
// the object. For other case, the factory can create a new object which
// need to be deleted later. We have deleteCollector method to handle both
// case as factory class knows how to release resource that it created
virtual void deleteCollector(ResultCollector* collector) {
// Do nothing
}

virtual ~ResultCollectorFactory() {}
};

} // namespace faiss

0 comments on commit c59bc53

Please sign in to comment.