diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 031c11a2f..9f198b653 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -168,23 +168,19 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran auto c_ = sc.get_mem(c_acc); cublasStatus_t err; #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, - err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a_, - get_cublas_datatype(), lda, stride_a, b_, - get_cublas_datatype(), ldb, stride_b, &beta, c_, - get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), + ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #else - CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", - cublasGemmStridedBatchedEx, err, handle, - get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a_, - get_cublas_datatype(), lda, stride_a, b_, - get_cublas_datatype(), ldb, stride_b, &beta, - c_, get_cublas_datatype(), ldc, stride_c, - batch_size, get_cublas_datatype(), - cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), + ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #endif }); }); @@ -622,23 +618,19 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra auto handle = sc.get_handle(queue); cublasStatus_t err; #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, - err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a, - get_cublas_datatype(), lda, stride_a, b, - get_cublas_datatype(), ldb, stride_b, &beta, c, - get_cublas_datatype(), ldc, stride_c, batch_size, - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), + ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #else - CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx", - cublasGemmStridedBatchedEx, err, handle, - get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, &alpha, a, - get_cublas_datatype(), lda, stride_a, b, - get_cublas_datatype(), ldb, stride_b, &beta, - c, get_cublas_datatype(), ldc, stride_c, - batch_size, get_cublas_datatype(), - cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), + ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); #endif }); }); @@ -714,26 +706,23 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr cublasStatus_t err; for (int64_t i = 0; i < group_count; i++) { #ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND - CUBLAS_ERROR_FUNC_T("cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, - get_cublas_operation(transa[i]), - get_cublas_operation(transb[i]), (int)m[i], (int)n[i], - (int)k[i], &alpha[i], (const void *const *)(a + offset), - get_cublas_datatype(), (int)lda[i], - (const void *const *)(b + offset), - get_cublas_datatype(), (int)ldb[i], &beta[i], - (void *const *)(c + offset), get_cublas_datatype(), - (int)ldc[i], (int)group_size[i], - get_cublas_datatype(), cublas_gemm_algo); + CUBLAS_ERROR_FUNC_T( + "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, + get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], + (int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), + get_cublas_datatype(), (int)lda[i], (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], + (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); #else CUBLAS_ERROR_FUNC_T_SYNC( "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), - get_cublas_datatype(), (int)lda[i], - (const void *const *)(b + offset), get_cublas_datatype(), - (int)ldb[i], &beta[i], (void *const *)(c + offset), - get_cublas_datatype(), (int)ldc[i], (int)group_size[i], - get_cublas_datatype(), cublas_gemm_algo); + get_cublas_datatype(), (int)lda[i], (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], + (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); #endif offset += group_size[i]; } @@ -832,13 +821,13 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a); auto **b_ = reinterpret_cast(b); - cublas_native_named_func(func_name, func, err, handle, - get_cublas_side_mode(left_right[i]), - get_cublas_fill_mode(upper_lower[i]), - get_cublas_operation(trans[i]), - get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], - (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], - b_ + offset, (int)ldb[i], (int)group_size[i]); + cublas_native_named_func( + func_name, func, err, handle, get_cublas_side_mode(left_right[i]), + get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]), + get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i], + (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], + (int)group_size[i]); + offset += group_size[i]; } });