From 1d9adab59d6eb273b5244b232813d8f7c86d74a9 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 16 Jan 2024 20:10:18 +0100 Subject: [PATCH 1/2] Add AIR-Top-k reference (#2031) Add reference to AIR top-k paper. Authors: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2031 --- README.md | 35 +++++++++++++++++++ .../raft/matrix/detail/select_radix.cuh | 8 ++++- .../raft/neighbors/nn_descent_types.hpp | 6 ++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9ab1168bdb..26ddc30ed4 100755 --- a/README.md +++ b/README.md @@ -354,3 +354,38 @@ If citing CAGRA, please consider the following bibtex: primaryClass={cs.DS} } ``` + +If citing the k-selection routines, please consider the following bibtex: + +```bibtex +@proceedings{10.1145/3581784, + title = {SC '23: Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis}, + year = {2023}, + isbn = {9798400701092}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + abstract = {Started in 1988, the SC Conference has become the annual nexus for researchers and practitioners from academia, industry and government to share information and foster collaborations to advance the state of the art in High Performance Computing (HPC), Networking, Storage, and Analysis.}, + location = {, Denver, CO, USA, } +} +``` + +If citing the nearest neighbors descent API, please consider the following bibtex: +```bibtex +@inproceedings{10.1145/3459637.3482344, + author = {Wang, Hui and Zhao, Wan-Lei and Zeng, Xiangxiang and Yang, Jianye}, + title = {Fast K-NN Graph Construction by GPU Based NN-Descent}, + year = {2021}, + isbn = {9781450384469}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + url = {https://doi.org/10.1145/3459637.3482344}, + doi = {10.1145/3459637.3482344}, + abstract = {NN-Descent is a classic k-NN graph construction approach. It is still widely employed in machine learning, computer vision, and information retrieval tasks due to its efficiency and genericness. However, the current design only works well on CPU. In this paper, NN-Descent has been redesigned to adapt to the GPU architecture. A new graph update strategy called selective update is proposed. It reduces the data exchange between GPU cores and GPU global memory significantly, which is the processing bottleneck under GPU computation architecture. This redesign leads to full exploitation of the parallelism of the GPU hardware. In the meantime, the genericness, as well as the simplicity of NN-Descent, are well-preserved. Moreover, a procedure that allows to k-NN graph to be merged efficiently on GPU is proposed. It makes the construction of high-quality k-NN graphs for out-of-GPU-memory datasets tractable. Our approach is 100-250\texttimes{} faster than the single-thread NN-Descent and is 2.5-5\texttimes{} faster than the existing GPU-based approaches as we tested on million as well as billion scale datasets.}, + booktitle = {Proceedings of the 30th ACM International Conference on Information \& Knowledge Management}, + pages = {1929–1938}, + numpages = {10}, + keywords = {high-dimensional, nn-descent, gpu, k-nearest neighbor graph}, + location = {Virtual Event, Queensland, Australia}, + series = {CIKM '21} +} +``` \ No newline at end of file diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 4245be42d6..b6ed03b93d 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1141,6 +1141,12 @@ void radix_topk_one_block(const T* in, * * Note, the output is NOT sorted within the groups of `k` selected elements. * + * Reference: + * Jingrong Zhang, Akira Naruse, Xipeng Li, and Yong Wang. 2023. Parallel Top-K Algorithms on GPU: + * A Comprehensive Study and New Methods. In The International Conference for High Performance + * Computing, Networking, Storage and Analysis (SC ’23), November 12–17, 2023, Denver, CO, USA. + * ACM, New York, NY, USA. https://doi.org/10.1145/3581784.3607062 + * * @tparam T * the type of the keys (what is being compared). * @tparam IdxT diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index 7d4f3d615b..fd1df2965e 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -58,6 +58,12 @@ struct index_params : ann::index_params { * The index contains an all-neighbors graph of the input dataset * stored in host memory of dimensions (n_rows, n_cols) * + * Reference: + * Hui Wang, Wan-Lei Zhao, Xiangxiang Zeng, and Jianye Yang. 2021. + * Fast k-NN Graph Construction by GPU based NN-Descent. In Proceedings of the 30th ACM + * International Conference on Information & Knowledge Management (CIKM '21). Association for + * Computing Machinery, New York, NY, USA, 1929–1938. https://doi.org/10.1145/3459637.3482344 + * * @tparam IdxT dtype to be used for constructing knn-graph */ template From 3c7586f813973c5489df70f25c2e221343b65853 Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 16 Jan 2024 13:00:42 -0800 Subject: [PATCH 2/2] [FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067) (#2067) - Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - rhdong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2067 --- cpp/bench/prims/CMakeLists.txt | 3 +- cpp/bench/prims/linalg/sddmm.cu | 275 +++++++++++++ .../distance/detail/kernels/gram_matrix.cuh | 4 +- cpp/include/raft/linalg/linalg_types.hpp | 9 +- .../raft/sparse/detail/cusparse_wrappers.h | 114 +++++- .../sparse/linalg/detail/cusparse_utils.hpp | 103 +++++ .../raft/sparse/linalg/detail/sddmm.hpp | 99 +++++ .../raft/sparse/linalg/detail/spmm.hpp | 54 +-- cpp/include/raft/sparse/linalg/sddmm.hpp | 83 ++++ cpp/include/raft/sparse/linalg/spmm.cuh | 66 +--- cpp/include/raft/sparse/linalg/spmm.hpp | 79 ++++ cpp/test/CMakeLists.txt | 1 + cpp/test/sparse/sddmm.cu | 365 ++++++++++++++++++ 13 files changed, 1136 insertions(+), 119 deletions(-) create mode 100644 cpp/bench/prims/linalg/sddmm.cu create mode 100644 cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp create mode 100644 cpp/include/raft/sparse/linalg/detail/sddmm.hpp create mode 100644 cpp/include/raft/sparse/linalg/sddmm.hpp create mode 100644 cpp/include/raft/sparse/linalg/spmm.hpp create mode 100644 cpp/test/sparse/sddmm.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index fe58453d0d..3a2431cd34 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -117,6 +117,7 @@ if(BUILD_PRIMS_BENCH) bench/prims/linalg/reduce_cols_by_key.cu bench/prims/linalg/reduce_rows_by_key.cu bench/prims/linalg/reduce.cu + bench/prims/linalg/sddmm.cu bench/prims/main.cpp ) diff --git a/cpp/bench/prims/linalg/sddmm.cu b/cpp/bench/prims/linalg/sddmm.cu new file mode 100644 index 0000000000..139a2b838d --- /dev/null +++ b/cpp/bench/prims/linalg/sddmm.cu @@ -0,0 +1,275 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace raft::bench::linalg { + +template +struct SDDMMBenchParams { + size_t m; + size_t k; + size_t n; + float sparsity; + bool transpose_a; + bool transpose_b; + ValueType alpha = 1.0; + ValueType beta = 0.0; +}; + +enum Alg { SDDMM, Inner }; + +template +inline auto operator<<(std::ostream& os, const SDDMMBenchParams& params) -> std::ostream& +{ + os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n + << "\tsparsity=" << params.sparsity << "\ttrans_a=" << (params.transpose_a ? "T" : "F") + << " trans_b=" << (params.transpose_b ? "T" : "F"); + return os; +} + +template +struct SDDMMBench : public fixture { + SDDMMBench(const SDDMMBenchParams& p) + : fixture(true), + params(p), + handle(stream), + a_data_d(0, stream), + b_data_d(0, stream), + c_indptr_d(0, stream), + c_indices_d(0, stream), + c_data_d(0, stream), + c_dense_data_d(0, stream) + { + a_data_d.resize(params.m * params.k, stream); + b_data_d.resize(params.k * params.n, stream); + + raft::random::RngState rng(2024ULL); + raft::random::uniform( + handle, rng, a_data_d.data(), params.m * params.k, ValueType(-1.0), ValueType(1.0)); + raft::random::uniform( + handle, rng, b_data_d.data(), params.k * params.n, ValueType(-1.0), ValueType(1.0)); + + std::vector c_dense_data_h(params.m * params.n); + + c_true_nnz = create_sparse_matrix(c_dense_data_h); + std::vector values(c_true_nnz); + std::vector indices(c_true_nnz); + std::vector indptr(params.m + 1); + + c_data_d.resize(c_true_nnz, stream); + c_indptr_d.resize(params.m + 1, stream); + c_indices_d.resize(c_true_nnz, stream); + + if (SDDMMorInner == Alg::Inner) { c_dense_data_d.resize(params.m * params.n, stream); } + + convert_to_csr(c_dense_data_h, params.m, params.n, values, indices, indptr); + RAFT_EXPECTS(c_true_nnz == c_indices_d.size(), + "Something wrong. The c_true_nnz != c_indices_d.size()!"); + + update_device(c_data_d.data(), values.data(), c_true_nnz, stream); + update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream); + update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream); + } + + void convert_to_csr(std::vector& matrix, + IndexType rows, + IndexType cols, + std::vector& values, + std::vector& indices, + std::vector& indptr) + { + IndexType offset_indptr = 0; + IndexType offset_values = 0; + indptr[offset_indptr++] = 0; + + for (IndexType i = 0; i < rows; ++i) { + for (IndexType j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + values[offset_values] = static_cast(1.0); + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + size_t create_sparse_matrix(std::vector& matrix) + { + size_t total_elements = static_cast(params.m * params.n); + size_t num_ones = static_cast((total_elements * 1.0f) * params.sparsity); + size_t res = num_ones; + + for (size_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis(gen); + + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + ~SDDMMBench() {} + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto a = raft::make_device_matrix_view( + a_data_d.data(), + (!params.transpose_a ? params.m : params.k), + (!params.transpose_a ? params.k : params.m)); + + auto b = raft::make_device_matrix_view( + b_data_d.data(), + (!params.transpose_b ? params.k : params.n), + (!params.transpose_b ? params.n : params.k)); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + raft::resource::get_cusparse_handle(handle); + + resource::sync_stream(handle); + + auto op_a = params.transpose_a ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + auto op_b = params.transpose_b ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + resource::sync_stream(handle); + + loop_on_state(state, [this, &a, &b, &c, &op_a, &op_b]() { + if (SDDMMorInner == Alg::SDDMM) { + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + resource::sync_stream(handle); + } else { + raft::distance::pairwise_distance(handle, + a_data_d.data(), + b_data_d.data(), + c_dense_data_d.data(), + static_cast(params.m), + static_cast(params.n), + static_cast(params.k), + raft::distance::DistanceType::InnerProduct, + std::is_same_v); + resource::sync_stream(handle); + } + }); + } + + private: + const raft::device_resources handle; + SDDMMBenchParams params; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + rmm::device_uvector c_dense_data_d; + + size_t c_true_nnz = 0; + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; +}; + +template +static std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + bool transpose_a; + bool transpose_b; + size_t m; + size_t k; + size_t n; + float sparsity; + }; + + const std::vector params_group = + raft::util::itertools::product({false, true}, + {false, true}, + {size_t(10), size_t(1024)}, + {size_t(128), size_t(1024)}, + {size_t(1024 * 1024)}, + {0.01f, 0.1f, 0.2f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(SDDMMBenchParams( + {params.m, params.k, params.n, params.sparsity, params.transpose_a, params.transpose_b})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); + +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); + +} // namespace raft::bench::linalg diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index e121c1be9c..14b4ba12c6 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ #include // #include #include -#include +#include #include #include diff --git a/cpp/include/raft/linalg/linalg_types.hpp b/cpp/include/raft/linalg/linalg_types.hpp index e50d3a8e79..9c81fbc177 100644 --- a/cpp/include/raft/linalg/linalg_types.hpp +++ b/cpp/include/raft/linalg/linalg_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,4 +32,11 @@ enum class Apply { ALONG_ROWS, ALONG_COLUMNS }; */ enum class FillMode { UPPER, LOWER }; +/** + * @brief Enum for this type indicates which operation is applied to the related input (e.g. sparse + * matrix, or vector). + * + */ +enum class Operation { NON_TRANSPOSE, TRANSPOSE }; + } // end namespace raft::linalg \ No newline at end of file diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index e8bf9c6de5..cc3ae3ab87 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -571,6 +571,118 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, alg, static_cast(externalBuffer)); } + +template +cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, bufferSize); +} +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, bufferSize); +} +template +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + T* externalBuffer, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + float* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_32F, + alg, + static_cast(externalBuffer)); +} +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + double* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_64F, + alg, + static_cast(externalBuffer)); +} + /** @} */ #else /** diff --git a/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp new file mode 100644 index 0000000000..b15614905b --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief create a cuSparse dense descriptor + * @tparam ValueType Data type of dense_view (float/double) + * @tparam IndexType Type of dense_view + * @tparam LayoutPolicy layout of dense_view + * @param[in] dense_view input raft::device_matrix_view + * @returns dense matrix descriptor to be used by cuSparse API + */ +template +cusparseDnMatDescr_t create_descriptor( + raft::device_matrix_view dense_view) +{ + bool is_row_major = raft::is_row_major(dense_view); + auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); + cusparseDnMatDescr_t descr; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descr, + dense_view.extent(0), + dense_view.extent(1), + ld, + const_cast*>(dense_view.data_handle()), + order)); + return descr; +} + +/** + * @brief create a cuSparse sparse descriptor + * @tparam ValueType Data type of sparse_view (float/double) + * @tparam IndptrType Data type of csr_matrix_view index pointers + * @tparam IndicesType Data type of csr_matrix_view indices + * @tparam NZType Type of sparse_view + * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns + * @returns sparse matrix descriptor to be used by cuSparse API + */ +template +cusparseSpMatDescr_t create_descriptor( + raft::device_csr_matrix_view sparse_view) +{ + cusparseSpMatDescr_t descr; + auto csr_structure = sparse_view.structure_view(); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descr, + static_cast(csr_structure.get_n_rows()), + static_cast(csr_structure.get_n_cols()), + static_cast(csr_structure.get_nnz()), + const_cast(csr_structure.get_indptr().data()), + const_cast(csr_structure.get_indices().data()), + const_cast*>(sparse_view.get_elements().data()))); + return descr; +} + +/** + * @brief convert the operation to cusparseOperation_t type + * @param param[in] op type of operation + */ +inline cusparseOperation_t convert_operation(const raft::linalg::Operation op) +{ + if (op == raft::linalg::Operation::TRANSPOSE) { + return CUSPARSE_OPERATION_TRANSPOSE; + } else if (op == raft::linalg::Operation::NON_TRANSPOSE) { + return CUSPARSE_OPERATION_NON_TRANSPOSE; + } else { + RAFT_EXPECTS(false, "The operation type is not allowed."); + } + return CUSPARSE_OPERATION_NON_TRANSPOSE; +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/detail/sddmm.hpp b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp new file mode 100644 index 0000000000..5088a20f46 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (op_a(A) * op_b(B) ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @tparam NZType Type of C + * + * @param[in] handle raft resource handle + * @param[in] descr_a input dense descriptor + * @param[in] descr_b input dense descriptor + * @param[in/out] descr_c output sparse descriptor + * @param[in] op_a input Operation op(A) + * @param[in] op_b input Operation op(B) + * @param[in] alpha scalar pointer + * @param[in] beta scalar pointer + */ +template +void sddmm(raft::resources const& handle, + cusparseDnMatDescr_t& descr_a, + cusparseDnMatDescr_t& descr_b, + cusparseSpMatDescr_t& descr_c, + cusparseOperation_t op_a, + cusparseOperation_t op_b, + const ValueType* alpha, + const ValueType* beta) +{ + auto alg = CUSPARSE_SDDMM_ALG_DEFAULT; + size_t bufferSize; + + RAFT_CUSPARSE_TRY( + raft::sparse::detail::cusparsesddmm_bufferSize(resource::get_cusparse_handle(handle), + op_a, + op_b, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + &bufferSize, + resource::get_cuda_stream(handle))); + + resource::sync_stream(handle); + + rmm::device_uvector tmp(bufferSize, resource::get_cuda_stream(handle)); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle), + op_a, + op_b, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + tmp.data(), + resource::get_cuda_stream(handle))); +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/detail/spmm.hpp b/cpp/include/raft/sparse/linalg/detail/spmm.hpp index d8d73ee83f..6206348b02 100644 --- a/cpp/include/raft/sparse/linalg/detail/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/detail/spmm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,58 +48,6 @@ bool is_row_major(raft::device_matrix_view -cusparseDnMatDescr_t create_descriptor( - raft::device_matrix_view& dense_view, const bool is_row_major) -{ - auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; - IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); - cusparseDnMatDescr_t descr; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( - &descr, - dense_view.extent(0), - dense_view.extent(1), - ld, - const_cast*>(dense_view.data_handle()), - order)); - return descr; -} - -/** - * @brief create a cuSparse sparse descriptor - * @tparam ValueType Data type of sparse_view (float/double) - * @tparam IndptrType Data type of csr_matrix_view index pointers - * @tparam IndicesType Data type of csr_matrix_view indices - * @tparam NZType Type of sparse_view - * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns - * @returns sparse matrix descriptor to be used by cuSparse API - */ -template -cusparseSpMatDescr_t create_descriptor( - raft::device_csr_matrix_view& sparse_view) -{ - cusparseSpMatDescr_t descr; - auto csr_structure = sparse_view.structure_view(); - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( - &descr, - static_cast(csr_structure.get_n_rows()), - static_cast(csr_structure.get_n_cols()), - static_cast(csr_structure.get_nnz()), - const_cast(csr_structure.get_indptr().data()), - const_cast(csr_structure.get_indices().data()), - const_cast*>(sparse_view.get_elements().data()))); - return descr; -} - /** * @brief SPMM function designed for handling all CSR * DENSE * combinations of operand layouts for cuSparse. diff --git a/cpp/include/raft/sparse/linalg/sddmm.hpp b/cpp/include/raft/sparse/linalg/sddmm.hpp new file mode 100644 index 0000000000..c19f1d9081 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/sddmm.hpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (opA(A) * opB(B) ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam NZType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @param[in] handle raft handle + * @param[in] A input raft::device_matrix_view + * @param[in] B input raft::device_matrix_view + * @param[inout] C output raft::device_csr_matrix_view + * @param[in] opA input Operation op(A) + * @param[in] opB input Operation op(B) + * @param[in] alpha input raft::host_scalar_view + * @param[in] beta input raft::host_scalar_view + */ +template +void sddmm(raft::resources const& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + raft::device_csr_matrix_view C, + const raft::linalg::Operation opA, + const raft::linalg::Operation opB, + raft::host_scalar_view alpha, + raft::host_scalar_view beta) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(A), "A is not contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(B), "B is not contiguous"); + + static_assert(std::is_same_v || std::is_same_v, + "The `ValueType` of sddmm only supports float/double."); + + auto descrA = detail::create_descriptor(A); + auto descrB = detail::create_descriptor(B); + auto descrC = detail::create_descriptor(C); + auto op_A = detail::convert_operation(opA); + auto op_B = detail::convert_operation(opB); + + detail::sddmm( + handle, descrA, descrB, descrC, op_A, op_B, alpha.data_handle(), beta.data_handle()); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descrA)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descrB)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descrC)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/spmm.cuh b/cpp/include/raft/sparse/linalg/spmm.cuh index 064da4d8fb..439ed8c341 100644 --- a/cpp/include/raft/sparse/linalg/spmm.cuh +++ b/cpp/include/raft/sparse/linalg/spmm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,66 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __SPMM_H -#define __SPMM_H - #pragma once -#include "detail/spmm.hpp" - -namespace raft { -namespace sparse { -namespace linalg { - -/** - * @brief SPMM function designed for handling all CSR * DENSE - * combinations of operand layouts for cuSparse. - * It computes the following equation: Z = alpha . X * Y + beta . Z - * where X is a CSR device matrix view and Y,Z are device matrix views - * @tparam ValueType Data type of input/output matrices (float/double) - * @tparam IndexType Type of Y and Z - * @tparam NZType Type of X - * @tparam LayoutPolicyY layout of Y - * @tparam LayoutPolicyZ layout of Z - * @param[in] handle raft handle - * @param[in] trans_x transpose operation for X - * @param[in] trans_y transpose operation for Y - * @param[in] alpha scalar - * @param[in] x input raft::device_csr_matrix_view - * @param[in] y input raft::device_matrix_view - * @param[in] beta scalar - * @param[out] z output raft::device_matrix_view - */ -template -void spmm(raft::resources const& handle, - const bool trans_x, - const bool trans_y, - const ValueType* alpha, - raft::device_csr_matrix_view x, - raft::device_matrix_view y, - const ValueType* beta, - raft::device_matrix_view z) -{ - bool is_row_major = detail::is_row_major(y, z); - - auto descr_x = detail::create_descriptor(x); - auto descr_y = detail::create_descriptor(y, is_row_major); - auto descr_z = detail::create_descriptor(z, is_row_major); - - detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); - - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z)); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -} // end namespace linalg -} // end namespace sparse -} // end namespace raft +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the spmm.hpp at the same path instead.") -#endif +#include diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp new file mode 100644 index 0000000000..c2fdd64574 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __SPMM_H +#define __SPMM_H + +#pragma once + +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @brief SPMM function designed for handling all CSR * DENSE + * combinations of operand layouts for cuSparse. + * It computes the following equation: Z = alpha . X * Y + beta . Z + * where X is a CSR device matrix view and Y,Z are device matrix views + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of Y and Z + * @tparam NZType Type of X + * @tparam LayoutPolicyY layout of Y + * @tparam LayoutPolicyZ layout of Z + * @param[in] handle raft handle + * @param[in] trans_x transpose operation for X + * @param[in] trans_y transpose operation for Y + * @param[in] alpha scalar + * @param[in] x input raft::device_csr_matrix_view + * @param[in] y input raft::device_matrix_view + * @param[in] beta scalar + * @param[out] z output raft::device_matrix_view + */ +template +void spmm(raft::resources const& handle, + const bool trans_x, + const bool trans_y, + const ValueType* alpha, + raft::device_csr_matrix_view x, + raft::device_matrix_view y, + const ValueType* beta, + raft::device_matrix_view z) +{ + bool is_row_major = detail::is_row_major(y, z); + + auto descr_x = detail::create_descriptor(x); + auto descr_y = detail::create_descriptor(y); + auto descr_z = detail::create_descriptor(z); + + detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft + +#endif diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6e32281ec0..931530b66a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -315,6 +315,7 @@ if(BUILD_TESTS) test/sparse/normalize.cu test/sparse/reduce.cu test/sparse/row_op.cu + test/sparse/sddmm.cu test/sparse/sort.cu test/sparse/spgemmi.cu test/sparse/symmetrize.cu diff --git a/cpp/test/sparse/sddmm.cu b/cpp/test/sparse/sddmm.cu new file mode 100644 index 0000000000..9323ee8c2b --- /dev/null +++ b/cpp/test/sparse/sddmm.cu @@ -0,0 +1,365 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../test_utils.cuh" + +namespace raft { +namespace sparse { + +template +struct SDDMMInputs { + ValueType tolerance; + + IndexType m; + IndexType k; + IndexType n; + + ValueType alpha; + ValueType beta; + + bool transpose_a; + bool transpose_b; + + ValueType sparsity; + + unsigned long long int seed; +}; + +template +struct sum_abs_op { + __host__ __device__ ValueType operator()(const ValueType& x, const ValueType& y) const + { + return y >= ValueType(0.0) ? (x + y) : (x - y); + } +}; + +template +::std::ostream& operator<<(::std::ostream& os, const SDDMMInputs& params) +{ + os << " m: " << params.m << "\tk: " << params.k << "\tn: " << params.n + << "\talpha: " << params.alpha << "\tbeta: " << params.beta + << "\tsparsity: " << params.sparsity; + + return os; +} + +template +class SDDMMTest : public ::testing::TestWithParam> { + public: + SDDMMTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + a_data_d(0, resource::get_cuda_stream(handle)), + b_data_d(0, resource::get_cuda_stream(handle)), + c_indptr_d(0, resource::get_cuda_stream(handle)), + c_indices_d(0, resource::get_cuda_stream(handle)), + c_data_d(0, resource::get_cuda_stream(handle)), + c_expected_data_d(0, resource::get_cuda_stream(handle)) + { + } + + protected: + IndexType create_sparse_matrix(IndexType m, + IndexType n, + ValueType sparsity, + std::vector& matrix) + { + IndexType total_elements = static_cast(m * n); + IndexType num_ones = static_cast((total_elements * 1.0f) * sparsity); + IndexType res = num_ones; + + for (IndexType i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis(gen); + + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + void convert_to_csr(std::vector& matrix, + IndexType rows, + IndexType cols, + std::vector& values, + std::vector& indices, + std::vector& indptr) + { + IndexType offset_indptr = 0; + IndexType offset_values = 0; + indptr[offset_indptr++] = 0; + + for (IndexType i = 0; i < rows; ++i) { + for (IndexType j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + values[offset_values] = static_cast(1.0); + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + void cpu_sddmm(const std::vector& A, + const std::vector& B, + std::vector& vals, + const std::vector& cols, + const std::vector& row_ptrs, + bool is_row_major_A, + bool is_row_major_B) + { + if (params.m * params.k != static_cast(A.size()) || + params.k * params.n != static_cast(B.size())) { + std::cerr << "Matrix dimensions and vector size do not match!" << std::endl; + return; + } + + bool trans_a = params.transpose_a ? !is_row_major_A : is_row_major_A; + bool trans_b = params.transpose_b ? !is_row_major_B : is_row_major_B; + + for (IndexType i = 0; i < params.m; ++i) { + for (IndexType j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j) { + ValueType sum = 0; + for (IndexType l = 0; l < params.k; ++l) { + IndexType a_index = trans_a ? i * params.k + l : l * params.m + i; + IndexType b_index = trans_b ? l * params.n + cols[j] : cols[j] * params.k + l; + sum += A[a_index] * B[b_index]; + } + vals[j] = params.alpha * sum + params.beta * vals[j]; + } + } + } + + void make_data() + { + IndexType a_size = params.m * params.k; + IndexType b_size = params.k * params.n; + IndexType c_size = params.m * params.n; + + std::vector a_data_h(a_size); + std::vector b_data_h(b_size); + + a_data_d.resize(a_size, stream); + b_data_d.resize(b_size, stream); + + auto blobs_a_b = raft::make_device_matrix(handle, 1, a_size + b_size); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_a_b.data_handle(), + labels.data_handle(), + 1, + a_size + b_size, + 1, + stream, + false, + nullptr, + nullptr, + ValueType(1.0), + false, + ValueType(-1.0f), + ValueType(1.0f), + uint64_t(2024)); + + raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_h.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + raft::copy(a_data_d.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_d.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + resource::sync_stream(handle); + + std::vector c_dense_data_h(c_size); + IndexType c_true_nnz = + create_sparse_matrix(params.m, params.n, params.sparsity, c_dense_data_h); + + std::vector c_indptr_h(params.m + 1); + std::vector c_indices_h(c_true_nnz); + std::vector c_data_h(c_true_nnz); + + convert_to_csr(c_dense_data_h, params.m, params.n, c_data_h, c_indices_h, c_indptr_h); + + bool is_row_major_A = (std::is_same_v); + bool is_row_major_B = (std::is_same_v); + + c_data_d.resize(c_data_h.size(), stream); + update_device(c_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + resource::sync_stream(handle); + + cpu_sddmm( + a_data_h, b_data_h, c_data_h, c_indices_h, c_indptr_h, is_row_major_A, is_row_major_B); + + c_indptr_d.resize(c_indptr_h.size(), stream); + c_indices_d.resize(c_indices_h.size(), stream); + c_expected_data_d.resize(c_data_h.size(), stream); + + update_device(c_indptr_d.data(), c_indptr_h.data(), c_indptr_h.size(), stream); + update_device(c_indices_d.data(), c_indices_h.data(), c_indices_h.size(), stream); + update_device(c_expected_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + + resource::sync_stream(handle); + } + + void SetUp() override { make_data(); } + + void Run() + { + auto a = raft::make_device_matrix_view( + a_data_d.data(), + (!params.transpose_a ? params.m : params.k), + (!params.transpose_a ? params.k : params.m)); + auto b = raft::make_device_matrix_view( + b_data_d.data(), + (!params.transpose_b ? params.k : params.n), + (!params.transpose_b ? params.n : params.k)); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + + auto op_a = params.transpose_a ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + auto op_b = params.transpose_b ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + + resource::sync_stream(handle); + + ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), + c.get_elements().data(), + c_expected_data_d.size(), + raft::CompareApprox(params.tolerance), + stream)); + + thrust::device_ptr expected_data_ptr = + thrust::device_pointer_cast(c_expected_data_d.data()); + ValueType sum_abs = thrust::reduce(thrust::cuda::par.on(stream), + expected_data_ptr, + expected_data_ptr + c_expected_data_d.size(), + ValueType(0.0f), + sum_abs_op()); + ValueType avg = sum_abs / (1.0f * c_expected_data_d.size()); + + ASSERT_GE(avg, (params.tolerance * static_cast(0.001f))); + } + + raft::resources handle; + cudaStream_t stream; + SDDMMInputs params; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; + + rmm::device_uvector c_expected_data_d; +}; + +using SDDMMTestF_Row_Col = SDDMMTest; +TEST_P(SDDMMTestF_Row_Col, Result) { Run(); } + +using SDDMMTestF_Col_Row = SDDMMTest; +TEST_P(SDDMMTestF_Col_Row, Result) { Run(); } + +using SDDMMTestF_Row_Row = SDDMMTest; +TEST_P(SDDMMTestF_Row_Row, Result) { Run(); } + +using SDDMMTestF_Col_Col = SDDMMTest; +TEST_P(SDDMMTestF_Col_Col, Result) { Run(); } + +using SDDMMTestD_Row_Col = SDDMMTest; +TEST_P(SDDMMTestD_Row_Col, Result) { Run(); } + +using SDDMMTestD_Col_Row = SDDMMTest; +TEST_P(SDDMMTestD_Col_Row, Result) { Run(); } + +using SDDMMTestD_Row_Row = SDDMMTest; +TEST_P(SDDMMTestD_Row_Row, Result) { Run(); } + +using SDDMMTestD_Col_Col = SDDMMTest; +TEST_P(SDDMMTestD_Col_Col, Result) { Run(); } + +const std::vector> sddmm_inputs_f = { + {0.0001f, 10, 5, 32, 1.0, 0.0, false, false, 0.01, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.3, 0.0, true, false, 0.1, 1234ULL}, + {0.0003f, 32, 1024, 1024, 1.0, 0.3, false, true, 0.2, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.2, 0.2, true, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.1, 0.2, false, false, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 1.0, 0.3, true, false, 0.4, 1234ULL}, + {0.0003f, 32, 1024, 1024, 2.0, 0.2, false, true, 0.19, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.0, 1.2, true, true, 0.1, 1234ULL}}; + +const std::vector> sddmm_inputs_d = { + {0.0001f, 10, 5, 32, 1.0, 0.0, false, false, 0.01, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.3, 0.0, true, false, 0.1, 1234ULL}, + {0.0001f, 32, 1024, 1024, 1.0, 0.3, false, true, 0.2, 1234ULL}, + {0.0001f, 1024, 1024, 1024, 0.2, 0.2, true, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.1, 0.2, false, false, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 1.0, 0.3, true, false, 0.4, 1234ULL}, + {0.0001f, 32, 1024, 1024, 2.0, 0.2, false, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 1024, 0.0, 1.2, true, true, 0.1, 1234ULL}}; + +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Col, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Col_Row, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Row, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Col_Col, ::testing::ValuesIn(sddmm_inputs_f)); + +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Row_Col, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Row, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Row_Row, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Col, ::testing::ValuesIn(sddmm_inputs_d)); + +} // namespace sparse +} // namespace raft