Skip to content

Commit

Permalink
Introduce result collector for HNSW
Browse files Browse the repository at this point in the history
  • Loading branch information
heemin32 committed Dec 8, 2023
1 parent 5b6c4b4 commit 44d162e
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 12 deletions.
2 changes: 2 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ set(FAISS_HEADERS
impl/ProductQuantizer.h
impl/Quantizer.h
impl/ResidualQuantizer.h
impl/ResultCollector.h
impl/ResultCollectorFactory.h
impl/ResultHandler.h
impl/ScalarQuantizer.h
impl/ThreadedIndex-inl.h
Expand Down
6 changes: 4 additions & 2 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@

namespace faiss {

/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
/// impl/DistanceComputer.h
/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h,
/// impl/DistanceComputer.h, and impl/ResultCollectorFactory.h
struct IDSelector;
struct RangeSearchResult;
struct DistanceComputer;
struct ResultCollectorFactory;

/** Parent class for the optional search paramenters.
*
Expand All @@ -52,6 +53,7 @@ struct DistanceComputer;
struct SearchParameters {
/// if non-null, only these IDs will be considered during search.
IDSelector* sel = nullptr;
ResultCollectorFactory* col = nullptr;
/// make sure we can dynamic_cast this
virtual ~SearchParameters() {}
};
Expand Down
22 changes: 12 additions & 10 deletions faiss/impl/HNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultCollectorFactory.h>
#include <faiss/utils/prefetch.h>

#include <faiss/impl/platform_macros.h>
Expand Down Expand Up @@ -530,6 +531,12 @@ int search_from_candidates(
int level,
int nres_in = 0,
const SearchParametersHNSW* params = nullptr) {
ResultCollector* collector;
if (params == nullptr || params->col == nullptr) {
collector = new DefaultCollector();
} else {
collector = params->col->newCollector();
}
int nres = nres_in;
int ndis = 0;

Expand All @@ -544,11 +551,7 @@ int search_from_candidates(
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
if (!sel || sel->is_member(v1)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
}
collector->collect(k, nres, D, I, d, v1);
}
vt.set(v1);
}
Expand Down Expand Up @@ -612,11 +615,7 @@ int search_from_candidates(

auto add_to_heap = [&](const size_t idx, const float dis) {
if (!sel || sel->is_member(idx)) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, dis, idx);
} else if (dis < D[0]) {
faiss::maxheap_replace_top(nres, D, I, dis, idx);
}
collector->collect(k, nres, D, I, dis, idx);
}
candidates.push(idx, dis);
};
Expand Down Expand Up @@ -660,6 +659,9 @@ int search_from_candidates(
}
}

collector->finalize(nres, I);
delete collector;

if (level == 0) {
stats.n1++;
if (candidates.size() == 0) {
Expand Down
58 changes: 58 additions & 0 deletions faiss/impl/ResultCollector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <unordered_set>
#include <vector>

#include <faiss/MetricType.h>
#include <faiss/utils/Heap.h>

/** ResultCollector is intended to define how to collect search result */

namespace faiss {

/** Encapsulates a set of ids to handle. */
struct ResultCollector {
// For each result, collect method is called to store result
virtual void collect(
int k,
int& nres,
float* bh_val,
idx_t* bh_ids,
float val,
idx_t ids) = 0;

// This method is called after all result is collected
virtual void finalize(idx_t nres, idx_t* bh_ids) = 0;
virtual ~ResultCollector() {}
};

struct DefaultCollector : ResultCollector {
void collect(
int k,
int& nres,
float* bh_val,
idx_t* bh_ids,
float val,
idx_t ids) override {
if (nres < k) {
faiss::maxheap_push(++nres, bh_val, bh_ids, val, ids);
} else if (val < bh_val[0]) {
faiss::maxheap_replace_top(nres, bh_val, bh_ids, val, ids);
}
}

void finalize(idx_t nres, idx_t* bh_ids) override {
// Do nothing
}

~DefaultCollector() override {}
};

} // namespace faiss
21 changes: 21 additions & 0 deletions faiss/impl/ResultCollectorFactory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once
#include <faiss/impl/ResultCollector.h>
namespace faiss {

/** ResultCollector is intended to define how to collect search result */
struct ResultCollectorFactory {
// For each result, collect method is called to store result
virtual ResultCollector* newCollector() = 0;

// This method is called after all result is collected
virtual ~ResultCollectorFactory() {}
};

} // namespace faiss

0 comments on commit 44d162e

Please sign in to comment.