From 333eaeb841aae11de6f82014a707caa125794858 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Smoli=C5=84ski?= <29839376+lukaszsmolinski@users.noreply.github.com> Date: Wed, 18 Oct 2023 20:40:39 +0200 Subject: [PATCH 1/3] Fix incorrect results in bruteforce with filter --- hnswlib/bruteforce.h | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 8727cc8a..371847ad 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -107,27 +107,17 @@ class BruteforceSearch : public AlgorithmInterface { searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { assert(k <= cur_element_count); std::priority_queue> topResults; - if (cur_element_count == 0) return topResults; - for (int i = 0; i < k; i++) { + dist_t lastdist = std::numeric_limits::max(); + for (int i = 0; i < cur_element_count; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); - if ((!isIdAllowed) || (*isIdAllowed)(label)) { - topResults.emplace(dist, label); - } - } - dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; - for (int i = k; i < cur_element_count; i++) { - dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - if (dist <= lastdist) { + if (dist <= lastdist || topResults.size() < k) { labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); if ((!isIdAllowed) || (*isIdAllowed)(label)) { topResults.emplace(dist, label); - } - if (topResults.size() > k) - topResults.pop(); - - if (!topResults.empty()) { - lastdist = topResults.top().first; + if (topResults.size() > k) + topResults.pop(); + if (!topResults.empty()) + lastdist = topResults.top().first; } } } From 76f5affc87e1776642fcf91d6b4b34d32e8481b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Smoli=C5=84ski?= <29839376+lukaszsmolinski@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:26:03 +0100 Subject: [PATCH 2/3] Throw an exception when there are not k elements --- python_bindings/bindings.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index dd09e80a..14f1cabe 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -874,6 +874,9 @@ class BFIndex { ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { std::priority_queue> result = alg->searchKnn( (void*)items.data(row), k, p_idFilter); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contiguous 2D array. There are not enough elements."); for (int i = k - 1; i >= 0; i--) { auto& result_tuple = result.top(); data_numpy_d[row * k + i] = result_tuple.first; From 39bc6af6dc6dd14987f8036ff5b641c4bb276e63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Smoli=C5=84ski?= <29839376+lukaszsmolinski@users.noreply.github.com> Date: Sat, 4 Nov 2023 12:17:25 +0100 Subject: [PATCH 3/3] Add missing normalization check to BFIndex --- python_bindings/bindings.cpp | 46 ++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 14f1cabe..56ce9beb 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -871,19 +871,39 @@ class BFIndex { CustomFilterFunctor idFilter(filter); CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; - ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { - std::priority_queue> result = alg->searchKnn( - (void*)items.data(row), k, p_idFilter); - if (result.size() != k) - throw std::runtime_error( - "Cannot return the results in a contiguous 2D array. There are not enough elements."); - for (int i = k - 1; i >= 0; i--) { - auto& result_tuple = result.top(); - data_numpy_d[row * k + i] = result_tuple.first; - data_numpy_l[row * k + i] = result_tuple.second; - result.pop(); - } - }); + if (!normalize) { + ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg->searchKnn( + (void*)items.data(row), k, p_idFilter); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contiguous 2D array. There are not enough elements."); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); + } else { + std::vector norm_array(num_threads * features); + ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { + size_t start_idx = threadId * dim; + normalize_vector((float*)items.data(row), norm_array.data() + start_idx); + + std::priority_queue> result = alg->searchKnn( + (void*)(norm_array.data() + start_idx), k, p_idFilter); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contiguous 2D array. There are not enough elements."); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); + } } py::capsule free_when_done_l(data_numpy_l, [](void *f) {