Skip to content

Commit

Permalink
Adds a linear accessor to RMM cuda stream pool (#696)
Browse files Browse the repository at this point in the history
Adds  `rmm::cuda_stream_pool::get_stream(stream_id)` and  `rmm::cuda_stream_pool::get_pool_size()` accessors which allow legacy compatibility in cuML and immediate adoption of `rmm::cuda_stream_pool` in RAFT and cuGraph. This co-exist with the current features in `rmm::cuda_stream_pool`.

close #689

Authors:
  - Alex Fender (@afender)

Approvers:
  - Jake Hemstad (@jrhemstad)
  - Mark Harris (@harrism)
  - Rong Ou (@rongou)

URL: #696
  • Loading branch information
afender authored Feb 8, 2021
1 parent 31604e7 commit 8f18e7f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
25 changes: 25 additions & 0 deletions include/rmm/cuda_stream_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/detail/error.hpp>

#include <atomic>
#include <vector>
Expand Down Expand Up @@ -61,6 +62,30 @@ class cuda_stream_pool {
return streams_[(next_stream++) % streams_.size()].view();
}

/**
* @brief Get a `cuda_stream_view` of the stream associated with `stream_id`.
* Equivalent values of `stream_id` return a stream_view to the same underlying stream.
*
* This function is thread safe with respect to other calls to the same function.
*
* @param stream_id Unique identifier for the desired stream
*
* @return rmm::cuda_stream_view
*/
rmm::cuda_stream_view get_stream(std::size_t stream_id) const
{
return streams_[stream_id % streams_.size()].view();
}

/**
* @brief Get the number of streams in the pool.
*
* This function is thread safe with respect to other calls to the same function.
*
* @return the number of streams in the pool
*/
size_t get_pool_size() const noexcept { return streams_.size(); }

private:
std::vector<rmm::cuda_stream> streams_;
mutable std::atomic_size_t next_stream{};
Expand Down
18 changes: 17 additions & 1 deletion tests/cuda_stream_pool_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,23 @@ TEST_F(CudaStreamPoolTest, ValidStreams)
RMM_CUDA_TRY(cudaMemsetAsync(v.data(), 0xcc, 100, stream_a.value()));
stream_a.synchronize();

auto v2 = rmm::device_uvector<uint8_t>{v, stream_b};
auto v2 = rmm::device_uvector<std::uint8_t>{v, stream_b};
auto x = v2.front_element(stream_b);
EXPECT_EQ(x, 0xcc);
}

TEST_F(CudaStreamPoolTest, PoolSize) { EXPECT_GE(this->pool.get_pool_size(), 1); }

TEST_F(CudaStreamPoolTest, OutOfBoundLinearAccess)
{
auto const stream_a = this->pool.get_stream(0);
auto const stream_b = this->pool.get_stream(this->pool.get_pool_size());
EXPECT_EQ(stream_a, stream_b);
}

TEST_F(CudaStreamPoolTest, ValidLinearAccess)
{
auto const stream_a = this->pool.get_stream(0);
auto const stream_b = this->pool.get_stream(1);
EXPECT_NE(stream_a, stream_b);
}

0 comments on commit 8f18e7f

Please sign in to comment.