From 17ce59fd685c9f8f5a0aa7217ba348987e5102aa Mon Sep 17 00:00:00 2001 From: Matti Kortelainen Date: Mon, 20 Feb 2023 15:52:02 +0100 Subject: [PATCH] Fix synchronization in ContextState test --- .../CUDACore/test/test_ScopedContext.cc | 39 ++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/HeterogeneousCore/CUDACore/test/test_ScopedContext.cc b/HeterogeneousCore/CUDACore/test/test_ScopedContext.cc index 5352d96714393..f8e6e9779d83e 100644 --- a/HeterogeneousCore/CUDACore/test/test_ScopedContext.cc +++ b/HeterogeneousCore/CUDACore/test/test_ScopedContext.cc @@ -1,6 +1,7 @@ #include "catch.hpp" #include "CUDADataFormats/Common/interface/Product.h" +#include "FWCore/Concurrency/interface/FinalWaitingTask.h" #include "FWCore/Concurrency/interface/WaitingTask.h" #include "FWCore/ParameterSet/interface/ParameterSet.h" #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h" @@ -16,6 +17,9 @@ #include "test_ScopedContextKernels.h" +#include "oneapi/tbb/task_arena.h" +#include "oneapi/tbb/task_group.h" + namespace cms::cudatest { class TestScopedContext { public: @@ -71,20 +75,27 @@ TEST_CASE("Use of cms::cuda::ScopedContext", "[CUDACore]") { } SECTION("Storing state in cms::cuda::ContextState") { - cms::cuda::ContextState ctxstate; - { // acquire - std::unique_ptr> dataPtr = ctx.wrap(10); - const auto& data = *dataPtr; - tbb::task_group group; - edm::WaitingTaskWithArenaHolder dummy{group, edm::make_waiting_task([](std::exception_ptr const* iPtr) {})}; - cms::cuda::ScopedContextAcquire ctx2{data, std::move(dummy), ctxstate}; - } - - { // produce - cms::cuda::ScopedContextProduce ctx2{ctxstate}; - REQUIRE(cms::cuda::currentDevice() == ctx.device()); - REQUIRE(ctx2.stream() == ctx.stream()); - } + oneapi::tbb::task_arena arena(1); + arena.execute([&ctx]() { + cms::cuda::ContextState ctxstate; + { // acquire + std::unique_ptr> dataPtr = ctx.wrap(10); + const auto& data = *dataPtr; + oneapi::tbb::task_group group; + edm::FinalWaitingTask waitTask{group}; + { + edm::WaitingTaskWithArenaHolder dummy{group, &waitTask}; + cms::cuda::ScopedContextAcquire ctx2{data, dummy, ctxstate}; + } + waitTask.wait(); + } + + { // produce + cms::cuda::ScopedContextProduce ctx2{ctxstate}; + REQUIRE(cms::cuda::currentDevice() == ctx.device()); + REQUIRE(ctx2.stream() == ctx.stream()); + } + }); } SECTION("Joining multiple CUDA streams") {