diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index fe58453d0d..3a2431cd34 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -117,6 +117,7 @@ if(BUILD_PRIMS_BENCH) bench/prims/linalg/reduce_cols_by_key.cu bench/prims/linalg/reduce_rows_by_key.cu bench/prims/linalg/reduce.cu + bench/prims/linalg/sddmm.cu bench/prims/main.cpp ) diff --git a/cpp/bench/prims/linalg/sddmm.cu b/cpp/bench/prims/linalg/sddmm.cu new file mode 100644 index 0000000000..139a2b838d --- /dev/null +++ b/cpp/bench/prims/linalg/sddmm.cu @@ -0,0 +1,275 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace raft::bench::linalg { + +template +struct SDDMMBenchParams { + size_t m; + size_t k; + size_t n; + float sparsity; + bool transpose_a; + bool transpose_b; + ValueType alpha = 1.0; + ValueType beta = 0.0; +}; + +enum Alg { SDDMM, Inner }; + +template +inline auto operator<<(std::ostream& os, const SDDMMBenchParams& params) -> std::ostream& +{ + os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n + << "\tsparsity=" << params.sparsity << "\ttrans_a=" << (params.transpose_a ? "T" : "F") + << " trans_b=" << (params.transpose_b ? "T" : "F"); + return os; +} + +template +struct SDDMMBench : public fixture { + SDDMMBench(const SDDMMBenchParams& p) + : fixture(true), + params(p), + handle(stream), + a_data_d(0, stream), + b_data_d(0, stream), + c_indptr_d(0, stream), + c_indices_d(0, stream), + c_data_d(0, stream), + c_dense_data_d(0, stream) + { + a_data_d.resize(params.m * params.k, stream); + b_data_d.resize(params.k * params.n, stream); + + raft::random::RngState rng(2024ULL); + raft::random::uniform( + handle, rng, a_data_d.data(), params.m * params.k, ValueType(-1.0), ValueType(1.0)); + raft::random::uniform( + handle, rng, b_data_d.data(), params.k * params.n, ValueType(-1.0), ValueType(1.0)); + + std::vector c_dense_data_h(params.m * params.n); + + c_true_nnz = create_sparse_matrix(c_dense_data_h); + std::vector values(c_true_nnz); + std::vector indices(c_true_nnz); + std::vector indptr(params.m + 1); + + c_data_d.resize(c_true_nnz, stream); + c_indptr_d.resize(params.m + 1, stream); + c_indices_d.resize(c_true_nnz, stream); + + if (SDDMMorInner == Alg::Inner) { c_dense_data_d.resize(params.m * params.n, stream); } + + convert_to_csr(c_dense_data_h, params.m, params.n, values, indices, indptr); + RAFT_EXPECTS(c_true_nnz == c_indices_d.size(), + "Something wrong. The c_true_nnz != c_indices_d.size()!"); + + update_device(c_data_d.data(), values.data(), c_true_nnz, stream); + update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream); + update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream); + } + + void convert_to_csr(std::vector& matrix, + IndexType rows, + IndexType cols, + std::vector& values, + std::vector& indices, + std::vector& indptr) + { + IndexType offset_indptr = 0; + IndexType offset_values = 0; + indptr[offset_indptr++] = 0; + + for (IndexType i = 0; i < rows; ++i) { + for (IndexType j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + values[offset_values] = static_cast(1.0); + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + size_t create_sparse_matrix(std::vector& matrix) + { + size_t total_elements = static_cast(params.m * params.n); + size_t num_ones = static_cast((total_elements * 1.0f) * params.sparsity); + size_t res = num_ones; + + for (size_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis(gen); + + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + ~SDDMMBench() {} + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto a = raft::make_device_matrix_view( + a_data_d.data(), + (!params.transpose_a ? params.m : params.k), + (!params.transpose_a ? params.k : params.m)); + + auto b = raft::make_device_matrix_view( + b_data_d.data(), + (!params.transpose_b ? params.k : params.n), + (!params.transpose_b ? params.n : params.k)); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + raft::resource::get_cusparse_handle(handle); + + resource::sync_stream(handle); + + auto op_a = params.transpose_a ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + auto op_b = params.transpose_b ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + resource::sync_stream(handle); + + loop_on_state(state, [this, &a, &b, &c, &op_a, &op_b]() { + if (SDDMMorInner == Alg::SDDMM) { + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + resource::sync_stream(handle); + } else { + raft::distance::pairwise_distance(handle, + a_data_d.data(), + b_data_d.data(), + c_dense_data_d.data(), + static_cast(params.m), + static_cast(params.n), + static_cast(params.k), + raft::distance::DistanceType::InnerProduct, + std::is_same_v); + resource::sync_stream(handle); + } + }); + } + + private: + const raft::device_resources handle; + SDDMMBenchParams params; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + rmm::device_uvector c_dense_data_d; + + size_t c_true_nnz = 0; + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; +}; + +template +static std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + bool transpose_a; + bool transpose_b; + size_t m; + size_t k; + size_t n; + float sparsity; + }; + + const std::vector params_group = + raft::util::itertools::product({false, true}, + {false, true}, + {size_t(10), size_t(1024)}, + {size_t(128), size_t(1024)}, + {size_t(1024 * 1024)}, + {0.01f, 0.1f, 0.2f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(SDDMMBenchParams( + {params.m, params.k, params.n, params.sparsity, params.transpose_a, params.transpose_b})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); + +RAFT_BENCH_REGISTER((SDDMMBench), "", getInputs()); + +} // namespace raft::bench::linalg diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index e121c1be9c..14b4ba12c6 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ #include // #include #include -#include +#include #include #include diff --git a/cpp/include/raft/linalg/linalg_types.hpp b/cpp/include/raft/linalg/linalg_types.hpp index e50d3a8e79..9c81fbc177 100644 --- a/cpp/include/raft/linalg/linalg_types.hpp +++ b/cpp/include/raft/linalg/linalg_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,4 +32,11 @@ enum class Apply { ALONG_ROWS, ALONG_COLUMNS }; */ enum class FillMode { UPPER, LOWER }; +/** + * @brief Enum for this type indicates which operation is applied to the related input (e.g. sparse + * matrix, or vector). + * + */ +enum class Operation { NON_TRANSPOSE, TRANSPOSE }; + } // end namespace raft::linalg \ No newline at end of file diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index e8bf9c6de5..cc3ae3ab87 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -571,6 +571,118 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle, alg, static_cast(externalBuffer)); } + +template +cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, bufferSize); +} +template <> +inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + size_t* bufferSize, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM_bufferSize( + handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, bufferSize); +} +template +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const T* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const T* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + T* externalBuffer, + cudaStream_t stream); +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const float* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const float* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + float* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_32F, + alg, + static_cast(externalBuffer)); +} +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const double* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const double* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + double* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_64F, + alg, + static_cast(externalBuffer)); +} + /** @} */ #else /** diff --git a/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp new file mode 100644 index 0000000000..b15614905b --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief create a cuSparse dense descriptor + * @tparam ValueType Data type of dense_view (float/double) + * @tparam IndexType Type of dense_view + * @tparam LayoutPolicy layout of dense_view + * @param[in] dense_view input raft::device_matrix_view + * @returns dense matrix descriptor to be used by cuSparse API + */ +template +cusparseDnMatDescr_t create_descriptor( + raft::device_matrix_view dense_view) +{ + bool is_row_major = raft::is_row_major(dense_view); + auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); + cusparseDnMatDescr_t descr; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descr, + dense_view.extent(0), + dense_view.extent(1), + ld, + const_cast*>(dense_view.data_handle()), + order)); + return descr; +} + +/** + * @brief create a cuSparse sparse descriptor + * @tparam ValueType Data type of sparse_view (float/double) + * @tparam IndptrType Data type of csr_matrix_view index pointers + * @tparam IndicesType Data type of csr_matrix_view indices + * @tparam NZType Type of sparse_view + * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns + * @returns sparse matrix descriptor to be used by cuSparse API + */ +template +cusparseSpMatDescr_t create_descriptor( + raft::device_csr_matrix_view sparse_view) +{ + cusparseSpMatDescr_t descr; + auto csr_structure = sparse_view.structure_view(); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descr, + static_cast(csr_structure.get_n_rows()), + static_cast(csr_structure.get_n_cols()), + static_cast(csr_structure.get_nnz()), + const_cast(csr_structure.get_indptr().data()), + const_cast(csr_structure.get_indices().data()), + const_cast*>(sparse_view.get_elements().data()))); + return descr; +} + +/** + * @brief convert the operation to cusparseOperation_t type + * @param param[in] op type of operation + */ +inline cusparseOperation_t convert_operation(const raft::linalg::Operation op) +{ + if (op == raft::linalg::Operation::TRANSPOSE) { + return CUSPARSE_OPERATION_TRANSPOSE; + } else if (op == raft::linalg::Operation::NON_TRANSPOSE) { + return CUSPARSE_OPERATION_NON_TRANSPOSE; + } else { + RAFT_EXPECTS(false, "The operation type is not allowed."); + } + return CUSPARSE_OPERATION_NON_TRANSPOSE; +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/detail/sddmm.hpp b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp new file mode 100644 index 0000000000..5088a20f46 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (op_a(A) * op_b(B) ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @tparam NZType Type of C + * + * @param[in] handle raft resource handle + * @param[in] descr_a input dense descriptor + * @param[in] descr_b input dense descriptor + * @param[in/out] descr_c output sparse descriptor + * @param[in] op_a input Operation op(A) + * @param[in] op_b input Operation op(B) + * @param[in] alpha scalar pointer + * @param[in] beta scalar pointer + */ +template +void sddmm(raft::resources const& handle, + cusparseDnMatDescr_t& descr_a, + cusparseDnMatDescr_t& descr_b, + cusparseSpMatDescr_t& descr_c, + cusparseOperation_t op_a, + cusparseOperation_t op_b, + const ValueType* alpha, + const ValueType* beta) +{ + auto alg = CUSPARSE_SDDMM_ALG_DEFAULT; + size_t bufferSize; + + RAFT_CUSPARSE_TRY( + raft::sparse::detail::cusparsesddmm_bufferSize(resource::get_cusparse_handle(handle), + op_a, + op_b, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + &bufferSize, + resource::get_cuda_stream(handle))); + + resource::sync_stream(handle); + + rmm::device_uvector tmp(bufferSize, resource::get_cuda_stream(handle)); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle), + op_a, + op_b, + alpha, + descr_a, + descr_b, + beta, + descr_c, + alg, + tmp.data(), + resource::get_cuda_stream(handle))); +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/detail/spmm.hpp b/cpp/include/raft/sparse/linalg/detail/spmm.hpp index d8d73ee83f..6206348b02 100644 --- a/cpp/include/raft/sparse/linalg/detail/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/detail/spmm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,58 +48,6 @@ bool is_row_major(raft::device_matrix_view -cusparseDnMatDescr_t create_descriptor( - raft::device_matrix_view& dense_view, const bool is_row_major) -{ - auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; - IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); - cusparseDnMatDescr_t descr; - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( - &descr, - dense_view.extent(0), - dense_view.extent(1), - ld, - const_cast*>(dense_view.data_handle()), - order)); - return descr; -} - -/** - * @brief create a cuSparse sparse descriptor - * @tparam ValueType Data type of sparse_view (float/double) - * @tparam IndptrType Data type of csr_matrix_view index pointers - * @tparam IndicesType Data type of csr_matrix_view indices - * @tparam NZType Type of sparse_view - * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns - * @returns sparse matrix descriptor to be used by cuSparse API - */ -template -cusparseSpMatDescr_t create_descriptor( - raft::device_csr_matrix_view& sparse_view) -{ - cusparseSpMatDescr_t descr; - auto csr_structure = sparse_view.structure_view(); - RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( - &descr, - static_cast(csr_structure.get_n_rows()), - static_cast(csr_structure.get_n_cols()), - static_cast(csr_structure.get_nnz()), - const_cast(csr_structure.get_indptr().data()), - const_cast(csr_structure.get_indices().data()), - const_cast*>(sparse_view.get_elements().data()))); - return descr; -} - /** * @brief SPMM function designed for handling all CSR * DENSE * combinations of operand layouts for cuSparse. diff --git a/cpp/include/raft/sparse/linalg/sddmm.hpp b/cpp/include/raft/sparse/linalg/sddmm.hpp new file mode 100644 index 0000000000..c19f1d9081 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/sddmm.hpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @brief This function performs the multiplication of dense matrix A and dense matrix B, + * followed by an element-wise multiplication with the sparsity pattern of C. + * It computes the following equation: C = alpha · (opA(A) * opB(B) ∘ spy(C)) + beta · C + * where A,B are device matrix views and C is a CSR device matrix view + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of C + * @tparam NZType Type of C + * @tparam LayoutPolicyA layout of A + * @tparam LayoutPolicyB layout of B + * @param[in] handle raft handle + * @param[in] A input raft::device_matrix_view + * @param[in] B input raft::device_matrix_view + * @param[inout] C output raft::device_csr_matrix_view + * @param[in] opA input Operation op(A) + * @param[in] opB input Operation op(B) + * @param[in] alpha input raft::host_scalar_view + * @param[in] beta input raft::host_scalar_view + */ +template +void sddmm(raft::resources const& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + raft::device_csr_matrix_view C, + const raft::linalg::Operation opA, + const raft::linalg::Operation opB, + raft::host_scalar_view alpha, + raft::host_scalar_view beta) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(A), "A is not contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(B), "B is not contiguous"); + + static_assert(std::is_same_v || std::is_same_v, + "The `ValueType` of sddmm only supports float/double."); + + auto descrA = detail::create_descriptor(A); + auto descrB = detail::create_descriptor(B); + auto descrC = detail::create_descriptor(C); + auto op_A = detail::convert_operation(opA); + auto op_B = detail::convert_operation(opB); + + detail::sddmm( + handle, descrA, descrB, descrC, op_A, op_B, alpha.data_handle(), beta.data_handle()); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descrA)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descrB)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descrC)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/spmm.cuh b/cpp/include/raft/sparse/linalg/spmm.cuh index 064da4d8fb..439ed8c341 100644 --- a/cpp/include/raft/sparse/linalg/spmm.cuh +++ b/cpp/include/raft/sparse/linalg/spmm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,66 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __SPMM_H -#define __SPMM_H - #pragma once -#include "detail/spmm.hpp" - -namespace raft { -namespace sparse { -namespace linalg { - -/** - * @brief SPMM function designed for handling all CSR * DENSE - * combinations of operand layouts for cuSparse. - * It computes the following equation: Z = alpha . X * Y + beta . Z - * where X is a CSR device matrix view and Y,Z are device matrix views - * @tparam ValueType Data type of input/output matrices (float/double) - * @tparam IndexType Type of Y and Z - * @tparam NZType Type of X - * @tparam LayoutPolicyY layout of Y - * @tparam LayoutPolicyZ layout of Z - * @param[in] handle raft handle - * @param[in] trans_x transpose operation for X - * @param[in] trans_y transpose operation for Y - * @param[in] alpha scalar - * @param[in] x input raft::device_csr_matrix_view - * @param[in] y input raft::device_matrix_view - * @param[in] beta scalar - * @param[out] z output raft::device_matrix_view - */ -template -void spmm(raft::resources const& handle, - const bool trans_x, - const bool trans_y, - const ValueType* alpha, - raft::device_csr_matrix_view x, - raft::device_matrix_view y, - const ValueType* beta, - raft::device_matrix_view z) -{ - bool is_row_major = detail::is_row_major(y, z); - - auto descr_x = detail::create_descriptor(x); - auto descr_y = detail::create_descriptor(y, is_row_major); - auto descr_z = detail::create_descriptor(z, is_row_major); - - detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); - - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z)); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -} // end namespace linalg -} // end namespace sparse -} // end namespace raft +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the spmm.hpp at the same path instead.") -#endif +#include diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp new file mode 100644 index 0000000000..c2fdd64574 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __SPMM_H +#define __SPMM_H + +#pragma once + +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @brief SPMM function designed for handling all CSR * DENSE + * combinations of operand layouts for cuSparse. + * It computes the following equation: Z = alpha . X * Y + beta . Z + * where X is a CSR device matrix view and Y,Z are device matrix views + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of Y and Z + * @tparam NZType Type of X + * @tparam LayoutPolicyY layout of Y + * @tparam LayoutPolicyZ layout of Z + * @param[in] handle raft handle + * @param[in] trans_x transpose operation for X + * @param[in] trans_y transpose operation for Y + * @param[in] alpha scalar + * @param[in] x input raft::device_csr_matrix_view + * @param[in] y input raft::device_matrix_view + * @param[in] beta scalar + * @param[out] z output raft::device_matrix_view + */ +template +void spmm(raft::resources const& handle, + const bool trans_x, + const bool trans_y, + const ValueType* alpha, + raft::device_csr_matrix_view x, + raft::device_matrix_view y, + const ValueType* beta, + raft::device_matrix_view z) +{ + bool is_row_major = detail::is_row_major(y, z); + + auto descr_x = detail::create_descriptor(x); + auto descr_y = detail::create_descriptor(y); + auto descr_z = detail::create_descriptor(z); + + detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft + +#endif diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6e32281ec0..931530b66a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -315,6 +315,7 @@ if(BUILD_TESTS) test/sparse/normalize.cu test/sparse/reduce.cu test/sparse/row_op.cu + test/sparse/sddmm.cu test/sparse/sort.cu test/sparse/spgemmi.cu test/sparse/symmetrize.cu diff --git a/cpp/test/sparse/sddmm.cu b/cpp/test/sparse/sddmm.cu new file mode 100644 index 0000000000..9323ee8c2b --- /dev/null +++ b/cpp/test/sparse/sddmm.cu @@ -0,0 +1,365 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../test_utils.cuh" + +namespace raft { +namespace sparse { + +template +struct SDDMMInputs { + ValueType tolerance; + + IndexType m; + IndexType k; + IndexType n; + + ValueType alpha; + ValueType beta; + + bool transpose_a; + bool transpose_b; + + ValueType sparsity; + + unsigned long long int seed; +}; + +template +struct sum_abs_op { + __host__ __device__ ValueType operator()(const ValueType& x, const ValueType& y) const + { + return y >= ValueType(0.0) ? (x + y) : (x - y); + } +}; + +template +::std::ostream& operator<<(::std::ostream& os, const SDDMMInputs& params) +{ + os << " m: " << params.m << "\tk: " << params.k << "\tn: " << params.n + << "\talpha: " << params.alpha << "\tbeta: " << params.beta + << "\tsparsity: " << params.sparsity; + + return os; +} + +template +class SDDMMTest : public ::testing::TestWithParam> { + public: + SDDMMTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + a_data_d(0, resource::get_cuda_stream(handle)), + b_data_d(0, resource::get_cuda_stream(handle)), + c_indptr_d(0, resource::get_cuda_stream(handle)), + c_indices_d(0, resource::get_cuda_stream(handle)), + c_data_d(0, resource::get_cuda_stream(handle)), + c_expected_data_d(0, resource::get_cuda_stream(handle)) + { + } + + protected: + IndexType create_sparse_matrix(IndexType m, + IndexType n, + ValueType sparsity, + std::vector& matrix) + { + IndexType total_elements = static_cast(m * n); + IndexType num_ones = static_cast((total_elements * 1.0f) * sparsity); + IndexType res = num_ones; + + for (IndexType i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis(gen); + + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + void convert_to_csr(std::vector& matrix, + IndexType rows, + IndexType cols, + std::vector& values, + std::vector& indices, + std::vector& indptr) + { + IndexType offset_indptr = 0; + IndexType offset_values = 0; + indptr[offset_indptr++] = 0; + + for (IndexType i = 0; i < rows; ++i) { + for (IndexType j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + values[offset_values] = static_cast(1.0); + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + void cpu_sddmm(const std::vector& A, + const std::vector& B, + std::vector& vals, + const std::vector& cols, + const std::vector& row_ptrs, + bool is_row_major_A, + bool is_row_major_B) + { + if (params.m * params.k != static_cast(A.size()) || + params.k * params.n != static_cast(B.size())) { + std::cerr << "Matrix dimensions and vector size do not match!" << std::endl; + return; + } + + bool trans_a = params.transpose_a ? !is_row_major_A : is_row_major_A; + bool trans_b = params.transpose_b ? !is_row_major_B : is_row_major_B; + + for (IndexType i = 0; i < params.m; ++i) { + for (IndexType j = row_ptrs[i]; j < row_ptrs[i + 1]; ++j) { + ValueType sum = 0; + for (IndexType l = 0; l < params.k; ++l) { + IndexType a_index = trans_a ? i * params.k + l : l * params.m + i; + IndexType b_index = trans_b ? l * params.n + cols[j] : cols[j] * params.k + l; + sum += A[a_index] * B[b_index]; + } + vals[j] = params.alpha * sum + params.beta * vals[j]; + } + } + } + + void make_data() + { + IndexType a_size = params.m * params.k; + IndexType b_size = params.k * params.n; + IndexType c_size = params.m * params.n; + + std::vector a_data_h(a_size); + std::vector b_data_h(b_size); + + a_data_d.resize(a_size, stream); + b_data_d.resize(b_size, stream); + + auto blobs_a_b = raft::make_device_matrix(handle, 1, a_size + b_size); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_a_b.data_handle(), + labels.data_handle(), + 1, + a_size + b_size, + 1, + stream, + false, + nullptr, + nullptr, + ValueType(1.0), + false, + ValueType(-1.0f), + ValueType(1.0f), + uint64_t(2024)); + + raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_h.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + raft::copy(a_data_d.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_d.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + resource::sync_stream(handle); + + std::vector c_dense_data_h(c_size); + IndexType c_true_nnz = + create_sparse_matrix(params.m, params.n, params.sparsity, c_dense_data_h); + + std::vector c_indptr_h(params.m + 1); + std::vector c_indices_h(c_true_nnz); + std::vector c_data_h(c_true_nnz); + + convert_to_csr(c_dense_data_h, params.m, params.n, c_data_h, c_indices_h, c_indptr_h); + + bool is_row_major_A = (std::is_same_v); + bool is_row_major_B = (std::is_same_v); + + c_data_d.resize(c_data_h.size(), stream); + update_device(c_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + resource::sync_stream(handle); + + cpu_sddmm( + a_data_h, b_data_h, c_data_h, c_indices_h, c_indptr_h, is_row_major_A, is_row_major_B); + + c_indptr_d.resize(c_indptr_h.size(), stream); + c_indices_d.resize(c_indices_h.size(), stream); + c_expected_data_d.resize(c_data_h.size(), stream); + + update_device(c_indptr_d.data(), c_indptr_h.data(), c_indptr_h.size(), stream); + update_device(c_indices_d.data(), c_indices_h.data(), c_indices_h.size(), stream); + update_device(c_expected_data_d.data(), c_data_h.data(), c_data_h.size(), stream); + + resource::sync_stream(handle); + } + + void SetUp() override { make_data(); } + + void Run() + { + auto a = raft::make_device_matrix_view( + a_data_d.data(), + (!params.transpose_a ? params.m : params.k), + (!params.transpose_a ? params.k : params.m)); + auto b = raft::make_device_matrix_view( + b_data_d.data(), + (!params.transpose_b ? params.k : params.n), + (!params.transpose_b ? params.n : params.k)); + + auto c_structure = raft::make_device_compressed_structure_view( + c_indptr_d.data(), + c_indices_d.data(), + params.m, + params.n, + static_cast(c_indices_d.size())); + + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + + auto op_a = params.transpose_a ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + auto op_b = params.transpose_b ? raft::linalg::Operation::TRANSPOSE + : raft::linalg::Operation::NON_TRANSPOSE; + + raft::sparse::linalg::sddmm(handle, + a, + b, + c, + op_a, + op_b, + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); + + resource::sync_stream(handle); + + ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), + c.get_elements().data(), + c_expected_data_d.size(), + raft::CompareApprox(params.tolerance), + stream)); + + thrust::device_ptr expected_data_ptr = + thrust::device_pointer_cast(c_expected_data_d.data()); + ValueType sum_abs = thrust::reduce(thrust::cuda::par.on(stream), + expected_data_ptr, + expected_data_ptr + c_expected_data_d.size(), + ValueType(0.0f), + sum_abs_op()); + ValueType avg = sum_abs / (1.0f * c_expected_data_d.size()); + + ASSERT_GE(avg, (params.tolerance * static_cast(0.001f))); + } + + raft::resources handle; + cudaStream_t stream; + SDDMMInputs params; + + rmm::device_uvector a_data_d; + rmm::device_uvector b_data_d; + + rmm::device_uvector c_indptr_d; + rmm::device_uvector c_indices_d; + rmm::device_uvector c_data_d; + + rmm::device_uvector c_expected_data_d; +}; + +using SDDMMTestF_Row_Col = SDDMMTest; +TEST_P(SDDMMTestF_Row_Col, Result) { Run(); } + +using SDDMMTestF_Col_Row = SDDMMTest; +TEST_P(SDDMMTestF_Col_Row, Result) { Run(); } + +using SDDMMTestF_Row_Row = SDDMMTest; +TEST_P(SDDMMTestF_Row_Row, Result) { Run(); } + +using SDDMMTestF_Col_Col = SDDMMTest; +TEST_P(SDDMMTestF_Col_Col, Result) { Run(); } + +using SDDMMTestD_Row_Col = SDDMMTest; +TEST_P(SDDMMTestD_Row_Col, Result) { Run(); } + +using SDDMMTestD_Col_Row = SDDMMTest; +TEST_P(SDDMMTestD_Col_Row, Result) { Run(); } + +using SDDMMTestD_Row_Row = SDDMMTest; +TEST_P(SDDMMTestD_Row_Row, Result) { Run(); } + +using SDDMMTestD_Col_Col = SDDMMTest; +TEST_P(SDDMMTestD_Col_Col, Result) { Run(); } + +const std::vector> sddmm_inputs_f = { + {0.0001f, 10, 5, 32, 1.0, 0.0, false, false, 0.01, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.3, 0.0, true, false, 0.1, 1234ULL}, + {0.0003f, 32, 1024, 1024, 1.0, 0.3, false, true, 0.2, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.2, 0.2, true, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.1, 0.2, false, false, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 1.0, 0.3, true, false, 0.4, 1234ULL}, + {0.0003f, 32, 1024, 1024, 2.0, 0.2, false, true, 0.19, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.0, 1.2, true, true, 0.1, 1234ULL}}; + +const std::vector> sddmm_inputs_d = { + {0.0001f, 10, 5, 32, 1.0, 0.0, false, false, 0.01, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.3, 0.0, true, false, 0.1, 1234ULL}, + {0.0001f, 32, 1024, 1024, 1.0, 0.3, false, true, 0.2, 1234ULL}, + {0.0001f, 1024, 1024, 1024, 0.2, 0.2, true, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.1, 0.2, false, false, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 1.0, 0.3, true, false, 0.4, 1234ULL}, + {0.0001f, 32, 1024, 1024, 2.0, 0.2, false, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 1024, 0.0, 1.2, true, true, 0.1, 1234ULL}}; + +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Col, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Col_Row, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Row, ::testing::ValuesIn(sddmm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Col_Col, ::testing::ValuesIn(sddmm_inputs_f)); + +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Row_Col, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Row, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Row_Row, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Col, ::testing::ValuesIn(sddmm_inputs_d)); + +} // namespace sparse +} // namespace raft