From cbaba370273ab9ef86f4605eabe057f49e8ebac2 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 18 Dec 2023 14:30:18 -0800 Subject: [PATCH] [FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067) - Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2067 --- .../raft/sparse/detail/cusparse_wrappers.h | 112 +++++++++ .../raft/sparse/linalg/detail/sddmm.hpp | 101 ++++++++ .../raft/sparse/linalg/detail/spmm.hpp | 72 ------ .../raft/sparse/linalg/detail/utils.cuh | 107 ++++++++ cpp/include/raft/sparse/linalg/sddmm.cuh | 80 ++++++ cpp/include/raft/sparse/linalg/spmm.cuh | 1 + cpp/test/CMakeLists.txt | 1 + cpp/test/sparse/sddmm.cu | 230 ++++++++++++++++++ 8 files changed, 632 insertions(+), 72 deletions(-) create mode 100644 cpp/include/raft/sparse/linalg/detail/sddmm.hpp create mode 100644 cpp/include/raft/sparse/linalg/detail/utils.cuh create mode 100644 cpp/include/raft/sparse/linalg/sddmm.cuh create mode 100644 cpp/test/sparse/sddmm.cu diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index e8bf9c6de5..e559f34028 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -571,6 +571,118 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, alg, static_cast(externalBuffer)); } + +template +cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, bufferSize); +} +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, bufferSize); +} +template +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + T* externalBuffer, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + float* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_32F, + alg, + static_cast(externalBuffer)); +} +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + double* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_64F, + alg, + static_cast(externalBuffer)); +} + /** @} */ #else /** diff --git a/cpp/include/raft/sparse/linalg/detail/sddmm.hpp b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp new file mode 100644 index 0000000000..f299a2b086 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (A * B ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @tparam NZType Type of C + * + * @param[in] handle raft resource handle + * @param[in] trans_a transpose operation for A + * @param[in] trans_b transpose operation for B + * @param[in] is_row_major data layout of A,B + * @param[in] alpha scalar pointer + * @param[in] descr_a input dense descriptor + * @param[in] descr_b input dense descriptor + * @param[in] beta scalar pointer + * @param[out] descr_c output sparse descriptor + */ +template +void sddmm(raft::resources const& handle, + const bool trans_a, + const bool trans_b, + const bool is_row_major, + const ValueType* alpha, + cusparseDnMatDescr_t& descr_a, + cusparseDnMatDescr_t& descr_b, + const ValueType* beta, + cusparseSpMatDescr_t& descr_c) +{ + auto opA = trans_a ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + auto opB = trans_b ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + auto alg = CUSPARSE_SDDMM_ALG_DEFAULT; + size_t bufferSize; + RAFT_CUSPARSE_TRY( + raft::sparse::detail::cusparsesddmm_bufferSize(resource::get_cusparse_handle(handle), + opA, + opB, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + &bufferSize, + resource::get_cuda_stream(handle))); + + raft::interruptible::synchronize(resource::get_cuda_stream(handle)); + + rmm::device_uvector tmp(bufferSize, resource::get_cuda_stream(handle)); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle), + opA, + opB, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + tmp.data(), + resource::get_cuda_stream(handle))); +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/detail/spmm.hpp b/cpp/include/raft/sparse/linalg/detail/spmm.hpp index d8d73ee83f..5f4ef427f0 100644 --- a/cpp/include/raft/sparse/linalg/detail/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/detail/spmm.hpp @@ -28,78 +28,6 @@ namespace sparse { namespace linalg { namespace detail { -/** - * @brief determine common data layout for both dense matrices - * @tparam ValueType Data type of Y,Z (float/double) - * @tparam IndexType Type of Y,Z - * @tparam LayoutPolicyY layout of Y - * @tparam LayoutPolicyZ layout of Z - * @param[in] x input raft::device_matrix_view - * @param[in] y input raft::device_matrix_view - * @returns dense matrix descriptor to be used by cuSparse API - */ -template -bool is_row_major(raft::device_matrix_view& y, - raft::device_matrix_view& z) -{ - bool is_row_major = z.stride(1) == 1 && y.stride(1) == 1; - bool is_col_major = z.stride(0) == 1 && y.stride(0) == 1; - ASSERT(is_row_major || is_col_major, "Both matrices need to be either row or col major"); - return is_row_major; -} - -/** - * @brief create a cuSparse dense descriptor - * @tparam ValueType Data type of dense_view (float/double) - * @tparam IndexType Type of dense_view - * @tparam LayoutPolicy layout of dense_view - * @param[in] dense_view input raft::device_matrix_view - * @param[in] is_row_major data layout of raft::device_matrix_view - * @returns dense matrix descriptor to be used by cuSparse API - */ -template -cusparseDnMatDescr_t create_descriptor( - raft::device_matrix_view& dense_view, const bool is_row_major) -{ - auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; - IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); - cusparseDnMatDescr_t descr; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( - &descr, - dense_view.extent(0), - dense_view.extent(1), - ld, - const_cast*>(dense_view.data_handle()), - order)); - return descr; -} - -/** - * @brief create a cuSparse sparse descriptor - * @tparam ValueType Data type of sparse_view (float/double) - * @tparam IndptrType Data type of csr_matrix_view index pointers - * @tparam IndicesType Data type of csr_matrix_view indices - * @tparam NZType Type of sparse_view - * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns - * @returns sparse matrix descriptor to be used by cuSparse API - */ -template -cusparseSpMatDescr_t create_descriptor( - raft::device_csr_matrix_view& sparse_view) -{ - cusparseSpMatDescr_t descr; - auto csr_structure = sparse_view.structure_view(); - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( - &descr, - static_cast(csr_structure.get_n_rows()), - static_cast(csr_structure.get_n_cols()), - static_cast(csr_structure.get_nnz()), - const_cast(csr_structure.get_indptr().data()), - const_cast(csr_structure.get_indices().data()), - const_cast*>(sparse_view.get_elements().data()))); - return descr; -} - /** * @brief SPMM function designed for handling all CSR * DENSE * combinations of operand layouts for cuSparse. diff --git a/cpp/include/raft/sparse/linalg/detail/utils.cuh b/cpp/include/raft/sparse/linalg/detail/utils.cuh new file mode 100644 index 0000000000..de04755ef7 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/utils.cuh @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief determine common data layout for both dense matrices + * @tparam ValueType Data type of A, B (float/double) + * @tparam IndexType Type of A, B + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @param[in] a input raft::device_matrix_view + * @param[in] b input raft::device_matrix_view + * @returns dense matrix descriptor to be used by cuSparse API + */ +template +bool is_row_major(raft::device_matrix_view& a, + raft::device_matrix_view& b) +{ + bool is_row_major = a.stride(1) == 1 && b.stride(1) == 1; + bool is_col_major = a.stride(0) == 1 && b.stride(0) == 1; + ASSERT(is_row_major || is_col_major, "Both matrices need to be either row or col major"); + return is_row_major; +} + +/** + * @brief create a cuSparse dense descriptor + * @tparam ValueType Data type of dense_view (float/double) + * @tparam IndexType Type of dense_view + * @tparam LayoutPolicy layout of dense_view + * @param[in] dense_view input raft::device_matrix_view + * @param[in] is_row_major data layout of raft::device_matrix_view + * @returns dense matrix descriptor to be used by cuSparse API + */ +template +cusparseDnMatDescr_t create_descriptor( + raft::device_matrix_view& dense_view, const bool is_row_major) +{ + auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); + cusparseDnMatDescr_t descr; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descr, + dense_view.extent(0), + dense_view.extent(1), + ld, + const_cast*>(dense_view.data_handle()), + order)); + return descr; +} + +/** + * @brief create a cuSparse sparse descriptor + * @tparam ValueType Data type of sparse_view (float/double) + * @tparam IndptrType Data type of csr_matrix_view index pointers + * @tparam IndicesType Data type of csr_matrix_view indices + * @tparam NZType Type of sparse_view + * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns + * @returns sparse matrix descriptor to be used by cuSparse API + */ +template +cusparseSpMatDescr_t create_descriptor( + raft::device_csr_matrix_view& sparse_view) +{ + cusparseSpMatDescr_t descr; + auto csr_structure = sparse_view.structure_view(); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descr, + static_cast(csr_structure.get_n_rows()), + static_cast(csr_structure.get_n_cols()), + static_cast(csr_structure.get_nnz()), + const_cast(csr_structure.get_indptr().data()), + const_cast(csr_structure.get_indices().data()), + const_cast*>(sparse_view.get_elements().data()))); + return descr; +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/sddmm.cuh b/cpp/include/raft/sparse/linalg/sddmm.cuh new file mode 100644 index 0000000000..7609b3103a --- /dev/null +++ b/cpp/include/raft/sparse/linalg/sddmm.cuh @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (A * B ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @tparam NZType Type of Cz + * + * @param[in] handle raft handle + * @param[in] trans_a transpose operation for A + * @param[in] trans_b transpose operation for B + * @param[in] alpha scalar + * @param[in] a input raft::device_matrix_view + * @param[in] b input raft::device_matrix_view + * @param[in] beta scalar + * @param[out] c output raft::device_csr_matrix_view + */ +template +void sddmm(raft::resources const& handle, + const bool trans_a, + const bool trans_b, + const ValueType* alpha, + raft::device_matrix_view a, + raft::device_matrix_view b, + const ValueType* beta, + raft::device_csr_matrix_view c) +{ + static_assert(std::is_same_v || std::is_same_v, + "The `ValueType` of sddmm only supports float/double."); + + bool is_row_major = detail::is_row_major(a, b); + + auto descr_a = detail::create_descriptor(a, is_row_major); + auto descr_b = detail::create_descriptor(b, is_row_major); + auto descr_c = detail::create_descriptor(c); + + detail::sddmm(handle, trans_a, trans_b, is_row_major, alpha, descr_a, descr_b, beta, descr_c); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_a)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_b)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_c)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/spmm.cuh b/cpp/include/raft/sparse/linalg/spmm.cuh index 064da4d8fb..053c9d915b 100644 --- a/cpp/include/raft/sparse/linalg/spmm.cuh +++ b/cpp/include/raft/sparse/linalg/spmm.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/spmm.hpp" +#include namespace raft { namespace sparse { diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index f043442840..3f1f0cec60 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -314,6 +314,7 @@ if(BUILD_TESTS) test/sparse/normalize.cu test/sparse/reduce.cu test/sparse/row_op.cu + test/sparse/sddmm.cu test/sparse/sort.cu test/sparse/spgemmi.cu test/sparse/symmetrize.cu diff --git a/cpp/test/sparse/sddmm.cu b/cpp/test/sparse/sddmm.cu new file mode 100644 index 0000000000..62133c41f8 --- /dev/null +++ b/cpp/test/sparse/sddmm.cu @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include +#include +#include + +#include "../test_utils.cuh" + +namespace raft { +namespace sparse { + +template +struct SDDMMInputs { + bool trans_a; + bool trans_b; + + IndexType m; + IndexType k; + IndexType n; + + ValueType alpha; + ValueType beta; + + std::vector a_data; + std::vector b_data; + + std::vector c_indptr; + std::vector c_indices; + std::vector c_data; + + std::vector c_expected_data; +}; + +template +class SDDMMTest : public ::testing::TestWithParam> { + public: + SDDMMTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + a_data_d(0, resource::get_cuda_stream(handle)), + b_data_d(0, resource::get_cuda_stream(handle)), + c_indptr_d(0, resource::get_cuda_stream(handle)), + c_indices_d(0, resource::get_cuda_stream(handle)), + c_data_d(0, resource::get_cuda_stream(handle)), + c_expected_data_d(0, resource::get_cuda_stream(handle)) + { + } + + protected: + void make_data() + { + std::vector a_data_h = params.a_data; + std::vector b_data_h = params.b_data; + + std::vector c_indptr_h = params.c_indptr; + std::vector c_indices_h = params.c_indices; + std::vector c_data_h = params.c_data; + std::vector c_expected_data_h = params.c_expected_data; + + a_data_d.resize(a_data_h.size(), stream); + b_data_d.resize(b_data_h.size(), stream); + c_indptr_d.resize(c_indptr_h.size(), stream); + c_indices_d.resize(c_indices_h.size(), stream); + c_data_d.resize(c_data_h.size(), stream); + c_expected_data_d.resize(c_expected_data_h.size(), stream); + + update_device(a_data_d.data(), a_data_h.data(), a_data_h.size(), stream); + update_device(b_data_d.data(), b_data_h.data(), b_data_h.size(), stream); + + update_device(c_indptr_d.data(), c_indptr_h.data(), c_indptr_h.size(), stream); + update_device(c_indices_d.data(), c_indices_h.data(), c_indices_h.size(), stream); + update_device(c_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + update_device( + c_expected_data_d.data(), c_expected_data_h.data(), c_expected_data_h.size(), stream); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void SetUp() override { make_data(); } + + void Run() + { + // Check params + ASSERT_EQ(params.a_data.size(), params.m * params.k); + ASSERT_EQ(params.b_data.size(), params.n * params.k); + ASSERT_EQ(params.c_data.size(), params.c_indices.size()); + ASSERT_GE(params.c_indices.size(), 0); + + auto a = raft::make_device_matrix_view( + a_data_d.data(), params.m, params.k); + auto b = raft::make_device_matrix_view( + b_data_d.data(), params.k, params.n); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + + RAFT_CUDA_TRY(cudaStreamSynchronize(resource::get_cuda_stream(handle))); + + raft::sparse::linalg::sddmm( + handle, params.trans_a, params.trans_b, ¶ms.alpha, a, b, ¶ms.beta, c); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + + ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), + c.get_elements().data(), + params.c_indices.size(), + raft::CompareApprox(1e-6f), + stream)); + } + + raft::resources handle; + SDDMMInputs params; + cudaStream_t stream; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; + + rmm::device_uvector c_expected_data_d; +}; + +using SDDMMTestF = SDDMMTest; +TEST_P(SDDMMTestF, Result) { Run(); } + +using SDDMMTestD = SDDMMTest; +TEST_P(SDDMMTestD, Result) { Run(); } + +const std::vector> sddmm_inputs_f = { + { + false, + false, + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + }, + { + false, + false, + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; + +const std::vector> sddmm_inputs_d = { + { + false, + false, + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + + }, + { + false, + false, + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; + +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD, ::testing::ValuesIn(sddmm_inputs_d)); + +} // namespace sparse +} // namespace raft