Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streams upgrade in RAFT handle (RMM backend + create handle from parent's pool) #148

Merged
merged 7 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <raft/comms/comms.hpp>
#include <raft/mr/device/allocator.hpp>
#include <raft/mr/host/allocator.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include "cudart_utils.h"

namespace raft {
Expand All @@ -62,19 +63,45 @@ class handle_t {
CUDA_CHECK(cudaGetDevice(&cur_dev));
return cur_dev;
}()),
num_streams_(n_streams),
streams_(n_streams),
device_allocator_(std::make_shared<mr::device::default_allocator>()),
host_allocator_(std::make_shared<mr::host::default_allocator>()) {
create_resources();
}

/**
* @brief Construct a light handle copy from another
* user stream, cuda handles, comms and worker pool are not copied
* The user_stream of the returned handle is set to the specified stream
* of the other handle worker pool
* @param[in] stream_id stream id in `other` worker streams
* to be set as user stream in the constructed handle
* @param[in] n_streams number worker streams to be created
*/
handle_t(const handle_t& other, int stream_id,
int n_streams = kNumDefaultWorkerStreams)
: dev_id_(other.get_device()), streams_(n_streams) {
RAFT_EXPECTS(
other.get_num_internal_streams() > 0,
"ERROR: the main handle must have at least one worker stream\n");
prop_ = other.get_device_properties();
device_prop_initialized_ = true;
device_allocator_ = other.get_device_allocator();
host_allocator_ = other.get_host_allocator();
create_resources();
set_stream(other.get_internal_stream(stream_id));
}
afender marked this conversation as resolved.
Show resolved Hide resolved

/** Destroys all held-up resources */
virtual ~handle_t() { destroy_resources(); }

int get_device() const { return dev_id_; }

void set_stream(cudaStream_t stream) { user_stream_ = stream; }
cudaStream_t get_stream() const { return user_stream_; }
rmm::cuda_stream_view get_stream_view() const {
return rmm::cuda_stream_view(user_stream_);
}

void set_device_allocator(std::shared_ptr<mr::device::allocator> allocator) {
device_allocator_ = allocator;
Expand Down Expand Up @@ -126,26 +153,34 @@ class handle_t {
return cusparse_handle_;
}

cudaStream_t get_internal_stream(int sid) const { return streams_[sid]; }
int get_num_internal_streams() const { return num_streams_; }
// legacy compatibility for cuML
cudaStream_t get_internal_stream(int sid) const {
return streams_.get_stream(sid).value();
}
// new accessor return rmm::cuda_stream_view
rmm::cuda_stream_view get_internal_stream_view(int sid) const {
return streams_.get_stream(sid);
}

int get_num_internal_streams() const { return streams_.get_pool_size(); }
std::vector<cudaStream_t> get_internal_streams() const {
std::vector<cudaStream_t> int_streams_vec;
for (auto s : streams_) {
int_streams_vec.push_back(s);
for (int i = 0; i < get_num_internal_streams(); i++) {
int_streams_vec.push_back(get_internal_stream(i));
}
return int_streams_vec;
}

void wait_on_user_stream() const {
CUDA_CHECK(cudaEventRecord(event_, user_stream_));
for (auto s : streams_) {
CUDA_CHECK(cudaStreamWaitEvent(s, event_, 0));
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaStreamWaitEvent(get_internal_stream(i), event_, 0));
}
}

void wait_on_internal_streams() const {
for (auto s : streams_) {
CUDA_CHECK(cudaEventRecord(event_, s));
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaEventRecord(event_, get_internal_stream(i)));
CUDA_CHECK(cudaStreamWaitEvent(user_stream_, event_, 0));
}
}
Expand Down Expand Up @@ -192,8 +227,7 @@ class handle_t {
std::unordered_map<std::string, std::shared_ptr<comms::comms_t>> subcomms_;

const int dev_id_;
const int num_streams_;
std::vector<cudaStream_t> streams_;
rmm::cuda_stream_pool streams_{0};
mutable cublasHandle_t cublas_handle_;
mutable bool cublas_initialized_{false};
mutable cusolverDnHandle_t cusolver_dn_handle_;
Expand All @@ -211,11 +245,6 @@ class handle_t {
mutable std::mutex mutex_;

void create_resources() {
for (int i = 0; i < num_streams_; ++i) {
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
streams_.push_back(stream);
}
CUDA_CHECK(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}

Expand All @@ -237,11 +266,6 @@ class handle_t {
//CUBLAS_CHECK_NO_THROW(cublasDestroy(cublas_handle_));
CUBLAS_CHECK(cublasDestroy(cublas_handle_));
}
while (!streams_.empty()) {
//CUDA_CHECK_NO_THROW(cudaStreamDestroy(streams_.back()));
CUDA_CHECK(cudaStreamDestroy(streams_.back()));
streams_.pop_back();
}
//CUDA_CHECK_NO_THROW(cudaEventDestroy(event_));
CUDA_CHECK(cudaEventDestroy(event_));
}
Expand Down
36 changes: 36 additions & 0 deletions cpp/test/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <gtest/gtest.h>
#include <cstddef>
#include <iostream>
#include <memory>
#include <raft/handle.hpp>
Expand Down Expand Up @@ -49,4 +50,39 @@ TEST(Raft, GetInternalStreams) {
ASSERT_EQ(4U, streams.size());
}

TEST(Raft, GetHandleFromPool) {
handle_t parent(4);

handle_t child(parent, 2);
ASSERT_EQ(parent.get_internal_stream(2), child.get_stream());
ASSERT_EQ(0, child.get_num_internal_streams());

child.set_stream(parent.get_internal_stream(3));
ASSERT_EQ(parent.get_internal_stream(3), child.get_stream());
ASSERT_NE(parent.get_internal_stream(2), child.get_stream());

ASSERT_EQ(parent.get_device(), child.get_device());
}

TEST(Raft, GetHandleFromPoolPerf) {
handle_t parent(100);
auto start = curTimeMillis();
for (int i = 0; i < parent.get_num_internal_streams(); i++) {
handle_t child(parent, i);
ASSERT_EQ(parent.get_internal_stream(i), child.get_stream());
child.wait_on_user_stream();
}
// upperbound on 0.1ms per child handle
ASSERT_LE(curTimeMillis() - start, 10);
}

TEST(Raft, GetHandleStreamViews) {
handle_t parent(4);

handle_t child(parent, 2);
ASSERT_EQ(parent.get_internal_stream_view(2), child.get_stream_view());
ASSERT_EQ(parent.get_internal_stream_view(2).value(),
child.get_stream_view().value());
EXPECT_FALSE(child.get_stream_view().is_default());
}
} // namespace raft