diff --git a/src/blas/backends/syclblas/syclblas_level1.cxx b/src/blas/backends/syclblas/syclblas_level1.cxx index 1722cc440..722645e11 100644 --- a/src/blas/backends/syclblas/syclblas_level1.cxx +++ b/src/blas/backends/syclblas/syclblas_level1.cxx @@ -150,10 +150,8 @@ void rotm(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::i void rotmg(sycl::queue &queue, sycl::buffer &d1, sycl::buffer &d2, sycl::buffer &x1, real_t y1, sycl::buffer ¶m) { - //TODO(codeplay): Enable rotmg - //sycl::buffer y1_buffer(&y1, sycl::range<1>{ 1 }); - //CALL_SYCLBLAS_FN(::blas::_rotmg, queue, d1, d2, x1, y1_buffer, param); - throw unimplemented("blas", "rotmg", ""); + sycl::buffer y1_buffer(&y1, sycl::range<1>{ 1 }); + CALL_SYCLBLAS_FN(::blas::_rotmg, queue, d1, d2, x1, y1_buffer, param); } void scal(sycl::queue &queue, std::int64_t n, real_t alpha, sycl::buffer &x, diff --git a/src/blas/backends/syclblas/syclblas_level2.cxx b/src/blas/backends/syclblas/syclblas_level2.cxx index 6d2b5d753..91e633be2 100644 --- a/src/blas/backends/syclblas/syclblas_level2.cxx +++ b/src/blas/backends/syclblas/syclblas_level2.cxx @@ -118,7 +118,8 @@ void hpr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, void sbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, std::int64_t k, real_t alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, std::int64_t incx, real_t beta, sycl::buffer &y, std::int64_t incy) { - throw unimplemented("blas", "sbmv", ""); + CALL_SYCLBLAS_FN(::blas::_sbmv, queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, + incy); } void symv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, @@ -142,7 +143,7 @@ void syr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, rea void spmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, sycl::buffer &a, sycl::buffer &x, std::int64_t incx, real_t beta, sycl::buffer &y, std::int64_t incy) { - throw unimplemented("blas", "spmv", ""); + CALL_SYCLBLAS_FN(::blas::_spmv, queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); } void spr(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, @@ -159,7 +160,7 @@ void spr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, rea void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, std::int64_t incx) { - throw unimplemented("blas", "tbmv", ""); + CALL_SYCLBLAS_FN(::blas::_tbmv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); } void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, @@ -172,7 +173,7 @@ void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transp void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, std::int64_t incx) { - throw unimplemented("blas", "tbsv", ""); + CALL_SYCLBLAS_FN(::blas::_tbsv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); } void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, @@ -185,7 +186,7 @@ void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transp void tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer &a, sycl::buffer &x, std::int64_t incx) { - throw unimplemented("blas", "tpmv", ""); + CALL_SYCLBLAS_FN(::blas::_tpmv, queue, upper_lower, trans, unit_diag, n, a, x, incx); } void tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, @@ -221,7 +222,7 @@ void trmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transp void trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, std::int64_t incx) { - throw unimplemented("blas", "trsv", ""); + CALL_SYCLBLAS_FN(::blas::_trsv, queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); } void trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans,