Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interruptible execution #433

Merged
merged 40 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c18cab1
First take
achirkin Dec 16, 2021
34d4023
Merge branch 'branch-22.02' into fea-interruptible
achirkin Dec 21, 2021
e1a0c3a
Some refactoring and yield function
achirkin Dec 21, 2021
f6222dc
Fix a typo
achirkin Dec 21, 2021
ee99523
Added a python Ctrl+C handler wrapper
achirkin Dec 22, 2021
a07edae
Fix linter
achirkin Dec 22, 2021
b3119bb
Fix linter
achirkin Dec 22, 2021
54a0599
Initialize cuda primitives lazily and add a mutex-free non-static can…
achirkin Jan 10, 2022
db5adfd
Fix relative import
achirkin Jan 10, 2022
5539984
Fix deallocation issue with shared_ptr + unordered_map
achirkin Jan 11, 2022
4b95859
Refactor names
achirkin Jan 11, 2022
36e8de5
Merge branch 'branch-22.02' of https://github.com/rapidsai/raft into …
achirkin Jan 11, 2022
a2610d1
Make comms sync_stream interruptible
achirkin Jan 11, 2022
53155e9
Enable OpenMP in raft
achirkin Jan 12, 2022
396beda
Add gtests
achirkin Jan 12, 2022
636b529
add pytests
achirkin Jan 12, 2022
2b65798
Make clang-format happy
achirkin Jan 12, 2022
6b96f3b
Make flake8 happy
achirkin Jan 12, 2022
23b681d
Merge branch 'branch-22.02' into fea-interruptible
achirkin Jan 12, 2022
b579d72
Support python < 3.8
achirkin Jan 12, 2022
5405c12
Update cpp/include/raft/interruptible.hpp
achirkin Jan 13, 2022
0f8bc71
Change implementation: now it's a spinning lock
achirkin Jan 13, 2022
81828f6
Fix comms due to changed yield_no_throw semantics
achirkin Jan 13, 2022
6948cab
Account for the possibility of repeating std::thread::id
achirkin Jan 13, 2022
6e7aa24
Simplify the thread::id workaround (no more global seq_id)
achirkin Jan 14, 2022
02d95db
Merge branch 'branch-22.02' of https://github.com/rapidsai/raft into …
achirkin Jan 14, 2022
777d5ed
Merge branch 'branch-22.04' into fea-interruptible
cjnolet Jan 25, 2022
cbe44d8
Add synchronize(cudaEvent_t) and fix python bindings
achirkin Jan 26, 2022
9658ca4
Make stream pool interruptible as well
achirkin Jan 26, 2022
b1b8edf
Merge branch 'branch-22.04' into fea-interruptible
achirkin Jan 28, 2022
cf6c6ff
Merge branch 'branch-22.04' into fea-interruptible
achirkin Feb 2, 2022
47fad7b
Merge branch 'branch-22.04' into fea-interruptible
achirkin Feb 3, 2022
3e67ec0
Update docs
achirkin Feb 3, 2022
fc81823
Merge branch 'branch-22.04' into fea-interruptible
cjnolet Feb 4, 2022
c1a7070
Merge branch 'branch-22.04' into fea-interruptible
achirkin Feb 5, 2022
98c9035
Add 'cudart' to cython libs
achirkin Feb 5, 2022
dbcdcf0
Don't use __nanosleep on older archs
achirkin Feb 7, 2022
853b5c3
Add a comment about using thread-local storage.
achirkin Feb 7, 2022
d32f4df
Merge remote-tracking branch 'rapidsai/branch-22.04' into fea-interru…
achirkin Feb 8, 2022
e8b7b54
Replace more cudaStreamSynchronize with handle.sync_stream
achirkin Feb 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ endif()
##############################################################################
# - compiler options ---------------------------------------------------------

if (NOT DISABLE_OPENMP)
find_package(OpenMP)
if(OPENMP_FOUND)
message(VERBOSE "RAFT: OpenMP found in ${OpenMP_CXX_INCLUDE_DIRS}")
endif()
endif()

# * find CUDAToolkit package
# * determine GPU architectures
# * enable the CMake CUDA language
Expand All @@ -95,13 +102,6 @@ include(cmake/modules/ConfigureCUDA.cmake)
##############################################################################
# - Requirements -------------------------------------------------------------

if (NOT DISABLE_OPENMP)
find_package(OpenMP)
if(OPENMP_FOUND)
message(VERBOSE "RAFT: OpenMP found in ${OpenMP_CXX_INCLUDE_DIRS}")
endif()
endif()

# add third party dependencies using CPM
rapids_cpm_init()

Expand Down
6 changes: 5 additions & 1 deletion cpp/cmake/modules/ConfigureCUDA.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=============================================================================
# Copyright (c) 2018-2021, NVIDIA CORPORATION.
# Copyright (c) 2018-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,6 +38,10 @@ if(CUDA_ENABLE_LINEINFO)
list(APPEND RAFT_CUDA_FLAGS -lineinfo)
endif()

if(OpenMP_FOUND)
list(APPEND RAFT_CUDA_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS})
endif()

# Debug options
if(CMAKE_BUILD_TYPE MATCHES Debug)
message(VERBOSE "RAFT: Building with debugging flags")
Expand Down
35 changes: 2 additions & 33 deletions cpp/include/raft/comms/detail/mpi_comms.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -333,38 +333,7 @@ class mpi_comms : public comms_iface {
stream));
}

status_t sync_stream(cudaStream_t stream) const
{
cudaError_t cudaErr;
ncclResult_t ncclErr, ncclAsyncErr;
while (1) {
cudaErr = cudaStreamQuery(stream);
if (cudaErr == cudaSuccess) return status_t::SUCCESS;

if (cudaErr != cudaErrorNotReady) {
// An error occurred querying the status of the stream
return status_t::ERROR;
}

ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr);
if (ncclErr != ncclSuccess) {
// An error occurred retrieving the asynchronous error
return status_t::ERROR;
}

if (ncclAsyncErr != ncclSuccess) {
// An asynchronous error happened. Stop the operation and destroy
// the communicator
ncclErr = ncclCommAbort(nccl_comm_);
if (ncclErr != ncclSuccess)
// Caller may abort with an exception or try to re-create a new communicator.
return status_t::ABORT;
}

// Let other threads (including NCCL threads) use the CPU.
pthread_yield();
}
};
status_t sync_stream(cudaStream_t stream) const { return nccl_sync_stream(nccl_comm_, stream); }

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const
Expand Down
37 changes: 3 additions & 34 deletions cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -441,38 +441,7 @@ class std_comms : public comms_iface {
stream));
}

status_t sync_stream(cudaStream_t stream) const
{
cudaError_t cudaErr;
ncclResult_t ncclErr, ncclAsyncErr;
while (1) {
cudaErr = cudaStreamQuery(stream);
if (cudaErr == cudaSuccess) return status_t::SUCCESS;

if (cudaErr != cudaErrorNotReady) {
// An error occurred querying the status of the stream_
return status_t::ERROR;
}

ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr);
if (ncclErr != ncclSuccess) {
// An error occurred retrieving the asynchronous error
return status_t::ERROR;
}

if (ncclAsyncErr != ncclSuccess) {
// An asynchronous error happened. Stop the operation and destroy
// the communicator
ncclErr = ncclCommAbort(nccl_comm_);
if (ncclErr != ncclSuccess)
// Caller may abort with an exception or try to re-create a new communicator.
return status_t::ABORT;
}

// Let other threads (including NCCL threads) use the CPU.
std::this_thread::yield();
}
}
status_t sync_stream(cudaStream_t stream) const { return nccl_sync_stream(nccl_comm_, stream); }

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const
Expand Down Expand Up @@ -553,4 +522,4 @@ class std_comms : public comms_iface {
};
} // namespace detail
} // end namespace comms
} // end namespace raft
} // end namespace raft
40 changes: 39 additions & 1 deletion cpp/include/raft/comms/detail/util.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

#pragma once

#include <raft/interruptible.hpp>

#include <nccl.h>
#include <raft/error.hpp>
#include <string>
Expand Down Expand Up @@ -109,6 +111,42 @@ get_nccl_op(const op_t op)
default: throw "Unsupported datatype";
}
}

status_t nccl_sync_stream(ncclComm_t comm, cudaStream_t stream)
{
cudaError_t cudaErr;
ncclResult_t ncclErr, ncclAsyncErr;
while (1) {
cudaErr = cudaStreamQuery(stream);
if (cudaErr == cudaSuccess) return status_t::SUCCESS;

if (cudaErr != cudaErrorNotReady) {
// An error occurred querying the status of the stream_
return status_t::ERROR;
}

ncclErr = ncclCommGetAsyncError(comm, &ncclAsyncErr);
if (ncclErr != ncclSuccess) {
// An error occurred retrieving the asynchronous error
return status_t::ERROR;
}

if (ncclAsyncErr != ncclSuccess || !interruptible::yield_no_throw()) {
// An asynchronous error happened. Stop the operation and destroy
// the communicator
ncclErr = ncclCommAbort(comm);
if (ncclErr != ncclSuccess)
// Caller may abort with an exception or try to re-create a new communicator.
return status_t::ABORT;
// TODO: shouldn't we place status_t::ERROR above under the condition, and
// status_t::ABORT below here (i.e. after successful ncclCommAbort)?
}

// Let other threads (including NCCL threads) use the CPU.
std::this_thread::yield();
}
}

}; // namespace detail
}; // namespace comms
}; // namespace raft
14 changes: 10 additions & 4 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,7 @@

#include "cudart_utils.h"
#include <raft/comms/comms.hpp>
#include <raft/interruptible.hpp>
#include <raft/linalg/cublas_wrappers.h>
#include <raft/linalg/cusolver_wrappers.h>
#include <raft/sparse/cusparse_wrappers.h>
Expand Down Expand Up @@ -126,10 +127,15 @@ class handle_t {

rmm::exec_policy& get_thrust_policy() const { return *thrust_policy_; }

/**
* @brief synchronize a stream on the handle
*/
void sync_stream(rmm::cuda_stream_view stream) const { interruptible::synchronize(stream); }

/**
* @brief synchronize main stream on the handle
*/
void sync_stream() const { stream_view_.synchronize(); }
void sync_stream() const { sync_stream(stream_view_); }

/**
* @brief returns main stream on the handle
Expand Down Expand Up @@ -198,7 +204,7 @@ class handle_t {
void sync_stream_pool() const
{
for (std::size_t i = 0; i < get_stream_pool_size(); i++) {
stream_pool_->get_stream(i).synchronize();
sync_stream(stream_pool_->get_stream(i));
}
}

Expand All @@ -211,7 +217,7 @@ class handle_t {
{
RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized");
for (const auto& stream_index : stream_indices) {
stream_pool_->get_stream(stream_index).synchronize();
sync_stream(stream_pool_->get_stream(stream_index));
}
}

Expand Down
Loading