diff --git a/cpp/include/raft/handle.hpp b/cpp/include/raft/handle.hpp index af53968653..dbe7e83189 100644 --- a/cpp/include/raft/handle.hpp +++ b/cpp/include/raft/handle.hpp @@ -38,6 +38,7 @@ #include #include #include +#include #include "cudart_utils.h" namespace raft { @@ -62,12 +63,35 @@ class handle_t { CUDA_CHECK(cudaGetDevice(&cur_dev)); return cur_dev; }()), - num_streams_(n_streams), + streams_(n_streams), device_allocator_(std::make_shared()), host_allocator_(std::make_shared()) { create_resources(); } + /** + * @brief Construct a light handle copy from another + * user stream, cuda handles, comms and worker pool are not copied + * The user_stream of the returned handle is set to the specified stream + * of the other handle worker pool + * @param[in] stream_id stream id in `other` worker streams + * to be set as user stream in the constructed handle + * @param[in] n_streams number worker streams to be created + */ + handle_t(const handle_t& other, int stream_id, + int n_streams = kNumDefaultWorkerStreams) + : dev_id_(other.get_device()), streams_(n_streams) { + RAFT_EXPECTS( + other.get_num_internal_streams() > 0, + "ERROR: the main handle must have at least one worker stream\n"); + prop_ = other.get_device_properties(); + device_prop_initialized_ = true; + device_allocator_ = other.get_device_allocator(); + host_allocator_ = other.get_host_allocator(); + create_resources(); + set_stream(other.get_internal_stream(stream_id)); + } + /** Destroys all held-up resources */ virtual ~handle_t() { destroy_resources(); } @@ -75,6 +99,9 @@ class handle_t { void set_stream(cudaStream_t stream) { user_stream_ = stream; } cudaStream_t get_stream() const { return user_stream_; } + rmm::cuda_stream_view get_stream_view() const { + return rmm::cuda_stream_view(user_stream_); + } void set_device_allocator(std::shared_ptr allocator) { device_allocator_ = allocator; @@ -126,26 +153,34 @@ class handle_t { return cusparse_handle_; } - cudaStream_t get_internal_stream(int sid) const { return streams_[sid]; } - int get_num_internal_streams() const { return num_streams_; } + // legacy compatibility for cuML + cudaStream_t get_internal_stream(int sid) const { + return streams_.get_stream(sid).value(); + } + // new accessor return rmm::cuda_stream_view + rmm::cuda_stream_view get_internal_stream_view(int sid) const { + return streams_.get_stream(sid); + } + + int get_num_internal_streams() const { return streams_.get_pool_size(); } std::vector get_internal_streams() const { std::vector int_streams_vec; - for (auto s : streams_) { - int_streams_vec.push_back(s); + for (int i = 0; i < get_num_internal_streams(); i++) { + int_streams_vec.push_back(get_internal_stream(i)); } return int_streams_vec; } void wait_on_user_stream() const { CUDA_CHECK(cudaEventRecord(event_, user_stream_)); - for (auto s : streams_) { - CUDA_CHECK(cudaStreamWaitEvent(s, event_, 0)); + for (int i = 0; i < get_num_internal_streams(); i++) { + CUDA_CHECK(cudaStreamWaitEvent(get_internal_stream(i), event_, 0)); } } void wait_on_internal_streams() const { - for (auto s : streams_) { - CUDA_CHECK(cudaEventRecord(event_, s)); + for (int i = 0; i < get_num_internal_streams(); i++) { + CUDA_CHECK(cudaEventRecord(event_, get_internal_stream(i))); CUDA_CHECK(cudaStreamWaitEvent(user_stream_, event_, 0)); } } @@ -192,8 +227,7 @@ class handle_t { std::unordered_map> subcomms_; const int dev_id_; - const int num_streams_; - std::vector streams_; + rmm::cuda_stream_pool streams_{0}; mutable cublasHandle_t cublas_handle_; mutable bool cublas_initialized_{false}; mutable cusolverDnHandle_t cusolver_dn_handle_; @@ -211,11 +245,6 @@ class handle_t { mutable std::mutex mutex_; void create_resources() { - for (int i = 0; i < num_streams_; ++i) { - cudaStream_t stream; - CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - streams_.push_back(stream); - } CUDA_CHECK(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); } @@ -237,11 +266,6 @@ class handle_t { //CUBLAS_CHECK_NO_THROW(cublasDestroy(cublas_handle_)); CUBLAS_CHECK(cublasDestroy(cublas_handle_)); } - while (!streams_.empty()) { - //CUDA_CHECK_NO_THROW(cudaStreamDestroy(streams_.back())); - CUDA_CHECK(cudaStreamDestroy(streams_.back())); - streams_.pop_back(); - } //CUDA_CHECK_NO_THROW(cudaEventDestroy(event_)); CUDA_CHECK(cudaEventDestroy(event_)); } diff --git a/cpp/test/handle.cpp b/cpp/test/handle.cpp index 5f6f3ceece..4cb9809844 100644 --- a/cpp/test/handle.cpp +++ b/cpp/test/handle.cpp @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -49,4 +50,39 @@ TEST(Raft, GetInternalStreams) { ASSERT_EQ(4U, streams.size()); } +TEST(Raft, GetHandleFromPool) { + handle_t parent(4); + + handle_t child(parent, 2); + ASSERT_EQ(parent.get_internal_stream(2), child.get_stream()); + ASSERT_EQ(0, child.get_num_internal_streams()); + + child.set_stream(parent.get_internal_stream(3)); + ASSERT_EQ(parent.get_internal_stream(3), child.get_stream()); + ASSERT_NE(parent.get_internal_stream(2), child.get_stream()); + + ASSERT_EQ(parent.get_device(), child.get_device()); +} + +TEST(Raft, GetHandleFromPoolPerf) { + handle_t parent(100); + auto start = curTimeMillis(); + for (int i = 0; i < parent.get_num_internal_streams(); i++) { + handle_t child(parent, i); + ASSERT_EQ(parent.get_internal_stream(i), child.get_stream()); + child.wait_on_user_stream(); + } + // upperbound on 0.1ms per child handle + ASSERT_LE(curTimeMillis() - start, 10); +} + +TEST(Raft, GetHandleStreamViews) { + handle_t parent(4); + + handle_t child(parent, 2); + ASSERT_EQ(parent.get_internal_stream_view(2), child.get_stream_view()); + ASSERT_EQ(parent.get_internal_stream_view(2).value(), + child.get_stream_view().value()); + EXPECT_FALSE(child.get_stream_view().is_default()); +} } // namespace raft