Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interruptible execution #433

Merged
merged 40 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c18cab1
First take
achirkin Dec 16, 2021
34d4023
Merge branch 'branch-22.02' into fea-interruptible
achirkin Dec 21, 2021
e1a0c3a
Some refactoring and yield function
achirkin Dec 21, 2021
f6222dc
Fix a typo
achirkin Dec 21, 2021
ee99523
Added a python Ctrl+C handler wrapper
achirkin Dec 22, 2021
a07edae
Fix linter
achirkin Dec 22, 2021
b3119bb
Fix linter
achirkin Dec 22, 2021
54a0599
Initialize cuda primitives lazily and add a mutex-free non-static can…
achirkin Jan 10, 2022
db5adfd
Fix relative import
achirkin Jan 10, 2022
5539984
Fix deallocation issue with shared_ptr + unordered_map
achirkin Jan 11, 2022
4b95859
Refactor names
achirkin Jan 11, 2022
36e8de5
Merge branch 'branch-22.02' of https://github.com/rapidsai/raft into …
achirkin Jan 11, 2022
a2610d1
Make comms sync_stream interruptible
achirkin Jan 11, 2022
53155e9
Enable OpenMP in raft
achirkin Jan 12, 2022
396beda
Add gtests
achirkin Jan 12, 2022
636b529
add pytests
achirkin Jan 12, 2022
2b65798
Make clang-format happy
achirkin Jan 12, 2022
6b96f3b
Make flake8 happy
achirkin Jan 12, 2022
23b681d
Merge branch 'branch-22.02' into fea-interruptible
achirkin Jan 12, 2022
b579d72
Support python < 3.8
achirkin Jan 12, 2022
5405c12
Update cpp/include/raft/interruptible.hpp
achirkin Jan 13, 2022
0f8bc71
Change implementation: now it's a spinning lock
achirkin Jan 13, 2022
81828f6
Fix comms due to changed yield_no_throw semantics
achirkin Jan 13, 2022
6948cab
Account for the possibility of repeating std::thread::id
achirkin Jan 13, 2022
6e7aa24
Simplify the thread::id workaround (no more global seq_id)
achirkin Jan 14, 2022
02d95db
Merge branch 'branch-22.02' of https://github.com/rapidsai/raft into …
achirkin Jan 14, 2022
777d5ed
Merge branch 'branch-22.04' into fea-interruptible
cjnolet Jan 25, 2022
cbe44d8
Add synchronize(cudaEvent_t) and fix python bindings
achirkin Jan 26, 2022
9658ca4
Make stream pool interruptible as well
achirkin Jan 26, 2022
b1b8edf
Merge branch 'branch-22.04' into fea-interruptible
achirkin Jan 28, 2022
cf6c6ff
Merge branch 'branch-22.04' into fea-interruptible
achirkin Feb 2, 2022
47fad7b
Merge branch 'branch-22.04' into fea-interruptible
achirkin Feb 3, 2022
3e67ec0
Update docs
achirkin Feb 3, 2022
fc81823
Merge branch 'branch-22.04' into fea-interruptible
cjnolet Feb 4, 2022
c1a7070
Merge branch 'branch-22.04' into fea-interruptible
achirkin Feb 5, 2022
98c9035
Add 'cudart' to cython libs
achirkin Feb 5, 2022
dbcdcf0
Don't use __nanosleep on older archs
achirkin Feb 7, 2022
853b5c3
Add a comment about using thread-local storage.
achirkin Feb 7, 2022
d32f4df
Merge remote-tracking branch 'rapidsai/branch-22.04' into fea-interru…
achirkin Feb 8, 2022
e8b7b54
Replace more cudaStreamSynchronize with handle.sync_stream
achirkin Feb 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "cudart_utils.h"
#include <raft/comms/comms.hpp>
#include <raft/interruptible.hpp>
#include <raft/linalg/cublas_wrappers.h>
#include <raft/linalg/cusolver_wrappers.h>
#include <raft/sparse/cusparse_wrappers.h>
Expand Down Expand Up @@ -126,10 +127,15 @@ class handle_t {

rmm::exec_policy& get_thrust_policy() const { return *thrust_policy_; }

/**
* @brief synchronize a stream on the handle
*/
void sync_stream(rmm::cuda_stream_view stream) const { interruptible::synchronize(stream); }

/**
* @brief synchronize main stream on the handle
*/
void sync_stream() const { stream_view_.synchronize(); }
void sync_stream() const { sync_stream(stream_view_); }

/**
* @brief returns main stream on the handle
Expand Down
159 changes: 159 additions & 0 deletions cpp/include/raft/interruptible.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <mutex>
#include <raft/cudart_utils.h>
#include <raft/error.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <thread>
#include <unordered_map>

namespace raft {

/**
* @brief Exception thrown during `interruptible::synchronize` call when it detects a request
* to cancel the work performed in this CPU thread.
*/
struct interrupted : public raft::exception {
achirkin marked this conversation as resolved.
Show resolved Hide resolved
explicit interrupted(char const* const message) : raft::exception(message) {}
explicit interrupted(std::string const& message) : raft::exception(message) {}
};

class interruptible {
public:
/**
* @brief Synchronize the CUDA stream, subject to being interrupted by `interruptible::cancel`
* called on this CPU thread.
*
* @param [in] stream a CUDA stream.
*
* @throw raft::interrupted if interruptible::cancel() was called on the current CPU thread id
* before the currently captured work has been finished.
* @throw raft::cuda_error if another CUDA error happens.
*/
static void synchronize(rmm::cuda_stream_view stream) { store_.synchronize(stream); }

/**
* @brief Check the thread state, whether the thread is interrupted by `interruptible::cancel`.
*
* This is a cancellation point for an interruptible thread. It's called in the internals of
* `interruptible::synchronize` in a loop. If two synchronize calls are far apart, it's
* recommended to call `interruptible::yield()` in between to make sure the thread does not become
* unresponsive for too long.
*
* @throw raft::interrupted if interruptible::cancel() was called on the current CPU thread.
*/
static void yield() { store_.yield(); }

/**
* @brief Cancel any current or next call to `interruptible::synchronize` performed on the
* CPU thread given by the `thread_id`
*
* @param [in] thread_id a CPU thread, in which the work should be interrupted.
*
* @throw raft::cuda_error if a CUDA error happens during recording the interruption event.
*/
static void cancel(std::thread::id thread_id) { store::cancel(thread_id); }

private:
/*
* Implementation-wise, the cancellation feature is bound to the CPU threads.
* Each thread has an associated thread-local state store comprising a boolean flag
* and a CUDA event. The event plays the role of a condition variable; it can be triggered
* either when the work captured in the stream+event is finished or when the cancellation request
* is issued to the current CPU thread. For the latter, we keep internally one global stream
achirkin marked this conversation as resolved.
Show resolved Hide resolved
* for recording the "interrupt" events in any CPU thread.
*/
static inline thread_local class store {
private:
/** Global registery of thread-local cancellation stores. */
static inline std::unordered_map<std::thread::id, store*> registry_;
/** Protect the access to the registry. */
static inline std::mutex mutex_;
/** The only purpose of this stream is to record the interruption events. */
static inline std::unique_ptr<cudaStream_t, std::function<void(cudaStream_t*)>>
cancellation_stream_{
[]() {
auto* stream = new cudaStream_t;
RAFT_CUDA_TRY(cudaStreamCreateWithFlags(stream, cudaStreamNonBlocking));
return stream;
}(),
[](cudaStream_t* stream) {
RAFT_CUDA_TRY(cudaStreamDestroy(*stream));
delete stream;
}};

/** The state of being in the process of cancelling. */
bool cancelled_ = false;
/** The main synchronization primitive for the current CPU thread on the CUDA side. */
cudaEvent_t wait_interrupt_ = nullptr;

public:
store()
{
std::lock_guard<std::mutex> guard(mutex_);
registry_[std::this_thread::get_id()] = this;
RAFT_CUDA_TRY(
cudaEventCreateWithFlags(&wait_interrupt_, cudaEventBlockingSync | cudaEventDisableTiming));
}
achirkin marked this conversation as resolved.
Show resolved Hide resolved
~store()
{
std::lock_guard<std::mutex> guard(mutex_);
registry_.erase(std::this_thread::get_id());
cudaEventDestroy(wait_interrupt_);
}

void yield()
{
if (cancelled_) {
cancelled_ = false;
throw interrupted("The work in this thread was cancelled.");
}
}

void synchronize(rmm::cuda_stream_view stream)
{
// This function synchronizes the CPU thread on the "interrupt" event instead of
// the given stream.
// Assuming that this method is called only on a thread-local store, there is no need for
// extra synchronization primitives to protect the state.
RAFT_CUDA_TRY(cudaEventRecord(wait_interrupt_, stream));
RAFT_CUDA_TRY(cudaEventSynchronize(wait_interrupt_));
yield();
}

void cancel()
{
// This method is supposed to be called from another thread;
// multiple calls to it just override each other, and that is ok - the cancellation request
// will be delivered (at least once).
cancelled_ = true;
RAFT_CUDA_TRY(cudaEventRecord(wait_interrupt_, *cancellation_stream_));
}

static void cancel(std::thread::id thread_id)
{
// The mutex here is neededd to make sure the registry_ is not accessed during
// the registration of a new thread (when the registry_ is altered).
std::lock_guard<std::mutex> guard(mutex_);
registry_[thread_id]->cancel();
}
} store_;
};

} // namespace raft
2 changes: 1 addition & 1 deletion cpp/include/raft/linalg/cholesky_r1_update.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ void choleskyRank1Update(const raft::handle_t& handle,
math_t L_22_host;
raft::update_host(&s_host, s, 1, stream);
raft::update_host(&L_22_host, L_22, 1, stream); // L_22 stores A_22
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
L_22_host = std::sqrt(L_22_host - s_host);

// Check for numeric error with sqrt. If the matrix is not positive definit or
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/linalg/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void svdQR(const raft::handle_t& handle,

int dev_info;
raft::update_host(&dev_info, devInfo.data(), 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
ASSERT(dev_info == 0,
"svd.cuh: svd couldn't converge to a solution. "
"This usually occurs when some of the features do not vary enough.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void build_dendrogram_host(const handle_t& handle,
update_host(mst_dst_h.data(), cols, n_edges, stream);
update_host(mst_weights_h.data(), data, n_edges, stream);

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

std::vector<value_idx> children_h(n_edges * 2);
std::vector<value_idx> out_size_h(n_edges);
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/linalg/detail/spectral.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void fit_embedding(const raft::handle_t& handle,
rmm::device_uvector<T> eigVecs(n * (n_components + 1), stream);
rmm::device_uvector<int> labels(n, stream);

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

/**
* Raft spectral clustering
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/op/detail/reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void max_duplicates(const raft::handle_t& handle,
// compute final size
value_idx size = 0;
raft::update_host(&size, diff.data() + (diff.size() - 1), 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
size++;

out.allocate(size, m, n, true, stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ void connect_components(
// compute final size
value_idx size = 0;
raft::update_host(&size, out_index.data() + (out_index.size() - 1), 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

size++;

Expand Down
2 changes: 1 addition & 1 deletion cpp/test/distance/dist_adj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class DistanceAdjTest : public ::testing::TestWithParam<DistanceAdjInputs<DataTy
fin_op,
stream,
isRowMajor);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

void TearDown() override {}
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/distance/distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
stream,
isRowMajor,
metric_arg);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ TEST(Raft, Handle)
rmm::cuda_stream_view stream_view(stream);
handle_t handle(stream_view);
ASSERT_EQ(stream_view, handle.get_stream());
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
RAFT_CUDA_TRY(cudaStreamDestroy(stream));
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/add.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class AddTest : public ::testing::TestWithParam<AddInputs<InT, OutT>> {
r.uniform(in2.data(), len, InT(-1.0), InT(1.0), stream);
naiveAddElem<InT, OutT>(out_ref.data(), in1.data(), in2.data(), len, stream);
add<InT, OutT>(out.data(), in1.data(), in2.data(), len, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

void compare()
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/binary_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class BinaryOpTest : public ::testing::TestWithParam<BinaryOpInputs<InType, IdxT
r.uniform(in2.data(), len, InType(-1.0), InType(1.0), stream);
naiveAdd(out_ref.data(), in1.data(), in2.data(), len);
binaryOpLaunch(out.data(), in1.data(), in2.data(), len, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/coalesced_reduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class coalescedReductionTest : public ::testing::TestWithParam<coalescedReductio
// Add to result with inplace = true next
coalescedReductionLaunch(dots_act.data(), data.data(), cols, rows, stream, true);

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/divide.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class DivideTest : public ::testing::TestWithParam<raft::linalg::UnaryOpInputs<T
r.uniform(in.data(), len, T(-1.0), T(1.0), stream);
naiveDivide(out_ref.data(), in.data(), params.scalar, len, stream);
divideScalar(out.data(), in.data(), params.scalar, len, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/eig.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class EigTest : public ::testing::TestWithParam<EigInputs<T>> {
stream,
tol,
sweeps);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/eig_sel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class EigSelTest : public ::testing::TestWithParam<EigSelInputs<T>> {
eig_vals.data(),
EigVecMemUsage::OVERWRITE_INPUT,
stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/linalg/eltwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ScalarMultiplyTest : public ::testing::TestWithParam<ScalarMultiplyInputs<
r.uniform(in, len, T(-1.0), T(1.0), stream);
naiveScale(out_ref, in, scalar, len, stream);
scalarMultiply(out, in, scalar, len, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down Expand Up @@ -164,7 +164,7 @@ class EltwiseAddTest : public ::testing::TestWithParam<EltwiseAddInputs<T>> {
r.uniform(in2, len, T(-1.0), T(1.0), stream);
naiveAdd(out_ref, in1, in2, len, stream);
eltwiseAdd(out, in1, in2, len, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/map_then_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class MapReduceTest : public ::testing::TestWithParam<MapReduceInputs<InType>> {
auto len = params.len;
r.uniform(in.data(), len, InType(-1.0), InType(1.0), stream);
mapReduceLaunch(out_ref.data(), out.data(), in.data(), len, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/matrix_vector_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class MatVecOpTest : public ::testing::TestWithParam<MatVecOpInputs<T, IdxType>>
params.bcastAlongRows,
params.useTwoVectors,
stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/multiply.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class MultiplyTest : public ::testing::TestWithParam<UnaryOpInputs<T>> {
r.uniform(in.data(), len, T(-1.0), T(1.0), stream);
naiveScale(out_ref.data(), in.data(), params.scalar, len, stream);
multiplyScalar(out.data(), in.data(), params.scalar, len, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/linalg/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class RowNormTest : public ::testing::TestWithParam<NormInputs<T>> {
} else {
rowNorm(dots_act.data(), data.data(), cols, rows, params.type, params.rowMajor, stream);
}
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down Expand Up @@ -159,7 +159,7 @@ class ColNormTest : public ::testing::TestWithParam<NormInputs<T>> {
} else {
colNorm(dots_act.data(), data.data(), cols, rows, params.type, params.rowMajor, stream);
}
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ReduceTest : public ::testing::TestWithParam<ReduceInputs<InType, OutType>
reduceLaunch(
dots_act.data(), data.data(), cols, rows, params.rowMajor, params.alongRows, true, stream);
}
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

protected:
Expand Down
Loading