From 8f18e7f24e0a11a0d26d7a41cb73c0ab71bfa8d1 Mon Sep 17 00:00:00 2001 From: Alex Fender Date: Mon, 8 Feb 2021 16:59:26 -0600 Subject: [PATCH] Adds a linear accessor to RMM cuda stream pool (#696) Adds `rmm::cuda_stream_pool::get_stream(stream_id)` and `rmm::cuda_stream_pool::get_pool_size()` accessors which allow legacy compatibility in cuML and immediate adoption of `rmm::cuda_stream_pool` in RAFT and cuGraph. This co-exist with the current features in `rmm::cuda_stream_pool`. close #689 Authors: - Alex Fender (@afender) Approvers: - Jake Hemstad (@jrhemstad) - Mark Harris (@harrism) - Rong Ou (@rongou) URL: https://github.com/rapidsai/rmm/pull/696 --- include/rmm/cuda_stream_pool.hpp | 25 +++++++++++++++++++++++++ tests/cuda_stream_pool_tests.cpp | 18 +++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/include/rmm/cuda_stream_pool.hpp b/include/rmm/cuda_stream_pool.hpp index 803c0474e..2e77f2047 100644 --- a/include/rmm/cuda_stream_pool.hpp +++ b/include/rmm/cuda_stream_pool.hpp @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -61,6 +62,30 @@ class cuda_stream_pool { return streams_[(next_stream++) % streams_.size()].view(); } + /** + * @brief Get a `cuda_stream_view` of the stream associated with `stream_id`. + * Equivalent values of `stream_id` return a stream_view to the same underlying stream. + * + * This function is thread safe with respect to other calls to the same function. + * + * @param stream_id Unique identifier for the desired stream + * + * @return rmm::cuda_stream_view + */ + rmm::cuda_stream_view get_stream(std::size_t stream_id) const + { + return streams_[stream_id % streams_.size()].view(); + } + + /** + * @brief Get the number of streams in the pool. + * + * This function is thread safe with respect to other calls to the same function. + * + * @return the number of streams in the pool + */ + size_t get_pool_size() const noexcept { return streams_.size(); } + private: std::vector streams_; mutable std::atomic_size_t next_stream{}; diff --git a/tests/cuda_stream_pool_tests.cpp b/tests/cuda_stream_pool_tests.cpp index 9ca2a4188..1e14e2abf 100644 --- a/tests/cuda_stream_pool_tests.cpp +++ b/tests/cuda_stream_pool_tests.cpp @@ -54,7 +54,23 @@ TEST_F(CudaStreamPoolTest, ValidStreams) RMM_CUDA_TRY(cudaMemsetAsync(v.data(), 0xcc, 100, stream_a.value())); stream_a.synchronize(); - auto v2 = rmm::device_uvector{v, stream_b}; + auto v2 = rmm::device_uvector{v, stream_b}; auto x = v2.front_element(stream_b); EXPECT_EQ(x, 0xcc); } + +TEST_F(CudaStreamPoolTest, PoolSize) { EXPECT_GE(this->pool.get_pool_size(), 1); } + +TEST_F(CudaStreamPoolTest, OutOfBoundLinearAccess) +{ + auto const stream_a = this->pool.get_stream(0); + auto const stream_b = this->pool.get_stream(this->pool.get_pool_size()); + EXPECT_EQ(stream_a, stream_b); +} + +TEST_F(CudaStreamPoolTest, ValidLinearAccess) +{ + auto const stream_a = this->pool.get_stream(0); + auto const stream_b = this->pool.get_stream(1); + EXPECT_NE(stream_a, stream_b); +}