Skip to content

Commit

Permalink
ivf-flat: fix incorrect recomputed size of the index (#1525)
Browse files Browse the repository at this point in the history
Fix ivf-flat's `recompute_internal_state` incorrectly using the amortized list sizes to compute the index size.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1525
  • Loading branch information
achirkin authored May 19, 2023
1 parent 29d1c15 commit dfb3d2c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
27 changes: 16 additions & 11 deletions cpp/include/raft/neighbors/ivf_flat_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@
#include <raft/core/error.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/ivf_list_types.hpp>
#include <raft/util/integer_utils.hpp>

#include <thrust/reduce.h>

#include <algorithm> // std::max
#include <memory>
#include <optional>
#include <thrust/fill.h>
#include <type_traits>

namespace raft::neighbors::ivf_flat {
Expand Down Expand Up @@ -303,20 +306,22 @@ struct index : ann::index {
auto stream = resource::get_cuda_stream(res);

// Actualize the list pointers
auto this_lists = lists();
auto this_data_ptrs = data_ptrs();
auto this_inds_ptrs = inds_ptrs();
IdxT recompute_total_size = 0;
auto this_lists = lists();
auto this_data_ptrs = data_ptrs();
auto this_inds_ptrs = inds_ptrs();
for (uint32_t label = 0; label < this_lists.size(); label++) {
auto& list = this_lists[label];
const auto data_ptr = list ? list->data.data_handle() : nullptr;
const auto inds_ptr = list ? list->indices.data_handle() : nullptr;
const auto list_size = list ? IdxT(list->size) : 0;
auto& list = this_lists[label];
const auto data_ptr = list ? list->data.data_handle() : nullptr;
const auto inds_ptr = list ? list->indices.data_handle() : nullptr;
copy(&this_data_ptrs(label), &data_ptr, 1, stream);
copy(&this_inds_ptrs(label), &inds_ptr, 1, stream);
recompute_total_size += list_size;
}
total_size_ = recompute_total_size;
auto this_list_sizes = list_sizes().data_handle();
total_size_ = thrust::reduce(resource::get_thrust_policy(res),
this_list_sizes,
this_list_sizes + this_lists.size(),
0,
raft::add_op{});
check_consistency();
}

Expand Down
1 change: 1 addition & 0 deletions cpp/test/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {
ivf_flat::detail::serialize(handle_, "ivf_flat_index", index_2);

auto index_loaded = ivf_flat::detail::deserialize<DataT, IdxT>(handle_, "ivf_flat_index");
ASSERT_EQ(index_2.size(), index_loaded.size());

ivf_flat::search(handle_,
search_params,
Expand Down

0 comments on commit dfb3d2c

Please sign in to comment.