From dfb3d2cef2907290e8910f19a8b1a2cfb766feed Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Fri, 19 May 2023 10:49:29 +0200 Subject: [PATCH] ivf-flat: fix incorrect recomputed size of the index (#1525) Fix ivf-flat's `recompute_internal_state` incorrectly using the amortized list sizes to compute the index size. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1525 --- cpp/include/raft/neighbors/ivf_flat_types.hpp | 27 +++++++++++-------- cpp/test/neighbors/ann_ivf_flat.cuh | 1 + 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index ccdc3f28da..2e2e49cdbc 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -23,15 +23,18 @@ #include #include #include +#include +#include #include #include #include #include +#include + #include // std::max #include #include -#include #include namespace raft::neighbors::ivf_flat { @@ -303,20 +306,22 @@ struct index : ann::index { auto stream = resource::get_cuda_stream(res); // Actualize the list pointers - auto this_lists = lists(); - auto this_data_ptrs = data_ptrs(); - auto this_inds_ptrs = inds_ptrs(); - IdxT recompute_total_size = 0; + auto this_lists = lists(); + auto this_data_ptrs = data_ptrs(); + auto this_inds_ptrs = inds_ptrs(); for (uint32_t label = 0; label < this_lists.size(); label++) { - auto& list = this_lists[label]; - const auto data_ptr = list ? list->data.data_handle() : nullptr; - const auto inds_ptr = list ? list->indices.data_handle() : nullptr; - const auto list_size = list ? IdxT(list->size) : 0; + auto& list = this_lists[label]; + const auto data_ptr = list ? list->data.data_handle() : nullptr; + const auto inds_ptr = list ? list->indices.data_handle() : nullptr; copy(&this_data_ptrs(label), &data_ptr, 1, stream); copy(&this_inds_ptrs(label), &inds_ptr, 1, stream); - recompute_total_size += list_size; } - total_size_ = recompute_total_size; + auto this_list_sizes = list_sizes().data_handle(); + total_size_ = thrust::reduce(resource::get_thrust_policy(res), + this_list_sizes, + this_list_sizes + this_lists.size(), + 0, + raft::add_op{}); check_consistency(); } diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 1c9406e8a9..88bf53280b 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -201,6 +201,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { ivf_flat::detail::serialize(handle_, "ivf_flat_index", index_2); auto index_loaded = ivf_flat::detail::deserialize(handle_, "ivf_flat_index"); + ASSERT_EQ(index_2.size(), index_loaded.size()); ivf_flat::search(handle_, search_params,