From 21229ee8f22a65a125b9c667c0458d9f76f7e073 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 11 Oct 2024 09:44:09 +0100 Subject: [PATCH] [cublas] add missing support for gemv_batch (#586) Signed-off-by: JackAKirk --- src/blas/backends/cublas/cublas_batch.cpp | 68 ++++++++++++++--------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 9f198b653..2975e6c58 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -502,35 +502,53 @@ sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, float *alpha, - const float **a, int64_t *lda, const float **x, int64_t *incx, float *beta, - float **y, int64_t *incy, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +template +inline sycl::event gemv_batch(const char *func_name, Func func, sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, T *alpha, const T **a, int64_t *lda, const T **x, + int64_t *incx, T *beta, T **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(m[i], n[i], lda[i], incx[i], incy[i], group_size[i]); + } + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int64_t offset = 0; + cublasStatus_t err; + auto **a_ = reinterpret_cast(a); + auto **x_ = reinterpret_cast(x); + auto **y_ = reinterpret_cast(y); + for (int64_t i = 0; i < group_count; i++) { + cublas_native_named_func( + func_name, func, err, handle, get_cublas_operation(trans[i]), + (int)m[i], (int)n[i], + (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], x_ + offset, (int)incx[i], + (cuDataType *)&beta[i], y_ + offset, (int)incy[i], (int)group_size[i]); + offset += group_size[i]; + } + }); + }); + return done; } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, double *alpha, - const double **a, int64_t *lda, const double **x, int64_t *incx, - double *beta, double **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +#define GEMV_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event gemv_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, \ + TYPE *alpha, const TYPE **a, int64_t *lda, const TYPE **x, \ + int64_t *incx, TYPE *beta, TYPE **y, int64_t *incy, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemv_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, \ + x, incx, beta, y, incy, group_count, group_size, dependencies); \ + } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +GEMV_BATCH_LAUNCHER_USM(float, cublasSgemvBatched) +GEMV_BATCH_LAUNCHER_USM(double, cublasDgemvBatched) +GEMV_BATCH_LAUNCHER_USM(std::complex, cublasCgemvBatched) +GEMV_BATCH_LAUNCHER_USM(std::complex, cublasZgemvBatched) -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +#undef GEMV_BATCH_LAUNCHER_USM sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const float *a, int64_t lda, int64_t stride_a, const float *x, int64_t incx,