From ac53a0fdc35fe36a11a3ac4debd4cc9a4076fe7f Mon Sep 17 00:00:00 2001 From: Micka Date: Tue, 10 Sep 2024 17:20:52 +0200 Subject: [PATCH] [BUG] Fix bitset function visibility (#2429) `raft::ceildiv` is also being replaced with `raft::div_rounding_up_safe` to avoid including CUDA headers when not needed. Authors: - Micka (https://github.com/lowener) Approvers: - rhdong (https://github.com/rhdong) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2429 --- cpp/include/raft/core/bitmap.cuh | 6 +++--- cpp/include/raft/core/bitset.cuh | 24 ++++++------------------ cpp/include/raft/core/bitset.hpp | 11 +++++++++-- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index cafd1977ab..024b1244a6 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -35,9 +35,9 @@ _RAFT_HOST_DEVICE inline bool bitmap_view::test(const index_t } template -_RAFT_HOST_DEVICE void bitmap_view::set(const index_t row, - const index_t col, - bool new_value) const +_RAFT_DEVICE void bitmap_view::set(const index_t row, + const index_t col, + bool new_value) const { set(row * cols_ + col, new_value); } diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index 0cdb4c1fb6..b6e6128eca 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -46,8 +46,8 @@ _RAFT_HOST_DEVICE bool bitset_view::operator[](const index_t } template -_RAFT_HOST_DEVICE void bitset_view::set(const index_t sample_index, - bool set_value) const +_RAFT_DEVICE void bitset_view::set(const index_t sample_index, + bool set_value) const { const index_t bit_element = sample_index / bitset_element_size; const index_t bit_index = sample_index % bitset_element_size; @@ -60,18 +60,12 @@ _RAFT_HOST_DEVICE void bitset_view::set(const index_t sample_ } } -template -_RAFT_HOST_DEVICE inline index_t bitset_view::n_elements() const -{ - return raft::ceildiv(bitset_len_, bitset_element_size); -} - template bitset::bitset(const raft::resources& res, raft::device_vector_view mask_index, index_t bitset_len, bool default_value) - : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), + : bitset_{std::size_t(raft::div_rounding_up_safe(bitset_len, bitset_element_size)), raft::resource::get_cuda_stream(res)}, bitset_len_{bitset_len} { @@ -83,26 +77,20 @@ template bitset::bitset(const raft::resources& res, index_t bitset_len, bool default_value) - : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), + : bitset_{std::size_t(raft::div_rounding_up_safe(bitset_len, bitset_element_size)), raft::resource::get_cuda_stream(res)}, bitset_len_{bitset_len} { reset(res, default_value); } -template -index_t bitset::n_elements() const -{ - return raft::ceildiv(bitset_len_, bitset_element_size); -} - template void bitset::resize(const raft::resources& res, index_t new_bitset_len, bool default_value) { - auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); - auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); + auto old_size = raft::div_rounding_up_safe(bitset_len_, bitset_element_size); + auto new_size = raft::div_rounding_up_safe(new_bitset_len, bitset_element_size); bitset_.resize(new_size); bitset_len_ = new_bitset_len; if (old_size < new_size) { diff --git a/cpp/include/raft/core/bitset.hpp b/cpp/include/raft/core/bitset.hpp index 0df12f25e6..3608ee43fa 100644 --- a/cpp/include/raft/core/bitset.hpp +++ b/cpp/include/raft/core/bitset.hpp @@ -20,6 +20,7 @@ #include #include #include +#include namespace raft::core { /** @@ -89,7 +90,10 @@ struct bitset_view { /** * @brief Get the number of elements used by the bitset representation. */ - inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t; + inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t + { + return raft::div_rounding_up_safe(bitset_len_, bitset_element_size); + } inline auto to_mdspan() -> raft::device_vector_view { @@ -173,7 +177,10 @@ struct bitset { /** * @brief Get the number of elements used by the bitset representation. */ - inline auto n_elements() const -> index_t; + inline auto n_elements() const -> index_t + { + return raft::div_rounding_up_safe(bitset_len_, bitset_element_size); + } /** @brief Get an mdspan view of the current bitset */ inline auto to_mdspan() -> raft::device_vector_view