diff --git a/cpp/include/raft/core/device_coo_matrix.hpp b/cpp/include/raft/core/device_coo_matrix.hpp index 67aa4e12f1..9ab5160ef5 100644 --- a/cpp/include/raft/core/device_coo_matrix.hpp +++ b/cpp/include/raft/core/device_coo_matrix.hpp @@ -23,14 +23,21 @@ namespace raft { -template +using device_coordinate_structure_view = coordinate_structure_view; + +/** + * Specialization for a sparsity-owning coordinate structure which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy, - SparsityType sparsity_type = SparsityType::OWNING> -using device_coo_matrix = - coo_matrix; + template typename ContainerPolicy = device_uvector_policy> +using device_coordinate_structure = + coordinate_structure; /** * Specialization for a coo matrix view which uses device memory @@ -38,6 +45,15 @@ using device_coo_matrix = template using device_coo_matrix_view = coo_matrix_view; +template typename ContainerPolicy = device_uvector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using device_coo_matrix = + coo_matrix; + /** * Specialization for a sparsity-owning coo matrix which uses device memory */ @@ -62,21 +78,15 @@ using device_sparsity_preserving_coo_matrix = coo_matrix; -/** - * Specialization for a sparsity-owning coordinate structure which uses device memory - */ -template typename ContainerPolicy = device_uvector_policy> -using device_coordinate_structure = - coordinate_structure; +template +struct is_device_coo_matrix_view : std::false_type {}; -/** - * Specialization for a sparsity-preserving coordinate structure view which uses device memory - */ -template -using device_coordinate_structure_view = coordinate_structure_view; +template +struct is_device_coo_matrix_view> + : std::true_type {}; + +template +constexpr bool is_device_coo_matrix_view_v = is_device_coo_matrix_view::value; template struct is_device_coo_matrix : std::false_type {}; diff --git a/cpp/include/raft/core/device_csr_matrix.hpp b/cpp/include/raft/core/device_csr_matrix.hpp index 1495609d75..df186cc194 100644 --- a/cpp/include/raft/core/device_csr_matrix.hpp +++ b/cpp/include/raft/core/device_csr_matrix.hpp @@ -25,6 +25,29 @@ namespace raft { +/** + * Specialization for a sparsity-preserving compressed structure view which uses device memory + */ +template +using device_compressed_structure_view = + compressed_structure_view; + +/** + * Specialization for a sparsity-owning compressed structure which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_compressed_structure = + compressed_structure; + +/** + * Specialization for a csr matrix view which uses device memory + */ +template +using device_csr_matrix_view = csr_matrix_view; + template ; +/** + * Specialization for a sparsity-preserving csr matrix which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_sparsity_preserving_csr_matrix = csr_matrix; + +template +struct is_device_csr_matrix_view : std::false_type {}; + +template +struct is_device_csr_matrix_view< + device_csr_matrix_view> : std::true_type {}; + +template +constexpr bool is_device_csr_matrix_view_v = is_device_csr_matrix_view::value; + template struct is_device_csr_matrix : std::false_type {}; @@ -70,51 +119,6 @@ template constexpr bool is_device_csr_sparsity_preserving_v = is_device_csr_matrix::value and T::get_sparsity_type() == PRESERVING; -/** - * Specialization for a csr matrix view which uses device memory - */ -template -using device_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-preserving csr matrix which uses device memory - */ -template typename ContainerPolicy = device_uvector_policy> -using device_sparsity_preserving_csr_matrix = csr_matrix; - -/** - * Specialization for a csr matrix view which uses device memory - */ -template -using device_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-owning compressed structure which uses device memory - */ -template typename ContainerPolicy = device_uvector_policy> -using device_compressed_structure = - compressed_structure; - -/** - * Specialization for a sparsity-preserving compressed structure view which uses device memory - */ -template -using device_compressed_structure_view = - compressed_structure_view; - /** * Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning * means that all of the underlying vectors (data, indptr, indices) are owned by the csr_matrix diff --git a/cpp/include/raft/core/host_coo_matrix.hpp b/cpp/include/raft/core/host_coo_matrix.hpp index 32e7a9e3c4..9e6aacfa48 100644 --- a/cpp/include/raft/core/host_coo_matrix.hpp +++ b/cpp/include/raft/core/host_coo_matrix.hpp @@ -22,14 +22,21 @@ namespace raft { -template +using host_coordinate_structure_view = coordinate_structure_view; + +/** + * Specialization for a sparsity-owning coordinate structure which uses host memory + */ +template typename ContainerPolicy = host_vector_policy, - SparsityType sparsity_type = SparsityType::OWNING> -using host_coo_matrix = - coo_matrix; + template typename ContainerPolicy = host_vector_policy> +using host_coordinate_structure = + coordinate_structure; /** * Specialization for a coo matrix view which uses host memory @@ -37,6 +44,15 @@ using host_coo_matrix = template using host_coo_matrix_view = coo_matrix_view; +template typename ContainerPolicy = host_vector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using host_coo_matrix = + coo_matrix; + /** * Specialization for a sparsity-owning coo matrix which uses host memory */ @@ -61,21 +77,15 @@ using host_sparsity_preserving_coo_matrix = coo_matrix; -/** - * Specialization for a sparsity-owning coordinate structure which uses host memory - */ -template typename ContainerPolicy = host_vector_policy> -using host_coordinate_structure = - coordinate_structure; +template +struct is_host_coo_matrix_view : std::false_type {}; -/** - * Specialization for a sparsity-preserving coordinate structure view which uses host memory - */ -template -using host_coordinate_structure_view = coordinate_structure_view; +template +struct is_host_coo_matrix_view> + : std::true_type {}; + +template +constexpr bool is_host_coo_matrix_view_v = is_host_coo_matrix_view::value; template struct is_host_coo_matrix : std::false_type {}; diff --git a/cpp/include/raft/core/host_csr_matrix.hpp b/cpp/include/raft/core/host_csr_matrix.hpp index 86199335f2..4b4df823db 100644 --- a/cpp/include/raft/core/host_csr_matrix.hpp +++ b/cpp/include/raft/core/host_csr_matrix.hpp @@ -24,6 +24,29 @@ namespace raft { +/** + * Specialization for a sparsity-preserving compressed structure view which uses host memory + */ +template +using host_compressed_structure_view = + compressed_structure_view; + +/** + * Specialization for a sparsity-owning compressed structure which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_compressed_structure = + compressed_structure; + +/** + * Specialization for a csr matrix view which uses host memory + */ +template +using host_csr_matrix_view = csr_matrix_view; + template ; +/** + * Specialization for a sparsity-preserving csr matrix which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_sparsity_preserving_csr_matrix = csr_matrix; + +template +struct is_host_csr_matrix_view : std::false_type {}; + +template +struct is_host_csr_matrix_view> + : std::true_type {}; + +template +constexpr bool is_host_csr_matrix_view_v = is_host_csr_matrix_view::value; + template struct is_host_csr_matrix : std::false_type {}; @@ -66,53 +115,9 @@ constexpr bool is_host_csr_sparsity_owning_v = is_host_csr_matrix::value and T::get_sparsity_type() == OWNING; template -constexpr bool is_host_csr_sparsity_preserving_v = - is_host_csr_matrix::value and T::get_sparsity_type() == PRESERVING; - -/** - * Specialization for a csr matrix view which uses host memory - */ -template -using host_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-preserving csr matrix which uses host memory - */ -template typename ContainerPolicy = host_vector_policy> -using host_sparsity_preserving_csr_matrix = csr_matrix; - -/** - * Specialization for a csr matrix view which uses host memory - */ -template -using host_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-owning compressed structure which uses host memory - */ -template typename ContainerPolicy = host_vector_policy> -using host_compressed_structure = - compressed_structure; - -/** - * Specialization for a sparsity-preserving compressed structure view which uses host memory - */ -template -using host_compressed_structure_view = - compressed_structure_view; +constexpr bool is_host_csr_sparsity_preserving_v = std::disjunction_v< + is_host_csr_matrix_view, + std::bool_constant::value and T::get_sparsity_type() == PRESERVING>>; /** * Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index 9b079a8539..e121c1be9c 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -471,41 +471,18 @@ class GramMatrixBase { ASSERT(is_row_major_nopad || is_col_major_nopad, "Sparse linear Kernel distance does not support ld_out parameter"); - auto x1_structure = x1.structure_view(); - auto x2_structure = x2.structure_view(); - raft::sparse::distance::distances_config_t dist_config(handle); - - // switch a,b based on data layout + // switch a,b based on is_row_major if (is_col_major_nopad) { - dist_config.a_nrows = x2_structure.get_n_rows(); - dist_config.a_ncols = x2_structure.get_n_cols(); - dist_config.a_nnz = x2_structure.get_nnz(); - dist_config.a_indptr = const_cast(x2_structure.get_indptr().data()); - dist_config.a_indices = const_cast(x2_structure.get_indices().data()); - dist_config.a_data = const_cast(x2.get_elements().data()); - dist_config.b_nrows = x1_structure.get_n_rows(); - dist_config.b_ncols = x1_structure.get_n_cols(); - dist_config.b_nnz = x1_structure.get_nnz(); - dist_config.b_indptr = const_cast(x1_structure.get_indptr().data()); - dist_config.b_indices = const_cast(x1_structure.get_indices().data()); - dist_config.b_data = const_cast(x1.get_elements().data()); + auto out_row_major = raft::make_device_matrix_view( + out.data_handle(), out.extent(1), out.extent(0)); + raft::sparse::distance::pairwise_distance( + handle, x2, x1, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0); } else { - dist_config.a_nrows = x1_structure.get_n_rows(); - dist_config.a_ncols = x1_structure.get_n_cols(); - dist_config.a_nnz = x1_structure.get_nnz(); - dist_config.a_indptr = const_cast(x1_structure.get_indptr().data()); - dist_config.a_indices = const_cast(x1_structure.get_indices().data()); - dist_config.a_data = const_cast(x1.get_elements().data()); - dist_config.b_nrows = x2_structure.get_n_rows(); - dist_config.b_ncols = x2_structure.get_n_cols(); - dist_config.b_nnz = x2_structure.get_nnz(); - dist_config.b_indptr = const_cast(x2_structure.get_indptr().data()); - dist_config.b_indices = const_cast(x2_structure.get_indices().data()); - dist_config.b_data = const_cast(x2.get_elements().data()); + auto out_row_major = raft::make_device_matrix_view( + out.data_handle(), out.extent(0), out.extent(1)); + raft::sparse::distance::pairwise_distance( + handle, x1, x2, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0); } - - raft::sparse::distance::pairwiseDistance( - out.data_handle(), dist_config, raft::distance::DistanceType::InnerProduct, 0.0); } }; diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index e05c8882fe..f934d7e3b4 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -18,6 +18,7 @@ #include // uint32_t #include // __half +#include #include // RAFT_EXPLICIT #include // rmm:cuda_stream_view #include // rmm::mr::device_memory_resource @@ -27,7 +28,8 @@ namespace raft::matrix::detail { template -void select_k(const T* in_val, +void select_k(raft::resources const& handle, + const T* in_val, const IdxT* in_idx, size_t batch_size, size_t len, @@ -35,24 +37,24 @@ void select_k(const T* in_val, T* out_val, IdxT* out_idx, bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + rmm::mr::device_memory_resource* mr = nullptr, + bool sorted = false) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - extern template void raft::matrix::detail::select_k(const T* in_val, \ - const IdxT* in_idx, \ - size_t batch_size, \ - size_t len, \ - int k, \ - T* out_val, \ - IdxT* out_idx, \ - bool select_min, \ - rmm::cuda_stream_view stream, \ - rmm::mr::device_memory_resource* mr) - +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + extern template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr, \ + bool sorted) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index dba2d1d841..b852f26e2e 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -19,11 +19,16 @@ #include "select_radix.cuh" #include "select_warpsort.cuh" +#include +#include #include +#include +#include #include #include #include +#include namespace raft::matrix::detail { @@ -116,6 +121,121 @@ inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k) } } +/** + * Performs a segmented sorting of a keys array with respect to + * the segments of a values array. + * @tparam KeyT + * @tparam ValT + * @param handle + * @param values + * @param keys + * @param n_segments + * @param k + * @param select_min + */ +template +void segmented_sort_by_key(raft::resources const& handle, + KeyT* keys, + ValT* values, + size_t n_segments, + size_t n_elements, + const ValT* offsets, + bool asc) +{ + auto stream = raft::resource::get_cuda_stream(handle); + auto out_inds = raft::make_device_vector(handle, n_elements); + auto out_dists = raft::make_device_vector(handle, n_elements); + + // Determine temporary device storage requirements + auto d_temp_storage = raft::make_device_vector(handle, 0); + size_t temp_storage_bytes = 0; + if (asc) { + cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(), + temp_storage_bytes, + keys, + out_dists.data_handle(), + values, + out_inds.data_handle(), + n_elements, + n_segments, + offsets, + offsets + 1, + 0, + sizeof(ValT) * 8, + stream); + } else { + cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)d_temp_storage.data_handle(), + temp_storage_bytes, + keys, + out_dists.data_handle(), + values, + out_inds.data_handle(), + n_elements, + n_segments, + offsets, + offsets + 1, + 0, + sizeof(ValT) * 8, + stream); + } + + d_temp_storage = raft::make_device_vector(handle, temp_storage_bytes); + + if (asc) { + // Run sorting operation + cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(), + temp_storage_bytes, + keys, + out_dists.data_handle(), + values, + out_inds.data_handle(), + n_elements, + n_segments, + offsets, + offsets + 1, + 0, + sizeof(ValT) * 8, + stream); + + } else { + // Run sorting operation + cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)d_temp_storage.data_handle(), + temp_storage_bytes, + keys, + out_dists.data_handle(), + values, + out_inds.data_handle(), + n_elements, + n_segments, + offsets, + offsets + 1, + 0, + sizeof(ValT) * 8, + stream); + } + + raft::copy(values, out_inds.data_handle(), out_inds.size(), stream); + raft::copy(keys, out_dists.data_handle(), out_dists.size(), stream); +} + +template +void segmented_sort_by_key(raft::resources const& handle, + raft::device_vector_view offsets, + raft::device_vector_view keys, + raft::device_vector_view values, + bool asc) +{ + RAFT_EXPECTS(keys.size() == values.size(), + "Keys and values must contain the same number of elements."); + segmented_sort_by_key(handle, + keys.data_handle(), + values.data_handle(), + offsets.size() - 1, + keys.size(), + offsets.data_handle(), + asc); +} + /** * Select k smallest or largest key/values from each row in the input data. * @@ -154,7 +274,8 @@ inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k) * memory pool here to avoid memory allocations within the call). */ template -void select_k(const T* in_val, +void select_k(raft::resources const& handle, + const T* in_val, const IdxT* in_idx, size_t batch_size, size_t len, @@ -162,25 +283,46 @@ void select_k(const T* in_val, T* out_val, IdxT* out_idx, bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::mr::device_memory_resource* mr = nullptr, + bool sorted = false) { common::nvtx::range fun_scope( "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); - auto algo = choose_select_k_algorithm(batch_size, len, k); + auto stream = raft::resource::get_cuda_stream(handle); + auto algo = choose_select_k_algorithm(batch_size, len, k); + switch (algo) { case Algo::kRadix11bits: - return detail::select::radix::select_k(in_val, - in_idx, - batch_size, - len, - k, - out_val, - out_idx, - select_min, - true, // fused_last_filter - stream); + detail::select::radix::select_k(in_val, + in_idx, + batch_size, + len, + k, + out_val, + out_idx, + select_min, + true, // fused_last_filter + stream); + + if (sorted) { + auto offsets = raft::make_device_vector(handle, (IdxT)(batch_size + 1)); + + raft::matrix::fill(handle, offsets.view(), (IdxT)k); + + thrust::exclusive_scan(raft::resource::get_thrust_policy(handle), + offsets.data_handle(), + offsets.data_handle() + offsets.size(), + offsets.data_handle(), + 0); + + auto keys = raft::make_device_vector_view(out_val, (IdxT)(batch_size * k)); + auto vals = raft::make_device_vector_view(out_idx, (IdxT)(batch_size * k)); + + segmented_sort_by_key( + handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min); + } + return; case Algo::kWarpDistributedShm: return detail::select::warpsort:: select_k_impl( @@ -188,6 +330,7 @@ void select_k(const T* in_val, case Algo::kFaissBlockSelect: return neighbors::detail::select_k( in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream); + default: RAFT_FAIL("K-selection Algorithm not supported."); } } } // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/select_k.cuh b/cpp/include/raft/matrix/select_k.cuh index 8e6dbaafa8..37a36cbf6b 100644 --- a/cpp/include/raft/matrix/select_k.cuh +++ b/cpp/include/raft/matrix/select_k.cuh @@ -58,7 +58,7 @@ namespace raft::matrix { * @tparam IdxT * the index type (what is being selected together with the keys). * - * @param[in] handle + * @param[in] handle container of reusable resources * @param[in] in_val * inputs values [batch_size, len]; * these are compared and selected. @@ -74,14 +74,17 @@ namespace raft::matrix { * the payload selected together with `out_val`. * @param[in] select_min * whether to select k smallest (true) or largest (false) keys. + * @param[in] sorted + * whether to make sure selected pairs are sorted by value */ template -void select_k(const resources& handle, +void select_k(raft::resources const& handle, raft::device_matrix_view in_val, std::optional> in_idx, raft::device_matrix_view out_val, raft::device_matrix_view out_idx, - bool select_min) + bool select_min, + bool sorted = false) { RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits::max()), "output k must fit the int type."); @@ -95,7 +98,9 @@ void select_k(const resources& handle, RAFT_EXPECTS(len == in_idx->extent(1), "value and index input lengths must be equal"); } RAFT_EXPECTS(int64_t(k) == out_idx.extent(1), "value and index output lengths must be equal"); - return detail::select_k(in_val.data_handle(), + + return detail::select_k(handle, + in_val.data_handle(), in_idx.has_value() ? in_idx->data_handle() : nullptr, batch_size, len, @@ -103,7 +108,8 @@ void select_k(const resources& handle, out_val.data_handle(), out_idx.data_handle(), select_min, - resource::get_cuda_stream(handle)); + nullptr, + sorted); } /** @} */ // end of group select_k diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index a88a449a68..4a384b90e1 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -55,8 +55,8 @@ enum class search_algo { enum class hash_mode { HASH, SMALL, AUTO }; struct search_params : ann::search_params { - /** Maximum number of queries to search at the same time (batch size). */ - size_t max_queries = 1; + /** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/ + size_t max_queries = 0; /** Number of intermediate search results retained during the search. * diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 7b35af4417..1561a3bb8d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -65,7 +65,9 @@ void search_main(raft::resources const& res, static_cast(queries.extent(0)), static_cast(queries.extent(1))); RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match"); - uint32_t topk = neighbors.extent(1); + const uint32_t topk = neighbors.extent(1); + + if (params.max_queries == 0) { params.max_queries = queries.extent(0); } std::unique_ptr> plan = factory::create( @@ -74,8 +76,8 @@ void search_main(raft::resources const& res, plan->check(neighbors.extent(1)); RAFT_LOG_DEBUG("Cagra search"); - uint32_t max_queries = plan->max_queries; - uint32_t query_dim = queries.extent(1); + const uint32_t max_queries = plan->max_queries; + const uint32_t query_dim = queries.extent(1); for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) { const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 8ab6b19b98..2f34febdd2 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -158,7 +158,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( assert(blockDim.x == BLOCK_SIZE); assert(dataset_dim <= MAX_DATASET_DIM); - // const auto num_queries = gridDim.y; + const auto num_queries = gridDim.y; const auto query_id = blockIdx.y; const auto num_cta_per_query = gridDim.x; const auto cta_id = blockIdx.x; // local CTA ID @@ -225,6 +225,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( // compute distance to randomly selecting nodes _CLK_START(); const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + uint32_t block_id = cta_id + (num_cta_per_query * query_id); + uint32_t num_blocks = num_cta_per_query * num_queries; device::compute_distance_to_random_nodes( result_indices_buffer, result_distances_buffer, @@ -240,8 +242,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( num_seeds, local_visited_hashmap_ptr, hash_bitlen, - cta_id, - num_cta_per_query); + block_id, + num_blocks); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -472,14 +474,14 @@ struct search : public search_plan_impl { topk_workspace(0, resource::get_cuda_stream(res)) { - set_params(res); + set_params(res, params); } - void set_params(raft::resources const& res) + void set_params(raft::resources const& res, const search_params& params) { this->itopk_size = 32; num_parents = 1; - num_cta_per_query = max(num_parents, itopk_size / 32); + num_cta_per_query = max(params.num_parents, params.itopk_size / 32); result_buffer_size = itopk_size + num_parents * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); @@ -532,8 +534,10 @@ struct search : public search_plan_impl { // Allocate memory for intermediate buffer and workspace. // uint32_t num_intermediate_results = num_cta_per_query * itopk_size; - intermediate_indices.resize(num_intermediate_results, resource::get_cuda_stream(res)); - intermediate_distances.resize(num_intermediate_results, resource::get_cuda_stream(res)); + intermediate_indices.resize(num_intermediate_results * max_queries, + resource::get_cuda_stream(res)); + intermediate_distances.resize(num_intermediate_results * max_queries, + resource::get_cuda_stream(res)); hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 366e9bfcd5..93eeb0dead 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -127,7 +127,8 @@ void search_impl(raft::resources const& handle, stream); RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - matrix::detail::select_k(distance_buffer_dev.data(), + matrix::detail::select_k(handle, + distance_buffer_dev.data(), nullptr, n_queries, index.n_lists(), @@ -135,7 +136,6 @@ void search_impl(raft::resources const& handle, coarse_distances_dev.data(), coarse_indices_dev.data(), select_min, - stream, search_mr); RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes); RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes); @@ -191,7 +191,8 @@ void search_impl(raft::resources const& handle, // Merge topk values from different blocks if (grid_dim_x > 1) { - matrix::detail::select_k(refined_distances_dev.data(), + matrix::detail::select_k(handle, + refined_distances_dev.data(), refined_indices_dev.data(), n_queries, k * grid_dim_x, @@ -199,7 +200,6 @@ void search_impl(raft::resources const& handle, distances, neighbors, select_min, - stream, search_mr); } } diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 8257f5ed35..82b1ac1542 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -151,7 +151,8 @@ void select_clusters(raft::resources const& handle, // Select neighbor clusters for each query. rmm::device_uvector cluster_dists(n_queries * n_probes, stream, mr); - matrix::detail::select_k(qc_distances.data(), + matrix::detail::select_k(handle, + qc_distances.data(), nullptr, n_queries, n_lists, @@ -159,7 +160,6 @@ void select_clusters(raft::resources const& handle, cluster_dists.data(), clusters_to_probe, true, - stream, mr); } @@ -581,7 +581,8 @@ void ivfpq_search_worker(raft::resources const& handle, // Select topk vectors for each query rmm::device_uvector topk_dists(n_queries * topK, stream, mr); - matrix::detail::select_k(distances_buf.data(), + matrix::detail::select_k(handle, + distances_buf.data(), neighbors_ptr, n_queries, topk_len, @@ -589,7 +590,6 @@ void ivfpq_search_worker(raft::resources const& handle, topk_dists.data(), neighbors_uint32, true, - stream, mr); // Postprocessing diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 5cb9f6d0ab..123a902ef9 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -238,7 +238,8 @@ void tiled_brute_force_knn(const raft::resources& handle, distances + i * k, current_query_size, current_k), raft::make_device_matrix_view( indices + i * k, current_query_size, current_k), - select_min); + select_min, + true); // if we're tiling over columns, we need to do a couple things to fix up // the output of select_k @@ -280,7 +281,8 @@ void tiled_brute_force_knn(const raft::resources& handle, distances + i * k, current_query_size, k), raft::make_device_matrix_view( indices + i * k, current_query_size, k), - select_min); + select_min, + true); } } } diff --git a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh index 630457158b..e87ef99469 100644 --- a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh @@ -19,9 +19,9 @@ #include #include +#include "common.hpp" #include #include -#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/common.h b/cpp/include/raft/sparse/distance/detail/common.hpp similarity index 95% rename from cpp/include/raft/sparse/distance/common.h rename to cpp/include/raft/sparse/distance/detail/common.hpp index 0b866bdc55..0f463dac80 100644 --- a/cpp/include/raft/sparse/distance/common.h +++ b/cpp/include/raft/sparse/distance/detail/common.hpp @@ -21,6 +21,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template struct distances_config_t { @@ -52,6 +53,7 @@ class distances_t { virtual ~distances_t() = default; }; +}; // namespace detail }; // namespace distance -} // namespace sparse +}; // namespace sparse }; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh index 3a8cf53b6e..9c233ecc19 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh @@ -26,7 +26,7 @@ #include "../../csr.hpp" #include "../../detail/utils.h" -#include "../common.h" +#include "common.hpp" #include diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh index 138471c6cf..1c2f83c69b 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../../common.h" +#include "../common.hpp" #include "../coo_spmv_kernel.cuh" #include "../utils.cuh" #include "coo_mask_row_iterators.cuh" diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh index 1fbce51caf..4c061336b3 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include "../../common.h" +#include "../common.hpp" #include "../utils.cuh" #include diff --git a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh index ef5bae8aa0..39e67acdea 100644 --- a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh @@ -23,11 +23,11 @@ #include #include +#include "common.hpp" #include #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh index 5293b36a26..acae3dc445 100644 --- a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh @@ -19,12 +19,12 @@ #include #include +#include "common.hpp" #include #include #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index ac78068247..5ee2cd7b15 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -29,8 +29,8 @@ #include #include +#include "common.hpp" #include -#include #include diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 510e02822e..b60940341a 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,11 @@ #pragma once -#include +#include "detail/common.hpp" #include +#include + #include #include @@ -66,7 +68,7 @@ static const std::unordered_set supportedDistance{ */ template void pairwiseDistance(value_t* out, - distances_config_t input_config, + detail::distances_config_t input_config, raft::distance::DistanceType metric, float metric_arg) { @@ -130,8 +132,94 @@ void pairwiseDistance(value_t* out, } } -}; // namespace distance -}; // namespace sparse -}; // namespace raft +/** + * @defgroup sparse_distance Sparse Pairwise Distance + * @{ + */ + +/** + * @brief Compute pairwise distances between x and y, using the provided + * input configuration and distance function. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int x_n_rows = 100000; + * int y_n_rows = 50000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * auto x = raft::make_device_csr_matrix(handle, x_n_rows, n_cols); + * auto y = raft::make_device_csr_matrix(handle, y_n_rows, n_cols); + * + * ... + * // populate data + * ... + * + * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); + * auto metric = raft::distance::DistanceType::L2Expanded; + * raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric); + * @endcode + * + * @tparam DeviceCSRMatrix raft::device_csr_matrix or raft::device_csr_matrix_view + * @tparam ElementType data-type of inputs and output + * @tparam IndexType data-type for indexing + * + * @param[in] handle raft::resources + * @param[in] x raft::device_csr_matrix_view + * @param[in] y raft::device_csr_matrix_view + * @param[out] dist raft::device_matrix_view dense matrix + * @param[in] metric distance metric to use + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +template >> +void pairwise_distance(raft::resources const& handle, + DeviceCSRMatrix x, + DeviceCSRMatrix y, + raft::device_matrix_view dist, + raft::distance::DistanceType metric, + float metric_arg = 2.0f) +{ + auto x_structure = x.structure_view(); + auto y_structure = y.structure_view(); + + RAFT_EXPECTS(x_structure.get_n_cols() == y_structure.get_n_cols(), + "Number of columns must be equal"); + + RAFT_EXPECTS(dist.extent(0) == x_structure.get_n_rows(), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y_structure.get_n_rows(), + "Number of columns in output must be equal to " + "number of rows in Y"); + + detail::distances_config_t input_config(handle); + input_config.a_nrows = x_structure.get_n_rows(); + input_config.a_ncols = x_structure.get_n_cols(); + input_config.a_nnz = x_structure.get_nnz(); + input_config.a_indptr = const_cast(x_structure.get_indptr().data()); + input_config.a_indices = const_cast(x_structure.get_indices().data()); + input_config.a_data = const_cast(x.get_elements().data()); + + input_config.b_nrows = y_structure.get_n_rows(); + input_config.b_ncols = y_structure.get_n_cols(); + input_config.b_nnz = y_structure.get_nnz(); + input_config.b_indptr = const_cast(y_structure.get_indptr().data()); + input_config.b_indices = const_cast(y_structure.get_indices().data()); + input_config.b_data = const_cast(y.get_elements().data()); + + pairwiseDistance(dist.data_handle(), input_config, metric, metric_arg); +} + +/** @} */ // end of sparse_distance + +}; // namespace distance +}; // namespace sparse +}; // namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/sparse/neighbors/detail/knn.cuh b/cpp/include/raft/sparse/neighbors/detail/knn.cuh index 7d7bcba443..cfb1a6403b 100644 --- a/cpp/include/raft/sparse/neighbors/detail/knn.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/knn.cuh @@ -390,7 +390,7 @@ class sparse_knn_t { /** * Compute distances */ - raft::sparse::distance::distances_config_t dist_config(handle); + raft::sparse::distance::detail::distances_config_t dist_config(handle); dist_config.b_nrows = idx_batcher.batch_rows(); dist_config.b_ncols = n_idx_cols; dist_config.b_nnz = idx_batch_nnz; diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index 013a61886f..b72e67580a 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -101,10 +101,15 @@ void select_k_impl(const resources& handle, if (in_idx == nullptr) { // NB: std::nullopt prevents automatic inference of the template parameters. return matrix::select_k( - handle, in_span, std::nullopt, out_span, out_idx_span, select_min); + handle, in_span, std::nullopt, out_span, out_idx_span, select_min, true); } else { - return matrix::select_k( - handle, in_span, std::make_optional(in_idx_span), out_span, out_idx_span, select_min); + return matrix::select_k(handle, + in_span, + std::make_optional(in_idx_span), + out_span, + out_idx_span, + select_min, + true); } } case Algo::kRadix8bits: diff --git a/cpp/internal/raft_internal/neighbors/naive_knn.cuh b/cpp/internal/raft_internal/neighbors/naive_knn.cuh index 3ad055272b..8565735672 100644 --- a/cpp/internal/raft_internal/neighbors/naive_knn.cuh +++ b/cpp/internal/raft_internal/neighbors/naive_knn.cuh @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -78,7 +79,8 @@ __global__ void naive_distance_kernel(EvalT* dist, * when either distance or brute_force_knn support 8-bit int inputs. */ template -void naive_knn(EvalT* dist_topk, +void naive_knn(raft::resources const& handle, + EvalT* dist_topk, IdxT* indices_topk, const DataT* x, const DataT* y, @@ -86,12 +88,12 @@ void naive_knn(EvalT* dist_topk, size_t input_len, size_t dim, uint32_t k, - raft::distance::DistanceType type, - rmm::cuda_stream_view stream) + raft::distance::DistanceType type) { rmm::mr::device_memory_resource* mr = nullptr; auto pool_guard = raft::get_pool_memory_resource(mr, 1024 * 1024); + auto stream = raft::resource::get_cuda_stream(handle); dim3 block_dim(16, 32, 1); // maximum reasonable grid size in `y` direction auto grid_y = @@ -109,7 +111,8 @@ void naive_knn(EvalT* dist_topk, naive_distance_kernel<<>>( dist.data(), x + offset * dim, y, batch_size, input_len, dim, type); - matrix::detail::select_k(dist.data(), + matrix::detail::select_k(handle, + dist.data(), nullptr, batch_size, input_len, @@ -117,7 +120,6 @@ void naive_knn(EvalT* dist_topk, dist_topk + offset * k, indices_topk + offset * k, type != raft::distance::DistanceType::InnerProduct, - stream, mr); } RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); diff --git a/cpp/internal/raft_internal/neighbors/refine_helper.cuh b/cpp/internal/raft_internal/neighbors/refine_helper.cuh index 67217d1e0e..ee06d90851 100644 --- a/cpp/internal/raft_internal/neighbors/refine_helper.cuh +++ b/cpp/internal/raft_internal/neighbors/refine_helper.cuh @@ -80,7 +80,8 @@ class RefineHelper { { candidates = raft::make_device_matrix(handle_, p.n_queries, p.k0); rmm::device_uvector distances_tmp(p.n_queries * p.k0, stream_); - naive_knn(distances_tmp.data(), + naive_knn(handle_, + distances_tmp.data(), candidates.data_handle(), queries.data_handle(), dataset.data_handle(), @@ -88,8 +89,7 @@ class RefineHelper { p.n_rows, p.dim, p.k0, - p.metric, - stream_); + p.metric); resource::sync_stream(handle_, stream_); } @@ -112,7 +112,8 @@ class RefineHelper { { rmm::device_uvector distances_dev(p.n_queries * p.k, stream_); rmm::device_uvector indices_dev(p.n_queries * p.k, stream_); - naive_knn(distances_dev.data(), + naive_knn(handle_, + distances_dev.data(), indices_dev.data(), queries.data_handle(), dataset.data_handle(), @@ -120,8 +121,7 @@ class RefineHelper { p.n_rows, p.dim, p.k, - p.metric, - stream_); + p.metric); true_refined_distances_host.resize(p.n_queries * p.k); true_refined_indices_host.resize(p.n_queries * p.k); raft::copy(true_refined_indices_host.data(), indices_dev.data(), indices_dev.size(), stream_); diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index 022627283a..c75a5b5261 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -16,17 +16,18 @@ #include -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k(const T* in_val, \ - const IdxT* in_idx, \ - size_t batch_size, \ - size_t len, \ - int k, \ - T* out_val, \ - IdxT* out_idx, \ - bool select_min, \ - rmm::cuda_stream_view stream, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr, \ + bool sorted) instantiate_raft_matrix_detail_select_k(double, int64_t); diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 22c6989337..171c8a1ae7 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -17,17 +17,18 @@ #include // uint32_t #include -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k(const T* in_val, \ - const IdxT* in_idx, \ - size_t batch_size, \ - size_t len, \ - int k, \ - T* out_val, \ - IdxT* out_idx, \ - bool select_min, \ - rmm::cuda_stream_view stream, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr, \ + bool sorted) instantiate_raft_matrix_detail_select_k(double, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index 42094bbb67..a21444dc0c 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -16,17 +16,18 @@ #include -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k(const T* in_val, \ - const IdxT* in_idx, \ - size_t batch_size, \ - size_t len, \ - int k, \ - T* out_val, \ - IdxT* out_idx, \ - bool select_min, \ - rmm::cuda_stream_view stream, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr, \ + bool sorted) instantiate_raft_matrix_detail_select_k(float, int); diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 1f1d686048..9542874ec0 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -16,17 +16,18 @@ #include -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k(const T* in_val, \ - const IdxT* in_idx, \ - size_t batch_size, \ - size_t len, \ - int k, \ - T* out_val, \ - IdxT* out_idx, \ - bool select_min, \ - rmm::cuda_stream_view stream, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr, \ + bool sorted) instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index 3bb47acbf2..fbf311d9bd 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -16,17 +16,18 @@ #include -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k(const T* in_val, \ - const IdxT* in_idx, \ - size_t batch_size, \ - size_t len, \ - int k, \ - T* out_val, \ - IdxT* out_idx, \ - bool select_min, \ - rmm::cuda_stream_view stream, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr, \ + bool sorted) instantiate_raft_matrix_detail_select_k(float, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index cf4e15959d..fdbfd66c46 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -16,17 +16,18 @@ #include -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k(const T* in_val, \ - const IdxT* in_idx, \ - size_t batch_size, \ - size_t len, \ - int k, \ - T* out_val, \ - IdxT* out_idx, \ - bool select_min, \ - rmm::cuda_stream_view stream, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr, \ + bool sorted) instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index b18887bfc0..48a3e91f9d 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -16,17 +16,18 @@ #include -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - template void raft::matrix::detail::select_k(const T* in_val, \ - const IdxT* in_idx, \ - size_t batch_size, \ - size_t len, \ - int k, \ - T* out_val, \ - IdxT* out_idx, \ - bool select_min, \ - rmm::cuda_stream_view stream, \ - rmm::mr::device_memory_resource* mr) +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr, \ + bool sorted) instantiate_raft_matrix_detail_select_k(__half, uint32_t); diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 702fd1c407..487b6d0bfd 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -190,6 +190,7 @@ struct io_computed { }; template + using Params = std::tuple; template typename ParamsReader> diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index d3bd5ba31d..3e929f9f3b 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -166,9 +166,6 @@ class AnnCagraTest : public ::testing::TestWithParam { protected: void testCagra() { - if (ps.algo == search_algo::MULTI_CTA && ps.max_queries > 1) { - GTEST_SKIP() << "Skipping test due to issue #1575"; - } size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); std::vector indices_naive(queries_size); @@ -178,7 +175,8 @@ class AnnCagraTest : public ::testing::TestWithParam { { rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(distances_naive_dev.data(), + naive_knn(handle_, + distances_naive_dev.data(), indices_naive_dev.data(), search_queries.data(), database.data(), @@ -186,8 +184,7 @@ class AnnCagraTest : public ::testing::TestWithParam { ps.n_rows, ps.dim, ps.k, - ps.metric, - stream_); + ps.metric); update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); resource::sync_stream(handle_); @@ -377,9 +374,9 @@ inline std::vector generate_inputs() {100}, {1000}, {1, 8, 17}, - {1, 16}, // k + {1, 16}, // k {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, - {1, 10, 100}, // query size + {0, 1, 10, 100}, // query size {0}, {256}, {1}, diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 88bf53280b..a252b26600 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -88,7 +88,8 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { { rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(distances_naive_dev.data(), + naive_knn(handle_, + distances_naive_dev.data(), indices_naive_dev.data(), search_queries.data(), database.data(), @@ -96,8 +97,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { ps.num_db_vecs, ps.dim, ps.k, - ps.metric, - stream_); + ps.metric); update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); resource::sync_stream(handle_); diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index de4453a034..e03d09ae50 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -186,7 +186,8 @@ class ivf_pq_test : public ::testing::TestWithParam { size_t queries_size = size_t{ps.num_queries} * size_t{ps.k}; rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(distances_naive_dev.data(), + naive_knn(handle_, + distances_naive_dev.data(), indices_naive_dev.data(), search_queries.data(), database.data(), @@ -194,8 +195,7 @@ class ivf_pq_test : public ::testing::TestWithParam { ps.num_db_vecs, ps.dim, ps.k, - ps.index_params.metric, - stream_); + ps.index_params.metric); distances_ref.resize(queries_size); update_host(distances_ref.data(), distances_naive_dev.data(), queries_size, stream_); indices_ref.resize(queries_size); diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index 2b7e8233a5..c729334d00 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -245,7 +245,7 @@ class SparseDistanceCOOSPMVTest // output data rmm::device_uvector out_dists, out_dists_ref; - raft::sparse::distance::distances_config_t dist_config; + raft::sparse::distance::detail::distances_config_t dist_config; SparseDistanceCOOSPMVInputs params; }; diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index debb439345..6b4e5c7cfa 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -61,7 +61,6 @@ class SparseDistanceTest public: SparseDistanceTest() : params(::testing::TestWithParam>::GetParam()), - dist_config(handle), indptr(0, resource::get_cuda_stream(handle)), indices(0, resource::get_cuda_stream(handle)), data(0, resource::get_cuda_stream(handle)), @@ -74,24 +73,25 @@ class SparseDistanceTest { make_data(); - dist_config.b_nrows = params.indptr_h.size() - 1; - dist_config.b_ncols = params.n_cols; - dist_config.b_nnz = params.indices_h.size(); - dist_config.b_indptr = indptr.data(); - dist_config.b_indices = indices.data(); - dist_config.b_data = data.data(); - dist_config.a_nrows = params.indptr_h.size() - 1; - dist_config.a_ncols = params.n_cols; - dist_config.a_nnz = params.indices_h.size(); - dist_config.a_indptr = indptr.data(); - dist_config.a_indices = indices.data(); - dist_config.a_data = data.data(); - - int out_size = dist_config.a_nrows * dist_config.b_nrows; + int out_size = static_cast(params.indptr_h.size() - 1) * + static_cast(params.indptr_h.size() - 1); out_dists.resize(out_size, resource::get_cuda_stream(handle)); - pairwiseDistance(out_dists.data(), dist_config, params.metric, params.metric_arg); + auto out = raft::make_device_matrix_view( + out_dists.data(), + static_cast(params.indptr_h.size() - 1), + static_cast(params.indptr_h.size() - 1)); + + auto x_structure = raft::make_device_compressed_structure_view( + indptr.data(), + indices.data(), + static_cast(params.indptr_h.size() - 1), + params.n_cols, + static_cast(params.indices_h.size())); + auto x = raft::make_device_csr_matrix_view(data.data(), x_structure); + + pairwise_distance(handle, x, x, out, params.metric, params.metric_arg); RAFT_CUDA_TRY(cudaStreamSynchronize(resource::get_cuda_stream(handle))); } @@ -127,7 +127,7 @@ class SparseDistanceTest update_device(out_dists_ref.data(), out_dists_ref_h.data(), out_dists_ref_h.size(), - resource::get_cuda_stream(dist_config.handle)); + resource::get_cuda_stream(handle)); } raft::resources handle; @@ -140,7 +140,6 @@ class SparseDistanceTest rmm::device_uvector out_dists, out_dists_ref; SparseDistanceInputs params; - raft::sparse::distance::distances_config_t dist_config; }; const std::vector> inputs_i32_f = { diff --git a/python/raft-dask/raft_dask/test/__init__.py b/python/raft-dask/raft_dask/test/__init__.py deleted file mode 100644 index 764e0f32fd..0000000000 --- a/python/raft-dask/raft_dask/test/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 5c69a94fd8..68c9fee556 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -17,9 +17,8 @@ import pytest -from dask.distributed import get_worker, wait - -from .conftest import create_client +from dask.distributed import Client, get_worker, wait +from dask_cuda import LocalCUDACluster try: from raft_dask.common import ( @@ -44,6 +43,29 @@ pytestmark = pytest.mark.skip +def create_client(cluster): + """ + Create a Dask distributed client for a specified cluster. + + Parameters + ---------- + cluster : LocalCUDACluster instance or str + If a LocalCUDACluster instance is provided, a client will be created + for it directly. If a string is provided, it should specify the path to + a Dask scheduler file. A client will then be created for the cluster + referenced by this scheduler file. + + Returns + ------- + dask.distributed.Client + A client connected to the specified cluster. + """ + if isinstance(cluster, LocalCUDACluster): + return Client(cluster) + else: + return Client(scheduler_file=cluster) + + def test_comms_init_no_p2p(cluster): client = create_client(cluster) try: