Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streams upgrade in RAFT handle (RMM backend + create handle from parent's pool) #148

Merged
merged 7 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions cpp/include/raft/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <raft/comms/comms.hpp>
#include <raft/mr/device/allocator.hpp>
#include <raft/mr/host/allocator.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include "cudart_utils.h"

namespace raft {
Expand All @@ -62,11 +63,23 @@ class handle_t {
CUDA_CHECK(cudaGetDevice(&cur_dev));
return cur_dev;
}()),
num_streams_(n_streams),
streams_(n_streams),
device_allocator_(std::make_shared<mr::device::default_allocator>()),
host_allocator_(std::make_shared<mr::host::default_allocator>()) {
create_resources();
}
handle_t(const handle_t& h) : dev_id_(h.get_device()) {}
handle_t(const handle_t&& h) : dev_id_(h.get_device()) {}

// light copy operator
// skip streams, comms, and libs handles
handle_t& operator=(const handle_t& h) {
prop_ = h.get_device_properties();
device_prop_initialized_ = true;
device_allocator_ = get_device_allocator();
host_allocator_ = get_host_allocator();
return *this;
}
afender marked this conversation as resolved.
Show resolved Hide resolved

/** Destroys all held-up resources */
virtual ~handle_t() { destroy_resources(); }
Expand All @@ -75,6 +88,9 @@ class handle_t {

void set_stream(cudaStream_t stream) { user_stream_ = stream; }
cudaStream_t get_stream() const { return user_stream_; }
rmm::cuda_stream_view get_stream_view() const {
return rmm::cuda_stream_view(user_stream_);
}

void set_device_allocator(std::shared_ptr<mr::device::allocator> allocator) {
device_allocator_ = allocator;
Expand Down Expand Up @@ -126,26 +142,42 @@ class handle_t {
return cusparse_handle_;
}

cudaStream_t get_internal_stream(int sid) const { return streams_[sid]; }
int get_num_internal_streams() const { return num_streams_; }
// legacy compatibility for cuML
cudaStream_t get_internal_stream(int sid) const {
return streams_.get_stream(sid).value();
}
// new accessor return rmm::cuda_stream_view
rmm::cuda_stream_view get_internal_stream_view(int sid) const {
return streams_.get_stream(sid);
}

int get_num_internal_streams() const { return streams_.get_pool_size(); }
std::vector<cudaStream_t> get_internal_streams() const {
std::vector<cudaStream_t> int_streams_vec;
for (auto s : streams_) {
int_streams_vec.push_back(s);
for (int i = 0; i < get_num_internal_streams(); i++) {
int_streams_vec.push_back(get_internal_stream(i));
}
return int_streams_vec;
}

handle_t get_handle_from_internal_pool(
int stream_id, int n_streams = kNumDefaultWorkerStreams) const {
afender marked this conversation as resolved.
Show resolved Hide resolved
handle_t handle(n_streams);
handle = *this;
afender marked this conversation as resolved.
Show resolved Hide resolved
handle.set_stream(this->get_internal_stream(stream_id));
return handle;
}

void wait_on_user_stream() const {
CUDA_CHECK(cudaEventRecord(event_, user_stream_));
for (auto s : streams_) {
CUDA_CHECK(cudaStreamWaitEvent(s, event_, 0));
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaStreamWaitEvent(get_internal_stream(i), event_, 0));
}
}

void wait_on_internal_streams() const {
for (auto s : streams_) {
CUDA_CHECK(cudaEventRecord(event_, s));
for (int i = 0; i < get_num_internal_streams(); i++) {
CUDA_CHECK(cudaEventRecord(event_, get_internal_stream(i)));
CUDA_CHECK(cudaStreamWaitEvent(user_stream_, event_, 0));
}
}
Expand Down Expand Up @@ -192,8 +224,7 @@ class handle_t {
std::unordered_map<std::string, std::shared_ptr<comms::comms_t>> subcomms_;

const int dev_id_;
const int num_streams_;
std::vector<cudaStream_t> streams_;
rmm::cuda_stream_pool streams_{0};
mutable cublasHandle_t cublas_handle_;
mutable bool cublas_initialized_{false};
mutable cusolverDnHandle_t cusolver_dn_handle_;
Expand All @@ -211,11 +242,6 @@ class handle_t {
mutable std::mutex mutex_;

void create_resources() {
for (int i = 0; i < num_streams_; ++i) {
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
streams_.push_back(stream);
}
CUDA_CHECK(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
}

Expand All @@ -237,11 +263,6 @@ class handle_t {
//CUBLAS_CHECK_NO_THROW(cublasDestroy(cublas_handle_));
CUBLAS_CHECK(cublasDestroy(cublas_handle_));
}
while (!streams_.empty()) {
//CUDA_CHECK_NO_THROW(cudaStreamDestroy(streams_.back()));
CUDA_CHECK(cudaStreamDestroy(streams_.back()));
streams_.pop_back();
}
//CUDA_CHECK_NO_THROW(cudaEventDestroy(event_));
CUDA_CHECK(cudaEventDestroy(event_));
}
Expand Down
36 changes: 36 additions & 0 deletions cpp/test/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <gtest/gtest.h>
#include <cstddef>
#include <iostream>
#include <memory>
#include <raft/handle.hpp>
Expand Down Expand Up @@ -49,4 +50,39 @@ TEST(Raft, GetInternalStreams) {
ASSERT_EQ(4U, streams.size());
}

TEST(Raft, GetHandleFromPool) {
handle_t parent(4);

auto child = parent.get_handle_from_internal_pool(2);
ASSERT_EQ(parent.get_internal_stream(2), child.get_stream());
ASSERT_EQ(0, child.get_num_internal_streams());

child.set_stream(parent.get_internal_stream(3));
ASSERT_EQ(parent.get_internal_stream(3), child.get_stream());
ASSERT_NE(parent.get_internal_stream(2), child.get_stream());

ASSERT_EQ(parent.get_device(), child.get_device());
}

TEST(Raft, GetHandleFromPoolPerf) {
handle_t parent(100);
auto start = curTimeMillis();
for (int i = 0; i < parent.get_num_internal_streams(); i++) {
auto child = parent.get_handle_from_internal_pool(i);
ASSERT_EQ(parent.get_internal_stream(i), child.get_stream());
child.wait_on_user_stream();
}
// upperbound on 0.1ms per child handle
ASSERT_LE(curTimeMillis() - start, 10);
}

TEST(Raft, GetHandleStreamViews) {
handle_t parent(4);

auto child = parent.get_handle_from_internal_pool(2);
ASSERT_EQ(parent.get_internal_stream_view(2), child.get_stream_view());
ASSERT_EQ(parent.get_internal_stream_view(2).value(),
child.get_stream_view().value());
EXPECT_FALSE(child.get_stream_view().is_default());
}
} // namespace raft