Skip to content

Commit

Permalink
[FEA] Improvements on bitset class (rapidsai#1877)
Browse files Browse the repository at this point in the history
Related to rapidsai#1866.
This PR adds useful operations on bitsets: `count()`, `any()`, ...

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - William Hicks (https://github.com/wphicks)
  - Divye Gala (https://github.com/divyegala)

URL: rapidsai#1877
  • Loading branch information
lowener authored Oct 24, 2023
1 parent 945355d commit 53c2539
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 48 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/prims/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct bitset_bench : public fixture {
loop_on_state(state, [this]() {
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
this->res, raft::make_const_mdspan(mask.view()), params.bitset_len);
my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view());
my_bitset.test(this->res, raft::make_const_mdspan(queries.view()), outputs.view());
});
}

Expand Down
174 changes: 132 additions & 42 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

#pragma once

#include <raft/core/detail/mdspan_util.cuh> // native_popc
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/util/device_atomics.cuh>
#include <thrust/for_each.h>

Expand All @@ -39,7 +42,7 @@ namespace raft::core {
*/
template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset_view {
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

_RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len)
: bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}
Expand Down Expand Up @@ -69,12 +72,40 @@ struct bitset_view {
const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0;
return is_bit_set;
}
/**
* @brief Device function to test if a given index is set in the bitset.
*
* @param sample_index Single index to test
* @return bool True if index has not been unset in the bitset
*/
inline _RAFT_DEVICE auto operator[](const index_t sample_index) const -> bool
{
return test(sample_index);
}
/**
* @brief Device function to set a given index to set_value in the bitset.
*
* @param sample_index index to set
* @param set_value Value to set the bit to (true or false)
*/
inline _RAFT_DEVICE void set(const index_t sample_index, bool set_value) const
{
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr_ + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr_ + bit_element, bitmask2);
}
}

/**
* @brief Get the device pointer to the bitset.
*/
inline _RAFT_HOST_DEVICE auto data_handle() -> bitset_t* { return bitset_ptr_; }
inline _RAFT_HOST_DEVICE auto data_handle() const -> const bitset_t* { return bitset_ptr_; }
inline _RAFT_HOST_DEVICE auto data() -> bitset_t* { return bitset_ptr_; }
inline _RAFT_HOST_DEVICE auto data() const -> const bitset_t* { return bitset_ptr_; }
/**
* @brief Get the number of bits of the bitset representation.
*/
Expand Down Expand Up @@ -114,7 +145,7 @@ struct bitset_view {
*/
template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset {
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

/**
* @brief Construct a new bitset object with a list of indices to unset.
Expand All @@ -130,13 +161,9 @@ struct bitset {
bool default_value = true)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
raft::resource::get_cuda_stream(res)},
bitset_len_{bitset_len},
default_value_{default_value}
bitset_len_{bitset_len}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
reset(res, default_value);
set(res, mask_index, !default_value);
}

Expand All @@ -150,13 +177,9 @@ struct bitset {
bitset(const raft::resources& res, index_t bitset_len, bool default_value = true)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
resource::get_cuda_stream(res)},
bitset_len_{bitset_len},
default_value_{default_value}
bitset_len_{bitset_len}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
reset(res, default_value);
}
// Disable copy constructor
bitset(const bitset&) = delete;
Expand All @@ -181,8 +204,8 @@ struct bitset {
/**
* @brief Get the device pointer to the bitset.
*/
inline auto data_handle() -> bitset_t* { return bitset_.data(); }
inline auto data_handle() const -> const bitset_t* { return bitset_.data(); }
inline auto data() -> bitset_t* { return bitset_.data(); }
inline auto data() const -> const bitset_t* { return bitset_.data(); }
/**
* @brief Get the number of bits of the bitset representation.
*/
Expand All @@ -207,19 +230,24 @@ struct bitset {
}

/** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to
* the default value. */
void resize(const raft::resources& res, index_t new_bitset_len)
* the default value.
* @param res RAFT resources
* @param new_bitset_len new size of the bitset
* @param default_value default value to initialize the new bits to
*/
void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true)
{
auto old_size = raft::ceildiv(bitset_len_, bitset_element_size);
auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size);
bitset_.resize(new_size);
bitset_len_ = new_bitset_len;
if (old_size < new_size) {
// If the new size is larger, set the new bits to the default value
cudaMemsetAsync(bitset_.data() + old_size,
default_value_ ? 0xff : 0x00,
(new_size - old_size) * sizeof(bitset_t),
resource::get_cuda_stream(res));

thrust::fill_n(resource::get_thrust_policy(res),
bitset_.data() + old_size,
new_size - old_size,
default_value ? ~bitset_t{0} : bitset_t{0});
}
}

Expand Down Expand Up @@ -255,25 +283,16 @@ struct bitset {
raft::device_vector_view<const index_t, index_t> mask_index,
bool set_value = false)
{
auto* bitset_ptr = this->data_handle();
auto this_bitset_view = view();
thrust::for_each_n(resource::get_thrust_policy(res),
mask_index.data_handle(),
mask_index.extent(0),
[bitset_ptr, set_value] __device__(const index_t sample_index) {
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr + bit_element, bitmask2);
}
[this_bitset_view, set_value] __device__(const index_t sample_index) {
this_bitset_view.set(sample_index, set_value);
});
}
/**
* @brief Flip all the bits in a bitset.
*
* @param res RAFT resources
*/
void flip(const raft::resources& res)
Expand All @@ -289,19 +308,90 @@ struct bitset {
* @brief Reset the bits in a bitset.
*
* @param res RAFT resources
* @param default_value Value to set the bits to (true or false)
*/
void reset(const raft::resources& res)
void reset(const raft::resources& res, bool default_value = true)
{
cudaMemsetAsync(bitset_.data(),
default_value_ ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
thrust::fill_n(resource::get_thrust_policy(res),
bitset_.data(),
n_elements(),
default_value ? ~bitset_t{0} : bitset_t{0});
}
/**
* @brief Returns the number of bits set to true in count_gpu_scalar.
*
* @param[in] res RAFT resources
* @param[out] count_gpu_scalar Device scalar to store the count
*/
void count(const raft::resources& res, raft::device_scalar_view<index_t> count_gpu_scalar)
{
auto n_elements_ = n_elements();
auto count_gpu =
raft::make_device_vector_view<index_t, index_t>(count_gpu_scalar.data_handle(), 1);
auto bitset_matrix_view = raft::make_device_matrix_view<const bitset_t, index_t, col_major>(
bitset_.data(), n_elements_, 1);

bitset_t n_last_element = (bitset_len_ % bitset_element_size);
bitset_t last_element_mask =
n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0};
raft::linalg::coalesced_reduction(
res,
bitset_matrix_view,
count_gpu,
index_t{0},
false,
[last_element_mask, n_elements_] __device__(bitset_t element, index_t index) {
index_t result = 0;
if constexpr (bitset_element_size == 64) {
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(element & last_element_mask));
else
result = index_t(raft::detail::popc(element));
} else { // Needed because popc is not overloaded for 16 and 8 bit elements
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask));
else
result = index_t(raft::detail::popc(uint32_t{element}));
}

return result;
});
}
/**
* @brief Returns the number of bits set to true.
*
* @param res RAFT resources
* @return index_t Number of bits set to true
*/
auto count(const raft::resources& res) -> index_t
{
auto count_gpu_scalar = raft::make_device_scalar<index_t>(res, 0.0);
count(res, count_gpu_scalar.view());
index_t count_cpu = 0;
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
return count_cpu;
}
/**
* @brief Checks if any of the bits are set to true in the bitset.
* @param res RAFT resources
*/
bool any(const raft::resources& res) { return count(res) > 0; }
/**
* @brief Checks if all of the bits are set to true in the bitset.
* @param res RAFT resources
*/
bool all(const raft::resources& res) { return count(res) == bitset_len_; }
/**
* @brief Checks if none of the bits are set to true in the bitset.
* @param res RAFT resources
*/
bool none(const raft::resources& res) { return count(res) == 0; }

private:
raft::device_uvector<bitset_t> bitset_;
index_t bitset_len_;
bool default_value_;
};

/** @} */
Expand Down
24 changes: 19 additions & 5 deletions cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/core/bitset.cuh>
#include <raft/core/device_mdarray.hpp>
#include <raft/linalg/init.cuh>
#include <raft/random/rng.cuh>

#include <gtest/gtest.h>
Expand All @@ -43,7 +44,7 @@ auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream&
template <typename bitset_t, typename index_t>
void add_cpu_bitset(std::vector<bitset_t>& bitset, const std::vector<index_t>& mask_idx)
{
static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8;
constexpr size_t bitset_element_size = sizeof(bitset_t) * 8;
for (size_t i = 0; i < mask_idx.size(); i++) {
auto idx = mask_idx[i];
bitset[idx / bitset_element_size] &= ~(bitset_t{1} << (idx % bitset_element_size));
Expand All @@ -64,7 +65,7 @@ void test_cpu_bitset(const std::vector<bitset_t>& bitset,
const std::vector<index_t>& queries,
std::vector<uint8_t>& result)
{
static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8;
constexpr size_t bitset_element_size = sizeof(bitset_t) * 8;
for (size_t i = 0; i < queries.size(); i++) {
result[i] = uint8_t((bitset[queries[i] / bitset_element_size] &
(bitset_t{1} << (queries[i] % bitset_element_size))) != 0);
Expand Down Expand Up @@ -111,7 +112,7 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
// calculate the results
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len));
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream);

// calculate the reference
create_cpu_bitset(bitset_ref, mask_cpu);
Expand All @@ -138,18 +139,31 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream);
resource::sync_stream(res, stream);
my_bitset.set(res, mask_device.view());
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream);

add_cpu_bitset(bitset_ref, mask_cpu);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// Flip the bitset and re-test
auto bitset_count = my_bitset.count(res);
my_bitset.flip(res);
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
ASSERT_EQ(my_bitset.count(res), spec.bitset_len - bitset_count);
update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream);
flip_cpu_bitset(bitset_ref);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// Test count() operations
my_bitset.reset(res, false);
ASSERT_EQ(my_bitset.any(res), false);
ASSERT_EQ(my_bitset.none(res), true);
raft::linalg::range(query_device.data_handle(), query_device.size(), stream);
my_bitset.set(res, raft::make_const_mdspan(query_device.view()), true);
bitset_count = my_bitset.count(res);
ASSERT_EQ(bitset_count, query_device.size());
ASSERT_EQ(my_bitset.any(res), true);
ASSERT_EQ(my_bitset.none(res), false);
}
};

Expand Down

0 comments on commit 53c2539

Please sign in to comment.