diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 27701586c8..af682a056d 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -162,6 +162,8 @@ set(FAISS_HEADERS impl/ProductQuantizer.h impl/Quantizer.h impl/ResidualQuantizer.h + impl/ResultCollector.h + impl/ResultCollectorFactory.h impl/ResultHandler.h impl/ScalarQuantizer.h impl/ThreadedIndex-inl.h diff --git a/faiss/Index.h b/faiss/Index.h index 4b4b302b47..13eab0c077 100644 --- a/faiss/Index.h +++ b/faiss/Index.h @@ -38,11 +38,12 @@ namespace faiss { -/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and -/// impl/DistanceComputer.h +/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h, +/// impl/DistanceComputer.h, and impl/ResultCollectorFactory.h struct IDSelector; struct RangeSearchResult; struct DistanceComputer; +struct ResultCollectorFactory; /** Parent class for the optional search paramenters. * @@ -52,6 +53,7 @@ struct DistanceComputer; struct SearchParameters { /// if non-null, only these IDs will be considered during search. IDSelector* sel = nullptr; + ResultCollectorFactory* col = nullptr; /// make sure we can dynamic_cast this virtual ~SearchParameters() {} }; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 9fc201ea39..540210a634 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -530,6 +531,15 @@ int search_from_candidates( int level, int nres_in = 0, const SearchParametersHNSW* params = nullptr) { + ResultCollectorFactory defaultFactory; + ResultCollectorFactory* collectorFactory; + if (params == nullptr || params->col == nullptr) { + collectorFactory = &defaultFactory; + } else { + collectorFactory = params->col; + } + ResultCollector* collector = collectorFactory->newCollector(); + int nres = nres_in; int ndis = 0; @@ -544,11 +554,7 @@ int search_from_candidates( float d = candidates.dis[i]; FAISS_ASSERT(v1 >= 0); if (!sel || sel->is_member(v1)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, d, v1); - } else if (d < D[0]) { - faiss::maxheap_replace_top(nres, D, I, d, v1); - } + collector->collect(k, nres, D, I, d, v1); } vt.set(v1); } @@ -612,11 +618,7 @@ int search_from_candidates( auto add_to_heap = [&](const size_t idx, const float dis) { if (!sel || sel->is_member(idx)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, dis, idx); - } else if (dis < D[0]) { - faiss::maxheap_replace_top(nres, D, I, dis, idx); - } + collector->collect(k, nres, D, I, dis, idx); } candidates.push(idx, dis); }; @@ -660,6 +662,9 @@ int search_from_candidates( } } + collector->finalize(nres, I); + collectorFactory->deleteCollector(collector); + if (level == 0) { stats.n1++; if (candidates.size() == 0) { diff --git a/faiss/impl/ResultCollector.h b/faiss/impl/ResultCollector.h new file mode 100644 index 0000000000..3e4dac3426 --- /dev/null +++ b/faiss/impl/ResultCollector.h @@ -0,0 +1,58 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +/** ResultCollector is intended to define how to collect search result */ + +namespace faiss { + +/** Encapsulates a set of ids to handle. */ +struct ResultCollector { + // For each result, collect method is called to store result + virtual void collect( + int k, + int& nres, + float* bh_val, + idx_t* bh_ids, + float val, + idx_t ids) = 0; + + // This method is called after all result is collected + virtual void finalize(idx_t nres, idx_t* bh_ids) = 0; + virtual ~ResultCollector() {} +}; + +struct DefaultCollector : ResultCollector { + void collect( + int k, + int& nres, + float* bh_val, + idx_t* bh_ids, + float val, + idx_t ids) override { + if (nres < k) { + faiss::maxheap_push(++nres, bh_val, bh_ids, val, ids); + } else if (val < bh_val[0]) { + faiss::maxheap_replace_top(nres, bh_val, bh_ids, val, ids); + } + } + + void finalize(idx_t nres, idx_t* bh_ids) override { + // Do nothing + } + + ~DefaultCollector() override {} +}; + +} // namespace faiss diff --git a/faiss/impl/ResultCollectorFactory.h b/faiss/impl/ResultCollectorFactory.h new file mode 100644 index 0000000000..6a15208a09 --- /dev/null +++ b/faiss/impl/ResultCollectorFactory.h @@ -0,0 +1,28 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +namespace faiss { + +/** ResultCollector is intended to define how to collect search result */ +struct ResultCollectorFactory { + DefaultCollector defaultCollector; + + // For each result, collect method is called to store result + virtual ResultCollector* newCollector() { + return &defaultCollector; + } + + virtual void deleteCollector(ResultCollector* collector) { + // Do nothing + } + // This method is called after all result is collected + virtual ~ResultCollectorFactory() {} +}; + +} // namespace faiss