From 8e7bb8773cf3e092e4ea662af736ea88d8cfd921 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 6 Oct 2023 11:56:44 +0200 Subject: [PATCH 1/5] Improvements on bitset class --- cpp/include/raft/core/bitset.cuh | 139 ++++++++++++++++++++++++++----- cpp/test/core/bitset.cu | 19 ++++- 2 files changed, 134 insertions(+), 24 deletions(-) diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index 6747c5fab0..d75957817a 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -16,10 +16,13 @@ #pragma once +#include // native_popc #include +#include #include #include #include +#include #include #include @@ -39,7 +42,7 @@ namespace raft::core { */ template struct bitset_view { - index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; + static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len) : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} @@ -69,6 +72,34 @@ struct bitset_view { const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; return is_bit_set; } + /** + * @brief Device function to test if a given index is set in the bitset. + * + * @param sample_index Single index to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_DEVICE auto operator[](const index_t sample_index) const -> bool + { + return test(sample_index); + } + /** + * @brief Device function to set a given index to set_value in the bitset. + * + * @param sample_index index to set + * @param set_value Value to set the bit to (true or false) + */ + inline _RAFT_DEVICE void set(const index_t sample_index, bool set_value) const + { + const index_t bit_element = sample_index / bitset_element_size; + const index_t bit_index = sample_index % bitset_element_size; + const bitset_t bitmask = bitset_t{1} << bit_index; + if (set_value) { + atomicOr(bitset_ptr_ + bit_element, bitmask); + } else { + const bitset_t bitmask2 = ~bitmask; + atomicAnd(bitset_ptr_ + bit_element, bitmask2); + } + } /** * @brief Get the device pointer to the bitset. @@ -114,7 +145,7 @@ struct bitset_view { */ template struct bitset { - index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; + static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; /** * @brief Construct a new bitset object with a list of indices to unset. @@ -130,8 +161,7 @@ struct bitset { bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), raft::resource::get_cuda_stream(res)}, - bitset_len_{bitset_len}, - default_value_{default_value} + bitset_len_{bitset_len} { cudaMemsetAsync(bitset_.data(), default_value ? 0xff : 0x00, @@ -150,8 +180,7 @@ struct bitset { bitset(const raft::resources& res, index_t bitset_len, bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), resource::get_cuda_stream(res)}, - bitset_len_{bitset_len}, - default_value_{default_value} + bitset_len_{bitset_len} { cudaMemsetAsync(bitset_.data(), default_value ? 0xff : 0x00, @@ -208,7 +237,7 @@ struct bitset { /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to * the default value. */ - void resize(const raft::resources& res, index_t new_bitset_len) + void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true) { auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); @@ -217,7 +246,7 @@ struct bitset { if (old_size < new_size) { // If the new size is larger, set the new bits to the default value cudaMemsetAsync(bitset_.data() + old_size, - default_value_ ? 0xff : 0x00, + default_value ? 0xff : 0x00, (new_size - old_size) * sizeof(bitset_t), resource::get_cuda_stream(res)); } @@ -255,20 +284,12 @@ struct bitset { raft::device_vector_view mask_index, bool set_value = false) { - auto* bitset_ptr = this->data_handle(); + auto this_bitset_view = view(); thrust::for_each_n(resource::get_thrust_policy(res), mask_index.data_handle(), mask_index.extent(0), - [bitset_ptr, set_value] __device__(const index_t sample_index) { - const index_t bit_element = sample_index / bitset_element_size; - const index_t bit_index = sample_index % bitset_element_size; - const bitset_t bitmask = bitset_t{1} << bit_index; - if (set_value) { - atomicOr(bitset_ptr + bit_element, bitmask); - } else { - const bitset_t bitmask2 = ~bitmask; - atomicAnd(bitset_ptr + bit_element, bitmask2); - } + [this_bitset_view, set_value] __device__(const index_t sample_index) { + this_bitset_view.set(sample_index, set_value); }); } /** @@ -289,19 +310,93 @@ struct bitset { * @brief Reset the bits in a bitset. * * @param res RAFT resources + * @param default_value Value to set the bits to (true or false) */ - void reset(const raft::resources& res) + void reset(const raft::resources& res, bool default_value = true) { cudaMemsetAsync(bitset_.data(), - default_value_ ? 0xff : 0x00, + default_value ? 0xff : 0x00, n_elements() * sizeof(bitset_t), resource::get_cuda_stream(res)); } + /** + * @brief Returns the number of bits set to true in count_gpu_scalar. + * + * @param[in] res RAFT resources + * @param[out] count_gpu_scalar Device scalar to store the count + */ + void count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar) + { + auto n_elements_ = n_elements(); + auto count_gpu = + raft::make_device_vector_view(count_gpu_scalar.data_handle(), 1); + auto bitset_matrix_view = raft::make_device_matrix_view( + bitset_.data(), n_elements_, 1); + + bitset_t n_last_element = (bitset_len_ % bitset_element_size); + bitset_t last_element_mask = + n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0}; + raft::linalg::coalesced_reduction( + res, + bitset_matrix_view, + count_gpu, + index_t{0}, + false, + [last_element_mask, n_elements_] __device__(bitset_t element, index_t index) { + index_t res = 0; + if constexpr (bitset_element_size == 64) { // Needed because __popc doesn't support 64bit + if (index == n_elements_ - 1) + res = index_t(raft::detail::native_popc(element & last_element_mask)); + else + res = index_t(raft::detail::native_popc(element)); + } else { + if (index == n_elements_ - 1) + res = index_t(__popc(element & last_element_mask)); + else + res = index_t(__popc(element)); + } + + return res; + }); + } + /** + * @brief Returns the number of bits set to true. + * + * @param[in] res RAFT resources + * @return index_t Number of bits set to true + */ + auto count(const raft::resources& res) -> index_t + { + auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); + count(res, count_gpu_scalar.view()); + index_t count_cpu = 0; + raft::update_host( + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); + return count_cpu; + } + /** + * @brief Checks if any of the bits are set to true in the bitset. + * + * @param res RAFT resources + */ + bool any(const raft::resources& res) { return count(res) > 0; } + /** + * @brief Checks if all of the bits are set to true in the bitset. + * + * @param res RAFT resources + */ + bool all(const raft::resources& res) { return count(res) == bitset_len_; } + /** + * @brief Checks if none of the bits are set to true in the bitset. + * + * @param res RAFT resources + */ + bool none(const raft::resources& res) { return count(res) == 0; } private: raft::device_uvector bitset_; index_t bitset_len_; - bool default_value_; }; /** @} */ diff --git a/cpp/test/core/bitset.cu b/cpp/test/core/bitset.cu index 215de98aaf..9d12f04891 100644 --- a/cpp/test/core/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -43,7 +44,7 @@ auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream& template void add_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) { - static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8; + constexpr size_t bitset_element_size = sizeof(bitset_t) * 8; for (size_t i = 0; i < mask_idx.size(); i++) { auto idx = mask_idx[i]; bitset[idx / bitset_element_size] &= ~(bitset_t{1} << (idx % bitset_element_size)); @@ -64,7 +65,7 @@ void test_cpu_bitset(const std::vector& bitset, const std::vector& queries, std::vector& result) { - static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8; + constexpr size_t bitset_element_size = sizeof(bitset_t) * 8; for (size_t i = 0; i < queries.size(); i++) { result[i] = uint8_t((bitset[queries[i] / bitset_element_size] & (bitset_t{1} << (queries[i] % bitset_element_size))) != 0); @@ -145,11 +146,25 @@ class BitsetTest : public testing::TestWithParam { ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); // Flip the bitset and re-test + auto bitset_count = my_bitset.count(res); my_bitset.flip(res); + ASSERT_EQ(my_bitset.count(res), spec.bitset_len - bitset_count); update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); flip_cpu_bitset(bitset_ref); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + + my_bitset.reset(res, false); + ASSERT_EQ(my_bitset.any(res), false); + ASSERT_EQ(my_bitset.none(res), true); + raft::linalg::range(query_device.data_handle(), query_device.size(), stream); + my_bitset.set(res, raft::make_const_mdspan(query_device.view()), true); + bitset_count = my_bitset.count(res); + ASSERT_EQ(bitset_count, query_device.size()); + ASSERT_EQ(my_bitset.any(res), true); + ASSERT_EQ(my_bitset.none(res), false); + + ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); } }; From 82a7575b938cdfb4e2013d68675fc064730e19e3 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 6 Oct 2023 19:03:37 +0200 Subject: [PATCH 2/5] Add operators --- cpp/bench/prims/core/bitset.cu | 2 +- cpp/bench/prims/neighbors/cagra_bench.cuh | 2 +- cpp/include/raft/core/bitset.cuh | 143 +++++++++++++--------- cpp/test/core/bitset.cu | 43 ++++--- 4 files changed, 117 insertions(+), 73 deletions(-) diff --git a/cpp/bench/prims/core/bitset.cu b/cpp/bench/prims/core/bitset.cu index 5f44aa9af5..85e24a3d37 100644 --- a/cpp/bench/prims/core/bitset.cu +++ b/cpp/bench/prims/core/bitset.cu @@ -44,7 +44,7 @@ struct bitset_bench : public fixture { loop_on_state(state, [this]() { auto my_bitset = raft::core::bitset( this->res, raft::make_const_mdspan(mask.view()), params.bitset_len); - my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view()); + my_bitset.test(raft::make_const_mdspan(queries.view()), outputs.view()); }); } diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh index 63f6c14686..0748177dff 100644 --- a/cpp/bench/prims/neighbors/cagra_bench.cuh +++ b/cpp/bench/prims/neighbors/cagra_bench.cuh @@ -85,7 +85,7 @@ struct CagraBench : public fixture { resource::get_thrust_policy(handle), thrust::device_pointer_cast(removed_indices.data_handle()), thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); - removed_indices_bitset_.set(handle, removed_indices.view()); + removed_indices_bitset_.set(removed_indices.view()); index_.emplace(raft::neighbors::cagra::index( handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view()))); } diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index d75957817a..bfb3364c07 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -161,13 +161,11 @@ struct bitset { bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), raft::resource::get_cuda_stream(res)}, - bitset_len_{bitset_len} + bitset_len_{bitset_len}, + res_{res} { - cudaMemsetAsync(bitset_.data(), - default_value ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res)); - set(res, mask_index, !default_value); + reset(default_value); + set(mask_index, !default_value); } /** @@ -180,12 +178,10 @@ struct bitset { bitset(const raft::resources& res, index_t bitset_len, bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), resource::get_cuda_stream(res)}, - bitset_len_{bitset_len} + bitset_len_{bitset_len}, + res_{res} { - cudaMemsetAsync(bitset_.data(), - default_value ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res)); + reset(default_value); } // Disable copy constructor bitset(const bitset&) = delete; @@ -237,7 +233,7 @@ struct bitset { /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to * the default value. */ - void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true) + void resize(index_t new_bitset_len, bool default_value = true) { auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); @@ -245,10 +241,10 @@ struct bitset { bitset_len_ = new_bitset_len; if (old_size < new_size) { // If the new size is larger, set the new bits to the default value - cudaMemsetAsync(bitset_.data() + old_size, - default_value ? 0xff : 0x00, - (new_size - old_size) * sizeof(bitset_t), - resource::get_cuda_stream(res)); + RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data() + old_size, + default_value ? 0xff : 0x00, + (new_size - old_size) * sizeof(bitset_t), + resource::get_cuda_stream(res_))); } } @@ -261,14 +257,13 @@ struct bitset { * @param output List of outputs */ template - void test(const raft::resources& res, - raft::device_vector_view queries, + void test(raft::device_vector_view queries, raft::device_vector_view output) const { RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); auto bitset_view = view(); raft::linalg::map( - res, + res_, output, [bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); }, queries); @@ -280,12 +275,10 @@ struct bitset { * @param mask_index indices to remove from the bitset * @param set_value Value to set the bits to (true or false) */ - void set(const raft::resources& res, - raft::device_vector_view mask_index, - bool set_value = false) + void set(raft::device_vector_view mask_index, bool set_value = false) { auto this_bitset_view = view(); - thrust::for_each_n(resource::get_thrust_policy(res), + thrust::for_each_n(resource::get_thrust_policy(res_), mask_index.data_handle(), mask_index.extent(0), [this_bitset_view, set_value] __device__(const index_t sample_index) { @@ -294,14 +287,12 @@ struct bitset { } /** * @brief Flip all the bits in a bitset. - * - * @param res RAFT resources */ - void flip(const raft::resources& res) + void flip() { auto bitset_span = this->to_mdspan(); raft::linalg::map( - res, + res_, bitset_span, [] __device__(bitset_t element) { return bitset_t(~element); }, raft::make_const_mdspan(bitset_span)); @@ -309,23 +300,21 @@ struct bitset { /** * @brief Reset the bits in a bitset. * - * @param res RAFT resources * @param default_value Value to set the bits to (true or false) */ - void reset(const raft::resources& res, bool default_value = true) + void reset(bool default_value = true) { - cudaMemsetAsync(bitset_.data(), - default_value ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res)); + RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data(), + default_value ? 0xff : 0x00, + n_elements() * sizeof(bitset_t), + resource::get_cuda_stream(res_))); } /** * @brief Returns the number of bits set to true in count_gpu_scalar. * - * @param[in] res RAFT resources * @param[out] count_gpu_scalar Device scalar to store the count */ - void count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar) + void count(raft::device_scalar_view count_gpu_scalar) { auto n_elements_ = n_elements(); auto count_gpu = @@ -337,66 +326,106 @@ struct bitset { bitset_t last_element_mask = n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0}; raft::linalg::coalesced_reduction( - res, + res_, bitset_matrix_view, count_gpu, index_t{0}, false, [last_element_mask, n_elements_] __device__(bitset_t element, index_t index) { - index_t res = 0; + index_t result = 0; if constexpr (bitset_element_size == 64) { // Needed because __popc doesn't support 64bit if (index == n_elements_ - 1) - res = index_t(raft::detail::native_popc(element & last_element_mask)); + result = index_t(raft::detail::native_popc(element & last_element_mask)); else - res = index_t(raft::detail::native_popc(element)); + result = index_t(raft::detail::native_popc(element)); } else { if (index == n_elements_ - 1) - res = index_t(__popc(element & last_element_mask)); + result = index_t(__popc(element & last_element_mask)); else - res = index_t(__popc(element)); + result = index_t(__popc(element)); } - return res; + return result; }); } /** * @brief Returns the number of bits set to true. * - * @param[in] res RAFT resources * @return index_t Number of bits set to true */ - auto count(const raft::resources& res) -> index_t + auto count() -> index_t { - auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); - count(res, count_gpu_scalar.view()); + auto count_gpu_scalar = raft::make_device_scalar(res_, 0.0); + count(count_gpu_scalar.view()); index_t count_cpu = 0; raft::update_host( - &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res_)); + resource::sync_stream(res_); return count_cpu; } /** * @brief Checks if any of the bits are set to true in the bitset. - * - * @param res RAFT resources */ - bool any(const raft::resources& res) { return count(res) > 0; } + bool any() { return count() > 0; } /** * @brief Checks if all of the bits are set to true in the bitset. - * - * @param res RAFT resources */ - bool all(const raft::resources& res) { return count(res) == bitset_len_; } + bool all() { return count() == bitset_len_; } /** * @brief Checks if none of the bits are set to true in the bitset. - * - * @param res RAFT resources */ - bool none(const raft::resources& res) { return count(res) == 0; } + bool none() { return count() == 0; } + + bitset& operator|=(const bitset& other) + { + RAFT_EXPECTS(size() == other.size(), "Sizes must be equal"); + auto this_span = to_mdspan(); + auto other_span = other.to_mdspan(); + raft::linalg::map( + res_, + this_span, + [] __device__(bitset_t this_element, bitset_t other_element) { + return this_element | other_element; + }, + raft::make_const_mdspan(this_span), + other_span); + return *this; + } + bitset& operator&=(const bitset& other) + { + RAFT_EXPECTS(size() == other.size(), "Sizes must be equal"); + auto this_span = to_mdspan(); + auto other_span = other.to_mdspan(); + raft::linalg::map( + res_, + this_span, + [] __device__(bitset_t this_element, bitset_t other_element) { + return this_element & other_element; + }, + raft::make_const_mdspan(this_span), + other_span); + return *this; + } + bitset& operator^=(const bitset& other) + { + RAFT_EXPECTS(size() == other.size(), "Sizes must be equal"); + auto this_span = to_mdspan(); + auto other_span = other.to_mdspan(); + raft::linalg::map( + res_, + this_span, + [] __device__(bitset_t this_element, bitset_t other_element) { + return this_element ^ other_element; + }, + raft::make_const_mdspan(this_span), + other_span); + return *this; + } private: raft::device_uvector bitset_; index_t bitset_len_; + const raft::resources& res_; }; /** @} */ diff --git a/cpp/test/core/bitset.cu b/cpp/test/core/bitset.cu index 9d12f04891..edda1884b3 100644 --- a/cpp/test/core/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -128,7 +128,7 @@ class BitsetTest : public testing::TestWithParam { // Create queries and verify the test results raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); - my_bitset.test(res, raft::make_const_mdspan(query_device.view()), result_device.view()); + my_bitset.test(raft::make_const_mdspan(query_device.view()), result_device.view()); update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); test_cpu_bitset(bitset_ref, query_cpu, result_ref); resource::sync_stream(res, stream); @@ -138,7 +138,7 @@ class BitsetTest : public testing::TestWithParam { raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); - my_bitset.set(res, mask_device.view()); + my_bitset.set(mask_device.view()); update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); add_cpu_bitset(bitset_ref, mask_cpu); @@ -146,25 +146,40 @@ class BitsetTest : public testing::TestWithParam { ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); // Flip the bitset and re-test - auto bitset_count = my_bitset.count(res); - my_bitset.flip(res); - ASSERT_EQ(my_bitset.count(res), spec.bitset_len - bitset_count); + auto bitset_count = my_bitset.count(); + my_bitset.flip(); + ASSERT_EQ(my_bitset.count(), spec.bitset_len - bitset_count); update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); flip_cpu_bitset(bitset_ref); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); - my_bitset.reset(res, false); - ASSERT_EQ(my_bitset.any(res), false); - ASSERT_EQ(my_bitset.none(res), true); + // Test count() operations + my_bitset.reset(false); + ASSERT_EQ(my_bitset.any(), false); + ASSERT_EQ(my_bitset.none(), true); raft::linalg::range(query_device.data_handle(), query_device.size(), stream); - my_bitset.set(res, raft::make_const_mdspan(query_device.view()), true); - bitset_count = my_bitset.count(res); + my_bitset.set(raft::make_const_mdspan(query_device.view()), true); + bitset_count = my_bitset.count(); ASSERT_EQ(bitset_count, query_device.size()); - ASSERT_EQ(my_bitset.any(res), true); - ASSERT_EQ(my_bitset.none(res), false); - - ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + ASSERT_EQ(my_bitset.any(), true); + ASSERT_EQ(my_bitset.none(), false); + + // Test operators + auto my_bitset_2 = raft::core::bitset( + res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false); + auto my_bitset_3 = raft::core::bitset( + res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false); + my_bitset_2 ^= my_bitset; + ASSERT_FALSE(devArrMatch(my_bitset_2.data_handle(), + my_bitset_3.data_handle(), + my_bitset.n_elements(), + raft::Compare())); + my_bitset_2 ^= my_bitset; + ASSERT_TRUE(devArrMatch(my_bitset_2.data_handle(), + my_bitset_3.data_handle(), + my_bitset.n_elements(), + raft::Compare())); } }; From 587e49ee7db746d0672d6511143b4f25d88f6ced Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 13 Oct 2023 12:31:08 +0200 Subject: [PATCH 3/5] Update doc --- cpp/include/raft/core/bitset.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index bfb3364c07..f315b0a65e 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -252,7 +252,6 @@ struct bitset { * @brief Test a list of indices in a bitset. * * @tparam output_t Output type of the test. Default is bool. - * @param res RAFT resources * @param queries List of indices to test * @param output List of outputs */ @@ -271,7 +270,6 @@ struct bitset { /** * @brief Set a list of indices in a bitset to set_value. * - * @param res RAFT resources * @param mask_index indices to remove from the bitset * @param set_value Value to set the bits to (true or false) */ From d59d08bc0531f91a92d4c50aa0d47cf6a8913116 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 18 Oct 2023 16:04:58 +0200 Subject: [PATCH 4/5] Fix popc, data() and cuda calls as suggested in reviews --- cpp/include/raft/core/bitset.cuh | 37 ++++++++++++++++---------------- cpp/test/core/bitset.cu | 18 ++++++---------- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index f315b0a65e..cd13ed9ca4 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -104,8 +104,8 @@ struct bitset_view { /** * @brief Get the device pointer to the bitset. */ - inline _RAFT_HOST_DEVICE auto data_handle() -> bitset_t* { return bitset_ptr_; } - inline _RAFT_HOST_DEVICE auto data_handle() const -> const bitset_t* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto data() -> bitset_t* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto data() const -> const bitset_t* { return bitset_ptr_; } /** * @brief Get the number of bits of the bitset representation. */ @@ -206,8 +206,8 @@ struct bitset { /** * @brief Get the device pointer to the bitset. */ - inline auto data_handle() -> bitset_t* { return bitset_.data(); } - inline auto data_handle() const -> const bitset_t* { return bitset_.data(); } + inline auto data() -> bitset_t* { return bitset_.data(); } + inline auto data() const -> const bitset_t* { return bitset_.data(); } /** * @brief Get the number of bits of the bitset representation. */ @@ -241,10 +241,11 @@ struct bitset { bitset_len_ = new_bitset_len; if (old_size < new_size) { // If the new size is larger, set the new bits to the default value - RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data() + old_size, - default_value ? 0xff : 0x00, - (new_size - old_size) * sizeof(bitset_t), - resource::get_cuda_stream(res_))); + + thrust::fill_n(resource::get_thrust_policy(res_), + bitset_.data() + old_size, + new_size - old_size, + default_value ? ~bitset_t{0} : bitset_t{0}); } } @@ -302,10 +303,10 @@ struct bitset { */ void reset(bool default_value = true) { - RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data(), - default_value ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res_))); + thrust::fill_n(resource::get_thrust_policy(res_), + bitset_.data(), + n_elements(), + default_value ? ~bitset_t{0} : bitset_t{0}); } /** * @brief Returns the number of bits set to true in count_gpu_scalar. @@ -331,16 +332,16 @@ struct bitset { false, [last_element_mask, n_elements_] __device__(bitset_t element, index_t index) { index_t result = 0; - if constexpr (bitset_element_size == 64) { // Needed because __popc doesn't support 64bit + if constexpr (bitset_element_size == 64) { if (index == n_elements_ - 1) - result = index_t(raft::detail::native_popc(element & last_element_mask)); + result = index_t(raft::detail::popc(element & last_element_mask)); else - result = index_t(raft::detail::native_popc(element)); - } else { + result = index_t(raft::detail::popc(element)); + } else { // Needed because popc is not overloaded for 16 and 8 bit elements if (index == n_elements_ - 1) - result = index_t(__popc(element & last_element_mask)); + result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask)); else - result = index_t(__popc(element)); + result = index_t(raft::detail::popc(uint32_t{element})); } return result; diff --git a/cpp/test/core/bitset.cu b/cpp/test/core/bitset.cu index edda1884b3..eb1306680c 100644 --- a/cpp/test/core/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -112,7 +112,7 @@ class BitsetTest : public testing::TestWithParam { // calculate the results auto my_bitset = raft::core::bitset( res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len)); - update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); + update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream); // calculate the reference create_cpu_bitset(bitset_ref, mask_cpu); @@ -139,7 +139,7 @@ class BitsetTest : public testing::TestWithParam { update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); my_bitset.set(mask_device.view()); - update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); + update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream); add_cpu_bitset(bitset_ref, mask_cpu); resource::sync_stream(res, stream); @@ -149,7 +149,7 @@ class BitsetTest : public testing::TestWithParam { auto bitset_count = my_bitset.count(); my_bitset.flip(); ASSERT_EQ(my_bitset.count(), spec.bitset_len - bitset_count); - update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); + update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream); flip_cpu_bitset(bitset_ref); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); @@ -171,15 +171,11 @@ class BitsetTest : public testing::TestWithParam { auto my_bitset_3 = raft::core::bitset( res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false); my_bitset_2 ^= my_bitset; - ASSERT_FALSE(devArrMatch(my_bitset_2.data_handle(), - my_bitset_3.data_handle(), - my_bitset.n_elements(), - raft::Compare())); + ASSERT_FALSE(devArrMatch( + my_bitset_2.data(), my_bitset_3.data(), my_bitset.n_elements(), raft::Compare())); my_bitset_2 ^= my_bitset; - ASSERT_TRUE(devArrMatch(my_bitset_2.data_handle(), - my_bitset_3.data_handle(), - my_bitset.n_elements(), - raft::Compare())); + ASSERT_TRUE(devArrMatch( + my_bitset_2.data(), my_bitset_3.data(), my_bitset.n_elements(), raft::Compare())); } }; From c5ddf0bd9f0ca172662bc74d1934ece24dfafcaf Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 19 Oct 2023 19:26:37 +0200 Subject: [PATCH 5/5] Remove raft::resources from bitset class storage --- cpp/bench/prims/core/bitset.cu | 2 +- cpp/bench/prims/neighbors/cagra_bench.cuh | 2 +- cpp/include/raft/core/bitset.cuh | 117 ++++++++-------------- cpp/test/core/bitset.cu | 36 +++---- 4 files changed, 56 insertions(+), 101 deletions(-) diff --git a/cpp/bench/prims/core/bitset.cu b/cpp/bench/prims/core/bitset.cu index 85e24a3d37..ce3136bcd5 100644 --- a/cpp/bench/prims/core/bitset.cu +++ b/cpp/bench/prims/core/bitset.cu @@ -44,7 +44,7 @@ struct bitset_bench : public fixture { loop_on_state(state, [this]() { auto my_bitset = raft::core::bitset( this->res, raft::make_const_mdspan(mask.view()), params.bitset_len); - my_bitset.test(raft::make_const_mdspan(queries.view()), outputs.view()); + my_bitset.test(this->res, raft::make_const_mdspan(queries.view()), outputs.view()); }); } diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh index c7c5a2da64..07e93a3473 100644 --- a/cpp/bench/prims/neighbors/cagra_bench.cuh +++ b/cpp/bench/prims/neighbors/cagra_bench.cuh @@ -85,7 +85,7 @@ struct CagraBench : public fixture { resource::get_thrust_policy(handle), thrust::device_pointer_cast(removed_indices.data_handle()), thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); - removed_indices_bitset_.set(removed_indices.view()); + removed_indices_bitset_.set(handle, removed_indices.view()); index_.emplace(raft::neighbors::cagra::index( handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view()))); } diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index cd13ed9ca4..552c2e9ac5 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -161,11 +161,10 @@ struct bitset { bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), raft::resource::get_cuda_stream(res)}, - bitset_len_{bitset_len}, - res_{res} + bitset_len_{bitset_len} { - reset(default_value); - set(mask_index, !default_value); + reset(res, default_value); + set(res, mask_index, !default_value); } /** @@ -178,10 +177,9 @@ struct bitset { bitset(const raft::resources& res, index_t bitset_len, bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), resource::get_cuda_stream(res)}, - bitset_len_{bitset_len}, - res_{res} + bitset_len_{bitset_len} { - reset(default_value); + reset(res, default_value); } // Disable copy constructor bitset(const bitset&) = delete; @@ -232,8 +230,12 @@ struct bitset { } /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to - * the default value. */ - void resize(index_t new_bitset_len, bool default_value = true) + * the default value. + * @param res RAFT resources + * @param new_bitset_len new size of the bitset + * @param default_value default value to initialize the new bits to + */ + void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true) { auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); @@ -242,7 +244,7 @@ struct bitset { if (old_size < new_size) { // If the new size is larger, set the new bits to the default value - thrust::fill_n(resource::get_thrust_policy(res_), + thrust::fill_n(resource::get_thrust_policy(res), bitset_.data() + old_size, new_size - old_size, default_value ? ~bitset_t{0} : bitset_t{0}); @@ -253,17 +255,19 @@ struct bitset { * @brief Test a list of indices in a bitset. * * @tparam output_t Output type of the test. Default is bool. + * @param res RAFT resources * @param queries List of indices to test * @param output List of outputs */ template - void test(raft::device_vector_view queries, + void test(const raft::resources& res, + raft::device_vector_view queries, raft::device_vector_view output) const { RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); auto bitset_view = view(); raft::linalg::map( - res_, + res, output, [bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); }, queries); @@ -271,13 +275,16 @@ struct bitset { /** * @brief Set a list of indices in a bitset to set_value. * + * @param res RAFT resources * @param mask_index indices to remove from the bitset * @param set_value Value to set the bits to (true or false) */ - void set(raft::device_vector_view mask_index, bool set_value = false) + void set(const raft::resources& res, + raft::device_vector_view mask_index, + bool set_value = false) { auto this_bitset_view = view(); - thrust::for_each_n(resource::get_thrust_policy(res_), + thrust::for_each_n(resource::get_thrust_policy(res), mask_index.data_handle(), mask_index.extent(0), [this_bitset_view, set_value] __device__(const index_t sample_index) { @@ -286,12 +293,13 @@ struct bitset { } /** * @brief Flip all the bits in a bitset. + * @param res RAFT resources */ - void flip() + void flip(const raft::resources& res) { auto bitset_span = this->to_mdspan(); raft::linalg::map( - res_, + res, bitset_span, [] __device__(bitset_t element) { return bitset_t(~element); }, raft::make_const_mdspan(bitset_span)); @@ -299,11 +307,12 @@ struct bitset { /** * @brief Reset the bits in a bitset. * + * @param res RAFT resources * @param default_value Value to set the bits to (true or false) */ - void reset(bool default_value = true) + void reset(const raft::resources& res, bool default_value = true) { - thrust::fill_n(resource::get_thrust_policy(res_), + thrust::fill_n(resource::get_thrust_policy(res), bitset_.data(), n_elements(), default_value ? ~bitset_t{0} : bitset_t{0}); @@ -311,9 +320,10 @@ struct bitset { /** * @brief Returns the number of bits set to true in count_gpu_scalar. * + * @param[in] res RAFT resources * @param[out] count_gpu_scalar Device scalar to store the count */ - void count(raft::device_scalar_view count_gpu_scalar) + void count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar) { auto n_elements_ = n_elements(); auto count_gpu = @@ -325,7 +335,7 @@ struct bitset { bitset_t last_element_mask = n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0}; raft::linalg::coalesced_reduction( - res_, + res, bitset_matrix_view, count_gpu, index_t{0}, @@ -350,81 +360,38 @@ struct bitset { /** * @brief Returns the number of bits set to true. * + * @param res RAFT resources * @return index_t Number of bits set to true */ - auto count() -> index_t + auto count(const raft::resources& res) -> index_t { - auto count_gpu_scalar = raft::make_device_scalar(res_, 0.0); - count(count_gpu_scalar.view()); + auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); + count(res, count_gpu_scalar.view()); index_t count_cpu = 0; raft::update_host( - &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res_)); - resource::sync_stream(res_); + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); return count_cpu; } /** * @brief Checks if any of the bits are set to true in the bitset. + * @param res RAFT resources */ - bool any() { return count() > 0; } + bool any(const raft::resources& res) { return count(res) > 0; } /** * @brief Checks if all of the bits are set to true in the bitset. + * @param res RAFT resources */ - bool all() { return count() == bitset_len_; } + bool all(const raft::resources& res) { return count(res) == bitset_len_; } /** * @brief Checks if none of the bits are set to true in the bitset. + * @param res RAFT resources */ - bool none() { return count() == 0; } - - bitset& operator|=(const bitset& other) - { - RAFT_EXPECTS(size() == other.size(), "Sizes must be equal"); - auto this_span = to_mdspan(); - auto other_span = other.to_mdspan(); - raft::linalg::map( - res_, - this_span, - [] __device__(bitset_t this_element, bitset_t other_element) { - return this_element | other_element; - }, - raft::make_const_mdspan(this_span), - other_span); - return *this; - } - bitset& operator&=(const bitset& other) - { - RAFT_EXPECTS(size() == other.size(), "Sizes must be equal"); - auto this_span = to_mdspan(); - auto other_span = other.to_mdspan(); - raft::linalg::map( - res_, - this_span, - [] __device__(bitset_t this_element, bitset_t other_element) { - return this_element & other_element; - }, - raft::make_const_mdspan(this_span), - other_span); - return *this; - } - bitset& operator^=(const bitset& other) - { - RAFT_EXPECTS(size() == other.size(), "Sizes must be equal"); - auto this_span = to_mdspan(); - auto other_span = other.to_mdspan(); - raft::linalg::map( - res_, - this_span, - [] __device__(bitset_t this_element, bitset_t other_element) { - return this_element ^ other_element; - }, - raft::make_const_mdspan(this_span), - other_span); - return *this; - } + bool none(const raft::resources& res) { return count(res) == 0; } private: raft::device_uvector bitset_; index_t bitset_len_; - const raft::resources& res_; }; /** @} */ diff --git a/cpp/test/core/bitset.cu b/cpp/test/core/bitset.cu index eb1306680c..b799297e8c 100644 --- a/cpp/test/core/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -128,7 +128,7 @@ class BitsetTest : public testing::TestWithParam { // Create queries and verify the test results raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); - my_bitset.test(raft::make_const_mdspan(query_device.view()), result_device.view()); + my_bitset.test(res, raft::make_const_mdspan(query_device.view()), result_device.view()); update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); test_cpu_bitset(bitset_ref, query_cpu, result_ref); resource::sync_stream(res, stream); @@ -138,7 +138,7 @@ class BitsetTest : public testing::TestWithParam { raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); - my_bitset.set(mask_device.view()); + my_bitset.set(res, mask_device.view()); update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream); add_cpu_bitset(bitset_ref, mask_cpu); @@ -146,36 +146,24 @@ class BitsetTest : public testing::TestWithParam { ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); // Flip the bitset and re-test - auto bitset_count = my_bitset.count(); - my_bitset.flip(); - ASSERT_EQ(my_bitset.count(), spec.bitset_len - bitset_count); + auto bitset_count = my_bitset.count(res); + my_bitset.flip(res); + ASSERT_EQ(my_bitset.count(res), spec.bitset_len - bitset_count); update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream); flip_cpu_bitset(bitset_ref); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); // Test count() operations - my_bitset.reset(false); - ASSERT_EQ(my_bitset.any(), false); - ASSERT_EQ(my_bitset.none(), true); + my_bitset.reset(res, false); + ASSERT_EQ(my_bitset.any(res), false); + ASSERT_EQ(my_bitset.none(res), true); raft::linalg::range(query_device.data_handle(), query_device.size(), stream); - my_bitset.set(raft::make_const_mdspan(query_device.view()), true); - bitset_count = my_bitset.count(); + my_bitset.set(res, raft::make_const_mdspan(query_device.view()), true); + bitset_count = my_bitset.count(res); ASSERT_EQ(bitset_count, query_device.size()); - ASSERT_EQ(my_bitset.any(), true); - ASSERT_EQ(my_bitset.none(), false); - - // Test operators - auto my_bitset_2 = raft::core::bitset( - res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false); - auto my_bitset_3 = raft::core::bitset( - res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false); - my_bitset_2 ^= my_bitset; - ASSERT_FALSE(devArrMatch( - my_bitset_2.data(), my_bitset_3.data(), my_bitset.n_elements(), raft::Compare())); - my_bitset_2 ^= my_bitset; - ASSERT_TRUE(devArrMatch( - my_bitset_2.data(), my_bitset_3.data(), my_bitset.n_elements(), raft::Compare())); + ASSERT_EQ(my_bitset.any(res), true); + ASSERT_EQ(my_bitset.none(res), false); } };