From d848b4d10ebe4c09ea1cb89dba1a2ce1002107b8 Mon Sep 17 00:00:00 2001 From: Hugh Delaney <46290137+hdelan@users.noreply.github.com> Date: Fri, 2 Jun 2023 16:45:05 +0100 Subject: [PATCH] [LAPACK][CUSOLVER] Add getri batch funcs (#248) * Add getri batch funcs * Responding to comments --- src/lapack/backends/cusolver/CMakeLists.txt | 7 +- .../backends/cusolver/cusolver_batch.cpp | 422 +++++++++++------- .../backends/cusolver/cusolver_lapack.cpp | 98 ++-- 3 files changed, 304 insertions(+), 223 deletions(-) diff --git a/src/lapack/backends/cusolver/CMakeLists.txt b/src/lapack/backends/cusolver/CMakeLists.txt index e40119dfe..bb515165c 100644 --- a/src/lapack/backends/cusolver/CMakeLists.txt +++ b/src/lapack/backends/cusolver/CMakeLists.txt @@ -20,6 +20,7 @@ set(LIB_NAME onemkl_lapack_cusolver) set(LIB_OBJ ${LIB_NAME}_obj) find_package(cuSOLVER REQUIRED) +find_package(cuBLAS REQUIRED) set(SOURCES cusolver_lapack.cpp cusolver_batch.cpp $<$:cusolver_scope_handle.cpp > @@ -31,10 +32,14 @@ target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/src + ${PROJECT_SOURCE_DIR}/src/blas/backends/cublas ${CMAKE_BINARY_DIR}/bin ) target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ONEMKL::cuSOLVER::cuSOLVER) +target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + ONEMKL::cuSOLVER::cuSOLVER + ONEMKL::cuBLAS::cuBLAS) target_compile_features(${LIB_OBJ} PUBLIC cxx_std_11) set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/src/lapack/backends/cusolver/cusolver_batch.cpp b/src/lapack/backends/cusolver/cusolver_batch.cpp index 57b9f4a88..1890ca770 100644 --- a/src/lapack/backends/cusolver/cusolver_batch.cpp +++ b/src/lapack/backends/cusolver/cusolver_batch.cpp @@ -16,6 +16,7 @@ * limitations under the License. * **************************************************************************/ +#include "cublas_helper.hpp" #include "cusolver_helper.hpp" #include "cusolver_task.hpp" @@ -76,31 +77,114 @@ GEQRF_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnZgeqrf) #undef GEQRF_STRIDED_BATCH_LAUNCHER -void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, - std::int64_t stride_a, sycl::buffer &ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri_batch"); -} -void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, - std::int64_t stride_a, sycl::buffer &ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri_batch"); -} -void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer> &a, - std::int64_t lda, std::int64_t stride_a, sycl::buffer &ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - sycl::buffer> &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri_batch"); -} -void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer> &a, - std::int64_t lda, std::int64_t stride_a, sycl::buffer &ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - sycl::buffer> &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri_batch"); +template +inline void getri_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t n, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &ipiv, std::int64_t stride_ipiv, + std::int64_t batch_size, sycl::buffer &scratchpad, + std::int64_t scratchpad_size) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, lda, stride_a, stride_ipiv, batch_size, scratchpad_size); + + std::uint64_t ipiv32_size = n * batch_size; + sycl::buffer ipiv32(sycl::range<1>{ ipiv32_size }); + sycl::buffer devInfo{ batch_size }; + + queue.submit([&](sycl::handler &cgh) { + auto ipiv_acc = sycl::accessor{ ipiv, cgh, sycl::read_only }; + auto ipiv32_acc = sycl::accessor{ ipiv32, cgh, sycl::write_only }; + cgh.parallel_for(sycl::range<1>{ ipiv32_size }, + [=](sycl::id<1> index) { + ipiv32_acc[index] = + static_cast(ipiv_acc[(index / n) * stride_ipiv + index % n]); + }); + }); + + // getri_batched is contained within cublas, not cusolver. For this reason + // we need to use cublas types instead of cusolver types (as is needed for + // other lapack routines) + queue.submit([&](sycl::handler &cgh) { + using blas::cublas::cublas_error; + + sycl::accessor a_acc{ a, cgh, sycl::read_only }; + sycl::accessor scratch_acc{ scratchpad, cgh, sycl::write_only }; + sycl::accessor ipiv32_acc{ ipiv32, cgh }; + sycl::accessor devInfo_acc{ devInfo, cgh, sycl::write_only }; + + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + cublasStatus_t err; + CUresult cuda_result; + cublasHandle_t cublas_handle; + CUBLAS_ERROR_FUNC(cublasCreate, err, &cublas_handle); + CUstream cu_stream = sycl::get_native(queue); + CUBLAS_ERROR_FUNC(cublasSetStream, err, cublas_handle, cu_stream); + + auto a_ = sc.get_mem(a_acc); + auto scratch_ = sc.get_mem(scratch_acc); + auto ipiv32_ = sc.get_mem(ipiv32_acc); + auto info_ = sc.get_mem(devInfo_acc); + + CUdeviceptr a_dev; + cuDataType **a_batched = create_ptr_list_from_stride(a_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &a_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, a_dev, a_batched, sizeof(T *) * batch_size); + auto **a_dev_ = reinterpret_cast(a_dev); + + CUdeviceptr scratch_dev; + cuDataType **scratch_batched = + create_ptr_list_from_stride(scratch_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &scratch_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, scratch_dev, scratch_batched, + sizeof(T *) * batch_size); + auto **scratch_dev_ = reinterpret_cast(scratch_dev); + + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32_, + scratch_dev_, lda, info_, batch_size) + + free(a_batched); + free(scratch_batched); + cuMemFree(a_dev); + cuMemFree(scratch_dev); + }); + }); + + // The inverted matrices stored in scratch_ need to be stored in a_ + queue.submit([&](sycl::handler &cgh) { + sycl::accessor a_acc{ a, cgh, sycl::write_only }; + sycl::accessor scratch_acc{ scratchpad, cgh, sycl::read_only }; + cgh.parallel_for(sycl::range<1>{ static_cast( + sycl::max(stride_a * batch_size, lda * n * batch_size)) }, + [=](sycl::id<1> index) { a_acc[index] = scratch_acc[index]; }); + }); + + queue.submit([&](sycl::handler &cgh) { + sycl::accessor ipiv32_acc{ ipiv32, cgh, sycl::read_only }; + sycl::accessor ipiv_acc{ ipiv, cgh, sycl::write_only }; + cgh.parallel_for(sycl::range<1>{ static_cast(ipiv32_size) }, + [=](sycl::id<1> index) { + ipiv_acc[(index / n) * stride_ipiv + index % n] = + static_cast(ipiv32_acc[index]); + }); + }); } +#define GETRI_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \ + void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, \ + std::int64_t stride_a, sycl::buffer &ipiv, \ + std::int64_t stride_ipiv, std::int64_t batch_size, \ + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { \ + return getri_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, n, a, lda, stride_a, ipiv, \ + stride_ipiv, batch_size, scratchpad, scratchpad_size); \ + } + +GETRI_STRIDED_BATCH_LAUNCHER(float, cublasSgetriBatched) +GETRI_STRIDED_BATCH_LAUNCHER(double, cublasDgetriBatched) +GETRI_STRIDED_BATCH_LAUNCHER(std::complex, cublasCgetriBatched) +GETRI_STRIDED_BATCH_LAUNCHER(std::complex, cublasZgetriBatched) + +#undef GETRI_STRIDED_BATCH_LAUNCHER + template inline void getrs_batch(const char *func_name, Func func, sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, std::int64_t nrhs, @@ -459,10 +543,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m, n, lda, stride_a, stride_tau, batch_size, scratchpad_size); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -513,10 +594,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m[i], n[i], lda[i], group_sizes[i]); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -574,10 +652,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -659,10 +734,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -701,10 +773,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu // Enqueue free memory sycl::event done_freeing = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = casting_dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(casting_dependencies[i]); - } + cgh.depends_on(casting_dependencies); cgh.host_task([=](sycl::interop_handle ih) { for (int64_t global_id = 0; global_id < batch_size; ++global_id) sycl::free(ipiv32[global_id], queue); @@ -736,32 +805,108 @@ GETRF_BATCH_LAUNCHER_USM(std::complex, cusolverDnZgetrf) #undef GETRS_BATCH_LAUNCHER_USM -sycl::event getri_batch(sycl::queue &queue, std::int64_t n, float *a, std::int64_t lda, - std::int64_t stride_a, std::int64_t *ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri_batch"); -} -sycl::event getri_batch(sycl::queue &queue, std::int64_t n, double *a, std::int64_t lda, - std::int64_t stride_a, std::int64_t *ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri_batch"); -} -sycl::event getri_batch(sycl::queue &queue, std::int64_t n, std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri_batch"); -} -sycl::event getri_batch(sycl::queue &queue, std::int64_t n, std::complex *a, +template +sycl::event getri_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t n, T *a, std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, + std::int64_t stride_ipiv, std::int64_t batch_size, T *scratchpad, + std::int64_t scratchpad_size, const std::vector &dependencies) { - throw unimplemented("lapack", "getri_batch"); + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, lda, stride_a, stride_ipiv, batch_size, scratchpad_size); + + std::uint64_t ipiv32_size = n * batch_size; + int *ipiv32 = sycl::malloc_device(ipiv32_size, queue); + int *devInfo = sycl::malloc_device(batch_size, queue); + + sycl::event done_casting = queue.submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::range<1>{ static_cast(ipiv32_size) }, [=](sycl::id<1> index) { + ipiv32[index] = static_cast(ipiv[(index / n) * stride_ipiv + index % n]); + }); + }); + + // getri_batched is contained within cublas, not cusolver. For this reason + // we need to use cublas types instead of cusolver types (as is needed for + // other lapack routines) + auto done = queue.submit([&](sycl::handler &cgh) { + using blas::cublas::cublas_error; + + cgh.depends_on(done_casting); + cgh.depends_on(dependencies); + + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + cublasStatus_t err; + CUresult cuda_result; + cublasHandle_t cublas_handle; + CUBLAS_ERROR_FUNC(cublasCreate, err, &cublas_handle); + CUstream cu_stream = sycl::get_native(queue); + CUBLAS_ERROR_FUNC(cublasSetStream, err, cublas_handle, cu_stream); + + CUdeviceptr a_dev; + auto *a_ = reinterpret_cast(a); + cuDataType **a_batched = create_ptr_list_from_stride(a_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &a_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, a_dev, a_batched, sizeof(T *) * batch_size); + auto **a_dev_ = reinterpret_cast(a_dev); + + CUdeviceptr scratch_dev; + auto *scratch_ = reinterpret_cast(scratchpad); + cuDataType **scratch_batched = + create_ptr_list_from_stride(scratch_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &scratch_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, scratch_dev, scratch_batched, + sizeof(T *) * batch_size); + auto **scratch_dev_ = reinterpret_cast(scratch_dev); + + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32, + scratch_dev_, lda, devInfo, batch_size) + + free(a_batched); + free(scratch_batched); + cuMemFree(a_dev); + cuMemFree(scratch_dev); + }); + }); + + // The inverted matrices stored in scratch_ need to be stored in a_ + auto copy1 = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + cgh.parallel_for( + sycl::range<1>{ static_cast(stride_a * (batch_size - 1) + lda * n) }, + [=](sycl::id<1> index) { a[index] = scratchpad[index]; }); + }); + + auto copy2 = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + cgh.parallel_for( + sycl::range<1>{ static_cast(ipiv32_size) }, [=](sycl::id<1> index) { + ipiv[(index / n) * stride_ipiv + index % n] = static_cast(ipiv32[index]); + }); + }); + copy1.wait(); + copy2.wait(); + sycl::free(ipiv32, queue); + sycl::free(devInfo, queue); + return done; } + +#define GETRI_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event getri_batch( \ + sycl::queue &queue, std::int64_t n, TYPE *a, std::int64_t lda, std::int64_t stride_a, \ + std::int64_t *ipiv, std::int64_t stride_ipiv, std::int64_t batch_size, TYPE *scratchpad, \ + std::int64_t scratchpad_size, const std::vector &dependencies) { \ + return getri_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, n, a, lda, stride_a, ipiv, \ + stride_ipiv, batch_size, scratchpad, scratchpad_size, dependencies); \ + } + +GETRI_BATCH_LAUNCHER_USM(float, cublasSgetriBatched) +GETRI_BATCH_LAUNCHER_USM(double, cublasDgetriBatched) +GETRI_BATCH_LAUNCHER_USM(std::complex, cublasCgetriBatched) +GETRI_BATCH_LAUNCHER_USM(std::complex, cublasZgetriBatched) + +#undef GETRI_BATCH_LAUNCHER_USM + sycl::event getri_batch(sycl::queue &queue, std::int64_t *n, float **a, std::int64_t *lda, std::int64_t **ipiv, std::int64_t group_count, std::int64_t *group_sizes, float *scratchpad, std::int64_t scratchpad_size, @@ -814,10 +959,7 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu }); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); cgh.depends_on(done_casting); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -902,13 +1044,8 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu } auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - for (int64_t i = 0; i < batch_size; i++) { - cgh.depends_on(casting_dependencies[i]); - } + cgh.depends_on(dependencies); + cgh.depends_on(casting_dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -967,10 +1104,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -1020,10 +1154,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -1074,10 +1205,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(n, lda, stride_a, batch_size, scratchpad_size); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); CUdeviceptr a_dev; @@ -1135,10 +1263,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu } auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); int64_t offset = 0; @@ -1199,10 +1324,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu throw unimplemented("lapack", "potrs_batch", "cusolver potrs_batch only supports nrhs = 1"); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); CUresult cuda_result; @@ -1283,10 +1405,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu queue.submit([&](sycl::handler &h) { h.memcpy(b_dev, b, batch_size * sizeof(T *)); }); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); cgh.depends_on(done_cpy_a); cgh.depends_on(done_cpy_b); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { @@ -1340,10 +1459,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -1393,10 +1509,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -1472,35 +1585,22 @@ GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex, cusolverDnZgetrf_buff #undef GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH -template <> -std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t n, - std::int64_t lda, std::int64_t stride_a, - std::int64_t stride_ipiv, std::int64_t batch_size) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t n, - std::int64_t lda, std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t batch_size) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size>(sycl::queue &queue, std::int64_t n, - std::int64_t lda, - std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t batch_size) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size>(sycl::queue &queue, std::int64_t n, - std::int64_t lda, - std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t batch_size) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} +// Scratch memory needs to be the same size as a +#define GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t getri_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t n, std::int64_t lda, std::int64_t stride_a, \ + std::int64_t stride_ipiv, std::int64_t batch_size) { \ + assert(stride_a >= lda * n && "A matrices must not overlap"); \ + return stride_a * (batch_size - 1) + lda * n; \ + } + +GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(float) +GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(double) +GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) +GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) + +#undef GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH // cusolverDnXgetrs does not use scratchpad memory #define GETRS_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE) \ @@ -1696,32 +1796,26 @@ GETRF_GROUP_LAUNCHER_SCRATCH(std::complex, cusolverDnZgetrf_bufferSize) #undef GETRF_GROUP_LAUNCHER_SCRATCH -template <> -std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t *n, - std::int64_t *lda, std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t *n, - std::int64_t *lda, std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *n, - std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *n, - std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} +#define GETRI_GROUP_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t getri_batch_scratchpad_size(sycl::queue & queue, std::int64_t * n, \ + std::int64_t * lda, std::int64_t group_count, \ + std::int64_t * group_sizes) { \ + std::int64_t max_scratch_sz = 0; \ + for (auto group_id = 0; group_id < group_count; ++group_id) { \ + auto scratch_sz = lda[group_id] * n[group_id]; \ + if (scratch_sz > max_scratch_sz) \ + max_scratch_sz = scratch_sz; \ + } \ + return max_scratch_sz; \ + } + +GETRI_GROUP_LAUNCHER_SCRATCH(float) +GETRI_GROUP_LAUNCHER_SCRATCH(double) +GETRI_GROUP_LAUNCHER_SCRATCH(std::complex) +GETRI_GROUP_LAUNCHER_SCRATCH(std::complex) + +#undef GETRI_GROUP_LAUNCHER_SCRATCH #define GETRS_GROUP_LAUNCHER_SCRATCH(TYPE) \ template <> \ diff --git a/src/lapack/backends/cusolver/cusolver_lapack.cpp b/src/lapack/backends/cusolver/cusolver_lapack.cpp index 4fbdccc72..3fcb733db 100644 --- a/src/lapack/backends/cusolver/cusolver_lapack.cpp +++ b/src/lapack/backends/cusolver/cusolver_lapack.cpp @@ -195,26 +195,19 @@ GETRF_LAUNCHER(std::complex, cusolverDnZgetrf) #undef GETRF_LAUNCHER -void getri(sycl::queue &queue, std::int64_t n, sycl::buffer> &a, - std::int64_t lda, sycl::buffer &ipiv, - sycl::buffer> &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri"); -} -void getri(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, - sycl::buffer &ipiv, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri"); -} -void getri(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, - sycl::buffer &ipiv, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri"); -} -void getri(sycl::queue &queue, std::int64_t n, sycl::buffer> &a, - std::int64_t lda, sycl::buffer &ipiv, - sycl::buffer> &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri"); -} +#define GETRI_LAUNCHER(TYPE) \ + void getri(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, \ + sycl::buffer &ipiv, sycl::buffer &scratchpad, \ + std::int64_t scratchpad_size) { \ + return getri_batch(queue, n, a, lda, lda * n, ipiv, n, 1, scratchpad, scratchpad_size); \ + } + +GETRI_LAUNCHER(float) +GETRI_LAUNCHER(double) +GETRI_LAUNCHER(std::complex) +GETRI_LAUNCHER(std::complex) + +#undef GETRI_LAUNCHER // cusolverDnXgetrs does not use scratchpad memory template @@ -1380,26 +1373,20 @@ GETRF_LAUNCHER_USM(std::complex, cusolverDnZgetrf) #undef GETRF_LAUNCHER_USM -sycl::event getri(sycl::queue &queue, std::int64_t n, std::complex *a, std::int64_t lda, - std::int64_t *ipiv, std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri"); -} -sycl::event getri(sycl::queue &queue, std::int64_t n, double *a, std::int64_t lda, - std::int64_t *ipiv, double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri"); -} -sycl::event getri(sycl::queue &queue, std::int64_t n, float *a, std::int64_t lda, - std::int64_t *ipiv, float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri"); -} -sycl::event getri(sycl::queue &queue, std::int64_t n, std::complex *a, std::int64_t lda, - std::int64_t *ipiv, std::complex *scratchpad, - std::int64_t scratchpad_size, const std::vector &dependencies) { - throw unimplemented("lapack", "getri"); -} +#define GETRI_LAUNCHER_USM(TYPE) \ + sycl::event getri(sycl::queue &queue, std::int64_t n, TYPE *a, std::int64_t lda, \ + std::int64_t *ipiv, TYPE *scratchpad, std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return getri_batch(queue, n, a, lda, lda * n, ipiv, n, 1, scratchpad, scratchpad_size, \ + dependencies); \ + } + +GETRI_LAUNCHER_USM(float) +GETRI_LAUNCHER_USM(double) +GETRI_LAUNCHER_USM(std::complex) +GETRI_LAUNCHER_USM(std::complex) + +#undef GETRI_LAUNCHER_USM // cusolverDnXgetrs does not use scratchpad memory template @@ -2603,24 +2590,19 @@ GETRF_LAUNCHER_SCRATCH(std::complex, cusolverDnZgetrf_bufferSize) #undef GETRF_LAUNCHER_SCRATCH -template <> -std::int64_t getri_scratchpad_size(sycl::queue &queue, std::int64_t n, std::int64_t lda) { - throw unimplemented("lapack", "getri_scratchpad_size"); -} -template <> -std::int64_t getri_scratchpad_size(sycl::queue &queue, std::int64_t n, std::int64_t lda) { - throw unimplemented("lapack", "getri_scratchpad_size"); -} -template <> -std::int64_t getri_scratchpad_size>(sycl::queue &queue, std::int64_t n, - std::int64_t lda) { - throw unimplemented("lapack", "getri_scratchpad_size"); -} -template <> -std::int64_t getri_scratchpad_size>(sycl::queue &queue, std::int64_t n, - std::int64_t lda) { - throw unimplemented("lapack", "getri_scratchpad_size"); -} +#define GETRI_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t getri_scratchpad_size(sycl::queue & queue, std::int64_t n, \ + std::int64_t lda) { \ + return lda * n; \ + } + +GETRI_LAUNCHER_SCRATCH(float) +GETRI_LAUNCHER_SCRATCH(double) +GETRI_LAUNCHER_SCRATCH(std::complex) +GETRI_LAUNCHER_SCRATCH(std::complex) + +#undef GETRI_LAUNCHER_SCRATCH // cusolverDnXgetrs does not use scratchpad memory #define GETRS_LAUNCHER_SCRATCH(TYPE) \