From 3c75a158d54b4d5b333a7f15302b7ce8e9341328 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Scipione?= <9421873+s-Nick@users.noreply.github.com> Date: Thu, 26 Oct 2023 11:12:22 +0200 Subject: [PATCH] [BLAS::portBLAS backend] Add USM support for batch operators (#399) --- src/blas/backends/portblas/portblas_batch.cxx | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/blas/backends/portblas/portblas_batch.cxx b/src/blas/backends/portblas/portblas_batch.cxx index a1be7f7aa..96b34e6bd 100644 --- a/src/blas/backends/portblas/portblas_batch.cxx +++ b/src/blas/backends/portblas/portblas_batch.cxx @@ -703,7 +703,9 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, + dependencies); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -713,7 +715,9 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, double beta, double *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, + dependencies); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -825,7 +829,8 @@ sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std std::int64_t n, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", " for USM"); + CALL_PORTBLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b, + ldb, stride_b, batch_size, dependencies); } sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, @@ -833,7 +838,8 @@ sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std std::int64_t stride_a, double *b, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", " for USM"); + CALL_PORTBLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b, + ldb, stride_b, batch_size, dependencies); } sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, @@ -886,7 +892,9 @@ sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, float beta, const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", " for USM"); + CALL_PORTBLAS_USM_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda, + stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, + dependencies); } sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -895,7 +903,9 @@ sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, double beta, const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", " for USM"); + CALL_PORTBLAS_USM_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda, + stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, + dependencies); } sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa,