From 154902a13b180abe448ab2d368cc1e136dd4c677 Mon Sep 17 00:00:00 2001 From: pgorlani Date: Thu, 11 May 2023 14:33:48 +0100 Subject: [PATCH 1/2] Add rotmg, sbmv, tbmv, spmv, tbsv, trsv --- src/blas/backends/syclblas/syclblas_level1.cxx | 6 ++---- src/blas/backends/syclblas/syclblas_level2.cxx | 11 ++++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/blas/backends/syclblas/syclblas_level1.cxx b/src/blas/backends/syclblas/syclblas_level1.cxx index 1722cc440..722645e11 100644 --- a/src/blas/backends/syclblas/syclblas_level1.cxx +++ b/src/blas/backends/syclblas/syclblas_level1.cxx @@ -150,10 +150,8 @@ void rotm(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::i void rotmg(sycl::queue &queue, sycl::buffer &d1, sycl::buffer &d2, sycl::buffer &x1, real_t y1, sycl::buffer ¶m) { - //TODO(codeplay): Enable rotmg - //sycl::buffer y1_buffer(&y1, sycl::range<1>{ 1 }); - //CALL_SYCLBLAS_FN(::blas::_rotmg, queue, d1, d2, x1, y1_buffer, param); - throw unimplemented("blas", "rotmg", ""); + sycl::buffer y1_buffer(&y1, sycl::range<1>{ 1 }); + CALL_SYCLBLAS_FN(::blas::_rotmg, queue, d1, d2, x1, y1_buffer, param); } void scal(sycl::queue &queue, std::int64_t n, real_t alpha, sycl::buffer &x, diff --git a/src/blas/backends/syclblas/syclblas_level2.cxx b/src/blas/backends/syclblas/syclblas_level2.cxx index 6d2b5d753..b1149b07e 100644 --- a/src/blas/backends/syclblas/syclblas_level2.cxx +++ b/src/blas/backends/syclblas/syclblas_level2.cxx @@ -118,7 +118,8 @@ void hpr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, void sbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, std::int64_t k, real_t alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, std::int64_t incx, real_t beta, sycl::buffer &y, std::int64_t incy) { - throw unimplemented("blas", "sbmv", ""); + CALL_SYCLBLAS_FN(::blas::_sbmv, queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, + incy); } void symv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, @@ -142,7 +143,7 @@ void syr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, rea void spmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, sycl::buffer &a, sycl::buffer &x, std::int64_t incx, real_t beta, sycl::buffer &y, std::int64_t incy) { - throw unimplemented("blas", "spmv", ""); + CALL_SYCLBLAS_FN(::blas::_spmv, queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); } void spr(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, @@ -159,7 +160,7 @@ void spr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, rea void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, std::int64_t incx) { - throw unimplemented("blas", "tbmv", ""); + CALL_SYCLBLAS_FN(::blas::_tbmv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); } void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, @@ -172,7 +173,7 @@ void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transp void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, std::int64_t incx) { - throw unimplemented("blas", "tbsv", ""); + CALL_SYCLBLAS_FN(::blas::_tbsv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); } void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, @@ -221,7 +222,7 @@ void trmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transp void trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, std::int64_t incx) { - throw unimplemented("blas", "trsv", ""); + CALL_SYCLBLAS_FN(::blas::_trsv, queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); } void trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, From 79bfe76d6c09cd72b23e01ee1db56204de0aad34 Mon Sep 17 00:00:00 2001 From: pgorlani Date: Thu, 11 May 2023 16:28:24 +0100 Subject: [PATCH 2/2] Add tpmv --- src/blas/backends/syclblas/syclblas_level2.cxx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blas/backends/syclblas/syclblas_level2.cxx b/src/blas/backends/syclblas/syclblas_level2.cxx index b1149b07e..91e633be2 100644 --- a/src/blas/backends/syclblas/syclblas_level2.cxx +++ b/src/blas/backends/syclblas/syclblas_level2.cxx @@ -186,7 +186,7 @@ void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transp void tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer &a, sycl::buffer &x, std::int64_t incx) { - throw unimplemented("blas", "tpmv", ""); + CALL_SYCLBLAS_FN(::blas::_tpmv, queue, upper_lower, trans, unit_diag, n, a, x, incx); } void tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,