Skip to content

Commit

Permalink
Custom patch to support multi-vector
Browse files Browse the repository at this point in the history
  • Loading branch information
heemin32 committed Dec 27, 2023
1 parent 5b6c4b4 commit 231b522
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 12 deletions.
2 changes: 2 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ set(FAISS_HEADERS
impl/ProductQuantizer.h
impl/Quantizer.h
impl/ResidualQuantizer.h
impl/ResultCollector.h
impl/ResultCollectorFactory.h
impl/ResultHandler.h
impl/ScalarQuantizer.h
impl/ThreadedIndex-inl.h
Expand Down
6 changes: 4 additions & 2 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@

namespace faiss {

/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
/// impl/DistanceComputer.h
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h,
/// impl/DistanceComputer.h, and impl/ResultCollectorFactory.h
struct IDSelector;
struct RangeSearchResult;
struct DistanceComputer;
struct ResultCollectorFactory;

/** Parent class for the optional search paramenters.
*
Expand All @@ -52,6 +53,7 @@ struct DistanceComputer;
struct SearchParameters {
/// if non-null, only these IDs will be considered during search.
IDSelector* sel = nullptr;
ResultCollectorFactory* col = nullptr;
/// make sure we can dynamic_cast this
virtual ~SearchParameters() {}
};
Expand Down
24 changes: 24 additions & 0 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/WorkerThread.h>
#include <faiss/impl/ResultCollectorFactory.h>

namespace faiss {

Expand Down Expand Up @@ -102,6 +103,24 @@ struct ScopedSelChange {
}
};

// RAII object to reset the id_map parameter in ResultCollectorFactory object
// This object make sure to reset the id_map parameter in ResultCollectorFactory
// once the program exist current method scope.
struct ScopedColChange {
ResultCollectorFactory* collector_factory = nullptr;
void set(
ResultCollectorFactory* collector_factory,
const std::vector<int64_t>* id_map) {
this->collector_factory = collector_factory;
collector_factory->id_map = id_map;
}
~ScopedColChange() {
if (collector_factory) {
collector_factory->id_map = nullptr;
}
}
};

} // namespace

template <typename IndexT>
Expand All @@ -114,6 +133,7 @@ void IndexIDMapTemplate<IndexT>::search(
const SearchParameters* params) const {
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
ScopedSelChange sel_change;
ScopedColChange col_change;

if (params && params->sel) {
auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
Expand All @@ -131,6 +151,10 @@ void IndexIDMapTemplate<IndexT>::search(
sel_change.set(params_non_const, &this_idtrans);
}
}

if (params && params->col && !params->col->id_map) {
col_change.set(params->col, &this->id_map);
}
index->search(n, x, k, distances, labels, params);
idx_t* li = labels;
#pragma omp parallel for
Expand Down
27 changes: 17 additions & 10 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultCollectorFactory.h>
#include <faiss/utils/prefetch.h>

#include <faiss/impl/platform_macros.h>
Expand Down Expand Up @@ -530,6 +531,15 @@ int search_from_candidates(
int level,
int nres_in = 0,
const SearchParametersHNSW* params = nullptr) {
ResultCollectorFactory defaultFactory;
ResultCollectorFactory* collectorFactory;
if (params == nullptr || params->col == nullptr) {
collectorFactory = &defaultFactory;
} else {
collectorFactory = params->col;
}
ResultCollector* collector = collectorFactory->newCollector();

int nres = nres_in;
int ndis = 0;

Expand All @@ -544,11 +554,7 @@ int search_from_candidates(
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
if (!sel || sel->is_member(v1)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
}
collector->collect(k, nres, D, I, d, v1);
}
vt.set(v1);
}
Expand Down Expand Up @@ -612,11 +618,7 @@ int search_from_candidates(

auto add_to_heap = [&](const size_t idx, const float dis) {
if (!sel || sel->is_member(idx)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
}
collector->collect(k, nres, D, I, dis, idx);
}
candidates.push(idx, dis);
};
Expand Down Expand Up @@ -660,6 +662,11 @@ int search_from_candidates(
}
}

// Completed collection of result. Run post processor.
collector->post_process(nres, I);
// Collector completed its task. Release all resource of the collector.
collectorFactory->deleteCollector(collector);

if (level == 0) {
stats.n1++;
if (candidates.size() == 0) {
Expand Down
74 changes: 74 additions & 0 deletions faiss/impl/ResultCollector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <unordered_set>
#include <vector>

#include <faiss/MetricType.h>
#include <faiss/utils/Heap.h>

/**
* ResultCollector is intended to define how to collect search result
* For each single search result, collect method will be called.
* After every results are collected, post_process method is called at the end.
*/

namespace faiss {

/** Encapsulates a set of ids to handle. */
struct ResultCollector {
/**
* For each result, collect method is called to store result
* @param k number of vectors to search
* @param nres number of results in queue
* @param bh_val search result, distances from query
* @param bh_ids search result, ids of vectors
* @param val distance from query for current vector
* @param ids id of current vector
*/
virtual void collect(
int k,
int& nres,
float* bh_val,
idx_t* bh_ids,
float val,
idx_t ids) = 0;

// This method is called after all result is collected
virtual void post_process(idx_t nres, idx_t* bh_ids) = 0;
virtual ~ResultCollector() {}
};

struct DefaultCollector : ResultCollector {
void collect(
int k,
int& nres,
float* bh_val,
idx_t* bh_ids,
float val,
idx_t ids) override {
if (nres < k) {
faiss::maxheap_push(++nres, bh_val, bh_ids, val, ids);
} else if (val < bh_val[0]) {
faiss::maxheap_replace_top(nres, bh_val, bh_ids, val, ids);
}
}

// This method is called once all result is collected so that final post
// processing can be done For example, if the result is collected using
// group id, the group id can be converted back to its original id inside
// this method
void post_process(idx_t nres, idx_t* bh_ids) override {
// Do nothing
}

~DefaultCollector() override {}
};

} // namespace faiss
33 changes: 33 additions & 0 deletions faiss/impl/ResultCollectorFactory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once
#include <faiss/impl/ResultCollector.h>
namespace faiss {

/** ResultCollectorFactory to create a ResultCollector object */
struct ResultCollectorFactory {
DefaultCollector default_collector;
const std::vector<int64_t>* id_map;

// Create a new ResultCollector object
virtual ResultCollector* newCollector() {
return &default_collector;
}

// For default case, the factory share single object and no need to delete
// the object. For other case, the factory can create a new object which
// need to be deleted later. We have deleteCollector method to handle both
// case as factory class knows how to release resource that it created
virtual void deleteCollector(ResultCollector* collector) {
// Do nothing
}

virtual ~ResultCollectorFactory() {}
};

} // namespace faiss

0 comments on commit 231b522

Please sign in to comment.