diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 27701586c8..af682a056d 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -162,6 +162,8 @@ set(FAISS_HEADERS impl/ProductQuantizer.h impl/Quantizer.h impl/ResidualQuantizer.h + impl/ResultCollector.h + impl/ResultCollectorFactory.h impl/ResultHandler.h impl/ScalarQuantizer.h impl/ThreadedIndex-inl.h diff --git a/faiss/Index.h b/faiss/Index.h index 4b4b302b47..13eab0c077 100644 --- a/faiss/Index.h +++ b/faiss/Index.h @@ -38,11 +38,12 @@ namespace faiss { -/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and -/// impl/DistanceComputer.h +/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h, +/// impl/DistanceComputer.h, and impl/ResultCollectorFactory.h struct IDSelector; struct RangeSearchResult; struct DistanceComputer; +struct ResultCollectorFactory; /** Parent class for the optional search paramenters. * @@ -52,6 +53,7 @@ struct DistanceComputer; struct SearchParameters { /// if non-null, only these IDs will be considered during search. IDSelector* sel = nullptr; + ResultCollectorFactory* col = nullptr; /// make sure we can dynamic_cast this virtual ~SearchParameters() {} }; diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp index 7972bec9a0..0f82a17c6f 100644 --- a/faiss/IndexIDMap.cpp +++ b/faiss/IndexIDMap.cpp @@ -18,6 +18,7 @@ #include #include #include +#include namespace faiss { @@ -102,6 +103,24 @@ struct ScopedSelChange { } }; +// RAII object to reset the id_map parameter in ResultCollectorFactory object +// This object make sure to reset the id_map parameter in ResultCollectorFactory +// once the program exist current method scope. +struct ScopedColChange { + ResultCollectorFactory* collector_factory = nullptr; + void set( + ResultCollectorFactory* collector_factory, + const std::vector* id_map) { + this->collector_factory = collector_factory; + collector_factory->id_map = id_map; + } + ~ScopedColChange() { + if (collector_factory) { + collector_factory->id_map = nullptr; + } + } +}; + } // namespace template @@ -114,6 +133,7 @@ void IndexIDMapTemplate::search( const SearchParameters* params) const { IDSelectorTranslated this_idtrans(this->id_map, nullptr); ScopedSelChange sel_change; + ScopedColChange col_change; if (params && params->sel) { auto idtrans = dynamic_cast(params->sel); @@ -131,6 +151,10 @@ void IndexIDMapTemplate::search( sel_change.set(params_non_const, &this_idtrans); } } + + if (params && params->col && !params->col->id_map) { + col_change.set(params->col, &this->id_map); + } index->search(n, x, k, distances, labels, params); idx_t* li = labels; #pragma omp parallel for diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 9fc201ea39..9da9f87feb 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -530,6 +531,15 @@ int search_from_candidates( int level, int nres_in = 0, const SearchParametersHNSW* params = nullptr) { + ResultCollectorFactory defaultFactory; + ResultCollectorFactory* collectorFactory; + if (params == nullptr || params->col == nullptr) { + collectorFactory = &defaultFactory; + } else { + collectorFactory = params->col; + } + ResultCollector* collector = collectorFactory->newCollector(); + int nres = nres_in; int ndis = 0; @@ -544,11 +554,7 @@ int search_from_candidates( float d = candidates.dis[i]; FAISS_ASSERT(v1 >= 0); if (!sel || sel->is_member(v1)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, d, v1); - } else if (d < D[0]) { - faiss::maxheap_replace_top(nres, D, I, d, v1); - } + collector->collect(k, nres, D, I, d, v1); } vt.set(v1); } @@ -612,11 +618,7 @@ int search_from_candidates( auto add_to_heap = [&](const size_t idx, const float dis) { if (!sel || sel->is_member(idx)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, dis, idx); - } else if (dis < D[0]) { - faiss::maxheap_replace_top(nres, D, I, dis, idx); - } + collector->collect(k, nres, D, I, dis, idx); } candidates.push(idx, dis); }; @@ -660,6 +662,11 @@ int search_from_candidates( } } + // Completed collection of result. Run post processor. + collector->post_process(nres, I); + // Collector completed its task. Release all resource of the collector. + collectorFactory->deleteCollector(collector); + if (level == 0) { stats.n1++; if (candidates.size() == 0) { diff --git a/faiss/impl/ResultCollector.h b/faiss/impl/ResultCollector.h new file mode 100644 index 0000000000..a0489fd689 --- /dev/null +++ b/faiss/impl/ResultCollector.h @@ -0,0 +1,74 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +/** + * ResultCollector is intended to define how to collect search result + * For each single search result, collect method will be called. + * After every results are collected, post_process method is called at the end. + */ + +namespace faiss { + +/** Encapsulates a set of ids to handle. */ +struct ResultCollector { + /** + * For each result, collect method is called to store result + * @param k number of vectors to search + * @param nres number of results in queue + * @param bh_val search result, distances from query + * @param bh_ids search result, ids of vectors + * @param val distance from query for current vector + * @param ids id of current vector + */ + virtual void collect( + int k, + int& nres, + float* bh_val, + idx_t* bh_ids, + float val, + idx_t ids) = 0; + + // This method is called after all result is collected + virtual void post_process(idx_t nres, idx_t* bh_ids) = 0; + virtual ~ResultCollector() {} +}; + +struct DefaultCollector : ResultCollector { + void collect( + int k, + int& nres, + float* bh_val, + idx_t* bh_ids, + float val, + idx_t ids) override { + if (nres < k) { + faiss::maxheap_push(++nres, bh_val, bh_ids, val, ids); + } else if (val < bh_val[0]) { + faiss::maxheap_replace_top(nres, bh_val, bh_ids, val, ids); + } + } + + // This method is called once all result is collected so that final post + // processing can be done For example, if the result is collected using + // group id, the group id can be converted back to its original id inside + // this method + void post_process(idx_t nres, idx_t* bh_ids) override { + // Do nothing + } + + ~DefaultCollector() override {} +}; + +} // namespace faiss diff --git a/faiss/impl/ResultCollectorFactory.h b/faiss/impl/ResultCollectorFactory.h new file mode 100644 index 0000000000..d9a667ac57 --- /dev/null +++ b/faiss/impl/ResultCollectorFactory.h @@ -0,0 +1,33 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +namespace faiss { + +/** ResultCollectorFactory to create a ResultCollector object */ +struct ResultCollectorFactory { + DefaultCollector default_collector; + const std::vector* id_map; + + // Create a new ResultCollector object + virtual ResultCollector* newCollector() { + return &default_collector; + } + + // For default case, the factory share single object and no need to delete + // the object. For other case, the factory can create a new object which + // need to be deleted later. We have deleteCollector method to handle both + // case as factory class knows how to release resource that it created + virtual void deleteCollector(ResultCollector* collector) { + // Do nothing + } + + virtual ~ResultCollectorFactory() {} +}; + +} // namespace faiss