Skip to content

Commit

Permalink
[BLAS::portBLAS backend] Add USM support for batch operators (uxlfoun…
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick authored and normallytangent committed Aug 6, 2024
1 parent a025616 commit 3c75a15
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/blas/backends/portblas/portblas_batch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,9 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
std::int64_t stride_b, float beta, float *c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size,
dependencies);
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand All @@ -713,7 +715,9 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
std::int64_t stride_b, double beta, double *c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "gemm_batch", " for USM");
CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size,
dependencies);
}

sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand Down Expand Up @@ -825,15 +829,17 @@ sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std
std::int64_t n, float alpha, const float *a, std::int64_t lda,
std::int64_t stride_a, float *b, std::int64_t ldb, std::int64_t stride_b,
std::int64_t batch_size, const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "omatcopy_batch", " for USM");
CALL_PORTBLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b,
ldb, stride_b, batch_size, dependencies);
}

sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m,
std::int64_t n, double alpha, const double *a, std::int64_t lda,
std::int64_t stride_a, double *b, std::int64_t ldb,
std::int64_t stride_b, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "omatcopy_batch", " for USM");
CALL_PORTBLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b,
ldb, stride_b, batch_size, dependencies);
}

sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m,
Expand Down Expand Up @@ -886,7 +892,9 @@ sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
float beta, const float *b, std::int64_t ldb, std::int64_t stride_b,
float *c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size, const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "omatadd_batch", " for USM");
CALL_PORTBLAS_USM_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda,
stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size,
dependencies);
}

sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand All @@ -895,7 +903,9 @@ sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
double beta, const double *b, std::int64_t ldb, std::int64_t stride_b,
double *c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size, const std::vector<sycl::event> &dependencies) {
throw unimplemented("blas", "omatadd_batch", " for USM");
CALL_PORTBLAS_USM_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda,
stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size,
dependencies);
}

sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa,
Expand Down

0 comments on commit 3c75a15

Please sign in to comment.