Skip to content

Commit

Permalink
Interruptible execution (#433)
Browse files Browse the repository at this point in the history
### Cooperative-style interruptible C++ threads.

This proposal introduces `raft::interruptible` introducing three functions:
```C++
static void synchronize(rmm::cuda_stream_view stream);
static void yield();
static void cancel(std::thread::id thread_id);
```
`synchronize` and `yield` serve as cancellation points for the executing CPU thread. `cancel` allows to throw an async exception in a target CPU thread, which is observed in the nearest cancellation point. Altogether, these allow to cancel a long-running job without killing the OS process.

The key to make this work is an obvious observation that the CPU spends most of the time waiting on `cudaStreamSynchronize`. By replacing that with `interruptible::synchronize`, we introduce cancellation points in all critical places in code. If that is not enough in some edge cases (the cancellation points are too far apart), a developer can use `yield` to ensure that a cancellation request is received sooner rather than later.

#### Implementation

##### C++

`raft::interruptible` keeps an `std::atomic_flag` in the thread-local storage in each thread, which tells whether the thread can continue executing (being in non-cancelled state). [`cancel`](https://github.com/rapidsai/raft/blob/6948cab96483ddc7047b1ae0a162574e32bcd8f0/cpp/include/raft/interruptible.hpp#L122) clears this flag, and [`yield`](https://github.com/rapidsai/raft/blob/6948cab96483ddc7047b1ae0a162574e32bcd8f0/cpp/include/raft/interruptible.hpp#L194-L204) checks it and resets to the signalled state (throwing a `raft::interrupted_exception` exception if necessary). [`synchronize`](https://github.com/rapidsai/raft/blob/6948cab96483ddc7047b1ae0a162574e32bcd8f0/cpp/include/raft/interruptible.hpp#L206-L217) implements a spinning lock querying the state of the stream and `yield`ing on each iteration. I also add an overload [`sync_stream`](https://github.com/rapidsai/raft/blob/ee99523ff6a8257ec213e5ad15292f2132a2a687/cpp/include/raft/handle.hpp#L133) to the raft handle type, to make it easier to modify the behavior of all synchronization calls in raft and cuml.

##### python
This proposal adds a context manager [`cuda_interruptible`](https://github.com/rapidsai/raft/blob/36e8de5f73e9ec7e604b38a4290ac82bc35be4b7/python/raft/common/interruptible.pyx#L28) to handle Ctrl+C requests during C++ calls (using posix signals). `cuda_interruptible` simply calls `raft::interruptible::cancel` on the target C++ thread.

#### Motivation
See rapidsai/cuml#4463

Resolves rapidsai/cuml#4384

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #433
  • Loading branch information
achirkin authored Feb 8, 2022
1 parent 23e1650 commit 1a49fc1
Show file tree
Hide file tree
Showing 59 changed files with 719 additions and 142 deletions.
14 changes: 7 additions & 7 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ endif()
##############################################################################
# - compiler options ---------------------------------------------------------

if (NOT DISABLE_OPENMP)
find_package(OpenMP)
if(OPENMP_FOUND)
message(VERBOSE "RAFT: OpenMP found in ${OpenMP_CXX_INCLUDE_DIRS}")
endif()
endif()

# * find CUDAToolkit package
# * determine GPU architectures
# * enable the CMake CUDA language
Expand All @@ -97,13 +104,6 @@ include(cmake/modules/ConfigureCUDA.cmake)
##############################################################################
# - Requirements -------------------------------------------------------------

if (NOT DISABLE_OPENMP)
find_package(OpenMP)
if(OPENMP_FOUND)
message(VERBOSE "RAFT: OpenMP found in ${OpenMP_CXX_INCLUDE_DIRS}")
endif()
endif()

# add third party dependencies using CPM
rapids_cpm_init()

Expand Down
6 changes: 5 additions & 1 deletion cpp/cmake/modules/ConfigureCUDA.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=============================================================================
# Copyright (c) 2018-2021, NVIDIA CORPORATION.
# Copyright (c) 2018-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,6 +38,10 @@ if(CUDA_ENABLE_LINEINFO)
list(APPEND RAFT_CUDA_FLAGS -lineinfo)
endif()

if(OpenMP_FOUND)
list(APPEND RAFT_CUDA_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS})
endif()

# Debug options
if(CMAKE_BUILD_TYPE MATCHES Debug)
message(VERBOSE "RAFT: Building with debugging flags")
Expand Down
35 changes: 2 additions & 33 deletions cpp/include/raft/comms/detail/mpi_comms.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -333,38 +333,7 @@ class mpi_comms : public comms_iface {
stream));
}

status_t sync_stream(cudaStream_t stream) const
{
cudaError_t cudaErr;
ncclResult_t ncclErr, ncclAsyncErr;
while (1) {
cudaErr = cudaStreamQuery(stream);
if (cudaErr == cudaSuccess) return status_t::SUCCESS;

if (cudaErr != cudaErrorNotReady) {
// An error occurred querying the status of the stream
return status_t::ERROR;
}

ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr);
if (ncclErr != ncclSuccess) {
// An error occurred retrieving the asynchronous error
return status_t::ERROR;
}

if (ncclAsyncErr != ncclSuccess) {
// An asynchronous error happened. Stop the operation and destroy
// the communicator
ncclErr = ncclCommAbort(nccl_comm_);
if (ncclErr != ncclSuccess)
// Caller may abort with an exception or try to re-create a new communicator.
return status_t::ABORT;
}

// Let other threads (including NCCL threads) use the CPU.
pthread_yield();
}
};
status_t sync_stream(cudaStream_t stream) const { return nccl_sync_stream(nccl_comm_, stream); }

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const
Expand Down
37 changes: 3 additions & 34 deletions cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -441,38 +441,7 @@ class std_comms : public comms_iface {
stream));
}

status_t sync_stream(cudaStream_t stream) const
{
cudaError_t cudaErr;
ncclResult_t ncclErr, ncclAsyncErr;
while (1) {
cudaErr = cudaStreamQuery(stream);
if (cudaErr == cudaSuccess) return status_t::SUCCESS;

if (cudaErr != cudaErrorNotReady) {
// An error occurred querying the status of the stream_
return status_t::ERROR;
}

ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr);
if (ncclErr != ncclSuccess) {
// An error occurred retrieving the asynchronous error
return status_t::ERROR;
}

if (ncclAsyncErr != ncclSuccess) {
// An asynchronous error happened. Stop the operation and destroy
// the communicator
ncclErr = ncclCommAbort(nccl_comm_);
if (ncclErr != ncclSuccess)
// Caller may abort with an exception or try to re-create a new communicator.
return status_t::ABORT;
}

// Let other threads (including NCCL threads) use the CPU.
std::this_thread::yield();
}
}
status_t sync_stream(cudaStream_t stream) const { return nccl_sync_stream(nccl_comm_, stream); }

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const
Expand Down Expand Up @@ -553,4 +522,4 @@ class std_comms : public comms_iface {
};
} // namespace detail
} // end namespace comms
} // end namespace raft
} // end namespace raft
18 changes: 9 additions & 9 deletions cpp/include/raft/comms/detail/test.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -53,7 +53,7 @@ bool test_collective_allreduce(const handle_t& handle, int root)

int temp_h = 0;
RAFT_CUDA_TRY(cudaMemcpyAsync(&temp_h, temp_d.data(), 1, cudaMemcpyDeviceToHost, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
communicator.barrier();

std::cout << "Clique size: " << communicator.get_size() << std::endl;
Expand Down Expand Up @@ -88,7 +88,7 @@ bool test_collective_broadcast(const handle_t& handle, int root)
int temp_h = -1; // Verify more than one byte is being sent
RAFT_CUDA_TRY(
cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
communicator.barrier();

std::cout << "Clique size: " << communicator.get_size() << std::endl;
Expand Down Expand Up @@ -121,7 +121,7 @@ bool test_collective_reduce(const handle_t& handle, int root)
int temp_h = -1; // Verify more than one byte is being sent
RAFT_CUDA_TRY(
cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
communicator.barrier();

std::cout << "Clique size: " << communicator.get_size() << std::endl;
Expand Down Expand Up @@ -158,7 +158,7 @@ bool test_collective_allgather(const handle_t& handle, int root)
int temp_h[communicator.get_size()]; // Verify more than one byte is being sent
RAFT_CUDA_TRY(cudaMemcpyAsync(
&temp_h, recv_d.data(), sizeof(int) * communicator.get_size(), cudaMemcpyDeviceToHost, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
communicator.barrier();

std::cout << "Clique size: " << communicator.get_size() << std::endl;
Expand Down Expand Up @@ -198,7 +198,7 @@ bool test_collective_gather(const handle_t& handle, int root)
std::vector<int> temp_h(communicator.get_size(), 0);
RAFT_CUDA_TRY(cudaMemcpyAsync(
temp_h.data(), recv_d.data(), sizeof(int) * temp_h.size(), cudaMemcpyDeviceToHost, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

for (int i = 0; i < communicator.get_size(); i++) {
if (temp_h[i] != i) return false;
Expand Down Expand Up @@ -253,7 +253,7 @@ bool test_collective_gatherv(const handle_t& handle, int root)
sizeof(int) * displacements.back(),
cudaMemcpyDeviceToHost,
stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

for (int i = 0; i < communicator.get_size(); i++) {
if (std::count_if(temp_h.begin() + displacements[i],
Expand Down Expand Up @@ -292,7 +292,7 @@ bool test_collective_reducescatter(const handle_t& handle, int root)
int temp_h = -1; // Verify more than one byte is being sent
RAFT_CUDA_TRY(
cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
communicator.barrier();

std::cout << "Clique size: " << communicator.get_size() << std::endl;
Expand Down Expand Up @@ -502,7 +502,7 @@ bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrial

std::vector<int> h_received_data(communicator.get_size());
raft::update_host(h_received_data.data(), received_data.data(), received_data.size(), stream);
CUDA_TRY(cudaStreamSynchronize(stream));
h.sync_stream(stream);
for (int i = 0; i < communicator.get_size(); ++i) {
if (h_received_data[i] != i) { ret = false; }
}
Expand Down
40 changes: 39 additions & 1 deletion cpp/include/raft/comms/detail/util.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

#pragma once

#include <raft/interruptible.hpp>

#include <nccl.h>
#include <raft/error.hpp>
#include <string>
Expand Down Expand Up @@ -109,6 +111,42 @@ get_nccl_op(const op_t op)
default: throw "Unsupported datatype";
}
}

status_t nccl_sync_stream(ncclComm_t comm, cudaStream_t stream)
{
cudaError_t cudaErr;
ncclResult_t ncclErr, ncclAsyncErr;
while (1) {
cudaErr = cudaStreamQuery(stream);
if (cudaErr == cudaSuccess) return status_t::SUCCESS;

if (cudaErr != cudaErrorNotReady) {
// An error occurred querying the status of the stream_
return status_t::ERROR;
}

ncclErr = ncclCommGetAsyncError(comm, &ncclAsyncErr);
if (ncclErr != ncclSuccess) {
// An error occurred retrieving the asynchronous error
return status_t::ERROR;
}

if (ncclAsyncErr != ncclSuccess || !interruptible::yield_no_throw()) {
// An asynchronous error happened. Stop the operation and destroy
// the communicator
ncclErr = ncclCommAbort(comm);
if (ncclErr != ncclSuccess)
// Caller may abort with an exception or try to re-create a new communicator.
return status_t::ABORT;
// TODO: shouldn't we place status_t::ERROR above under the condition, and
// status_t::ABORT below here (i.e. after successful ncclCommAbort)?
}

// Let other threads (including NCCL threads) use the CPU.
std::this_thread::yield();
}
}

}; // namespace detail
}; // namespace comms
}; // namespace raft
14 changes: 10 additions & 4 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,7 @@
#include "cudart_utils.h"

#include <raft/comms/comms.hpp>
#include <raft/interruptible.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/detail/cusolver_wrappers.hpp>
#include <raft/sparse/cusparse_wrappers.h>
Expand Down Expand Up @@ -127,10 +128,15 @@ class handle_t {

rmm::exec_policy& get_thrust_policy() const { return *thrust_policy_; }

/**
* @brief synchronize a stream on the handle
*/
void sync_stream(rmm::cuda_stream_view stream) const { interruptible::synchronize(stream); }

/**
* @brief synchronize main stream on the handle
*/
void sync_stream() const { stream_view_.synchronize(); }
void sync_stream() const { sync_stream(stream_view_); }

/**
* @brief returns main stream on the handle
Expand Down Expand Up @@ -199,7 +205,7 @@ class handle_t {
void sync_stream_pool() const
{
for (std::size_t i = 0; i < get_stream_pool_size(); i++) {
stream_pool_->get_stream(i).synchronize();
sync_stream(stream_pool_->get_stream(i));
}
}

Expand All @@ -212,7 +218,7 @@ class handle_t {
{
RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized");
for (const auto& stream_index : stream_indices) {
stream_pool_->get_stream(stream_index).synchronize();
sync_stream(stream_pool_->get_stream(stream_index));
}
}

Expand Down
Loading

0 comments on commit 1a49fc1

Please sign in to comment.