From 9d4d04deb82beb7c46b311f9abc1a338998a1f2a Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 18 Apr 2023 17:30:25 -0700 Subject: [PATCH 1/8] working through --- .../raft/sparse/distance/detail/bin_distance.cuh | 4 ++-- .../distance/{common.h => detail/common.hpp} | 4 +++- .../raft/sparse/distance/detail/coo_spmv.cuh | 4 ++-- .../detail/coo_spmv_strategies/base_strategy.cuh | 4 ++-- .../coo_spmv_strategies/coo_mask_row_iterators.cuh | 4 ++-- .../raft/sparse/distance/detail/ip_distance.cuh | 4 ++-- .../raft/sparse/distance/detail/l2_distance.cuh | 2 +- .../raft/sparse/distance/detail/lp_distance.cuh | 2 +- cpp/include/raft/sparse/distance/distance.cuh | 14 +++++++++++--- 9 files changed, 26 insertions(+), 16 deletions(-) rename cpp/include/raft/sparse/distance/{common.h => detail/common.hpp} (95%) diff --git a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh index cdcb0b7322..466aa7decd 100644 --- a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,9 @@ #include +#include "common.hpp" #include #include -#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/common.h b/cpp/include/raft/sparse/distance/detail/common.hpp similarity index 95% rename from cpp/include/raft/sparse/distance/common.h rename to cpp/include/raft/sparse/distance/detail/common.hpp index 1e5aeb5210..9cc019f27c 100644 --- a/cpp/include/raft/sparse/distance/common.h +++ b/cpp/include/raft/sparse/distance/detail/common.hpp @@ -21,6 +21,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template struct distances_config_t { @@ -52,6 +53,7 @@ class distances_t { virtual ~distances_t() = default; }; +}; // namespace detail }; // namespace distance -} // namespace sparse +}; // namespace sparse }; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh index 53ef0326fb..a80d422901 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ #include "../../csr.hpp" #include "../../detail/utils.h" -#include "../common.h" +#include "common.hpp" #include diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh index c4e39c11a0..6ea804314b 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include "../../common.h" +#include "../common.hpp" #include "../coo_spmv_kernel.cuh" #include "../utils.cuh" #include "coo_mask_row_iterators.cuh" diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh index 1fbce51caf..4c061336b3 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include "../../common.h" +#include "../common.hpp" #include "../utils.cuh" #include diff --git a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh index d45e643780..559ad7687d 100644 --- a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,11 +22,11 @@ #include #include +#include "common.hpp" #include #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh index 2f165b3ff2..3ab2f18c1f 100644 --- a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh @@ -18,12 +18,12 @@ #include +#include "common.hpp" #include #include #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index f67109afbc..f85a3751e8 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -28,8 +28,8 @@ #include #include +#include "common.hpp" #include -#include #include diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 510e02822e..fc9c3ad784 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,11 @@ #pragma once -#include +#include "detail/common.hpp" #include +#include + #include #include @@ -66,7 +68,7 @@ static const std::unordered_set supportedDistance{ */ template void pairwiseDistance(value_t* out, - distances_config_t input_config, + detail::distances_config_t input_config, raft::distance::DistanceType metric, float metric_arg) { @@ -130,6 +132,12 @@ void pairwiseDistance(value_t* out, } } +template >> +void pairwise_distance(raft::resources const& handle, ) +{ +} + }; // namespace distance }; // namespace sparse }; // namespace raft From fb42a008184429dffd31724ed8d2f28bc07bb8cf Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 18 Apr 2023 18:56:12 -0700 Subject: [PATCH 2/8] tests building --- cpp/include/raft/sparse/distance/distance.cuh | 34 ++++++++++++++++++- .../raft/sparse/neighbors/detail/knn.cuh | 2 +- cpp/test/sparse/dist_coo_spmv.cu | 2 +- cpp/test/sparse/distance.cu | 2 +- 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index fc9c3ad784..a8ac9c2ccb 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -133,9 +133,41 @@ void pairwiseDistance(value_t* out, } template >> -void pairwise_distance(raft::resources const& handle, ) +void pairwise_distance( + raft::device_resources const& handle, + DeviceCSRMatrix x, + DeviceCSRMatrix y, + raft::device_matrix_view dist, + raft::distance::DistanceType metric, + float metric_arg = 2.0f) { + RAFT_EXPECTS(x.get_n_cols() == y.get_n_cols(), "Number of columns must be equal"); + RAFT_EXPECTS(dist.extent(0) == x.get_n_rows(), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.get_n_rows(), + "Number of columns in output must be equal to " + "number of rows in Y"); + + detail::distances_config_t input_config(handle); + input_config.a_nrows = x.get_n_rows(); + input_config.a_ncols = x.get_n_cols(); + input_config.a_nnz = x.get_nnz(); + input_config.a_indptr = x.get_indptr().data(); + input_config.a_indices = x.get_indices().data(); + input_config.a_data = x.get_elements().data(); + + input_config.b_nrows = y.get_n_rows(); + input_config.b_ncols = y.get_n_cols(); + input_config.b_nnz = y.get_nnz(); + input_config.b_indptr = y.get_indptr().data(); + input_config.b_indices = y.get_indices().data(); + input_config.b_data = y.get_elements().data(); + + pairwiseDistance(dist.data_handle(), input_config, metric, metric_arg); } }; // namespace distance diff --git a/cpp/include/raft/sparse/neighbors/detail/knn.cuh b/cpp/include/raft/sparse/neighbors/detail/knn.cuh index f9f07c13ca..c94df73b85 100644 --- a/cpp/include/raft/sparse/neighbors/detail/knn.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/knn.cuh @@ -384,7 +384,7 @@ class sparse_knn_t { /** * Compute distances */ - raft::sparse::distance::distances_config_t dist_config(handle); + raft::sparse::distance::detail::distances_config_t dist_config(handle); dist_config.b_nrows = idx_batcher.batch_rows(); dist_config.b_ncols = n_idx_cols; dist_config.b_nnz = idx_batch_nnz; diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index e768e49f6c..6a442869e8 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -244,7 +244,7 @@ class SparseDistanceCOOSPMVTest // output data rmm::device_uvector out_dists, out_dists_ref; - raft::sparse::distance::distances_config_t dist_config; + raft::sparse::distance::detail::distances_config_t dist_config; SparseDistanceCOOSPMVInputs params; }; diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 2a973d675c..0600d46cb7 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -139,7 +139,7 @@ class SparseDistanceTest rmm::device_uvector out_dists, out_dists_ref; SparseDistanceInputs params; - raft::sparse::distance::distances_config_t dist_config; + raft::sparse::distance::detail::distances_config_t dist_config; }; const std::vector> inputs_i32_f = { From ea16a20d33dc9ad8cc42dc53a2708031495a69ad Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 9 May 2023 07:26:13 -0700 Subject: [PATCH 3/8] working through --- cpp/include/raft/sparse/distance/distance.cuh | 2 +- cpp/test/sparse/distance.cu | 44 ++++++++++++------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index a8ac9c2ccb..6253e69e42 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -138,9 +138,9 @@ template >> void pairwise_distance( raft::device_resources const& handle, + raft::device_matrix_view dist, DeviceCSRMatrix x, DeviceCSRMatrix y, - raft::device_matrix_view dist, raft::distance::DistanceType metric, float metric_arg = 2.0f) { diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 0600d46cb7..9befa8c8fd 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -73,24 +73,38 @@ class SparseDistanceTest { make_data(); - dist_config.b_nrows = params.indptr_h.size() - 1; - dist_config.b_ncols = params.n_cols; - dist_config.b_nnz = params.indices_h.size(); - dist_config.b_indptr = indptr.data(); - dist_config.b_indices = indices.data(); - dist_config.b_data = data.data(); - dist_config.a_nrows = params.indptr_h.size() - 1; - dist_config.a_ncols = params.n_cols; - dist_config.a_nnz = params.indices_h.size(); - dist_config.a_indptr = indptr.data(); - dist_config.a_indices = indices.data(); - dist_config.a_data = data.data(); - - int out_size = dist_config.a_nrows * dist_config.b_nrows; + // dist_config.b_nrows = params.indptr_h.size() - 1; + // dist_config.b_ncols = params.n_cols; + // dist_config.b_nnz = params.indices_h.size(); + // dist_config.b_indptr = indptr.data(); + // dist_config.b_indices = indices.data(); + // dist_config.b_data = data.data(); + // dist_config.a_nrows = params.indptr_h.size() - 1; + // dist_config.a_ncols = params.n_cols; + // dist_config.a_nnz = params.indices_h.size(); + // dist_config.a_indptr = indptr.data(); + // dist_config.a_indices = indices.data(); + // dist_config.a_data = data.data(); + + // int out_size = dist_config.a_nrows * dist_config.b_nrows; out_dists.resize(out_size, handle.get_stream()); - pairwiseDistance(out_dists.data(), dist_config, params.metric, params.metric_arg); + // pairwiseDistance(out_dists.data(), dist_config, params.metric, params.metric_arg); + + auto out = raft::make_device_matrix_view( + out_dists.data(), dist_config.a_nrows, dist_config.b_nrows); + + auto x_structure = raft::make_device_compressed_structure_view( + handle, + indptr.data(), + indices.data(), + params.indptr_h.size() - 1, + params.n_cols, + params.indices_h.size()); + auto x = raft::make_device_csr_view(data.data(), x_structure); + + pairwiseDistance(handle, out, x, x, params.metric, params.metric_arg); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } From 3bf98ec10e890c053c62977059e935407a517665 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 9 May 2023 10:06:55 -0700 Subject: [PATCH 4/8] passing tests --- cpp/include/raft/core/coo_matrix.hpp | 1 + cpp/include/raft/core/csr_matrix.hpp | 1 + cpp/include/raft/core/device_coo_matrix.hpp | 52 +++++---- cpp/include/raft/core/device_csr_matrix.hpp | 96 +++++++-------- cpp/include/raft/core/host_coo_matrix.hpp | 52 +++++---- cpp/include/raft/core/host_csr_matrix.hpp | 96 +++++++-------- .../distance/detail/kernels/gram_matrix.cuh | 39 ++----- cpp/include/raft/sparse/distance/distance.cuh | 109 +++++++++++++----- cpp/test/sparse/distance.cu | 39 ++----- 9 files changed, 261 insertions(+), 224 deletions(-) diff --git a/cpp/include/raft/core/coo_matrix.hpp b/cpp/include/raft/core/coo_matrix.hpp index a5f7c05493..2fedc641cc 100644 --- a/cpp/include/raft/core/coo_matrix.hpp +++ b/cpp/include/raft/core/coo_matrix.hpp @@ -207,6 +207,7 @@ class coo_matrix_view using row_type = RowType; using col_type = ColType; using nnz_type = NZType; + static constexpr auto get_sparsity_type() { return SparsityType::PRESERVING; } coo_matrix_view(raft::span element_span, coordinate_structure_view structure_view) : sparse_matrix_view element_span, compressed_structure_view structure_view) diff --git a/cpp/include/raft/core/device_coo_matrix.hpp b/cpp/include/raft/core/device_coo_matrix.hpp index ce016dd5e0..590f483e4a 100644 --- a/cpp/include/raft/core/device_coo_matrix.hpp +++ b/cpp/include/raft/core/device_coo_matrix.hpp @@ -23,14 +23,21 @@ namespace raft { -template +using device_coordinate_structure_view = coordinate_structure_view; + +/** + * Specialization for a sparsity-owning coordinate structure which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy, - SparsityType sparsity_type = SparsityType::OWNING> -using device_coo_matrix = - coo_matrix; + template typename ContainerPolicy = device_uvector_policy> +using device_coordinate_structure = + coordinate_structure; /** * Specialization for a coo matrix view which uses device memory @@ -38,6 +45,15 @@ using device_coo_matrix = template using device_coo_matrix_view = coo_matrix_view; +template typename ContainerPolicy = device_uvector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using device_coo_matrix = + coo_matrix; + /** * Specialization for a sparsity-owning coo matrix which uses device memory */ @@ -62,21 +78,12 @@ using device_sparsity_preserving_coo_matrix = coo_matrix; -/** - * Specialization for a sparsity-owning coordinate structure which uses device memory - */ -template typename ContainerPolicy = device_uvector_policy> -using device_coordinate_structure = - coordinate_structure; +template +struct is_device_coo_matrix_view : std::false_type {}; -/** - * Specialization for a sparsity-preserving coordinate structure view which uses device memory - */ -template -using device_coordinate_structure_view = coordinate_structure_view; +template +struct is_device_coo_matrix_view> + : std::true_type {}; template struct is_device_coo_matrix : std::false_type {}; @@ -100,8 +107,9 @@ constexpr bool is_device_coo_sparsity_owning_v = is_device_coo_matrix::value and T::get_sparsity_type() == OWNING; template -constexpr bool is_device_coo_sparsity_preserving_v = - is_device_coo_matrix::value and T::get_sparsity_type() == PRESERVING; +constexpr bool is_device_coo_sparsity_preserving_v = std::disjunction_v< + is_device_coo_matrix_view, + std::bool_constant::value and T::get_sparsity_type() == PRESERVING>>; /** * Create a sparsity-owning sparse matrix in the coordinate format. sparsity-owning means that diff --git a/cpp/include/raft/core/device_csr_matrix.hpp b/cpp/include/raft/core/device_csr_matrix.hpp index 869034e925..76be55a7da 100644 --- a/cpp/include/raft/core/device_csr_matrix.hpp +++ b/cpp/include/raft/core/device_csr_matrix.hpp @@ -25,6 +25,29 @@ namespace raft { +/** + * Specialization for a sparsity-preserving compressed structure view which uses device memory + */ +template +using device_compressed_structure_view = + compressed_structure_view; + +/** + * Specialization for a sparsity-owning compressed structure which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_compressed_structure = + compressed_structure; + +/** + * Specialization for a csr matrix view which uses device memory + */ +template +using device_csr_matrix_view = csr_matrix_view; + template ; +/** + * Specialization for a sparsity-preserving csr matrix which uses device memory + */ +template typename ContainerPolicy = device_uvector_policy> +using device_sparsity_preserving_csr_matrix = csr_matrix; + +template +struct is_device_csr_matrix_view : std::false_type {}; + +template +struct is_device_csr_matrix_view< + device_csr_matrix_view> : std::true_type {}; + template struct is_device_csr_matrix : std::false_type {}; @@ -67,53 +113,9 @@ constexpr bool is_device_csr_sparsity_owning_v = is_device_csr_matrix::value and T::get_sparsity_type() == OWNING; template -constexpr bool is_device_csr_sparsity_preserving_v = - is_device_csr_matrix::value and T::get_sparsity_type() == PRESERVING; - -/** - * Specialization for a csr matrix view which uses device memory - */ -template -using device_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-preserving csr matrix which uses device memory - */ -template typename ContainerPolicy = device_uvector_policy> -using device_sparsity_preserving_csr_matrix = csr_matrix; - -/** - * Specialization for a csr matrix view which uses device memory - */ -template -using device_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-owning compressed structure which uses device memory - */ -template typename ContainerPolicy = device_uvector_policy> -using device_compressed_structure = - compressed_structure; - -/** - * Specialization for a sparsity-preserving compressed structure view which uses device memory - */ -template -using device_compressed_structure_view = - compressed_structure_view; +constexpr bool is_device_csr_sparsity_preserving_v = std::disjunction_v< + is_device_csr_matrix_view, + std::bool_constant::value and T::get_sparsity_type() == PRESERVING>>; /** * Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning diff --git a/cpp/include/raft/core/host_coo_matrix.hpp b/cpp/include/raft/core/host_coo_matrix.hpp index 32e7a9e3c4..1c1ce720fc 100644 --- a/cpp/include/raft/core/host_coo_matrix.hpp +++ b/cpp/include/raft/core/host_coo_matrix.hpp @@ -22,14 +22,21 @@ namespace raft { -template +using host_coordinate_structure_view = coordinate_structure_view; + +/** + * Specialization for a sparsity-owning coordinate structure which uses host memory + */ +template typename ContainerPolicy = host_vector_policy, - SparsityType sparsity_type = SparsityType::OWNING> -using host_coo_matrix = - coo_matrix; + template typename ContainerPolicy = host_vector_policy> +using host_coordinate_structure = + coordinate_structure; /** * Specialization for a coo matrix view which uses host memory @@ -37,6 +44,15 @@ using host_coo_matrix = template using host_coo_matrix_view = coo_matrix_view; +template typename ContainerPolicy = host_vector_policy, + SparsityType sparsity_type = SparsityType::OWNING> +using host_coo_matrix = + coo_matrix; + /** * Specialization for a sparsity-owning coo matrix which uses host memory */ @@ -61,21 +77,12 @@ using host_sparsity_preserving_coo_matrix = coo_matrix; -/** - * Specialization for a sparsity-owning coordinate structure which uses host memory - */ -template typename ContainerPolicy = host_vector_policy> -using host_coordinate_structure = - coordinate_structure; +template +struct is_host_coo_matrix_view : std::false_type {}; -/** - * Specialization for a sparsity-preserving coordinate structure view which uses host memory - */ -template -using host_coordinate_structure_view = coordinate_structure_view; +template +struct is_host_coo_matrix_view> + : std::true_type {}; template struct is_host_coo_matrix : std::false_type {}; @@ -99,8 +106,9 @@ constexpr bool is_host_coo_sparsity_owning_v = is_host_coo_matrix::value and T::get_sparsity_type() == OWNING; template -constexpr bool is_host_coo_sparsity_preserving_v = - is_host_coo_matrix::value and T::get_sparsity_type() == PRESERVING; +constexpr bool is_host_coo_sparsity_preserving_v = std::disjunction_v< + is_host_coo_matrix_view, + std::bool_constant::value and T::get_sparsity_type() == PRESERVING>>; /** * Create a sparsity-owning sparse matrix in the coordinate format. sparsity-owning means that diff --git a/cpp/include/raft/core/host_csr_matrix.hpp b/cpp/include/raft/core/host_csr_matrix.hpp index 86199335f2..eb637b1983 100644 --- a/cpp/include/raft/core/host_csr_matrix.hpp +++ b/cpp/include/raft/core/host_csr_matrix.hpp @@ -24,6 +24,29 @@ namespace raft { +/** + * Specialization for a sparsity-preserving compressed structure view which uses host memory + */ +template +using host_compressed_structure_view = + compressed_structure_view; + +/** + * Specialization for a sparsity-owning compressed structure which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_compressed_structure = + compressed_structure; + +/** + * Specialization for a csr matrix view which uses host memory + */ +template +using host_csr_matrix_view = csr_matrix_view; + template ; +/** + * Specialization for a sparsity-preserving csr matrix which uses host memory + */ +template typename ContainerPolicy = host_vector_policy> +using host_sparsity_preserving_csr_matrix = csr_matrix; + +template +struct is_host_csr_matrix_view : std::false_type {}; + +template +struct is_host_csr_matrix_view> + : std::true_type {}; + template struct is_host_csr_matrix : std::false_type {}; @@ -66,53 +112,9 @@ constexpr bool is_host_csr_sparsity_owning_v = is_host_csr_matrix::value and T::get_sparsity_type() == OWNING; template -constexpr bool is_host_csr_sparsity_preserving_v = - is_host_csr_matrix::value and T::get_sparsity_type() == PRESERVING; - -/** - * Specialization for a csr matrix view which uses host memory - */ -template -using host_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-preserving csr matrix which uses host memory - */ -template typename ContainerPolicy = host_vector_policy> -using host_sparsity_preserving_csr_matrix = csr_matrix; - -/** - * Specialization for a csr matrix view which uses host memory - */ -template -using host_csr_matrix_view = csr_matrix_view; - -/** - * Specialization for a sparsity-owning compressed structure which uses host memory - */ -template typename ContainerPolicy = host_vector_policy> -using host_compressed_structure = - compressed_structure; - -/** - * Specialization for a sparsity-preserving compressed structure view which uses host memory - */ -template -using host_compressed_structure_view = - compressed_structure_view; +constexpr bool is_host_csr_sparsity_preserving_v = std::disjunction_v< + is_host_csr_matrix_view, + std::bool_constant::value and T::get_sparsity_type() == PRESERVING>>; /** * Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index 2154aa560c..d80e93d1e8 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -469,41 +469,18 @@ class GramMatrixBase { int minor_out = is_row_major ? out.extent(1) : out.extent(0); ASSERT(ld_out == minor_out, "Sparse linear Kernel distance does not support ld_out parameter"); - auto x1_structure = x1.structure_view(); - auto x2_structure = x2.structure_view(); - raft::sparse::distance::distances_config_t dist_config(handle); - // switch a,b based on is_row_major if (!is_row_major) { - dist_config.a_nrows = x2_structure.get_n_rows(); - dist_config.a_ncols = x2_structure.get_n_cols(); - dist_config.a_nnz = x2_structure.get_nnz(); - dist_config.a_indptr = const_cast(x2_structure.get_indptr().data()); - dist_config.a_indices = const_cast(x2_structure.get_indices().data()); - dist_config.a_data = const_cast(x2.get_elements().data()); - dist_config.b_nrows = x1_structure.get_n_rows(); - dist_config.b_ncols = x1_structure.get_n_cols(); - dist_config.b_nnz = x1_structure.get_nnz(); - dist_config.b_indptr = const_cast(x1_structure.get_indptr().data()); - dist_config.b_indices = const_cast(x1_structure.get_indices().data()); - dist_config.b_data = const_cast(x1.get_elements().data()); + auto out_row_major = raft::make_device_matrix_view( + out.data_handle(), out.extent(1), out.extent(0)); + raft::sparse::distance::pairwise_distance( + handle, out_row_major, x2, x1, raft::distance::DistanceType::InnerProduct, 0.0); } else { - dist_config.a_nrows = x1_structure.get_n_rows(); - dist_config.a_ncols = x1_structure.get_n_cols(); - dist_config.a_nnz = x1_structure.get_nnz(); - dist_config.a_indptr = const_cast(x1_structure.get_indptr().data()); - dist_config.a_indices = const_cast(x1_structure.get_indices().data()); - dist_config.a_data = const_cast(x1.get_elements().data()); - dist_config.b_nrows = x2_structure.get_n_rows(); - dist_config.b_ncols = x2_structure.get_n_cols(); - dist_config.b_nnz = x2_structure.get_nnz(); - dist_config.b_indptr = const_cast(x2_structure.get_indptr().data()); - dist_config.b_indices = const_cast(x2_structure.get_indices().data()); - dist_config.b_data = const_cast(x2.get_elements().data()); + auto out_row_major = raft::make_device_matrix_view( + out.data_handle(), out.extent(0), out.extent(1)); + raft::sparse::distance::pairwise_distance( + handle, out_row_major, x1, x2, raft::distance::DistanceType::InnerProduct, 0.0); } - - raft::sparse::distance::pairwiseDistance( - out.data_handle(), dist_config, raft::distance::DistanceType::InnerProduct, 0.0); } }; diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 6253e69e42..d0f120da5c 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -132,46 +132,101 @@ void pairwiseDistance(value_t* out, } } +/** + * @defgroup sparse_distance Sparse Pairwise Distance + * @{ + */ + +/** + * @brief Compute pairwise distances between x and y, using the provided + * input configuration and distance function. + * + * @code{.cpp} + * #include + * #include + * #include + * + * int x_n_rows = 100000; + * int y_n_rows = 50000; + * int n_cols = 10000; + * + * raft::device_resources handle; + * auto x = raft::make_device_csr_matrix(handle, x_n_rows, n_cols); + * auto y = raft::make_device_csr_matrix(handle, y_n_rows, n_cols); + * ... + * // compute expected sparsity + * ... + * int x_nnz = 5000; + * int y_nnz = 10000; + * x.initialize_sparsity(nnz); + * y.initialize_sparsity(nnz); + * ... + * // populate data + * ... + * + * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); + * + * raft::sparse::distance_pairwise_distance(handle, out, x, y, + * raft::distance::DistanceType::L2Expanded); + * @endcode + * + * @tparam DeviceCSRMatrix raft::device_csr_matrix or raft::device_csr_matrix_view + * @tparam ElementType data-type of inputs and output + * @tparam IndexType data-type for indexing + * + * @param[in] handle raft::device_resources + * @param[out] dist raft::device_matrix_view dense matrix + * @param[in] x raft::SparsityType::PRESERVING sparse matrix + * @param[in] y raft::SparsityType::PRESERVING sparse matrix + * @param[in] metric distance metric to use + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ template >> -void pairwise_distance( - raft::device_resources const& handle, - raft::device_matrix_view dist, - DeviceCSRMatrix x, - DeviceCSRMatrix y, - raft::distance::DistanceType metric, - float metric_arg = 2.0f) +void pairwise_distance(raft::device_resources const& handle, + raft::device_matrix_view dist, + DeviceCSRMatrix x, + DeviceCSRMatrix y, + raft::distance::DistanceType metric, + float metric_arg = 2.0f) { - RAFT_EXPECTS(x.get_n_cols() == y.get_n_cols(), "Number of columns must be equal"); - RAFT_EXPECTS(dist.extent(0) == x.get_n_rows(), + auto x_structure = x.structure_view(); + auto y_structure = y.structure_view(); + + RAFT_EXPECTS(x_structure.get_n_cols() == y_structure.get_n_cols(), + "Number of columns must be equal"); + + RAFT_EXPECTS(dist.extent(0) == x_structure.get_n_rows(), "Number of rows in output must be equal to " "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.get_n_rows(), + RAFT_EXPECTS(dist.extent(1) == y_structure.get_n_rows(), "Number of columns in output must be equal to " "number of rows in Y"); - detail::distances_config_t input_config(handle); - input_config.a_nrows = x.get_n_rows(); - input_config.a_ncols = x.get_n_cols(); - input_config.a_nnz = x.get_nnz(); - input_config.a_indptr = x.get_indptr().data(); - input_config.a_indices = x.get_indices().data(); - input_config.a_data = x.get_elements().data(); - - input_config.b_nrows = y.get_n_rows(); - input_config.b_ncols = y.get_n_cols(); - input_config.b_nnz = y.get_nnz(); - input_config.b_indptr = y.get_indptr().data(); - input_config.b_indices = y.get_indices().data(); - input_config.b_data = y.get_elements().data(); + detail::distances_config_t input_config(handle); + input_config.a_nrows = x_structure.get_n_rows(); + input_config.a_ncols = x_structure.get_n_cols(); + input_config.a_nnz = x_structure.get_nnz(); + input_config.a_indptr = const_cast(x_structure.get_indptr().data()); + input_config.a_indices = const_cast(x_structure.get_indices().data()); + input_config.a_data = const_cast(x.get_elements().data()); + + input_config.b_nrows = y_structure.get_n_rows(); + input_config.b_ncols = y_structure.get_n_cols(); + input_config.b_nnz = y_structure.get_nnz(); + input_config.b_indptr = const_cast(y_structure.get_indptr().data()); + input_config.b_indices = const_cast(y_structure.get_indices().data()); + input_config.b_data = const_cast(y.get_elements().data()); pairwiseDistance(dist.data_handle(), input_config, metric, metric_arg); } -}; // namespace distance -}; // namespace sparse -}; // namespace raft +/** @} */ // end of sparse_distance + +}; // namespace distance +}; // namespace sparse +}; // namespace raft #endif \ No newline at end of file diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 9befa8c8fd..8607879a78 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -60,7 +60,6 @@ class SparseDistanceTest public: SparseDistanceTest() : params(::testing::TestWithParam>::GetParam()), - dist_config(handle), indptr(0, handle.get_stream()), indices(0, handle.get_stream()), data(0, handle.get_stream()), @@ -73,38 +72,25 @@ class SparseDistanceTest { make_data(); - // dist_config.b_nrows = params.indptr_h.size() - 1; - // dist_config.b_ncols = params.n_cols; - // dist_config.b_nnz = params.indices_h.size(); - // dist_config.b_indptr = indptr.data(); - // dist_config.b_indices = indices.data(); - // dist_config.b_data = data.data(); - // dist_config.a_nrows = params.indptr_h.size() - 1; - // dist_config.a_ncols = params.n_cols; - // dist_config.a_nnz = params.indices_h.size(); - // dist_config.a_indptr = indptr.data(); - // dist_config.a_indices = indices.data(); - // dist_config.a_data = data.data(); - - // int out_size = dist_config.a_nrows * dist_config.b_nrows; + int out_size = static_cast(params.indptr_h.size() - 1) * + static_cast(params.indptr_h.size() - 1); out_dists.resize(out_size, handle.get_stream()); - // pairwiseDistance(out_dists.data(), dist_config, params.metric, params.metric_arg); - auto out = raft::make_device_matrix_view( - out_dists.data(), dist_config.a_nrows, dist_config.b_nrows); + out_dists.data(), + static_cast(params.indptr_h.size() - 1), + static_cast(params.indptr_h.size() - 1)); auto x_structure = raft::make_device_compressed_structure_view( - handle, indptr.data(), indices.data(), - params.indptr_h.size() - 1, + static_cast(params.indptr_h.size() - 1), params.n_cols, - params.indices_h.size()); - auto x = raft::make_device_csr_view(data.data(), x_structure); + static_cast(params.indices_h.size())); + auto x = raft::make_device_csr_matrix_view(data.data(), x_structure); - pairwiseDistance(handle, out, x, x, params.metric, params.metric_arg); + pairwise_distance(handle, out, x, x, params.metric, params.metric_arg); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } @@ -137,10 +123,8 @@ class SparseDistanceTest out_dists_ref.resize((indptr_h.size() - 1) * (indptr_h.size() - 1), stream); - update_device(out_dists_ref.data(), - out_dists_ref_h.data(), - out_dists_ref_h.size(), - dist_config.handle.get_stream()); + update_device( + out_dists_ref.data(), out_dists_ref_h.data(), out_dists_ref_h.size(), handle.get_stream()); } raft::device_resources handle; @@ -153,7 +137,6 @@ class SparseDistanceTest rmm::device_uvector out_dists, out_dists_ref; SparseDistanceInputs params; - raft::sparse::distance::detail::distances_config_t dist_config; }; const std::vector> inputs_i32_f = { From 7a47e46442c2df937a9a946f2298458ebfd0937b Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 18 May 2023 10:46:45 -0700 Subject: [PATCH 5/8] address review --- .../distance/detail/kernels/gram_matrix.cuh | 4 ++-- cpp/include/raft/sparse/distance/distance.cuh | 23 +++++++------------ cpp/test/sparse/distance.cu | 2 +- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index d80e93d1e8..42d83600a6 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -474,12 +474,12 @@ class GramMatrixBase { auto out_row_major = raft::make_device_matrix_view( out.data_handle(), out.extent(1), out.extent(0)); raft::sparse::distance::pairwise_distance( - handle, out_row_major, x2, x1, raft::distance::DistanceType::InnerProduct, 0.0); + handle, x2, x1, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0); } else { auto out_row_major = raft::make_device_matrix_view( out.data_handle(), out.extent(0), out.extent(1)); raft::sparse::distance::pairwise_distance( - handle, out_row_major, x1, x2, raft::distance::DistanceType::InnerProduct, 0.0); + handle, x1, x2, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0); } } }; diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index d0f120da5c..0b90727c54 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -143,7 +143,7 @@ void pairwiseDistance(value_t* out, * * @code{.cpp} * #include - * #include + * #include * #include * * int x_n_rows = 100000; @@ -153,31 +153,24 @@ void pairwiseDistance(value_t* out, * raft::device_resources handle; * auto x = raft::make_device_csr_matrix(handle, x_n_rows, n_cols); * auto y = raft::make_device_csr_matrix(handle, y_n_rows, n_cols); - * ... - * // compute expected sparsity - * ... - * int x_nnz = 5000; - * int y_nnz = 10000; - * x.initialize_sparsity(nnz); - * y.initialize_sparsity(nnz); + * * ... * // populate data * ... * * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); - * - * raft::sparse::distance_pairwise_distance(handle, out, x, y, - * raft::distance::DistanceType::L2Expanded); + * auto metric = raft::distance::DistanceType::L2Expanded; + * raft::sparse::distance_pairwise_distance(handle, x, y, out, metric); * @endcode * * @tparam DeviceCSRMatrix raft::device_csr_matrix or raft::device_csr_matrix_view * @tparam ElementType data-type of inputs and output * @tparam IndexType data-type for indexing * - * @param[in] handle raft::device_resources - * @param[out] dist raft::device_matrix_view dense matrix + * @param[in] handle raft::resources * @param[in] x raft::SparsityType::PRESERVING sparse matrix * @param[in] y raft::SparsityType::PRESERVING sparse matrix + * @param[out] dist raft::device_matrix_view dense matrix * @param[in] metric distance metric to use * @param[in] metric_arg metric argument (used for Minkowski distance) */ @@ -185,10 +178,10 @@ template >> -void pairwise_distance(raft::device_resources const& handle, - raft::device_matrix_view dist, +void pairwise_distance(raft::resources const& handle, DeviceCSRMatrix x, DeviceCSRMatrix y, + raft::device_matrix_view dist, raft::distance::DistanceType metric, float metric_arg = 2.0f) { diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 8607879a78..7948326396 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -90,7 +90,7 @@ class SparseDistanceTest static_cast(params.indices_h.size())); auto x = raft::make_device_csr_matrix_view(data.data(), x_structure); - pairwise_distance(handle, out, x, x, params.metric, params.metric_arg); + pairwise_distance(handle, x, x, out, params.metric, params.metric_arg); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } From 06b9c150dcacf82d526446fbcdf8250a2b62ff0f Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 18 May 2023 11:30:34 -0700 Subject: [PATCH 6/8] remove dist_config --- cpp/test/sparse/distance.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index f087519cc6..6b4e5c7cfa 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -61,7 +61,6 @@ class SparseDistanceTest public: SparseDistanceTest() : params(::testing::TestWithParam>::GetParam()), - dist_config(handle), indptr(0, resource::get_cuda_stream(handle)), indices(0, resource::get_cuda_stream(handle)), data(0, resource::get_cuda_stream(handle)), From 66a563eb82eb1d0f0d91a543c0bc700a4d10ee69 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Fri, 19 May 2023 12:27:10 -0400 Subject: [PATCH 7/8] Update distance.cuh --- cpp/include/raft/sparse/distance/distance.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 0b90727c54..07e15f1c2b 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -160,7 +160,7 @@ void pairwiseDistance(value_t* out, * * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); * auto metric = raft::distance::DistanceType::L2Expanded; - * raft::sparse::distance_pairwise_distance(handle, x, y, out, metric); + * raft::sparse::distance::pairwise_distance(handle, x, y, out, metric); * @endcode * * @tparam DeviceCSRMatrix raft::device_csr_matrix or raft::device_csr_matrix_view From 3cac8820479d13e99a377b24e2bd880b97e5672b Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 28 Jun 2023 11:57:16 -0700 Subject: [PATCH 8/8] sparse api meta functions for views --- cpp/include/raft/core/coo_matrix.hpp | 1 - cpp/include/raft/core/csr_matrix.hpp | 1 - cpp/include/raft/core/device_coo_matrix.hpp | 8 +++++--- cpp/include/raft/core/device_csr_matrix.hpp | 8 +++++--- cpp/include/raft/core/host_coo_matrix.hpp | 8 +++++--- cpp/include/raft/core/host_csr_matrix.hpp | 3 +++ cpp/include/raft/sparse/distance/distance.cuh | 8 ++++---- 7 files changed, 22 insertions(+), 15 deletions(-) diff --git a/cpp/include/raft/core/coo_matrix.hpp b/cpp/include/raft/core/coo_matrix.hpp index 2fedc641cc..a5f7c05493 100644 --- a/cpp/include/raft/core/coo_matrix.hpp +++ b/cpp/include/raft/core/coo_matrix.hpp @@ -207,7 +207,6 @@ class coo_matrix_view using row_type = RowType; using col_type = ColType; using nnz_type = NZType; - static constexpr auto get_sparsity_type() { return SparsityType::PRESERVING; } coo_matrix_view(raft::span element_span, coordinate_structure_view structure_view) : sparse_matrix_view element_span, compressed_structure_view structure_view) diff --git a/cpp/include/raft/core/device_coo_matrix.hpp b/cpp/include/raft/core/device_coo_matrix.hpp index d352fe6368..9ab5160ef5 100644 --- a/cpp/include/raft/core/device_coo_matrix.hpp +++ b/cpp/include/raft/core/device_coo_matrix.hpp @@ -85,6 +85,9 @@ template > : std::true_type {}; +template +constexpr bool is_device_coo_matrix_view_v = is_device_coo_matrix_view::value; + template struct is_device_coo_matrix : std::false_type {}; @@ -107,9 +110,8 @@ constexpr bool is_device_coo_sparsity_owning_v = is_device_coo_matrix::value and T::get_sparsity_type() == OWNING; template -constexpr bool is_device_coo_sparsity_preserving_v = std::disjunction_v< - is_device_coo_matrix_view, - std::bool_constant::value and T::get_sparsity_type() == PRESERVING>>; +constexpr bool is_device_coo_sparsity_preserving_v = + is_device_coo_matrix::value and T::get_sparsity_type() == PRESERVING; /** * Create a sparsity-owning sparse matrix in the coordinate format. sparsity-owning means that diff --git a/cpp/include/raft/core/device_csr_matrix.hpp b/cpp/include/raft/core/device_csr_matrix.hpp index e5c93b244f..df186cc194 100644 --- a/cpp/include/raft/core/device_csr_matrix.hpp +++ b/cpp/include/raft/core/device_csr_matrix.hpp @@ -91,6 +91,9 @@ template > : std::true_type {}; +template +constexpr bool is_device_csr_matrix_view_v = is_device_csr_matrix_view::value; + template struct is_device_csr_matrix : std::false_type {}; @@ -113,9 +116,8 @@ constexpr bool is_device_csr_sparsity_owning_v = is_device_csr_matrix::value and T::get_sparsity_type() == OWNING; template -constexpr bool is_device_csr_sparsity_preserving_v = std::disjunction_v< - is_device_csr_matrix_view, - std::bool_constant::value and T::get_sparsity_type() == PRESERVING>>; +constexpr bool is_device_csr_sparsity_preserving_v = + is_device_csr_matrix::value and T::get_sparsity_type() == PRESERVING; /** * Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning diff --git a/cpp/include/raft/core/host_coo_matrix.hpp b/cpp/include/raft/core/host_coo_matrix.hpp index 1c1ce720fc..9e6aacfa48 100644 --- a/cpp/include/raft/core/host_coo_matrix.hpp +++ b/cpp/include/raft/core/host_coo_matrix.hpp @@ -84,6 +84,9 @@ template > : std::true_type {}; +template +constexpr bool is_host_coo_matrix_view_v = is_host_coo_matrix_view::value; + template struct is_host_coo_matrix : std::false_type {}; @@ -106,9 +109,8 @@ constexpr bool is_host_coo_sparsity_owning_v = is_host_coo_matrix::value and T::get_sparsity_type() == OWNING; template -constexpr bool is_host_coo_sparsity_preserving_v = std::disjunction_v< - is_host_coo_matrix_view, - std::bool_constant::value and T::get_sparsity_type() == PRESERVING>>; +constexpr bool is_host_coo_sparsity_preserving_v = + is_host_coo_matrix::value and T::get_sparsity_type() == PRESERVING; /** * Create a sparsity-owning sparse matrix in the coordinate format. sparsity-owning means that diff --git a/cpp/include/raft/core/host_csr_matrix.hpp b/cpp/include/raft/core/host_csr_matrix.hpp index eb637b1983..4b4df823db 100644 --- a/cpp/include/raft/core/host_csr_matrix.hpp +++ b/cpp/include/raft/core/host_csr_matrix.hpp @@ -90,6 +90,9 @@ template > : std::true_type {}; +template +constexpr bool is_host_csr_matrix_view_v = is_host_csr_matrix_view::value; + template struct is_host_csr_matrix : std::false_type {}; diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 07e15f1c2b..b60940341a 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -160,7 +160,7 @@ void pairwiseDistance(value_t* out, * * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); * auto metric = raft::distance::DistanceType::L2Expanded; - * raft::sparse::distance::pairwise_distance(handle, x, y, out, metric); + * raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric); * @endcode * * @tparam DeviceCSRMatrix raft::device_csr_matrix or raft::device_csr_matrix_view @@ -168,8 +168,8 @@ void pairwiseDistance(value_t* out, * @tparam IndexType data-type for indexing * * @param[in] handle raft::resources - * @param[in] x raft::SparsityType::PRESERVING sparse matrix - * @param[in] y raft::SparsityType::PRESERVING sparse matrix + * @param[in] x raft::device_csr_matrix_view + * @param[in] y raft::device_csr_matrix_view * @param[out] dist raft::device_matrix_view dense matrix * @param[in] metric distance metric to use * @param[in] metric_arg metric argument (used for Minkowski distance) @@ -177,7 +177,7 @@ void pairwiseDistance(value_t* out, template >> + typename = std::enable_if_t>> void pairwise_distance(raft::resources const& handle, DeviceCSRMatrix x, DeviceCSRMatrix y,