Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize ResultHanlder, support range search for HNSW and Fast Scan #3190

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ We try to indicate most contributions here with the contributor names who are no
the Facebook Faiss team. Feel free to add entries here if you submit a PR.

## [Unreleased]
- Support for range search in HNSW and Fast scan IVF.
## [1.7.4] - 2023-04-12
### Added
- Added big batch IVF search for conducting efficient search with big batches of queries
Expand Down
2 changes: 2 additions & 0 deletions benchs/link_and_code/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ The code runs on top of Faiss. The HNSW index can be extended with a
`ReconstructFromNeighbors` C++ object that refines the distances. The
training is implemented in Python.

Update: 2023-12-28: the current Faiss dropped support for reconstruction with
this method.

Reproducing Table 2 in the paper
--------------------------------
Expand Down
1 change: 1 addition & 0 deletions contrib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5):
mask = DrefC == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))


def check_ref_range_results(Lref, Dref, Iref,
Lnew, Dnew, Inew):
""" compare range search results wrt. a reference result,
Expand Down
27 changes: 17 additions & 10 deletions faiss/IndexAdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,19 @@ struct AQDistanceComputerLUT : FlatCodesDistanceComputer {
* scanning implementation for search
************************************************************/

template <class VectorDistance, class ResultHandler>
template <class VectorDistance, class BlockResultHandler>
void search_with_decompress(
const IndexAdditiveQuantizer& ir,
const float* xq,
VectorDistance& vd,
ResultHandler& res) {
BlockResultHandler& res) {
const uint8_t* codes = ir.codes.data();
size_t ntotal = ir.ntotal;
size_t code_size = ir.code_size;
const AdditiveQuantizer* aq = ir.aq;

using SingleResultHandler = typename ResultHandler::SingleResultHandler;
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;

#pragma omp parallel for if (res.nq > 100)
for (int64_t q = 0; q < res.nq; q++) {
Expand All @@ -142,19 +143,23 @@ void search_with_decompress(
}
}

template <bool is_IP, AdditiveQuantizer::Search_type_t st, class ResultHandler>
template <
bool is_IP,
AdditiveQuantizer::Search_type_t st,
class BlockResultHandler>
void search_with_LUT(
const IndexAdditiveQuantizer& ir,
const float* xq,
ResultHandler& res) {
BlockResultHandler& res) {
const AdditiveQuantizer& aq = *ir.aq;
const uint8_t* codes = ir.codes.data();
size_t ntotal = ir.ntotal;
size_t code_size = aq.code_size;
size_t nq = res.nq;
size_t d = ir.d;

using SingleResultHandler = typename ResultHandler::SingleResultHandler;
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
std::unique_ptr<float[]> LUT(new float[nq * aq.total_codebook_size]);

aq.compute_LUT(nq, xq, LUT.get());
Expand Down Expand Up @@ -241,21 +246,23 @@ void IndexAdditiveQuantizer::search(
if (metric_type == METRIC_L2) {
using VD = VectorDistance<METRIC_L2>;
VD vd = {size_t(d), metric_arg};
HeapResultHandler<VD::C> rh(n, distances, labels, k);
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
search_with_decompress(*this, x, vd, rh);
} else if (metric_type == METRIC_INNER_PRODUCT) {
using VD = VectorDistance<METRIC_INNER_PRODUCT>;
VD vd = {size_t(d), metric_arg};
HeapResultHandler<VD::C> rh(n, distances, labels, k);
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
search_with_decompress(*this, x, vd, rh);
}
} else {
if (metric_type == METRIC_INNER_PRODUCT) {
HeapResultHandler<CMin<float, idx_t>> rh(n, distances, labels, k);
HeapBlockResultHandler<CMin<float, idx_t>> rh(
n, distances, labels, k);
search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
*this, x, rh);
} else {
HeapResultHandler<CMax<float, idx_t>> rh(n, distances, labels, k);
HeapBlockResultHandler<CMax<float, idx_t>> rh(
n, distances, labels, k);
switch (aq->search_type) {
#define DISPATCH(st) \
case AdditiveQuantizer::st: \
Expand Down
4 changes: 2 additions & 2 deletions faiss/IndexAdditiveQuantizerFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ void IndexAdditiveQuantizerFastScan::search(

NormTableScaler scaler(norm_scale);
if (metric_type == METRIC_L2) {
search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
search_dispatch_implem<true>(n, x, k, distances, labels, &scaler);
} else {
search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
search_dispatch_implem<false>(n, x, k, distances, labels, &scaler);
}
}

Expand Down
23 changes: 13 additions & 10 deletions faiss/IndexBinaryHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include <faiss/IndexBinaryHNSW.h>

#include <omp.h>
Expand All @@ -28,6 +26,7 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/random.h>
Expand Down Expand Up @@ -201,27 +200,31 @@ void IndexBinaryHNSW::search(
!params, "search params not supported for this index");
FAISS_THROW_IF_NOT(k > 0);

// we use the buffer for distances as float but convert them back
// to int in the end
float* distances_f = (float*)distances;

using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances_f, labels, k);

#pragma omp parallel
{
VisitedTable vt(ntotal);
std::unique_ptr<DistanceComputer> dis(get_distance_computer());
RH::SingleResultHandler res(bres);

#pragma omp for
for (idx_t i = 0; i < n; i++) {
idx_t* idxi = labels + i * k;
float* simi = (float*)(distances + i * k);

res.begin(i);
dis->set_query((float*)(x + i * code_size));

maxheap_heapify(k, simi, idxi);
hnsw.search(*dis, k, idxi, simi, vt);
maxheap_reorder(k, simi, idxi);
hnsw.search(*dis, res, vt);
res.end();
}
}

#pragma omp parallel for
for (int i = 0; i < n * k; ++i) {
distances[i] = std::round(((float*)distances)[i]);
distances[i] = std::round(distances_f[i]);
}
}

Expand Down
Loading