diff --git a/cpp/include/raft/core/device_coo_matrix.hpp b/cpp/include/raft/core/device_coo_matrix.hpp index 67aa4e12f1..9ab5160ef5 100644 --- a/cpp/include/raft/core/device_coo_matrix.hpp +++ b/cpp/include/raft/core/device_coo_matrix.hpp @@ -23,14 +23,21 @@ namespace raft { -template +using device_coordinate_structure_view = coordinate_structure_view; + +/** + * Specialization for a sparsity-owning coordinate structure which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy, - SparsityType sparsity_type = SparsityType::OWNING> -using device_coo_matrix = - coo_matrix; + template typename ContainerPolicy = device_uvector_policy> +using device_coordinate_structure = + coordinate_structure; /** * Specialization for a coo matrix view which uses device memory @@ -38,6 +45,15 @@ using device_coo_matrix = template using device_coo_matrix_view = coo_matrix_view; +template typename ContainerPolicy = device_uvector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using device_coo_matrix = + coo_matrix; + /** * Specialization for a sparsity-owning coo matrix which uses device memory */ @@ -62,21 +78,15 @@ using device_sparsity_preserving_coo_matrix = coo_matrix; -/** - * Specialization for a sparsity-owning coordinate structure which uses device memory - */ -template typename ContainerPolicy = device_uvector_policy> -using device_coordinate_structure = - coordinate_structure; +template +struct is_device_coo_matrix_view : std::false_type {}; -/** - * Specialization for a sparsity-preserving coordinate structure view which uses device memory - */ -template -using device_coordinate_structure_view = coordinate_structure_view; +template +struct is_device_coo_matrix_view> + : std::true_type {}; + +template +constexpr bool is_device_coo_matrix_view_v = is_device_coo_matrix_view::value; template struct is_device_coo_matrix : std::false_type {}; diff --git a/cpp/include/raft/core/device_csr_matrix.hpp b/cpp/include/raft/core/device_csr_matrix.hpp index 1495609d75..df186cc194 100644 --- a/cpp/include/raft/core/device_csr_matrix.hpp +++ b/cpp/include/raft/core/device_csr_matrix.hpp @@ -25,6 +25,29 @@ namespace raft { +/** + * Specialization for a sparsity-preserving compressed structure view which uses device memory + */ +template +using device_compressed_structure_view = + compressed_structure_view; + +/** + * Specialization for a sparsity-owning compressed structure which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_compressed_structure = + compressed_structure; + +/** + * Specialization for a csr matrix view which uses device memory + */ +template +using device_csr_matrix_view = csr_matrix_view; + template ; +/** + * Specialization for a sparsity-preserving csr matrix which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_sparsity_preserving_csr_matrix = csr_matrix; + +template +struct is_device_csr_matrix_view : std::false_type {}; + +template +struct is_device_csr_matrix_view< + device_csr_matrix_view> : std::true_type {}; + +template +constexpr bool is_device_csr_matrix_view_v = is_device_csr_matrix_view::value; + template struct is_device_csr_matrix : std::false_type {}; @@ -70,51 +119,6 @@ template constexpr bool is_device_csr_sparsity_preserving_v = is_device_csr_matrix::value and T::get_sparsity_type() == PRESERVING; -/** - * Specialization for a csr matrix view which uses device memory - */ -template -using device_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-preserving csr matrix which uses device memory - */ -template typename ContainerPolicy = device_uvector_policy> -using device_sparsity_preserving_csr_matrix = csr_matrix; - -/** - * Specialization for a csr matrix view which uses device memory - */ -template -using device_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-owning compressed structure which uses device memory - */ -template typename ContainerPolicy = device_uvector_policy> -using device_compressed_structure = - compressed_structure; - -/** - * Specialization for a sparsity-preserving compressed structure view which uses device memory - */ -template -using device_compressed_structure_view = - compressed_structure_view; - /** * Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning * means that all of the underlying vectors (data, indptr, indices) are owned by the csr_matrix diff --git a/cpp/include/raft/core/host_coo_matrix.hpp b/cpp/include/raft/core/host_coo_matrix.hpp index 32e7a9e3c4..9e6aacfa48 100644 --- a/cpp/include/raft/core/host_coo_matrix.hpp +++ b/cpp/include/raft/core/host_coo_matrix.hpp @@ -22,14 +22,21 @@ namespace raft { -template +using host_coordinate_structure_view = coordinate_structure_view; + +/** + * Specialization for a sparsity-owning coordinate structure which uses host memory + */ +template typename ContainerPolicy = host_vector_policy, - SparsityType sparsity_type = SparsityType::OWNING> -using host_coo_matrix = - coo_matrix; + template typename ContainerPolicy = host_vector_policy> +using host_coordinate_structure = + coordinate_structure; /** * Specialization for a coo matrix view which uses host memory @@ -37,6 +44,15 @@ using host_coo_matrix = template using host_coo_matrix_view = coo_matrix_view; +template typename ContainerPolicy = host_vector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using host_coo_matrix = + coo_matrix; + /** * Specialization for a sparsity-owning coo matrix which uses host memory */ @@ -61,21 +77,15 @@ using host_sparsity_preserving_coo_matrix = coo_matrix; -/** - * Specialization for a sparsity-owning coordinate structure which uses host memory - */ -template typename ContainerPolicy = host_vector_policy> -using host_coordinate_structure = - coordinate_structure; +template +struct is_host_coo_matrix_view : std::false_type {}; -/** - * Specialization for a sparsity-preserving coordinate structure view which uses host memory - */ -template -using host_coordinate_structure_view = coordinate_structure_view; +template +struct is_host_coo_matrix_view> + : std::true_type {}; + +template +constexpr bool is_host_coo_matrix_view_v = is_host_coo_matrix_view::value; template struct is_host_coo_matrix : std::false_type {}; diff --git a/cpp/include/raft/core/host_csr_matrix.hpp b/cpp/include/raft/core/host_csr_matrix.hpp index 86199335f2..4b4df823db 100644 --- a/cpp/include/raft/core/host_csr_matrix.hpp +++ b/cpp/include/raft/core/host_csr_matrix.hpp @@ -24,6 +24,29 @@ namespace raft { +/** + * Specialization for a sparsity-preserving compressed structure view which uses host memory + */ +template +using host_compressed_structure_view = + compressed_structure_view; + +/** + * Specialization for a sparsity-owning compressed structure which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_compressed_structure = + compressed_structure; + +/** + * Specialization for a csr matrix view which uses host memory + */ +template +using host_csr_matrix_view = csr_matrix_view; + template ; +/** + * Specialization for a sparsity-preserving csr matrix which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_sparsity_preserving_csr_matrix = csr_matrix; + +template +struct is_host_csr_matrix_view : std::false_type {}; + +template +struct is_host_csr_matrix_view> + : std::true_type {}; + +template +constexpr bool is_host_csr_matrix_view_v = is_host_csr_matrix_view::value; + template struct is_host_csr_matrix : std::false_type {}; @@ -66,53 +115,9 @@ constexpr bool is_host_csr_sparsity_owning_v = is_host_csr_matrix::value and T::get_sparsity_type() == OWNING; template -constexpr bool is_host_csr_sparsity_preserving_v = - is_host_csr_matrix::value and T::get_sparsity_type() == PRESERVING; - -/** - * Specialization for a csr matrix view which uses host memory - */ -template -using host_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-preserving csr matrix which uses host memory - */ -template typename ContainerPolicy = host_vector_policy> -using host_sparsity_preserving_csr_matrix = csr_matrix; - -/** - * Specialization for a csr matrix view which uses host memory - */ -template -using host_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-owning compressed structure which uses host memory - */ -template typename ContainerPolicy = host_vector_policy> -using host_compressed_structure = - compressed_structure; - -/** - * Specialization for a sparsity-preserving compressed structure view which uses host memory - */ -template -using host_compressed_structure_view = - compressed_structure_view; +constexpr bool is_host_csr_sparsity_preserving_v = std::disjunction_v< + is_host_csr_matrix_view, + std::bool_constant::value and T::get_sparsity_type() == PRESERVING>>; /** * Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index 9b079a8539..e121c1be9c 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -471,41 +471,18 @@ class GramMatrixBase { ASSERT(is_row_major_nopad || is_col_major_nopad, "Sparse linear Kernel distance does not support ld_out parameter"); - auto x1_structure = x1.structure_view(); - auto x2_structure = x2.structure_view(); - raft::sparse::distance::distances_config_t dist_config(handle); - - // switch a,b based on data layout + // switch a,b based on is_row_major if (is_col_major_nopad) { - dist_config.a_nrows = x2_structure.get_n_rows(); - dist_config.a_ncols = x2_structure.get_n_cols(); - dist_config.a_nnz = x2_structure.get_nnz(); - dist_config.a_indptr = const_cast(x2_structure.get_indptr().data()); - dist_config.a_indices = const_cast(x2_structure.get_indices().data()); - dist_config.a_data = const_cast(x2.get_elements().data()); - dist_config.b_nrows = x1_structure.get_n_rows(); - dist_config.b_ncols = x1_structure.get_n_cols(); - dist_config.b_nnz = x1_structure.get_nnz(); - dist_config.b_indptr = const_cast(x1_structure.get_indptr().data()); - dist_config.b_indices = const_cast(x1_structure.get_indices().data()); - dist_config.b_data = const_cast(x1.get_elements().data()); + auto out_row_major = raft::make_device_matrix_view( + out.data_handle(), out.extent(1), out.extent(0)); + raft::sparse::distance::pairwise_distance( + handle, x2, x1, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0); } else { - dist_config.a_nrows = x1_structure.get_n_rows(); - dist_config.a_ncols = x1_structure.get_n_cols(); - dist_config.a_nnz = x1_structure.get_nnz(); - dist_config.a_indptr = const_cast(x1_structure.get_indptr().data()); - dist_config.a_indices = const_cast(x1_structure.get_indices().data()); - dist_config.a_data = const_cast(x1.get_elements().data()); - dist_config.b_nrows = x2_structure.get_n_rows(); - dist_config.b_ncols = x2_structure.get_n_cols(); - dist_config.b_nnz = x2_structure.get_nnz(); - dist_config.b_indptr = const_cast(x2_structure.get_indptr().data()); - dist_config.b_indices = const_cast(x2_structure.get_indices().data()); - dist_config.b_data = const_cast(x2.get_elements().data()); + auto out_row_major = raft::make_device_matrix_view( + out.data_handle(), out.extent(0), out.extent(1)); + raft::sparse::distance::pairwise_distance( + handle, x1, x2, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0); } - - raft::sparse::distance::pairwiseDistance( - out.data_handle(), dist_config, raft::distance::DistanceType::InnerProduct, 0.0); } }; diff --git a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh index 630457158b..e87ef99469 100644 --- a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh @@ -19,9 +19,9 @@ #include #include +#include "common.hpp" #include #include -#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/common.h b/cpp/include/raft/sparse/distance/detail/common.hpp similarity index 95% rename from cpp/include/raft/sparse/distance/common.h rename to cpp/include/raft/sparse/distance/detail/common.hpp index 0b866bdc55..0f463dac80 100644 --- a/cpp/include/raft/sparse/distance/common.h +++ b/cpp/include/raft/sparse/distance/detail/common.hpp @@ -21,6 +21,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template struct distances_config_t { @@ -52,6 +53,7 @@ class distances_t { virtual ~distances_t() = default; }; +}; // namespace detail }; // namespace distance -} // namespace sparse +}; // namespace sparse }; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh index 3a8cf53b6e..9c233ecc19 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh @@ -26,7 +26,7 @@ #include "../../csr.hpp" #include "../../detail/utils.h" -#include "../common.h" +#include "common.hpp" #include diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh index 138471c6cf..1c2f83c69b 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../../common.h" +#include "../common.hpp" #include "../coo_spmv_kernel.cuh" #include "../utils.cuh" #include "coo_mask_row_iterators.cuh" diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh index 1fbce51caf..4c061336b3 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include "../../common.h" +#include "../common.hpp" #include "../utils.cuh" #include diff --git a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh index ef5bae8aa0..39e67acdea 100644 --- a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh @@ -23,11 +23,11 @@ #include #include +#include "common.hpp" #include #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh index 5293b36a26..acae3dc445 100644 --- a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh @@ -19,12 +19,12 @@ #include #include +#include "common.hpp" #include #include #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index ac78068247..5ee2cd7b15 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -29,8 +29,8 @@ #include #include +#include "common.hpp" #include -#include #include diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 510e02822e..b60940341a 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,11 @@ #pragma once -#include +#include "detail/common.hpp" #include +#include + #include #include @@ -66,7 +68,7 @@ static const std::unordered_set supportedDistance{ */ template void pairwiseDistance(value_t* out, - distances_config_t input_config, + detail::distances_config_t input_config, raft::distance::DistanceType metric, float metric_arg) { @@ -130,8 +132,94 @@ void pairwiseDistance(value_t* out, } } -}; // namespace distance -}; // namespace sparse -}; // namespace raft +/** + * @defgroup sparse_distance Sparse Pairwise Distance + * @{ + */ + +/** + * @brief Compute pairwise distances between x and y, using the provided + * input configuration and distance function. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int x_n_rows = 100000; + * int y_n_rows = 50000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * auto x = raft::make_device_csr_matrix(handle, x_n_rows, n_cols); + * auto y = raft::make_device_csr_matrix(handle, y_n_rows, n_cols); + * + * ... + * // populate data + * ... + * + * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); + * auto metric = raft::distance::DistanceType::L2Expanded; + * raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric); + * @endcode + * + * @tparam DeviceCSRMatrix raft::device_csr_matrix or raft::device_csr_matrix_view + * @tparam ElementType data-type of inputs and output + * @tparam IndexType data-type for indexing + * + * @param[in] handle raft::resources + * @param[in] x raft::device_csr_matrix_view + * @param[in] y raft::device_csr_matrix_view + * @param[out] dist raft::device_matrix_view dense matrix + * @param[in] metric distance metric to use + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +template >> +void pairwise_distance(raft::resources const& handle, + DeviceCSRMatrix x, + DeviceCSRMatrix y, + raft::device_matrix_view dist, + raft::distance::DistanceType metric, + float metric_arg = 2.0f) +{ + auto x_structure = x.structure_view(); + auto y_structure = y.structure_view(); + + RAFT_EXPECTS(x_structure.get_n_cols() == y_structure.get_n_cols(), + "Number of columns must be equal"); + + RAFT_EXPECTS(dist.extent(0) == x_structure.get_n_rows(), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y_structure.get_n_rows(), + "Number of columns in output must be equal to " + "number of rows in Y"); + + detail::distances_config_t input_config(handle); + input_config.a_nrows = x_structure.get_n_rows(); + input_config.a_ncols = x_structure.get_n_cols(); + input_config.a_nnz = x_structure.get_nnz(); + input_config.a_indptr = const_cast(x_structure.get_indptr().data()); + input_config.a_indices = const_cast(x_structure.get_indices().data()); + input_config.a_data = const_cast(x.get_elements().data()); + + input_config.b_nrows = y_structure.get_n_rows(); + input_config.b_ncols = y_structure.get_n_cols(); + input_config.b_nnz = y_structure.get_nnz(); + input_config.b_indptr = const_cast(y_structure.get_indptr().data()); + input_config.b_indices = const_cast(y_structure.get_indices().data()); + input_config.b_data = const_cast(y.get_elements().data()); + + pairwiseDistance(dist.data_handle(), input_config, metric, metric_arg); +} + +/** @} */ // end of sparse_distance + +}; // namespace distance +}; // namespace sparse +}; // namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/sparse/neighbors/detail/knn.cuh b/cpp/include/raft/sparse/neighbors/detail/knn.cuh index 7d7bcba443..cfb1a6403b 100644 --- a/cpp/include/raft/sparse/neighbors/detail/knn.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/knn.cuh @@ -390,7 +390,7 @@ class sparse_knn_t { /** * Compute distances */ - raft::sparse::distance::distances_config_t dist_config(handle); + raft::sparse::distance::detail::distances_config_t dist_config(handle); dist_config.b_nrows = idx_batcher.batch_rows(); dist_config.b_ncols = n_idx_cols; dist_config.b_nnz = idx_batch_nnz; diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index 2b7e8233a5..c729334d00 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -245,7 +245,7 @@ class SparseDistanceCOOSPMVTest // output data rmm::device_uvector out_dists, out_dists_ref; - raft::sparse::distance::distances_config_t dist_config; + raft::sparse::distance::detail::distances_config_t dist_config; SparseDistanceCOOSPMVInputs params; }; diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index debb439345..6b4e5c7cfa 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -61,7 +61,6 @@ class SparseDistanceTest public: SparseDistanceTest() : params(::testing::TestWithParam>::GetParam()), - dist_config(handle), indptr(0, resource::get_cuda_stream(handle)), indices(0, resource::get_cuda_stream(handle)), data(0, resource::get_cuda_stream(handle)), @@ -74,24 +73,25 @@ class SparseDistanceTest { make_data(); - dist_config.b_nrows = params.indptr_h.size() - 1; - dist_config.b_ncols = params.n_cols; - dist_config.b_nnz = params.indices_h.size(); - dist_config.b_indptr = indptr.data(); - dist_config.b_indices = indices.data(); - dist_config.b_data = data.data(); - dist_config.a_nrows = params.indptr_h.size() - 1; - dist_config.a_ncols = params.n_cols; - dist_config.a_nnz = params.indices_h.size(); - dist_config.a_indptr = indptr.data(); - dist_config.a_indices = indices.data(); - dist_config.a_data = data.data(); - - int out_size = dist_config.a_nrows * dist_config.b_nrows; + int out_size = static_cast(params.indptr_h.size() - 1) * + static_cast(params.indptr_h.size() - 1); out_dists.resize(out_size, resource::get_cuda_stream(handle)); - pairwiseDistance(out_dists.data(), dist_config, params.metric, params.metric_arg); + auto out = raft::make_device_matrix_view( + out_dists.data(), + static_cast(params.indptr_h.size() - 1), + static_cast(params.indptr_h.size() - 1)); + + auto x_structure = raft::make_device_compressed_structure_view( + indptr.data(), + indices.data(), + static_cast(params.indptr_h.size() - 1), + params.n_cols, + static_cast(params.indices_h.size())); + auto x = raft::make_device_csr_matrix_view(data.data(), x_structure); + + pairwise_distance(handle, x, x, out, params.metric, params.metric_arg); RAFT_CUDA_TRY(cudaStreamSynchronize(resource::get_cuda_stream(handle))); } @@ -127,7 +127,7 @@ class SparseDistanceTest update_device(out_dists_ref.data(), out_dists_ref_h.data(), out_dists_ref_h.size(), - resource::get_cuda_stream(dist_config.handle)); + resource::get_cuda_stream(handle)); } raft::resources handle; @@ -140,7 +140,6 @@ class SparseDistanceTest rmm::device_uvector out_dists, out_dists_ref; SparseDistanceInputs params; - raft::sparse::distance::distances_config_t dist_config; }; const std::vector> inputs_i32_f = {