From d604f803d74e2de6a53da209ca66a020f942f624 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 8 Jul 2024 09:33:15 -0700 Subject: [PATCH] Enable OMP threading for host kernels --- CMakeLists.txt | 1 - include/targets/cuda/atomic_helper.h | 2 +- .../generic/block_reduction_kernel_host.h | 1 + include/targets/generic/kernel_host.h | 3 +++ include/targets/generic/reduction_kernel_host.h | 16 ++++++++++++---- include/targets/hip/atomic_helper.h | 2 +- lib/targets/cuda/target_cuda.cmake | 8 ++++++++ 7 files changed, 26 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f4e09a30ab..4f44a59233 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -280,7 +280,6 @@ mark_as_advanced(QUDA_RECONSTRUCT) mark_as_advanced(QUDA_CLOVER_CHOLESKY_PROMOTE) mark_as_advanced(QUDA_MULTIGRID_DSLASH_PROMOTE) mark_as_advanced(QUDA_CTEST_SEP_DSLASH_POLICIES) -mark_as_advanced(QUDA_OPENMP) mark_as_advanced(QUDA_BACKWARDS) diff --git a/include/targets/cuda/atomic_helper.h b/include/targets/cuda/atomic_helper.h index 43424e0f62..7620ee7cf2 100644 --- a/include/targets/cuda/atomic_helper.h +++ b/include/targets/cuda/atomic_helper.h @@ -81,7 +81,7 @@ namespace quda template struct atomic_fetch_abs_max_impl { template inline void operator()(T *addr, T val) { -#pragma omp atomic update +#pragma omp critical *addr = std::max(*addr, val); } }; diff --git a/include/targets/generic/block_reduction_kernel_host.h b/include/targets/generic/block_reduction_kernel_host.h index 1a356d4f1c..f0997a7d98 100644 --- a/include/targets/generic/block_reduction_kernel_host.h +++ b/include/targets/generic/block_reduction_kernel_host.h @@ -5,6 +5,7 @@ namespace quda { Functor t(arg); dim3 block(0, 0, 0); +#pragma omp parallel for for (block.y = 0; block.y < arg.grid_dim.y; block.y++) { for (block.x = 0; block.x < arg.grid_dim.x; block.x++) { t(block, dim3(0, 0, 0)); } } diff --git a/include/targets/generic/kernel_host.h b/include/targets/generic/kernel_host.h index 96523df955..1416b3a536 100644 --- a/include/targets/generic/kernel_host.h +++ b/include/targets/generic/kernel_host.h @@ -6,12 +6,14 @@ namespace quda template