Skip to content

Commit

Permalink
Merge pull request #148 from afender/streams_pool_upgrade
Browse files Browse the repository at this point in the history
Streams upgrade in RAFT handle (RMM backend + create handle from parent's pool)
  • Loading branch information
cjnolet authored Feb 22, 2021
2 parents 4f9fe93 + 30e341f commit 43ecde4
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 21 deletions.
66 changes: 45 additions & 21 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <raft/comms/comms.hpp>
#include <raft/mr/device/allocator.hpp>
#include <raft/mr/host/allocator.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include "cudart_utils.h"

namespace raft {
Expand All @@ -62,19 +63,45 @@ class handle_t {
CUDA_CHECK(cudaGetDevice(&cur_dev));
return cur_dev;
}()),
num_streams_(n_streams),
streams_(n_streams),
device_allocator_(std::make_shared<mr::device::default_allocator>()),
host_allocator_(std::make_shared<mr::host::default_allocator>()) {
create_resources();
}

/**
* @brief Construct a light handle copy from another
* user stream, cuda handles, comms and worker pool are not copied
* The user_stream of the returned handle is set to the specified stream
* of the other handle worker pool
* @param[in] stream_id stream id in `other` worker streams
* to be set as user stream in the constructed handle
* @param[in] n_streams number worker streams to be created
*/
handle_t(const handle_t& other, int stream_id,
int n_streams = kNumDefaultWorkerStreams)
: dev_id_(other.get_device()), streams_(n_streams) {
RAFT_EXPECTS(
other.get_num_internal_streams() > 0,
"ERROR: the main handle must have at least one worker stream\n");
prop_ = other.get_device_properties();
device_prop_initialized_ = true;
device_allocator_ = other.get_device_allocator();
host_allocator_ = other.get_host_allocator();
create_resources();
set_stream(other.get_internal_stream(stream_id));
}

/** Destroys all held-up resources */
virtual ~handle_t() { destroy_resources(); }

int get_device() const { return dev_id_; }

void set_stream(cudaStream_t stream) { user_stream_ = stream; }
cudaStream_t get_stream() const { return user_stream_; }
rmm::cuda_stream_view get_stream_view() const {
return rmm::cuda_stream_view(user_stream_);
}

void set_device_allocator(std::shared_ptr<mr::device::allocator> allocator) {
device_allocator_ = allocator;
Expand Down Expand Up @@ -126,26 +153,34 @@ class handle_t {
return cusparse_handle_;
}

cudaStream_t get_internal_stream(int sid) const { return streams_[sid]; }
int get_num_internal_streams() const { return num_streams_; }
// legacy compatibility for cuML
cudaStream_t get_internal_stream(int sid) const {
return streams_.get_stream(sid).value();
}
// new accessor return rmm::cuda_stream_view
rmm::cuda_stream_view get_internal_stream_view(int sid) const {
return streams_.get_stream(sid);
}

int get_num_internal_streams() const { return streams_.get_pool_size(); }
std::vector<cudaStream_t> get_internal_streams() const {
std::vector<cudaStream_t> int_streams_vec;
for (auto s : streams_) {
int_streams_vec.push_back(s);
for (int i = 0; i < get_num_internal_streams(); i++) {
int_streams_vec.push_back(get_internal_stream(i));
}
return int_streams_vec;
}

void wait_on_user_stream() const {
CUDA_CHECK(cudaEventRecord(event_, user_stream_));
for (auto s : streams_) {
CUDA_CHECK(cudaStreamWaitEvent(s, event_, 0));
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaStreamWaitEvent(get_internal_stream(i), event_, 0));
}
}

void wait_on_internal_streams() const {
for (auto s : streams_) {
CUDA_CHECK(cudaEventRecord(event_, s));
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaEventRecord(event_, get_internal_stream(i)));
CUDA_CHECK(cudaStreamWaitEvent(user_stream_, event_, 0));
}
}
Expand Down Expand Up @@ -192,8 +227,7 @@ class handle_t {
std::unordered_map<std::string, std::shared_ptr<comms::comms_t>> subcomms_;

const int dev_id_;
const int num_streams_;
std::vector<cudaStream_t> streams_;
rmm::cuda_stream_pool streams_{0};
mutable cublasHandle_t cublas_handle_;
mutable bool cublas_initialized_{false};
mutable cusolverDnHandle_t cusolver_dn_handle_;
Expand All @@ -211,11 +245,6 @@ class handle_t {
mutable std::mutex mutex_;

void create_resources() {
for (int i = 0; i < num_streams_; ++i) {
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
streams_.push_back(stream);
}
CUDA_CHECK(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}

Expand All @@ -237,11 +266,6 @@ class handle_t {
//CUBLAS_CHECK_NO_THROW(cublasDestroy(cublas_handle_));
CUBLAS_CHECK(cublasDestroy(cublas_handle_));
}
while (!streams_.empty()) {
//CUDA_CHECK_NO_THROW(cudaStreamDestroy(streams_.back()));
CUDA_CHECK(cudaStreamDestroy(streams_.back()));
streams_.pop_back();
}
//CUDA_CHECK_NO_THROW(cudaEventDestroy(event_));
CUDA_CHECK(cudaEventDestroy(event_));
}
Expand Down
36 changes: 36 additions & 0 deletions cpp/test/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <gtest/gtest.h>
#include <cstddef>
#include <iostream>
#include <memory>
#include <raft/handle.hpp>
Expand Down Expand Up @@ -49,4 +50,39 @@ TEST(Raft, GetInternalStreams) {
ASSERT_EQ(4U, streams.size());
}

TEST(Raft, GetHandleFromPool) {
handle_t parent(4);

handle_t child(parent, 2);
ASSERT_EQ(parent.get_internal_stream(2), child.get_stream());
ASSERT_EQ(0, child.get_num_internal_streams());

child.set_stream(parent.get_internal_stream(3));
ASSERT_EQ(parent.get_internal_stream(3), child.get_stream());
ASSERT_NE(parent.get_internal_stream(2), child.get_stream());

ASSERT_EQ(parent.get_device(), child.get_device());
}

TEST(Raft, GetHandleFromPoolPerf) {
handle_t parent(100);
auto start = curTimeMillis();
for (int i = 0; i < parent.get_num_internal_streams(); i++) {
handle_t child(parent, i);
ASSERT_EQ(parent.get_internal_stream(i), child.get_stream());
child.wait_on_user_stream();
}
// upperbound on 0.1ms per child handle
ASSERT_LE(curTimeMillis() - start, 10);
}

TEST(Raft, GetHandleStreamViews) {
handle_t parent(4);

handle_t child(parent, 2);
ASSERT_EQ(parent.get_internal_stream_view(2), child.get_stream_view());
ASSERT_EQ(parent.get_internal_stream_view(2).value(),
child.get_stream_view().value());
EXPECT_FALSE(child.get_stream_view().is_default());
}
} // namespace raft

0 comments on commit 43ecde4

Please sign in to comment.