diff --git a/include/rmm/cuda_stream_pool.hpp b/include/rmm/cuda_stream_pool.hpp index c0ef1dbce..ee94a9850 100644 --- a/include/rmm/cuda_stream_pool.hpp +++ b/include/rmm/cuda_stream_pool.hpp @@ -39,11 +39,15 @@ class cuda_stream_pool { static constexpr std::size_t default_size{16}; ///< Default stream pool size /** - * @brief Construct a new cuda stream pool object of the given size + * @brief Construct a new cuda stream pool object of the given non-zero size * + * @throws logic_error if `pool_size` is zero * @param pool_size The number of streams in the pool */ - explicit cuda_stream_pool(std::size_t pool_size = default_size) : streams_(pool_size) {} + explicit cuda_stream_pool(std::size_t pool_size = default_size) : streams_(pool_size) + { + RMM_EXPECTS(pool_size > 0, "Stream pool size must be greater than zero"); + } ~cuda_stream_pool() = default; cuda_stream_pool(cuda_stream_pool&&) = delete; diff --git a/tests/cuda_stream_pool_tests.cpp b/tests/cuda_stream_pool_tests.cpp index 4fddb2da6..de17f8c3c 100644 --- a/tests/cuda_stream_pool_tests.cpp +++ b/tests/cuda_stream_pool_tests.cpp @@ -26,6 +26,11 @@ struct CudaStreamPoolTest : public ::testing::Test { rmm::cuda_stream_pool pool{}; }; +TEST_F(CudaStreamPoolTest, ZeroSizePoolException) +{ + EXPECT_THROW(rmm::cuda_stream_pool pool{0}, rmm::logic_error); +} + TEST_F(CudaStreamPoolTest, Unequal) { auto const stream_a = this->pool.get_stream();