From 9abb8cf1a8b4fa56c1a81db885e6f545d8a9623b Mon Sep 17 00:00:00 2001 From: "Andrew T. Barker" Date: Tue, 20 Sep 2022 22:15:52 +0000 Subject: [PATCH] [BLAS] add interfaces to matrix copy/transposition routines (#227) --- include/oneapi/mkl/blas.hxx | 342 ++++++++++++ .../mkl/blas/detail/blas_ct_backends.hxx | 166 ++++++ .../oneapi/mkl/blas/detail/blas_loader.hxx | 149 +++++ .../oneapi/mkl/blas/detail/cublas/blas_ct.hxx | 335 +++++++++++ .../blas/detail/cublas/onemkl_blas_cublas.hxx | 122 ++++ .../oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx | 335 +++++++++++ .../oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx | 335 +++++++++++ .../oneapi/mkl/blas/detail/netlib/blas_ct.hxx | 340 +++++++++++ .../mkl/blas/detail/onemkl_blas_backends.hxx | 166 ++++++ .../mkl/blas/detail/rocblas/blas_ct.hxx | 335 +++++++++++ .../detail/rocblas/onemkl_blas_rocblas.hxx | 122 ++++ include/oneapi/mkl/blas/predicates.hxx | 527 ++++++++++++++++++ src/blas/backends/backend_wrappers.cxx | 24 + src/blas/backends/cublas/cublas_batch.cpp | 340 +++++++++++ src/blas/backends/cublas/cublas_wrappers.cpp | 48 ++ src/blas/backends/mkl_common/mkl_batch.cxx | 190 +++++++ .../backends/mkl_common/mkl_blas_backend.hxx | 141 +++++ src/blas/backends/netlib/netlib_batch.cxx | 290 ++++++++++ src/blas/backends/rocblas/rocblas_batch.cpp | 340 +++++++++++ .../backends/rocblas/rocblas_wrappers.cpp | 48 ++ src/blas/blas_loader.cpp | 450 +++++++++++++++ src/blas/function_table.hpp | 264 +++++++++ tests/unit_tests/blas/batch/CMakeLists.txt | 2 +- .../blas/batch/imatcopy_batch_stride.cpp | 209 +++++++ .../blas/batch/imatcopy_batch_stride_usm.cpp | 234 ++++++++ .../blas/batch/omatadd_batch_stride.cpp | 232 ++++++++ .../blas/batch/omatadd_batch_stride_usm.cpp | 263 +++++++++ .../blas/batch/omatcopy_batch_stride.cpp | 214 +++++++ .../blas/batch/omatcopy_batch_stride_usm.cpp | 245 ++++++++ .../blas/include/reference_blas_templates.hpp | 170 ++++++ 30 files changed, 6977 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/blas/batch/imatcopy_batch_stride.cpp create mode 100644 tests/unit_tests/blas/batch/imatcopy_batch_stride_usm.cpp create mode 100644 tests/unit_tests/blas/batch/omatadd_batch_stride.cpp create mode 100644 tests/unit_tests/blas/batch/omatadd_batch_stride_usm.cpp create mode 100644 tests/unit_tests/blas/batch/omatcopy_batch_stride.cpp create mode 100644 tests/unit_tests/blas/batch/omatcopy_batch_stride_usm.cpp diff --git a/include/oneapi/mkl/blas.hxx b/include/oneapi/mkl/blas.hxx index b8063a21a..2a3400cbb 100644 --- a/include/oneapi/mkl/blas.hxx +++ b/include/oneapi/mkl/blas.hxx @@ -1937,6 +1937,159 @@ static inline void trsv(sycl::queue &queue, uplo upper_lower, transpose trans, d trsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); } +static inline void omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); + detail::omatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stride_a, b, + ldb, stride_b, batch_size); + omatcopy_batch_postcondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); +} + +static inline void omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); + detail::omatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stride_a, b, + ldb, stride_b, batch_size); + omatcopy_batch_postcondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); +} + +static inline void omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + omatcopy_batch_precondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); + detail::omatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stride_a, b, + ldb, stride_b, batch_size); + omatcopy_batch_postcondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); +} + +static inline void omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + omatcopy_batch_precondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); + detail::omatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stride_a, b, + ldb, stride_b, batch_size); + omatcopy_batch_postcondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); +} + +static inline void imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { + imatcopy_batch_precondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); + detail::imatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + imatcopy_batch_postcondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); +} + +static inline void imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { + imatcopy_batch_precondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); + detail::imatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + imatcopy_batch_postcondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); +} + +static inline void imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); + detail::imatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + imatcopy_batch_postcondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); +} + +static inline void imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); + detail::imatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + imatcopy_batch_postcondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); +} + +static inline void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + float beta, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); + detail::omatadd_batch(get_device_id(queue), queue, transa, transb, m, n, alpha, a, lda, + stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +static inline void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); + detail::omatadd_batch(get_device_id(queue), queue, transa, transb, m, n, alpha, a, lda, + stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +static inline void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); + detail::omatadd_batch(get_device_id(queue), queue, transa, transb, m, n, alpha, a, lda, + stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +static inline void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); + detail::omatadd_batch(get_device_id(queue), queue, transa, transb, m, n, alpha, a, lda, + stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + // USM APIs static inline sycl::event asum(sycl::queue &queue, std::int64_t n, @@ -4775,3 +4928,192 @@ static inline sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose t trsv_postcondition(queue, upper_lower, trans, unit_diag, n, a, lda, x, incx, dependencies); return done; } + +static inline sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatcopy_batch_precondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); + auto done = detail::omatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size, dependencies); + omatcopy_batch_postcondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); + return done; +} + +static inline sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatcopy_batch_precondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); + auto done = detail::omatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size, dependencies); + omatcopy_batch_postcondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); + return done; +} + +static inline sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatcopy_batch_precondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); + auto done = detail::omatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size, dependencies); + omatcopy_batch_postcondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); + return done; +} + +static inline sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatcopy_batch_precondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); + auto done = detail::omatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size, dependencies); + omatcopy_batch_postcondition(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); + return done; +} + +static inline sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + imatcopy_batch_precondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); + auto done = detail::imatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, ab, lda, + ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); + return done; +} + +static inline sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, double *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + imatcopy_batch_precondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); + auto done = detail::imatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, ab, lda, + ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); + return done; +} + +static inline sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + imatcopy_batch_precondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); + auto done = detail::imatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, ab, lda, + ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); + return done; +} + +static inline sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + imatcopy_batch_precondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); + auto done = detail::imatcopy_batch(get_device_id(queue), queue, trans, m, n, alpha, ab, lda, + ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); + return done; +} + +static inline sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, + float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatadd_batch_precondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = detail::omatadd_batch(get_device_id(queue), queue, transa, transb, m, n, alpha, a, + lda, stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, + batch_size, dependencies); + omatadd_batch_postcondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +static inline sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, + double beta, const double *b, std::int64_t ldb, + std::int64_t stride_b, double *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatadd_batch_precondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = detail::omatadd_batch(get_device_id(queue), queue, transa, transb, m, n, alpha, a, + lda, stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, + batch_size, dependencies); + omatadd_batch_postcondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +static inline sycl::event omatadd_batch( + sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, const std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies = {}) { + omatadd_batch_precondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = detail::omatadd_batch(get_device_id(queue), queue, transa, transb, m, n, alpha, a, + lda, stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, + batch_size, dependencies); + omatadd_batch_postcondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +static inline sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatadd_batch_precondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = detail::omatadd_batch(get_device_id(queue), queue, transa, transb, m, n, alpha, a, + lda, stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, + batch_size, dependencies); + omatadd_batch_postcondition(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx index 6109120fa..9c4896aea 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx @@ -1083,6 +1083,82 @@ static inline void symv(backend_selector selector, uplo upper_ std::int64_t lda, sycl::buffer &x, std::int64_t incx, double beta, sycl::buffer &y, std::int64_t incy); +static inline void omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); + +static inline void omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); + +static inline void omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); + +static inline void omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); + +static inline void imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); + +static inline void imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); + +static inline void imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size); + +static inline void imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size); + +static inline void omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + float beta, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + +static inline void omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + +static inline void omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +static inline void omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + // USM APIs static inline sycl::event syr2(backend_selector selector, uplo upper_lower, @@ -2484,3 +2560,93 @@ static inline sycl::event symv(backend_selector selector, uplo const double *x, std::int64_t incx, double beta, double *y, std::int64_t incy, const std::vector &dependencies = {}); + +static inline sycl::event omatcopy_batch(backend_selector selector, + transpose trans, std::int64_t m, std::int64_t n, + float alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event omatcopy_batch(backend_selector selector, + transpose trans, std::int64_t m, std::int64_t n, + double alpha, const double *a, std::int64_t lda, + std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event omatcopy_batch(backend_selector selector, + transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event omatcopy_batch(backend_selector selector, + transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event imatcopy_batch(backend_selector selector, + transpose trans, std::int64_t m, std::int64_t n, + float alpha, float *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event imatcopy_batch(backend_selector selector, + transpose trans, std::int64_t m, std::int64_t n, + double alpha, double *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event imatcopy_batch(backend_selector selector, + transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event imatcopy_batch(backend_selector selector, + transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event omatadd_batch(backend_selector selector, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, + float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event omatadd_batch(backend_selector selector, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, + double *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event omatadd_batch( + backend_selector selector, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, const std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies = {}); + +static inline sycl::event omatadd_batch( + backend_selector selector, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, const std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::complex *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hxx b/include/oneapi/mkl/blas/detail/blas_loader.hxx index 6c8a747ed..ea7a35d17 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hxx +++ b/include/oneapi/mkl/blas/detail/blas_loader.hxx @@ -960,6 +960,73 @@ ONEMKL_EXPORT void rotg(oneapi::mkl::device libkey, sycl::queue &queue, sycl::buffer &c, sycl::buffer, 1> &s); +ONEMKL_EXPORT void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); +ONEMKL_EXPORT void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); +ONEMKL_EXPORT void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); +ONEMKL_EXPORT void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); + +ONEMKL_EXPORT void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); +ONEMKL_EXPORT void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); +ONEMKL_EXPORT void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size); +ONEMKL_EXPORT void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size); + +ONEMKL_EXPORT void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + float beta, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); +ONEMKL_EXPORT void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); +ONEMKL_EXPORT void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +ONEMKL_EXPORT void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + // USM APIs ONEMKL_EXPORT sycl::event herk(oneapi::mkl::device libkey, sycl::queue &queue, @@ -2271,3 +2338,85 @@ ONEMKL_EXPORT sycl::event rotg(oneapi::mkl::device libkey, sycl::queue &queue, std::complex *a, std::complex *b, double *c, std::complex *s, const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose trans, std::int64_t m, std::int64_t n, + float alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose trans, std::int64_t m, std::int64_t n, + double alpha, const double *a, std::int64_t lda, + std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose trans, std::int64_t m, std::int64_t n, + float alpha, float *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose trans, std::int64_t m, std::int64_t n, + double alpha, double *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, + float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, + double *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatadd_batch( + oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, const std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::complex *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatadd_batch( + oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx index cdbddc6c0..0f5e4e63b 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx @@ -2023,6 +2023,161 @@ void symv(backend_selector selector, uplo upper_lower, std::int symv_postcondition(selector.get_queue(), upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); } +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::cublas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::cublas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::cublas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::cublas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::cublas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::cublas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::cublas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::cublas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, float beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::cublas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::cublas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::cublas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::cublas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + // USM APIs sycl::event syr2(backend_selector selector, uplo upper_lower, std::int64_t n, @@ -4955,3 +5110,183 @@ sycl::event symv(backend_selector selector, uplo upper_lower, s dependencies); return done; } + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, double *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::cublas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx index a2ba03249..d6d59c435 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx @@ -875,6 +875,64 @@ void gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset of sycl::buffer &b, std::int64_t ldb, uint8_t bo, float beta, sycl::buffer &c, std::int64_t ldc, sycl::buffer &co); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size); + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size); + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size); + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + // USM APIs sycl::event asum(sycl::queue &queue, std::int64_t n, const std::complex *x, @@ -2012,3 +2070,67 @@ sycl::event gemm_bias(sycl::queue &queue, transpose transa, transpose transb, const std::uint8_t *b, std::int64_t ldb, std::uint8_t bo, float beta, std::int32_t *c, std::int64_t ldc, const std::int32_t *co, const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies = {}); diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx index b9b4f01ed..f97d3afb3 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx @@ -2025,6 +2025,161 @@ void symv(backend_selector selector, uplo upper_lower, std::int symv_postcondition(selector.get_queue(), upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); } +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, float beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::mklcpu::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + // USM APIs sycl::event syr2(backend_selector selector, uplo upper_lower, std::int64_t n, @@ -4957,3 +5112,183 @@ sycl::event symv(backend_selector selector, uplo upper_lower, s dependencies); return done; } + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, double *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} diff --git a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx index 48d2de7cd..cba36992a 100644 --- a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx @@ -2025,6 +2025,161 @@ void symv(backend_selector selector, uplo upper_lower, std::int symv_postcondition(selector.get_queue(), upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); } +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, float beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::mklgpu::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + // USM APIs sycl::event syr2(backend_selector selector, uplo upper_lower, std::int64_t n, @@ -4957,3 +5112,183 @@ sycl::event symv(backend_selector selector, uplo upper_lower, s dependencies); return done; } + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, double *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} diff --git a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx index c1100949a..a2f7681d6 100644 --- a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx @@ -2025,6 +2025,161 @@ void symv(backend_selector selector, uplo upper_lower, std::int symv_postcondition(selector.get_queue(), upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); } +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::netlib::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::netlib::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::netlib::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::netlib::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::netlib::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::netlib::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::netlib::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::netlib::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, float beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::netlib::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::netlib::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::netlib::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::netlib::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + // USM APIs sycl::event syr2(backend_selector selector, uplo upper_lower, std::int64_t n, @@ -4957,3 +5112,188 @@ sycl::event symv(backend_selector selector, uplo upper_lower, s dependencies); return done; } + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies = {}) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, double *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies = {}) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies = {}) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::netlib::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} diff --git a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx index 5ba428c7e..6180b0da1 100644 --- a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx +++ b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx @@ -1064,6 +1064,82 @@ ONEMKL_EXPORT void gemm_bias(sycl::queue &queue, oneapi::mkl::transpose transa, float beta, sycl::buffer &c, std::int64_t ldc, sycl::buffer &co); +ONEMKL_EXPORT void omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + +ONEMKL_EXPORT void omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + +ONEMKL_EXPORT void omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); + +ONEMKL_EXPORT void omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size); + +ONEMKL_EXPORT void imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size); + +ONEMKL_EXPORT void imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size); + +ONEMKL_EXPORT void imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size); + +ONEMKL_EXPORT void imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size); + +ONEMKL_EXPORT void omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, float beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +ONEMKL_EXPORT void omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + +ONEMKL_EXPORT void omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +ONEMKL_EXPORT void omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + // USM APIs ONEMKL_EXPORT sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -2470,3 +2546,93 @@ ONEMKL_EXPORT sycl::event gemmt(sycl::queue &queue, oneapi::mkl::uplo upper_lowe std::complex beta, std::complex *c, std::int64_t ldc, const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, + float *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, + double *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, float alpha, float *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, double alpha, double *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, + std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, + float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, + std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, + double *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatadd_batch( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, const std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::complex *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatadd_batch( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); diff --git a/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx index f00bd0587..c8278159f 100644 --- a/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx @@ -1949,6 +1949,161 @@ void symv(backend_selector selector, uplo upper_lower, int64_t symv_postcondition(selector.get_queue(), upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); } +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::rocblas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::rocblas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::rocblas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); + oneapi::mkl::blas::rocblas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::rocblas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::rocblas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::rocblas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); + oneapi::mkl::blas::rocblas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, float beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::rocblas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::rocblas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::rocblas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); + oneapi::mkl::blas::rocblas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + // USM APIs sycl::event syr2(backend_selector selector, uplo upper_lower, int64_t n, @@ -4760,3 +4915,183 @@ sycl::event symv(backend_selector selector, uplo upper_lower, dependencies); return done; } + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + omatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + omatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, double *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + imatcopy_batch_precondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + imatcopy_batch_postcondition(selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, + batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + omatadd_batch_precondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + omatadd_batch_postcondition(selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); + return done; +} diff --git a/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx b/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx index b1b99abad..39cca773d 100644 --- a/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx +++ b/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx @@ -806,6 +806,64 @@ void gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset of uint8_t ao, sycl::buffer &b, int64_t ldb, uint8_t bo, float beta, sycl::buffer &c, int64_t ldc, sycl::buffer &co); +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size); + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size); + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size); + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + // USM APIs sycl::event asum(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, @@ -1806,3 +1864,67 @@ sycl::event gemm_bias(sycl::queue &queue, transpose transa, transpose transb, of int64_t lda, std::uint8_t ao, const std::uint8_t *b, int64_t ldb, std::uint8_t bo, float beta, std::int32_t *c, int64_t ldc, const std::int32_t *co, const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies = {}); diff --git a/include/oneapi/mkl/blas/predicates.hxx b/include/oneapi/mkl/blas/predicates.hxx index 73f5812a8..cbd51a6d1 100644 --- a/include/oneapi/mkl/blas/predicates.hxx +++ b/include/oneapi/mkl/blas/predicates.hxx @@ -3699,6 +3699,260 @@ inline void rotg_postcondition(sycl::queue &queue, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_precondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, float beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_postcondition( + sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, float beta, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_precondition( + sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + double alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, double beta, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_postcondition( + sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + double alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, double beta, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_precondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, + sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_postcondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, + sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_precondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, + sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_postcondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, + sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + // USM APIs inline void herk_precondition(sycl::queue &queue, uplo upper_lower, transpose trans, @@ -8049,3 +8303,276 @@ inline void rotg_postcondition(sycl::queue &queue, std::complex *a, /* add postchecks to queue here for input args. */ #endif } + +inline void omatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, double *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, double *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_precondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void imatcopy_batch_postcondition(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_precondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, + float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_postcondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, + float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_precondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, + double beta, const double *b, std::int64_t ldb, + std::int64_t stride_b, double *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_postcondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, + double beta, const double *b, std::int64_t ldb, + std::int64_t stride_b, double *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_precondition( + sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, const std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_postcondition( + sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, const std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_precondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add prechecks to queue here for input args. */ +#endif +} + +inline void omatadd_batch_postcondition(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies) { +#ifndef ONEMKL_DISABLE_PREDICATES + /* add postchecks to queue here for input args. */ +#endif +} diff --git a/src/blas/backends/backend_wrappers.cxx b/src/blas/backends/backend_wrappers.cxx index a0b49c7c4..196ec1138 100644 --- a/src/blas/backends/backend_wrappers.cxx +++ b/src/blas/backends/backend_wrappers.cxx @@ -212,6 +212,18 @@ oneapi::mkl::blas::BACKEND::MAJOR::gemm_bias, oneapi::mkl::blas::BACKEND::MAJOR::gemm_bias, oneapi::mkl::blas::BACKEND::MAJOR::gemm_bias, oneapi::mkl::blas::BACKEND::MAJOR::gemm_bias, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::imatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::imatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::imatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::imatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatadd_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatadd_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatadd_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatadd_batch, oneapi::mkl::blas::BACKEND::MAJOR::asum, oneapi::mkl::blas::BACKEND::MAJOR::asum, oneapi::mkl::blas::BACKEND::MAJOR::asum, @@ -435,4 +447,16 @@ oneapi::mkl::blas::BACKEND::MAJOR::gemm_bias, oneapi::mkl::blas::BACKEND::MAJOR::gemm_bias, oneapi::mkl::blas::BACKEND::MAJOR::gemm_bias, oneapi::mkl::blas::BACKEND::MAJOR::gemm_bias, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::imatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::imatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::imatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::imatcopy_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatadd_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatadd_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatadd_batch, +oneapi::mkl::blas::BACKEND::MAJOR::omatadd_batch, // clang-format on diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 85eadbe42..b262df060 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -246,6 +246,88 @@ void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n throw unimplemented("blas", "syrk_batch", "for column_major layout"); } +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + // USM APIs sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, @@ -723,6 +805,94 @@ sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, in throw unimplemented("blas", "syrk_batch", "for column_major layout"); } +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + } // namespace column_major namespace row_major { @@ -924,6 +1094,88 @@ void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n throw unimplemented("blas", "syrk_batch", "for row_major layout"); } +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + // USM APIs sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, @@ -1322,6 +1574,94 @@ sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, in throw unimplemented("blas", "syrk_batch", "for row_major layout"); } +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + } // namespace row_major } // namespace cublas } // namespace blas diff --git a/src/blas/backends/cublas/cublas_wrappers.cpp b/src/blas/backends/cublas/cublas_wrappers.cpp index 123c4e438..6c0615afe 100644 --- a/src/blas/backends/cublas/cublas_wrappers.cpp +++ b/src/blas/backends/cublas/cublas_wrappers.cpp @@ -217,6 +217,18 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm_bias, oneapi::mkl::blas::cublas::column_major::gemm_bias, oneapi::mkl::blas::cublas::column_major::gemm_bias, + oneapi::mkl::blas::cublas::column_major::omatcopy_batch, + oneapi::mkl::blas::cublas::column_major::omatcopy_batch, + oneapi::mkl::blas::cublas::column_major::omatcopy_batch, + oneapi::mkl::blas::cublas::column_major::omatcopy_batch, + oneapi::mkl::blas::cublas::column_major::imatcopy_batch, + oneapi::mkl::blas::cublas::column_major::imatcopy_batch, + oneapi::mkl::blas::cublas::column_major::imatcopy_batch, + oneapi::mkl::blas::cublas::column_major::imatcopy_batch, + oneapi::mkl::blas::cublas::column_major::omatadd_batch, + oneapi::mkl::blas::cublas::column_major::omatadd_batch, + oneapi::mkl::blas::cublas::column_major::omatadd_batch, + oneapi::mkl::blas::cublas::column_major::omatadd_batch, oneapi::mkl::blas::cublas::column_major::asum, oneapi::mkl::blas::cublas::column_major::asum, oneapi::mkl::blas::cublas::column_major::asum, @@ -440,6 +452,18 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm_bias, oneapi::mkl::blas::cublas::column_major::gemm_bias, oneapi::mkl::blas::cublas::column_major::gemm_bias, + oneapi::mkl::blas::cublas::column_major::omatcopy_batch, + oneapi::mkl::blas::cublas::column_major::omatcopy_batch, + oneapi::mkl::blas::cublas::column_major::omatcopy_batch, + oneapi::mkl::blas::cublas::column_major::omatcopy_batch, + oneapi::mkl::blas::cublas::column_major::imatcopy_batch, + oneapi::mkl::blas::cublas::column_major::imatcopy_batch, + oneapi::mkl::blas::cublas::column_major::imatcopy_batch, + oneapi::mkl::blas::cublas::column_major::imatcopy_batch, + oneapi::mkl::blas::cublas::column_major::omatadd_batch, + oneapi::mkl::blas::cublas::column_major::omatadd_batch, + oneapi::mkl::blas::cublas::column_major::omatadd_batch, + oneapi::mkl::blas::cublas::column_major::omatadd_batch, oneapi::mkl::blas::cublas::row_major::asum, oneapi::mkl::blas::cublas::row_major::asum, oneapi::mkl::blas::cublas::row_major::asum, @@ -634,6 +658,18 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm_bias, oneapi::mkl::blas::cublas::row_major::gemm_bias, oneapi::mkl::blas::cublas::row_major::gemm_bias, + oneapi::mkl::blas::cublas::row_major::omatcopy_batch, + oneapi::mkl::blas::cublas::row_major::omatcopy_batch, + oneapi::mkl::blas::cublas::row_major::omatcopy_batch, + oneapi::mkl::blas::cublas::row_major::omatcopy_batch, + oneapi::mkl::blas::cublas::row_major::imatcopy_batch, + oneapi::mkl::blas::cublas::row_major::imatcopy_batch, + oneapi::mkl::blas::cublas::row_major::imatcopy_batch, + oneapi::mkl::blas::cublas::row_major::imatcopy_batch, + oneapi::mkl::blas::cublas::row_major::omatadd_batch, + oneapi::mkl::blas::cublas::row_major::omatadd_batch, + oneapi::mkl::blas::cublas::row_major::omatadd_batch, + oneapi::mkl::blas::cublas::row_major::omatadd_batch, oneapi::mkl::blas::cublas::row_major::asum, oneapi::mkl::blas::cublas::row_major::asum, oneapi::mkl::blas::cublas::row_major::asum, @@ -857,4 +893,16 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm_bias, oneapi::mkl::blas::cublas::row_major::gemm_bias, oneapi::mkl::blas::cublas::row_major::gemm_bias, + oneapi::mkl::blas::cublas::row_major::omatcopy_batch, + oneapi::mkl::blas::cublas::row_major::omatcopy_batch, + oneapi::mkl::blas::cublas::row_major::omatcopy_batch, + oneapi::mkl::blas::cublas::row_major::omatcopy_batch, + oneapi::mkl::blas::cublas::row_major::imatcopy_batch, + oneapi::mkl::blas::cublas::row_major::imatcopy_batch, + oneapi::mkl::blas::cublas::row_major::imatcopy_batch, + oneapi::mkl::blas::cublas::row_major::imatcopy_batch, + oneapi::mkl::blas::cublas::row_major::omatadd_batch, + oneapi::mkl::blas::cublas::row_major::omatadd_batch, + oneapi::mkl::blas::cublas::row_major::omatadd_batch, + oneapi::mkl::blas::cublas::row_major::omatadd_batch, }; diff --git a/src/blas/backends/mkl_common/mkl_batch.cxx b/src/blas/backends/mkl_common/mkl_batch.cxx index c66d04176..3ade02ae3 100644 --- a/src/blas/backends/mkl_common/mkl_batch.cxx +++ b/src/blas/backends/mkl_common/mkl_batch.cxx @@ -248,6 +248,96 @@ void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n stride_c, batch_size); } +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + blas_major::omatcopy_batch(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + blas_major::omatcopy_batch(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + blas_major::omatcopy_batch(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + blas_major::omatcopy_batch(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + blas_major::imatcopy_batch(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + blas_major::imatcopy_batch(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + blas_major::imatcopy_batch(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + blas_major::imatcopy_batch(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + blas_major::omatadd_batch(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + blas_major::omatadd_batch(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + blas_major::omatadd_batch(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + blas_major::omatadd_batch(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + // USM APIs sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, int64_t incx, @@ -734,3 +824,103 @@ sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, return blas_major::syrk_batch(queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, group_count, groupsize, dependencies); } + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::omatcopy_batch(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::omatcopy_batch(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + return blas_major::omatcopy_batch(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + return blas_major::omatcopy_batch(queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, + batch_size, dependencies); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::imatcopy_batch(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::imatcopy_batch(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::imatcopy_batch(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::imatcopy_batch(queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, + dependencies); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::omatadd_batch(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, + ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::omatadd_batch(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, + ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::omatadd_batch(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, + ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + return blas_major::omatadd_batch(queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, + ldb, stride_b, c, ldc, stride_c, batch_size, dependencies); +} diff --git a/src/blas/backends/mkl_common/mkl_blas_backend.hxx b/src/blas/backends/mkl_common/mkl_blas_backend.hxx index 5e22ef117..498a036af 100644 --- a/src/blas/backends/mkl_common/mkl_blas_backend.hxx +++ b/src/blas/backends/mkl_common/mkl_blas_backend.hxx @@ -1649,6 +1649,73 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +void omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size); + +void omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + +void omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + +void omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + double alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size); + +void imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, float beta, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, double beta, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + // batch, usm sycl::event syrk_batch(sycl::queue &queue, const uplo *upper_lower, const transpose *trans, @@ -2098,3 +2165,77 @@ sycl::event trsm_batch(sycl::queue &queue, const side *left_right, const uplo *u std::complex **b, const std::int64_t *ldb, std::int64_t group_count, const std::int64_t *group_size, const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, + float *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + double alpha, const double *a, std::int64_t lda, std::int64_t stride_a, + double *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + float alpha, float *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + double alpha, double *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, float beta, const float *b, std::int64_t ldb, + std::int64_t stride_b, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stride_a, double beta, const double *b, std::int64_t ldb, + std::int64_t stride_b, double *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); diff --git a/src/blas/backends/netlib/netlib_batch.cxx b/src/blas/backends/netlib/netlib_batch.cxx index 23009707c..5d6c08795 100644 --- a/src/blas/backends/netlib/netlib_batch.cxx +++ b/src/blas/backends/netlib/netlib_batch.cxx @@ -379,6 +379,148 @@ void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n #endif } +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +#endif +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +#endif +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +#endif +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +#endif +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +#endif +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +#endif +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +#endif +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +#endif +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +#endif +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +#endif +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +#endif +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +#endif +} + // USM APIs sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, @@ -1114,3 +1256,151 @@ sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, in throw unimplemented("blas", "syrk_batch", "for row_major layout"); #endif } + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +#endif +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +#endif +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +#endif +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +#endif +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +#endif +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +#endif +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +#endif +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +#endif +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +#endif +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +#endif +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +#endif +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +#endif +} diff --git a/src/blas/backends/rocblas/rocblas_batch.cpp b/src/blas/backends/rocblas/rocblas_batch.cpp index 97c369c99..801686a62 100644 --- a/src/blas/backends/rocblas/rocblas_batch.cpp +++ b/src/blas/backends/rocblas/rocblas_batch.cpp @@ -245,6 +245,88 @@ void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n int64_t batch_size) { throw unimplemented("blas", "syrk_batch", "for column_major layout"); } +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + // USM APIs sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, int64_t *incy, int64_t group_count, int64_t *group_size, @@ -710,6 +792,94 @@ sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, in throw unimplemented("blas", "syrk_batch", "for column_major layout"); } +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for column_major layout"); +} + } // namespace column_major namespace row_major { @@ -909,6 +1079,88 @@ void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n throw unimplemented("blas", "syrk_batch", "for row_major layout"); } +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, int64_t stride_a, + sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, + int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, + int64_t ldb, int64_t stride_b, int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, + int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, + double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, + sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, + sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + // USM APIs sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, int64_t *incy, int64_t group_count, int64_t *group_size, @@ -1301,6 +1553,94 @@ sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, in throw unimplemented("blas", "syrk_batch", "for row_major layout"); } +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, + int64_t stride_b, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + double *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, std::complex *ab, int64_t lda, + int64_t ldb, int64_t stride, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, + float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, + double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + +sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, std::complex alpha, const std::complex *a, + int64_t lda, int64_t stride_a, std::complex beta, + const std::complex *b, int64_t ldb, int64_t stride_b, + std::complex *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", "for row_major layout"); +} + } // namespace row_major } // namespace rocblas } // namespace blas diff --git a/src/blas/backends/rocblas/rocblas_wrappers.cpp b/src/blas/backends/rocblas/rocblas_wrappers.cpp index 94d23ce5c..a0030f33f 100644 --- a/src/blas/backends/rocblas/rocblas_wrappers.cpp +++ b/src/blas/backends/rocblas/rocblas_wrappers.cpp @@ -219,6 +219,18 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::gemm_bias, oneapi::mkl::blas::rocblas::column_major::gemm_bias, oneapi::mkl::blas::rocblas::column_major::gemm_bias, + oneapi::mkl::blas::rocblas::column_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::omatadd_batch, + oneapi::mkl::blas::rocblas::column_major::omatadd_batch, + oneapi::mkl::blas::rocblas::column_major::omatadd_batch, + oneapi::mkl::blas::rocblas::column_major::omatadd_batch, oneapi::mkl::blas::rocblas::column_major::asum, oneapi::mkl::blas::rocblas::column_major::asum, oneapi::mkl::blas::rocblas::column_major::asum, @@ -442,6 +454,18 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::gemm_bias, oneapi::mkl::blas::rocblas::column_major::gemm_bias, oneapi::mkl::blas::rocblas::column_major::gemm_bias, + oneapi::mkl::blas::rocblas::column_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::column_major::omatadd_batch, + oneapi::mkl::blas::rocblas::column_major::omatadd_batch, + oneapi::mkl::blas::rocblas::column_major::omatadd_batch, + oneapi::mkl::blas::rocblas::column_major::omatadd_batch, oneapi::mkl::blas::rocblas::row_major::asum, oneapi::mkl::blas::rocblas::row_major::asum, oneapi::mkl::blas::rocblas::row_major::asum, @@ -636,6 +660,18 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::gemm_bias, oneapi::mkl::blas::rocblas::row_major::gemm_bias, oneapi::mkl::blas::rocblas::row_major::gemm_bias, + oneapi::mkl::blas::rocblas::row_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::omatadd_batch, + oneapi::mkl::blas::rocblas::row_major::omatadd_batch, + oneapi::mkl::blas::rocblas::row_major::omatadd_batch, + oneapi::mkl::blas::rocblas::row_major::omatadd_batch, oneapi::mkl::blas::rocblas::row_major::asum, oneapi::mkl::blas::rocblas::row_major::asum, oneapi::mkl::blas::rocblas::row_major::asum, @@ -859,4 +895,16 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::gemm_bias, oneapi::mkl::blas::rocblas::row_major::gemm_bias, oneapi::mkl::blas::rocblas::row_major::gemm_bias, + oneapi::mkl::blas::rocblas::row_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::omatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::imatcopy_batch, + oneapi::mkl::blas::rocblas::row_major::omatadd_batch, + oneapi::mkl::blas::rocblas::row_major::omatadd_batch, + oneapi::mkl::blas::rocblas::row_major::omatadd_batch, + oneapi::mkl::blas::rocblas::row_major::omatadd_batch, }; diff --git a/src/blas/blas_loader.cpp b/src/blas/blas_loader.cpp index 4eefaa06c..ce01cdbb1 100644 --- a/src/blas/blas_loader.cpp +++ b/src/blas/blas_loader.cpp @@ -1452,6 +1452,115 @@ void gemm_bias(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, queue, transa, transb, offsetc, m, n, k, alpha, a, lda, ao, b, ldb, bo, beta, c, ldc, co); } +void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + function_tables[libkey].column_major_somatcopy_batch_strided_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size); +} + +void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + function_tables[libkey].column_major_domatcopy_batch_strided_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size); +} + +void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + function_tables[libkey].column_major_comatcopy_batch_strided_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size); +} + +void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + function_tables[libkey].column_major_zomatcopy_batch_strided_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size); +} + +void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + function_tables[libkey].column_major_simatcopy_batch_strided_sycl(queue, trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + function_tables[libkey].column_major_dimatcopy_batch_strided_sycl(queue, trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + function_tables[libkey].column_major_cimatcopy_batch_strided_sycl(queue, trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + function_tables[libkey].column_major_zimatcopy_batch_strided_sycl(queue, trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, float beta, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].column_major_somatadd_batch_strided_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size); +} + +void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, double beta, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].column_major_domatadd_batch_strided_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size); +} + +void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].column_major_comatadd_batch_strided_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size); +} + +void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].column_major_zomatadd_batch_strided_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size); +} + // USM APIs sycl::event asum(oneapi::mkl::device libkey, sycl::queue &queue, std::int64_t n, @@ -3324,6 +3433,122 @@ sycl::event gemm_bias(oneapi::mkl::device libkey, sycl::queue &queue, transpose dependencies); } +sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_somatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, dependencies); +} + +sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_domatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, dependencies); +} + +sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].column_major_comatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, dependencies); +} + +sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].column_major_zomatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, dependencies); +} + +sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_simatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); +} + +sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, double *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].column_major_dimatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); +} + +sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_cimatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); +} + +sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_zimatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); +} + +sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_somatadd_batch_strided_usm_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_domatadd_batch_strided_usm_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].column_major_comatadd_batch_strided_usm_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].column_major_zomatadd_batch_strided_usm_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size, dependencies); +} + } //namespace detail } //namespace column_major namespace row_major { @@ -4749,6 +4974,115 @@ void gemm_bias(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, queue, transa, transb, offsetc, m, n, k, alpha, a, lda, ao, b, ldb, bo, beta, c, ldc, co); } +void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + function_tables[libkey].row_major_somatcopy_batch_strided_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size); +} + +void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + function_tables[libkey].row_major_domatcopy_batch_strided_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size); +} + +void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + function_tables[libkey].row_major_comatcopy_batch_strided_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size); +} + +void omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + function_tables[libkey].row_major_zomatcopy_batch_strided_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size); +} + +void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + function_tables[libkey].row_major_simatcopy_batch_strided_sycl(queue, trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + function_tables[libkey].row_major_dimatcopy_batch_strided_sycl(queue, trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + function_tables[libkey].row_major_cimatcopy_batch_strided_sycl(queue, trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + function_tables[libkey].row_major_zimatcopy_batch_strided_sycl(queue, trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, float beta, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].row_major_somatadd_batch_strided_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size); +} + +void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, double beta, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].row_major_domatadd_batch_strided_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size); +} + +void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].row_major_comatadd_batch_strided_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size); +} + +void omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].row_major_zomatadd_batch_strided_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size); +} + // USM APIs sycl::event asum(oneapi::mkl::device libkey, sycl::queue &queue, std::int64_t n, @@ -6617,6 +6951,122 @@ sycl::event gemm_bias(oneapi::mkl::device libkey, sycl::queue &queue, transpose dependencies); } +sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_somatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, dependencies); +} + +sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_domatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, dependencies); +} + +sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].row_major_comatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, dependencies); +} + +sycl::event omatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].row_major_zomatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, dependencies); +} + +sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_simatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); +} + +sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, double *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].row_major_dimatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); +} + +sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_cimatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); +} + +sycl::event imatcopy_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_zimatcopy_batch_strided_usm_sycl( + queue, trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); +} + +sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_somatadd_batch_strided_usm_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_domatadd_batch_strided_usm_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].row_major_comatadd_batch_strided_usm_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event omatadd_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].row_major_zomatadd_batch_strided_usm_sycl( + queue, transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, c, ldc, + stride_c, batch_size, dependencies); +} + } //namespace detail } //namespace row_major } //namespace blas diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index 4af1106d8..2785dace0 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -943,6 +943,76 @@ typedef struct { sycl::buffer &a, std::int64_t lda, uint8_t ao, sycl::buffer &b, std::int64_t ldb, uint8_t bo, float beta, sycl::buffer &c, std::int64_t ldc, sycl::buffer &co); + void (*column_major_somatcopy_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size); + void (*column_major_domatcopy_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + double alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size); + void (*column_major_comatcopy_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + void (*column_major_zomatcopy_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + void (*column_major_simatcopy_batch_strided_sycl)(sycl::queue &queue, + oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, float alpha, + sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size); + void (*column_major_dimatcopy_batch_strided_sycl)(sycl::queue &queue, + oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, double alpha, + sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size); + void (*column_major_cimatcopy_batch_strided_sycl)(sycl::queue &queue, + oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); + void (*column_major_zimatcopy_batch_strided_sycl)(sycl::queue &queue, + oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); + void (*column_major_somatadd_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, float beta, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*column_major_domatadd_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, double beta, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*column_major_comatadd_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + void (*column_major_zomatadd_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); // USM APIs @@ -2112,6 +2182,68 @@ typedef struct { const std::uint8_t *a, std::int64_t lda, std::uint8_t ao, const std::uint8_t *b, std::int64_t ldb, std::uint8_t bo, float beta, std::int32_t *c, std::int64_t ldc, const std::int32_t *co, const std::vector &dependencies); + sycl::event (*column_major_somatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, float *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies); + sycl::event (*column_major_domatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + double alpha, const double *a, std::int64_t lda, std::int64_t stride_a, double *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies); + sycl::event (*column_major_comatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_zomatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_simatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + float alpha, float *ab, std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_dimatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + double alpha, double *ab, std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_cimatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_zimatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_somatadd_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, float beta, const float *b, std::int64_t ldb, std::int64_t stride_b, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies); + sycl::event (*column_major_domatadd_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stride_a, double beta, const double *b, std::int64_t ldb, + std::int64_t stride_b, double *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_comatadd_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies); + sycl::event (*column_major_zomatadd_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies); // Buffer APIs @@ -3002,6 +3134,76 @@ typedef struct { sycl::buffer &a, std::int64_t lda, uint8_t ao, sycl::buffer &b, std::int64_t ldb, uint8_t bo, float beta, sycl::buffer &c, std::int64_t ldc, sycl::buffer &co); + void (*row_major_somatcopy_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size); + void (*row_major_domatcopy_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + void (*row_major_comatcopy_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + void (*row_major_zomatcopy_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size); + void (*row_major_simatcopy_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size); + void (*row_major_dimatcopy_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size); + void (*row_major_cimatcopy_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, + std::complex alpha, + sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); + void (*row_major_zimatcopy_batch_strided_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, + std::complex alpha, + sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size); + void (*row_major_somatadd_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, float beta, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*row_major_domatadd_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, double beta, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*row_major_comatadd_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + void (*row_major_zomatadd_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); // USM APIs @@ -4175,6 +4377,68 @@ typedef struct { const std::uint8_t *a, std::int64_t lda, std::uint8_t ao, const std::uint8_t *b, std::int64_t ldb, std::uint8_t bo, float beta, std::int32_t *c, std::int64_t ldc, const std::int32_t *co, const std::vector &dependencies); + sycl::event (*row_major_somatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, float *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies); + sycl::event (*row_major_domatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + double alpha, const double *a, std::int64_t lda, std::int64_t stride_a, double *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies); + sycl::event (*row_major_comatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_zomatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_simatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + float alpha, float *ab, std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_dimatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + double alpha, double *ab, std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_cimatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_zimatcopy_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_somatadd_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, float beta, const float *b, std::int64_t ldb, std::int64_t stride_b, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies); + sycl::event (*row_major_domatadd_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stride_a, double beta, const double *b, std::int64_t ldb, + std::int64_t stride_b, double *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_comatadd_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies); + sycl::event (*row_major_zomatadd_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies); } blas_function_table_t; diff --git a/tests/unit_tests/blas/batch/CMakeLists.txt b/tests/unit_tests/blas/batch/CMakeLists.txt index d28828e8a..ba939de29 100644 --- a/tests/unit_tests/blas/batch/CMakeLists.txt +++ b/tests/unit_tests/blas/batch/CMakeLists.txt @@ -18,7 +18,7 @@ #=============================================================================== # Build object from all test sources -set(BATCH_SOURCES "copy_batch_stride.cpp" "axpy_batch_stride.cpp" "dgmm_batch_stride.cpp" "gemm_batch_stride.cpp" "gemv_batch_stride.cpp" "trsm_batch_stride.cpp" "syrk_batch_stride.cpp" "copy_batch_usm.cpp" "copy_batch_stride_usm.cpp" "axpy_batch_usm.cpp" "axpy_batch_stride_usm.cpp" "dgmm_batch_usm.cpp" "dgmm_batch_stride_usm.cpp" "gemm_batch_usm.cpp" "gemm_batch_stride_usm.cpp" "gemv_batch_usm.cpp" "gemv_batch_stride_usm.cpp" "trsm_batch_usm.cpp" "trsm_batch_stride_usm.cpp" "syrk_batch_usm.cpp" "syrk_batch_stride_usm.cpp") +set(BATCH_SOURCES "copy_batch_stride.cpp" "axpy_batch_stride.cpp" "dgmm_batch_stride.cpp" "gemm_batch_stride.cpp" "gemv_batch_stride.cpp" "trsm_batch_stride.cpp" "syrk_batch_stride.cpp" "copy_batch_usm.cpp" "copy_batch_stride_usm.cpp" "axpy_batch_usm.cpp" "axpy_batch_stride_usm.cpp" "dgmm_batch_usm.cpp" "dgmm_batch_stride_usm.cpp" "gemm_batch_usm.cpp" "gemm_batch_stride_usm.cpp" "gemv_batch_usm.cpp" "gemv_batch_stride_usm.cpp" "trsm_batch_usm.cpp" "trsm_batch_stride_usm.cpp" "syrk_batch_usm.cpp" "syrk_batch_stride_usm.cpp" "omatcopy_batch_stride.cpp" "omatcopy_batch_stride_usm.cpp" "imatcopy_batch_stride.cpp" "imatcopy_batch_stride_usm.cpp" "omatadd_batch_stride.cpp" "omatadd_batch_stride_usm.cpp") if(BUILD_SHARED_LIBS) add_library(blas_batch_rt OBJECT ${BATCH_SOURCES}) diff --git a/tests/unit_tests/blas/batch/imatcopy_batch_stride.cpp b/tests/unit_tests/blas/batch/imatcopy_batch_stride.cpp new file mode 100644 index 000000000..f2535e84b --- /dev/null +++ b/tests/unit_tests/blas/batch/imatcopy_batch_stride.cpp @@ -0,0 +1,209 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "allocator_helper.hpp" +#include "cblas.h" +#include "oneapi/mkl/detail/config.hpp" +#include "oneapi/mkl.hpp" +#include "onemkl_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; +using std::vector; + +extern std::vector devices; + +namespace { + +template +int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { + // Prepare data. + int64_t m, n; + int64_t lda, ldb; + oneapi::mkl::transpose trans; + fp alpha; + int64_t i, tmp; + + batch_size = 1 + std::rand() % 20; + m = 1 + std::rand() % 50; + n = 1 + std::rand() % 50; + lda = std::max(m, n); + ldb = std::max(m, n); + alpha = rand_scalar(); + + if ((std::is_same::value) || (std::is_same::value)) { + trans = (oneapi::mkl::transpose)(std::rand() % 2); + } + else { + tmp = std::rand() % 3; + if (tmp == 2) + trans = oneapi::mkl::transpose::conjtrans; + else + trans = (oneapi::mkl::transpose)tmp; + } + + int64_t stride_a, stride_b, stride; + switch (layout) { + case oneapi::mkl::layout::column_major: + stride_a = lda * m; + stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * m : ldb * n; + stride = std::max(stride_a, stride_b); + break; + case oneapi::mkl::layout::row_major: + stride_a = lda * n; + stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; + stride = std::max(stride_a, stride_b); + break; + default: break; + } + + vector> AB(stride * batch_size), AB_ref(stride * batch_size); + + rand_matrix(AB.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride * batch_size, 1, stride * batch_size); + copy_matrix(AB.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride * batch_size, 1, stride * batch_size, AB_ref.data()); + + // Call reference IMATCOPY_BATCH_STRIDE. + int m_ref = (int)m; + int n_ref = (int)n; + int lda_ref = (int)lda; + int ldb_ref = (int)ldb; + int batch_size_ref = (int)batch_size; + for (i = 0; i < batch_size_ref; i++) { + imatcopy_ref(layout, trans, m_ref, n_ref, alpha, AB_ref.data() + stride * i, + lda_ref, ldb_ref); + } + + // Call DPC++ IMATCOPY_BATCH_STRIDE + + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const &e) { + std::cout << "Caught asynchronous SYCL exception during IMATCOPY_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + queue main_queue(*dev, exception_handler); + + buffer AB_buffer(AB.data(), range<1>(AB.size())); + + try { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::mkl::layout::column_major: + oneapi::mkl::blas::column_major::imatcopy_batch( + main_queue, trans, m, n, alpha, AB_buffer, lda, ldb, stride, batch_size); + break; + case oneapi::mkl::layout::row_major: + oneapi::mkl::blas::row_major::imatcopy_batch( + main_queue, trans, m, n, alpha, AB_buffer, lda, ldb, stride, batch_size); + break; + default: break; + } +#else + switch (layout) { + case oneapi::mkl::layout::column_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy_batch, + trans, m, n, alpha, AB_buffer, lda, ldb, stride, batch_size); + break; + case oneapi::mkl::layout::row_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy_batch, trans, + m, n, alpha, AB_buffer, lda, ldb, stride, batch_size); + break; + default: break; + } +#endif + } + catch (exception const &e) { + std::cout << "Caught synchronous SYCL exception during IMATCOPY_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + + catch (const oneapi::mkl::unimplemented &e) { + return test_skipped; + } + + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of IMATCOPY_BATCH_STRIDE:\n" + << error.what() << std::endl; + } + + // Compare the results of reference implementation and DPC++ implementation. + + auto AB_accessor = AB_buffer.template get_access(); + bool good = + check_equal_matrix(AB_accessor, AB_ref, oneapi::mkl::layout::column_major, + stride * batch_size, 1, stride * batch_size, 10, std::cout); + + return (int)good; +} + +class ImatcopyBatchStrideTests + : public ::testing::TestWithParam> {}; + +TEST_P(ImatcopyBatchStrideTests, RealSinglePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(ImatcopyBatchStrideTests, RealDoublePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(ImatcopyBatchStrideTests, ComplexSinglePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(ImatcopyBatchStrideTests, ComplexDoublePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +INSTANTIATE_TEST_SUITE_P(ImatcopyBatchStrideTestSuite, ImatcopyBatchStrideTests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::mkl::layout::column_major, + oneapi::mkl::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/blas/batch/imatcopy_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/imatcopy_batch_stride_usm.cpp new file mode 100644 index 000000000..446b654f3 --- /dev/null +++ b/tests/unit_tests/blas/batch/imatcopy_batch_stride_usm.cpp @@ -0,0 +1,234 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "allocator_helper.hpp" +#include "cblas.h" +#include "oneapi/mkl/detail/config.hpp" +#include "oneapi/mkl.hpp" +#include "onemkl_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; +using std::vector; + +extern std::vector devices; + +namespace { + +template +int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const &e) { + std::cout << "Caught asynchronous SYCL exception during OMATCOPY_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + queue main_queue(*dev, exception_handler); + context cxt = main_queue.get_context(); + event done; + std::vector dependencies; + + // Prepare data. + int64_t m, n; + int64_t lda, ldb; + oneapi::mkl::transpose trans; + fp alpha; + int64_t i, tmp; + + batch_size = 1 + std::rand() % 20; + m = 1 + std::rand() % 50; + n = 1 + std::rand() % 50; + lda = std::max(m, n); + ldb = std::max(m, n); + alpha = rand_scalar(); + + if ((std::is_same::value) || (std::is_same::value)) { + trans = (oneapi::mkl::transpose)(std::rand() % 2); + } + else { + tmp = std::rand() % 3; + if (tmp == 2) + trans = oneapi::mkl::transpose::conjtrans; + else + trans = (oneapi::mkl::transpose)tmp; + } + + int64_t stride_a, stride_b, stride; + switch (layout) { + case oneapi::mkl::layout::column_major: + stride_a = lda * m; + stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * m : ldb * n; + stride = std::max(stride_a, stride_b); + break; + case oneapi::mkl::layout::row_major: + stride_a = lda * n; + stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; + stride = std::max(stride_a, stride_b); + break; + default: break; + } + + auto ua = usm_allocator(cxt, *dev); + vector AB(ua), AB_ref(ua); + + AB.resize(stride * batch_size); + AB_ref.resize(stride * batch_size); + fp **ab_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + fp **ab_ref_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + if ((ab_array == NULL) || (ab_ref_array == NULL)) { + std::cout << "Error cannot allocate arrays of pointers\n"; + oneapi::mkl::free_shared(ab_array, cxt); + oneapi::mkl::free_shared(ab_ref_array, cxt); + return false; + } + + for (i = 0; i < batch_size; i++) { + ab_array[i] = &AB[i * stride]; + ab_ref_array[i] = &AB_ref[i * stride]; + } + + rand_matrix(AB, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride * batch_size, 1, stride * batch_size); + copy_matrix(AB, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride * batch_size, 1, stride * batch_size, AB_ref); + + // Call reference IMATCOPY_BATCH_STRIDE. + int m_ref = (int)m; + int n_ref = (int)n; + int lda_ref = (int)lda; + int ldb_ref = (int)ldb; + int batch_size_ref = (int)batch_size; + for (i = 0; i < batch_size_ref; i++) { + imatcopy_ref(layout, trans, m_ref, n_ref, alpha, ab_ref_array[i], + lda_ref, ldb_ref); + } + + // Call DPC++ IMATCOPY_BATCH_STRIDE + try { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::mkl::layout::column_major: + done = oneapi::mkl::blas::column_major::imatcopy_batch( + main_queue, trans, m, n, alpha, &AB[0], lda, ldb, stride, batch_size, + dependencies); + break; + case oneapi::mkl::layout::row_major: + done = oneapi::mkl::blas::row_major::imatcopy_batch(main_queue, trans, m, n, alpha, + &AB[0], lda, ldb, stride, + batch_size, dependencies); + break; + default: break; + } + done.wait(); +#else + switch (layout) { + case oneapi::mkl::layout::column_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy_batch, + trans, m, n, alpha, &AB[0], lda, ldb, stride, batch_size, + dependencies); + break; + case oneapi::mkl::layout::row_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy_batch, trans, + m, n, alpha, &AB[0], lda, ldb, stride, batch_size, dependencies); + break; + default: break; + } + main_queue.wait(); +#endif + } + catch (exception const &e) { + std::cout << "Caught synchronous SYCL exception during IMATCOPY_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + + catch (const oneapi::mkl::unimplemented &e) { + oneapi::mkl::free_shared(ab_array, cxt); + oneapi::mkl::free_shared(ab_ref_array, cxt); + return test_skipped; + } + + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of IMATCOPY_BATCH_STRIDE:\n" + << error.what() << std::endl; + } + + // Compare the results of reference implementation and DPC++ implementation. + bool good = + check_equal_matrix(AB, AB_ref, oneapi::mkl::layout::column_major, stride * batch_size, 1, + stride * batch_size, 10, std::cout); + + oneapi::mkl::free_shared(ab_array, cxt); + oneapi::mkl::free_shared(ab_ref_array, cxt); + + return (int)good; +} + +class ImatcopyBatchStrideUsmTests + : public ::testing::TestWithParam> {}; + +TEST_P(ImatcopyBatchStrideUsmTests, RealSinglePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(ImatcopyBatchStrideUsmTests, RealDoublePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(ImatcopyBatchStrideUsmTests, ComplexSinglePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(ImatcopyBatchStrideUsmTests, ComplexDoublePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +INSTANTIATE_TEST_SUITE_P(ImatcopyBatchStrideUsmTestSuite, ImatcopyBatchStrideUsmTests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::mkl::layout::column_major, + oneapi::mkl::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/blas/batch/omatadd_batch_stride.cpp b/tests/unit_tests/blas/batch/omatadd_batch_stride.cpp new file mode 100644 index 000000000..faa9e2c85 --- /dev/null +++ b/tests/unit_tests/blas/batch/omatadd_batch_stride.cpp @@ -0,0 +1,232 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "allocator_helper.hpp" +#include "cblas.h" +#include "oneapi/mkl/detail/config.hpp" +#include "oneapi/mkl.hpp" +#include "onemkl_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; +using std::vector; + +extern std::vector devices; + +namespace { + +template +int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { + // Prepare data. + int64_t m, n; + int64_t lda, ldb, ldc; + oneapi::mkl::transpose transa, transb; + fp alpha, beta; + int64_t i, tmp; + + batch_size = 1 + std::rand() % 20; + m = 1 + std::rand() % 50; + n = 1 + std::rand() % 50; + lda = std::max(m, n); + ldb = std::max(m, n); + ldc = std::max(m, n); + alpha = rand_scalar(); + beta = rand_scalar(); + + if ((std::is_same::value) || (std::is_same::value)) { + transa = (oneapi::mkl::transpose)(std::rand() % 2); + transb = (oneapi::mkl::transpose)(std::rand() % 2); + } + else { + tmp = std::rand() % 3; + if (tmp == 2) + transa = oneapi::mkl::transpose::conjtrans; + else + transa = (oneapi::mkl::transpose)tmp; + tmp = std::rand() % 3; + if (tmp == 2) + transb = oneapi::mkl::transpose::conjtrans; + else + transb = (oneapi::mkl::transpose)tmp; + } + + int64_t stride_a, stride_b, stride_c; + + switch (layout) { + case oneapi::mkl::layout::column_major: + stride_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * n : lda * m; + stride_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; + stride_c = ldc * n; + break; + case oneapi::mkl::layout::row_major: + stride_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * m : lda * n; + stride_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * m : ldb * n; + stride_c = ldc * m; + break; + default: break; + } + + vector> A(stride_a * batch_size), B(stride_b * batch_size), + C(stride_c * batch_size), C_ref(stride_c * batch_size); + + rand_matrix(A.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_a * batch_size, 1, stride_a * batch_size); + rand_matrix(B.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_b * batch_size, 1, stride_b * batch_size); + rand_matrix(C.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_c * batch_size, 1, stride_c * batch_size); + copy_matrix(C.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_c * batch_size, 1, stride_c * batch_size, C_ref.data()); + + + // Call reference OMATADD_BATCH_STRIDE. + int m_ref = (int)m; + int n_ref = (int)n; + int lda_ref = (int)lda; + int ldb_ref = (int)ldb; + int ldc_ref = (int)ldc; + int batch_size_ref = (int)batch_size; + for (i = 0; i < batch_size_ref; i++) { + omatadd_ref(layout, transa, transb, m_ref, n_ref, alpha, A.data() + stride_a * i, + lda_ref, beta, B.data() + stride_b * i, ldb_ref, + C_ref.data() + stride_c * i, ldc_ref); + } + + // Call DPC++ OMATADD_BATCH_STRIDE + + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const &e) { + std::cout << "Caught asynchronous SYCL exception during OMATADD_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + queue main_queue(*dev, exception_handler); + + buffer A_buffer(A.data(), range<1>(A.size())); + buffer B_buffer(B.data(), range<1>(B.size())); + buffer C_buffer(C.data(), range<1>(C.size())); + + try { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::mkl::layout::column_major: + oneapi::mkl::blas::column_major::omatadd_batch( + main_queue, transa, transb, m, n, alpha, A_buffer, lda, stride_a, beta, + B_buffer, ldb, stride_b, C_buffer, ldc, stride_c, batch_size); + break; + case oneapi::mkl::layout::row_major: + oneapi::mkl::blas::row_major::omatadd_batch( + main_queue, transa, transb, m, n, alpha, A_buffer, lda, stride_a, beta, + B_buffer, ldb, stride_b, C_buffer, ldc, stride_c, batch_size); + break; + default: break; + } +#else + switch (layout) { + case oneapi::mkl::layout::column_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd_batch, + transa, transb, m, n, alpha, A_buffer, lda, stride_a, beta, + B_buffer, ldb, stride_b, C_buffer, ldc, stride_c, batch_size); + break; + case oneapi::mkl::layout::row_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd_batch, transa, + transb, m, n, alpha, A_buffer, lda, stride_a, beta, B_buffer, + ldb, stride_b, C_buffer, ldc, stride_c, batch_size); + break; + default: break; + } +#endif + } + catch (exception const &e) { + std::cout << "Caught synchronous SYCL exception during OMATADD_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + + catch (const oneapi::mkl::unimplemented &e) { + return test_skipped; + } + + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of OMATADD_BATCH_STRIDE:\n" + << error.what() << std::endl; + } + + // Compare the results of reference implementation and DPC++ implementation. + + auto C_accessor = C_buffer.template get_access(); + bool good = + check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::column_major, + stride_c * batch_size, 1, stride_c * batch_size, 10, std::cout); + + return (int)good; +} + +class OmataddBatchStrideTests + : public ::testing::TestWithParam> {}; + +TEST_P(OmataddBatchStrideTests, RealSinglePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmataddBatchStrideTests, RealDoublePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmataddBatchStrideTests, ComplexSinglePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmataddBatchStrideTests, ComplexDoublePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +INSTANTIATE_TEST_SUITE_P(OmataddBatchStrideTestSuite, OmataddBatchStrideTests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::mkl::layout::column_major, + oneapi::mkl::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/blas/batch/omatadd_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/omatadd_batch_stride_usm.cpp new file mode 100644 index 000000000..9bb6b266d --- /dev/null +++ b/tests/unit_tests/blas/batch/omatadd_batch_stride_usm.cpp @@ -0,0 +1,263 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "allocator_helper.hpp" +#include "cblas.h" +#include "oneapi/mkl/detail/config.hpp" +#include "oneapi/mkl.hpp" +#include "onemkl_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; +using std::vector; + +extern std::vector devices; + +namespace { + +template +int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const &e) { + std::cout << "Caught asynchronous SYCL exception during OMATADD_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + queue main_queue(*dev, exception_handler); + context cxt = main_queue.get_context(); + event done; + std::vector dependencies; + + // Prepare data. + int64_t m, n; + int64_t lda, ldb, ldc; + oneapi::mkl::transpose transa, transb; + fp alpha, beta; + int64_t i, tmp; + + batch_size = 1 + std::rand() % 20; + m = 1 + std::rand() % 50; + n = 1 + std::rand() % 50; + lda = std::max(m, n); + ldb = std::max(m, n); + ldc = std::max(m, n); + alpha = rand_scalar(); + beta = rand_scalar(); + + if ((std::is_same::value) || (std::is_same::value)) { + transa = (oneapi::mkl::transpose)(std::rand() % 2); + transb = (oneapi::mkl::transpose)(std::rand() % 2); + } + else { + tmp = std::rand() % 3; + if (tmp == 2) + transa = oneapi::mkl::transpose::conjtrans; + else + transa = (oneapi::mkl::transpose)tmp; + tmp = std::rand() % 3; + if (tmp == 2) + transb = oneapi::mkl::transpose::conjtrans; + else + transb = (oneapi::mkl::transpose)tmp; + } + + int64_t stride_a, stride_b, stride_c; + + switch (layout) { + case oneapi::mkl::layout::column_major: + stride_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * n : lda * m; + stride_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; + stride_c = ldc * n; + break; + case oneapi::mkl::layout::row_major: + stride_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * m : lda * n; + stride_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * m : ldb * n; + stride_c = ldc * m; + break; + default: break; + } + + auto ua = usm_allocator(cxt, *dev); + vector A(ua), B(ua), C(ua), C_ref(ua); + + A.resize(stride_a * batch_size); + B.resize(stride_b * batch_size); + C.resize(stride_c * batch_size); + C_ref.resize(stride_c * batch_size); + + fp **a_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + fp **b_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + fp **c_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + fp **c_ref_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + + if ((a_array == NULL) || (b_array == NULL) || (c_array == NULL) || (c_ref_array == NULL)) { + std::cout << "Error cannot allocate arrays of pointers\n"; + oneapi::mkl::free_shared(a_array, cxt); + oneapi::mkl::free_shared(b_array, cxt); + oneapi::mkl::free_shared(c_array, cxt); + oneapi::mkl::free_shared(c_ref_array, cxt); + return false; + } + + for (i = 0; i < batch_size; i++) { + a_array[i] = &A[i * stride_a]; + b_array[i] = &B[i * stride_b]; + c_array[i] = &C[i * stride_c]; + c_ref_array[i] = &C_ref[i * stride_c]; + } + + rand_matrix(A, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_a * batch_size, 1, stride_a * batch_size); + rand_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_b * batch_size, 1, stride_b * batch_size); + rand_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_c * batch_size, 1, stride_c * batch_size); + copy_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_c * batch_size, 1, stride_c * batch_size, C_ref); + + // Call reference OMATADD_BATCH_STRIDE. + int m_ref = (int)m; + int n_ref = (int)n; + int lda_ref = (int)lda; + int ldb_ref = (int)ldb; + int ldc_ref = (int)ldc; + int batch_size_ref = (int)batch_size; + for (i = 0; i < batch_size_ref; i++) { + omatadd_ref(layout, transa, transb, m_ref, n_ref, alpha, a_array[i], + lda_ref, beta, b_array[i], ldb_ref, c_ref_array[i], ldc_ref); + } + + // Call DPC++ OMATADD_BATCH_STRIDE + try { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::mkl::layout::column_major: + done = oneapi::mkl::blas::column_major::omatadd_batch( + main_queue, transa, transb, m, n, alpha, &A[0], lda, stride_a, beta, &B[0], ldb, + stride_b, &C[0], ldc, stride_c, batch_size, dependencies); + break; + case oneapi::mkl::layout::row_major: + done = oneapi::mkl::blas::row_major::omatadd_batch( + main_queue, transa, transb, m, n, alpha, &A[0], lda, stride_a, beta, &B[0], ldb, + stride_b, &C[0], ldc, stride_c, batch_size, dependencies); + break; + default: break; + } + done.wait(); +#else + switch (layout) { + case oneapi::mkl::layout::column_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd_batch, + transa, transb, m, n, alpha, &A[0], lda, stride_a, beta, &B[0], + ldb, stride_b, &C[0], ldc, stride_c, batch_size, dependencies); + break; + case oneapi::mkl::layout::row_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd_batch, transa, + transb, m, n, alpha, &A[0], lda, stride_a, beta, &B[0], ldb, + stride_b, &C[0], ldc, stride_c, batch_size, dependencies); + break; + default: break; + } + main_queue.wait(); +#endif + } + catch (exception const &e) { + std::cout << "Caught synchronous SYCL exception during OMATADD_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + + catch (const oneapi::mkl::unimplemented &e) { + oneapi::mkl::free_shared(a_array, cxt); + oneapi::mkl::free_shared(b_array, cxt); + oneapi::mkl::free_shared(c_array, cxt); + oneapi::mkl::free_shared(c_ref_array, cxt); + return test_skipped; + } + + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of OMATADD_BATCH_STRIDE:\n" + << error.what() << std::endl; + } + + // Compare the results of reference implementation and DPC++ implementation. + bool good = + check_equal_matrix(C, C_ref, oneapi::mkl::layout::column_major, stride_c * batch_size, 1, + stride_c * batch_size, 10, std::cout); + + oneapi::mkl::free_shared(a_array, cxt); + oneapi::mkl::free_shared(b_array, cxt); + oneapi::mkl::free_shared(c_array, cxt); + oneapi::mkl::free_shared(c_ref_array, cxt); + + return (int)good; +} + +class OmataddBatchStrideUsmTests + : public ::testing::TestWithParam> {}; + +TEST_P(OmataddBatchStrideUsmTests, RealSinglePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmataddBatchStrideUsmTests, RealDoublePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmataddBatchStrideUsmTests, ComplexSinglePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmataddBatchStrideUsmTests, ComplexDoublePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +INSTANTIATE_TEST_SUITE_P(OmataddBatchStrideUsmTestSuite, OmataddBatchStrideUsmTests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::mkl::layout::column_major, + oneapi::mkl::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/blas/batch/omatcopy_batch_stride.cpp b/tests/unit_tests/blas/batch/omatcopy_batch_stride.cpp new file mode 100644 index 000000000..a31026693 --- /dev/null +++ b/tests/unit_tests/blas/batch/omatcopy_batch_stride.cpp @@ -0,0 +1,214 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "allocator_helper.hpp" +#include "cblas.h" +#include "oneapi/mkl/detail/config.hpp" +#include "oneapi/mkl.hpp" +#include "onemkl_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; +using std::vector; + +extern std::vector devices; + +namespace { + +template +int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { + // Prepare data. + int64_t m, n; + int64_t lda, ldb; + oneapi::mkl::transpose trans; + fp alpha; + int64_t i, tmp; + + batch_size = 1 + std::rand() % 20; + m = 1 + std::rand() % 50; + n = 1 + std::rand() % 50; + lda = std::max(m, n); + ldb = std::max(m, n); + alpha = rand_scalar(); + + if ((std::is_same::value) || (std::is_same::value)) { + trans = (oneapi::mkl::transpose)(std::rand() % 2); + } + else { + tmp = std::rand() % 3; + if (tmp == 2) + trans = oneapi::mkl::transpose::conjtrans; + else + trans = (oneapi::mkl::transpose)tmp; + } + + int64_t stride_a, stride_b; + + switch (layout) { + case oneapi::mkl::layout::column_major: + stride_a = lda * n; + stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; + break; + case oneapi::mkl::layout::row_major: + stride_a = lda * m; + stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * m : ldb * n; + break; + default: break; + } + + vector> A(stride_a * batch_size), B(stride_b * batch_size), + B_ref(stride_b * batch_size); + + for (i = 0; i < batch_size; i++) { + rand_matrix(A.data() + stride_a * i, layout, oneapi::mkl::transpose::nontrans, m, n, lda); + rand_matrix(B.data() + stride_b * i, layout, trans, m, n, ldb); + } + + // Call reference OMATCOPY_BATCH_STRIDE. + int m_ref = (int)m; + int n_ref = (int)n; + int lda_ref = (int)lda; + int ldb_ref = (int)ldb; + int batch_size_ref = (int)batch_size; + for (i = 0; i < batch_size_ref; i++) { + omatcopy_ref(layout, trans, m_ref, n_ref, alpha, A.data() + stride_a * i, + lda_ref, B_ref.data() + stride_b * i, ldb_ref); + } + + // Call DPC++ OMATCOPY_BATCH_STRIDE + + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const &e) { + std::cout << "Caught asynchronous SYCL exception during OMATCOPY_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + queue main_queue(*dev, exception_handler); + + buffer A_buffer(A.data(), range<1>(A.size())); + buffer B_buffer(B.data(), range<1>(B.size())); + + try { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::mkl::layout::column_major: + oneapi::mkl::blas::column_major::omatcopy_batch(main_queue, trans, m, n, alpha, + A_buffer, lda, stride_a, B_buffer, + ldb, stride_b, batch_size); + break; + case oneapi::mkl::layout::row_major: + oneapi::mkl::blas::row_major::omatcopy_batch(main_queue, trans, m, n, alpha, + A_buffer, lda, stride_a, B_buffer, ldb, + stride_b, batch_size); + break; + default: break; + } +#else + switch (layout) { + case oneapi::mkl::layout::column_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy_batch, + trans, m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, + stride_b, batch_size); + break; + case oneapi::mkl::layout::row_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy_batch, trans, + m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, stride_b, + batch_size); + break; + default: break; + } +#endif + } + catch (exception const &e) { + std::cout << "Caught synchronous SYCL exception during OMATCOPY_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + + catch (const oneapi::mkl::unimplemented &e) { + return test_skipped; + } + + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of OMATCOPY_BATCH_STRIDE:\n" + << error.what() << std::endl; + } + + // Compare the results of reference implementation and DPC++ implementation. + + auto B_accessor = B_buffer.template get_access(); + bool good = + check_equal_matrix(B_accessor, B_ref, oneapi::mkl::layout::column_major, + stride_b * batch_size, 1, stride_b * batch_size, 10, std::cout); + + return (int)good; +} + +class OmatcopyBatchStrideTests + : public ::testing::TestWithParam> {}; + +TEST_P(OmatcopyBatchStrideTests, RealSinglePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmatcopyBatchStrideTests, RealDoublePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmatcopyBatchStrideTests, ComplexSinglePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmatcopyBatchStrideTests, ComplexDoublePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +INSTANTIATE_TEST_SUITE_P(OmatcopyBatchStrideTestSuite, OmatcopyBatchStrideTests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::mkl::layout::column_major, + oneapi::mkl::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/blas/batch/omatcopy_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/omatcopy_batch_stride_usm.cpp new file mode 100644 index 000000000..cab4df84c --- /dev/null +++ b/tests/unit_tests/blas/batch/omatcopy_batch_stride_usm.cpp @@ -0,0 +1,245 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "allocator_helper.hpp" +#include "cblas.h" +#include "oneapi/mkl/detail/config.hpp" +#include "oneapi/mkl.hpp" +#include "onemkl_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; +using std::vector; + +extern std::vector devices; + +namespace { + +template +int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const &e) { + std::cout << "Caught asynchronous SYCL exception during OMATCOPY_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + queue main_queue(*dev, exception_handler); + context cxt = main_queue.get_context(); + event done; + std::vector dependencies; + + // Prepare data. + int64_t m, n; + int64_t lda, ldb; + oneapi::mkl::transpose trans; + fp alpha; + int64_t i, tmp; + + batch_size = 1 + std::rand() % 20; + m = 1 + std::rand() % 50; + n = 1 + std::rand() % 50; + lda = std::max(m, n); + ldb = std::max(m, n); + alpha = rand_scalar(); + + if ((std::is_same::value) || (std::is_same::value)) { + trans = (oneapi::mkl::transpose)(std::rand() % 2); + } + else { + tmp = std::rand() % 3; + if (tmp == 2) + trans = oneapi::mkl::transpose::conjtrans; + else + trans = (oneapi::mkl::transpose)tmp; + } + + int64_t stride_a, stride_b; + + switch (layout) { + case oneapi::mkl::layout::column_major: + stride_a = lda * n; + stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; + break; + case oneapi::mkl::layout::row_major: + stride_a = lda * m; + stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * m : ldb * n; + break; + default: break; + } + + auto ua = usm_allocator(cxt, *dev); + vector A(ua), B(ua), B_ref(ua); + + A.resize(stride_a * batch_size); + B.resize(stride_b * batch_size); + B_ref.resize(stride_b * batch_size); + + fp **a_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + fp **b_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + fp **b_ref_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + + if ((a_array == NULL) || (b_array == NULL) || (b_ref_array == NULL)) { + std::cout << "Error cannot allocate arrays of pointers\n"; + oneapi::mkl::free_shared(a_array, cxt); + oneapi::mkl::free_shared(b_array, cxt); + oneapi::mkl::free_shared(b_ref_array, cxt); + return false; + } + + for (i = 0; i < batch_size; i++) { + a_array[i] = &A[i * stride_a]; + b_array[i] = &B[i * stride_b]; + b_ref_array[i] = &B_ref[i * stride_b]; + } + + rand_matrix(A, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_a * batch_size, 1, stride_a * batch_size); + rand_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_b * batch_size, 1, stride_b * batch_size); + copy_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + stride_b * batch_size, 1, stride_b * batch_size, B_ref); + + // Call reference OMATCOPY_BATCH_STRIDE. + int m_ref = (int)m; + int n_ref = (int)n; + int lda_ref = (int)lda; + int ldb_ref = (int)ldb; + int batch_size_ref = (int)batch_size; + for (i = 0; i < batch_size_ref; i++) { + omatcopy_ref(layout, trans, m_ref, n_ref, alpha, a_array[i], + lda_ref, b_ref_array[i], ldb_ref); + } + + // Call DPC++ OMATCOPY_BATCH_STRIDE + try { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::mkl::layout::column_major: + done = oneapi::mkl::blas::column_major::omatcopy_batch( + main_queue, trans, m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, + batch_size, dependencies); + break; + case oneapi::mkl::layout::row_major: + done = oneapi::mkl::blas::row_major::omatcopy_batch( + main_queue, trans, m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, + batch_size, dependencies); + break; + default: break; + } + done.wait(); +#else + switch (layout) { + case oneapi::mkl::layout::column_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy_batch, + trans, m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, + batch_size, dependencies); + break; + case oneapi::mkl::layout::row_major: + TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy_batch, trans, + m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, + batch_size, dependencies); + break; + default: break; + } + main_queue.wait(); +#endif + } + catch (exception const &e) { + std::cout << "Caught synchronous SYCL exception during OMATCOPY_BATCH_STRIDE:\n" + << e.what() << std::endl; + print_error_code(e); + } + + catch (const oneapi::mkl::unimplemented &e) { + oneapi::mkl::free_shared(a_array, cxt); + oneapi::mkl::free_shared(b_array, cxt); + oneapi::mkl::free_shared(b_ref_array, cxt); + return test_skipped; + } + + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of OMATCOPY_BATCH_STRIDE:\n" + << error.what() << std::endl; + } + + // Compare the results of reference implementation and DPC++ implementation. + bool good = + check_equal_matrix(B, B_ref, oneapi::mkl::layout::column_major, stride_b * batch_size, 1, + stride_b * batch_size, 10, std::cout); + + oneapi::mkl::free_shared(a_array, cxt); + oneapi::mkl::free_shared(b_array, cxt); + oneapi::mkl::free_shared(b_ref_array, cxt); + + return (int)good; +} + +class OmatcopyBatchStrideUsmTests + : public ::testing::TestWithParam> {}; + +TEST_P(OmatcopyBatchStrideUsmTests, RealSinglePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmatcopyBatchStrideUsmTests, RealDoublePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmatcopyBatchStrideUsmTests, ComplexSinglePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +TEST_P(OmatcopyBatchStrideUsmTests, ComplexDoublePrecision) { + EXPECT_TRUEORSKIP( + test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); +} + +INSTANTIATE_TEST_SUITE_P(OmatcopyBatchStrideUsmTestSuite, OmatcopyBatchStrideUsmTests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::mkl::layout::column_major, + oneapi::mkl::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/blas/include/reference_blas_templates.hpp b/tests/unit_tests/blas/include/reference_blas_templates.hpp index 8e1a16202..f0ec7536b 100644 --- a/tests/unit_tests/blas/include/reference_blas_templates.hpp +++ b/tests/unit_tests/blas/include/reference_blas_templates.hpp @@ -1983,4 +1983,174 @@ void dgmm(CBLAS_LAYOUT layout, CBLAS_SIDE left_right, const int *m, const int *n } } +// std::conj can take a real type as input, but still returns a complex type. +// This version always returns the same type it has as input +template +fp sametype_conj(fp x) { + if constexpr (std::is_same_v> || + std::is_same_v>) { + return std::conj(x); + } + else { + return x; + } +} + +template +void omatcopy_ref(oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int64_t m, int64_t n, + fp alpha, fp *A, int64_t lda, fp *B, int64_t ldb) { + int64_t logical_m, logical_n; + if (layout == oneapi::mkl::layout::column_major) { + logical_m = m; + logical_n = n; + } + else { + logical_m = n; + logical_n = m; + } + if (trans == oneapi::mkl::transpose::nontrans) { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + B[j*ldb + i] = alpha * A[j*lda + i]; + } + } + } + else if (trans == oneapi::mkl::transpose::trans) { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + B[i*ldb + j] = alpha * A[j*lda + i]; + } + } + } + else { + // conjtrans + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + B[i*ldb + j] = alpha * sametype_conj(A[j*lda + i]); + } + } + } +} + +template +void imatcopy_ref(oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int64_t m, int64_t n, + fp alpha, fp *A, int64_t lda, int64_t ldb) { + int64_t logical_m, logical_n; + if (layout == oneapi::mkl::layout::column_major) { + logical_m = m; + logical_n = n; + } + else { + logical_m = n; + logical_n = m; + } + std::vector temp(m * n); + int64_t ld_temp = (trans == oneapi::mkl::transpose::nontrans ? logical_m : logical_n); + + if (trans == oneapi::mkl::transpose::nontrans) { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + temp[j*ld_temp + i] = alpha * A[j*lda + i]; + } + } + } + else if (trans == oneapi::mkl::transpose::trans) { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + temp[i*ld_temp + j] = alpha * A[j*lda + i]; + } + } + } + else { + // conjtrans + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + temp[i*ld_temp + j] = alpha * sametype_conj(A[j*lda + i]); + } + } + } + + if (trans == oneapi::mkl::transpose::nontrans) { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + A[j*ldb + i] = temp[j*ld_temp + i]; + } + } + } + else { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + A[i*ldb + j] = temp[i*ld_temp + j]; + } + } + } +} + +template +void omatadd_ref(oneapi::mkl::layout layout, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, int64_t m, int64_t n, + fp alpha, fp *A, int64_t lda, fp beta, fp *B, int64_t ldb, fp *C, int64_t ldc) { + int64_t logical_m, logical_n; + if (layout == oneapi::mkl::layout::column_major) { + logical_m = m; + logical_n = n; + } + else { + logical_m = n; + logical_n = m; + } + + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + C[j*ldc + i] = 0.0; + } + } + + if (transa == oneapi::mkl::transpose::nontrans) { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + C[j*ldc + i] += alpha * A[j*lda + i]; + } + } + } + else if (transa == oneapi::mkl::transpose::trans) { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + C[j*ldc + i] += alpha * A[i*lda + j]; + } + } + } + else { + // conjtrans + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + C[j*ldc + i] += alpha * sametype_conj(A[i*lda + j]); + } + } + } + + if (transb == oneapi::mkl::transpose::nontrans) { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + C[j*ldc + i] += beta * B[j*ldb + i]; + } + } + } + else if (transa == oneapi::mkl::transpose::trans) { + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + C[j*ldc + i] += beta * B[i*ldb + j]; + } + } + } + else { + // conjtrans + for (int64_t j = 0; j < logical_n; j++) { + for (int64_t i = 0; i < logical_m; i++) { + C[j*ldc + i] += beta * sametype_conj(B[i*ldb + j]); + } + } + } +} + #endif /* header guard */