diff --git a/cpp/include/raft/linalg/linalg_types.hpp b/cpp/include/raft/linalg/linalg_types.hpp index e50d3a8e79..9c81fbc177 100644 --- a/cpp/include/raft/linalg/linalg_types.hpp +++ b/cpp/include/raft/linalg/linalg_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,4 +32,11 @@ enum class Apply { ALONG_ROWS, ALONG_COLUMNS }; */ enum class FillMode { UPPER, LOWER }; +/** + * @brief Enum for this type indicates which operation is applied to the related input (e.g. sparse + * matrix, or vector). + * + */ +enum class Operation { NON_TRANSPOSE, TRANSPOSE }; + } // end namespace raft::linalg \ No newline at end of file diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index e8bf9c6de5..cc3ae3ab87 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -571,6 +571,118 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, alg, static_cast(externalBuffer)); } + +template +cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, bufferSize); +} +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, bufferSize); +} +template +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + T* externalBuffer, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + float* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_32F, + alg, + static_cast(externalBuffer)); +} +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + double* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_64F, + alg, + static_cast(externalBuffer)); +} + /** @} */ #else /** diff --git a/cpp/include/raft/sparse/linalg/detail/sddmm.hpp b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp new file mode 100644 index 0000000000..95da3f1266 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (op_a(A) * op_b(B) ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @tparam NZType Type of C + * + * @param[in] handle raft resource handle + * @param[in] descr_a input dense descriptor + * @param[in] descr_b input dense descriptor + * @param[in/out] descr_c output sparse descriptor + * @param[in] op_a input Operation op(A) + * @param[in] op_b input Operation op(B) + * @param[in] alpha scalar pointer + * @param[in] beta scalar pointer + */ +template +void sddmm(raft::resources const& handle, + cusparseDnMatDescr_t& descr_a, + cusparseDnMatDescr_t& descr_b, + cusparseSpMatDescr_t& descr_c, + cusparseOperation_t op_a, + cusparseOperation_t op_b, + const ValueType* alpha, + const ValueType* beta) +{ + auto alg = CUSPARSE_SDDMM_ALG_DEFAULT; + size_t bufferSize; + + RAFT_CUSPARSE_TRY( + raft::sparse::detail::cusparsesddmm_bufferSize(resource::get_cusparse_handle(handle), + op_a, + op_b, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + &bufferSize, + resource::get_cuda_stream(handle))); + + raft::interruptible::synchronize(resource::get_cuda_stream(handle)); + + rmm::device_uvector tmp(bufferSize, resource::get_cuda_stream(handle)); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle), + op_a, + op_b, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + tmp.data(), + resource::get_cuda_stream(handle))); +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/detail/spmm.hpp b/cpp/include/raft/sparse/linalg/detail/spmm.hpp index d8d73ee83f..6206348b02 100644 --- a/cpp/include/raft/sparse/linalg/detail/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/detail/spmm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,58 +48,6 @@ bool is_row_major(raft::device_matrix_view -cusparseDnMatDescr_t create_descriptor( - raft::device_matrix_view& dense_view, const bool is_row_major) -{ - auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; - IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); - cusparseDnMatDescr_t descr; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( - &descr, - dense_view.extent(0), - dense_view.extent(1), - ld, - const_cast*>(dense_view.data_handle()), - order)); - return descr; -} - -/** - * @brief create a cuSparse sparse descriptor - * @tparam ValueType Data type of sparse_view (float/double) - * @tparam IndptrType Data type of csr_matrix_view index pointers - * @tparam IndicesType Data type of csr_matrix_view indices - * @tparam NZType Type of sparse_view - * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns - * @returns sparse matrix descriptor to be used by cuSparse API - */ -template -cusparseSpMatDescr_t create_descriptor( - raft::device_csr_matrix_view& sparse_view) -{ - cusparseSpMatDescr_t descr; - auto csr_structure = sparse_view.structure_view(); - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( - &descr, - static_cast(csr_structure.get_n_rows()), - static_cast(csr_structure.get_n_cols()), - static_cast(csr_structure.get_nnz()), - const_cast(csr_structure.get_indptr().data()), - const_cast(csr_structure.get_indices().data()), - const_cast*>(sparse_view.get_elements().data()))); - return descr; -} - /** * @brief SPMM function designed for handling all CSR * DENSE * combinations of operand layouts for cuSparse. diff --git a/cpp/include/raft/sparse/linalg/detail/utils.cuh b/cpp/include/raft/sparse/linalg/detail/utils.cuh new file mode 100644 index 0000000000..c72fc20f87 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/utils.cuh @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief create a cuSparse dense descriptor + * @tparam ValueType Data type of dense_view (float/double) + * @tparam IndexType Type of dense_view + * @tparam LayoutPolicy layout of dense_view + * @param[in] dense_view input raft::device_matrix_view + * @returns dense matrix descriptor to be used by cuSparse API + */ +template +cusparseDnMatDescr_t create_descriptor( + raft::device_matrix_view& dense_view) +{ + bool is_row_major = raft::is_row_major(dense_view); + auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); + cusparseDnMatDescr_t descr; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descr, + dense_view.extent(0), + dense_view.extent(1), + ld, + const_cast*>(dense_view.data_handle()), + order)); + return descr; +} + +/** + * @brief create a cuSparse sparse descriptor + * @tparam ValueType Data type of sparse_view (float/double) + * @tparam IndptrType Data type of csr_matrix_view index pointers + * @tparam IndicesType Data type of csr_matrix_view indices + * @tparam NZType Type of sparse_view + * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns + * @returns sparse matrix descriptor to be used by cuSparse API + */ +template +cusparseSpMatDescr_t create_descriptor( + raft::device_csr_matrix_view& sparse_view) +{ + cusparseSpMatDescr_t descr; + auto csr_structure = sparse_view.structure_view(); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descr, + static_cast(csr_structure.get_n_rows()), + static_cast(csr_structure.get_n_cols()), + static_cast(csr_structure.get_nnz()), + const_cast(csr_structure.get_indptr().data()), + const_cast(csr_structure.get_indices().data()), + const_cast*>(sparse_view.get_elements().data()))); + return descr; +} + +/** + * @brief convert the operation to cusparseOperation_t type + * @param param[in] op type of operation + */ +inline cusparseOperation_t convert_operation(const raft::linalg::Operation op) +{ + if (op == raft::linalg::Operation::TRANSPOSE) { + return CUSPARSE_OPERATION_TRANSPOSE; + } else if (op == raft::linalg::Operation::NON_TRANSPOSE) { + return CUSPARSE_OPERATION_NON_TRANSPOSE; + } else { + RAFT_EXPECTS(false, "The operation type is not allowed."); + } + return CUSPARSE_OPERATION_NON_TRANSPOSE; +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/sddmm.cuh b/cpp/include/raft/sparse/linalg/sddmm.cuh new file mode 100644 index 0000000000..a270a29133 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/sddmm.cuh @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (opA(A) * opB(B) ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @tparam NZType Type of C + * + * @param[in] handle raft handle + * @param[in] A input raft::device_matrix_view + * @param[in] B input raft::device_matrix_view + * @param[in/out] C output raft::device_csr_matrix_view + * @param[in] opA input Operation op(A) + * @param[in] opB input Operation op(B) + * @param[in] alpha input raft::host_scalar_view + * @param[in] beta input raft::host_scalar_view + */ +template +void sddmm(raft::resources const& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + raft::device_csr_matrix_view C, + const raft::linalg::Operation opA, + const raft::linalg::Operation opB, + raft::host_scalar_view alpha, + raft::host_scalar_view beta) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(A), "A is not contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(B), "B is not contiguous"); + + static_assert(std::is_same_v || std::is_same_v, + "The `ValueType` of sddmm only supports float/double."); + + auto descrA = detail::create_descriptor(A); + auto descrB = detail::create_descriptor(B); + auto descrC = detail::create_descriptor(C); + auto op_A = detail::convert_operation(opA); + auto op_B = detail::convert_operation(opB); + + detail::sddmm( + handle, descrA, descrB, descrC, op_A, op_B, alpha.data_handle(), beta.data_handle()); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descrA)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descrB)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descrC)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/spmm.cuh b/cpp/include/raft/sparse/linalg/spmm.cuh index 064da4d8fb..72925199e4 100644 --- a/cpp/include/raft/sparse/linalg/spmm.cuh +++ b/cpp/include/raft/sparse/linalg/spmm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #pragma once #include "detail/spmm.hpp" +#include namespace raft { namespace sparse { @@ -60,8 +61,8 @@ void spmm(raft::resources const& handle, bool is_row_major = detail::is_row_major(y, z); auto descr_x = detail::create_descriptor(x); - auto descr_y = detail::create_descriptor(y, is_row_major); - auto descr_z = detail::create_descriptor(z, is_row_major); + auto descr_y = detail::create_descriptor(y); + auto descr_z = detail::create_descriptor(z); detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6e32281ec0..931530b66a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -315,6 +315,7 @@ if(BUILD_TESTS) test/sparse/normalize.cu test/sparse/reduce.cu test/sparse/row_op.cu + test/sparse/sddmm.cu test/sparse/sort.cu test/sparse/spgemmi.cu test/sparse/symmetrize.cu diff --git a/cpp/test/sparse/sddmm.cu b/cpp/test/sparse/sddmm.cu new file mode 100644 index 0000000000..3436445192 --- /dev/null +++ b/cpp/test/sparse/sddmm.cu @@ -0,0 +1,425 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include +#include +#include + +#include "../test_utils.cuh" + +namespace raft { +namespace sparse { + +template +struct SDDMMInputs { + IndexType m; + IndexType k; + IndexType n; + + ValueType alpha; + ValueType beta; + + std::vector a_data; + std::vector b_data; + + std::vector c_indptr; + std::vector c_indices; + std::vector c_data; + + std::vector c_expected_data; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const SDDMMInputs& params) +{ + return os; +} + +template +class SDDMMTest : public ::testing::TestWithParam> { + public: + SDDMMTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + a_data_d(0, resource::get_cuda_stream(handle)), + b_data_d(0, resource::get_cuda_stream(handle)), + c_indptr_d(0, resource::get_cuda_stream(handle)), + c_indices_d(0, resource::get_cuda_stream(handle)), + c_data_d(0, resource::get_cuda_stream(handle)), + c_expected_data_d(0, resource::get_cuda_stream(handle)) + { + } + + protected: + void make_data() + { + std::vector a_data_h = params.a_data; + std::vector b_data_h = params.b_data; + + std::vector c_indptr_h = params.c_indptr; + std::vector c_indices_h = params.c_indices; + std::vector c_data_h = params.c_data; + std::vector c_expected_data_h = params.c_expected_data; + + a_data_d.resize(a_data_h.size(), stream); + b_data_d.resize(b_data_h.size(), stream); + c_indptr_d.resize(c_indptr_h.size(), stream); + c_indices_d.resize(c_indices_h.size(), stream); + c_data_d.resize(c_data_h.size(), stream); + c_expected_data_d.resize(c_expected_data_h.size(), stream); + + update_device(a_data_d.data(), a_data_h.data(), a_data_h.size(), stream); + update_device(b_data_d.data(), b_data_h.data(), b_data_h.size(), stream); + + update_device(c_indptr_d.data(), c_indptr_h.data(), c_indptr_h.size(), stream); + update_device(c_indices_d.data(), c_indices_h.data(), c_indices_h.size(), stream); + update_device(c_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + update_device( + c_expected_data_d.data(), c_expected_data_h.data(), c_expected_data_h.size(), stream); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void SetUp() override { make_data(); } + + void Run() + { + // Check params + ASSERT_EQ(params.a_data.size(), params.m * params.k); + ASSERT_EQ(params.b_data.size(), params.n * params.k); + ASSERT_EQ(params.c_data.size(), params.c_indices.size()); + ASSERT_GE(params.c_indices.size(), 0); + + auto a = raft::make_device_matrix_view( + a_data_d.data(), params.m, params.k); + auto b = raft::make_device_matrix_view( + b_data_d.data(), + ((std::is_same_v) ? params.n : params.k), + ((std::is_same_v) ? params.k : params.n)); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + + RAFT_CUDA_TRY(cudaStreamSynchronize(resource::get_cuda_stream(handle))); + + auto op_a = raft::linalg::Operation::NON_TRANSPOSE; + auto op_b = !(std::is_same_v) + ? raft::linalg::Operation::NON_TRANSPOSE + : raft::linalg::Operation::TRANSPOSE; + + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + + ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), + c.get_elements().data(), + params.c_indices.size(), + raft::CompareApprox(1e-6f), + stream)); + } + + raft::resources handle; + SDDMMInputs params; + cudaStream_t stream; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; + + rmm::device_uvector c_expected_data_d; +}; + +using SDDMMTestF_Row_Col = SDDMMTest; +TEST_P(SDDMMTestF_Row_Col, Result) { Run(); } + +using SDDMMTestF_Col_Row = SDDMMTest; +TEST_P(SDDMMTestF_Col_Row, Result) { Run(); } + +using SDDMMTestF_Row_Row = SDDMMTest; +TEST_P(SDDMMTestF_Row_Row, Result) { Run(); } + +using SDDMMTestF_Col_Col = SDDMMTest; +TEST_P(SDDMMTestF_Col_Col, Result) { Run(); } + +using SDDMMTestD_Row_Col = SDDMMTest; +TEST_P(SDDMMTestD_Row_Col, Result) { Run(); } + +using SDDMMTestD_Col_Row = SDDMMTest; +TEST_P(SDDMMTestD_Col_Row, Result) { Run(); } + +using SDDMMTestD_Row_Row = SDDMMTest; +TEST_P(SDDMMTestD_Row_Row, Result) { Run(); } + +using SDDMMTestD_Col_Col = SDDMMTest; +TEST_P(SDDMMTestD_Col_Col, Result) { Run(); } + +const std::vector> sddmm_inputs_row_col_f = { + { + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + }, + { + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; +const std::vector> sddmm_inputs_col_row_f = { + { + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0, 4.0, 8.0, 12.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + }, + { + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0, 4.0, 8.0, 12.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; +const std::vector> sddmm_inputs_row_row_f = { + { + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + }, + { + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; +const std::vector> sddmm_inputs_col_col_f = { + { + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0, 4.0, 8.0, 12.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + }, + { + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0, 4.0, 8.0, 12.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; + +const std::vector> sddmm_inputs_row_col_d = { + { + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + }, + { + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; + +const std::vector> sddmm_inputs_col_row_d = { + { + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0, 4.0, 8.0, 12.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + }, + { + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0, 4.0, 8.0, 12.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; +const std::vector> sddmm_inputs_row_row_d = { + { + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + }, + { + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}, + {1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; +const std::vector> sddmm_inputs_col_col_d = { + { + 4, + 4, + 3, + 1.0, + 0.0, + {1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0, 4.0, 8.0, 12.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.0, 80.0, 90.0, 184.0, 246.0, 288.0, 330.0, 334.0, 450.0}, + }, + { + 4, + 4, + 3, + 1.0, + 0.5, + {1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0, 4.0, 8.0, 12.0, 16.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + {0, 3, 4, 7, 9}, + {0, 1, 2, 1, 0, 1, 2, 0, 2}, + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, + {70.5, 80.5, 90.5, 184.5, 246.5, 288.5, 330.5, 334.5, 450.5}, + }}; + +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Col, ::testing::ValuesIn(sddmm_inputs_row_col_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Col_Row, ::testing::ValuesIn(sddmm_inputs_col_row_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Row, ::testing::ValuesIn(sddmm_inputs_row_row_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Col_Col, ::testing::ValuesIn(sddmm_inputs_col_col_f)); + +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Row_Col, ::testing::ValuesIn(sddmm_inputs_row_col_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Row, ::testing::ValuesIn(sddmm_inputs_col_row_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Row_Row, ::testing::ValuesIn(sddmm_inputs_row_row_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Col, ::testing::ValuesIn(sddmm_inputs_col_col_d)); + +} // namespace sparse +} // namespace raft