Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Support migration of cusparse<T>csrsm2_bufferSizeExt, cusparse<T>csrsm2_analysis and cusparse<T>csrsm2_solve #2629

Open
wants to merge 5 commits into
base: SYCLomatic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clang/lib/DPCT/RuleInfra/APINamesTemplateType.inc
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ TYPE_REWRITE_ENTRY("csrsv2Info_t",
TYPE_FACTORY(STR("std::shared_ptr<" +
MapNames::getLibraryHelperNamespace() +
"sparse::optimize_info>")))
TYPE_REWRITE_ENTRY("csrsm2Info_t",
TYPE_FACTORY(STR("std::shared_ptr<" +
MapNames::getLibraryHelperNamespace() +
"sparse::optimize_info>")))
TYPE_REWRITE_ENTRY("cusparseSolvePolicy_t", TYPE_FACTORY(STR("int")))
TYPE_REWRITE_ENTRY("cusparseAction_t",
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
Expand Down
57 changes: 57 additions & 0 deletions clang/lib/DPCT/RulesMathLib/APINamesCUSPARSE.inc
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,63 @@ ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cusparseCreateCsrsv2Info", DEREF(0),
ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cusparseDestroyCsrsv2Info",
ARG(0), false, "reset"))

ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY(
"cusparseCreateCsrsm2Info", DEREF(0),
CALL("std::make_shared<" + MapNames::getLibraryHelperNamespace() +
"sparse::optimize_info>")))
ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY("cusparseDestroyCsrsm2Info",
ARG(0), false, "reset"))

ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cusparseScsrsm2_bufferSizeExt",
DEREF(16), LITERAL("0")))
ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cusparseDcsrsm2_bufferSizeExt",
DEREF(16), LITERAL("0")))
ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cusparseCcsrsm2_bufferSizeExt",
DEREF(16), LITERAL("0")))
ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cusparseZcsrsm2_bufferSizeExt",
DEREF(16), LITERAL("0")))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseScsrsm2_analysis",
CALL(MapNames::getLibraryHelperNamespace() + "sparse::optimize_csrsm",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(2), ARG(3), ARG(4), ARG(5),
ARG(8), ARG(9), ARG(10), ARG(11), ARG(14))))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseDcsrsm2_analysis",
CALL(MapNames::getLibraryHelperNamespace() + "sparse::optimize_csrsm",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(2), ARG(3), ARG(4), ARG(5),
ARG(8), ARG(9), ARG(10), ARG(11), ARG(14))))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseCcsrsm2_analysis",
CALL(MapNames::getLibraryHelperNamespace() + "sparse::optimize_csrsm",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(2), ARG(3), ARG(4), ARG(5),
ARG(8), ARG(9), ARG(10), ARG(11), ARG(14))))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseZcsrsm2_analysis",
CALL(MapNames::getLibraryHelperNamespace() + "sparse::optimize_csrsm",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(2), ARG(3), ARG(4), ARG(5),
ARG(8), ARG(9), ARG(10), ARG(11), ARG(14))))

ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseScsrsm2_solve",
CALL(MapNames::getLibraryHelperNamespace() + "sparse::csrsm",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(2), ARG(3), ARG(4), ARG(5),
ARG(7), ARG(8), ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14))))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseDcsrsm2_solve",
CALL(MapNames::getLibraryHelperNamespace() + "sparse::csrsm",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(2), ARG(3), ARG(4), ARG(5),
ARG(7), ARG(8), ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14))))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseCcsrsm2_solve",
CALL(MapNames::getLibraryHelperNamespace() + "sparse::csrsm",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(2), ARG(3), ARG(4), ARG(5),
ARG(7), ARG(8), ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14))))
ASSIGNABLE_FACTORY(CALL_FACTORY_ENTRY(
"cusparseZcsrsm2_solve",
CALL(MapNames::getLibraryHelperNamespace() + "sparse::csrsm",
MEMBER_CALL(ARG(0), true, "get_queue"), ARG(2), ARG(3), ARG(4), ARG(5),
ARG(7), ARG(8), ARG(9), ARG(10), ARG(11), ARG(12), ARG(13), ARG(14))))

ASSIGNABLE_FACTORY(ASSIGN_FACTORY_ENTRY("cusparseSpSM_bufferSize", DEREF(10),
LITERAL("0")))
REMOVE_API_FACTORY_ENTRY("cusparseSpSM_createDescr")
Expand Down
11 changes: 8 additions & 3 deletions clang/lib/DPCT/RulesMathLib/SpBLASAPIMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace clang::ast_matchers;
void SpBLASTypeLocRule::registerMatcher(ast_matchers::MatchFinder &MF) {
auto TargetTypeName = [&]() {
return hasAnyName("csrsv2Info_t", "cusparseSolvePolicy_t",
"cusparseAction_t");
"cusparseAction_t", "csrsm2Info_t");
};

MF.addMatcher(
Expand All @@ -41,8 +41,6 @@ void SpBLASTypeLocRule::runRule(

// Rule for spBLAS function calls.
void SPBLASFunctionCallRule::registerMatcher(MatchFinder &MF) {


auto functionName = [&]() {
return hasAnyName(
/*management*/
Expand Down Expand Up @@ -79,6 +77,13 @@ void SPBLASFunctionCallRule::registerMatcher(MatchFinder &MF) {
"cusparseScsrgemm", "cusparseDcsrgemm", "cusparseCcsrgemm",
"cusparseZcsrgemm", "cusparseXcsrgemmNnz", "cusparseScsrmm2",
"cusparseDcsrmm2", "cusparseCcsrmm2", "cusparseZcsrmm2",
"cusparseCreateCsrsm2Info", "cusparseDestroyCsrsm2Info",
"cusparseScsrsm2_bufferSizeExt", "cusparseDcsrsm2_bufferSizeExt",
"cusparseCcsrsm2_bufferSizeExt", "cusparseZcsrsm2_bufferSizeExt",
"cusparseScsrsm2_analysis", "cusparseDcsrsm2_analysis",
"cusparseCcsrsm2_analysis", "cusparseZcsrsm2_analysis",
"cusparseScsrsm2_solve", "cusparseDcsrsm2_solve",
"cusparseCcsrsm2_solve", "cusparseZcsrsm2_solve",
/*Generic*/
"cusparseCreateCsr", "cusparseDestroySpMat", "cusparseCsrGet",
"cusparseSpMatGetFormat", "cusparseSpMatGetIndexBase",
Expand Down
28 changes: 14 additions & 14 deletions clang/lib/DPCT/SrcAPI/APINames_cuSPARSE.inc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ ENTRY(cusparseSetStream, cusparseSetStream, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseGetStream, cusparseGetStream, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseCreateCsrsv2Info, cusparseCreateCsrsv2Info, true, NO_FLAG, P4, "comment")
ENTRY(cusparseDestroyCsrsv2Info, cusparseDestroyCsrsv2Info, true, NO_FLAG, P4, "comment")
ENTRY(cusparseCreateCsrsm2Info, cusparseCreateCsrsm2Info, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDestroyCsrsm2Info, cusparseDestroyCsrsm2Info, false, NO_FLAG, P4, "comment")
ENTRY(cusparseCreateCsrsm2Info, cusparseCreateCsrsm2Info, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseDestroyCsrsm2Info, cusparseDestroyCsrsm2Info, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseCreateCsric02Info, cusparseCreateCsric02Info, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDestroyCsric02Info, cusparseDestroyCsric02Info, false, NO_FLAG, P4, "comment")
ENTRY(cusparseCreateCsrilu02Info, cusparseCreateCsrilu02Info, false, NO_FLAG, P4, "comment")
Expand Down Expand Up @@ -211,18 +211,18 @@ ENTRY(cusparseScsrsm_solve, cusparseScsrsm_solve, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDcsrsm_solve, cusparseDcsrsm_solve, false, NO_FLAG, P4, "comment")
ENTRY(cusparseCcsrsm_solve, cusparseCcsrsm_solve, false, NO_FLAG, P4, "comment")
ENTRY(cusparseZcsrsm_solve, cusparseZcsrsm_solve, false, NO_FLAG, P4, "comment")
ENTRY(cusparseScsrsm2_bufferSizeExt, cusparseScsrsm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDcsrsm2_bufferSizeExt, cusparseDcsrsm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
ENTRY(cusparseCcsrsm2_bufferSizeExt, cusparseCcsrsm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
ENTRY(cusparseZcsrsm2_bufferSizeExt, cusparseZcsrsm2_bufferSizeExt, false, NO_FLAG, P4, "comment")
ENTRY(cusparseScsrsm2_analysis, cusparseScsrsm2_analysis, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDcsrsm2_analysis, cusparseDcsrsm2_analysis, false, NO_FLAG, P4, "comment")
ENTRY(cusparseCcsrsm2_analysis, cusparseCcsrsm2_analysis, false, NO_FLAG, P4, "comment")
ENTRY(cusparseZcsrsm2_analysis, cusparseZcsrsm2_analysis, false, NO_FLAG, P4, "comment")
ENTRY(cusparseScsrsm2_solve, cusparseScsrsm2_solve, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDcsrsm2_solve, cusparseDcsrsm2_solve, false, NO_FLAG, P4, "comment")
ENTRY(cusparseCcsrsm2_solve, cusparseCcsrsm2_solve, false, NO_FLAG, P4, "comment")
ENTRY(cusparseZcsrsm2_solve, cusparseZcsrsm2_solve, false, NO_FLAG, P4, "comment")
ENTRY(cusparseScsrsm2_bufferSizeExt, cusparseScsrsm2_bufferSizeExt, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseDcsrsm2_bufferSizeExt, cusparseDcsrsm2_bufferSizeExt, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseCcsrsm2_bufferSizeExt, cusparseCcsrsm2_bufferSizeExt, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseZcsrsm2_bufferSizeExt, cusparseZcsrsm2_bufferSizeExt, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseScsrsm2_analysis, cusparseScsrsm2_analysis, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseDcsrsm2_analysis, cusparseDcsrsm2_analysis, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseCcsrsm2_analysis, cusparseCcsrsm2_analysis, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseZcsrsm2_analysis, cusparseZcsrsm2_analysis, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseScsrsm2_solve, cusparseScsrsm2_solve, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseDcsrsm2_solve, cusparseDcsrsm2_solve, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseCcsrsm2_solve, cusparseCcsrsm2_solve, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseZcsrsm2_solve, cusparseZcsrsm2_solve, true, NO_FLAG, P4, "Successful")
ENTRY(cusparseXcsrsm2_zeroPivot, cusparseXcsrsm2_zeroPivot, false, NO_FLAG, P4, "comment")
ENTRY(cusparseSgemmi, cusparseSgemmi, false, NO_FLAG, P4, "comment")
ENTRY(cusparseDgemmi, cusparseDgemmi, false, NO_FLAG, P4, "comment")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,85 @@ template <typename T> struct csrsv_impl {
}
};

template <typename T> struct optimize_csrsm_impl {
void operator()(sycl::queue &queue, oneapi::mkl::transpose transa,
oneapi::mkl::transpose transb, int row_col, int nrhs,
const std::shared_ptr<matrix_info> info, const void *val,
const int *row_ptr, const int *col_ind,
std::shared_ptr<optimize_info> optimize_info) {
using Ty = typename ::dpct::detail::lib_data_traits_t<T>;
auto temp_row_ptr = dpct::detail::get_memory<int>(row_ptr);
auto temp_col_ind = dpct::detail::get_memory<int>(col_ind);
auto temp_val = dpct::detail::get_memory<Ty>(val);
#ifdef DPCT_USM_LEVEL_NONE
optimize_info->_row_ptr_buf = temp_row_ptr;
optimize_info->_col_ind_buf = temp_col_ind;
optimize_info->_val_buf = temp_val;
auto &data_row_ptr = optimize_info->_row_ptr_buf;
auto &data_col_ind = optimize_info->_col_ind_buf;
auto &data_val = std::get<sycl::buffer<Ty>>(optimize_info->_val_buf);
#else
auto data_row_ptr = temp_row_ptr;
auto data_col_ind = temp_col_ind;
auto data_val = temp_val;
#endif
oneapi::mkl::sparse::set_csr_data(queue, optimize_info->get_matrix_handle(),
row_col, row_col, info->get_index_base(),
data_row_ptr, data_col_ind, data_val);
if (info->get_matrix_type() != matrix_info::matrix_type::tr)
throw std::runtime_error("dpct::sparse::optimize_csrsv_impl()(): "
"oneapi::mkl::sparse::optimize_trsm "
"only accept triangular matrix.");
SPARSE_CALL(oneapi::mkl::sparse::optimize_trsm(
queue,
transb == oneapi::mkl::transpose::nontrans
? oneapi::mkl::layout::col_major
: oneapi::mkl::layout::row_major,
info->get_uplo(), transa, info->get_diag(),
optimize_info->get_matrix_handle(), nrhs),
optimize_info);
}
};
template <typename T> struct csrsm_impl {
void operator()(sycl::queue &queue, oneapi::mkl::transpose transa,
oneapi::mkl::transpose transb, int row_col, int nrhs,
const void *alpha, const std::shared_ptr<matrix_info> info,
const void *val, const int *row_ptr, const int *col_ind,
void *b, int ldb,
std::shared_ptr<optimize_info> optimize_info) {
using Ty = typename ::dpct::detail::lib_data_traits_t<T>;
auto alpha_value =
dpct::detail::get_value(static_cast<const Ty *>(alpha), queue);
auto data_b = dpct::detail::get_memory<Ty>(b);

int x_size =
ldb * (transb == oneapi::mkl::transpose::nontrans ? nrhs : row_col);
Ty *x = (Ty *)::dpct::cs::malloc(sizeof(Ty) * x_size, queue);

auto data_x = dpct::detail::get_memory<Ty>(x);

sycl::event e1;
#ifndef DPCT_USM_LEVEL_NONE
e1 =
#endif
oneapi::mkl::sparse::trsm(
queue,
transb == oneapi::mkl::transpose::nontrans
? oneapi::mkl::layout::col_major
: oneapi::mkl::layout::row_major,
transa, oneapi::mkl::transpose::nontrans, info->get_uplo(),
info->get_diag(), alpha_value, optimize_info->get_matrix_handle(),
data_b, nrhs, ldb, data_x, ldb);

sycl::event e2 =
::dpct::cs::memcpy(queue, b, x, sizeof(Ty) * x_size,
::dpct::cs::memcpy_direction::automatic, {e1});

sycl::event e3 = ::dpct::cs::enqueue_free({x}, {e2}, queue);
optimize_info->add_dependency(e3);
}
};

template <typename T> struct spmv_impl {
void operator()(sycl::queue &queue, oneapi::mkl::transpose trans,
const void *alpha, sparse_matrix_desc_t a,
Expand Down
53 changes: 53 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/sparse_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ enum class conversion_scope : int { index = 0, index_and_value };
// Forward declaration
namespace detail {
template <typename T> struct optimize_csrsv_impl;
template <typename T> struct optimize_csrsm_impl;
}

/// Saving the optimization information for solving a system of linear
Expand All @@ -74,6 +75,7 @@ class optimize_info {
}
#ifdef DPCT_USM_LEVEL_NONE
template <typename T> friend struct detail::optimize_csrsv_impl;
template <typename T> friend struct detail::optimize_csrsm_impl;
#endif

private:
Expand Down Expand Up @@ -830,6 +832,57 @@ inline void csrsv(sycl::queue &queue, oneapi::mkl::transpose trans, int row_col,
optimize_info, x, y);
}

/// Performs internal optimizations for dpct::sparse::csrsm by analyzing
/// the provided matrix structure and operation parameters.
/// \param [in] queue The queue where the routine should be executed. It must
/// have the in_order property when using the USM mode.
/// \param [in] transa The operation applied to A.
/// \param [in] transb The operation applied to B and X.
/// \param [in] row_col Number of rows and columns of A.
/// \param [in] nrhs Number of columns op_b(B).
/// \param [in] info Matrix info of A.
/// \param [in] val An array containing the non-zero elements of A.
/// \param [in] row_ptr An array of length \p num_rows + 1.
/// \param [in] col_ind An array containing the column indices in index-based
/// numbering.
/// \param [out] optimize_info The result of the optimizations.
template <typename T>
void optimize_csrsm(sycl::queue &queue, oneapi::mkl::transpose transa,
oneapi::mkl::transpose transb, int row_col, int nrhs,
const std::shared_ptr<matrix_info> info, const T *val,
const int *row_ptr, const int *col_ind,
std::shared_ptr<optimize_info> optimize_info) {
detail::optimize_csrsm_impl<T>()(queue, transa, transb, row_col, nrhs, info,
val, row_ptr, col_ind, optimize_info);
}

/// Solves the sparse triangular system op_a(A) * op_b(X) = alpha * op_b(B)
/// where A is a sparse triangular matrix of size \p row_col by \p row_col .
/// \param [in] queue The queue where the routine should be executed. It must
/// have the in_order property when using the USM mode.
/// \param [in] transa The operation applied to A.
/// \param [in] transb The operation applied to B and X.
/// \param [in] row_col Number of rows and columns of A.
/// \param [in] nrhs Number of columns op_b(B).
/// \param [in] alpha Specifies the scalar.
/// \param [in] info Matrix info of A.
/// \param [in] val An array containing the non-zero elements of A.
/// \param [in] row_ptr An array of length \p num_rows + 1.
/// \param [in] col_ind An array containing the column indices in index-based
/// numbering.
/// \param [in, out] b The RHS matrix. It will be overwritten by the X.
/// \param [in] ldb The leading dimension of B and X.
/// \param [in] optimize_info The result of the optimizations.
template <typename T>
void csrsm(sycl::queue &queue, oneapi::mkl::transpose transa,
oneapi::mkl::transpose transb, int row_col, int nrhs, const T *alpha,
const std::shared_ptr<matrix_info> info, const T *val,
const int *row_ptr, const int *col_ind, T *b, int ldb,
std::shared_ptr<optimize_info> optimize_info) {
detail::csrsm_impl<T>()(queue, transa, transb, row_col, nrhs, alpha, info,
val, row_ptr, col_ind, b, ldb, optimize_info);
}

/// Computes a sparse matrix-dense vector product: y = alpha * op(a) * x + beta * y.
/// \param [in] queue The queue where the routine should be executed. It must
/// have the in_order property when using the USM mode.
Expand Down
Loading
Loading