diff --git a/cpp/include/raft/handle.hpp b/cpp/include/raft/handle.hpp index c925669530..190062e92f 100644 --- a/cpp/include/raft/handle.hpp +++ b/cpp/include/raft/handle.hpp @@ -61,8 +61,10 @@ class handle_t { int cur_dev = -1; CUDA_CHECK(cudaGetDevice(&cur_dev)); return cur_dev; - }()), - streams_(n_streams) { + }()) { + if (n_streams != 0) { + streams_ = std::make_unique(n_streams); + } create_resources(); thrust_policy_ = std::make_unique(user_stream_); } @@ -78,10 +80,13 @@ class handle_t { */ handle_t(const handle_t& other, int stream_id, int n_streams = kNumDefaultWorkerStreams) - : dev_id_(other.get_device()), streams_(n_streams) { + : dev_id_(other.get_device()) { RAFT_EXPECTS( other.get_num_internal_streams() > 0, "ERROR: the main handle must have at least one worker stream\n"); + if (n_streams != 0) { + streams_ = std::make_unique(n_streams); + } prop_ = other.get_device_properties(); device_prop_initialized_ = true; create_resources(); @@ -140,14 +145,23 @@ class handle_t { // legacy compatibility for cuML cudaStream_t get_internal_stream(int sid) const { - return streams_.get_stream(sid).value(); + RAFT_EXPECTS( + streams_.get() != nullptr, + "ERROR: rmm::cuda_stream_pool was not initialized with a non-zero value"); + return streams_->get_stream(sid).value(); } // new accessor return rmm::cuda_stream_view rmm::cuda_stream_view get_internal_stream_view(int sid) const { - return streams_.get_stream(sid); + RAFT_EXPECTS( + streams_.get() != nullptr, + "ERROR: rmm::cuda_stream_pool was not initialized with a non-zero value"); + return streams_->get_stream(sid); + } + + int get_num_internal_streams() const { + return streams_.get() != nullptr ? streams_->get_pool_size() : 0; } - int get_num_internal_streams() const { return streams_.get_pool_size(); } std::vector get_internal_streams() const { std::vector int_streams_vec; for (int i = 0; i < get_num_internal_streams(); i++) { @@ -212,7 +226,7 @@ class handle_t { std::unordered_map> subcomms_; const int dev_id_; - rmm::cuda_stream_pool streams_{0}; + std::unique_ptr streams_{nullptr}; mutable cublasHandle_t cublas_handle_; mutable bool cublas_initialized_{false}; mutable cusolverDnHandle_t cusolver_dn_handle_; diff --git a/cpp/test/cluster_solvers.cu b/cpp/test/cluster_solvers.cu index d280b3e95c..06b246d9a1 100644 --- a/cpp/test/cluster_solvers.cu +++ b/cpp/test/cluster_solvers.cu @@ -58,7 +58,6 @@ TEST(Raft, ModularitySolvers) { using value_type = double; handle_t h; - ASSERT_EQ(0, h.get_num_internal_streams()); ASSERT_EQ(0, h.get_device()); index_type neigvs{10}; diff --git a/cpp/test/eigen_solvers.cu b/cpp/test/eigen_solvers.cu index 15794ef568..ede790b38c 100644 --- a/cpp/test/eigen_solvers.cu +++ b/cpp/test/eigen_solvers.cu @@ -31,7 +31,6 @@ TEST(Raft, EigenSolvers) { using value_type = double; handle_t h; - ASSERT_EQ(0, h.get_num_internal_streams()); ASSERT_EQ(0, h.get_device()); index_type* ro{nullptr}; @@ -73,7 +72,6 @@ TEST(Raft, SpectralSolvers) { using value_type = double; handle_t h; - ASSERT_EQ(0, h.get_num_internal_streams()); ASSERT_EQ(0, h.get_device()); index_type neigvs{10}; diff --git a/cpp/test/handle.cpp b/cpp/test/handle.cpp index 4cb9809844..3e27789078 100644 --- a/cpp/test/handle.cpp +++ b/cpp/test/handle.cpp @@ -24,7 +24,6 @@ namespace raft { TEST(Raft, HandleDefault) { handle_t h; - ASSERT_EQ(0, h.get_num_internal_streams()); ASSERT_EQ(0, h.get_device()); ASSERT_EQ(nullptr, h.get_stream()); ASSERT_NE(nullptr, h.get_cublas_handle()); @@ -55,7 +54,6 @@ TEST(Raft, GetHandleFromPool) { handle_t child(parent, 2); ASSERT_EQ(parent.get_internal_stream(2), child.get_stream()); - ASSERT_EQ(0, child.get_num_internal_streams()); child.set_stream(parent.get_internal_stream(3)); ASSERT_EQ(parent.get_internal_stream(3), child.get_stream()); @@ -64,18 +62,6 @@ TEST(Raft, GetHandleFromPool) { ASSERT_EQ(parent.get_device(), child.get_device()); } -TEST(Raft, GetHandleFromPoolPerf) { - handle_t parent(100); - auto start = curTimeMillis(); - for (int i = 0; i < parent.get_num_internal_streams(); i++) { - handle_t child(parent, i); - ASSERT_EQ(parent.get_internal_stream(i), child.get_stream()); - child.wait_on_user_stream(); - } - // upperbound on 0.1ms per child handle - ASSERT_LE(curTimeMillis() - start, 10); -} - TEST(Raft, GetHandleStreamViews) { handle_t parent(4); diff --git a/cpp/test/spectral_matrix.cu b/cpp/test/spectral_matrix.cu index b85d35e3f8..388ad56f2d 100644 --- a/cpp/test/spectral_matrix.cu +++ b/cpp/test/spectral_matrix.cu @@ -38,7 +38,6 @@ TEST(Raft, SpectralMatrices) { using value_type = double; handle_t h; - ASSERT_EQ(0, h.get_num_internal_streams()); ASSERT_EQ(0, h.get_device()); csr_view_t csr_v{nullptr, nullptr, nullptr, 0, 0};