diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index d82ed158e7..96af5c9522 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -111,12 +111,12 @@ inline auto extend(const handle_t& handle, IdxT n_rows, rmm::cuda_stream_view stream) -> index { - auto n_lists = orig_index.n_lists(); - auto dim = orig_index.dim(); + auto n_lists = orig_index.n_lists; + auto dim = orig_index.dim; common::nvtx::range fun_scope( "ivf_flat::extend(%zu, %u)", size_t(n_rows), dim); - RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0, + RAFT_EXPECTS(new_indices != nullptr || orig_index.size == 0, "You must pass data indices when the index is non-empty."); rmm::device_uvector new_labels(n_rows, stream); @@ -130,12 +130,12 @@ inline auto extend(const handle_t& handle, orig_index.metric, stream); - auto&& list_sizes = make_device_mdarray(stream, n_lists); - auto&& list_offsets = make_device_mdarray(stream, n_lists + 1); + auto&& list_sizes = rmm::device_uvector(n_lists, stream); + auto&& list_offsets = rmm::device_uvector(n_lists + 1, stream); auto list_sizes_ptr = list_sizes.data(); auto list_offsets_ptr = list_offsets.data(); - auto&& centers = make_device_mdarray(stream, n_lists, dim); + auto&& centers = rmm::device_uvector(size_t(n_lists) * size_t(dim), stream); auto centers_ptr = centers.data(); // Calculate the centers and sizes on the new data, starting from the original values @@ -164,11 +164,11 @@ inline auto extend(const handle_t& handle, update_host(&index_size, list_offsets_ptr + n_lists, 1, stream); handle.sync_stream(stream); - auto&& data = make_device_mdarray(stream, index_size, dim); - auto&& indices = make_device_mdarray(stream, index_size); + auto&& data = rmm::device_uvector(index_size * IdxT(dim), stream); + auto&& indices = rmm::device_uvector(index_size, stream); // Populate index with the old data - if (orig_index.size() > 0) { + if (orig_index.size > 0) { utils::block_copy(orig_index.list_offsets.data(), list_offsets_ptr, IdxT(n_lists), @@ -206,10 +206,10 @@ inline auto extend(const handle_t& handle, // Precompute the centers vector norms for L2Expanded distance auto compute_norms = [&]() { - auto&& r = make_device_mdarray(stream, n_lists); + auto&& r = rmm::device_uvector(n_lists, stream); utils::dots_along_rows(n_lists, dim, centers.data(), r.data(), stream); RAFT_LOG_TRACE_VEC(r.data(), 20); - return r; + return std::move(r); }; auto&& center_norms = orig_index.metric == raft::distance::DistanceType::L2Expanded ? std::optional(compute_norms()) @@ -219,6 +219,9 @@ inline auto extend(const handle_t& handle, index new_index{{}, orig_index.veclen, orig_index.metric, + index_size, + orig_index.dim, + orig_index.n_lists, std::move(data), std::move(indices), std::move(list_sizes), @@ -256,7 +259,7 @@ inline auto build(const handle_t& handle, auto n_lists = static_cast(params.n_lists); // kmeans cluster ids for the dataset - auto&& centers = make_device_mdarray(stream, n_lists, dim); + auto&& centers = rmm::device_uvector(size_t(n_lists) * size_t(dim), stream); // Predict labels of the whole dataset kmeans::build_optimized_kmeans(handle, @@ -270,10 +273,10 @@ inline auto build(const handle_t& handle, params.metric, stream); - auto&& data = make_device_mdarray(stream, 0, dim); - auto&& indices = make_device_mdarray(stream, 0); - auto&& list_sizes = make_device_mdarray(stream, n_lists); - auto&& list_offsets = make_device_mdarray(stream, n_lists + 1); + auto&& data = rmm::device_uvector(0, stream); + auto&& indices = rmm::device_uvector(0, stream); + auto&& list_sizes = rmm::device_uvector(n_lists, stream); + auto&& list_offsets = rmm::device_uvector(n_lists + 1, stream); utils::memzero(list_sizes.data(), list_sizes.size(), stream); utils::memzero(list_offsets.data(), list_offsets.size(), stream); @@ -281,6 +284,9 @@ inline auto build(const handle_t& handle, index index{{}, veclen, params.metric, + IdxT(0), + dim, + n_lists, std::move(data), std::move(indices), std::move(list_sizes), diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index c04ece3858..a52fbc69de 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -825,7 +825,7 @@ void launch_kernel(Lambda lambda, interleaved_scan_kernel; const int max_query_smem = 16384; int query_smem_elems = - std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); + std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim)); int smem_size = query_smem_elems * sizeof(T); constexpr int kSubwarpSize = std::min(Capacity, WarpSize); smem_size += raft::spatial::knn::detail::topk::calc_smem_size_for_block_wide( @@ -861,10 +861,10 @@ void launch_kernel(Lambda lambda, index.list_offsets.data(), n_probes, k, - index.dim(), + index.dim, neighbors, distances); - queries += grid_dim_y * index.dim(); + queries += grid_dim_y * index.dim; neighbors += grid_dim_y * grid_dim_x * k; distances += grid_dim_y * grid_dim_x * k; } @@ -1072,7 +1072,7 @@ void search_impl(const handle_t& handle, // The norm of query rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); // The distance value of cluster(list) and queries - rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); + rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists, stream, search_mr); // The topk distance value of cluster(list) and queries rmm::device_uvector coarse_distances_dev(n_queries * n_probes, stream, search_mr); // The topk index of cluster(list) and queries @@ -1084,7 +1084,7 @@ void search_impl(const handle_t& handle, size_t float_query_size; if constexpr (std::is_integral_v) { - float_query_size = n_queries * index.dim(); + float_query_size = n_queries * index.dim; } else { float_query_size = 0; } @@ -1095,7 +1095,7 @@ void search_impl(const handle_t& handle, converted_queries_ptr = const_cast(queries); } else { linalg::unaryOp( - converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); + converted_queries_ptr, queries, n_queries * index.dim, utils::mapping{}, stream); } float alpha = 1.0f; @@ -1105,11 +1105,11 @@ void search_impl(const handle_t& handle, alpha = -2.0f; beta = 1.0f; utils::dots_along_rows( - n_queries, index.dim(), converted_queries_ptr, query_norm_dev.data(), stream); + n_queries, index.dim, converted_queries_ptr, query_norm_dev.data(), stream); utils::outer_add(query_norm_dev.data(), n_queries, index.center_norms->data(), - index.n_lists(), + index.n_lists, distance_buffer_dev.data(), stream); RAFT_LOG_TRACE_VEC(index.center_norms->data(), 20); @@ -1122,17 +1122,17 @@ void search_impl(const handle_t& handle, linalg::gemm(handle, true, false, - index.n_lists(), + index.n_lists, n_queries, - index.dim(), + index.dim, &alpha, index.centers.data(), - index.dim(), + index.dim, converted_queries_ptr, - index.dim(), + index.dim, &beta, distance_buffer_dev.data(), - index.n_lists(), + index.n_lists, stream); RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), 20); @@ -1140,7 +1140,7 @@ void search_impl(const handle_t& handle, topk::warp_sort_topk(distance_buffer_dev.data(), nullptr, n_queries, - index.n_lists(), + index.n_lists, n_probes, coarse_distances_dev.data(), coarse_indices_dev.data(), @@ -1151,7 +1151,7 @@ void search_impl(const handle_t& handle, topk::radix_topk(distance_buffer_dev.data(), nullptr, n_queries, - index.n_lists(), + index.n_lists, n_probes, coarse_distances_dev.data(), coarse_indices_dev.data(), @@ -1249,11 +1249,11 @@ inline void search(const handle_t& handle, rmm::mr::device_memory_resource* mr = nullptr) { common::nvtx::range fun_scope( - "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); + "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim); RAFT_EXPECTS(params.n_probes > 0, "n_probes (number of clusters to probe in the search) must be positive."); - auto n_probes = std::min(params.n_probes, index.n_lists()); + auto n_probes = std::min(params.n_probes, index.n_lists); bool select_min; switch (index.metric) { diff --git a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp index 81c9bba998..6c46a288c1 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp +++ b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp @@ -18,10 +18,12 @@ #include "common.hpp" -#include +#include #include #include +#include + #include namespace raft::spatial::knn::ivf_flat { @@ -62,6 +64,12 @@ struct index : knn::index { const uint32_t veclen; /** Distance metric used for clustering. */ const raft::distance::DistanceType metric; + /** Total length of the index. */ + const IdxT size; + /** Dimensionality of the data. */ + const uint32_t dim; + /** Number of clusters/inverted lists. */ + const uint32_t n_lists; /** * Inverted list data [size, dim]. @@ -86,20 +94,20 @@ struct index : knn::index { * x[16, 4], x[16, 5], x[17, 4], x[17, 5], ... x[30, 4], x[30, 5], - , - , * */ - device_mdarray data; + rmm::device_uvector data; /** Inverted list indices: ids of items in the source data [size] */ - device_mdarray indices; + rmm::device_uvector indices; /** Sizes of the lists (clusters) [n_lists] */ - device_mdarray list_sizes; + rmm::device_uvector list_sizes; /** * Offsets into the lists [n_lists + 1]. * The last value contains the total length of the index. */ - device_mdarray list_offsets; + rmm::device_uvector list_offsets; /** k-means cluster centers corresponding to the lists [n_lists, dim] */ - device_mdarray centers; + rmm::device_uvector centers; /** (Optional) Precomputed norms of the `centers` w.r.t. the chosen distance metric [n_lists] */ - std::optional> center_norms; + std::optional> center_norms; // Don't allow copying the index for performance reasons (try avoiding copying data) index(const index&) = delete; @@ -108,33 +116,12 @@ struct index : knn::index { auto operator=(index&&) -> index& = default; ~index() = default; - /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT - { - return static_cast(data.extent(0)); - } - /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return static_cast(data.extent(1)); - } - /** Number of clusters/inverted lists. */ - [[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t - { - return static_cast(centers.extent(0)); - } - /** Throw an error if the index content is inconsistent. */ inline void check_consistency() const { - RAFT_EXPECTS(dim() % veclen == 0, "dimensionality is not a multiple of the veclen"); - RAFT_EXPECTS(data.extent(0) == indices.extent(0), "inconsistent index size"); - RAFT_EXPECTS(data.extent(1) == centers.extent(1), "inconsistent data dimensionality"); - RAFT_EXPECTS( // - (centers.extent(0) == list_sizes.extent(0)) && // - (centers.extent(0) + 1 == list_offsets.extent(0)) && // - (!center_norms.has_value() || centers.extent(0) == center_norms->extent(0)), - "inconsistent number of lists (clusters)"); + RAFT_EXPECTS(dim % veclen == 0, "dimensionality is not a multiple of the veclen"); + RAFT_EXPECTS(list_offsets.size() == list_sizes.size() + 1, + "inconsistent number of lists (clusters)"); RAFT_EXPECTS(reinterpret_cast(data.data()) % (veclen * sizeof(T)) == 0, "The data storage pointer is not aligned to the vector length"); }