From 44d162e8434e8c0716c0bc1a36435751e2c96528 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Wed, 6 Dec 2023 16:33:52 -0800 Subject: [PATCH] Introduce result collector for HNSW --- faiss/CMakeLists.txt | 2 + faiss/Index.h | 6 ++- faiss/impl/HNSW.cpp | 22 ++++++----- faiss/impl/ResultCollector.h | 58 +++++++++++++++++++++++++++++ faiss/impl/ResultCollectorFactory.h | 21 +++++++++++ 5 files changed, 97 insertions(+), 12 deletions(-) create mode 100644 faiss/impl/ResultCollector.h create mode 100644 faiss/impl/ResultCollectorFactory.h diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 27701586c8..af682a056d 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -162,6 +162,8 @@ set(FAISS_HEADERS impl/ProductQuantizer.h impl/Quantizer.h impl/ResidualQuantizer.h + impl/ResultCollector.h + impl/ResultCollectorFactory.h impl/ResultHandler.h impl/ScalarQuantizer.h impl/ThreadedIndex-inl.h diff --git a/faiss/Index.h b/faiss/Index.h index 4b4b302b47..13eab0c077 100644 --- a/faiss/Index.h +++ b/faiss/Index.h @@ -38,11 +38,12 @@ namespace faiss { -/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and -/// impl/DistanceComputer.h +/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h, +/// impl/DistanceComputer.h, and impl/ResultCollectorFactory.h struct IDSelector; struct RangeSearchResult; struct DistanceComputer; +struct ResultCollectorFactory; /** Parent class for the optional search paramenters. * @@ -52,6 +53,7 @@ struct DistanceComputer; struct SearchParameters { /// if non-null, only these IDs will be considered during search. IDSelector* sel = nullptr; + ResultCollectorFactory* col = nullptr; /// make sure we can dynamic_cast this virtual ~SearchParameters() {} }; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 9fc201ea39..674fd9b987 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -530,6 +531,12 @@ int search_from_candidates( int level, int nres_in = 0, const SearchParametersHNSW* params = nullptr) { + ResultCollector* collector; + if (params == nullptr || params->col == nullptr) { + collector = new DefaultCollector(); + } else { + collector = params->col->newCollector(); + } int nres = nres_in; int ndis = 0; @@ -544,11 +551,7 @@ int search_from_candidates( float d = candidates.dis[i]; FAISS_ASSERT(v1 >= 0); if (!sel || sel->is_member(v1)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, d, v1); - } else if (d < D[0]) { - faiss::maxheap_replace_top(nres, D, I, d, v1); - } + collector->collect(k, nres, D, I, d, v1); } vt.set(v1); } @@ -612,11 +615,7 @@ int search_from_candidates( auto add_to_heap = [&](const size_t idx, const float dis) { if (!sel || sel->is_member(idx)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, dis, idx); - } else if (dis < D[0]) { - faiss::maxheap_replace_top(nres, D, I, dis, idx); - } + collector->collect(k, nres, D, I, dis, idx); } candidates.push(idx, dis); }; @@ -660,6 +659,9 @@ int search_from_candidates( } } + collector->finalize(nres, I); + delete collector; + if (level == 0) { stats.n1++; if (candidates.size() == 0) { diff --git a/faiss/impl/ResultCollector.h b/faiss/impl/ResultCollector.h new file mode 100644 index 0000000000..3e4dac3426 --- /dev/null +++ b/faiss/impl/ResultCollector.h @@ -0,0 +1,58 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +/** ResultCollector is intended to define how to collect search result */ + +namespace faiss { + +/** Encapsulates a set of ids to handle. */ +struct ResultCollector { + // For each result, collect method is called to store result + virtual void collect( + int k, + int& nres, + float* bh_val, + idx_t* bh_ids, + float val, + idx_t ids) = 0; + + // This method is called after all result is collected + virtual void finalize(idx_t nres, idx_t* bh_ids) = 0; + virtual ~ResultCollector() {} +}; + +struct DefaultCollector : ResultCollector { + void collect( + int k, + int& nres, + float* bh_val, + idx_t* bh_ids, + float val, + idx_t ids) override { + if (nres < k) { + faiss::maxheap_push(++nres, bh_val, bh_ids, val, ids); + } else if (val < bh_val[0]) { + faiss::maxheap_replace_top(nres, bh_val, bh_ids, val, ids); + } + } + + void finalize(idx_t nres, idx_t* bh_ids) override { + // Do nothing + } + + ~DefaultCollector() override {} +}; + +} // namespace faiss diff --git a/faiss/impl/ResultCollectorFactory.h b/faiss/impl/ResultCollectorFactory.h new file mode 100644 index 0000000000..f9f4098163 --- /dev/null +++ b/faiss/impl/ResultCollectorFactory.h @@ -0,0 +1,21 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +namespace faiss { + +/** ResultCollector is intended to define how to collect search result */ +struct ResultCollectorFactory { + // For each result, collect method is called to store result + virtual ResultCollector* newCollector() = 0; + + // This method is called after all result is collected + virtual ~ResultCollectorFactory() {} +}; + +} // namespace faiss