From 328a1794569c0b1986c4dec99ad5c42917f8360f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 7 Feb 2023 02:38:58 +0100 Subject: [PATCH 01/33] Initial index splitting --- cpp/include/raft/neighbors/ivf_flat_types.hpp | 103 +++++++++++++++++- 1 file changed, 98 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index d234822a23..b67cf38139 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -65,6 +65,35 @@ struct search_params : ann::search_params { static_assert(std::is_aggregate_v); static_assert(std::is_aggregate_v); +/** The data for a single list (cluster). */ +template +struct list_data { + /** Cluster data. */ + device_mdarray, row_major> data; + /** Source indices. */ + device_mdarray, row_major> indices; + /** The actual size of the content. */ + std::atomic size; + + list_data(raft::device_resources const& handle, IdxT n_rows, uint32_t dim) + : size{n_rows} + { + auto capacity = round_up_safe(bound_by_power_of_two(size), kIndexGroupSize); + try { + data = make_device_mdarray(handle, make_extents(capacity, dim)); + indices = make_device_mdarray(handle, make_extents(capacity)); + } catch (std::bad_alloc& e) { + RAFT_FAIL( + "ivf-flat: failed to allocate a big enough index list to hold all data " + "(requested size: %zu records, selected capacity: %zu records). " + "Allocator exception: %s", + size_t(size), + size_t(capacity), + e.what()); + } + } +}; + /** * @brief IVF-flat index. * @@ -78,6 +107,12 @@ struct index : ann::index { "IdxT must be able to represent all values of uint32_t"); public: + /** + * Default value filled in the `indices()` array. + * One may encounter it trying to access a record within a cluster that is outside of the + * `list_sizes()` bound (due to the record alignment `kIndexGroupSize`). + */ + constexpr static IdxT kInvalidRecord = std::numeric_limits::max() - 1; /** * Vectorized load/store size in elements, determines the size of interleaved data chunks. * @@ -252,8 +287,60 @@ struct index : ann::index { * Replace the content of the index with new uninitialized mdarrays to hold the indicated amount * of data. */ - void allocate(raft::device_resources const& handle, IdxT index_size) + void resize_list(raft::device_resources const& handle, uint32_t label, uint32_t list_size) { + uint32_t prev_size = 0; + auto& list = lists()(label); + bool skip_resize = false; + if (list) { + copy(&prev_size, &list_sizes()(label), 1, handle.get_stream()); + handle.sync_stream(); + if (list_size <= list->indices.extent(0)) { + auto shared_list_size = prev_size; + if (list_size <= prev_size || + list->size.compare_exchange_strong(shared_list_size, list_size)) { + // We don't need to resize the list if: + // 1. The list exists + // 2. The new size fits in the list + // 3. The list doesn't grow or no-one else has grown it yet + skip_resize = true; + } + } + } + // Sic! We're writing the min(list_size, prev_size) + // to keep the number of _valid_ records after update + const auto new_list_size = std::min(list_size, prev_size); + raft::copy(&list_sizes()(label), &new_list_size, 1, handle.get_stream()); + if (skip_resize) { return; } + auto new_list = new list_data(handle, list_size, dim()); + if (prev_size > 0) { + auto copied_data_extents = make_extents(prev_size, dim()); + auto copied_view = make_mdspan( + new_list->data.data_handle(), copied_data_extents); + copy(copied_view.data_handle(), + list->data.data_handle(), + copied_view.size(), + handle.get_stream()); + copy(new_list->indices.data_handle(), + list->indices.data_handle(), + prev_size, + handle.get_stream()); + } + // swap the shared pointer content with the new list + list.reset(new_list); + // fill unused index spaces with placeholder values for easier debugging + thrust::fill_n(handle.get_thrust_policy(), + list->indices.data_handle() + prev_size, + list->indices.size() - prev_size, + kInvalidRecord); + // keep the device pointers updated + const auto new_data_ptr = list->data.data_handle(); + copy(&(data_ptrs()(label)), &new_data_ptr, 1, handle.get_stream()); + const auto new_inds_ptr = list->indices.data_handle(); + copy(&(inds_ptrs()(label)), &new_inds_ptr, 1, handle.get_stream()); + + + data_ = make_device_mdarray(handle, make_extents(index_size, dim())); indices_ = make_device_mdarray(handle, make_extents(index_size)); @@ -278,13 +365,19 @@ struct index : ann::index { uint32_t veclen_; raft::distance::DistanceType metric_; bool adaptive_centers_; - device_mdarray, row_major> data_; - device_mdarray, row_major> indices_; - device_mdarray, row_major> list_sizes_; - device_mdarray, row_major> list_offsets_; + host_mdarray, extent_1d, row_major> lists_; + //device_mdarray, row_major> data_; + //device_mdarray, row_major> indices_; + //device_mdarray, row_major> list_sizes_; + //device_mdarray, row_major> list_offsets_; device_mdarray, row_major> centers_; std::optional, row_major>> center_norms_; + // Computed members + device_mdarray, row_major> data_ptrs_; + device_mdarray, row_major> inds_ptrs_; + host_mdarray, row_major> accum_sorted_sizes_; + /** Throw an error if the index content is inconsistent. */ void check_consistency() { From a83aca4d312e6c2299e89abb4f4034f705605113 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 7 Feb 2023 18:59:14 +0100 Subject: [PATCH 02/33] Adapt `extend` --- cpp/include/raft/neighbors/ivf_flat_types.hpp | 205 +++++++----------- .../spatial/knn/detail/ivf_flat_build.cuh | 122 ++++++++--- 2 files changed, 175 insertions(+), 152 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index b67cf38139..b384cbcbb3 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -66,22 +66,22 @@ static_assert(std::is_aggregate_v); static_assert(std::is_aggregate_v); /** The data for a single list (cluster). */ -template +template struct list_data { /** Cluster data. */ - device_mdarray, row_major> data; + device_mdarray, row_major> data; /** Source indices. */ - device_mdarray, row_major> indices; + device_mdarray, row_major> indices; /** The actual size of the content. */ - std::atomic size; + std::atomic size; - list_data(raft::device_resources const& handle, IdxT n_rows, uint32_t dim) + list_data(raft::device_resources const& handle, SizeT n_rows, uint32_t dim) : size{n_rows} { - auto capacity = round_up_safe(bound_by_power_of_two(size), kIndexGroupSize); + auto capacity = round_up_safe(bound_by_power_of_two(size), kIndexGroupSize); try { - data = make_device_mdarray(handle, make_extents(capacity, dim)); - indices = make_device_mdarray(handle, make_extents(capacity)); + data = make_device_mdarray(handle, make_extents(capacity, dim)); + indices = make_device_mdarray(handle, make_extents(capacity)); } catch (std::bad_alloc& e) { RAFT_FAIL( "ivf-flat: failed to allocate a big enough index list to hold all data " @@ -153,27 +153,6 @@ struct index : ann::index { * x[16, 4], x[16, 5], x[17, 4], x[17, 5], ... x[30, 4], x[30, 5], - , - , * */ - inline auto data() noexcept -> device_mdspan, row_major> - { - return data_.view(); - } - [[nodiscard]] inline auto data() const noexcept - -> device_mdspan, row_major> - { - return data_.view(); - } - - /** Inverted list indices: ids of items in the source data [size] */ - inline auto indices() noexcept -> device_mdspan, row_major> - { - return indices_.view(); - } - [[nodiscard]] inline auto indices() const noexcept - -> device_mdspan, row_major> - { - return indices_.view(); - } - /** Sizes of the lists (clusters) [n_lists] */ inline auto list_sizes() noexcept -> device_mdspan, row_major> { @@ -185,20 +164,6 @@ struct index : ann::index { return list_sizes_.view(); } - /** - * Offsets into the lists [n_lists + 1]. - * The last value contains the total length of the index. - */ - inline auto list_offsets() noexcept -> device_mdspan, row_major> - { - return list_offsets_.view(); - } - [[nodiscard]] inline auto list_offsets() const noexcept - -> device_mdspan, row_major> - { - return list_offsets_.view(); - } - /** k-means cluster centers corresponding to the lists [n_lists, dim] */ inline auto centers() noexcept -> device_mdspan, row_major> { @@ -238,7 +203,10 @@ struct index : ann::index { } /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return indices_.extent(0); } + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT + { + return total_size_; + } /** Dimensionality of the data. */ [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { @@ -247,7 +215,7 @@ struct index : ann::index { /** Number of clusters/inverted lists. */ [[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t { - return centers_.extent(0); + return lists_.extent(0); } // Don't allow copying the index for performance reasons (try avoiding copying data) @@ -267,14 +235,17 @@ struct index : ann::index { veclen_(calculate_veclen(dim)), metric_(metric), adaptive_centers_(adaptive_centers), - data_(make_device_mdarray(handle, make_extents(0, dim))), - indices_(make_device_mdarray(handle, make_extents(0))), - list_sizes_(make_device_mdarray(handle, make_extents(n_lists))), - list_offsets_(make_device_mdarray(handle, make_extents(n_lists + 1))), centers_(make_device_mdarray(handle, make_extents(n_lists, dim))), - center_norms_(std::nullopt) + center_norms_(std::nullopt), + lists_{make_host_mdarray>>(make_extents(n_lists))}, + list_sizes_{make_device_mdarray(handle, make_extents(n_lists))}, + data_ptrs_{make_device_mdarray(handle, make_extents(n_lists))}, + inds_ptrs_{make_device_mdarray(handle, make_extents(n_lists))} { check_consistency(); + for (uint32_t i = 0; i < n_lists; i++) { + lists_(i) = std::shared_ptr>(); + } } /** Construct an empty index. It needs to be trained and then populated. */ @@ -283,78 +254,61 @@ struct index : ann::index { { } - /** - * Replace the content of the index with new uninitialized mdarrays to hold the indicated amount - * of data. - */ - void resize_list(raft::device_resources const& handle, uint32_t label, uint32_t list_size) + /** Pointers to the inverted lists (clusters) data [n_lists]. */ + inline auto data_ptrs() noexcept -> device_mdspan, row_major> { - uint32_t prev_size = 0; - auto& list = lists()(label); - bool skip_resize = false; - if (list) { - copy(&prev_size, &list_sizes()(label), 1, handle.get_stream()); - handle.sync_stream(); - if (list_size <= list->indices.extent(0)) { - auto shared_list_size = prev_size; - if (list_size <= prev_size || - list->size.compare_exchange_strong(shared_list_size, list_size)) { - // We don't need to resize the list if: - // 1. The list exists - // 2. The new size fits in the list - // 3. The list doesn't grow or no-one else has grown it yet - skip_resize = true; - } - } - } - // Sic! We're writing the min(list_size, prev_size) - // to keep the number of _valid_ records after update - const auto new_list_size = std::min(list_size, prev_size); - raft::copy(&list_sizes()(label), &new_list_size, 1, handle.get_stream()); - if (skip_resize) { return; } - auto new_list = new list_data(handle, list_size, dim()); - if (prev_size > 0) { - auto copied_data_extents = make_extents(prev_size, dim()); - auto copied_view = make_mdspan( - new_list->data.data_handle(), copied_data_extents); - copy(copied_view.data_handle(), - list->data.data_handle(), - copied_view.size(), - handle.get_stream()); - copy(new_list->indices.data_handle(), - list->indices.data_handle(), - prev_size, - handle.get_stream()); - } - // swap the shared pointer content with the new list - list.reset(new_list); - // fill unused index spaces with placeholder values for easier debugging - thrust::fill_n(handle.get_thrust_policy(), - list->indices.data_handle() + prev_size, - list->indices.size() - prev_size, - kInvalidRecord); - // keep the device pointers updated - const auto new_data_ptr = list->data.data_handle(); - copy(&(data_ptrs()(label)), &new_data_ptr, 1, handle.get_stream()); - const auto new_inds_ptr = list->indices.data_handle(); - copy(&(inds_ptrs()(label)), &new_inds_ptr, 1, handle.get_stream()); - - + return data_ptrs_.view(); + } + [[nodiscard]] inline auto data_ptrs() const noexcept + -> device_mdspan, row_major> + { + return data_ptrs_.view(); + } - data_ = make_device_mdarray(handle, make_extents(index_size, dim())); - indices_ = make_device_mdarray(handle, make_extents(index_size)); + /** Pointers to the inverted lists (clusters) indices [n_lists]. */ + inline auto inds_ptrs() noexcept -> device_mdspan, row_major> + { + return inds_ptrs_.view(); + } + [[nodiscard]] inline auto inds_ptrs() const noexcept + -> device_mdspan, row_major> + { + return inds_ptrs_.view(); + } - switch (metric_) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - center_norms_ = make_device_mdarray(handle, make_extents(n_lists())); - break; - default: center_norms_ = std::nullopt; + /** + * Update the state of the dependent index members. + */ + void recompute_internal_state(raft::device_resources const& res) + { + auto stream = res.get_stream(); + + // Actualize the list pointers + auto lists = lists(); + auto data_ptrs = data_ptrs(); + auto inds_ptrs = inds_ptrs(); + IdxT recompute_total_size = 0; + for (uint32_t label = 0; label < lists.size(); label++) { + const auto data_ptr = lists(label) ? lists(label)->data.data_handle() : nullptr; + const auto inds_ptr = lists(label) ? lists(label)->indices.data_handle() : nullptr; + const auto list_size = lists(label) ? lists(label)->size() : 0; + copy(&data_ptrs(label), &data_ptr, 1, stream); + copy(&inds_ptrs(label), &inds_ptr, 1, stream); + recompute_total_size += list_size; } + total_size_ = recompute_total_size; + } - check_consistency(); + /** Lists' data and indices. */ + inline auto lists() noexcept + -> host_mdspan>, extent_1d, row_major> + { + return lists_.view(); + } + [[nodiscard]] inline auto lists() const noexcept + -> host_mdspan>, extent_1d, row_major> + { + return lists_.view(); } private: @@ -365,10 +319,10 @@ struct index : ann::index { uint32_t veclen_; raft::distance::DistanceType metric_; bool adaptive_centers_; - host_mdarray, extent_1d, row_major> lists_; + host_mdarray>, extent_1d, row_major> lists_; //device_mdarray, row_major> data_; //device_mdarray, row_major> indices_; - //device_mdarray, row_major> list_sizes_; + device_mdarray, row_major> list_sizes_; //device_mdarray, row_major> list_offsets_; device_mdarray, row_major> centers_; std::optional, row_major>> center_norms_; @@ -376,21 +330,22 @@ struct index : ann::index { // Computed members device_mdarray, row_major> data_ptrs_; device_mdarray, row_major> inds_ptrs_; - host_mdarray, row_major> accum_sorted_sizes_; + IdxT total_size_; /** Throw an error if the index content is inconsistent. */ void check_consistency() { + auto n_lists = lists_.extent(0); RAFT_EXPECTS(dim() % veclen_ == 0, "dimensionality is not a multiple of the veclen"); - RAFT_EXPECTS(data_.extent(0) == indices_.extent(0), "inconsistent index size"); - RAFT_EXPECTS(data_.extent(1) == IdxT(centers_.extent(1)), "inconsistent data dimensionality"); + RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); RAFT_EXPECTS( // (centers_.extent(0) == list_sizes_.extent(0)) && // - (centers_.extent(0) + 1 == list_offsets_.extent(0)) && // (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), "inconsistent number of lists (clusters)"); - RAFT_EXPECTS(reinterpret_cast(data_.data_handle()) % (veclen_ * sizeof(T)) == 0, - "The data storage pointer is not aligned to the vector length"); + //RAFT_EXPECTS(reinterpret_cast(data_.data_handle()) % (veclen_ * sizeof(T)) == 0, + // "The data storage pointer is not aligned to the vector length"); } static auto calculate_veclen(uint32_t dim) -> uint32_t diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index c417a97531..1b8bc307b6 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -41,6 +41,63 @@ namespace raft::spatial::knn::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT +/** + * Resize a list by the given id, so that it can contain the given number of records; + * possibly, copy the data. + * + * Besides resizing the corresponding list_data, this function updates the device pointers + * data_ptrs, inds_ptrs, and the list_sizes if necessary. + * + * The new `list_sizes(label)` represents the number of valid records in the index; + * it can be `list_size` if the previous size was not smaller; otherwise it's not updated. + * + * @param[in] handle + * @param[in] label list id + * @param[in] list_size the minimum size the list should grow. + */ +template +void resize_list(raft::device_resources const& handle, + std::shared_ptr>& orig_list, + SizeT new_used_size, + SizeT old_used_size, + uint32_t dim) +{ + bool skip_resize = false; + // TODO update total_size + if (orig_list) { + if (new_used_size <= orig_list->indices.extent(0)) { + auto shared_list_size = old_used_size; + if (new_used_size <= old_used_size || + orig_list->size.compare_exchange_strong(shared_list_size, new_used_size)) { + // We don't need to resize the list if: + // 1. The list exists + // 2. The new size fits in the list + // 3. The list doesn't grow or no-one else has grown it yet + skip_resize = true; + } + } + } else { + old_used_size = 0; + } + if (skip_resize) { return; } + auto new_list = std::make_shared>(res, new_used_size, dim); + if (old_used_size > 0) { + auto copied_data_extents = make_extents(old_used_size, dim); + auto copied_view = make_mdspan( + new_list->data.data_handle(), copied_data_extents); + copy(copied_view.data_handle(), + orig_list->data.data_handle(), + copied_view.size(), + res.get_stream()); + copy(new_list->indices.data_handle(), + orig_list->indices.data_handle(), + old_used_size, + res.get_stream()); + } + // swap the shared pointer content with the new list + new_list.swap(orig_list); +} + /** * @brief Record the dataset into the index, one source row at a time. * @@ -119,29 +176,30 @@ __global__ void build_index_kernel(const LabelT* labels, /** See raft::spatial::knn::ivf_flat::extend docs */ template -inline auto extend(raft::device_resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index +void extend(raft::device_resources const& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) { using LabelT = uint32_t; + RAFT_EXPECTS(orig_index != nullptr, "index cannot be empty."); auto stream = handle.get_stream(); - auto n_lists = orig_index.n_lists(); - auto dim = orig_index.dim(); + auto n_lists = index->n_lists(); + auto dim = index->dim(); common::nvtx::range fun_scope( "ivf_flat::extend(%zu, %u)", size_t(n_rows), dim); - RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0, + RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, "You must pass data indices when the index is non-empty."); rmm::device_uvector new_labels(n_rows, stream); raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = orig_index.metric(); + kmeans_params.metric = index->metric(); auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); auto orig_centroids_view = raft::make_device_matrix_view( - orig_index.centers().data_handle(), n_lists, dim); + index->centers().data_handle(), n_lists, dim); auto labels_view = raft::make_device_vector_view(new_labels.data(), n_rows); raft::cluster::kmeans_balanced::predict(handle, kmeans_params, @@ -150,20 +208,13 @@ inline auto extend(raft::device_resources const& handle, labels_view, utils::mapping{}); - index ext_index( - handle, orig_index.metric(), n_lists, orig_index.adaptive_centers(), dim); - - auto list_sizes_ptr = ext_index.list_sizes().data_handle(); - auto list_offsets_ptr = ext_index.list_offsets().data_handle(); - auto centers_ptr = ext_index.centers().data_handle(); + auto list_sizes_ptr = index->list_sizes().data_handle(); + //auto list_offsets_ptr = ext_index.list_offsets().data_handle(); // Calculate the centers and sizes on the new data, starting from the original values - raft::copy(centers_ptr, orig_index.centers().data_handle(), ext_index.centers().size(), stream); - if (ext_index.adaptive_centers()) { - raft::copy( - list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream); - auto centroids_view = raft::make_device_matrix_view(centers_ptr, n_lists, dim); + if (index->adaptive_centers()) { + auto centroids_view = index->centers().view(); auto list_sizes_view = raft::make_device_vector_view, IdxT>( list_sizes_ptr, n_lists); @@ -177,6 +228,7 @@ inline auto extend(raft::device_resources const& handle, false, utils::mapping{}); } else { + auto new_list_sizes = raft::make_device_vector(handle, n_lists); raft::stats::histogram(raft::stats::HistTypeAuto, reinterpret_cast(list_sizes_ptr), IdxT(n_lists), @@ -188,8 +240,24 @@ inline auto extend(raft::device_resources const& handle, list_sizes_ptr, list_sizes_ptr, orig_index.list_sizes().data_handle(), n_lists, stream); } - // Calculate new offsets - IdxT index_size = 0; + // Calculate and allocate new list data + { + std::vector new_cluster_sizes(n_lists); + std::vector old_cluster_sizes(n_lists); + copy(new_cluster_sizes.data(), list_sizes, n_lists, stream); + copy(old_cluster_sizes.data(), orig_list_sizes.data(), n_lists, stream); + handle.sync_stream(); + auto lists = index->lists(); + for (uint32_t label = 0; label < n_lists; label++) { + resize_list(handle, + lists(label), + new_cluster_sizes[label], + old_cluster_sizes[label], + index->dim()); + } + // Update the pointers and the sizes + index->recompute_internal_state(handle); + /*IdxT index_size = 0; update_device(list_offsets_ptr, &index_size, 1, stream); thrust::inclusive_scan( rmm::exec_policy(stream), @@ -219,7 +287,7 @@ inline auto extend(raft::device_resources const& handle, ext_index.indices().data_handle(), IdxT(1), stream); - } + }*/ // Copy the old sizes, so we can start from the current state of the index; // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. @@ -357,13 +425,13 @@ inline void fill_refinement_index(raft::device_resources const& handle, raft::compose_op(raft::cast_op(), raft::div_const_op(n_candidates))); auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); - auto list_offsets_ptr = refinement_index->list_offsets().data_handle(); + //auto list_offsets_ptr = refinement_index->list_offsets().data_handle(); // We do not fill centers and center norms, since we will not run coarse search. // Calculate new offsets uint32_t n_roundup = Pow2::roundUp(n_candidates); - auto list_offsets_view = raft::make_device_vector_view( - list_offsets_ptr, refinement_index->list_offsets().size()); + //auto list_offsets_view = raft::make_device_vector_view( + // list_offsets_ptr, refinement_index->list_offsets().size()); linalg::map_offset(handle, list_offsets_view, raft::compose_op(raft::cast_op(), raft::mul_const_op(n_roundup))); From 843904a46307b6cb0b179631d836909f1f38168f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 9 Feb 2023 23:09:30 +0100 Subject: [PATCH 03/33] Refactoring: build and extend fix --- cpp/include/raft/neighbors/ivf_flat.cuh | 12 +- cpp/include/raft/neighbors/ivf_flat_types.hpp | 21 +- .../spatial/knn/detail/ivf_flat_build.cuh | 189 ++++++++---------- cpp/test/neighbors/ann_ivf_flat.cuh | 28 +-- 4 files changed, 115 insertions(+), 135 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index f18611b9f1..d210d2c74b 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -252,7 +252,7 @@ void extend(raft::device_resources const& handle, const IdxT* new_indices, IdxT n_rows) { - *index = extend(handle, *index, new_vectors, new_indices, n_rows); + raft::spatial::knn::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); } /** @@ -293,11 +293,11 @@ void extend(raft::device_resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices = std::nullopt) { - *index = extend(handle, - *index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); + extend(handle, + index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + static_cast(new_vectors.extent(0))); } /** @} */ diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index b384cbcbb3..fcca6cd858 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -19,10 +19,13 @@ #include "ann_types.hpp" #include +#include +#include #include #include #include +#include #include #include @@ -284,16 +287,16 @@ struct index : ann::index { auto stream = res.get_stream(); // Actualize the list pointers - auto lists = lists(); - auto data_ptrs = data_ptrs(); - auto inds_ptrs = inds_ptrs(); + auto this_lists = lists(); + auto this_data_ptrs = data_ptrs(); + auto this_inds_ptrs = inds_ptrs(); IdxT recompute_total_size = 0; - for (uint32_t label = 0; label < lists.size(); label++) { - const auto data_ptr = lists(label) ? lists(label)->data.data_handle() : nullptr; - const auto inds_ptr = lists(label) ? lists(label)->indices.data_handle() : nullptr; - const auto list_size = lists(label) ? lists(label)->size() : 0; - copy(&data_ptrs(label), &data_ptr, 1, stream); - copy(&inds_ptrs(label), &inds_ptr, 1, stream); + for (uint32_t label = 0; label < this_lists.size(); label++) { + const auto data_ptr = this_lists(label) ? this_lists(label)->data.data_handle() : nullptr; + const auto inds_ptr = this_lists(label) ? this_lists(label)->indices.data_handle() : nullptr; + const auto list_size = this_lists(label) ? IdxT(this_lists(label)->size) : 0; + copy(&this_data_ptrs(label), &data_ptr, 1, stream); + copy(&this_inds_ptrs(label), &inds_ptr, 1, stream); recompute_total_size += list_size; } total_size_ = recompute_total_size; diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 1b8bc307b6..ca595f7303 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -16,7 +16,6 @@ #pragma once -#include "../ivf_flat_types.hpp" #include "ann_utils.cuh" #include @@ -29,6 +28,7 @@ #include #include #include +#include #include #include @@ -40,7 +40,13 @@ namespace raft::spatial::knn::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT +using namespace raft::neighbors::ivf_flat; // NOLINT +using raft::neighbors::ivf_flat::index; +using raft::neighbors::ivf_flat::index_params; +using raft::neighbors::ivf_flat::kIndexGroupSize; +using raft::neighbors::ivf_flat::search_params; +using raft::neighbors::ivf_flat::list_data; /** * Resize a list by the given id, so that it can contain the given number of records; * possibly, copy the data. @@ -80,7 +86,7 @@ void resize_list(raft::device_resources const& handle, old_used_size = 0; } if (skip_resize) { return; } - auto new_list = std::make_shared>(res, new_used_size, dim); + auto new_list = std::make_shared>(handle, new_used_size, dim); if (old_used_size > 0) { auto copied_data_extents = make_extents(old_used_size, dim); auto copied_view = make_mdspan( @@ -88,11 +94,11 @@ void resize_list(raft::device_resources const& handle, copy(copied_view.data_handle(), orig_list->data.data_handle(), copied_view.size(), - res.get_stream()); + handle.get_stream()); copy(new_list->indices.data_handle(), orig_list->indices.data_handle(), old_used_size, - res.get_stream()); + handle.get_stream()); } // swap the shared pointer content with the new list new_list.swap(orig_list); @@ -131,11 +137,11 @@ void resize_list(raft::device_resources const& handle, */ template __global__ void build_index_kernel(const LabelT* labels, - const IdxT* list_offsets, + //const IdxT* list_offsets, const T* source_vecs, const IdxT* source_ixs, - T* list_data, - IdxT* list_index, + T** list_data_ptrs, + IdxT** list_index_ptrs, uint32_t* list_sizes_ptr, IdxT n_rows, uint32_t dim, @@ -146,10 +152,11 @@ __global__ void build_index_kernel(const LabelT* labels, auto list_id = labels[i]; auto inlist_id = atomicAdd(list_sizes_ptr + list_id, 1); - auto list_offset = list_offsets[list_id]; + auto* list_index = list_index_ptrs[list_id]; + auto* list_data = list_data_ptrs[list_id]; // Record the source vector id in the index - list_index[list_offset + inlist_id] = source_ixs == nullptr ? i : source_ixs[i]; + list_index[inlist_id] = source_ixs == nullptr ? i : source_ixs[i]; // The data is written in interleaved groups of `index::kGroupSize` vectors using interleaved_group = Pow2; @@ -157,7 +164,7 @@ __global__ void build_index_kernel(const LabelT* labels, auto ingroup_id = interleaved_group::mod(inlist_id) * veclen; // Point to the location of the interleaved group of vectors - list_data += (list_offset + group_offset) * dim; + list_data += group_offset * dim; // Point to the source vector if constexpr (gather_src) { @@ -183,7 +190,7 @@ void extend(raft::device_resources const& handle, IdxT n_rows) { using LabelT = uint32_t; - RAFT_EXPECTS(orig_index != nullptr, "index cannot be empty."); + RAFT_EXPECTS(index != nullptr, "index cannot be empty."); auto stream = handle.get_stream(); auto n_lists = index->n_lists(); @@ -194,32 +201,30 @@ void extend(raft::device_resources const& handle, RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, "You must pass data indices when the index is non-empty."); - rmm::device_uvector new_labels(n_rows, stream); + auto new_labels = raft::make_device_vector(handle, n_rows); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.metric = index->metric(); auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); auto orig_centroids_view = raft::make_device_matrix_view( index->centers().data_handle(), n_lists, dim); - auto labels_view = raft::make_device_vector_view(new_labels.data(), n_rows); raft::cluster::kmeans_balanced::predict(handle, kmeans_params, new_vectors_view, orig_centroids_view, - labels_view, + new_labels.view(), utils::mapping{}); - auto list_sizes_ptr = index->list_sizes().data_handle(); - //auto list_offsets_ptr = ext_index.list_offsets().data_handle(); + auto* list_sizes_ptr = index->list_sizes().data_handle(); + auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); + copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); // Calculate the centers and sizes on the new data, starting from the original values - if (index->adaptive_centers()) { - auto centroids_view = index->centers().view(); + auto centroids_view = raft::make_device_matrix_view(index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); auto list_sizes_view = raft::make_device_vector_view, IdxT>( list_sizes_ptr, n_lists); - auto const_labels_view = - raft::make_device_vector_view(new_labels.data(), n_rows); + auto const_labels_view = make_const_mdspan(new_labels.view()); raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, new_vectors_view, const_labels_view, @@ -228,109 +233,82 @@ void extend(raft::device_resources const& handle, false, utils::mapping{}); } else { - auto new_list_sizes = raft::make_device_vector(handle, n_lists); raft::stats::histogram(raft::stats::HistTypeAuto, reinterpret_cast(list_sizes_ptr), IdxT(n_lists), - new_labels.data(), + new_labels.data_handle(), n_rows, 1, stream); raft::linalg::add( - list_sizes_ptr, list_sizes_ptr, orig_index.list_sizes().data_handle(), n_lists, stream); + list_sizes_ptr, list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); } // Calculate and allocate new list data { - std::vector new_cluster_sizes(n_lists); - std::vector old_cluster_sizes(n_lists); - copy(new_cluster_sizes.data(), list_sizes, n_lists, stream); - copy(old_cluster_sizes.data(), orig_list_sizes.data(), n_lists, stream); + std::vector new_list_sizes(n_lists); + std::vector old_list_sizes(n_lists); + copy(old_list_sizes.data(), old_list_sizes_dev.data_handle(), n_lists, stream); + copy(new_list_sizes.data(), list_sizes_ptr, n_lists, stream); handle.sync_stream(); auto lists = index->lists(); for (uint32_t label = 0; label < n_lists; label++) { resize_list(handle, lists(label), - new_cluster_sizes[label], - old_cluster_sizes[label], + new_list_sizes[label], + old_list_sizes[label], index->dim()); + } } // Update the pointers and the sizes index->recompute_internal_state(handle); - /*IdxT index_size = 0; - update_device(list_offsets_ptr, &index_size, 1, stream); - thrust::inclusive_scan( - rmm::exec_policy(stream), - list_sizes_ptr, - list_sizes_ptr + n_lists, - list_offsets_ptr + 1, - [] __device__(IdxT s, uint32_t l) { return s + Pow2::roundUp(l); }); - update_host(&index_size, list_offsets_ptr + n_lists, 1, stream); - handle.sync_stream(stream); - - ext_index.allocate(handle, index_size); - - // Populate index with the old data - if (orig_index.size() > 0) { - utils::block_copy(orig_index.list_offsets().data_handle(), - list_offsets_ptr, - IdxT(n_lists), - orig_index.data().data_handle(), - ext_index.data().data_handle(), - IdxT(dim), - stream); - - utils::block_copy(orig_index.list_offsets().data_handle(), - list_offsets_ptr, - IdxT(n_lists), - orig_index.indices().data_handle(), - ext_index.indices().data_handle(), - IdxT(1), - stream); - }*/ // Copy the old sizes, so we can start from the current state of the index; // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. raft::copy( - list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream); + list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); + // Kernel to insert the new vectors const dim3 block_dim(256); const dim3 grid_dim(raft::ceildiv(n_rows, block_dim.x)); - build_index_kernel<<>>(new_labels.data(), - list_offsets_ptr, + build_index_kernel<<>>(new_labels.data_handle(), new_vectors, new_indices, - ext_index.data().data_handle(), - ext_index.indices().data_handle(), + index->data_ptrs().data_handle(),//ext_index.data().data_handle(), + index->inds_ptrs().data_handle(),//ext_index.indices().data_handle(), list_sizes_ptr, n_rows, dim, - ext_index.veclen()); + index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Precompute the centers vector norms for L2Expanded distance - if (ext_index.center_norms().has_value()) { - if (!ext_index.adaptive_centers() && orig_index.center_norms().has_value()) { - raft::copy(ext_index.center_norms()->data_handle(), - orig_index.center_norms()->data_handle(), - orig_index.center_norms()->size(), - stream); - } else { - raft::linalg::rowNorm(ext_index.center_norms()->data_handle(), - ext_index.centers().data_handle(), - dim, - n_lists, - raft::linalg::L2Norm, - true, - stream); - RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min(dim, 20)); - } + if (index->center_norms().has_value() && index->adaptive_centers()) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream); + RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } +} - // assemble the index +/** See raft::spatial::knn::ivf_flat::extend docs */ +template +auto extend(raft::device_resources const& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index +{ + auto ext_index = clone(handle, &orig_index); + extend(handle, &ext_index, new_vectors, new_indices, n_rows); return ext_index; } + /** See raft::spatial::knn::ivf_flat::build docs */ template inline auto build(raft::device_resources const& handle, @@ -348,7 +326,8 @@ inline auto build(raft::device_resources const& handle, index index(handle, params, dim); utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); - utils::memzero(index.list_offsets().data_handle(), index.list_offsets().size(), stream); + utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); + utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); // Train the kmeans clustering { @@ -378,10 +357,9 @@ inline auto build(raft::device_resources const& handle, // add the data if necessary if (params.add_data_on_build) { - return detail::extend(handle, index, dataset, nullptr, n_rows); - } else { - return index; + detail::extend(handle, &index, dataset, nullptr, n_rows); } + return index; } /** @@ -425,19 +403,17 @@ inline void fill_refinement_index(raft::device_resources const& handle, raft::compose_op(raft::cast_op(), raft::div_const_op(n_candidates))); auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); - //auto list_offsets_ptr = refinement_index->list_offsets().data_handle(); // We do not fill centers and center norms, since we will not run coarse search. - // Calculate new offsets - uint32_t n_roundup = Pow2::roundUp(n_candidates); - //auto list_offsets_view = raft::make_device_vector_view( - // list_offsets_ptr, refinement_index->list_offsets().size()); - linalg::map_offset(handle, - list_offsets_view, - raft::compose_op(raft::cast_op(), raft::mul_const_op(n_roundup))); - - IdxT index_size = n_roundup * n_lists; - refinement_index->allocate(handle, index_size); + // Allocate new memory + auto lists = refinement_index->lists(); + for (uint32_t label = 0; label < n_lists; label++) { + resize_list(handle, + lists(label), + n_candidates, + uint32_t(0), + refinement_index->dim()); + } RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); @@ -445,11 +421,10 @@ inline void fill_refinement_index(raft::device_resources const& handle, const dim3 grid_dim(raft::ceildiv(n_queries * n_candidates, block_dim.x)); build_index_kernel <<>>(new_labels.data(), - list_offsets_ptr, dataset, candidate_idx, - refinement_index->data().data_handle(), - refinement_index->indices().data_handle(), + refinement_index->data_ptrs().data_handle(), + refinement_index->inds_ptrs().data_handle(), list_sizes_ptr, n_queries * n_candidates, refinement_index->dim(), @@ -464,9 +439,10 @@ inline void fill_refinement_index(raft::device_resources const& handle, // compatible fashion. constexpr int serialization_version = 2; +/* TODO static_assert(sizeof(index) == 408, "The size of the index struct has changed since the last update; " - "paste in the new size and consider updating the save/load logic"); + "paste in the new size and consider updating the save/load logic");*/ /** * Save the index to file. @@ -488,6 +464,7 @@ void serialize(raft::device_resources const& handle, RAFT_LOG_DEBUG( "Saving IVF-PQ index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + /* TODO serialize_scalar(handle, of, serialization_version); serialize_scalar(handle, of, index_.size()); serialize_scalar(handle, of, index_.dim()); @@ -507,7 +484,7 @@ void serialize(raft::device_resources const& handle, } else { bool has_norms = false; serialize_scalar(handle, of, has_norms); - } + }*/ of.close(); if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } @@ -541,8 +518,8 @@ auto deserialize(raft::device_resources const& handle, const std::string& filena bool adaptive_centers = deserialize_scalar(handle, infile); index index_ = - raft::spatial::knn::ivf_flat::index(handle, metric, n_lists, adaptive_centers, dim); - + index(handle, metric, n_lists, adaptive_centers, dim); + /* TODO index_.allocate(handle, n_rows); auto data = index_.data(); deserialize_mdspan(handle, infile, data); @@ -558,7 +535,7 @@ auto deserialize(raft::device_resources const& handle, const std::string& filena auto center_norms = *index_.center_norms(); deserialize_mdspan(handle, infile, center_norms); } - } + }*/ infile.close(); return index_; } diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index c100afb2c4..3807e414c0 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -169,7 +169,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto half_of_data_view = raft::make_device_matrix_view( (const DataT*)database.data(), half_of_data, ps.dim); - auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); + ivf_flat::extend(handle_, &index, half_of_data_view); auto new_half_of_data_view = raft::make_device_matrix_view( database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); @@ -178,7 +178,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); ivf_flat::extend(handle_, - &index_2, + &index, new_half_of_data_view, std::make_optional>( new_half_of_data_indices_view)); @@ -189,7 +189,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { indices_ivfflat_dev.data(), ps.num_queries, ps.k); auto dists_out_view = raft::make_device_matrix_view( distances_ivfflat_dev.data(), ps.num_queries, ps.k); - raft::spatial::knn::ivf_flat::detail::serialize(handle_, "ivf_flat_index", index_2); + raft::spatial::knn::ivf_flat::detail::serialize(handle_, "ivf_flat_index", index); auto index_loaded = raft::spatial::knn::ivf_flat::detail::deserialize(handle_, "ivf_flat_index"); @@ -207,30 +207,30 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { handle_.sync_stream(stream_); // Test the centroid invariants - if (index_2.adaptive_centers()) { + if (index.adaptive_centers()) { // The centers must be up-to-date with the corresponding data - std::vector list_sizes(index_2.n_lists()); - std::vector list_offsets(index_2.n_lists()); + std::vector list_sizes(index.n_lists()); + //std::vector list_offsets(index.n_lists()); rmm::device_uvector centroid(ps.dim, stream_); raft::copy( - list_sizes.data(), index_2.list_sizes().data_handle(), index_2.n_lists(), stream_); - raft::copy( - list_offsets.data(), index_2.list_offsets().data_handle(), index_2.n_lists(), stream_); + list_sizes.data(), index.list_sizes().data_handle(), index.n_lists(), stream_); + //raft::copy( + // list_offsets.data(), index.list_offsets().data_handle(), index.n_lists(), stream_); handle_.sync_stream(stream_); - for (uint32_t l = 0; l < index_2.n_lists(); l++) { + for (uint32_t l = 0; l < index.n_lists(); l++) { rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); raft::spatial::knn::detail::utils::copy_selected( (IdxT)list_sizes[l], (IdxT)ps.dim, database.data(), - index_2.indices().data_handle() + list_offsets[l], + index.inds_ptrs()(l), (IdxT)ps.dim, cluster_data.data(), (IdxT)ps.dim, stream_); raft::stats::mean( centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, true, stream_); - ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle() + ps.dim * l, + ASSERT_TRUE(raft::devArrMatch(index.centers().data_handle() + ps.dim * l, centroid.data(), ps.dim, raft::CompareApprox(0.001), @@ -238,9 +238,9 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { } } else { // The centers must be immutable - ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle(), + ASSERT_TRUE(raft::devArrMatch(index.centers().data_handle(), index.centers().data_handle(), - index_2.centers().size(), + index.centers().size(), raft::Compare(), stream_)); } From f09b3a003949c3b7838b1c197c0b64ee61789858 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 10 Feb 2023 15:49:39 +0100 Subject: [PATCH 04/33] Refactor ivf flat search for index splitting --- .../spatial/knn/detail/ivf_flat_search.cuh | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 7f70d4b8a5..77ecae745a 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -16,7 +16,6 @@ #pragma once -#include "../ivf_flat_types.hpp" #include "ann_utils.cuh" #include @@ -30,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -42,6 +42,12 @@ namespace raft::spatial::knn::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT +using namespace raft::neighbors::ivf_flat; // NOLINT + +using raft::neighbors::ivf_flat::index; +using raft::neighbors::ivf_flat::index_params; +using raft::neighbors::ivf_flat::kIndexGroupSize; +using raft::neighbors::ivf_flat::search_params; constexpr int kThreadsPerBlock = 128; @@ -673,10 +679,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, - const IdxT* list_indices, - const T* list_data, + const IdxT* const* list_indices_ptrs, // const IdxT* list_indices + const T* const* list_data_ptrs, // const T* list_data const uint32_t* list_sizes, - const IdxT* list_offsets, + //const IdxT* list_offsets, const uint32_t n_probes, const uint32_t k, const uint32_t dim, @@ -723,7 +729,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) // Every CUDA block scans one cluster at a time. for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) - const size_t list_offset = list_offsets[list_id]; + //const size_t list_offset = list_offsets[list_id]; // The number of vectors in each cluster(list); [nlist] const uint32_t list_length = list_sizes[list_id]; @@ -741,7 +747,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) group_id += kNumWarps) { AccT dist = 0; // This is where this warp begins reading data (start position of an interleaved group) - const T* data = list_data + (list_offset + group_id * kIndexGroupSize) * dim; + //const T* data = list_data + (list_offset + group_id * kIndexGroupSize) * dim; + const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; // This is the vector a given lane/thread handles const uint32_t vec_id = group_id * WarpSize + lane_id; @@ -778,7 +785,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) // Enqueue one element per thread const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; - const size_t idx = valid ? static_cast(list_indices[list_offset + vec_id]) : 0; + const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; queue.add(val, idx); } } @@ -819,7 +826,7 @@ template void launch_kernel(Lambda lambda, PostLambda post_process, - const ivf_flat::index& index, + const index& index, const T* queries, const uint32_t* coarse_index, const uint32_t num_queries, @@ -869,10 +876,10 @@ void launch_kernel(Lambda lambda, query_smem_elems, queries, coarse_index, - index.indices().data_handle(), - index.data().data_handle(), - index.list_sizes().data_handle(), - index.list_offsets().data_handle(), + index.inds_ptrs().data_handle(),//indices().data_handle(), + index.data_ptrs().data_handle(),//data().data_handle(), + index.list_sizes().data_handle(),// TODO + //index.list_offsets().data_handle(), n_probes, k, index.dim(), @@ -1056,7 +1063,7 @@ struct select_interleaved_scan_kernel { * @param stream */ template -void ivfflat_interleaved_scan(const ivf_flat::index& index, +void ivfflat_interleaved_scan(const index& index, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, From 93d5b35863b1d1b940fbedd624be3ea3176d45ad Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Sun, 12 Feb 2023 22:57:51 +0100 Subject: [PATCH 05/33] Use mdpsan/mdarray aliases --- cpp/include/raft/neighbors/ivf_flat_types.hpp | 88 +++++++++---------- .../spatial/knn/detail/ivf_flat_build.cuh | 14 ++- .../spatial/knn/detail/ivf_flat_search.cuh | 7 +- 3 files changed, 57 insertions(+), 52 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index fcca6cd858..c57e5e5c6e 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -38,6 +39,14 @@ namespace raft::neighbors::ivf_flat { /** Size of the interleaved group (see `index::data` description). */ constexpr static uint32_t kIndexGroupSize = 32; +/** + * Default value filled in the `indices()` array. + * One may encounter it trying to access a record within a cluster that is outside of the + * `list_sizes()` bound (due to the record alignment `kIndexGroupSize`). + */ +template +constexpr static IdxT kInvalidRecord = std::numeric_limits::max() - 1; + struct index_params : ann::index_params { /** The number of inverted lists (clusters) */ uint32_t n_lists = 1024; @@ -72,19 +81,19 @@ static_assert(std::is_aggregate_v); template struct list_data { /** Cluster data. */ - device_mdarray, row_major> data; + device_matrix data; /** Source indices. */ - device_mdarray, row_major> indices; + device_vector indices; /** The actual size of the content. */ std::atomic size; - list_data(raft::device_resources const& handle, SizeT n_rows, uint32_t dim) + list_data(raft::device_resources const& res, SizeT n_rows, uint32_t dim) : size{n_rows} { auto capacity = round_up_safe(bound_by_power_of_two(size), kIndexGroupSize); try { - data = make_device_mdarray(handle, make_extents(capacity, dim)); - indices = make_device_mdarray(handle, make_extents(capacity)); + data = make_device_matrix(res, capacity, dim); + indices = make_device_vector(res, capacity); } catch (std::bad_alloc& e) { RAFT_FAIL( "ivf-flat: failed to allocate a big enough index list to hold all data " @@ -94,6 +103,8 @@ struct list_data { size_t(capacity), e.what()); } + // Fill the index buffer with a pre-defined marker for easier debugging + matrix::fill(res, indices.view(), ivf_flat::kInvalidRecord); } }; @@ -110,12 +121,6 @@ struct index : ann::index { "IdxT must be able to represent all values of uint32_t"); public: - /** - * Default value filled in the `indices()` array. - * One may encounter it trying to access a record within a cluster that is outside of the - * `list_sizes()` bound (due to the record alignment `kIndexGroupSize`). - */ - constexpr static IdxT kInvalidRecord = std::numeric_limits::max() - 1; /** * Vectorized load/store size in elements, determines the size of interleaved data chunks. * @@ -157,23 +162,23 @@ struct index : ann::index { * */ /** Sizes of the lists (clusters) [n_lists] */ - inline auto list_sizes() noexcept -> device_mdspan, row_major> + inline auto list_sizes() noexcept -> device_vector_view { return list_sizes_.view(); } [[nodiscard]] inline auto list_sizes() const noexcept - -> device_mdspan, row_major> + -> device_vector_view { return list_sizes_.view(); } /** k-means cluster centers corresponding to the lists [n_lists, dim] */ - inline auto centers() noexcept -> device_mdspan, row_major> + inline auto centers() noexcept -> device_matrix_view { return centers_.view(); } [[nodiscard]] inline auto centers() const noexcept - -> device_mdspan, row_major> + -> device_matrix_view { return centers_.view(); } @@ -185,20 +190,20 @@ struct index : ann::index { * calculation. */ inline auto center_norms() noexcept - -> std::optional, row_major>> + -> std::optional> { if (center_norms_.has_value()) { - return std::make_optional, row_major>>( + return std::make_optional>( center_norms_->view()); } else { return std::nullopt; } } [[nodiscard]] inline auto center_norms() const noexcept - -> std::optional, row_major>> + -> std::optional> { if (center_norms_.has_value()) { - return std::make_optional, row_major>>( + return std::make_optional>( center_norms_->view()); } else { return std::nullopt; @@ -229,7 +234,7 @@ struct index : ann::index { ~index() = default; /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::device_resources const& handle, + index(raft::device_resources const& res, raft::distance::DistanceType metric, uint32_t n_lists, bool adaptive_centers, @@ -238,12 +243,12 @@ struct index : ann::index { veclen_(calculate_veclen(dim)), metric_(metric), adaptive_centers_(adaptive_centers), - centers_(make_device_mdarray(handle, make_extents(n_lists, dim))), + centers_(make_device_matrix(res, n_lists, dim)), center_norms_(std::nullopt), - lists_{make_host_mdarray>>(make_extents(n_lists))}, - list_sizes_{make_device_mdarray(handle, make_extents(n_lists))}, - data_ptrs_{make_device_mdarray(handle, make_extents(n_lists))}, - inds_ptrs_{make_device_mdarray(handle, make_extents(n_lists))} + lists_{make_host_vector>, uint32_t>(n_lists)}, + list_sizes_{make_device_vector(res, n_lists)}, + data_ptrs_{make_device_vector(res, n_lists)}, + inds_ptrs_{make_device_vector(res, n_lists)} { check_consistency(); for (uint32_t i = 0; i < n_lists; i++) { @@ -252,29 +257,29 @@ struct index : ann::index { } /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::device_resources const& handle, const index_params& params, uint32_t dim) - : index(handle, params.metric, params.n_lists, params.adaptive_centers, dim) + index(raft::device_resources const& res, const index_params& params, uint32_t dim) + : index(res, params.metric, params.n_lists, params.adaptive_centers, dim) { } /** Pointers to the inverted lists (clusters) data [n_lists]. */ - inline auto data_ptrs() noexcept -> device_mdspan, row_major> + inline auto data_ptrs() noexcept -> device_vector_view { return data_ptrs_.view(); } [[nodiscard]] inline auto data_ptrs() const noexcept - -> device_mdspan, row_major> + -> device_vector_view { return data_ptrs_.view(); } /** Pointers to the inverted lists (clusters) indices [n_lists]. */ - inline auto inds_ptrs() noexcept -> device_mdspan, row_major> + inline auto inds_ptrs() noexcept -> device_vector_view { return inds_ptrs_.view(); } [[nodiscard]] inline auto inds_ptrs() const noexcept - -> device_mdspan, row_major> + -> device_vector_view { return inds_ptrs_.view(); } @@ -304,12 +309,12 @@ struct index : ann::index { /** Lists' data and indices. */ inline auto lists() noexcept - -> host_mdspan>, extent_1d, row_major> + -> host_vector_view>, uint32_t> { return lists_.view(); } [[nodiscard]] inline auto lists() const noexcept - -> host_mdspan>, extent_1d, row_major> + -> host_vector_view>, uint32_t> { return lists_.view(); } @@ -322,17 +327,14 @@ struct index : ann::index { uint32_t veclen_; raft::distance::DistanceType metric_; bool adaptive_centers_; - host_mdarray>, extent_1d, row_major> lists_; - //device_mdarray, row_major> data_; - //device_mdarray, row_major> indices_; - device_mdarray, row_major> list_sizes_; - //device_mdarray, row_major> list_offsets_; - device_mdarray, row_major> centers_; - std::optional, row_major>> center_norms_; + host_vector>, uint32_t> lists_; + device_vector list_sizes_; + device_matrix centers_; + std::optional> center_norms_; // Computed members - device_mdarray, row_major> data_ptrs_; - device_mdarray, row_major> inds_ptrs_; + device_vector data_ptrs_; + device_vector inds_ptrs_; IdxT total_size_; /** Throw an error if the index content is inconsistent. */ @@ -347,8 +349,6 @@ struct index : ann::index { (centers_.extent(0) == list_sizes_.extent(0)) && // (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), "inconsistent number of lists (clusters)"); - //RAFT_EXPECTS(reinterpret_cast(data_.data_handle()) % (veclen_ * sizeof(T)) == 0, - // "The data storage pointer is not aligned to the vector length"); } static auto calculate_veclen(uint32_t dim) -> uint32_t diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index ca595f7303..1f31a54dce 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -439,10 +439,16 @@ inline void fill_refinement_index(raft::device_resources const& handle, // compatible fashion. constexpr int serialization_version = 2; -/* TODO -static_assert(sizeof(index) == 408, - "The size of the index struct has changed since the last update; " - "paste in the new size and consider updating the save/load logic");*/ +// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error +// message. +template +struct check_index_layout { + static_assert(RealSize == ExpectedSize, + "The size of the index struct has changed since the last update; " + "paste in the new size and consider updating the serialization logic"); +}; + +template struct check_index_layout), 376>; /** * Save the index to file. diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 77ecae745a..8ab3d06efc 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -876,10 +876,9 @@ void launch_kernel(Lambda lambda, query_smem_elems, queries, coarse_index, - index.inds_ptrs().data_handle(),//indices().data_handle(), - index.data_ptrs().data_handle(),//data().data_handle(), - index.list_sizes().data_handle(),// TODO - //index.list_offsets().data_handle(), + index.inds_ptrs().data_handle(), + index.data_ptrs().data_handle(), + index.list_sizes().data_handle(), n_probes, k, index.dim(), From 5fcf5647077b7da6cf3ec392990140aa083d833c Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 13 Feb 2023 16:47:36 +0100 Subject: [PATCH 06/33] Add serialization --- cpp/include/raft/neighbors/ivf_flat_types.hpp | 9 +- .../spatial/knn/detail/ivf_flat_build.cuh | 84 ++++++++++++++++--- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index c57e5e5c6e..476fa755e0 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -23,11 +23,11 @@ #include #include #include -#include #include #include #include +#include #include namespace raft::neighbors::ivf_flat { @@ -90,7 +90,7 @@ struct list_data { list_data(raft::device_resources const& res, SizeT n_rows, uint32_t dim) : size{n_rows} { - auto capacity = round_up_safe(bound_by_power_of_two(size), kIndexGroupSize); + auto capacity = round_up_safe(n_rows, kIndexGroupSize); try { data = make_device_matrix(res, capacity, dim); indices = make_device_vector(res, capacity); @@ -99,12 +99,13 @@ struct list_data { "ivf-flat: failed to allocate a big enough index list to hold all data " "(requested size: %zu records, selected capacity: %zu records). " "Allocator exception: %s", - size_t(size), + size_t(n_rows), size_t(capacity), e.what()); } // Fill the index buffer with a pre-defined marker for easier debugging - matrix::fill(res, indices.view(), ivf_flat::kInvalidRecord); + thrust::fill_n( + res.get_thrust_policy(), indices.data_handle(), indices.size(), kInvalidRecord); } }; diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 1f31a54dce..5a90be93ae 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -274,8 +274,8 @@ void extend(raft::device_resources const& handle, build_index_kernel<<>>(new_labels.data_handle(), new_vectors, new_indices, - index->data_ptrs().data_handle(),//ext_index.data().data_handle(), - index->inds_ptrs().data_handle(),//ext_index.indices().data_handle(), + index->data_ptrs().data_handle(), + index->inds_ptrs().data_handle(), list_sizes_ptr, n_rows, dim, @@ -432,12 +432,12 @@ inline void fill_refinement_index(raft::device_resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); } -// Serialization version 2 +// Serialization version 3 // No backward compatibility yet; that is, can't add additional fields without breaking // backward compatibility. // TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward // compatible fashion. -constexpr int serialization_version = 2; +constexpr int serialization_version = 3; // NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error // message. @@ -450,6 +450,59 @@ struct check_index_layout { template struct check_index_layout), 376>; + +template +void serialize_list(const raft::device_resources& handle, + std::ostream& os, + const list_data& ld, + std::optional size_override = std::nullopt) +{ + auto size = size_override.value_or(ld.size.load()); + serialize_scalar(handle, os, size); + if (size == 0) { return; } + + auto data_extents = make_extents(size, ld.data.extent(1)); + auto data_array = make_host_mdarray(data_extents); + auto inds_array = make_host_mdarray(make_extents(size)); + copy(data_array.data_handle(), ld.data.data_handle(), data_array.size(), handle.get_stream()); + copy(inds_array.data_handle(), ld.indices.data_handle(), inds_array.size(), handle.get_stream()); + handle.sync_stream(); + serialize_mdspan(handle, os, data_array.view()); + serialize_mdspan(handle, os, inds_array.view()); +} + +template +void serialize_list(const raft::device_resources& handle, + std::ostream& os, + const std::shared_ptr>& ld, + std::optional size_override = std::nullopt) +{ + if (ld) { + return serialize_list(handle, os, *ld, size_override); + } else { + return serialize_scalar(handle, os, SizeT{0}); + } +} + +template +void deserialize_list(const raft::device_resources& handle, + std::istream& is, + std::shared_ptr>& ld, + uint32_t dim) +{ + auto size = deserialize_scalar(handle, is); + if (size == 0) { return ld.reset(); } + std::make_shared>(handle, size, dim).swap(ld); + auto data_extents = make_extents(size, ld->data.extent(1)); + auto data_array = make_host_mdarray(data_extents); + auto inds_array = make_host_mdarray(make_extents(size)); + deserialize_mdspan(handle, is, data_array.view()); + deserialize_mdspan(handle, is, inds_array.view()); + copy(ld->data.data_handle(), data_array.data_handle(), data_array.size(), handle.get_stream()); + // NB: copying exactly 'size' indices to leave the rest 'kInvalidRecord' intact. + copy(ld->indices.data_handle(), inds_array.data_handle(), size, handle.get_stream()); +} + /** * Save the index to file. * @@ -469,8 +522,8 @@ void serialize(raft::device_resources const& handle, if (!of) { RAFT_FAIL("Cannot open %s", filename.c_str()); } RAFT_LOG_DEBUG( - "Saving IVF-PQ index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - /* TODO + "Saving IVF-Flat index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + serialize_scalar(handle, of, serialization_version); serialize_scalar(handle, of, index_.size()); serialize_scalar(handle, of, index_.dim()); @@ -478,10 +531,7 @@ void serialize(raft::device_resources const& handle, serialize_scalar(handle, of, index_.metric()); serialize_scalar(handle, of, index_.veclen()); serialize_scalar(handle, of, index_.adaptive_centers()); - serialize_mdspan(handle, of, index_.data()); - serialize_mdspan(handle, of, index_.indices()); serialize_mdspan(handle, of, index_.list_sizes()); - serialize_mdspan(handle, of, index_.list_offsets()); serialize_mdspan(handle, of, index_.centers()); if (index_.center_norms()) { bool has_norms = true; @@ -490,7 +540,17 @@ void serialize(raft::device_resources const& handle, } else { bool has_norms = false; serialize_scalar(handle, of, has_norms); - }*/ + } + auto sizes_host = make_host_vector(index_.list_sizes().extent(0)); + copy(sizes_host.data_handle(), + index_.list_sizes().data_handle(), + sizes_host.size(), + handle.get_stream()); + handle.sync_stream(); + serialize_mdspan(handle, of, sizes_host.view()); + for (uint32_t label = 0; label < index_.n_lists(); label++) { + serialize_list(handle, of, index_.lists()(label), sizes_host(label)); + } of.close(); if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } @@ -541,7 +601,9 @@ auto deserialize(raft::device_resources const& handle, const std::string& filena auto center_norms = *index_.center_norms(); deserialize_mdspan(handle, infile, center_norms); } - }*/ + } + index_.recompute_internal_state(handle); + */ infile.close(); return index_; } From c49bf6728d1dbeb2c9b256cb009bcd72e279fe7f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 16 Feb 2023 18:44:49 +0100 Subject: [PATCH 07/33] Deserialize ivf_flat and style fix --- cpp/include/raft/neighbors/ivf_flat_types.hpp | 66 +++++---- .../spatial/knn/detail/ivf_flat_build.cuh | 134 +++++++++++------- .../spatial/knn/detail/ivf_flat_search.cuh | 9 +- cpp/test/neighbors/ann_ivf_flat.cuh | 32 +++-- 4 files changed, 133 insertions(+), 108 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 476fa755e0..53c73ae6fe 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -19,9 +19,9 @@ #include "ann_types.hpp" #include +#include #include #include -#include #include #include @@ -87,12 +87,11 @@ struct list_data { /** The actual size of the content. */ std::atomic size; - list_data(raft::device_resources const& res, SizeT n_rows, uint32_t dim) - : size{n_rows} + list_data(raft::device_resources const& res, SizeT n_rows, uint32_t dim) : size{n_rows} { auto capacity = round_up_safe(n_rows, kIndexGroupSize); try { - data = make_device_matrix(res, capacity, dim); + data = make_device_matrix(res, capacity, dim); indices = make_device_vector(res, capacity); } catch (std::bad_alloc& e) { RAFT_FAIL( @@ -190,12 +189,10 @@ struct index : ann::index { * NB: this may be empty if the index is empty or if the metric does not require the center norms * calculation. */ - inline auto center_norms() noexcept - -> std::optional> + inline auto center_norms() noexcept -> std::optional> { if (center_norms_.has_value()) { - return std::make_optional>( - center_norms_->view()); + return std::make_optional>(center_norms_->view()); } else { return std::nullopt; } @@ -204,18 +201,14 @@ struct index : ann::index { -> std::optional> { if (center_norms_.has_value()) { - return std::make_optional>( - center_norms_->view()); + return std::make_optional>(center_norms_->view()); } else { return std::nullopt; } } /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT - { - return total_size_; - } + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return total_size_; } /** Dimensionality of the data. */ [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { @@ -249,12 +242,22 @@ struct index : ann::index { lists_{make_host_vector>, uint32_t>(n_lists)}, list_sizes_{make_device_vector(res, n_lists)}, data_ptrs_{make_device_vector(res, n_lists)}, - inds_ptrs_{make_device_vector(res, n_lists)} + inds_ptrs_{make_device_vector(res, n_lists)}, + total_size_{0} { - check_consistency(); for (uint32_t i = 0; i < n_lists; i++) { lists_(i) = std::shared_ptr>(); } + switch (metric_) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + center_norms_ = make_device_vector(res, n_lists); + break; + default: center_norms_ = std::nullopt; + } + check_consistency(); } /** Construct an empty index. It needs to be trained and then populated. */ @@ -264,12 +267,8 @@ struct index : ann::index { } /** Pointers to the inverted lists (clusters) data [n_lists]. */ - inline auto data_ptrs() noexcept -> device_vector_view - { - return data_ptrs_.view(); - } - [[nodiscard]] inline auto data_ptrs() const noexcept - -> device_vector_view + inline auto data_ptrs() noexcept -> device_vector_view { return data_ptrs_.view(); } + [[nodiscard]] inline auto data_ptrs() const noexcept -> device_vector_view { return data_ptrs_.view(); } @@ -279,8 +278,7 @@ struct index : ann::index { { return inds_ptrs_.view(); } - [[nodiscard]] inline auto inds_ptrs() const noexcept - -> device_vector_view + [[nodiscard]] inline auto inds_ptrs() const noexcept -> device_vector_view { return inds_ptrs_.view(); } @@ -290,27 +288,27 @@ struct index : ann::index { */ void recompute_internal_state(raft::device_resources const& res) { - auto stream = res.get_stream(); + auto stream = res.get_stream(); // Actualize the list pointers - auto this_lists = lists(); - auto this_data_ptrs = data_ptrs(); - auto this_inds_ptrs = inds_ptrs(); + auto this_lists = lists(); + auto this_data_ptrs = data_ptrs(); + auto this_inds_ptrs = inds_ptrs(); IdxT recompute_total_size = 0; for (uint32_t label = 0; label < this_lists.size(); label++) { - const auto data_ptr = this_lists(label) ? this_lists(label)->data.data_handle() : nullptr; - const auto inds_ptr = this_lists(label) ? this_lists(label)->indices.data_handle() : nullptr; + const auto data_ptr = this_lists(label) ? this_lists(label)->data.data_handle() : nullptr; + const auto inds_ptr = this_lists(label) ? this_lists(label)->indices.data_handle() : nullptr; const auto list_size = this_lists(label) ? IdxT(this_lists(label)->size) : 0; copy(&this_data_ptrs(label), &data_ptr, 1, stream); copy(&this_inds_ptrs(label), &inds_ptr, 1, stream); recompute_total_size += list_size; } total_size_ = recompute_total_size; + check_consistency(); } /** Lists' data and indices. */ - inline auto lists() noexcept - -> host_vector_view>, uint32_t> + inline auto lists() noexcept -> host_vector_view>, uint32_t> { return lists_.view(); } @@ -346,8 +344,8 @@ struct index : ann::index { RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS( // - (centers_.extent(0) == list_sizes_.extent(0)) && // + RAFT_EXPECTS( // + (centers_.extent(0) == list_sizes_.extent(0)) && // (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), "inconsistent number of lists (clusters)"); } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 5a90be93ae..86724c8cea 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -45,24 +45,24 @@ using namespace raft::neighbors::ivf_flat; // NOLINT using raft::neighbors::ivf_flat::index; using raft::neighbors::ivf_flat::index_params; using raft::neighbors::ivf_flat::kIndexGroupSize; -using raft::neighbors::ivf_flat::search_params; using raft::neighbors::ivf_flat::list_data; +using raft::neighbors::ivf_flat::search_params; /** - * Resize a list by the given id, so that it can contain the given number of records; - * possibly, copy the data. - * - * Besides resizing the corresponding list_data, this function updates the device pointers - * data_ptrs, inds_ptrs, and the list_sizes if necessary. - * - * The new `list_sizes(label)` represents the number of valid records in the index; - * it can be `list_size` if the previous size was not smaller; otherwise it's not updated. - * - * @param[in] handle - * @param[in] label list id - * @param[in] list_size the minimum size the list should grow. - */ + * Resize a list by the given id, so that it can contain the given number of records; + * possibly, copy the data. + * + * Besides resizing the corresponding list_data, this function updates the device pointers + * data_ptrs, inds_ptrs, and the list_sizes if necessary. + * + * The new `list_sizes(label)` represents the number of valid records in the index; + * it can be `list_size` if the previous size was not smaller; otherwise it's not updated. + * + * @param[in] handle + * @param[in] label list id + * @param[in] list_size the minimum size the list should grow. + */ template -void resize_list(raft::device_resources const& handle, +void resize_list(raft::device_resources const& handle, std::shared_ptr>& orig_list, SizeT new_used_size, SizeT old_used_size, @@ -89,8 +89,8 @@ void resize_list(raft::device_resources const& handle, auto new_list = std::make_shared>(handle, new_used_size, dim); if (old_used_size > 0) { auto copied_data_extents = make_extents(old_used_size, dim); - auto copied_view = make_mdspan( - new_list->data.data_handle(), copied_data_extents); + auto copied_view = make_mdspan(new_list->data.data_handle(), + copied_data_extents); copy(copied_view.data_handle(), orig_list->data.data_handle(), copied_view.size(), @@ -104,6 +104,44 @@ void resize_list(raft::device_resources const& handle, new_list.swap(orig_list); } +template +auto clone(const raft::device_resources& res, const index& source) -> index +{ + auto stream = res.get_stream(); + + // Allocate the new index + index target( + res, source.metric(), source.n_lists(), source.adaptive_centers(), source.dim()); + + // Copy the independent parts + copy(target.list_sizes().data_handle(), + source.list_sizes().data_handle(), + source.list_sizes().size(), + stream); + copy(target.centers().data_handle(), + source.centers().data_handle(), + source.centers().size(), + stream); + if (source.center_norms().has_value()) + copy(target.center_norms().value.data_handle(), + source.center_norms().value.data_handle(), + source.center_norms().value.size(), + stream); + // Copy shared pointers + { + auto source_lists = source.lists(); + auto target_lists = target.lists(); + for (uint32_t label = 0; label < source_lists.size(); label++) { + target_lists(label) = source_lists(label); + } + } + + // Make sure the device pointers point to the new lists + target.recompute_internal_state(res); + + return target; +} + /** * @brief Record the dataset into the index, one source row at a time. * @@ -137,7 +175,7 @@ void resize_list(raft::device_resources const& handle, */ template __global__ void build_index_kernel(const LabelT* labels, - //const IdxT* list_offsets, + // const IdxT* list_offsets, const T* source_vecs, const IdxT* source_ixs, T** list_data_ptrs, @@ -203,10 +241,10 @@ void extend(raft::device_resources const& handle, auto new_labels = raft::make_device_vector(handle, n_rows); raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = index->metric(); - auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); - auto orig_centroids_view = raft::make_device_matrix_view( - index->centers().data_handle(), n_lists, dim); + kmeans_params.metric = index->metric(); + auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); + auto orig_centroids_view = + raft::make_device_matrix_view(index->centers().data_handle(), n_lists, dim); raft::cluster::kmeans_balanced::predict(handle, kmeans_params, new_vectors_view, @@ -214,13 +252,14 @@ void extend(raft::device_resources const& handle, new_labels.view(), utils::mapping{}); - auto* list_sizes_ptr = index->list_sizes().data_handle(); - auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); + auto* list_sizes_ptr = index->list_sizes().data_handle(); + auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); // Calculate the centers and sizes on the new data, starting from the original values if (index->adaptive_centers()) { - auto centroids_view = raft::make_device_matrix_view(index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); + auto centroids_view = raft::make_device_matrix_view( + index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); auto list_sizes_view = raft::make_device_vector_view, IdxT>( list_sizes_ptr, n_lists); @@ -253,20 +292,15 @@ void extend(raft::device_resources const& handle, handle.sync_stream(); auto lists = index->lists(); for (uint32_t label = 0; label < n_lists; label++) { - resize_list(handle, - lists(label), - new_list_sizes[label], - old_list_sizes[label], - index->dim()); - } + resize_list(handle, lists(label), new_list_sizes[label], old_list_sizes[label], index->dim()); + } } // Update the pointers and the sizes index->recompute_internal_state(handle); // Copy the old sizes, so we can start from the current state of the index; // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. - raft::copy( - list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); + raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); // Kernel to insert the new vectors const dim3 block_dim(256); @@ -308,7 +342,6 @@ auto extend(raft::device_resources const& handle, return ext_index; } - /** See raft::spatial::knn::ivf_flat::build docs */ template inline auto build(raft::device_resources const& handle, @@ -402,17 +435,13 @@ inline void fill_refinement_index(raft::device_resources const& handle, new_labels_view, raft::compose_op(raft::cast_op(), raft::div_const_op(n_candidates))); - auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); + auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); // We do not fill centers and center norms, since we will not run coarse search. // Allocate new memory auto lists = refinement_index->lists(); for (uint32_t label = 0; label < n_lists; label++) { - resize_list(handle, - lists(label), - n_candidates, - uint32_t(0), - refinement_index->dim()); + resize_list(handle, lists(label), n_candidates, uint32_t(0), refinement_index->dim()); } RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); @@ -450,7 +479,6 @@ struct check_index_layout { template struct check_index_layout), 376>; - template void serialize_list(const raft::device_resources& handle, std::ostream& os, @@ -492,7 +520,7 @@ void deserialize_list(const raft::device_resources& handle, { auto size = deserialize_scalar(handle, is); if (size == 0) { return ld.reset(); } - std::make_shared>(handle, size, dim).swap(ld); + std::make_shared>(handle, size, dim).swap(ld); auto data_extents = make_extents(size, ld->data.extent(1)); auto data_array = make_host_mdarray(data_extents); auto inds_array = make_host_mdarray(make_extents(size)); @@ -531,7 +559,6 @@ void serialize(raft::device_resources const& handle, serialize_scalar(handle, of, index_.metric()); serialize_scalar(handle, of, index_.veclen()); serialize_scalar(handle, of, index_.adaptive_centers()); - serialize_mdspan(handle, of, index_.list_sizes()); serialize_mdspan(handle, of, index_.centers()); if (index_.center_norms()) { bool has_norms = true; @@ -583,15 +610,8 @@ auto deserialize(raft::device_resources const& handle, const std::string& filena auto veclen = deserialize_scalar(handle, infile); bool adaptive_centers = deserialize_scalar(handle, infile); - index index_ = - index(handle, metric, n_lists, adaptive_centers, dim); - /* TODO - index_.allocate(handle, n_rows); - auto data = index_.data(); - deserialize_mdspan(handle, infile, data); - deserialize_mdspan(handle, infile, index_.indices()); - deserialize_mdspan(handle, infile, index_.list_sizes()); - deserialize_mdspan(handle, infile, index_.list_offsets()); + index index_ = index(handle, metric, n_lists, adaptive_centers, dim); + deserialize_mdspan(handle, infile, index_.centers()); bool has_norms = deserialize_scalar(handle, infile); if (has_norms) { @@ -602,9 +622,15 @@ auto deserialize(raft::device_resources const& handle, const std::string& filena deserialize_mdspan(handle, infile, center_norms); } } - index_.recompute_internal_state(handle); - */ + deserialize_mdspan(handle, infile, index_.list_sizes()); + for (uint32_t label = 0; label < index_.n_lists(); label++) { + deserialize_list(handle, infile, index_.lists()(label), index_.dim()); + } + handle.sync_stream(); infile.close(); + + index_.recompute_internal_state(handle); + return index_; } } // namespace raft::spatial::knn::ivf_flat::detail diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 8ab3d06efc..905d7abc07 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -679,10 +679,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, // const IdxT* list_indices - const T* const* list_data_ptrs, // const T* list_data + const IdxT* const* list_indices_ptrs, // const IdxT* list_indices + const T* const* list_data_ptrs, // const T* list_data const uint32_t* list_sizes, - //const IdxT* list_offsets, const uint32_t n_probes, const uint32_t k, const uint32_t dim, @@ -728,8 +727,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) // Every CUDA block scans one cluster at a time. for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { - const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) - //const size_t list_offset = list_offsets[list_id]; + const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) // The number of vectors in each cluster(list); [nlist] const uint32_t list_length = list_sizes[list_id]; @@ -747,7 +745,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) group_id += kNumWarps) { AccT dist = 0; // This is where this warp begins reading data (start position of an interleaved group) - //const T* data = list_data + (list_offset + group_id * kIndexGroupSize) * dim; const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; // This is the vector a given lane/thread handles diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 3807e414c0..fb0dbb130b 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -58,6 +58,15 @@ struct AnnIvfFlatInputs { bool adaptive_centers; }; +template +::std::ostream& operator<<(::std::ostream& os, const AnnIvfFlatInputs& p) +{ + os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " + << p.nprobe << ", " << p.nlist << ", " << static_cast(p.metric) << ", " + << p.adaptive_centers << '}' << std::endl; + return os; +} + template class AnnIVFFlatTest : public ::testing::TestWithParam> { public: @@ -210,24 +219,19 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { if (index.adaptive_centers()) { // The centers must be up-to-date with the corresponding data std::vector list_sizes(index.n_lists()); - //std::vector list_offsets(index.n_lists()); rmm::device_uvector centroid(ps.dim, stream_); - raft::copy( - list_sizes.data(), index.list_sizes().data_handle(), index.n_lists(), stream_); - //raft::copy( - // list_offsets.data(), index.list_offsets().data_handle(), index.n_lists(), stream_); + raft::copy(list_sizes.data(), index.list_sizes().data_handle(), index.n_lists(), stream_); handle_.sync_stream(stream_); for (uint32_t l = 0; l < index.n_lists(); l++) { rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); - raft::spatial::knn::detail::utils::copy_selected( - (IdxT)list_sizes[l], - (IdxT)ps.dim, - database.data(), - index.inds_ptrs()(l), - (IdxT)ps.dim, - cluster_data.data(), - (IdxT)ps.dim, - stream_); + raft::spatial::knn::detail::utils::copy_selected((IdxT)list_sizes[l], + (IdxT)ps.dim, + database.data(), + index.inds_ptrs()(l), + (IdxT)ps.dim, + cluster_data.data(), + (IdxT)ps.dim, + stream_); raft::stats::mean( centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, true, stream_); ASSERT_TRUE(raft::devArrMatch(index.centers().data_handle() + ps.dim * l, From 7e2d80bae4bf3cd7f527c7a545eaab4a6c87e3d9 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 21 Feb 2023 15:29:14 +0100 Subject: [PATCH 08/33] Integrate ivf::list to ivf_flat index splitting --- .../detail/ivf_flat_build.cuh | 266 ++---------------- .../detail/ivf_flat_search.cuh | 10 +- .../neighbors/detail/ivf_flat_serialize.cuh | 158 +++++++++++ .../neighbors/detail/ivf_pq_serialize.cuh | 4 +- cpp/include/raft/neighbors/detail/refine.cuh | 22 +- cpp/include/raft/neighbors/ivf_flat.cuh | 23 +- cpp/include/raft/neighbors/ivf_flat_types.hpp | 92 +++--- cpp/include/raft/neighbors/ivf_list.hpp | 45 ++- cpp/include/raft/neighbors/ivf_list_types.hpp | 7 +- cpp/include/raft/neighbors/ivf_pq_types.hpp | 2 +- cpp/test/neighbors/ann_ivf_flat.cuh | 10 +- 11 files changed, 295 insertions(+), 344 deletions(-) rename cpp/include/raft/{spatial/knn => neighbors}/detail/ivf_flat_build.cuh (63%) rename cpp/include/raft/{spatial/knn => neighbors}/detail/ivf_flat_search.cuh (99%) create mode 100644 cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh similarity index 63% rename from cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh rename to cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index 86724c8cea..9dafb63149 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -16,93 +16,35 @@ #pragma once -#include "ann_utils.cuh" - #include #include #include #include #include #include -#include #include #include #include #include +#include +#include +#include #include #include #include #include -#include -namespace raft::spatial::knn::ivf_flat::detail { +namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT -using namespace raft::neighbors::ivf_flat; // NOLINT using raft::neighbors::ivf_flat::index; using raft::neighbors::ivf_flat::index_params; using raft::neighbors::ivf_flat::kIndexGroupSize; using raft::neighbors::ivf_flat::list_data; using raft::neighbors::ivf_flat::search_params; -/** - * Resize a list by the given id, so that it can contain the given number of records; - * possibly, copy the data. - * - * Besides resizing the corresponding list_data, this function updates the device pointers - * data_ptrs, inds_ptrs, and the list_sizes if necessary. - * - * The new `list_sizes(label)` represents the number of valid records in the index; - * it can be `list_size` if the previous size was not smaller; otherwise it's not updated. - * - * @param[in] handle - * @param[in] label list id - * @param[in] list_size the minimum size the list should grow. - */ -template -void resize_list(raft::device_resources const& handle, - std::shared_ptr>& orig_list, - SizeT new_used_size, - SizeT old_used_size, - uint32_t dim) -{ - bool skip_resize = false; - // TODO update total_size - if (orig_list) { - if (new_used_size <= orig_list->indices.extent(0)) { - auto shared_list_size = old_used_size; - if (new_used_size <= old_used_size || - orig_list->size.compare_exchange_strong(shared_list_size, new_used_size)) { - // We don't need to resize the list if: - // 1. The list exists - // 2. The new size fits in the list - // 3. The list doesn't grow or no-one else has grown it yet - skip_resize = true; - } - } - } else { - old_used_size = 0; - } - if (skip_resize) { return; } - auto new_list = std::make_shared>(handle, new_used_size, dim); - if (old_used_size > 0) { - auto copied_data_extents = make_extents(old_used_size, dim); - auto copied_view = make_mdspan(new_list->data.data_handle(), - copied_data_extents); - copy(copied_view.data_handle(), - orig_list->data.data_handle(), - copied_view.size(), - handle.get_stream()); - copy(new_list->indices.data_handle(), - orig_list->indices.data_handle(), - old_used_size, - handle.get_stream()); - } - // swap the shared pointer content with the new list - new_list.swap(orig_list); -} template auto clone(const raft::device_resources& res, const index& source) -> index @@ -110,8 +52,12 @@ auto clone(const raft::device_resources& res, const index& source) -> i auto stream = res.get_stream(); // Allocate the new index - index target( - res, source.metric(), source.n_lists(), source.adaptive_centers(), source.dim()); + index target(res, + source.metric(), + source.n_lists(), + source.adaptive_centers(), + source.conservative_memory_allocation(), + source.dim()); // Copy the independent parts copy(target.list_sizes().data_handle(), @@ -219,7 +165,7 @@ __global__ void build_index_kernel(const LabelT* labels, } } -/** See raft::spatial::knn::ivf_flat::extend docs */ +/** See raft::neighbors::ivf_flat::extend docs */ template void extend(raft::device_resources const& handle, index* index, @@ -290,9 +236,11 @@ void extend(raft::device_resources const& handle, copy(old_list_sizes.data(), old_list_sizes_dev.data_handle(), n_lists, stream); copy(new_list_sizes.data(), list_sizes_ptr, n_lists, stream); handle.sync_stream(); - auto lists = index->lists(); + auto lists = index->lists(); + auto list_device_spec = list_spec{index->dim(), false}; for (uint32_t label = 0; label < n_lists; label++) { - resize_list(handle, lists(label), new_list_sizes[label], old_list_sizes[label], index->dim()); + ivf::resize_list( + handle, lists(label), list_device_spec, new_list_sizes[label], old_list_sizes[label]); } } // Update the pointers and the sizes @@ -329,7 +277,7 @@ void extend(raft::device_resources const& handle, } } -/** See raft::spatial::knn::ivf_flat::extend docs */ +/** See raft::neighbors::ivf_flat::extend docs */ template auto extend(raft::device_resources const& handle, const index& orig_index, @@ -342,7 +290,7 @@ auto extend(raft::device_resources const& handle, return ext_index; } -/** See raft::spatial::knn::ivf_flat::build docs */ +/** See raft::neighbors::ivf_flat::build docs */ template inline auto build(raft::device_resources const& handle, const index_params& params, @@ -439,9 +387,10 @@ inline void fill_refinement_index(raft::device_resources const& handle, // We do not fill centers and center norms, since we will not run coarse search. // Allocate new memory - auto lists = refinement_index->lists(); + auto lists = refinement_index->lists(); + auto list_device_spec = list_spec{refinement_index->dim(), false}; for (uint32_t label = 0; label < n_lists; label++) { - resize_list(handle, lists(label), n_candidates, uint32_t(0), refinement_index->dim()); + ivf::resize_list(handle, lists(label), list_device_spec, n_candidates, uint32_t(0)); } RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); @@ -460,177 +409,4 @@ inline void fill_refinement_index(raft::device_resources const& handle, refinement_index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } - -// Serialization version 3 -// No backward compatibility yet; that is, can't add additional fields without breaking -// backward compatibility. -// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward -// compatible fashion. -constexpr int serialization_version = 3; - -// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error -// message. -template -struct check_index_layout { - static_assert(RealSize == ExpectedSize, - "The size of the index struct has changed since the last update; " - "paste in the new size and consider updating the serialization logic"); -}; - -template struct check_index_layout), 376>; - -template -void serialize_list(const raft::device_resources& handle, - std::ostream& os, - const list_data& ld, - std::optional size_override = std::nullopt) -{ - auto size = size_override.value_or(ld.size.load()); - serialize_scalar(handle, os, size); - if (size == 0) { return; } - - auto data_extents = make_extents(size, ld.data.extent(1)); - auto data_array = make_host_mdarray(data_extents); - auto inds_array = make_host_mdarray(make_extents(size)); - copy(data_array.data_handle(), ld.data.data_handle(), data_array.size(), handle.get_stream()); - copy(inds_array.data_handle(), ld.indices.data_handle(), inds_array.size(), handle.get_stream()); - handle.sync_stream(); - serialize_mdspan(handle, os, data_array.view()); - serialize_mdspan(handle, os, inds_array.view()); -} - -template -void serialize_list(const raft::device_resources& handle, - std::ostream& os, - const std::shared_ptr>& ld, - std::optional size_override = std::nullopt) -{ - if (ld) { - return serialize_list(handle, os, *ld, size_override); - } else { - return serialize_scalar(handle, os, SizeT{0}); - } -} - -template -void deserialize_list(const raft::device_resources& handle, - std::istream& is, - std::shared_ptr>& ld, - uint32_t dim) -{ - auto size = deserialize_scalar(handle, is); - if (size == 0) { return ld.reset(); } - std::make_shared>(handle, size, dim).swap(ld); - auto data_extents = make_extents(size, ld->data.extent(1)); - auto data_array = make_host_mdarray(data_extents); - auto inds_array = make_host_mdarray(make_extents(size)); - deserialize_mdspan(handle, is, data_array.view()); - deserialize_mdspan(handle, is, inds_array.view()); - copy(ld->data.data_handle(), data_array.data_handle(), data_array.size(), handle.get_stream()); - // NB: copying exactly 'size' indices to leave the rest 'kInvalidRecord' intact. - copy(ld->indices.data_handle(), inds_array.data_handle(), size, handle.get_stream()); -} - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index_ IVF-Flat index - * - */ -template -void serialize(raft::device_resources const& handle, - const std::string& filename, - const index& index_) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open %s", filename.c_str()); } - - RAFT_LOG_DEBUG( - "Saving IVF-Flat index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - - serialize_scalar(handle, of, serialization_version); - serialize_scalar(handle, of, index_.size()); - serialize_scalar(handle, of, index_.dim()); - serialize_scalar(handle, of, index_.n_lists()); - serialize_scalar(handle, of, index_.metric()); - serialize_scalar(handle, of, index_.veclen()); - serialize_scalar(handle, of, index_.adaptive_centers()); - serialize_mdspan(handle, of, index_.centers()); - if (index_.center_norms()) { - bool has_norms = true; - serialize_scalar(handle, of, has_norms); - serialize_mdspan(handle, of, *index_.center_norms()); - } else { - bool has_norms = false; - serialize_scalar(handle, of, has_norms); - } - auto sizes_host = make_host_vector(index_.list_sizes().extent(0)); - copy(sizes_host.data_handle(), - index_.list_sizes().data_handle(), - sizes_host.size(), - handle.get_stream()); - handle.sync_stream(); - serialize_mdspan(handle, of, sizes_host.view()); - for (uint32_t label = 0; label < index_.n_lists(); label++) { - serialize_list(handle, of, index_.lists()(label), sizes_host(label)); - } - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - -/** Load an index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * @param[in] index_ IVF-Flat index - * - */ -template -auto deserialize(raft::device_resources const& handle, const std::string& filename) - -> index -{ - std::ifstream infile(filename, std::ios::in | std::ios::binary); - - if (!infile) { RAFT_FAIL("Cannot open %s", filename.c_str()); } - - auto ver = deserialize_scalar(handle, infile); - if (ver != serialization_version) { - RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); - } - auto n_rows = deserialize_scalar(handle, infile); - auto dim = deserialize_scalar(handle, infile); - auto n_lists = deserialize_scalar(handle, infile); - auto metric = deserialize_scalar(handle, infile); - auto veclen = deserialize_scalar(handle, infile); - bool adaptive_centers = deserialize_scalar(handle, infile); - - index index_ = index(handle, metric, n_lists, adaptive_centers, dim); - - deserialize_mdspan(handle, infile, index_.centers()); - bool has_norms = deserialize_scalar(handle, infile); - if (has_norms) { - if (!index_.center_norms()) { - RAFT_FAIL("Error inconsistent center norms"); - } else { - auto center_norms = *index_.center_norms(); - deserialize_mdspan(handle, infile, center_norms); - } - } - deserialize_mdspan(handle, infile, index_.list_sizes()); - for (uint32_t label = 0; label < index_.n_lists(); label++) { - deserialize_list(handle, infile, index_.lists()(label), index_.dim()); - } - handle.sync_stream(); - infile.close(); - - index_.recompute_internal_state(handle); - - return index_; -} -} // namespace raft::spatial::knn::ivf_flat::detail +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh similarity index 99% rename from cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh rename to cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index 905d7abc07..0003b02fc5 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -16,8 +16,6 @@ #pragma once -#include "ann_utils.cuh" - #include #include #include @@ -30,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -39,10 +38,9 @@ #include #include -namespace raft::spatial::knn::ivf_flat::detail { +namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT -using namespace raft::neighbors::ivf_flat; // NOLINT using raft::neighbors::ivf_flat::index; using raft::neighbors::ivf_flat::index_params; @@ -1271,7 +1269,7 @@ inline bool is_min_close(distance::DistanceType metric) return select_min; } -/** See raft::spatial::knn::ivf_flat::search docs */ +/** See raft::neighbors::ivf_flat::search docs */ template inline void search(raft::device_resources const& handle, const search_params& params, @@ -1308,4 +1306,4 @@ inline void search(raft::device_resources const& handle, mr); } -} // namespace raft::spatial::knn::ivf_flat::detail +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh new file mode 100644 index 0000000000..405bc3252f --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace raft::neighbors::ivf_flat::detail { + +// Serialization version 3 +// No backward compatibility yet; that is, can't add additional fields without breaking +// backward compatibility. +// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward +// compatible fashion. +constexpr int serialization_version = 3; + +// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error +// message. +template +struct check_index_layout { + static_assert(RealSize == ExpectedSize, + "The size of the index struct has changed since the last update; " + "paste in the new size and consider updating the serialization logic"); +}; + +template struct check_index_layout), 376>; + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index_ IVF-Flat index + * + */ +template +void serialize(raft::device_resources const& handle, + const std::string& filename, + const index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open %s", filename.c_str()); } + + RAFT_LOG_DEBUG( + "Saving IVF-Flat index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + + serialize_scalar(handle, of, serialization_version); + serialize_scalar(handle, of, index_.size()); + serialize_scalar(handle, of, index_.dim()); + serialize_scalar(handle, of, index_.n_lists()); + serialize_scalar(handle, of, index_.metric()); + serialize_scalar(handle, of, index_.veclen()); + serialize_scalar(handle, of, index_.adaptive_centers()); + serialize_scalar(handle, of, index_.conservative_memory_allocation()); + serialize_mdspan(handle, of, index_.centers()); + if (index_.center_norms()) { + bool has_norms = true; + serialize_scalar(handle, of, has_norms); + serialize_mdspan(handle, of, *index_.center_norms()); + } else { + bool has_norms = false; + serialize_scalar(handle, of, has_norms); + } + auto sizes_host = make_host_vector(index_.list_sizes().extent(0)); + copy(sizes_host.data_handle(), + index_.list_sizes().data_handle(), + sizes_host.size(), + handle.get_stream()); + handle.sync_stream(); + serialize_mdspan(handle, of, sizes_host.view()); + + auto list_store_spec = list_spec{index_.dim(), true}; + for (uint32_t label = 0; label < index_.n_lists(); label++) { + ivf::serialize_list( + handle, of, index_.lists()(label), list_store_spec, sizes_host(label)); + } + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } +} + +/** Load an index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index_ IVF-Flat index + * + */ +template +auto deserialize(raft::device_resources const& handle, const std::string& filename) + -> index +{ + std::ifstream infile(filename, std::ios::in | std::ios::binary); + + if (!infile) { RAFT_FAIL("Cannot open %s", filename.c_str()); } + + auto ver = deserialize_scalar(handle, infile); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + auto n_rows = deserialize_scalar(handle, infile); + auto dim = deserialize_scalar(handle, infile); + auto n_lists = deserialize_scalar(handle, infile); + auto metric = deserialize_scalar(handle, infile); + auto veclen = deserialize_scalar(handle, infile); + bool adaptive_centers = deserialize_scalar(handle, infile); + bool cma = deserialize_scalar(handle, infile); + + index index_ = index(handle, metric, n_lists, adaptive_centers, cma, dim); + + deserialize_mdspan(handle, infile, index_.centers()); + bool has_norms = deserialize_scalar(handle, infile); + if (has_norms) { + if (!index_.center_norms()) { + RAFT_FAIL("Error inconsistent center norms"); + } else { + auto center_norms = *index_.center_norms(); + deserialize_mdspan(handle, infile, center_norms); + } + } + deserialize_mdspan(handle, infile, index_.list_sizes()); + + auto list_device_spec = list_spec{index_.dim(), cma}; + auto list_store_spec = list_spec{index_.dim(), true}; + for (uint32_t label = 0; label < index_.n_lists(); label++) { + ivf::deserialize_list( + handle, infile, index_.lists()(label), list_store_spec, list_device_spec); + } + handle.sync_stream(); + infile.close(); + + index_.recompute_internal_state(handle); + + return index_; +} +} // namespace raft::neighbors::ivf_flat::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index 33d9b363ba..bab9a52ed1 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -96,7 +96,7 @@ void serialize(raft::device_resources const& handle_, serialize_mdspan(handle_, of, sizes_host.view()); auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; for (uint32_t label = 0; label < index.n_lists(); label++) { - ivf::serialize_list( + ivf::serialize_list( handle_, of, index.lists()[label], list_store_spec, sizes_host(label)); } @@ -154,7 +154,7 @@ auto deserialize(raft::device_resources const& handle_, const std::string& filen auto list_device_spec = list_spec{pq_bits, pq_dim, cma}; auto list_store_spec = list_spec{pq_bits, pq_dim, true}; for (auto& list : index.lists()) { - ivf::deserialize_list( + ivf::deserialize_list( handle_, infile, list, list_store_spec, list_device_spec); } diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index b264643584..5d226132a4 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -20,9 +20,9 @@ #include #include #include +#include +#include #include -#include -#include #include #include @@ -108,17 +108,17 @@ void refine_device(raft::device_resources const& handle, handle.get_thrust_policy(), fake_coarse_idx.data(), fake_coarse_idx.data() + n_queries); raft::neighbors::ivf_flat::index refinement_index( - handle, metric, n_queries, false, dim); + handle, metric, n_queries, false, false, dim); - raft::spatial::knn::ivf_flat::detail::fill_refinement_index(handle, - &refinement_index, - dataset.data_handle(), - neighbor_candidates.data_handle(), - n_queries, - n_candidates); + raft::neighbors::ivf_flat::detail::fill_refinement_index(handle, + &refinement_index, + dataset.data_handle(), + neighbor_candidates.data_handle(), + n_queries, + n_candidates); uint32_t grid_dim_x = 1; - raft::spatial::knn::ivf_flat::detail::ivfflat_interleaved_scan< + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, typename raft::spatial::knn::detail::utils::config::value_t, idx_t>(refinement_index, @@ -128,7 +128,7 @@ void refine_device(raft::device_resources const& handle, refinement_index.metric(), 1, k, - raft::spatial::knn::ivf_flat::detail::is_min_close(metric), + raft::neighbors::ivf_flat::detail::is_min_close(metric), indices.data_handle(), distances.data_handle(), grid_dim_x, diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index d210d2c74b..4a5e2a3641 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -16,9 +16,10 @@ #pragma once +#include +#include +#include #include -#include -#include #include @@ -67,7 +68,7 @@ auto build(raft::device_resources const& handle, IdxT n_rows, uint32_t dim) -> index { - return raft::spatial::knn::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); + return raft::neighbors::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); } /** @@ -112,11 +113,11 @@ auto build(raft::device_resources const& handle, raft::device_matrix_view dataset, const index_params& params) -> index { - return raft::spatial::knn::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); + return raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); } /** @} */ @@ -160,7 +161,7 @@ auto extend(raft::device_resources const& handle, const IdxT* new_indices, IdxT n_rows) -> index { - return raft::spatial::knn::ivf_flat::detail::extend( + return raft::neighbors::ivf_flat::detail::extend( handle, orig_index, new_vectors, new_indices, n_rows); } @@ -252,7 +253,7 @@ void extend(raft::device_resources const& handle, const IdxT* new_indices, IdxT n_rows) { - raft::spatial::knn::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); + raft::neighbors::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); } /** @@ -355,7 +356,7 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return raft::spatial::knn::ivf_flat::detail::search( + return raft::neighbors::ivf_flat::detail::search( handle, params, index, queries, n_queries, k, neighbors, distances, mr); } diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 53c73ae6fe..b60d5b3aa7 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -39,14 +40,6 @@ namespace raft::neighbors::ivf_flat { /** Size of the interleaved group (see `index::data` description). */ constexpr static uint32_t kIndexGroupSize = 32; -/** - * Default value filled in the `indices()` array. - * One may encounter it trying to access a record within a cluster that is outside of the - * `list_sizes()` bound (due to the record alignment `kIndexGroupSize`). - */ -template -constexpr static IdxT kInvalidRecord = std::numeric_limits::max() - 1; - struct index_params : ann::index_params { /** The number of inverted lists (clusters) */ uint32_t n_lists = 1024; @@ -67,6 +60,16 @@ struct index_params : ann::index_params { * `index.centers()` "drift" together with the changing distribution of the newly added data. */ bool adaptive_centers = false; + /** + * By default, the algorithm allocates more space than necessary for individual clusters + * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + * data copies during repeated calls to `extend` (extending the database). + * + * The alternative is the conservative allocation behavior; when enabled, the algorithm always + * allocates the minimum amount of memory required to store the given number of records. Set this + * flag to `true` if you prefer to use as little GPU memory for the database as possible. + */ + bool conservative_memory_allocation = false; }; struct search_params : ann::search_params { @@ -77,37 +80,38 @@ struct search_params : ann::search_params { static_assert(std::is_aggregate_v); static_assert(std::is_aggregate_v); -/** The data for a single list (cluster). */ -template -struct list_data { - /** Cluster data. */ - device_matrix data; - /** Source indices. */ - device_vector indices; - /** The actual size of the content. */ - std::atomic size; - - list_data(raft::device_resources const& res, SizeT n_rows, uint32_t dim) : size{n_rows} +template +struct list_spec { + using list_extents = matrix_extent; + + SizeT align_max; + SizeT align_min; + uint32_t dim; + + constexpr list_spec(uint32_t dim, bool conservative_memory_allocation) + : dim(dim), + align_min(kIndexGroupSize), + align_max(conservative_memory_allocation ? kIndexGroupSize : 1024) { - auto capacity = round_up_safe(n_rows, kIndexGroupSize); - try { - data = make_device_matrix(res, capacity, dim); - indices = make_device_vector(res, capacity); - } catch (std::bad_alloc& e) { - RAFT_FAIL( - "ivf-flat: failed to allocate a big enough index list to hold all data " - "(requested size: %zu records, selected capacity: %zu records). " - "Allocator exception: %s", - size_t(n_rows), - size_t(capacity), - e.what()); - } - // Fill the index buffer with a pre-defined marker for easier debugging - thrust::fill_n( - res.get_thrust_policy(), indices.data_handle(), indices.size(), kInvalidRecord); + } + + // Allow casting between different size-types (for safer size and offset calculations) + template + constexpr explicit list_spec(const list_spec& other_spec) + : dim{other_spec.dim}, align_min{other_spec.align_min}, align_max{other_spec.align_max} + { + } + + /** Determine the extents of an array enough to hold a given amount of data. */ + constexpr auto make_list_extents(SizeT n_rows) const -> list_extents + { + return make_extents(n_rows, dim); } }; +template +using list_data = ivf::list; + /** * @brief IVF-flat index. * @@ -232,11 +236,13 @@ struct index : ann::index { raft::distance::DistanceType metric, uint32_t n_lists, bool adaptive_centers, + bool conservative_memory_allocation, uint32_t dim) : ann::index(), veclen_(calculate_veclen(dim)), metric_(metric), adaptive_centers_(adaptive_centers), + conservative_memory_allocation_{conservative_memory_allocation}, centers_(make_device_matrix(res, n_lists, dim)), center_norms_(std::nullopt), lists_{make_host_vector>, uint32_t>(n_lists)}, @@ -262,7 +268,12 @@ struct index : ann::index { /** Construct an empty index. It needs to be trained and then populated. */ index(raft::device_resources const& res, const index_params& params, uint32_t dim) - : index(res, params.metric, params.n_lists, params.adaptive_centers, dim) + : index(res, + params.metric, + params.n_lists, + params.adaptive_centers, + params.conservative_memory_allocation, + dim) { } @@ -282,6 +293,14 @@ struct index : ann::index { { return inds_ptrs_.view(); } + /** + * Whether to use convervative memory allocation when extending the list (cluster) data + * (see index_params.conservative_memory_allocation). + */ + [[nodiscard]] constexpr inline auto conservative_memory_allocation() const noexcept -> bool + { + return conservative_memory_allocation_; + } /** * Update the state of the dependent index members. @@ -326,6 +345,7 @@ struct index : ann::index { uint32_t veclen_; raft::distance::DistanceType metric_; bool adaptive_centers_; + bool conservative_memory_allocation_; host_vector>, uint32_t> lists_; device_vector list_sizes_; device_matrix centers_; diff --git a/cpp/include/raft/neighbors/ivf_list.hpp b/cpp/include/raft/neighbors/ivf_list.hpp index 4644143057..2fe3d2cf3d 100644 --- a/cpp/include/raft/neighbors/ivf_list.hpp +++ b/cpp/include/raft/neighbors/ivf_list.hpp @@ -35,10 +35,10 @@ namespace raft::neighbors::ivf { /** The data for a single IVF list. */ -template