Skip to content

Commit

Permalink
[BLAS:: SYCL-BLAS backend] Enable SYCL-BLAS routines (#277)
Browse files Browse the repository at this point in the history
Added rotmg, sbmv, tbmv, spmv, tbsv, trsv and tpmv BLAS operators.
  • Loading branch information
pgorlani authored Jun 19, 2023
1 parent 57c5f2b commit efa1165
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
6 changes: 2 additions & 4 deletions src/blas/backends/syclblas/syclblas_level1.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,8 @@ void rotm(sycl::queue &queue, std::int64_t n, sycl::buffer<real_t, 1> &x, std::i

void rotmg(sycl::queue &queue, sycl::buffer<real_t, 1> &d1, sycl::buffer<real_t, 1> &d2,
sycl::buffer<real_t, 1> &x1, real_t y1, sycl::buffer<real_t, 1> &param) {
//TODO(codeplay): Enable rotmg
//sycl::buffer<real_t, 1> y1_buffer(&y1, sycl::range<1>{ 1 });
//CALL_SYCLBLAS_FN(::blas::_rotmg, queue, d1, d2, x1, y1_buffer, param);
throw unimplemented("blas", "rotmg", "");
sycl::buffer<real_t, 1> y1_buffer(&y1, sycl::range<1>{ 1 });
CALL_SYCLBLAS_FN(::blas::_rotmg, queue, d1, d2, x1, y1_buffer, param);
}

void scal(sycl::queue &queue, std::int64_t n, real_t alpha, sycl::buffer<real_t, 1> &x,
Expand Down
13 changes: 7 additions & 6 deletions src/blas/backends/syclblas/syclblas_level2.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ void hpr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n,
void sbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, std::int64_t k,
real_t alpha, sycl::buffer<real_t, 1> &a, std::int64_t lda, sycl::buffer<real_t, 1> &x,
std::int64_t incx, real_t beta, sycl::buffer<real_t, 1> &y, std::int64_t incy) {
throw unimplemented("blas", "sbmv", "");
CALL_SYCLBLAS_FN(::blas::_sbmv, queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y,
incy);
}

void symv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha,
Expand All @@ -142,7 +143,7 @@ void syr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, rea
void spmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha,
sycl::buffer<real_t, 1> &a, sycl::buffer<real_t, 1> &x, std::int64_t incx, real_t beta,
sycl::buffer<real_t, 1> &y, std::int64_t incy) {
throw unimplemented("blas", "spmv", "");
CALL_SYCLBLAS_FN(::blas::_spmv, queue, upper_lower, n, alpha, a, x, incx, beta, y, incy);
}

void spr(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha,
Expand All @@ -159,7 +160,7 @@ void spr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, rea
void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,
oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer<real_t, 1> &a,
std::int64_t lda, sycl::buffer<real_t, 1> &x, std::int64_t incx) {
throw unimplemented("blas", "tbmv", "");
CALL_SYCLBLAS_FN(::blas::_tbmv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx);
}

void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,
Expand All @@ -172,7 +173,7 @@ void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transp
void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,
oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer<real_t, 1> &a,
std::int64_t lda, sycl::buffer<real_t, 1> &x, std::int64_t incx) {
throw unimplemented("blas", "tbsv", "");
CALL_SYCLBLAS_FN(::blas::_tbsv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx);
}

void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,
Expand All @@ -185,7 +186,7 @@ void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transp
void tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,
oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer<real_t, 1> &a,
sycl::buffer<real_t, 1> &x, std::int64_t incx) {
throw unimplemented("blas", "tpmv", "");
CALL_SYCLBLAS_FN(::blas::_tpmv, queue, upper_lower, trans, unit_diag, n, a, x, incx);
}

void tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,
Expand Down Expand Up @@ -221,7 +222,7 @@ void trmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transp
void trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,
oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer<real_t, 1> &a, std::int64_t lda,
sycl::buffer<real_t, 1> &x, std::int64_t incx) {
throw unimplemented("blas", "trsv", "");
CALL_SYCLBLAS_FN(::blas::_trsv, queue, upper_lower, trans, unit_diag, n, a, lda, x, incx);
}

void trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,
Expand Down

0 comments on commit efa1165

Please sign in to comment.