Skip to content

Commit

Permalink
[FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067)
Browse files Browse the repository at this point in the history
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
  • Loading branch information
rhdong committed Jan 9, 2024
1 parent 6762fe5 commit 234c09d
Show file tree
Hide file tree
Showing 9 changed files with 839 additions and 58 deletions.
9 changes: 8 additions & 1 deletion cpp/include/raft/linalg/linalg_types.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,4 +32,11 @@ enum class Apply { ALONG_ROWS, ALONG_COLUMNS };
*/
enum class FillMode { UPPER, LOWER };

/**
* @brief Enum for this type indicates which operation is applied to the related input (e.g. sparse
* matrix, or vector).
*
*/
enum class Operation { NON_TRANSPOSE, TRANSPOSE };

} // end namespace raft::linalg
114 changes: 113 additions & 1 deletion cpp/include/raft/sparse/detail/cusparse_wrappers.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -571,6 +571,118 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle,
alg,
static_cast<void*>(externalBuffer));
}

template <typename T>
cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const T* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const T* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
size_t* bufferSize,
cudaStream_t stream);
template <>
inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const float* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const float* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
size_t* bufferSize,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSDDMM_bufferSize(
handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, bufferSize);
}
template <>
inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const double* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const double* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
size_t* bufferSize,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSDDMM_bufferSize(
handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, bufferSize);
}
template <typename T>
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const T* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const T* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
T* externalBuffer,
cudaStream_t stream);
template <>
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const float* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const float* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
float* externalBuffer,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSDDMM(handle,
opA,
opB,
static_cast<void const*>(alpha),
matA,
matB,
static_cast<void const*>(beta),
matC,
CUDA_R_32F,
alg,
static_cast<void*>(externalBuffer));
}
template <>
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const double* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const double* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
double* externalBuffer,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSDDMM(handle,
opA,
opB,
static_cast<void const*>(alpha),
matA,
matB,
static_cast<void const*>(beta),
matC,
CUDA_R_64F,
alg,
static_cast<void*>(externalBuffer));
}

/** @} */
#else
/**
Expand Down
99 changes: 99 additions & 0 deletions cpp/include/raft/sparse/linalg/detail/sddmm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cusparse_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/linalg_types.hpp>
#include <raft/sparse/detail/cusparse_wrappers.h>

namespace raft {
namespace sparse {
namespace linalg {
namespace detail {

/**
* @brief This function performs the multiplication of dense matrix A and dense matrix B,
* followed by an element-wise multiplication with the sparsity pattern of C.
* It computes the following equation: C = alpha · (op_a(A) * op_b(B) ∘ spy(C)) + beta · C
* where A,B are device matrix views and C is a CSR device matrix view
*
* @tparam ValueType Data type of input/output matrices (float/double)
* @tparam IndexType Type of C
* @tparam LayoutPolicyA layout of A
* @tparam LayoutPolicyB layout of B
* @tparam NZType Type of C
*
* @param[in] handle raft resource handle
* @param[in] descr_a input dense descriptor
* @param[in] descr_b input dense descriptor
* @param[in/out] descr_c output sparse descriptor
* @param[in] op_a input Operation op(A)
* @param[in] op_b input Operation op(B)
* @param[in] alpha scalar pointer
* @param[in] beta scalar pointer
*/
template <typename ValueType>
void sddmm(raft::resources const& handle,
cusparseDnMatDescr_t& descr_a,
cusparseDnMatDescr_t& descr_b,
cusparseSpMatDescr_t& descr_c,
cusparseOperation_t op_a,
cusparseOperation_t op_b,
const ValueType* alpha,
const ValueType* beta)
{
auto alg = CUSPARSE_SDDMM_ALG_DEFAULT;
size_t bufferSize;

RAFT_CUSPARSE_TRY(
raft::sparse::detail::cusparsesddmm_bufferSize(resource::get_cusparse_handle(handle),
op_a,
op_b,
alpha,
descr_a,
descr_b,
beta,
descr_c,
alg,
&bufferSize,
resource::get_cuda_stream(handle)));

raft::interruptible::synchronize(resource::get_cuda_stream(handle));

rmm::device_uvector<ValueType> tmp(bufferSize, resource::get_cuda_stream(handle));

RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle),
op_a,
op_b,
alpha,
descr_a,
descr_b,
beta,
descr_c,
alg,
tmp.data(),
resource::get_cuda_stream(handle)));
}

} // end namespace detail
} // end namespace linalg
} // end namespace sparse
} // end namespace raft
54 changes: 1 addition & 53 deletions cpp/include/raft/sparse/linalg/detail/spmm.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,58 +48,6 @@ bool is_row_major(raft::device_matrix_view<const ValueType, IndexType, LayoutPol
return is_row_major;
}

/**
* @brief create a cuSparse dense descriptor
* @tparam ValueType Data type of dense_view (float/double)
* @tparam IndexType Type of dense_view
* @tparam LayoutPolicy layout of dense_view
* @param[in] dense_view input raft::device_matrix_view
* @param[in] is_row_major data layout of raft::device_matrix_view
* @returns dense matrix descriptor to be used by cuSparse API
*/
template <typename ValueType, typename IndexType, typename LayoutPolicy>
cusparseDnMatDescr_t create_descriptor(
raft::device_matrix_view<ValueType, IndexType, LayoutPolicy>& dense_view, const bool is_row_major)
{
auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1);
cusparseDnMatDescr_t descr;
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(
&descr,
dense_view.extent(0),
dense_view.extent(1),
ld,
const_cast<std::remove_const_t<ValueType>*>(dense_view.data_handle()),
order));
return descr;
}

/**
* @brief create a cuSparse sparse descriptor
* @tparam ValueType Data type of sparse_view (float/double)
* @tparam IndptrType Data type of csr_matrix_view index pointers
* @tparam IndicesType Data type of csr_matrix_view indices
* @tparam NZType Type of sparse_view
* @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns
* @returns sparse matrix descriptor to be used by cuSparse API
*/
template <typename ValueType, typename IndptrType, typename IndicesType, typename NZType>
cusparseSpMatDescr_t create_descriptor(
raft::device_csr_matrix_view<ValueType, IndptrType, IndicesType, NZType>& sparse_view)
{
cusparseSpMatDescr_t descr;
auto csr_structure = sparse_view.structure_view();
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
&descr,
static_cast<int64_t>(csr_structure.get_n_rows()),
static_cast<int64_t>(csr_structure.get_n_cols()),
static_cast<int64_t>(csr_structure.get_nnz()),
const_cast<IndptrType*>(csr_structure.get_indptr().data()),
const_cast<IndicesType*>(csr_structure.get_indices().data()),
const_cast<std::remove_const_t<ValueType>*>(sparse_view.get_elements().data())));
return descr;
}

/**
* @brief SPMM function designed for handling all CSR * DENSE
* combinations of operand layouts for cuSparse.
Expand Down
Loading

0 comments on commit 234c09d

Please sign in to comment.