Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix worker streams in OLS-eig executing in an unsafe order #4539

Merged
merged 6 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions cpp/src_prims/linalg/lstsq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,22 @@ struct DeviceEvent {
DeviceEvent(bool concurrent)
{
if (concurrent)
RAFT_CUDA_TRY(cudaEventCreate(&e));
RAFT_CUDA_TRY(cudaEventCreateWithFlags(&e, cudaEventDisableTiming));
else
e = nullptr;
}
~DeviceEvent()
{
if (e != nullptr) RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(e));
}
operator cudaEvent_t() const { return e; }
void record(cudaStream_t stream)
{
if (e != nullptr) RAFT_CUDA_TRY(cudaEventRecord(e, stream));
}
void wait(cudaStream_t stream)
void wait_by(cudaStream_t stream)
{
if (e != nullptr) RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, e, 0u));
}
void wait()
{
if (e != nullptr) raft::interruptible::synchronize(e);
}
DeviceEvent& operator=(const DeviceEvent& other) = delete;
};

Expand Down Expand Up @@ -259,27 +254,26 @@ void lstsqEig(const raft::handle_t& handle,
cudaStream_t stream)
{
rmm::cuda_stream_view mainStream = rmm::cuda_stream_view(stream);
rmm::cuda_stream_view multAbStream = mainStream;
bool concurrent = false;
{
int sp_size = handle.get_stream_pool_size();
if (sp_size > 0) {
multAbStream = handle.get_stream_from_stream_pool(0);
// check if the two streams can run concurrently
if (!are_implicitly_synchronized(mainStream, multAbStream)) {
concurrent = true;
} else if (sp_size > 1) {
mainStream = multAbStream;
multAbStream = handle.get_stream_from_stream_pool(1);
concurrent = true;
}
}
rmm::cuda_stream_view multAbStream = handle.get_next_usable_stream();
bool concurrent;
// Check if the two streams can run concurrently. This is needed because a legacy default stream
// would synchronize with other blocking streams. To avoid synchronization in such case, we try to
// use an additional stream from the pool.
if (!are_implicitly_synchronized(mainStream, multAbStream)) {
concurrent = true;
} else if (handle.get_stream_pool_size() > 1) {
mainStream = handle.get_next_usable_stream();
concurrent = true;
} else {
multAbStream = mainStream;
concurrent = false;
}
// the event is created only if the given raft handle is capable of running
// at least two CUDA streams without implicit synchronization.
DeviceEvent multAbDone(concurrent);

rmm::device_uvector<math_t> workset(n_cols * n_cols * 3 + n_cols * 2, mainStream);
// the event is created only if the given raft handle is capable of running
// at least two CUDA streams without implicit synchronization.
DeviceEvent worksetDone(concurrent);
worksetDone.record(mainStream);
math_t* Q = workset.data();
math_t* QS = Q + n_cols * n_cols;
math_t* covA = QS + n_cols * n_cols;
Expand All @@ -304,7 +298,9 @@ void lstsqEig(const raft::handle_t& handle,
mainStream);

// Ab <- A* b
worksetDone.wait_by(multAbStream);
raft::linalg::gemv(handle, A, n_rows, n_cols, b, Ab, true, multAbStream);
DeviceEvent multAbDone(concurrent);
multAbDone.record(multAbStream);

// Q S Q* <- covA
Expand All @@ -329,9 +325,18 @@ void lstsqEig(const raft::handle_t& handle,
alpha,
beta,
mainStream);
multAbDone.wait(mainStream);

multAbDone.wait_by(mainStream);
// w <- covA Ab == Q invS Q* A b == inv(A* A) A b
raft::linalg::gemv(handle, covA, n_cols, n_cols, Ab, w, false, mainStream);

// This event is created only if we use two worker streams, and `stream` is not the legacy stream,
// and `mainStream` is not a non-blocking stream. In fact, with the current logic these conditions
// are impossible together, but it still makes sense to put this construct here to emphasize that
// `stream` must wait till the work here is done (for future refactorings).
DeviceEvent mainDone(!are_implicitly_synchronized(mainStream, stream));
mainDone.record(mainStream);
mainDone.wait_by(stream);
}

/** Solves the linear ordinary least squares problem `Aw = b`
Expand Down
38 changes: 33 additions & 5 deletions cpp/test/sg/ols.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <raft/cuda_utils.cuh>
#include <raft/cudart_utils.h>
#include <raft/mr/device/allocator.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <test_utils.h>
#include <vector>

Expand All @@ -27,8 +28,29 @@ namespace GLM {

using namespace MLCommon;

enum class hconf { SINGLE, LEGACY_ONE, LEGACY_TWO, NON_BLOCKING_ONE, NON_BLOCKING_TWO };

raft::handle_t create_handle(hconf type)
{
switch (type) {
case hconf::LEGACY_ONE:
return raft::handle_t(rmm::cuda_stream_legacy, std::make_shared<rmm::cuda_stream_pool>(1));
case hconf::LEGACY_TWO:
return raft::handle_t(rmm::cuda_stream_legacy, std::make_shared<rmm::cuda_stream_pool>(2));
case hconf::NON_BLOCKING_ONE:
return raft::handle_t(rmm::cuda_stream_per_thread,
std::make_shared<rmm::cuda_stream_pool>(1));
case hconf::NON_BLOCKING_TWO:
return raft::handle_t(rmm::cuda_stream_per_thread,
std::make_shared<rmm::cuda_stream_pool>(2));
case hconf::SINGLE:
default: return raft::handle_t();
}
}

template <typename T>
struct OlsInputs {
hconf hc;
T tol;
int n_row;
int n_col;
Expand All @@ -41,6 +63,7 @@ class OlsTest : public ::testing::TestWithParam<OlsInputs<T>> {
public:
OlsTest()
: params(::testing::TestWithParam<OlsInputs<T>>::GetParam()),
handle(create_handle(params.hc)),
stream(handle.get_stream()),
coef(params.n_col, stream),
coef2(params.n_col, stream),
Expand Down Expand Up @@ -216,10 +239,11 @@ class OlsTest : public ::testing::TestWithParam<OlsInputs<T>> {
}

protected:
OlsInputs<T> params;

raft::handle_t handle;
cudaStream_t stream = 0;

OlsInputs<T> params;
rmm::device_uvector<T> coef, coef_ref, pred, pred_ref;
rmm::device_uvector<T> coef2, coef2_ref, pred2, pred2_ref;
rmm::device_uvector<T> coef3, coef3_ref, pred3, pred3_ref;
Expand All @@ -228,11 +252,15 @@ class OlsTest : public ::testing::TestWithParam<OlsInputs<T>> {
T intercept, intercept2, intercept3;
};

const std::vector<OlsInputs<float>> inputsf2 = {
{0.001f, 4, 2, 2, 0}, {0.001f, 4, 2, 2, 1}, {0.001f, 4, 2, 2, 2}};
const std::vector<OlsInputs<float>> inputsf2 = {{hconf::NON_BLOCKING_ONE, 0.001f, 4, 2, 2, 0},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the changes in this PR, are these assertions able to reliably reproduce the problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this bug seems to be very elusive. I managed to reproduce it only under some specific conditions in python, but then it disappeared again after I did some further changes to optimize preProcessData (perhaps, due to changing the pattern of calls using the main stream / rmm resources).
Yet I hope the changes in these tests will help to find other streams-related bugs if there are any more.

{hconf::NON_BLOCKING_TWO, 0.001f, 4, 2, 2, 1},
{hconf::LEGACY_ONE, 0.001f, 4, 2, 2, 2},
{hconf::LEGACY_TWO, 0.001f, 4, 2, 2, 2},
{hconf::SINGLE, 0.001f, 4, 2, 2, 2}};

const std::vector<OlsInputs<double>> inputsd2 = {
{0.001, 4, 2, 2, 0}, {0.001, 4, 2, 2, 1}, {0.001, 4, 2, 2, 2}};
const std::vector<OlsInputs<double>> inputsd2 = {{hconf::SINGLE, 0.001, 4, 2, 2, 0},
{hconf::LEGACY_ONE, 0.001, 4, 2, 2, 1},
{hconf::LEGACY_TWO, 0.001, 4, 2, 2, 2}};

typedef OlsTest<float> OlsTestF;
TEST_P(OlsTestF, Fit)
Expand Down