Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Improvements on bitset class #1877

Merged
merged 10 commits into from
Oct 24, 2023
2 changes: 1 addition & 1 deletion cpp/bench/prims/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct bitset_bench : public fixture {
loop_on_state(state, [this]() {
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
this->res, raft::make_const_mdspan(mask.view()), params.bitset_len);
my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view());
my_bitset.test(this->res, raft::make_const_mdspan(queries.view()), outputs.view());
});
}

Expand Down
174 changes: 132 additions & 42 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

#pragma once

#include <raft/core/detail/mdspan_util.cuh> // native_popc
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/util/device_atomics.cuh>
#include <thrust/for_each.h>

Expand All @@ -39,7 +42,7 @@ namespace raft::core {
*/
template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset_view {
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

_RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len)
: bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}
Expand Down Expand Up @@ -69,12 +72,40 @@ struct bitset_view {
const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0;
return is_bit_set;
}
/**
* @brief Device function to test if a given index is set in the bitset.
*
* @param sample_index Single index to test
* @return bool True if index has not been unset in the bitset
*/
inline _RAFT_DEVICE auto operator[](const index_t sample_index) const -> bool
{
return test(sample_index);
}
/**
* @brief Device function to set a given index to set_value in the bitset.
*
* @param sample_index index to set
* @param set_value Value to set the bit to (true or false)
*/
inline _RAFT_DEVICE void set(const index_t sample_index, bool set_value) const
{
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr_ + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr_ + bit_element, bitmask2);
divyegala marked this conversation as resolved.
Show resolved Hide resolved
}
}

/**
* @brief Get the device pointer to the bitset.
*/
inline _RAFT_HOST_DEVICE auto data_handle() -> bitset_t* { return bitset_ptr_; }
inline _RAFT_HOST_DEVICE auto data_handle() const -> const bitset_t* { return bitset_ptr_; }
inline _RAFT_HOST_DEVICE auto data() -> bitset_t* { return bitset_ptr_; }
inline _RAFT_HOST_DEVICE auto data() const -> const bitset_t* { return bitset_ptr_; }
/**
* @brief Get the number of bits of the bitset representation.
*/
Expand Down Expand Up @@ -114,7 +145,7 @@ struct bitset_view {
*/
template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset {
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

/**
* @brief Construct a new bitset object with a list of indices to unset.
Expand All @@ -130,13 +161,9 @@ struct bitset {
bool default_value = true)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
raft::resource::get_cuda_stream(res)},
bitset_len_{bitset_len},
default_value_{default_value}
bitset_len_{bitset_len}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
reset(res, default_value);
set(res, mask_index, !default_value);
}

Expand All @@ -150,13 +177,9 @@ struct bitset {
bitset(const raft::resources& res, index_t bitset_len, bool default_value = true)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
resource::get_cuda_stream(res)},
bitset_len_{bitset_len},
default_value_{default_value}
bitset_len_{bitset_len}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
reset(res, default_value);
}
// Disable copy constructor
bitset(const bitset&) = delete;
Expand All @@ -181,8 +204,8 @@ struct bitset {
/**
* @brief Get the device pointer to the bitset.
*/
inline auto data_handle() -> bitset_t* { return bitset_.data(); }
inline auto data_handle() const -> const bitset_t* { return bitset_.data(); }
inline auto data() -> bitset_t* { return bitset_.data(); }
inline auto data() const -> const bitset_t* { return bitset_.data(); }
/**
* @brief Get the number of bits of the bitset representation.
*/
Expand All @@ -207,19 +230,24 @@ struct bitset {
}

/** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to
* the default value. */
void resize(const raft::resources& res, index_t new_bitset_len)
* the default value.
* @param res RAFT resources
* @param new_bitset_len new size of the bitset
* @param default_value default value to initialize the new bits to
*/
void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true)
{
auto old_size = raft::ceildiv(bitset_len_, bitset_element_size);
auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size);
bitset_.resize(new_size);
bitset_len_ = new_bitset_len;
if (old_size < new_size) {
// If the new size is larger, set the new bits to the default value
cudaMemsetAsync(bitset_.data() + old_size,
default_value_ ? 0xff : 0x00,
(new_size - old_size) * sizeof(bitset_t),
resource::get_cuda_stream(res));

thrust::fill_n(resource::get_thrust_policy(res),
bitset_.data() + old_size,
new_size - old_size,
default_value ? ~bitset_t{0} : bitset_t{0});
}
}

Expand Down Expand Up @@ -255,25 +283,16 @@ struct bitset {
raft::device_vector_view<const index_t, index_t> mask_index,
bool set_value = false)
{
auto* bitset_ptr = this->data_handle();
auto this_bitset_view = view();
thrust::for_each_n(resource::get_thrust_policy(res),
mask_index.data_handle(),
mask_index.extent(0),
[bitset_ptr, set_value] __device__(const index_t sample_index) {
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr + bit_element, bitmask2);
}
[this_bitset_view, set_value] __device__(const index_t sample_index) {
this_bitset_view.set(sample_index, set_value);
});
}
/**
* @brief Flip all the bits in a bitset.
*
* @param res RAFT resources
*/
void flip(const raft::resources& res)
Expand All @@ -289,19 +308,90 @@ struct bitset {
* @brief Reset the bits in a bitset.
*
* @param res RAFT resources
* @param default_value Value to set the bits to (true or false)
*/
void reset(const raft::resources& res)
void reset(const raft::resources& res, bool default_value = true)
{
cudaMemsetAsync(bitset_.data(),
default_value_ ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
thrust::fill_n(resource::get_thrust_policy(res),
bitset_.data(),
n_elements(),
default_value ? ~bitset_t{0} : bitset_t{0});
}
/**
* @brief Returns the number of bits set to true in count_gpu_scalar.
*
* @param[in] res RAFT resources
* @param[out] count_gpu_scalar Device scalar to store the count
*/
void count(const raft::resources& res, raft::device_scalar_view<index_t> count_gpu_scalar)
{
auto n_elements_ = n_elements();
auto count_gpu =
raft::make_device_vector_view<index_t, index_t>(count_gpu_scalar.data_handle(), 1);
auto bitset_matrix_view = raft::make_device_matrix_view<const bitset_t, index_t, col_major>(
bitset_.data(), n_elements_, 1);

bitset_t n_last_element = (bitset_len_ % bitset_element_size);
bitset_t last_element_mask =
n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0};
raft::linalg::coalesced_reduction(
res,
bitset_matrix_view,
count_gpu,
index_t{0},
false,
[last_element_mask, n_elements_] __device__(bitset_t element, index_t index) {
index_t result = 0;
if constexpr (bitset_element_size == 64) {
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(element & last_element_mask));
else
result = index_t(raft::detail::popc(element));
} else { // Needed because popc is not overloaded for 16 and 8 bit elements
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask));
else
result = index_t(raft::detail::popc(uint32_t{element}));
}

return result;
});
}
/**
* @brief Returns the number of bits set to true.
*
* @param res RAFT resources
* @return index_t Number of bits set to true
*/
auto count(const raft::resources& res) -> index_t
{
auto count_gpu_scalar = raft::make_device_scalar<index_t>(res, 0.0);
count(res, count_gpu_scalar.view());
index_t count_cpu = 0;
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
return count_cpu;
}
/**
* @brief Checks if any of the bits are set to true in the bitset.
* @param res RAFT resources
*/
bool any(const raft::resources& res) { return count(res) > 0; }
/**
* @brief Checks if all of the bits are set to true in the bitset.
* @param res RAFT resources
*/
bool all(const raft::resources& res) { return count(res) == bitset_len_; }
/**
* @brief Checks if none of the bits are set to true in the bitset.
* @param res RAFT resources
*/
bool none(const raft::resources& res) { return count(res) == 0; }

private:
raft::device_uvector<bitset_t> bitset_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this type is, is this from rmm? Why not use raft::device_vector or rmm::device_uvector?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is buried in the container policy code for RAFT and I'd prefer using RAFT types, even if they are just direct wrappers around thrust or rmm types (it provides us a safe facade to maintain api compatibility even if the underlying impls should somehow change or need to be modified).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raft::device_vector is the way to go here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, I chose raft::device_uvector because it is resizable. raft::device_vector isn't, and this resizable feature is very helpful for incremental addition to IVF indexes (and soon CAGRA)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own education, won't raft::device_vector initialize the underlying memory? Would it be worth exposing raft::device_uvector more publicly specifically for cases like this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala You're quite right! Had rmm::device_vector in my mind instead of raft::device_vector. raft::device_vector uses rmm::device_uvector for its container policy, so no initialization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding my two cents on the resize question, I'd be quite squeamish about using the resources from the constructor. Not only do we not know if that resources object has gone out of scope, but it also makes it much harder to track and ensure stream safety if we are not explicitly passing the stream we want to use when a resize might occur. For example:

auto res1 = device_resources{};  // Some underlying stream
auto arr = device_mdarray{res1, ...};
res1.sync_stream();
auto res2 = device_resources();
// modify underlying data of arr using the stream from res2
res2.sync_stream();
// Is it now safe to copy back data to the host on res2? Maybe. If a resize was triggered in between, we'd need another res1.sync_stream() first

Alternatively, if we sync the stream from res1 in the resize method, we might still have an issue if the data are actively being modified on another stream during the resize.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wphicks thanks for typing that out, you are right and that just clarifies me for again why we didn't end up supporting this in the first place anyway.

raft::device_uvector is supposed to be an internal/detail type and I do not belive it should be used directly. If resize() is needed then we should switch the type to rmm::device_uvector. Thoughts on this @cjnolet @wphicks ?

Copy link
Member

@cjnolet cjnolet Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the original suggestion to use a raft type instead of directly using the rmm type in public API code was from me. To make the point more clear, this is about consistency and safeguarding our implementations from changes that we cannot control. This is in a very similar vein to our diligent wrapping of the CUDA math libraries APIs so that we can centralize those calls and change any underlying details should the rug get pulled out from under us.

Upon closer inspection to Mickael's changes, however, I very much agree that we should not be storing the raft::resources instance as object members and Will's example above is one of the very reasons we want to avoid this. We wouldn't otherwise be doing it in the device_container_policy if it weren't for the fact that we needed the deferred allocation. To Divye'a point, the device_uvector had been kept an implementation detail until recently. I would prefer that we fix that problem at some point soon, rather than to continue using this pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed raft::resources instance from the bitset class.

index_t bitset_len_;
bool default_value_;
};

/** @} */
Expand Down
24 changes: 19 additions & 5 deletions cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/core/bitset.cuh>
#include <raft/core/device_mdarray.hpp>
#include <raft/linalg/init.cuh>
#include <raft/random/rng.cuh>

#include <gtest/gtest.h>
Expand All @@ -43,7 +44,7 @@ auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream&
template <typename bitset_t, typename index_t>
void add_cpu_bitset(std::vector<bitset_t>& bitset, const std::vector<index_t>& mask_idx)
{
static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8;
constexpr size_t bitset_element_size = sizeof(bitset_t) * 8;
for (size_t i = 0; i < mask_idx.size(); i++) {
auto idx = mask_idx[i];
bitset[idx / bitset_element_size] &= ~(bitset_t{1} << (idx % bitset_element_size));
Expand All @@ -64,7 +65,7 @@ void test_cpu_bitset(const std::vector<bitset_t>& bitset,
const std::vector<index_t>& queries,
std::vector<uint8_t>& result)
{
static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8;
constexpr size_t bitset_element_size = sizeof(bitset_t) * 8;
for (size_t i = 0; i < queries.size(); i++) {
result[i] = uint8_t((bitset[queries[i] / bitset_element_size] &
(bitset_t{1} << (queries[i] % bitset_element_size))) != 0);
Expand Down Expand Up @@ -111,7 +112,7 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
// calculate the results
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len));
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream);

// calculate the reference
create_cpu_bitset(bitset_ref, mask_cpu);
Expand All @@ -138,18 +139,31 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream);
resource::sync_stream(res, stream);
my_bitset.set(res, mask_device.view());
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream);

add_cpu_bitset(bitset_ref, mask_cpu);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// Flip the bitset and re-test
auto bitset_count = my_bitset.count(res);
my_bitset.flip(res);
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
ASSERT_EQ(my_bitset.count(res), spec.bitset_len - bitset_count);
update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream);
flip_cpu_bitset(bitset_ref);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// Test count() operations
my_bitset.reset(res, false);
ASSERT_EQ(my_bitset.any(res), false);
ASSERT_EQ(my_bitset.none(res), true);
raft::linalg::range(query_device.data_handle(), query_device.size(), stream);
my_bitset.set(res, raft::make_const_mdspan(query_device.view()), true);
bitset_count = my_bitset.count(res);
ASSERT_EQ(bitset_count, query_device.size());
ASSERT_EQ(my_bitset.any(res), true);
ASSERT_EQ(my_bitset.none(res), false);
}
};

Expand Down