Skip to content

Commit

Permalink
Fix worker streams in OLS-eig executing in an unsafe order (#4539)
Browse files Browse the repository at this point in the history
The latest version of the "eig" OLS solver has a bug producing garbage results under some conditions. When at least one worker stream is used to run some operations concurrently, for sufficiently large workset sizes, the memory allocation in the main stream may finish later than the worker stream starts to use it.

This PR adds more ordering between the main and the worker streams, fixing this and some other theoretically possible edge cases.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #4539
  • Loading branch information
achirkin authored Feb 10, 2022
1 parent 06783ff commit df6d7be
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 31 deletions.
57 changes: 31 additions & 26 deletions cpp/src_prims/linalg/lstsq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,22 @@ struct DeviceEvent {
DeviceEvent(bool concurrent)
{
if (concurrent)
RAFT_CUDA_TRY(cudaEventCreate(&e));
RAFT_CUDA_TRY(cudaEventCreateWithFlags(&e, cudaEventDisableTiming));
else
e = nullptr;
}
~DeviceEvent()
{
if (e != nullptr) RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(e));
}
operator cudaEvent_t() const { return e; }
void record(cudaStream_t stream)
{
if (e != nullptr) RAFT_CUDA_TRY(cudaEventRecord(e, stream));
}
void wait(cudaStream_t stream)
void wait_by(cudaStream_t stream)
{
if (e != nullptr) RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, e, 0u));
}
void wait()
{
if (e != nullptr) raft::interruptible::synchronize(e);
}
DeviceEvent& operator=(const DeviceEvent& other) = delete;
};

Expand Down Expand Up @@ -259,27 +254,26 @@ void lstsqEig(const raft::handle_t& handle,
cudaStream_t stream)
{
rmm::cuda_stream_view mainStream = rmm::cuda_stream_view(stream);
rmm::cuda_stream_view multAbStream = mainStream;
bool concurrent = false;
{
int sp_size = handle.get_stream_pool_size();
if (sp_size > 0) {
multAbStream = handle.get_stream_from_stream_pool(0);
// check if the two streams can run concurrently
if (!are_implicitly_synchronized(mainStream, multAbStream)) {
concurrent = true;
} else if (sp_size > 1) {
mainStream = multAbStream;
multAbStream = handle.get_stream_from_stream_pool(1);
concurrent = true;
}
}
rmm::cuda_stream_view multAbStream = handle.get_next_usable_stream();
bool concurrent;
// Check if the two streams can run concurrently. This is needed because a legacy default stream
// would synchronize with other blocking streams. To avoid synchronization in such case, we try to
// use an additional stream from the pool.
if (!are_implicitly_synchronized(mainStream, multAbStream)) {
concurrent = true;
} else if (handle.get_stream_pool_size() > 1) {
mainStream = handle.get_next_usable_stream();
concurrent = true;
} else {
multAbStream = mainStream;
concurrent = false;
}
// the event is created only if the given raft handle is capable of running
// at least two CUDA streams without implicit synchronization.
DeviceEvent multAbDone(concurrent);

rmm::device_uvector<math_t> workset(n_cols * n_cols * 3 + n_cols * 2, mainStream);
// the event is created only if the given raft handle is capable of running
// at least two CUDA streams without implicit synchronization.
DeviceEvent worksetDone(concurrent);
worksetDone.record(mainStream);
math_t* Q = workset.data();
math_t* QS = Q + n_cols * n_cols;
math_t* covA = QS + n_cols * n_cols;
Expand All @@ -304,7 +298,9 @@ void lstsqEig(const raft::handle_t& handle,
mainStream);

// Ab <- A* b
worksetDone.wait_by(multAbStream);
raft::linalg::gemv(handle, A, n_rows, n_cols, b, Ab, true, multAbStream);
DeviceEvent multAbDone(concurrent);
multAbDone.record(multAbStream);

// Q S Q* <- covA
Expand All @@ -329,9 +325,18 @@ void lstsqEig(const raft::handle_t& handle,
alpha,
beta,
mainStream);
multAbDone.wait(mainStream);

multAbDone.wait_by(mainStream);
// w <- covA Ab == Q invS Q* A b == inv(A* A) A b
raft::linalg::gemv(handle, covA, n_cols, n_cols, Ab, w, false, mainStream);

// This event is created only if we use two worker streams, and `stream` is not the legacy stream,
// and `mainStream` is not a non-blocking stream. In fact, with the current logic these conditions
// are impossible together, but it still makes sense to put this construct here to emphasize that
// `stream` must wait till the work here is done (for future refactorings).
DeviceEvent mainDone(!are_implicitly_synchronized(mainStream, stream));
mainDone.record(mainStream);
mainDone.wait_by(stream);
}

/** Solves the linear ordinary least squares problem `Aw = b`
Expand Down
38 changes: 33 additions & 5 deletions cpp/test/sg/ols.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <raft/cuda_utils.cuh>
#include <raft/cudart_utils.h>
#include <raft/mr/device/allocator.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <test_utils.h>
#include <vector>

Expand All @@ -27,8 +28,29 @@ namespace GLM {

using namespace MLCommon;

enum class hconf { SINGLE, LEGACY_ONE, LEGACY_TWO, NON_BLOCKING_ONE, NON_BLOCKING_TWO };

raft::handle_t create_handle(hconf type)
{
switch (type) {
case hconf::LEGACY_ONE:
return raft::handle_t(rmm::cuda_stream_legacy, std::make_shared<rmm::cuda_stream_pool>(1));
case hconf::LEGACY_TWO:
return raft::handle_t(rmm::cuda_stream_legacy, std::make_shared<rmm::cuda_stream_pool>(2));
case hconf::NON_BLOCKING_ONE:
return raft::handle_t(rmm::cuda_stream_per_thread,
std::make_shared<rmm::cuda_stream_pool>(1));
case hconf::NON_BLOCKING_TWO:
return raft::handle_t(rmm::cuda_stream_per_thread,
std::make_shared<rmm::cuda_stream_pool>(2));
case hconf::SINGLE:
default: return raft::handle_t();
}
}

template <typename T>
struct OlsInputs {
hconf hc;
T tol;
int n_row;
int n_col;
Expand All @@ -41,6 +63,7 @@ class OlsTest : public ::testing::TestWithParam<OlsInputs<T>> {
public:
OlsTest()
: params(::testing::TestWithParam<OlsInputs<T>>::GetParam()),
handle(create_handle(params.hc)),
stream(handle.get_stream()),
coef(params.n_col, stream),
coef2(params.n_col, stream),
Expand Down Expand Up @@ -216,10 +239,11 @@ class OlsTest : public ::testing::TestWithParam<OlsInputs<T>> {
}

protected:
OlsInputs<T> params;

raft::handle_t handle;
cudaStream_t stream = 0;

OlsInputs<T> params;
rmm::device_uvector<T> coef, coef_ref, pred, pred_ref;
rmm::device_uvector<T> coef2, coef2_ref, pred2, pred2_ref;
rmm::device_uvector<T> coef3, coef3_ref, pred3, pred3_ref;
Expand All @@ -228,11 +252,15 @@ class OlsTest : public ::testing::TestWithParam<OlsInputs<T>> {
T intercept, intercept2, intercept3;
};

const std::vector<OlsInputs<float>> inputsf2 = {
{0.001f, 4, 2, 2, 0}, {0.001f, 4, 2, 2, 1}, {0.001f, 4, 2, 2, 2}};
const std::vector<OlsInputs<float>> inputsf2 = {{hconf::NON_BLOCKING_ONE, 0.001f, 4, 2, 2, 0},
{hconf::NON_BLOCKING_TWO, 0.001f, 4, 2, 2, 1},
{hconf::LEGACY_ONE, 0.001f, 4, 2, 2, 2},
{hconf::LEGACY_TWO, 0.001f, 4, 2, 2, 2},
{hconf::SINGLE, 0.001f, 4, 2, 2, 2}};

const std::vector<OlsInputs<double>> inputsd2 = {
{0.001, 4, 2, 2, 0}, {0.001, 4, 2, 2, 1}, {0.001, 4, 2, 2, 2}};
const std::vector<OlsInputs<double>> inputsd2 = {{hconf::SINGLE, 0.001, 4, 2, 2, 0},
{hconf::LEGACY_ONE, 0.001, 4, 2, 2, 1},
{hconf::LEGACY_TWO, 0.001, 4, 2, 2, 2}};

typedef OlsTest<float> OlsTestF;
TEST_P(OlsTestF, Fit)
Expand Down

0 comments on commit df6d7be

Please sign in to comment.