Skip to content

Commit

Permalink
[SPARSE] Add support for sparse gemv
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Sep 26, 2023
1 parent 39a18c5 commit d1294a4
Show file tree
Hide file tree
Showing 8 changed files with 864 additions and 19 deletions.
26 changes: 14 additions & 12 deletions src/sparse_blas/backends/mkl_common/mkl_operations.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/

sycl::event optimize_gemv(sycl::queue& /*queue*/, transpose /*transpose_val*/,
detail::matrix_handle* /*handle*/,
const std::vector<sycl::event>& /*dependencies*/) {
throw unimplemented("SPARSE_BLAS", "optimize_gemv");
sycl::event optimize_gemv(sycl::queue& queue, transpose transpose_val,
detail::matrix_handle* handle,
const std::vector<sycl::event>& dependencies) {
return oneapi::mkl::sparse::optimize_gemv(queue, transpose_val, get_handle(handle),
dependencies);
}

sycl::event optimize_trmv(sycl::queue& /*queue*/, uplo /*uplo_val*/, transpose /*transpose_val*/,
Expand All @@ -37,18 +38,19 @@ sycl::event optimize_trsv(sycl::queue& /*queue*/, uplo /*uplo_val*/, transpose /

template <typename fpType>
std::enable_if_t<detail::is_fp_supported_v<fpType>> gemv(
sycl::queue& /*queue*/, transpose /*transpose_val*/, const fpType /*alpha*/,
detail::matrix_handle* /*A_handle*/, sycl::buffer<fpType, 1>& /*x*/, const fpType /*beta*/,
sycl::buffer<fpType, 1>& /*y*/) {
throw unimplemented("SPARSE_BLAS", "gemv");
sycl::queue& queue, transpose transpose_val, const fpType alpha,
detail::matrix_handle* A_handle, sycl::buffer<fpType, 1>& x, const fpType beta,
sycl::buffer<fpType, 1>& y) {
oneapi::mkl::sparse::gemv(queue, transpose_val, alpha, get_handle(A_handle), x, beta, y);
}

template <typename fpType>
std::enable_if_t<detail::is_fp_supported_v<fpType>, sycl::event> gemv(
sycl::queue& /*queue*/, transpose /*transpose_val*/, const fpType /*alpha*/,
detail::matrix_handle* /*A_handle*/, const fpType* /*x*/, const fpType /*beta*/, fpType* /*y*/,
const std::vector<sycl::event>& /*dependencies*/) {
throw unimplemented("SPARSE_BLAS", "gemv");
sycl::queue& queue, transpose transpose_val, const fpType alpha,
detail::matrix_handle* A_handle, const fpType* x, const fpType beta, fpType* y,
const std::vector<sycl::event>& dependencies) {
return oneapi::mkl::sparse::gemv(queue, transpose_val, alpha, get_handle(A_handle), x, beta, y,
dependencies);
}

template <typename fpType>
Expand Down
11 changes: 6 additions & 5 deletions tests/unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ set(dft_TEST_LIST

set(dft_TEST_LINK "")

foreach(domain ${TARGET_DOMAINS})
# TODO: Add sparse_blas tests and CMake logic
if (domain STREQUAL "sparse_blas")
continue()
endif()
# Sparse BLAS config
set(sparse_blas_TEST_LIST
spblas_source)

set(sparse_blas_TEST_LINK "")

foreach(domain ${TARGET_DOMAINS})
# Generate RT and CT test lists
set(${domain}_TEST_LIST_RT ${${domain}_TEST_LIST})
set(${domain}_TEST_LIST_CT ${${domain}_TEST_LIST})
Expand Down
3 changes: 1 addition & 2 deletions tests/unit_tests/sparse_blas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@
# SPDX-License-Identifier: Apache-2.0
#===============================================================================

# TODO: Add tests
#add_subdirectory(source)
add_subdirectory(source)
148 changes: 148 additions & 0 deletions tests/unit_tests/sparse_blas/include/sparse_reference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions
* and limitations under the License.
*
*
* SPDX-License-Identifier: Apache-2.0
*******************************************************************************/

#ifndef _SPARSE_REFERENCE_HPP__
#define _SPARSE_REFERENCE_HPP__

#include <stdexcept>

#include "oneapi/mkl.hpp"

template <typename T>
inline T conjugate(T) {
static_assert(false, "Unsupported type");
}
template <>
inline float conjugate(float t) {
return t;
}
template <>
inline double conjugate(double t) {
return t;
}
template <>
inline std::complex<float> conjugate(std::complex<float> t) {
return std::conj(t);
}
template <>
inline std::complex<double> conjugate(std::complex<double> t) {
return std::conj(t);
}

template <typename T>
inline T opVal(const T t, const bool isConj) {
return (isConj ? conjugate(t) : t);
};

template <typename fpType, typename intType, typename accIntType, typename accFpType>
void do_csr_transpose(const oneapi::mkl::transpose opA, intType *ia_t, intType *ja_t, fpType *a_t,
intType a_nrows, intType a_ncols, intType a_ind, accIntType &ia,
accIntType &ja, accFpType &a, const bool structOnlyFlag = false) {
const bool isConj = (opA == oneapi::mkl::transpose::conjtrans);

// initialize ia_t to zero
for (intType i = 0; i < a_ncols + 1; ++i) {
ia_t[i] = 0;
}

// fill ia_t with counts of columns
for (intType i = 0; i < a_nrows; ++i) {
const intType st = ia[i] - a_ind;
const intType en = ia[i + 1] - a_ind;
for (intType j = st; j < en; ++j) {
const intType col = ja[j] - a_ind;
ia_t[col + 1]++;
}
}
// prefix sum to get official ia_t counts
ia_t[0] = a_ind;
for (intType i = 0; i < a_ncols; ++i) {
ia_t[i + 1] += ia_t[i];
}

// second pass through data to fill transpose structure
for (intType i = 0; i < a_nrows; ++i) {
const intType st = ia[i] - a_ind;
const intType en = ia[i + 1] - a_ind;
for (intType j = st; j < en; ++j) {
const intType col = ja[j] - a_ind;
const intType j_in_a_t = ia_t[col] - a_ind;
ia_t[col]++;
ja_t[j_in_a_t] = i + a_ind;
if (!structOnlyFlag) {
const fpType val = a[j];
a_t[j_in_a_t] = opVal(val, isConj);
}
}
}

// adjust ia_t back to original state after filling structure
for (intType i = a_ncols; i > 0; --i) {
ia_t[i] = ia_t[i - 1];
}
ia_t[0] = a_ind;
}

template <typename fpType, typename intType>
void prepare_reference_gemv_data(const intType *ia, const intType *ja, const fpType *a,
intType a_nrows, intType a_ncols, std::size_t a_nnz, intType a_ind,
oneapi::mkl::transpose opA, fpType alpha, fpType beta,
const fpType *x, fpType *y_ref) {
std::size_t opa_nrows =
static_cast<std::size_t>((opA == oneapi::mkl::transpose::nontrans) ? a_nrows : a_ncols);

// prepare op(A) locally
std::vector<intType> iopa;
std::vector<intType> jopa;
std::vector<fpType> opa;
if (opA == oneapi::mkl::transpose::nontrans) {
iopa.assign(ia, ia + a_nrows + 1);
jopa.assign(ja, ja + a_nnz);
opa.assign(a, a + a_nnz);
}
else if (opA == oneapi::mkl::transpose::trans || opA == oneapi::mkl::transpose::conjtrans) {
iopa.resize(opa_nrows + 1);
jopa.resize(a_nnz);
opa.resize(a_nnz);
do_csr_transpose(opA, iopa.data(), jopa.data(), opa.data(), a_nrows, a_ncols, a_ind, ia, ja,
a);
}
else {
throw std::runtime_error(
"unsupported transpose_val (opA) in prepare_reference_gemv_data()");
}

//
// do GEMV operation
//
// y_ref <- alpha * op(A) * x + beta * y_ref
//
for (std::size_t row = 0; row < opa_nrows; row++) {
fpType tmp = 0;
for (intType i = iopa[row] - a_ind; i < iopa[row + 1] - a_ind; i++) {
std::size_t iu = static_cast<std::size_t>(i);
std::size_t x_ind = static_cast<std::size_t>(jopa[iu] - a_ind);
tmp += opa[iu] * x[x_ind];
}

y_ref[row] = alpha * tmp + beta * y_ref[row];
}
}

#endif // _SPARSE_REFERENCE_HPP__
Loading

0 comments on commit d1294a4

Please sign in to comment.