Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix bitset function visibility #2429

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp/include/raft/core/bitmap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ _RAFT_HOST_DEVICE inline bool bitmap_view<bitmap_t, index_t>::test(const index_t
}

template <typename bitmap_t, typename index_t>
_RAFT_HOST_DEVICE void bitmap_view<bitmap_t, index_t>::set(const index_t row,
const index_t col,
bool new_value) const
_RAFT_DEVICE void bitmap_view<bitmap_t, index_t>::set(const index_t row,
const index_t col,
bool new_value) const
{
set(row * cols_ + col, new_value);
}
Expand Down
24 changes: 6 additions & 18 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ _RAFT_HOST_DEVICE bool bitset_view<bitset_t, index_t>::operator[](const index_t
}

template <typename bitset_t, typename index_t>
_RAFT_HOST_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_index,
bool set_value) const
_RAFT_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_index,
bool set_value) const
{
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
Expand All @@ -60,18 +60,12 @@ _RAFT_HOST_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_
}
}

template <typename bitset_t, typename index_t>
_RAFT_HOST_DEVICE inline index_t bitset_view<bitset_t, index_t>::n_elements() const
{
return raft::ceildiv(bitset_len_, bitset_element_size);
}

template <typename bitset_t, typename index_t>
bitset<bitset_t, index_t>::bitset(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
index_t bitset_len,
bool default_value)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
: bitset_{std::size_t(raft::div_rounding_up_safe(bitset_len, bitset_element_size)),
raft::resource::get_cuda_stream(res)},
bitset_len_{bitset_len}
{
Expand All @@ -83,26 +77,20 @@ template <typename bitset_t, typename index_t>
bitset<bitset_t, index_t>::bitset(const raft::resources& res,
index_t bitset_len,
bool default_value)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
: bitset_{std::size_t(raft::div_rounding_up_safe(bitset_len, bitset_element_size)),
raft::resource::get_cuda_stream(res)},
bitset_len_{bitset_len}
{
reset(res, default_value);
}

template <typename bitset_t, typename index_t>
index_t bitset<bitset_t, index_t>::n_elements() const
{
return raft::ceildiv(bitset_len_, bitset_element_size);
}

template <typename bitset_t, typename index_t>
void bitset<bitset_t, index_t>::resize(const raft::resources& res,
index_t new_bitset_len,
bool default_value)
{
auto old_size = raft::ceildiv(bitset_len_, bitset_element_size);
auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size);
auto old_size = raft::div_rounding_up_safe(bitset_len_, bitset_element_size);
auto new_size = raft::div_rounding_up_safe(new_bitset_len, bitset_element_size);
bitset_.resize(new_size);
bitset_len_ = new_bitset_len;
if (old_size < new_size) {
Expand Down
11 changes: 9 additions & 2 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/integer_utils.hpp>

namespace raft::core {
/**
Expand Down Expand Up @@ -89,7 +90,10 @@ struct bitset_view {
/**
* @brief Get the number of elements used by the bitset representation.
*/
inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t;
inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t
{
return raft::div_rounding_up_safe(bitset_len_, bitset_element_size);
}

inline auto to_mdspan() -> raft::device_vector_view<bitset_t, index_t>
{
Expand Down Expand Up @@ -173,7 +177,10 @@ struct bitset {
/**
* @brief Get the number of elements used by the bitset representation.
*/
inline auto n_elements() const -> index_t;
inline auto n_elements() const -> index_t
{
return raft::div_rounding_up_safe(bitset_len_, bitset_element_size);
}

/** @brief Get an mdspan view of the current bitset */
inline auto to_mdspan() -> raft::device_vector_view<bitset_t, index_t>
Expand Down
Loading