Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Improvements on bitset class #1877

Merged
merged 10 commits into from
Oct 24, 2023
2 changes: 1 addition & 1 deletion cpp/bench/prims/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct bitset_bench : public fixture {
loop_on_state(state, [this]() {
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
this->res, raft::make_const_mdspan(mask.view()), params.bitset_len);
my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view());
my_bitset.test(raft::make_const_mdspan(queries.view()), outputs.view());
});
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct CagraBench : public fixture {
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
removed_indices_bitset_.set(removed_indices.view());
index_.emplace(raft::neighbors::cagra::index<T, IdxT>(
handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view())));
}
Expand Down
220 changes: 171 additions & 49 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

#pragma once

#include <raft/core/detail/mdspan_util.cuh> // native_popc
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/util/device_atomics.cuh>
#include <thrust/for_each.h>

Expand All @@ -39,7 +42,7 @@ namespace raft::core {
*/
template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset_view {
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

_RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len)
: bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}
Expand Down Expand Up @@ -69,6 +72,34 @@ struct bitset_view {
const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0;
return is_bit_set;
}
/**
* @brief Device function to test if a given index is set in the bitset.
*
* @param sample_index Single index to test
* @return bool True if index has not been unset in the bitset
*/
inline _RAFT_DEVICE auto operator[](const index_t sample_index) const -> bool
{
return test(sample_index);
}
/**
* @brief Device function to set a given index to set_value in the bitset.
*
* @param sample_index index to set
* @param set_value Value to set the bit to (true or false)
*/
inline _RAFT_DEVICE void set(const index_t sample_index, bool set_value) const
{
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr_ + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr_ + bit_element, bitmask2);
divyegala marked this conversation as resolved.
Show resolved Hide resolved
}
}

/**
* @brief Get the device pointer to the bitset.
Expand Down Expand Up @@ -114,7 +145,7 @@ struct bitset_view {
*/
template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset {
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

/**
* @brief Construct a new bitset object with a list of indices to unset.
Expand All @@ -131,13 +162,10 @@ struct bitset {
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
raft::resource::get_cuda_stream(res)},
bitset_len_{bitset_len},
default_value_{default_value}
res_{res}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
set(res, mask_index, !default_value);
reset(default_value);
set(mask_index, !default_value);
}

/**
Expand All @@ -151,12 +179,9 @@ struct bitset {
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
resource::get_cuda_stream(res)},
bitset_len_{bitset_len},
default_value_{default_value}
res_{res}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
reset(default_value);
}
// Disable copy constructor
bitset(const bitset&) = delete;
Expand Down Expand Up @@ -208,100 +233,197 @@ struct bitset {

/** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to
* the default value. */
void resize(const raft::resources& res, index_t new_bitset_len)
void resize(index_t new_bitset_len, bool default_value = true)
{
auto old_size = raft::ceildiv(bitset_len_, bitset_element_size);
auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size);
bitset_.resize(new_size);
bitset_len_ = new_bitset_len;
if (old_size < new_size) {
// If the new size is larger, set the new bits to the default value
cudaMemsetAsync(bitset_.data() + old_size,
default_value_ ? 0xff : 0x00,
(new_size - old_size) * sizeof(bitset_t),
resource::get_cuda_stream(res));
RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data() + old_size,
lowener marked this conversation as resolved.
Show resolved Hide resolved
default_value ? 0xff : 0x00,
(new_size - old_size) * sizeof(bitset_t),
resource::get_cuda_stream(res_)));
}
}

/**
* @brief Test a list of indices in a bitset.
*
* @tparam output_t Output type of the test. Default is bool.
* @param res RAFT resources
* @param queries List of indices to test
* @param output List of outputs
*/
template <typename output_t = bool>
void test(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> queries,
void test(raft::device_vector_view<const index_t, index_t> queries,
raft::device_vector_view<output_t, index_t> output) const
{
RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size");
auto bitset_view = view();
raft::linalg::map(
res,
res_,
output,
[bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); },
queries);
}
/**
* @brief Set a list of indices in a bitset to set_value.
*
* @param res RAFT resources
* @param mask_index indices to remove from the bitset
* @param set_value Value to set the bits to (true or false)
*/
void set(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
bool set_value = false)
void set(raft::device_vector_view<const index_t, index_t> mask_index, bool set_value = false)
{
auto* bitset_ptr = this->data_handle();
thrust::for_each_n(resource::get_thrust_policy(res),
auto this_bitset_view = view();
thrust::for_each_n(resource::get_thrust_policy(res_),
mask_index.data_handle(),
mask_index.extent(0),
[bitset_ptr, set_value] __device__(const index_t sample_index) {
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr + bit_element, bitmask2);
}
[this_bitset_view, set_value] __device__(const index_t sample_index) {
this_bitset_view.set(sample_index, set_value);
});
}
/**
* @brief Flip all the bits in a bitset.
*
* @param res RAFT resources
*/
void flip(const raft::resources& res)
void flip()
{
auto bitset_span = this->to_mdspan();
raft::linalg::map(
res,
res_,
bitset_span,
[] __device__(bitset_t element) { return bitset_t(~element); },
raft::make_const_mdspan(bitset_span));
}
/**
* @brief Reset the bits in a bitset.
*
* @param res RAFT resources
* @param default_value Value to set the bits to (true or false)
*/
void reset(bool default_value = true)
{
RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data(),
lowener marked this conversation as resolved.
Show resolved Hide resolved
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res_)));
}
/**
* @brief Returns the number of bits set to true in count_gpu_scalar.
*
* @param[out] count_gpu_scalar Device scalar to store the count
*/
void reset(const raft::resources& res)
void count(raft::device_scalar_view<index_t> count_gpu_scalar)
{
cudaMemsetAsync(bitset_.data(),
default_value_ ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
auto n_elements_ = n_elements();
auto count_gpu =
raft::make_device_vector_view<index_t, index_t>(count_gpu_scalar.data_handle(), 1);
auto bitset_matrix_view = raft::make_device_matrix_view<const bitset_t, index_t, col_major>(
bitset_.data(), n_elements_, 1);

bitset_t n_last_element = (bitset_len_ % bitset_element_size);
bitset_t last_element_mask =
n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0};
raft::linalg::coalesced_reduction(
res_,
bitset_matrix_view,
count_gpu,
index_t{0},
false,
[last_element_mask, n_elements_] __device__(bitset_t element, index_t index) {
index_t result = 0;
if constexpr (bitset_element_size == 64) { // Needed because __popc doesn't support 64bit
if (index == n_elements_ - 1)
result = index_t(raft::detail::native_popc<uint64_t>(element & last_element_mask));
else
result = index_t(raft::detail::native_popc<uint64_t>(element));
} else {
if (index == n_elements_ - 1)
result = index_t(__popc(element & last_element_mask));
lowener marked this conversation as resolved.
Show resolved Hide resolved
else
result = index_t(__popc(element));
}

return result;
});
}
/**
* @brief Returns the number of bits set to true.
*
* @return index_t Number of bits set to true
*/
auto count() -> index_t
{
auto count_gpu_scalar = raft::make_device_scalar<index_t>(res_, 0.0);
count(count_gpu_scalar.view());
index_t count_cpu = 0;
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res_));
resource::sync_stream(res_);
return count_cpu;
}
/**
* @brief Checks if any of the bits are set to true in the bitset.
*/
bool any() { return count() > 0; }
/**
* @brief Checks if all of the bits are set to true in the bitset.
*/
bool all() { return count() == bitset_len_; }
/**
* @brief Checks if none of the bits are set to true in the bitset.
*/
bool none() { return count() == 0; }

bitset<bitset_t, index_t>& operator|=(const bitset<bitset_t, index_t>& other)
{
RAFT_EXPECTS(size() == other.size(), "Sizes must be equal");
auto this_span = to_mdspan();
auto other_span = other.to_mdspan();
raft::linalg::map(
res_,
this_span,
[] __device__(bitset_t this_element, bitset_t other_element) {
return this_element | other_element;
},
raft::make_const_mdspan(this_span),
other_span);
return *this;
}
bitset<bitset_t, index_t>& operator&=(const bitset<bitset_t, index_t>& other)
{
RAFT_EXPECTS(size() == other.size(), "Sizes must be equal");
auto this_span = to_mdspan();
auto other_span = other.to_mdspan();
raft::linalg::map(
res_,
this_span,
[] __device__(bitset_t this_element, bitset_t other_element) {
return this_element & other_element;
},
raft::make_const_mdspan(this_span),
other_span);
return *this;
}
bitset<bitset_t, index_t>& operator^=(const bitset<bitset_t, index_t>& other)
{
RAFT_EXPECTS(size() == other.size(), "Sizes must be equal");
auto this_span = to_mdspan();
auto other_span = other.to_mdspan();
raft::linalg::map(
res_,
this_span,
[] __device__(bitset_t this_element, bitset_t other_element) {
return this_element ^ other_element;
},
raft::make_const_mdspan(this_span),
other_span);
return *this;
}

private:
raft::device_uvector<bitset_t> bitset_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this type is, is this from rmm? Why not use raft::device_vector or rmm::device_uvector?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is buried in the container policy code for RAFT and I'd prefer using RAFT types, even if they are just direct wrappers around thrust or rmm types (it provides us a safe facade to maintain api compatibility even if the underlying impls should somehow change or need to be modified).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raft::device_vector is the way to go here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, I chose raft::device_uvector because it is resizable. raft::device_vector isn't, and this resizable feature is very helpful for incremental addition to IVF indexes (and soon CAGRA)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own education, won't raft::device_vector initialize the underlying memory? Would it be worth exposing raft::device_uvector more publicly specifically for cases like this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala You're quite right! Had rmm::device_vector in my mind instead of raft::device_vector. raft::device_vector uses rmm::device_uvector for its container policy, so no initialization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding my two cents on the resize question, I'd be quite squeamish about using the resources from the constructor. Not only do we not know if that resources object has gone out of scope, but it also makes it much harder to track and ensure stream safety if we are not explicitly passing the stream we want to use when a resize might occur. For example:

auto res1 = device_resources{};  // Some underlying stream
auto arr = device_mdarray{res1, ...};
res1.sync_stream();
auto res2 = device_resources();
// modify underlying data of arr using the stream from res2
res2.sync_stream();
// Is it now safe to copy back data to the host on res2? Maybe. If a resize was triggered in between, we'd need another res1.sync_stream() first

Alternatively, if we sync the stream from res1 in the resize method, we might still have an issue if the data are actively being modified on another stream during the resize.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wphicks thanks for typing that out, you are right and that just clarifies me for again why we didn't end up supporting this in the first place anyway.

raft::device_uvector is supposed to be an internal/detail type and I do not belive it should be used directly. If resize() is needed then we should switch the type to rmm::device_uvector. Thoughts on this @cjnolet @wphicks ?

Copy link
Member

@cjnolet cjnolet Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the original suggestion to use a raft type instead of directly using the rmm type in public API code was from me. To make the point more clear, this is about consistency and safeguarding our implementations from changes that we cannot control. This is in a very similar vein to our diligent wrapping of the CUDA math libraries APIs so that we can centralize those calls and change any underlying details should the rug get pulled out from under us.

Upon closer inspection to Mickael's changes, however, I very much agree that we should not be storing the raft::resources instance as object members and Will's example above is one of the very reasons we want to avoid this. We wouldn't otherwise be doing it in the device_container_policy if it weren't for the fact that we needed the deferred allocation. To Divye'a point, the device_uvector had been kept an implementation detail until recently. I would prefer that we fix that problem at some point soon, rather than to continue using this pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed raft::resources instance from the bitset class.

index_t bitset_len_;
bool default_value_;
const raft::resources& res_;
};

/** @} */
Expand Down
Loading