Skip to content

Commit

Permalink
[lapack][blas][cuda] Update host task impl to use enqueue_native_comm…
Browse files Browse the repository at this point in the history
…and (#572)

Signed-off-by: JackAKirk <[email protected]>
  • Loading branch information
JackAKirk authored Oct 8, 2024
1 parent b2324f1 commit 7adfbcc
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 236 deletions.
32 changes: 31 additions & 1 deletion src/blas/backends/cublas/cublas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,21 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran
auto b_ = sc.get_mem<cuTypeB *>(b_acc);
auto c_ = sc.get_mem<cuTypeC *>(c_acc);
cublasStatus_t err;
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
});
});
}
Expand Down Expand Up @@ -608,12 +617,21 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra
onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
cublasStatus_t err;
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
});
});
return done;
Expand Down Expand Up @@ -687,6 +705,16 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr
int64_t offset = 0;
cublasStatus_t err;
for (int64_t i = 0; i < group_count; i++) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(
"cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i],
(int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset),
get_cublas_datatype<cuTypeA>(), (int)lda[i], (const void *const *)(b + offset),
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), (int)ldc[i],
(int)group_size[i], get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i],
Expand All @@ -695,6 +723,7 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), (int)ldc[i],
(int)group_size[i], get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
offset += group_size[i];
}
});
Expand Down Expand Up @@ -792,12 +821,13 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<const cuDataType **>(a);
auto **b_ = reinterpret_cast<cuDataType **>(b);
CUBLAS_ERROR_FUNC_T_SYNC(
cublas_native_named_func(
func_name, func, err, handle, get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i],
(int)group_size[i]);

offset += group_size[i];
}
});
Expand Down
27 changes: 27 additions & 0 deletions src/blas/backends/cublas/cublas_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ class cuda_error : virtual public std::runtime_error {
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

#define CUBLAS_ERROR_FUNC_T(name, func, err, handle, ...) \
err = func(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
throw cublas_error(std::string(name) + std::string(" : "), err); \
}

#define CUBLAS_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \
err = func(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
Expand All @@ -199,6 +205,27 @@ class cuda_error : virtual public std::runtime_error {
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

template <class Func, class... Types>
inline void cublas_native_func(Func func, cublasStatus_t err,
cublasHandle_t handle, Types... args) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC(func, err, handle, args...)
#else
CUBLAS_ERROR_FUNC_SYNC(func, err, handle, args...)
#endif
};

template <class Func, class... Types>
inline void cublas_native_named_func(const char *func_name, Func func,
cublasStatus_t err, cublasHandle_t handle,
Types... args) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, args...)
#else
CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...)
#endif
};

inline cublasOperation_t get_cublas_operation(oneapi::mkl::transpose trn) {
switch (trn) {
case oneapi::mkl::transpose::nontrans: return CUBLAS_OP_N;
Expand Down
Loading

0 comments on commit 7adfbcc

Please sign in to comment.