diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index ccdc3f28da..2e2e49cdbc 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -23,15 +23,18 @@ #include #include #include +#include +#include #include #include #include #include +#include + #include // std::max #include #include -#include #include namespace raft::neighbors::ivf_flat { @@ -303,20 +306,22 @@ struct index : ann::index { auto stream = resource::get_cuda_stream(res); // Actualize the list pointers - auto this_lists = lists(); - auto this_data_ptrs = data_ptrs(); - auto this_inds_ptrs = inds_ptrs(); - IdxT recompute_total_size = 0; + auto this_lists = lists(); + auto this_data_ptrs = data_ptrs(); + auto this_inds_ptrs = inds_ptrs(); for (uint32_t label = 0; label < this_lists.size(); label++) { - auto& list = this_lists[label]; - const auto data_ptr = list ? list->data.data_handle() : nullptr; - const auto inds_ptr = list ? list->indices.data_handle() : nullptr; - const auto list_size = list ? IdxT(list->size) : 0; + auto& list = this_lists[label]; + const auto data_ptr = list ? list->data.data_handle() : nullptr; + const auto inds_ptr = list ? list->indices.data_handle() : nullptr; copy(&this_data_ptrs(label), &data_ptr, 1, stream); copy(&this_inds_ptrs(label), &inds_ptr, 1, stream); - recompute_total_size += list_size; } - total_size_ = recompute_total_size; + auto this_list_sizes = list_sizes().data_handle(); + total_size_ = thrust::reduce(resource::get_thrust_policy(res), + this_list_sizes, + this_list_sizes + this_lists.size(), + 0, + raft::add_op{}); check_consistency(); } diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 1c9406e8a9..88bf53280b 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -201,6 +201,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { ivf_flat::detail::serialize(handle_, "ivf_flat_index", index_2); auto index_loaded = ivf_flat::detail::deserialize(handle_, "ivf_flat_index"); + ASSERT_EQ(index_2.size(), index_loaded.size()); ivf_flat::search(handle_, search_params,