From 1a49fc1bba8ccfb87c7a13b400665b337a1fcd66 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Tue, 8 Feb 2022 13:25:42 +0100 Subject: [PATCH] Interruptible execution (#433) ### Cooperative-style interruptible C++ threads. This proposal introduces `raft::interruptible` introducing three functions: ```C++ static void synchronize(rmm::cuda_stream_view stream); static void yield(); static void cancel(std::thread::id thread_id); ``` `synchronize` and `yield` serve as cancellation points for the executing CPU thread. `cancel` allows to throw an async exception in a target CPU thread, which is observed in the nearest cancellation point. Altogether, these allow to cancel a long-running job without killing the OS process. The key to make this work is an obvious observation that the CPU spends most of the time waiting on `cudaStreamSynchronize`. By replacing that with `interruptible::synchronize`, we introduce cancellation points in all critical places in code. If that is not enough in some edge cases (the cancellation points are too far apart), a developer can use `yield` to ensure that a cancellation request is received sooner rather than later. #### Implementation ##### C++ `raft::interruptible` keeps an `std::atomic_flag` in the thread-local storage in each thread, which tells whether the thread can continue executing (being in non-cancelled state). [`cancel`](https://github.com/rapidsai/raft/blob/6948cab96483ddc7047b1ae0a162574e32bcd8f0/cpp/include/raft/interruptible.hpp#L122) clears this flag, and [`yield`](https://github.com/rapidsai/raft/blob/6948cab96483ddc7047b1ae0a162574e32bcd8f0/cpp/include/raft/interruptible.hpp#L194-L204) checks it and resets to the signalled state (throwing a `raft::interrupted_exception` exception if necessary). [`synchronize`](https://github.com/rapidsai/raft/blob/6948cab96483ddc7047b1ae0a162574e32bcd8f0/cpp/include/raft/interruptible.hpp#L206-L217) implements a spinning lock querying the state of the stream and `yield`ing on each iteration. I also add an overload [`sync_stream`](https://github.com/rapidsai/raft/blob/ee99523ff6a8257ec213e5ad15292f2132a2a687/cpp/include/raft/handle.hpp#L133) to the raft handle type, to make it easier to modify the behavior of all synchronization calls in raft and cuml. ##### python This proposal adds a context manager [`cuda_interruptible`](https://github.com/rapidsai/raft/blob/36e8de5f73e9ec7e604b38a4290ac82bc35be4b7/python/raft/common/interruptible.pyx#L28) to handle Ctrl+C requests during C++ calls (using posix signals). `cuda_interruptible` simply calls `raft::interruptible::cancel` on the target C++ thread. #### Motivation See https://github.com/rapidsai/cuml/pull/4463 Resolves https://github.com/rapidsai/cuml/issues/4384 Authors: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/433 --- cpp/CMakeLists.txt | 14 +- cpp/cmake/modules/ConfigureCUDA.cmake | 6 +- cpp/include/raft/comms/detail/mpi_comms.hpp | 35 +-- cpp/include/raft/comms/detail/std_comms.hpp | 37 +-- cpp/include/raft/comms/detail/test.hpp | 18 +- cpp/include/raft/comms/detail/util.hpp | 40 ++- cpp/include/raft/handle.hpp | 14 +- cpp/include/raft/interruptible.hpp | 266 ++++++++++++++++++ .../raft/linalg/detail/cholesky_r1_update.hpp | 4 +- cpp/include/raft/linalg/detail/lanczos.hpp | 2 +- cpp/include/raft/linalg/detail/svd.hpp | 2 +- cpp/include/raft/mr/buffer_base.hpp | 2 +- .../sparse/hierarchy/detail/agglomerative.cuh | 2 +- .../raft/sparse/linalg/detail/spectral.cuh | 2 +- cpp/include/raft/sparse/op/detail/reduce.cuh | 2 +- .../selection/detail/connect_components.cuh | 2 +- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_adj.cu | 2 +- cpp/test/distance/distance_base.cuh | 2 +- cpp/test/handle.cpp | 2 +- cpp/test/interruptible.cu | 144 ++++++++++ cpp/test/linalg/add.cu | 2 +- cpp/test/linalg/binary_op.cu | 2 +- cpp/test/linalg/coalesced_reduction.cu | 2 +- cpp/test/linalg/divide.cu | 2 +- cpp/test/linalg/eig.cu | 2 +- cpp/test/linalg/eig_sel.cu | 4 +- cpp/test/linalg/eltwise.cu | 4 +- cpp/test/linalg/map_then_reduce.cu | 2 +- cpp/test/linalg/matrix_vector_op.cu | 2 +- cpp/test/linalg/multiply.cu | 2 +- cpp/test/linalg/norm.cu | 4 +- cpp/test/linalg/reduce.cu | 2 +- cpp/test/linalg/strided_reduction.cu | 2 +- cpp/test/linalg/subtract.cu | 2 +- cpp/test/linalg/svd.cu | 2 +- cpp/test/linalg/transpose.cu | 2 +- cpp/test/linalg/unary_op.cu | 4 +- cpp/test/matrix/math.cu | 2 +- cpp/test/matrix/matrix.cu | 2 +- cpp/test/random/rng_int.cu | 4 +- cpp/test/random/sample_without_replacement.cu | 2 +- cpp/test/sparse/connect_components.cu | 2 +- cpp/test/sparse/csr_row_slice.cu | 4 +- cpp/test/sparse/csr_transpose.cu | 2 +- cpp/test/sparse/knn_graph.cu | 2 +- cpp/test/sparse/linkage.cu | 2 +- cpp/test/sparse/symmetrize.cu | 2 +- cpp/test/spatial/haversine.cu | 2 +- cpp/test/spatial/knn.cu | 2 +- cpp/test/spatial/selection.cu | 2 +- cpp/test/stats/mean_center.cu | 2 +- cpp/test/stats/stddev.cu | 2 +- cpp/test/stats/sum.cu | 2 +- docs/source/cpp_api/core.rst | 8 + python/raft/common/interruptible.pxd | 34 +++ python/raft/common/interruptible.pyx | 84 ++++++ python/raft/test/test_interruptible.py | 54 ++++ python/setup.py | 4 +- 59 files changed, 719 insertions(+), 142 deletions(-) create mode 100644 cpp/include/raft/interruptible.hpp create mode 100644 cpp/test/interruptible.cu create mode 100644 python/raft/common/interruptible.pxd create mode 100644 python/raft/common/interruptible.pyx create mode 100644 python/raft/test/test_interruptible.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8acd9c0099..ea0ef2c2f1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -84,6 +84,13 @@ endif() ############################################################################## # - compiler options --------------------------------------------------------- +if (NOT DISABLE_OPENMP) + find_package(OpenMP) + if(OPENMP_FOUND) + message(VERBOSE "RAFT: OpenMP found in ${OpenMP_CXX_INCLUDE_DIRS}") + endif() +endif() + # * find CUDAToolkit package # * determine GPU architectures # * enable the CMake CUDA language @@ -97,13 +104,6 @@ include(cmake/modules/ConfigureCUDA.cmake) ############################################################################## # - Requirements ------------------------------------------------------------- -if (NOT DISABLE_OPENMP) - find_package(OpenMP) - if(OPENMP_FOUND) - message(VERBOSE "RAFT: OpenMP found in ${OpenMP_CXX_INCLUDE_DIRS}") - endif() -endif() - # add third party dependencies using CPM rapids_cpm_init() diff --git a/cpp/cmake/modules/ConfigureCUDA.cmake b/cpp/cmake/modules/ConfigureCUDA.cmake index a9163a474f..5984c424e7 100644 --- a/cpp/cmake/modules/ConfigureCUDA.cmake +++ b/cpp/cmake/modules/ConfigureCUDA.cmake @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,6 +38,10 @@ if(CUDA_ENABLE_LINEINFO) list(APPEND RAFT_CUDA_FLAGS -lineinfo) endif() +if(OpenMP_FOUND) + list(APPEND RAFT_CUDA_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS}) +endif() + # Debug options if(CMAKE_BUILD_TYPE MATCHES Debug) message(VERBOSE "RAFT: Building with debugging flags") diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index 3bfd72baf9..b0da532f0a 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -333,38 +333,7 @@ class mpi_comms : public comms_iface { stream)); } - status_t sync_stream(cudaStream_t stream) const - { - cudaError_t cudaErr; - ncclResult_t ncclErr, ncclAsyncErr; - while (1) { - cudaErr = cudaStreamQuery(stream); - if (cudaErr == cudaSuccess) return status_t::SUCCESS; - - if (cudaErr != cudaErrorNotReady) { - // An error occurred querying the status of the stream - return status_t::ERROR; - } - - ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr); - if (ncclErr != ncclSuccess) { - // An error occurred retrieving the asynchronous error - return status_t::ERROR; - } - - if (ncclAsyncErr != ncclSuccess) { - // An asynchronous error happened. Stop the operation and destroy - // the communicator - ncclErr = ncclCommAbort(nccl_comm_); - if (ncclErr != ncclSuccess) - // Caller may abort with an exception or try to re-create a new communicator. - return status_t::ABORT; - } - - // Let other threads (including NCCL threads) use the CPU. - pthread_yield(); - } - }; + status_t sync_stream(cudaStream_t stream) const { return nccl_sync_stream(nccl_comm_, stream); } // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 758a9d3781..d8b0f2090c 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -441,38 +441,7 @@ class std_comms : public comms_iface { stream)); } - status_t sync_stream(cudaStream_t stream) const - { - cudaError_t cudaErr; - ncclResult_t ncclErr, ncclAsyncErr; - while (1) { - cudaErr = cudaStreamQuery(stream); - if (cudaErr == cudaSuccess) return status_t::SUCCESS; - - if (cudaErr != cudaErrorNotReady) { - // An error occurred querying the status of the stream_ - return status_t::ERROR; - } - - ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr); - if (ncclErr != ncclSuccess) { - // An error occurred retrieving the asynchronous error - return status_t::ERROR; - } - - if (ncclAsyncErr != ncclSuccess) { - // An asynchronous error happened. Stop the operation and destroy - // the communicator - ncclErr = ncclCommAbort(nccl_comm_); - if (ncclErr != ncclSuccess) - // Caller may abort with an exception or try to re-create a new communicator. - return status_t::ABORT; - } - - // Let other threads (including NCCL threads) use the CPU. - std::this_thread::yield(); - } - } + status_t sync_stream(cudaStream_t stream) const { return nccl_sync_stream(nccl_comm_, stream); } // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const @@ -553,4 +522,4 @@ class std_comms : public comms_iface { }; } // namespace detail } // end namespace comms -} // end namespace raft \ No newline at end of file +} // end namespace raft diff --git a/cpp/include/raft/comms/detail/test.hpp b/cpp/include/raft/comms/detail/test.hpp index cd84d2becd..d81d7c80fb 100644 --- a/cpp/include/raft/comms/detail/test.hpp +++ b/cpp/include/raft/comms/detail/test.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,7 +53,7 @@ bool test_collective_allreduce(const handle_t& handle, int root) int temp_h = 0; RAFT_CUDA_TRY(cudaMemcpyAsync(&temp_h, temp_d.data(), 1, cudaMemcpyDeviceToHost, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -88,7 +88,7 @@ bool test_collective_broadcast(const handle_t& handle, int root) int temp_h = -1; // Verify more than one byte is being sent RAFT_CUDA_TRY( cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -121,7 +121,7 @@ bool test_collective_reduce(const handle_t& handle, int root) int temp_h = -1; // Verify more than one byte is being sent RAFT_CUDA_TRY( cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -158,7 +158,7 @@ bool test_collective_allgather(const handle_t& handle, int root) int temp_h[communicator.get_size()]; // Verify more than one byte is being sent RAFT_CUDA_TRY(cudaMemcpyAsync( &temp_h, recv_d.data(), sizeof(int) * communicator.get_size(), cudaMemcpyDeviceToHost, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -198,7 +198,7 @@ bool test_collective_gather(const handle_t& handle, int root) std::vector temp_h(communicator.get_size(), 0); RAFT_CUDA_TRY(cudaMemcpyAsync( temp_h.data(), recv_d.data(), sizeof(int) * temp_h.size(), cudaMemcpyDeviceToHost, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); for (int i = 0; i < communicator.get_size(); i++) { if (temp_h[i] != i) return false; @@ -253,7 +253,7 @@ bool test_collective_gatherv(const handle_t& handle, int root) sizeof(int) * displacements.back(), cudaMemcpyDeviceToHost, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); for (int i = 0; i < communicator.get_size(); i++) { if (std::count_if(temp_h.begin() + displacements[i], @@ -292,7 +292,7 @@ bool test_collective_reducescatter(const handle_t& handle, int root) int temp_h = -1; // Verify more than one byte is being sent RAFT_CUDA_TRY( cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -502,7 +502,7 @@ bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrial std::vector h_received_data(communicator.get_size()); raft::update_host(h_received_data.data(), received_data.data(), received_data.size(), stream); - CUDA_TRY(cudaStreamSynchronize(stream)); + h.sync_stream(stream); for (int i = 0; i < communicator.get_size(); ++i) { if (h_received_data[i] != i) { ret = false; } } diff --git a/cpp/include/raft/comms/detail/util.hpp b/cpp/include/raft/comms/detail/util.hpp index 1c0d152016..7bd60cf8e1 100644 --- a/cpp/include/raft/comms/detail/util.hpp +++ b/cpp/include/raft/comms/detail/util.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ #pragma once +#include + #include #include #include @@ -109,6 +111,42 @@ get_nccl_op(const op_t op) default: throw "Unsupported datatype"; } } + +status_t nccl_sync_stream(ncclComm_t comm, cudaStream_t stream) +{ + cudaError_t cudaErr; + ncclResult_t ncclErr, ncclAsyncErr; + while (1) { + cudaErr = cudaStreamQuery(stream); + if (cudaErr == cudaSuccess) return status_t::SUCCESS; + + if (cudaErr != cudaErrorNotReady) { + // An error occurred querying the status of the stream_ + return status_t::ERROR; + } + + ncclErr = ncclCommGetAsyncError(comm, &ncclAsyncErr); + if (ncclErr != ncclSuccess) { + // An error occurred retrieving the asynchronous error + return status_t::ERROR; + } + + if (ncclAsyncErr != ncclSuccess || !interruptible::yield_no_throw()) { + // An asynchronous error happened. Stop the operation and destroy + // the communicator + ncclErr = ncclCommAbort(comm); + if (ncclErr != ncclSuccess) + // Caller may abort with an exception or try to re-create a new communicator. + return status_t::ABORT; + // TODO: shouldn't we place status_t::ERROR above under the condition, and + // status_t::ABORT below here (i.e. after successful ncclCommAbort)? + } + + // Let other threads (including NCCL threads) use the CPU. + std::this_thread::yield(); + } +} + }; // namespace detail }; // namespace comms }; // namespace raft diff --git a/cpp/include/raft/handle.hpp b/cpp/include/raft/handle.hpp index 22e9e78ebe..015d422f9a 100644 --- a/cpp/include/raft/handle.hpp +++ b/cpp/include/raft/handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,7 @@ #include "cudart_utils.h" #include +#include #include #include #include @@ -127,10 +128,15 @@ class handle_t { rmm::exec_policy& get_thrust_policy() const { return *thrust_policy_; } + /** + * @brief synchronize a stream on the handle + */ + void sync_stream(rmm::cuda_stream_view stream) const { interruptible::synchronize(stream); } + /** * @brief synchronize main stream on the handle */ - void sync_stream() const { stream_view_.synchronize(); } + void sync_stream() const { sync_stream(stream_view_); } /** * @brief returns main stream on the handle @@ -199,7 +205,7 @@ class handle_t { void sync_stream_pool() const { for (std::size_t i = 0; i < get_stream_pool_size(); i++) { - stream_pool_->get_stream(i).synchronize(); + sync_stream(stream_pool_->get_stream(i)); } } @@ -212,7 +218,7 @@ class handle_t { { RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); for (const auto& stream_index : stream_indices) { - stream_pool_->get_stream(stream_index).synchronize(); + sync_stream(stream_pool_->get_stream(stream_index)); } } diff --git a/cpp/include/raft/interruptible.hpp b/cpp/include/raft/interruptible.hpp new file mode 100644 index 0000000000..7ff5ca0c88 --- /dev/null +++ b/cpp/include/raft/interruptible.hpp @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief Exception thrown during `interruptible::synchronize` call when it detects a request + * to cancel the work performed in this CPU thread. + */ +struct interrupted_exception : public raft::exception { + using raft::exception::exception; +}; + +/** + * @brief Cooperative-style interruptible execution. + * + * This class provides facilities for interrupting execution of a C++ thread at designated points + * in code from outside of the thread. In particular, it provides an interruptible version of the + * blocking CUDA synchronization function, that allows dropping a long-running GPU work. + * + * + * **Important:** Although CUDA synchronize calls serve as cancellation points, the interruptible + * machinery has nothing to do with CUDA streams or events. In other words, when you call `cancel`, + * it’s the CPU waiting function what is interrupted, not the GPU stream work. This means, when the + * `interrupted_exception` is raised, any unfinished GPU stream work continues to run. It’s the + * responsibility of the developer then to make sure the unfinished stream work does not affect the + * program in an undesirable way. + * + * + * What can happen to CUDA stream when the `synchronize` is cancelled? If you catch the + * `interrupted_exception` immediately, you can safely wait on the stream again. + * Otherwise, some of the allocated resources may be released before the active kernel finishes + * using them, which will result in writing into deallocated or reallocated memory and undefined + * behavior in general. A dead-locked kernel may never finish (or may crash if you’re lucky). In + * practice, the outcome is usually acceptable for the use case of emergency program interruption + * (e.g., CTRL+C), but extra effort on the use side is required to allow safe interrupting and + * resuming of the GPU stream work. + */ +class interruptible { + public: + /** + * @brief Synchronize the CUDA stream, subject to being interrupted by `interruptible::cancel` + * called on this CPU thread. + * + * @param [in] stream a CUDA stream. + * + * @throw raft::interrupted_exception if interruptible::cancel() was called on the current CPU + * thread before the currently captured work has been finished. + * @throw raft::cuda_error if another CUDA error happens. + */ + static inline void synchronize(rmm::cuda_stream_view stream) + { + get_token()->synchronize_impl(cudaStreamQuery, stream); + } + + /** + * @brief Synchronize the CUDA event, subject to being interrupted by `interruptible::cancel` + * called on this CPU thread. + * + * @param [in] event a CUDA event. + * + * @throw raft::interrupted_exception if interruptible::cancel() was called on the current CPU + * thread before the currently captured work has been finished. + * @throw raft::cuda_error if another CUDA error happens. + */ + static inline void synchronize(cudaEvent_t event) + { + get_token()->synchronize_impl(cudaEventQuery, event); + } + + /** + * @brief Check the thread state, whether the thread can continue execution or is interrupted by + * `interruptible::cancel`. + * + * This is a cancellation point for an interruptible thread. It's called in the internals of + * `interruptible::synchronize` in a loop. If two synchronize calls are far apart, it's + * recommended to call `interruptible::yield()` in between to make sure the thread does not become + * unresponsive for too long. + * + * Both `yield` and `yield_no_throw` reset the state to non-cancelled after execution. + * + * @throw raft::interrupted_exception if interruptible::cancel() was called on the current CPU + * thread. + */ + static inline void yield() { get_token()->yield_impl(); } + + /** + * @brief Check the thread state, whether the thread can continue execution or is interrupted by + * `interruptible::cancel`. + * + * Same as `interruptible::yield`, but does not throw an exception if the thread is cancelled. + * + * Both `yield` and `yield_no_throw` reset the state to non-cancelled after execution. + * + * @return whether the thread can continue, i.e. `true` means continue, `false` means cancelled. + */ + static inline auto yield_no_throw() -> bool { return get_token()->yield_no_throw_impl(); } + + /** + * @brief Get a cancellation token for this CPU thread. + * + * @return an object that can be used to cancel the GPU work waited on this CPU thread. + */ + static inline auto get_token() -> std::shared_ptr + { + // NB: using static thread-local storage to keep the token alive once it is initialized + static thread_local std::shared_ptr s( + get_token_impl(std::this_thread::get_id())); + return s; + } + + /** + * @brief Get a cancellation token for a CPU thread given by its id. + * + * The returned token may live longer than the associated thread. In that case, using its + * `cancel` method has no effect. + * + * @param [in] thread_id an id of a C++ CPU thread. + * @return an object that can be used to cancel the GPU work waited on the given CPU thread. + */ + static inline auto get_token(std::thread::id thread_id) -> std::shared_ptr + { + return get_token_impl(thread_id); + } + + /** + * @brief Cancel any current or next call to `interruptible::synchronize` performed on the + * CPU thread given by the `thread_id` + * + * Note, this function uses a mutex to safely get a cancellation token that may be shared + * among multiple threads. If you plan to use it from a signal handler, consider the non-static + * `cancel()` instead. + * + * @param [in] thread_id a CPU thread, in which the work should be interrupted. + */ + static inline void cancel(std::thread::id thread_id) { get_token(thread_id)->cancel(); } + + /** + * @brief Cancel any current or next call to `interruptible::synchronize` performed on the + * CPU thread given by this `interruptible` token. + * + * Note, this function does not involve thread synchronization/locks and does not throw any + * exceptions, so it's safe to call from a signal handler. + */ + inline void cancel() noexcept { continue_.clear(std::memory_order_relaxed); } + + // don't allow the token to leave the shared_ptr + interruptible(interruptible const&) = delete; + interruptible(interruptible&&) = delete; + auto operator=(interruptible const&) -> interruptible& = delete; + auto operator=(interruptible&&) -> interruptible& = delete; + + private: + /** Global registry of thread-local cancellation stores. */ + static inline std::unordered_map> registry_; + /** Protect the access to the registry. */ + static inline std::mutex mutex_; + + /** + * Create a new interruptible token or get an existing from the global registry_. + * + * Presumptions: + * + * 1. get_token_impl must be called at most once per thread. + * 2. When `Claim == true`, thread_id must be equal to std::this_thread::get_id(). + * 3. get_token_impl can be called as many times as needed, producing a valid + * token for any input thread_id, independent of whether a C++ thread with this + * id exists or not. + * + * @tparam Claim whether to bind the token to the given thread. + * @param [in] thread_id the id of the associated C++ thread. + * @return new or existing interruptible token. + */ + template + static auto get_token_impl(std::thread::id thread_id) -> std::shared_ptr + { + std::lock_guard guard_get(mutex_); + // the following constructs an empty shared_ptr if the key does not exist. + auto& weak_store = registry_[thread_id]; + auto thread_store = weak_store.lock(); + if (!thread_store || (Claim && thread_store->claimed_)) { + // Create a new thread_store in two cases: + // 1. It does not exist in the map yet + // 2. The previous store in the map has not yet been deleted + thread_store.reset(new interruptible(), [thread_id](auto ts) { + std::lock_guard guard_erase(mutex_); + auto found = registry_.find(thread_id); + if (found != registry_.end()) { + auto stored = found->second.lock(); + // thread_store is not moveable, thus retains its original location. + // Not equal pointers below imply the new store has been already placed + // in the registry_ by the same std::thread::id + if (!stored || stored.get() == ts) { registry_.erase(found); } + } + delete ts; + }); + std::weak_ptr(thread_store).swap(weak_store); + } + // The thread_store is "claimed" by the thread + if constexpr (Claim) { thread_store->claimed_ = true; } + return thread_store; + } + + /** + * Communicate whether the thread is in a cancelled state or can continue execution. + * + * `yield` checks this flag and always resets it to the signalled state; `cancel` clears it. + * These are the only two places where it's used. + */ + std::atomic_flag continue_; + /** This flag is set to true when the created token is placed into a thread-local storage. */ + bool claimed_ = false; + + interruptible() noexcept { yield_no_throw_impl(); } + + void yield_impl() + { + if (!yield_no_throw_impl()) { + throw interrupted_exception("The work in this thread was cancelled."); + } + } + + auto yield_no_throw_impl() noexcept -> bool + { + return continue_.test_and_set(std::memory_order_relaxed); + } + + template + inline void synchronize_impl(Query query, Object object) + { + cudaError_t query_result; + while (true) { + yield_impl(); + query_result = query(object); + if (query_result != cudaErrorNotReady) { break; } + std::this_thread::yield(); + } + RAFT_CUDA_TRY(query_result); + } +}; + +} // namespace raft diff --git a/cpp/include/raft/linalg/detail/cholesky_r1_update.hpp b/cpp/include/raft/linalg/detail/cholesky_r1_update.hpp index 335544e094..48993886a6 100644 --- a/cpp/include/raft/linalg/detail/cholesky_r1_update.hpp +++ b/cpp/include/raft/linalg/detail/cholesky_r1_update.hpp @@ -112,7 +112,7 @@ void choleskyRank1Update(const raft::handle_t& handle, math_t L_22_host; raft::update_host(&s_host, s, 1, stream); raft::update_host(&L_22_host, L_22, 1, stream); // L_22 stores A_22 - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); L_22_host = std::sqrt(L_22_host - s_host); // Check for numeric error with sqrt. If the matrix is not positive definit or @@ -126,4 +126,4 @@ void choleskyRank1Update(const raft::handle_t& handle, } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/linalg/detail/lanczos.hpp b/cpp/include/raft/linalg/detail/lanczos.hpp index c761c06c14..3d8fde7e68 100644 --- a/cpp/include/raft/linalg/detail/lanczos.hpp +++ b/cpp/include/raft/linalg/detail/lanczos.hpp @@ -273,7 +273,7 @@ int performLanczosIteration(handle_t const& handle, RAFT_CUBLAS_TRY(cublasscal(cublas_h, n, &alpha, lanczosVecs_dev + IDX(0, *iter, n), 1, stream)); } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); return 0; } diff --git a/cpp/include/raft/linalg/detail/svd.hpp b/cpp/include/raft/linalg/detail/svd.hpp index 796adc89ff..5d349cd101 100644 --- a/cpp/include/raft/linalg/detail/svd.hpp +++ b/cpp/include/raft/linalg/detail/svd.hpp @@ -101,7 +101,7 @@ void svdQR(const raft::handle_t& handle, int dev_info; raft::update_host(&dev_info, devInfo.data(), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); ASSERT(dev_info == 0, "svd.cuh: svd couldn't converge to a solution. " "This usually occurs when some of the features do not vary enough."); diff --git a/cpp/include/raft/mr/buffer_base.hpp b/cpp/include/raft/mr/buffer_base.hpp index 6998c1f186..11724bed00 100644 --- a/cpp/include/raft/mr/buffer_base.hpp +++ b/cpp/include/raft/mr/buffer_base.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh b/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh index 4e78494e6b..31ebe38d85 100644 --- a/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh +++ b/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh @@ -119,7 +119,7 @@ void build_dendrogram_host(const handle_t& handle, update_host(mst_dst_h.data(), cols, n_edges, stream); update_host(mst_weights_h.data(), data, n_edges, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); std::vector children_h(n_edges * 2); std::vector out_size_h(n_edges); diff --git a/cpp/include/raft/sparse/linalg/detail/spectral.cuh b/cpp/include/raft/sparse/linalg/detail/spectral.cuh index de62f25ffa..9d1741fab7 100644 --- a/cpp/include/raft/sparse/linalg/detail/spectral.cuh +++ b/cpp/include/raft/sparse/linalg/detail/spectral.cuh @@ -51,7 +51,7 @@ void fit_embedding(const raft::handle_t& handle, rmm::device_uvector eigVecs(n * (n_components + 1), stream); rmm::device_uvector labels(n, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); /** * Raft spectral clustering diff --git a/cpp/include/raft/sparse/op/detail/reduce.cuh b/cpp/include/raft/sparse/op/detail/reduce.cuh index 074a139ba9..ba728f54c8 100644 --- a/cpp/include/raft/sparse/op/detail/reduce.cuh +++ b/cpp/include/raft/sparse/op/detail/reduce.cuh @@ -147,7 +147,7 @@ void max_duplicates(const raft::handle_t& handle, // compute final size value_idx size = 0; raft::update_host(&size, diff.data() + (diff.size() - 1), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); size++; out.allocate(size, m, n, true, stream); diff --git a/cpp/include/raft/sparse/selection/detail/connect_components.cuh b/cpp/include/raft/sparse/selection/detail/connect_components.cuh index b56b2df02e..afbb7f17b3 100644 --- a/cpp/include/raft/sparse/selection/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/selection/detail/connect_components.cuh @@ -415,7 +415,7 @@ void connect_components( // compute final size value_idx size = 0; raft::update_host(&size, out_index.data() + (out_index.size() - 1), 1, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); size++; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 369aac1e7c..a3df5c7a4b 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -36,6 +36,7 @@ add_executable(test_raft test/eigen_solvers.cu test/handle.cpp test/integer_utils.cpp + test/interruptible.cu test/nvtx.cpp test/pow2_utils.cu test/label/label.cu diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index 8637d1f6bb..3bfc70ccf0 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -128,7 +128,7 @@ class DistanceAdjTest : public ::testing::TestWithParam> { stream, isRowMajor, metric_arg); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/handle.cpp b/cpp/test/handle.cpp index 22816d0aad..118002dba0 100644 --- a/cpp/test/handle.cpp +++ b/cpp/test/handle.cpp @@ -47,7 +47,7 @@ TEST(Raft, Handle) rmm::cuda_stream_view stream_view(stream); handle_t handle(stream_view); ASSERT_EQ(stream_view, handle.get_stream()); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } diff --git a/cpp/test/interruptible.cu b/cpp/test/interruptible.cu new file mode 100644 index 0000000000..92adfabd55 --- /dev/null +++ b/cpp/test/interruptible.cu @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +__global__ void gpu_wait(int millis) +{ + for (auto i = millis; i > 0; i--) { +#if __CUDA_ARCH__ >= 700 + __nanosleep(1000000); +#else + // For older CUDA devices: + // just do some random work that takes more or less the same time from run to run. + volatile double x = 0; + for (int i = 0; i < 10000; i++) { + x = x + double(i); + x /= 2.0; + __syncthreads(); + } +#endif + } +} + +TEST(Raft, InterruptibleBasic) +{ + ASSERT_TRUE(interruptible::yield_no_throw()); + + // Cancel using the token + interruptible::get_token()->cancel(); + ASSERT_FALSE(interruptible::yield_no_throw()); + ASSERT_TRUE(interruptible::yield_no_throw()); + + // Cancel by thread id + interruptible::cancel(std::this_thread::get_id()); + ASSERT_FALSE(interruptible::yield_no_throw()); + ASSERT_TRUE(interruptible::yield_no_throw()); +} + +TEST(Raft, InterruptibleRepeatedGetToken) +{ + auto i = std::this_thread::get_id(); + auto a1 = interruptible::get_token(); + auto count = a1.use_count(); + auto a2 = interruptible::get_token(); + ASSERT_LT(count, a1.use_count()); + count = a1.use_count(); + auto b1 = interruptible::get_token(i); + ASSERT_LT(count, a1.use_count()); + count = a1.use_count(); + auto b2 = interruptible::get_token(i); + ASSERT_LT(count, a1.use_count()); + + ASSERT_EQ(a1, a2); + ASSERT_EQ(a1, b2); + ASSERT_EQ(b1, b2); +} + +TEST(Raft, InterruptibleDelayedInit) +{ + std::thread([&]() { + auto a = interruptible::get_token(std::this_thread::get_id()); + ASSERT_EQ(a.use_count(), 1); // the only pointer here is [a] + auto b = interruptible::get_token(); + ASSERT_EQ(a.use_count(), 3); // [a, b, thread_local] + auto c = interruptible::get_token(); + ASSERT_EQ(a.use_count(), 4); // [a, b, c, thread_local] + ASSERT_EQ(a.get(), b.get()); + ASSERT_EQ(a.get(), c.get()); + }).join(); +} + +TEST(Raft, InterruptibleOpenMP) +{ + // number of threads must be smaller than max number of resident grids for GPU + const int n_threads = 10; + // 1 <= n_expected_succeed <= n_threads + const int n_expected_succeed = 5; + // How many milliseconds passes between a thread i and i+1 finishes. + // i.e. thread i executes (C + i*n_expected_succeed) milliseconds in total. + const int thread_delay_millis = 20; + common::nvtx::range fun_scope("interruptible"); + + std::vector> thread_tokens(n_threads); + int n_finished = 0; + int n_cancelled = 0; + + omp_set_dynamic(0); + omp_set_num_threads(n_threads); +#pragma omp parallel reduction(+ : n_finished) reduction(+ : n_cancelled) num_threads(n_threads) + { + auto i = omp_get_thread_num(); + common::nvtx::range omp_scope("interruptible::thread-%d", i); + rmm::cuda_stream stream; + gpu_wait<<<1, 1, 0, stream.value()>>>(1); + interruptible::synchronize(stream); + thread_tokens[i] = interruptible::get_token(); + +#pragma omp barrier + try { + common::nvtx::range wait_scope("interruptible::wait-%d", i); + gpu_wait<<<1, 1, 0, stream.value()>>>((1 + i) * thread_delay_millis); + interruptible::synchronize(stream); + n_finished = 1; + } catch (interrupted_exception&) { + n_cancelled = 1; + } + if (i == n_expected_succeed - 1) { + common::nvtx::range cancel_scope("interruptible::cancel-%d", i); + for (auto token : thread_tokens) + token->cancel(); + } + +#pragma omp barrier + // clear the cancellation state to not disrupt other tests + interruptible::yield_no_throw(); + } + ASSERT_EQ(n_finished, n_expected_succeed); + ASSERT_EQ(n_cancelled, n_threads - n_expected_succeed); +} +} // namespace raft diff --git a/cpp/test/linalg/add.cu b/cpp/test/linalg/add.cu index c277db76ee..d5daef8d7b 100644 --- a/cpp/test/linalg/add.cu +++ b/cpp/test/linalg/add.cu @@ -47,7 +47,7 @@ class AddTest : public ::testing::TestWithParam> { r.uniform(in2.data(), len, InT(-1.0), InT(1.0), stream); naiveAddElem(out_ref.data(), in1.data(), in2.data(), len, stream); add(out.data(), in1.data(), in2.data(), len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void compare() diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index 55810a5ca0..d1b00da728 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -58,7 +58,7 @@ class BinaryOpTest : public ::testing::TestWithParam> { stream, tol, sweeps); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/eig_sel.cu b/cpp/test/linalg/eig_sel.cu index 4ae2653e47..7aab2c18c0 100644 --- a/cpp/test/linalg/eig_sel.cu +++ b/cpp/test/linalg/eig_sel.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -92,7 +92,7 @@ class EigSelTest : public ::testing::TestWithParam> { eig_vals.data(), EigVecMemUsage::OVERWRITE_INPUT, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/eltwise.cu b/cpp/test/linalg/eltwise.cu index 146d48e179..982dc21573 100644 --- a/cpp/test/linalg/eltwise.cu +++ b/cpp/test/linalg/eltwise.cu @@ -76,7 +76,7 @@ class ScalarMultiplyTest : public ::testing::TestWithParam> { r.uniform(in2, len, T(-1.0), T(1.0), stream); naiveAdd(out_ref, in1, in2, len, stream); eltwiseAdd(out, in1, in2, len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/map_then_reduce.cu b/cpp/test/linalg/map_then_reduce.cu index 9875e2548f..a12bb6ff9d 100644 --- a/cpp/test/linalg/map_then_reduce.cu +++ b/cpp/test/linalg/map_then_reduce.cu @@ -87,7 +87,7 @@ class MapReduceTest : public ::testing::TestWithParam> { auto len = params.len; r.uniform(in.data(), len, InType(-1.0), InType(1.0), stream); mapReduceLaunch(out_ref.data(), out.data(), in.data(), len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/matrix_vector_op.cu b/cpp/test/linalg/matrix_vector_op.cu index 4ff5243826..1a97603430 100644 --- a/cpp/test/linalg/matrix_vector_op.cu +++ b/cpp/test/linalg/matrix_vector_op.cu @@ -134,7 +134,7 @@ class MatVecOpTest : public ::testing::TestWithParam> params.bcastAlongRows, params.useTwoVectors, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/multiply.cu b/cpp/test/linalg/multiply.cu index ec0599eb1b..6341fa341d 100644 --- a/cpp/test/linalg/multiply.cu +++ b/cpp/test/linalg/multiply.cu @@ -45,7 +45,7 @@ class MultiplyTest : public ::testing::TestWithParam> { r.uniform(in.data(), len, T(-1.0), T(1.0), stream); naiveScale(out_ref.data(), in.data(), params.scalar, len, stream); multiplyScalar(out.data(), in.data(), params.scalar, len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index 56e111d056..e574c52692 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -95,7 +95,7 @@ class RowNormTest : public ::testing::TestWithParam> { } else { rowNorm(dots_act.data(), data.data(), cols, rows, params.type, params.rowMajor, stream); } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: @@ -159,7 +159,7 @@ class ColNormTest : public ::testing::TestWithParam> { } else { colNorm(dots_act.data(), data.data(), cols, rows, params.type, params.rowMajor, stream); } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/reduce.cu b/cpp/test/linalg/reduce.cu index 14f34f142d..cb69dc0e81 100644 --- a/cpp/test/linalg/reduce.cu +++ b/cpp/test/linalg/reduce.cu @@ -96,7 +96,7 @@ class ReduceTest : public ::testing::TestWithParam reduceLaunch( dots_act.data(), data.data(), cols, rows, params.rowMajor, params.alongRows, true, stream); } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/strided_reduction.cu b/cpp/test/linalg/strided_reduction.cu index 6d33fbdef1..840889dee8 100644 --- a/cpp/test/linalg/strided_reduction.cu +++ b/cpp/test/linalg/strided_reduction.cu @@ -61,7 +61,7 @@ class stridedReductionTest : public ::testing::TestWithParam> { subtractScalar(out.data(), out.data(), T(1), len, stream); subtract(in1.data(), in1.data(), in2.data(), len, stream); subtractScalar(in1.data(), in1.data(), T(1), len, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/svd.cu b/cpp/test/linalg/svd.cu index e9128bad93..e074197dec 100644 --- a/cpp/test/linalg/svd.cu +++ b/cpp/test/linalg/svd.cu @@ -91,7 +91,7 @@ class SvdTest : public ::testing::TestWithParam> { true, true, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index 60db1ee82b..3c651bb8ee 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -63,7 +63,7 @@ class TransposeTest : public ::testing::TestWithParam> { transpose(handle, data.data(), data_trans.data(), params.n_row, params.n_col, stream); transpose(data.data(), params.n_row, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/linalg/unary_op.cu b/cpp/test/linalg/unary_op.cu index 050fed78ea..7a976ec336 100644 --- a/cpp/test/linalg/unary_op.cu +++ b/cpp/test/linalg/unary_op.cu @@ -59,7 +59,7 @@ class UnaryOpTest : public ::testing::TestWithParam(params.tolerance))); } diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index 3215df0d73..127e582145 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -177,7 +177,7 @@ class MathTest : public ::testing::TestWithParam> { update_device(out_smallzero_ref.data(), in_small_val_zero_ref_h.data(), 4, stream); setSmallValuesZero(out_smallzero.data(), in_smallzero.data(), 4, stream); setSmallValuesZero(in_smallzero.data(), 4, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 86b94fb011..fb2f6c6b15 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -63,7 +63,7 @@ class MatrixTest : public ::testing::TestWithParam> { rmm::device_uvector outTrunc(6, stream); truncZeroOrigin(in1.data(), params.n_row, outTrunc.data(), 3, 2, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/random/rng_int.cu b/cpp/test/random/rng_int.cu index 4b0f1f0a4f..02c8dc9f39 100644 --- a/cpp/test/random/rng_int.cu +++ b/cpp/test/random/rng_int.cu @@ -94,10 +94,10 @@ class RngTest : public ::testing::TestWithParam> { meanKernel<<>>( stats.data(), data.data(), params.len); update_host(h_stats, stats.data(), 2, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); h_stats[0] /= params.len; h_stats[1] = (h_stats[1] / params.len) - (h_stats[0] * h_stats[0]); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void getExpectedMeanVar(float meanvar[2]) diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index d3b1baf388..a8bba340fa 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -77,7 +77,7 @@ class SWoRTest : public ::testing::TestWithParam> { params.len, stream); update_host(&(h_outIdx[0]), outIdx.data(), params.sampledLen, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/sparse/connect_components.cu b/cpp/test/sparse/connect_components.cu index 648964fc57..e4b197d7f5 100644 --- a/cpp/test/sparse/connect_components.cu +++ b/cpp/test/sparse/connect_components.cu @@ -127,7 +127,7 @@ class ConnectComponentsTest false, false); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); // The sum of edges for both MST runs should be n_rows - 1 final_edges = output_mst.n_edges + mst_coo.n_edges; diff --git a/cpp/test/sparse/csr_row_slice.cu b/cpp/test/sparse/csr_row_slice.cu index e37827d18d..f0a245b432 100644 --- a/cpp/test/sparse/csr_row_slice.cu +++ b/cpp/test/sparse/csr_row_slice.cu @@ -98,7 +98,7 @@ class CSRRowSliceTest : public ::testing::TestWithParamrows(), out->cols(), out->vals(), out->nnz, sum.data()); sum_h = sum.value(stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void TearDown() override { delete out; } diff --git a/cpp/test/sparse/linkage.cu b/cpp/test/sparse/linkage.cu index cb09b9e7f5..7944d0ee1f 100644 --- a/cpp/test/sparse/linkage.cu +++ b/cpp/test/sparse/linkage.cu @@ -188,7 +188,7 @@ class LinkageTest : public ::testing::TestWithParam> { params.c, params.n_clusters); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); score = compute_rand_index(labels.data(), labels_ref.data(), params.n_row, stream); } diff --git a/cpp/test/sparse/symmetrize.cu b/cpp/test/sparse/symmetrize.cu index 9c766d2d05..9a2e35b0fe 100644 --- a/cpp/test/sparse/symmetrize.cu +++ b/cpp/test/sparse/symmetrize.cu @@ -111,7 +111,7 @@ class SparseSymmetrizeTest out.rows(), out.cols(), out.vals(), out.nnz, sum.data()); sum_h = sum.value(stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/cpp/test/spatial/haversine.cu b/cpp/test/spatial/haversine.cu index 6b7402a7bd..f78c6c46da 100644 --- a/cpp/test/spatial/haversine.cu +++ b/cpp/test/spatial/haversine.cu @@ -94,7 +94,7 @@ class HaversineKNNTest : public ::testing::Test { k, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void SetUp() override { basicTest(); } diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index ee216ee434..54c3b55e5e 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -153,7 +153,7 @@ class KNNTest : public ::testing::TestWithParam { raft::copy(input_.data(), input_ptr, rows_ * cols_, stream); raft::copy(search_data_.data(), input_ptr, rows_ * cols_, stream); raft::copy(search_labels_.data(), labels_ptr, rows_, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } private: diff --git a/cpp/test/spatial/selection.cu b/cpp/test/spatial/selection.cu index 8ccf3b6b73..769406487a 100644 --- a/cpp/test/spatial/selection.cu +++ b/cpp/test/spatial/selection.cu @@ -111,7 +111,7 @@ class SparseSelectionTest k, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void compare() diff --git a/cpp/test/stats/mean_center.cu b/cpp/test/stats/mean_center.cu index af6d7c8d7b..ddabe0e814 100644 --- a/cpp/test/stats/mean_center.cu +++ b/cpp/test/stats/mean_center.cu @@ -79,7 +79,7 @@ class MeanCenterTest : public ::testing::TestWithParam> { vars_act.resize(cols, stream); r.normal(data.data(), len, params.mean, params.stddev, stream); stdVarSGtest(data.data(), stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void stdVarSGtest(T* data, cudaStream_t stream) diff --git a/cpp/test/stats/sum.cu b/cpp/test/stats/sum.cu index fd656423ad..0df140b8b4 100644 --- a/cpp/test/stats/sum.cu +++ b/cpp/test/stats/sum.cu @@ -62,7 +62,7 @@ class SumTest : public ::testing::TestWithParam> { raft::update_device(data.data(), data_h, len, stream); sum(sum_act.data(), data.data(), cols, rows, false, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index 13a4dca267..bae39e3282 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -11,3 +11,11 @@ handle_t .. doxygenclass:: raft::handle_t :project: RAFT :members: + + +interruptible +######## + +.. doxygenclass:: raft::interruptible + :project: RAFT + :members: diff --git a/python/raft/common/interruptible.pxd b/python/raft/common/interruptible.pxd new file mode 100644 index 0000000000..a73e8c1ac7 --- /dev/null +++ b/python/raft/common/interruptible.pxd @@ -0,0 +1,34 @@ +# +# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libcpp.memory cimport shared_ptr +from rmm._lib.cuda_stream_view cimport cuda_stream_view + +cdef extern from "raft/interruptible.hpp" namespace "raft" nogil: + cdef cppclass interruptible: + void cancel() + +cdef extern from "raft/interruptible.hpp" \ + namespace "raft::interruptible" nogil: + cdef void inter_synchronize \ + "raft::interruptible::synchronize"(cuda_stream_view stream) except+ + cdef void inter_yield "raft::interruptible::yield"() except+ + cdef shared_ptr[interruptible] get_token() except+ diff --git a/python/raft/common/interruptible.pyx b/python/raft/common/interruptible.pyx new file mode 100644 index 0000000000..dfc95490ed --- /dev/null +++ b/python/raft/common/interruptible.pyx @@ -0,0 +1,84 @@ +# +# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import contextlib +import signal +from cython.operator cimport dereference + +from rmm._lib.cuda_stream_view cimport cuda_stream_view +from cuda.ccudart cimport cudaStream_t +from .cuda cimport Stream + + +@contextlib.contextmanager +def cuda_interruptible(): + ''' + Temporarily install a keyboard interrupt handler (Ctrl+C) + that cancels the enclosed interruptible C++ thread. + + Use this on a long-running C++ function imported via cython: + + .. code-block:: python + + with cuda_interruptible(): + my_long_running_function(...) + + It's also recommended to release the GIL during the call, to + make sure the handler has a chance to run: + + .. code-block:: python + + with cuda_interruptible(): + with nogil: + my_long_running_function(...) + + ''' + cdef shared_ptr[interruptible] token = get_token() + + def newhr(*args, **kwargs): + with nogil: + dereference(token).cancel() + + oldhr = signal.signal(signal.SIGINT, newhr) + try: + yield + finally: + signal.signal(signal.SIGINT, oldhr) + + +def synchronize(stream: Stream): + ''' + Same as cudaStreamSynchronize, but can be interrupted + if called within a `with cuda_interruptible()` block. + ''' + cdef cuda_stream_view c_stream = cuda_stream_view(stream.getStream()) + with nogil: + inter_synchronize(c_stream) + + +def cuda_yield(): + ''' + Check for an asynchronously received interrupted_exception. + Raises the exception if a user pressed Ctrl+C within a + `with cuda_interruptible()` block before. + ''' + with nogil: + inter_yield() diff --git a/python/raft/test/test_interruptible.py b/python/raft/test/test_interruptible.py new file mode 100644 index 0000000000..81f4f99ed8 --- /dev/null +++ b/python/raft/test/test_interruptible.py @@ -0,0 +1,54 @@ + +import os +import pytest +import signal +import time +from raft.common.interruptible import cuda_interruptible, cuda_yield + + +def send_ctrl_c(): + # signal.raise_signal(signal.SIGINT) available only since python 3.8 + os.kill(os.getpid(), signal.SIGINT) + + +def test_should_cancel_via_interruptible(): + start_time = time.monotonic() + with pytest.raises(RuntimeError, match='this thread was cancelled'): + with cuda_interruptible(): + send_ctrl_c() + cuda_yield() + time.sleep(1.0) + end_time = time.monotonic() + assert end_time < start_time + 0.5, \ + "The process seems to have waited, while it shouldn't have." + + +def test_should_cancel_via_python(): + start_time = time.monotonic() + with pytest.raises(KeyboardInterrupt): + send_ctrl_c() + cuda_yield() + time.sleep(1.0) + end_time = time.monotonic() + assert end_time < start_time + 0.5, \ + "The process seems to have waited, while it shouldn't have." + + +def test_should_wait_no_interrupt(): + start_time = time.monotonic() + with cuda_interruptible(): + cuda_yield() + time.sleep(1.0) + end_time = time.monotonic() + assert end_time > start_time + 0.5, \ + "The process seems to be cancelled, while it shouldn't be." + + +def test_should_wait_no_yield(): + start_time = time.monotonic() + with cuda_interruptible(): + send_ctrl_c() + time.sleep(1.0) + end_time = time.monotonic() + assert end_time > start_time + 0.5, \ + "The process seems to be cancelled, while it shouldn't be." diff --git a/python/setup.py b/python/setup.py index f5b1e8bace..80f687a442 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -102,7 +102,7 @@ # - Cython extensions build and parameters ----------------------------------- -libs = ["nccl", "cusolver", "cusparse", "cublas"] +libs = ['cudart', "nccl", "cusolver", "cusparse", "cublas"] include_dirs = [cuda_include_dir, numpy.get_include(),