Skip to content

Commit

Permalink
[FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067)
Browse files Browse the repository at this point in the history
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
  • Loading branch information
rhdong committed Dec 20, 2023
1 parent 7e098b2 commit 8016277
Show file tree
Hide file tree
Showing 8 changed files with 797 additions and 52 deletions.
112 changes: 112 additions & 0 deletions cpp/include/raft/sparse/detail/cusparse_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,118 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle,
alg,
static_cast<void*>(externalBuffer));
}

template <typename T>
cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const T* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const T* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
size_t* bufferSize,
cudaStream_t stream);
template <>
inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const float* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const float* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
size_t* bufferSize,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSDDMM_bufferSize(
handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, bufferSize);
}
template <>
inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const double* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const double* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
size_t* bufferSize,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSDDMM_bufferSize(
handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, bufferSize);
}
template <typename T>
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const T* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const T* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
T* externalBuffer,
cudaStream_t stream);
template <>
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const float* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const float* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
float* externalBuffer,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSDDMM(handle,
opA,
opB,
static_cast<void const*>(alpha),
matA,
matB,
static_cast<void const*>(beta),
matC,
CUDA_R_32F,
alg,
static_cast<void*>(externalBuffer));
}
template <>
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
cusparseOperation_t opA,
cusparseOperation_t opB,
const double* alpha,
const cusparseDnMatDescr_t matA,
const cusparseDnMatDescr_t matB,
const double* beta,
cusparseSpMatDescr_t matC,
cusparseSDDMMAlg_t alg,
double* externalBuffer,
cudaStream_t stream)
{
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
return cusparseSDDMM(handle,
opA,
opB,
static_cast<void const*>(alpha),
matA,
matB,
static_cast<void const*>(beta),
matC,
CUDA_R_64F,
alg,
static_cast<void*>(externalBuffer));
}

/** @} */
#else
/**
Expand Down
100 changes: 100 additions & 0 deletions cpp/include/raft/sparse/linalg/detail/sddmm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cusparse_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/sparse/detail/cusparse_wrappers.h>

namespace raft {
namespace sparse {
namespace linalg {
namespace detail {

/**
* @brief This function performs the multiplication of dense matrix A and dense matrix B,
* followed by an element-wise multiplication with the sparsity pattern of C.
* It computes the following equation: C = alpha · (A * B ∘ spy(C)) + beta · C
* where A,B are device matrix views and C is a CSR device matrix view
*
* @tparam ValueType Data type of input/output matrices (float/double)
* @tparam IndexType Type of C
* @tparam LayoutPolicyA layout of A
* @tparam LayoutPolicyB layout of B
* @tparam NZType Type of C
*
* @param[in] handle raft resource handle
* @param[in] descr_a input dense descriptor
* @param[in] descr_b input dense descriptor
* @param[in/out] descr_c output sparse descriptor
* @param[in] alpha scalar pointer
* @param[in] beta scalar pointer
*/
template <typename ValueType>
void sddmm(raft::resources const& handle,
cusparseDnMatDescr_t& descr_a,
cusparseDnMatDescr_t& descr_b,
cusparseSpMatDescr_t& descr_c,
const bool is_row_major_a,
const bool is_row_major_b,
const ValueType* alpha,
const ValueType* beta)
{
auto opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto opB = (is_row_major_a != is_row_major_b) ? CUSPARSE_OPERATION_NON_TRANSPOSE
: CUSPARSE_OPERATION_TRANSPOSE;

auto alg = CUSPARSE_SDDMM_ALG_DEFAULT;
size_t bufferSize;

RAFT_CUSPARSE_TRY(
raft::sparse::detail::cusparsesddmm_bufferSize(resource::get_cusparse_handle(handle),
opA,
opB,
alpha,
descr_a,
descr_b,
beta,
descr_c,
alg,
&bufferSize,
resource::get_cuda_stream(handle)));

raft::interruptible::synchronize(resource::get_cuda_stream(handle));

rmm::device_uvector<ValueType> tmp(bufferSize, resource::get_cuda_stream(handle));

RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle),
opA,
opB,
alpha,
descr_a,
descr_b,
beta,
descr_c,
alg,
tmp.data(),
resource::get_cuda_stream(handle)));
}

} // end namespace detail
} // end namespace linalg
} // end namespace sparse
} // end namespace raft
52 changes: 0 additions & 52 deletions cpp/include/raft/sparse/linalg/detail/spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,58 +48,6 @@ bool is_row_major(raft::device_matrix_view<const ValueType, IndexType, LayoutPol
return is_row_major;
}

/**
* @brief create a cuSparse dense descriptor
* @tparam ValueType Data type of dense_view (float/double)
* @tparam IndexType Type of dense_view
* @tparam LayoutPolicy layout of dense_view
* @param[in] dense_view input raft::device_matrix_view
* @param[in] is_row_major data layout of raft::device_matrix_view
* @returns dense matrix descriptor to be used by cuSparse API
*/
template <typename ValueType, typename IndexType, typename LayoutPolicy>
cusparseDnMatDescr_t create_descriptor(
raft::device_matrix_view<ValueType, IndexType, LayoutPolicy>& dense_view, const bool is_row_major)
{
auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1);
cusparseDnMatDescr_t descr;
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(
&descr,
dense_view.extent(0),
dense_view.extent(1),
ld,
const_cast<std::remove_const_t<ValueType>*>(dense_view.data_handle()),
order));
return descr;
}

/**
* @brief create a cuSparse sparse descriptor
* @tparam ValueType Data type of sparse_view (float/double)
* @tparam IndptrType Data type of csr_matrix_view index pointers
* @tparam IndicesType Data type of csr_matrix_view indices
* @tparam NZType Type of sparse_view
* @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns
* @returns sparse matrix descriptor to be used by cuSparse API
*/
template <typename ValueType, typename IndptrType, typename IndicesType, typename NZType>
cusparseSpMatDescr_t create_descriptor(
raft::device_csr_matrix_view<ValueType, IndptrType, IndicesType, NZType>& sparse_view)
{
cusparseSpMatDescr_t descr;
auto csr_structure = sparse_view.structure_view();
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
&descr,
static_cast<int64_t>(csr_structure.get_n_rows()),
static_cast<int64_t>(csr_structure.get_n_cols()),
static_cast<int64_t>(csr_structure.get_nnz()),
const_cast<IndptrType*>(csr_structure.get_indptr().data()),
const_cast<IndicesType*>(csr_structure.get_indices().data()),
const_cast<std::remove_const_t<ValueType>*>(sparse_view.get_elements().data())));
return descr;
}

/**
* @brief SPMM function designed for handling all CSR * DENSE
* combinations of operand layouts for cuSparse.
Expand Down
83 changes: 83 additions & 0 deletions cpp/include/raft/sparse/linalg/detail/utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/sparse/detail/cusparse_wrappers.h>

namespace raft {
namespace sparse {
namespace linalg {
namespace detail {

/**
* @brief create a cuSparse dense descriptor
* @tparam ValueType Data type of dense_view (float/double)
* @tparam IndexType Type of dense_view
* @tparam LayoutPolicy layout of dense_view
* @param[in] dense_view input raft::device_matrix_view
* @param[in] is_row_major data layout of raft::device_matrix_view
* @returns dense matrix descriptor to be used by cuSparse API
*/
template <typename ValueType, typename IndexType, typename LayoutPolicy>
cusparseDnMatDescr_t create_descriptor(
raft::device_matrix_view<ValueType, IndexType, LayoutPolicy>& dense_view, const bool is_row_major)
{
auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1);
cusparseDnMatDescr_t descr;
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(
&descr,
dense_view.extent(0),
dense_view.extent(1),
ld,
const_cast<std::remove_const_t<ValueType>*>(dense_view.data_handle()),
order));
return descr;
}

/**
* @brief create a cuSparse sparse descriptor
* @tparam ValueType Data type of sparse_view (float/double)
* @tparam IndptrType Data type of csr_matrix_view index pointers
* @tparam IndicesType Data type of csr_matrix_view indices
* @tparam NZType Type of sparse_view
* @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns
* @returns sparse matrix descriptor to be used by cuSparse API
*/
template <typename ValueType, typename IndptrType, typename IndicesType, typename NZType>
cusparseSpMatDescr_t create_descriptor(
raft::device_csr_matrix_view<ValueType, IndptrType, IndicesType, NZType>& sparse_view)
{
cusparseSpMatDescr_t descr;
auto csr_structure = sparse_view.structure_view();
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
&descr,
static_cast<int64_t>(csr_structure.get_n_rows()),
static_cast<int64_t>(csr_structure.get_n_cols()),
static_cast<int64_t>(csr_structure.get_nnz()),
const_cast<IndptrType*>(csr_structure.get_indptr().data()),
const_cast<IndicesType*>(csr_structure.get_indices().data()),
const_cast<std::remove_const_t<ValueType>*>(sparse_view.get_elements().data())));
return descr;
}

} // end namespace detail
} // end namespace linalg
} // end namespace sparse
} // end namespace raft
Loading

0 comments on commit 8016277

Please sign in to comment.