From df6d7beb21bad171e138aa6c435e1da3e35dac0a Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Thu, 10 Feb 2022 01:18:04 +0100 Subject: [PATCH] Fix worker streams in OLS-eig executing in an unsafe order (#4539) The latest version of the "eig" OLS solver has a bug producing garbage results under some conditions. When at least one worker stream is used to run some operations concurrently, for sufficiently large workset sizes, the memory allocation in the main stream may finish later than the worker stream starts to use it. This PR adds more ordering between the main and the worker streams, fixing this and some other theoretically possible edge cases. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/cuml/pull/4539 --- cpp/src_prims/linalg/lstsq.cuh | 57 ++++++++++++++++++---------------- cpp/test/sg/ols.cu | 38 ++++++++++++++++++++--- 2 files changed, 64 insertions(+), 31 deletions(-) diff --git a/cpp/src_prims/linalg/lstsq.cuh b/cpp/src_prims/linalg/lstsq.cuh index 8af85c7b13..2132d6164e 100644 --- a/cpp/src_prims/linalg/lstsq.cuh +++ b/cpp/src_prims/linalg/lstsq.cuh @@ -52,7 +52,7 @@ struct DeviceEvent { DeviceEvent(bool concurrent) { if (concurrent) - RAFT_CUDA_TRY(cudaEventCreate(&e)); + RAFT_CUDA_TRY(cudaEventCreateWithFlags(&e, cudaEventDisableTiming)); else e = nullptr; } @@ -60,19 +60,14 @@ struct DeviceEvent { { if (e != nullptr) RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(e)); } - operator cudaEvent_t() const { return e; } void record(cudaStream_t stream) { if (e != nullptr) RAFT_CUDA_TRY(cudaEventRecord(e, stream)); } - void wait(cudaStream_t stream) + void wait_by(cudaStream_t stream) { if (e != nullptr) RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, e, 0u)); } - void wait() - { - if (e != nullptr) raft::interruptible::synchronize(e); - } DeviceEvent& operator=(const DeviceEvent& other) = delete; }; @@ -259,27 +254,26 @@ void lstsqEig(const raft::handle_t& handle, cudaStream_t stream) { rmm::cuda_stream_view mainStream = rmm::cuda_stream_view(stream); - rmm::cuda_stream_view multAbStream = mainStream; - bool concurrent = false; - { - int sp_size = handle.get_stream_pool_size(); - if (sp_size > 0) { - multAbStream = handle.get_stream_from_stream_pool(0); - // check if the two streams can run concurrently - if (!are_implicitly_synchronized(mainStream, multAbStream)) { - concurrent = true; - } else if (sp_size > 1) { - mainStream = multAbStream; - multAbStream = handle.get_stream_from_stream_pool(1); - concurrent = true; - } - } + rmm::cuda_stream_view multAbStream = handle.get_next_usable_stream(); + bool concurrent; + // Check if the two streams can run concurrently. This is needed because a legacy default stream + // would synchronize with other blocking streams. To avoid synchronization in such case, we try to + // use an additional stream from the pool. + if (!are_implicitly_synchronized(mainStream, multAbStream)) { + concurrent = true; + } else if (handle.get_stream_pool_size() > 1) { + mainStream = handle.get_next_usable_stream(); + concurrent = true; + } else { + multAbStream = mainStream; + concurrent = false; } - // the event is created only if the given raft handle is capable of running - // at least two CUDA streams without implicit synchronization. - DeviceEvent multAbDone(concurrent); rmm::device_uvector workset(n_cols * n_cols * 3 + n_cols * 2, mainStream); + // the event is created only if the given raft handle is capable of running + // at least two CUDA streams without implicit synchronization. + DeviceEvent worksetDone(concurrent); + worksetDone.record(mainStream); math_t* Q = workset.data(); math_t* QS = Q + n_cols * n_cols; math_t* covA = QS + n_cols * n_cols; @@ -304,7 +298,9 @@ void lstsqEig(const raft::handle_t& handle, mainStream); // Ab <- A* b + worksetDone.wait_by(multAbStream); raft::linalg::gemv(handle, A, n_rows, n_cols, b, Ab, true, multAbStream); + DeviceEvent multAbDone(concurrent); multAbDone.record(multAbStream); // Q S Q* <- covA @@ -329,9 +325,18 @@ void lstsqEig(const raft::handle_t& handle, alpha, beta, mainStream); - multAbDone.wait(mainStream); + + multAbDone.wait_by(mainStream); // w <- covA Ab == Q invS Q* A b == inv(A* A) A b raft::linalg::gemv(handle, covA, n_cols, n_cols, Ab, w, false, mainStream); + + // This event is created only if we use two worker streams, and `stream` is not the legacy stream, + // and `mainStream` is not a non-blocking stream. In fact, with the current logic these conditions + // are impossible together, but it still makes sense to put this construct here to emphasize that + // `stream` must wait till the work here is done (for future refactorings). + DeviceEvent mainDone(!are_implicitly_synchronized(mainStream, stream)); + mainDone.record(mainStream); + mainDone.wait_by(stream); } /** Solves the linear ordinary least squares problem `Aw = b` diff --git a/cpp/test/sg/ols.cu b/cpp/test/sg/ols.cu index 6998f500f0..ba3795028c 100644 --- a/cpp/test/sg/ols.cu +++ b/cpp/test/sg/ols.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -27,8 +28,29 @@ namespace GLM { using namespace MLCommon; +enum class hconf { SINGLE, LEGACY_ONE, LEGACY_TWO, NON_BLOCKING_ONE, NON_BLOCKING_TWO }; + +raft::handle_t create_handle(hconf type) +{ + switch (type) { + case hconf::LEGACY_ONE: + return raft::handle_t(rmm::cuda_stream_legacy, std::make_shared(1)); + case hconf::LEGACY_TWO: + return raft::handle_t(rmm::cuda_stream_legacy, std::make_shared(2)); + case hconf::NON_BLOCKING_ONE: + return raft::handle_t(rmm::cuda_stream_per_thread, + std::make_shared(1)); + case hconf::NON_BLOCKING_TWO: + return raft::handle_t(rmm::cuda_stream_per_thread, + std::make_shared(2)); + case hconf::SINGLE: + default: return raft::handle_t(); + } +} + template struct OlsInputs { + hconf hc; T tol; int n_row; int n_col; @@ -41,6 +63,7 @@ class OlsTest : public ::testing::TestWithParam> { public: OlsTest() : params(::testing::TestWithParam>::GetParam()), + handle(create_handle(params.hc)), stream(handle.get_stream()), coef(params.n_col, stream), coef2(params.n_col, stream), @@ -216,10 +239,11 @@ class OlsTest : public ::testing::TestWithParam> { } protected: + OlsInputs params; + raft::handle_t handle; cudaStream_t stream = 0; - OlsInputs params; rmm::device_uvector coef, coef_ref, pred, pred_ref; rmm::device_uvector coef2, coef2_ref, pred2, pred2_ref; rmm::device_uvector coef3, coef3_ref, pred3, pred3_ref; @@ -228,11 +252,15 @@ class OlsTest : public ::testing::TestWithParam> { T intercept, intercept2, intercept3; }; -const std::vector> inputsf2 = { - {0.001f, 4, 2, 2, 0}, {0.001f, 4, 2, 2, 1}, {0.001f, 4, 2, 2, 2}}; +const std::vector> inputsf2 = {{hconf::NON_BLOCKING_ONE, 0.001f, 4, 2, 2, 0}, + {hconf::NON_BLOCKING_TWO, 0.001f, 4, 2, 2, 1}, + {hconf::LEGACY_ONE, 0.001f, 4, 2, 2, 2}, + {hconf::LEGACY_TWO, 0.001f, 4, 2, 2, 2}, + {hconf::SINGLE, 0.001f, 4, 2, 2, 2}}; -const std::vector> inputsd2 = { - {0.001, 4, 2, 2, 0}, {0.001, 4, 2, 2, 1}, {0.001, 4, 2, 2, 2}}; +const std::vector> inputsd2 = {{hconf::SINGLE, 0.001, 4, 2, 2, 0}, + {hconf::LEGACY_ONE, 0.001, 4, 2, 2, 1}, + {hconf::LEGACY_TWO, 0.001, 4, 2, 2, 2}}; typedef OlsTest OlsTestF; TEST_P(OlsTestF, Fit)