From 25a85e0f6cbabe46f664c3b7c664dc51687de5c8 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 4 Sep 2023 13:50:28 +0200 Subject: [PATCH 01/16] Add bitset --- cpp/include/raft/util/bitset.cuh | 165 +++++++++++++++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/util/bitset.cu | 153 ++++++++++++++++++++++++++++ 3 files changed, 319 insertions(+) create mode 100644 cpp/include/raft/util/bitset.cuh create mode 100644 cpp/test/util/bitset.cu diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh new file mode 100644 index 0000000000..e9313fdfda --- /dev/null +++ b/cpp/include/raft/util/bitset.cuh @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::utils { +namespace detail { + +/** + * @brief Unset bits in bitset already created + * + * @tparam IdxT + * @param bitset + * @param sample_index_ptr + * @param sample_len + */ +template +__global__ void unset_kernel(uint32_t* bitset, const IdxT* sample_index_ptr, const IdxT sample_len) +{ + for (IdxT tid = threadIdx.x + blockIdx.x * blockDim.x; tid <= sample_len; + tid += blockDim.x * gridDim.x) { + IdxT sample_index = sample_index_ptr[tid]; + const IdxT bit_element = sample_index / 32; + const IdxT bit_index = sample_index % 32; + const uint32_t bitmask = 1 << bit_index; + atomicAnd(bitset + bit_element, ~bitmask); + } +} + +/** + * @brief Create bitset from list of indices to unset + * + * @tparam IdxT + * @tparam TPB + * @param bitset + * @param bitset_size + * @param index_ptr + * @param index_len + */ +template +__global__ void create_bitset_kernel(uint32_t* bitset, + const IdxT bitset_size, + const IdxT* index_ptr, + const IdxT index_len) +{ + extern __shared__ std::uint32_t shared_mem[]; + + // Create bitset in shmem + for (IdxT tid = threadIdx.x; tid < bitset_size; tid += TPB) { + shared_mem[tid] = 0xffffffff; + } + + __syncthreads(); + + for (IdxT tid = threadIdx.x; tid < index_len; tid += TPB) { + const IdxT sample_index = index_ptr[tid]; + const IdxT bit_element = sample_index / 32; + const IdxT bit_index = sample_index % 32; + const std::uint32_t bitmask = 1 << bit_index; + atomicAnd(shared_mem + bit_element, ~bitmask); + } + + __syncthreads(); + // Output bitset + for (IdxT tid = threadIdx.x; tid < bitset_size; tid += TPB) { + bitset[tid] = shared_mem[tid]; + } +} +} // namespace detail + +template +struct bitset_view { + using BitsetT = uint32_t; + IdxT bitset_size = sizeof(BitsetT) * 8; + + _RAFT_HOST_DEVICE bitset_view(BitsetT* bitset_ptr, IdxT bitset_len) + : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} + { + } + _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span) + : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_span.extent(0)} + { + } + + inline _RAFT_DEVICE bool test(const IdxT sample_index) const + { + const IdxT bit_element = bitset_ptr_[sample_index / bitset_size]; + const IdxT bit_index = sample_index % bitset_size; + const bool is_bit_set = (bit_element & (1ULL << bit_index)) != 0; + return is_bit_set; + } + inline _RAFT_HOST_DEVICE auto get_bitset_ptr() -> BitsetT* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto get_bitset_ptr() const -> const BitsetT* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto get_bitset_len() const -> IdxT { return bitset_len_; } + + private: + BitsetT* bitset_ptr_; + IdxT bitset_len_; +}; + +template +struct bitset { + using BitsetT = uint32_t; + IdxT bitset_size = sizeof(BitsetT) * 8; + + bitset(const raft::resources& res, + raft::device_vector_view mask_index, + IdxT bitset_len) + : bitset_{raft::make_device_vector(res, raft::ceildiv(bitset_len, bitset_size))} + { + static const size_t TPB_X = 128; + dim3 blocks(raft::ceildiv(size_t(raft::ceildiv(bitset_len, bitset_size)), TPB_X)); + dim3 threads(TPB_X); + + detail::create_bitset_kernel + <<>>( + bitset_.data_handle(), bitset_.extent(0), mask_index.data_handle(), mask_index.extent(0)); + } + // Disable copy constructor + bitset(const bitset&) = delete; + bitset(bitset&&) = default; + bitset& operator=(const bitset&) = delete; + bitset& operator=(bitset&&) = default; + + inline auto view() -> bitset_view { return bitset_view(bitset_.view()); } + [[nodiscard]] inline auto view() const -> bitset_view + { + return bitset_view(bitset_.view()); + } + + private: + raft::device_vector bitset_; +}; + +template +void unset_bitset(const raft::resources& res, + bitset_view& bitset_view_, + raft::device_vector_view mask_index) +{ + static const size_t TPB_X = 128; + dim3 blocks(raft::ceildiv(mask_index.extents(), TPB_X)); + dim3 threads(TPB_X); + detail::unset_kernel<<>>( + bitset_view_.get_bitset_ptr(), mask_index.data_handle(), mask_index.extent(0)); +} +} // namespace raft::utils \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index db4c59c807..57a45c557c 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -423,6 +423,7 @@ if(BUILD_TESTS) PATH test/core/seive.cu test/util/bitonic_sort.cu + test/util/bitset.cu test/util/cudart_utils.cpp test/util/device_atomics.cu test/util/integer_utils.cpp diff --git a/cpp/test/util/bitset.cu b/cpp/test/util/bitset.cu new file mode 100644 index 0000000000..10307b3b93 --- /dev/null +++ b/cpp/test/util/bitset.cu @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include + +#include + +#include +#include + +namespace raft::utils { + +struct test_spec { + int bitset_len; + int mask_len; + int query_len; +}; + +auto operator<<(std::ostream& os, const test_spec& ss) -> std::ostream& +{ + os << "bitset{bitset_len: " << ss.bitset_len << ", mask_len: " << ss.mask_len << "}"; + return os; +} + +template +void create_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) +{ + for (size_t i = 0; i < bitset.size(); i++) { + bitset[i] = 0xffffffff; + } + for (size_t i = 0; i < mask_idx.size(); i++) { + auto idx = mask_idx[i]; + bitset[idx / 32] &= ~(1 << (idx % 32)); + } +} + +template +void test_cpu_bitset(const std::vector& bitset, + const std::vector& queries, + std::vector& result) +{ + for (size_t i = 0; i < queries.size(); i++) { + result[i] = uint8_t((bitset[queries[i] / 32] & (1 << (queries[i] % 32))) != 0); + } +} + +template +__global__ void test_gpu_bitset(bitset_view bitset, + const IdxT* queries, + uint8_t* result, + IdxT n_queries) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < n_queries; + tid += blockDim.x * gridDim.x) { + auto query = queries[tid]; + result[tid] = (uint8_t)bitset.test(query); + } +} + +template +class BitsetTest : public testing::TestWithParam { + protected: + const test_spec spec; + std::vector bitset_result; + std::vector bitset_ref; + raft::resources res; + + public: + explicit BitsetTest() + : spec(testing::TestWithParam::GetParam()), + bitset_result(raft::ceildiv(spec.bitset_len, 32)), + bitset_ref(raft::ceildiv(spec.bitset_len, 32)) + { + } + + void run() + { + auto stream = resource::get_cuda_stream(res); + + // generate input and mask + raft::random::RngState rng(42); + auto mask_device = raft::make_device_vector(res, spec.mask_len); + std::vector mask_cpu(spec.mask_len); + raft::random::uniformInt(res, rng, mask_device.view(), T(0), T(spec.bitset_len)); + update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); + resource::sync_stream(res, stream); + + // calculate the results + auto test_bitset = + raft::utils::bitset(res, raft::make_const_mdspan(mask_device.view()), T(spec.bitset_len)); + update_host( + bitset_result.data(), test_bitset.view().get_bitset_ptr(), bitset_result.size(), stream); + + // calculate the reference + create_cpu_bitset(bitset_ref, mask_cpu); + + // make sure the results are available on host + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + + auto query_device = raft::make_device_vector(res, spec.query_len); + auto result_device = raft::make_device_vector(res, spec.query_len); + auto query_cpu = std::vector(spec.query_len); + auto result_cpu = std::vector(spec.query_len); + auto result_ref = std::vector(spec.query_len); + + raft::random::uniformInt(res, rng, query_device.view(), T(0), T(spec.bitset_len)); + update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); + test_gpu_bitset<<>>(test_bitset.view(), + query_device.data_handle(), + result_device.data_handle(), + query_device.extent(0)); + update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); + test_cpu_bitset(bitset_ref, query_cpu, result_ref); + + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(result_cpu, result_ref, Compare())); + } +}; + +auto inputs = ::testing::Values(test_spec{1 << 25, 1 << 23, 1 << 24}, + test_spec{32, 5, 10}, + test_spec{100, 30, 10}, + test_spec{1024, 55, 100}, + test_spec{10000, 1000, 1000}, + test_spec{1 << 15, 1 << 3, 1 << 12}, + test_spec{1 << 15, 1 << 14, 1 << 13}); + +using Uint32 = BitsetTest; +TEST_P(Uint32, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32, inputs); + +using Uint64 = BitsetTest; +TEST_P(Uint64, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64, inputs); + +} // namespace raft::utils From dbabbd2fe5367ea5fcd94fbb0b1218f5c9d09a8f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 5 Sep 2023 01:22:23 +0200 Subject: [PATCH 02/16] Update naming and use raft::linalg::map --- cpp/include/raft/util/bitset.cuh | 48 ++++++++++++++++++++----------- cpp/test/util/bitset.cu | 49 ++++++++++++++++---------------- 2 files changed, 57 insertions(+), 40 deletions(-) diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index e9313fdfda..231551ab81 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -18,6 +18,7 @@ #include #include +#include namespace raft::utils { namespace detail { @@ -31,9 +32,11 @@ namespace detail { * @param sample_len */ template -__global__ void unset_kernel(uint32_t* bitset, const IdxT* sample_index_ptr, const IdxT sample_len) +__global__ void bitset_unset_kernel(uint32_t* bitset, + const IdxT* sample_index_ptr, + const IdxT sample_len) { - for (IdxT tid = threadIdx.x + blockIdx.x * blockDim.x; tid <= sample_len; + for (IdxT tid = threadIdx.x + blockIdx.x * blockDim.x; tid < sample_len; tid += blockDim.x * gridDim.x) { IdxT sample_index = sample_index_ptr[tid]; const IdxT bit_element = sample_index / 32; @@ -54,7 +57,7 @@ __global__ void unset_kernel(uint32_t* bitset, const IdxT* sample_index_ptr, con * @param index_len */ template -__global__ void create_bitset_kernel(uint32_t* bitset, +__global__ void bitset_create_kernel(uint32_t* bitset, const IdxT bitset_size, const IdxT* index_ptr, const IdxT index_len) @@ -86,8 +89,8 @@ __global__ void create_bitset_kernel(uint32_t* bitset, template struct bitset_view { - using BitsetT = uint32_t; - IdxT bitset_size = sizeof(BitsetT) * 8; + using BitsetT = uint32_t; + IdxT bitset_element_size = sizeof(BitsetT) * 8; _RAFT_HOST_DEVICE bitset_view(BitsetT* bitset_ptr, IdxT bitset_len) : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} @@ -100,8 +103,8 @@ struct bitset_view { inline _RAFT_DEVICE bool test(const IdxT sample_index) const { - const IdxT bit_element = bitset_ptr_[sample_index / bitset_size]; - const IdxT bit_index = sample_index % bitset_size; + const IdxT bit_element = bitset_ptr_[sample_index / bitset_element_size]; + const IdxT bit_index = sample_index % bitset_element_size; const bool is_bit_set = (bit_element & (1ULL << bit_index)) != 0; return is_bit_set; } @@ -116,19 +119,21 @@ struct bitset_view { template struct bitset { - using BitsetT = uint32_t; - IdxT bitset_size = sizeof(BitsetT) * 8; + using BitsetT = uint32_t; + IdxT bitset_element_size = sizeof(BitsetT) * 8; bitset(const raft::resources& res, raft::device_vector_view mask_index, IdxT bitset_len) - : bitset_{raft::make_device_vector(res, raft::ceildiv(bitset_len, bitset_size))} + : bitset_{raft::make_device_vector( + res, raft::ceildiv(bitset_len, bitset_element_size))} { + RAFT_EXPECTS(mask_index.extent(0) <= bitset_len, "Mask index cannot be larger than bitset len"); static const size_t TPB_X = 128; - dim3 blocks(raft::ceildiv(size_t(raft::ceildiv(bitset_len, bitset_size)), TPB_X)); + dim3 blocks(raft::ceildiv(size_t(bitset_.extent(0)), TPB_X)); dim3 threads(TPB_X); - detail::create_bitset_kernel + detail::bitset_create_kernel << -void unset_bitset(const raft::resources& res, - bitset_view& bitset_view_, +void bitset_unset(const raft::resources& res, + bitset_view bitset_view_, raft::device_vector_view mask_index) { static const size_t TPB_X = 128; - dim3 blocks(raft::ceildiv(mask_index.extents(), TPB_X)); + dim3 blocks(raft::ceildiv(size_t(mask_index.extent(0)), TPB_X)); dim3 threads(TPB_X); - detail::unset_kernel<<>>( + detail::bitset_unset_kernel<<>>( bitset_view_.get_bitset_ptr(), mask_index.data_handle(), mask_index.extent(0)); } + +template +void bitset_test(const raft::resources& res, + const bitset_view bitset_view_, + raft::device_vector_view queries, + raft::device_vector_view output) +{ + RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); + raft::linalg::map( + res, output, [=] __device__(IdxT query) { return OutputT(bitset_view_.test(query)); }, queries); +} } // namespace raft::utils \ No newline at end of file diff --git a/cpp/test/util/bitset.cu b/cpp/test/util/bitset.cu index 10307b3b93..0ca6be1c48 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/util/bitset.cu @@ -39,17 +39,23 @@ auto operator<<(std::ostream& os, const test_spec& ss) -> std::ostream& } template -void create_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) +void add_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) { - for (size_t i = 0; i < bitset.size(); i++) { - bitset[i] = 0xffffffff; - } for (size_t i = 0; i < mask_idx.size(); i++) { auto idx = mask_idx[i]; bitset[idx / 32] &= ~(1 << (idx % 32)); } } +template +void create_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) +{ + for (size_t i = 0; i < bitset.size(); i++) { + bitset[i] = 0xffffffff; + } + add_cpu_bitset(bitset, mask_idx); +} + template void test_cpu_bitset(const std::vector& bitset, const std::vector& queries, @@ -60,19 +66,6 @@ void test_cpu_bitset(const std::vector& bitset, } } -template -__global__ void test_gpu_bitset(bitset_view bitset, - const IdxT* queries, - uint8_t* result, - IdxT n_queries) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < n_queries; - tid += blockDim.x * gridDim.x) { - auto query = queries[tid]; - result[tid] = (uint8_t)bitset.test(query); - } -} - template class BitsetTest : public testing::TestWithParam { protected: @@ -109,8 +102,6 @@ class BitsetTest : public testing::TestWithParam { // calculate the reference create_cpu_bitset(bitset_ref, mask_cpu); - - // make sure the results are available on host resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); @@ -120,17 +111,27 @@ class BitsetTest : public testing::TestWithParam { auto result_cpu = std::vector(spec.query_len); auto result_ref = std::vector(spec.query_len); + // Create queries and verify the test results raft::random::uniformInt(res, rng, query_device.view(), T(0), T(spec.bitset_len)); update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); - test_gpu_bitset<<>>(test_bitset.view(), - query_device.data_handle(), - result_device.data_handle(), - query_device.extent(0)); + raft::utils::bitset_test( + res, test_bitset.view(), raft::make_const_mdspan(query_device.view()), result_device.view()); update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); test_cpu_bitset(bitset_ref, query_cpu, result_ref); + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(result_cpu, result_ref, Compare())); + + // Add more sample to the bitset and re-test + raft::random::uniformInt(res, rng, mask_device.view(), T(0), T(spec.bitset_len)); + update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); + resource::sync_stream(res, stream); + raft::utils::bitset_unset(res, test_bitset.view(), mask_device.view()); + update_host( + bitset_result.data(), test_bitset.view().get_bitset_ptr(), bitset_result.size(), stream); + add_cpu_bitset(bitset_ref, mask_cpu); resource::sync_stream(res, stream); - ASSERT_TRUE(hostVecMatch(result_cpu, result_ref, Compare())); + ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); } }; From 4906c3a1bd1411577e65b9db4c3fbf4ee4cc6244 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 5 Sep 2023 02:46:49 +0200 Subject: [PATCH 03/16] Fix shared mem allocation --- cpp/include/raft/util/bitset.cuh | 40 +++++++++++++++++++------------- cpp/test/util/bitset.cu | 6 ++--- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index 231551ab81..83c7849d08 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -58,31 +58,36 @@ __global__ void bitset_unset_kernel(uint32_t* bitset, */ template __global__ void bitset_create_kernel(uint32_t* bitset, - const IdxT bitset_size, + const IdxT bitset_len, + const IdxT bitset_len_per_block, const IdxT* index_ptr, const IdxT index_len) { extern __shared__ std::uint32_t shared_mem[]; + const IdxT bitset_start = blockIdx.x * blockDim.x; - // Create bitset in shmem - for (IdxT tid = threadIdx.x; tid < bitset_size; tid += TPB) { + // Create local bitset in shmem + for (IdxT tid = threadIdx.x; tid < bitset_len_per_block; tid += TPB) { shared_mem[tid] = 0xffffffff; } __syncthreads(); for (IdxT tid = threadIdx.x; tid < index_len; tid += TPB) { - const IdxT sample_index = index_ptr[tid]; - const IdxT bit_element = sample_index / 32; + const IdxT sample_index = index_ptr[tid]; + const IdxT bit_element = sample_index / 32; + if (bit_element < bitset_start || (bit_element >= (bitset_start + bitset_len_per_block))) + continue; const IdxT bit_index = sample_index % 32; - const std::uint32_t bitmask = 1 << bit_index; - atomicAnd(shared_mem + bit_element, ~bitmask); + const std::uint32_t bitmask = ~(1 << bit_index); + atomicAnd(shared_mem + bit_element - bitset_start, bitmask); } __syncthreads(); - // Output bitset - for (IdxT tid = threadIdx.x; tid < bitset_size; tid += TPB) { - bitset[tid] = shared_mem[tid]; + // Output global bitset + for (IdxT tid = threadIdx.x; tid < bitset_len_per_block && (tid + bitset_start) < bitset_len; + tid += TPB) { + bitset[tid + bitset_start] = shared_mem[tid]; } } } // namespace detail @@ -129,16 +134,18 @@ struct bitset { res, raft::ceildiv(bitset_len, bitset_element_size))} { RAFT_EXPECTS(mask_index.extent(0) <= bitset_len, "Mask index cannot be larger than bitset len"); - static const size_t TPB_X = 128; + static const size_t TPB_X = 512; dim3 blocks(raft::ceildiv(size_t(bitset_.extent(0)), TPB_X)); dim3 threads(TPB_X); + auto bitset_len_per_block = TPB_X; detail::bitset_create_kernel - <<>>( - bitset_.data_handle(), bitset_.extent(0), mask_index.data_handle(), mask_index.extent(0)); + <<>>( + bitset_.data_handle(), + bitset_.extent(0), + bitset_len_per_block, + mask_index.data_handle(), + mask_index.extent(0)); } // Disable copy constructor bitset(const bitset&) = delete; @@ -164,6 +171,7 @@ void bitset_unset(const raft::resources& res, static const size_t TPB_X = 128; dim3 blocks(raft::ceildiv(size_t(mask_index.extent(0)), TPB_X)); dim3 threads(TPB_X); + // TODO thrust::for_each? detail::bitset_unset_kernel<<>>( bitset_view_.get_bitset_ptr(), mask_index.data_handle(), mask_index.extent(0)); } diff --git a/cpp/test/util/bitset.cu b/cpp/test/util/bitset.cu index 0ca6be1c48..dc048a5b43 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/util/bitset.cu @@ -135,13 +135,13 @@ class BitsetTest : public testing::TestWithParam { } }; -auto inputs = ::testing::Values(test_spec{1 << 25, 1 << 23, 1 << 24}, - test_spec{32, 5, 10}, +auto inputs = ::testing::Values(test_spec{32, 5, 10}, test_spec{100, 30, 10}, test_spec{1024, 55, 100}, test_spec{10000, 1000, 1000}, test_spec{1 << 15, 1 << 3, 1 << 12}, - test_spec{1 << 15, 1 << 14, 1 << 13}); + test_spec{1 << 15, 1 << 14, 1 << 13}, + test_spec{1 << 25, 1 << 23, 1 << 14}); using Uint32 = BitsetTest; TEST_P(Uint32, Run) { run(); } From 463d7f544004de334dfc7c1e021648049d4ea310 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 5 Sep 2023 02:47:50 +0200 Subject: [PATCH 04/16] Fix copyright --- cpp/test/util/bitset.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/util/bitset.cu b/cpp/test/util/bitset.cu index dc048a5b43..f251159a8c 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/util/bitset.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 5227a95ada951b7eeccad088019db4ffb90a0021 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 5 Sep 2023 14:19:35 +0200 Subject: [PATCH 05/16] Add documentation --- cpp/include/raft/util/bitset.cuh | 127 +++++++++++++++------- cpp/include/raft/util/memory_pool-inl.hpp | 5 + docs/source/cpp_api.rst | 3 +- docs/source/cpp_api/utils.rst | 33 ++++++ 4 files changed, 129 insertions(+), 39 deletions(-) create mode 100644 docs/source/cpp_api/utils.rst diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index 83c7849d08..3212039c1e 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -17,35 +17,13 @@ #pragma once #include +#include #include #include +#include namespace raft::utils { namespace detail { - -/** - * @brief Unset bits in bitset already created - * - * @tparam IdxT - * @param bitset - * @param sample_index_ptr - * @param sample_len - */ -template -__global__ void bitset_unset_kernel(uint32_t* bitset, - const IdxT* sample_index_ptr, - const IdxT sample_len) -{ - for (IdxT tid = threadIdx.x + blockIdx.x * blockDim.x; tid < sample_len; - tid += blockDim.x * gridDim.x) { - IdxT sample_index = sample_index_ptr[tid]; - const IdxT bit_element = sample_index / 32; - const IdxT bit_index = sample_index % 32; - const uint32_t bitmask = 1 << bit_index; - atomicAnd(bitset + bit_element, ~bitmask); - } -} - /** * @brief Create bitset from list of indices to unset * @@ -90,8 +68,20 @@ __global__ void bitset_create_kernel(uint32_t* bitset, bitset[tid + bitset_start] = shared_mem[tid]; } } -} // namespace detail +} // end namespace detail +/** + * @defgroup bitset Bitset + * @{ + */ +/** + * @brief View of a RAFT Bitset. + * + * This lightweight structure stores a pointer to a bitset in device memory with it's length. + * It provides a test() device function to check if a given index is set in the bitset. + * + * @tparam IdxT Indexing type used. Default is uint32_t. + */ template struct bitset_view { using BitsetT = uint32_t; @@ -101,20 +91,36 @@ struct bitset_view { : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} { } + /** + * @brief Create a bitset view from a device vector view of the bitset. + * + * @param bitset_span Device vector view of the bitset + */ _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span) : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_span.extent(0)} { } - - inline _RAFT_DEVICE bool test(const IdxT sample_index) const + /** + * @brief Device function to test if a given index is set in the bitset. + * + * @param sample_index Single index to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_DEVICE auto test(const IdxT sample_index) const -> bool { const IdxT bit_element = bitset_ptr_[sample_index / bitset_element_size]; const IdxT bit_index = sample_index % bitset_element_size; const bool is_bit_set = (bit_element & (1ULL << bit_index)) != 0; return is_bit_set; } + /** + * @brief Get the device pointer to the bitset. + */ inline _RAFT_HOST_DEVICE auto get_bitset_ptr() -> BitsetT* { return bitset_ptr_; } inline _RAFT_HOST_DEVICE auto get_bitset_ptr() const -> const BitsetT* { return bitset_ptr_; } + /** + * @brief Get the length of the bitset representation. + */ inline _RAFT_HOST_DEVICE auto get_bitset_len() const -> IdxT { return bitset_len_; } private: @@ -122,11 +128,28 @@ struct bitset_view { IdxT bitset_len_; }; +/** + * @brief RAFT Bitset. + * + * This structure encapsulates a bitset in device memory. It provides a view() method to get a + * device-usable lightweight view of the bitset. + * Each index is represented by a single bit in the bitset. The total number of bytes used is + * ceil(bitset_len / 4). + * The underlying type of the bitset array is uint32_t. + * @tparam IdxT Indexing type used. Default is uint32_t. + */ template struct bitset { using BitsetT = uint32_t; IdxT bitset_element_size = sizeof(BitsetT) * 8; + /** + * @brief Construct a new bitset object + * + * @param res RAFT resources + * @param mask_index List of indices to unset in the bitset + * @param bitset_len Length of the bitset + */ bitset(const raft::resources& res, raft::device_vector_view mask_index, IdxT bitset_len) @@ -153,8 +176,13 @@ struct bitset { bitset& operator=(const bitset&) = delete; bitset& operator=(bitset&&) = default; - inline auto view() -> bitset_view { return bitset_view(bitset_.view()); } - [[nodiscard]] inline auto view() const -> bitset_view + /** + * @brief Create a device-usable view of the bitset. + * + * @return bitset_view + */ + inline auto view() -> raft::utils::bitset_view { return bitset_view(bitset_.view()); } + [[nodiscard]] inline auto view() const -> raft::utils::bitset_view { return bitset_view(bitset_.view()); } @@ -163,22 +191,44 @@ struct bitset { raft::device_vector bitset_; }; +/** + * @brief Function to unset a list of indices in a bitset. + * + * @tparam IdxT Indexing type used. Default is uint32_t. + * @param res RAFT resources + * @param bitset_view_ View of the bitset + * @param mask_index indices to remove from the bitset + */ template void bitset_unset(const raft::resources& res, - bitset_view bitset_view_, + raft::utils::bitset_view bitset_view_, raft::device_vector_view mask_index) { - static const size_t TPB_X = 128; - dim3 blocks(raft::ceildiv(size_t(mask_index.extent(0)), TPB_X)); - dim3 threads(TPB_X); - // TODO thrust::for_each? - detail::bitset_unset_kernel<<>>( - bitset_view_.get_bitset_ptr(), mask_index.data_handle(), mask_index.extent(0)); + auto* bitset_ptr = bitset_view_.get_bitset_ptr(); + thrust::for_each_n(resource::get_thrust_policy(res), + mask_index.data_handle(), + mask_index.extent(0), + [bitset_ptr] __device__(const IdxT sample_index) { + const IdxT bit_element = sample_index / 32; + const IdxT bit_index = sample_index % 32; + const uint32_t bitmask = ~(1 << bit_index); + atomicAnd(bitset_ptr + bit_element, bitmask); + }); } +/** + * @brief Function to test a list of indices in a bitset. + * + * @tparam IdxT Indexing type + * @tparam OutputT Output type of the test. Default is bool. + * @param res RAFT resources + * @param bitset_view_ View of the bitset + * @param queries List of indices to test + * @param output List of outputs + */ template void bitset_test(const raft::resources& res, - const bitset_view bitset_view_, + const raft::utils::bitset_view bitset_view_, raft::device_vector_view queries, raft::device_vector_view output) { @@ -186,4 +236,5 @@ void bitset_test(const raft::resources& res, raft::linalg::map( res, output, [=] __device__(IdxT query) { return OutputT(bitset_view_.test(query)); }, queries); } -} // namespace raft::utils \ No newline at end of file +/** @} */ +} // end namespace raft::utils diff --git a/cpp/include/raft/util/memory_pool-inl.hpp b/cpp/include/raft/util/memory_pool-inl.hpp index 070c8f4e30..ad94ee0096 100644 --- a/cpp/include/raft/util/memory_pool-inl.hpp +++ b/cpp/include/raft/util/memory_pool-inl.hpp @@ -25,6 +25,10 @@ namespace raft { +/** + * @defgroup memory_pool Memory Pool + * @{ + */ /** * @brief Get a pointer to a pooled memory resource within the scope of the lifetime of the returned * unique pointer. @@ -73,4 +77,5 @@ RAFT_INLINE_CONDITIONAL std::unique_ptr get_poo return pool_res; } +/** @} */ } // namespace raft diff --git a/docs/source/cpp_api.rst b/docs/source/cpp_api.rst index 0e82d81e35..e60ef4e697 100644 --- a/docs/source/cpp_api.rst +++ b/docs/source/cpp_api.rst @@ -18,4 +18,5 @@ C++ API cpp_api/random.rst cpp_api/solver.rst cpp_api/sparse.rst - cpp_api/stats.rst \ No newline at end of file + cpp_api/stats.rst + cpp_api/utils.rst \ No newline at end of file diff --git a/docs/source/cpp_api/utils.rst b/docs/source/cpp_api/utils.rst new file mode 100644 index 0000000000..ccdb9919ac --- /dev/null +++ b/docs/source/cpp_api/utils.rst @@ -0,0 +1,33 @@ +Utilities +========= + +RAFT contains numerous utility functions and primitives that are easily usable. +This page provides C++ API references for the publicly-exposed utility functions. + +.. role:: py(code) + :language: c++ + :class: highlight + +Bitset +------ + +``#include `` + +namespace *raft::utils* + +.. doxygengroup:: bitset + :project: RAFT + :members: + :content-only: + +Memory Pool +----------- + +``#include `` + +namespace *raft* + +.. doxygengroup:: memory_pool + :project: RAFT + :members: + :content-only: From 75b0f82e7e7873b18a50d2b7ec86778577d17265 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 6 Sep 2023 21:42:33 +0200 Subject: [PATCH 06/16] Simplify code and remove shared mem kernel --- cpp/include/raft/util/bitset.cuh | 80 +++++++++----------------------- 1 file changed, 21 insertions(+), 59 deletions(-) diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index 3212039c1e..d3d3ccfd9e 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -23,53 +23,6 @@ #include namespace raft::utils { -namespace detail { -/** - * @brief Create bitset from list of indices to unset - * - * @tparam IdxT - * @tparam TPB - * @param bitset - * @param bitset_size - * @param index_ptr - * @param index_len - */ -template -__global__ void bitset_create_kernel(uint32_t* bitset, - const IdxT bitset_len, - const IdxT bitset_len_per_block, - const IdxT* index_ptr, - const IdxT index_len) -{ - extern __shared__ std::uint32_t shared_mem[]; - const IdxT bitset_start = blockIdx.x * blockDim.x; - - // Create local bitset in shmem - for (IdxT tid = threadIdx.x; tid < bitset_len_per_block; tid += TPB) { - shared_mem[tid] = 0xffffffff; - } - - __syncthreads(); - - for (IdxT tid = threadIdx.x; tid < index_len; tid += TPB) { - const IdxT sample_index = index_ptr[tid]; - const IdxT bit_element = sample_index / 32; - if (bit_element < bitset_start || (bit_element >= (bitset_start + bitset_len_per_block))) - continue; - const IdxT bit_index = sample_index % 32; - const std::uint32_t bitmask = ~(1 << bit_index); - atomicAnd(shared_mem + bit_element - bitset_start, bitmask); - } - - __syncthreads(); - // Output global bitset - for (IdxT tid = threadIdx.x; tid < bitset_len_per_block && (tid + bitset_start) < bitset_len; - tid += TPB) { - bitset[tid + bitset_start] = shared_mem[tid]; - } -} -} // end namespace detail - /** * @defgroup bitset Bitset * @{ @@ -144,7 +97,7 @@ struct bitset { IdxT bitset_element_size = sizeof(BitsetT) * 8; /** - * @brief Construct a new bitset object + * @brief Construct a new bitset object with a list of indices to unset. * * @param res RAFT resources * @param mask_index List of indices to unset in the bitset @@ -157,18 +110,27 @@ struct bitset { res, raft::ceildiv(bitset_len, bitset_element_size))} { RAFT_EXPECTS(mask_index.extent(0) <= bitset_len, "Mask index cannot be larger than bitset len"); - static const size_t TPB_X = 512; - dim3 blocks(raft::ceildiv(size_t(bitset_.extent(0)), TPB_X)); - dim3 threads(TPB_X); - auto bitset_len_per_block = TPB_X; + cudaMemsetAsync(bitset_.data_handle(), + 0xff, + bitset_.size() * sizeof(BitsetT), + resource::get_cuda_stream(res)); + bitset_unset(res, view(), mask_index); + } - detail::bitset_create_kernel - <<>>( - bitset_.data_handle(), - bitset_.extent(0), - bitset_len_per_block, - mask_index.data_handle(), - mask_index.extent(0)); + /** + * @brief Construct a new bitset object + * + * @param res RAFT resources + * @param bitset_len Length of the bitset + */ + bitset(const raft::resources& res, IdxT bitset_len) + : bitset_{raft::make_device_vector( + res, raft::ceildiv(bitset_len, bitset_element_size))} + { + cudaMemsetAsync(bitset_.data_handle(), + 0xff, + bitset_.size() * sizeof(BitsetT), + resource::get_cuda_stream(res)); } // Disable copy constructor bitset(const bitset&) = delete; From f834b1878e5195764eb0316caad7528f8cead8ae Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 11 Sep 2023 15:59:00 +0200 Subject: [PATCH 07/16] Enable bitset type template and add flip function --- cpp/include/raft/util/bitset.cuh | 193 +++++++++++++++++++++---------- cpp/test/util/bitset.cu | 123 +++++++++++++------- 2 files changed, 211 insertions(+), 105 deletions(-) diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index d3d3ccfd9e..a13ad724eb 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace raft::utils { @@ -33,14 +34,14 @@ namespace raft::utils { * This lightweight structure stores a pointer to a bitset in device memory with it's length. * It provides a test() device function to check if a given index is set in the bitset. * - * @tparam IdxT Indexing type used. Default is uint32_t. + * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. */ -template +template struct bitset_view { - using BitsetT = uint32_t; - IdxT bitset_element_size = sizeof(BitsetT) * 8; + index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; - _RAFT_HOST_DEVICE bitset_view(BitsetT* bitset_ptr, IdxT bitset_len) + _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len) : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} { } @@ -49,7 +50,7 @@ struct bitset_view { * * @param bitset_span Device vector view of the bitset */ - _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span) + _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span) : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_span.extent(0)} { } @@ -59,26 +60,38 @@ struct bitset_view { * @param sample_index Single index to test * @return bool True if index has not been unset in the bitset */ - inline _RAFT_DEVICE auto test(const IdxT sample_index) const -> bool + inline _RAFT_DEVICE auto test(const index_t sample_index) const -> bool { - const IdxT bit_element = bitset_ptr_[sample_index / bitset_element_size]; - const IdxT bit_index = sample_index % bitset_element_size; - const bool is_bit_set = (bit_element & (1ULL << bit_index)) != 0; + const bitset_t bit_element = bitset_ptr_[sample_index / bitset_element_size]; + const index_t bit_index = sample_index % bitset_element_size; + const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; return is_bit_set; } /** * @brief Get the device pointer to the bitset. */ - inline _RAFT_HOST_DEVICE auto get_bitset_ptr() -> BitsetT* { return bitset_ptr_; } - inline _RAFT_HOST_DEVICE auto get_bitset_ptr() const -> const BitsetT* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto data_handle() -> bitset_t* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto data_handle() const -> const bitset_t* { return bitset_ptr_; } /** - * @brief Get the length of the bitset representation. + * @brief Get the number of bits of the bitset representation. */ - inline _RAFT_HOST_DEVICE auto get_bitset_len() const -> IdxT { return bitset_len_; } + inline _RAFT_HOST_DEVICE auto size() const -> index_t + { + return bitset_len_ * bitset_element_size; + } + + inline auto to_mdspan() -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_ptr_, bitset_len_); + } + inline auto to_mdspan() const -> raft::device_vector_view + { + return raft::make_device_vector_view(bitset_ptr_, bitset_len_); + } private: - BitsetT* bitset_ptr_; - IdxT bitset_len_; + bitset_t* bitset_ptr_; + index_t bitset_len_; }; /** @@ -88,13 +101,12 @@ struct bitset_view { * device-usable lightweight view of the bitset. * Each index is represented by a single bit in the bitset. The total number of bytes used is * ceil(bitset_len / 4). - * The underlying type of the bitset array is uint32_t. - * @tparam IdxT Indexing type used. Default is uint32_t. + * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. */ -template +template struct bitset { - using BitsetT = uint32_t; - IdxT bitset_element_size = sizeof(BitsetT) * 8; + index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; /** * @brief Construct a new bitset object with a list of indices to unset. @@ -102,19 +114,20 @@ struct bitset { * @param res RAFT resources * @param mask_index List of indices to unset in the bitset * @param bitset_len Length of the bitset + * @param default_value Default value to set the bits to. Default is true. */ bitset(const raft::resources& res, - raft::device_vector_view mask_index, - IdxT bitset_len) - : bitset_{raft::make_device_vector( + raft::device_vector_view mask_index, + index_t bitset_len, + bool default_value = true) + : bitset_{raft::make_device_vector( res, raft::ceildiv(bitset_len, bitset_element_size))} { - RAFT_EXPECTS(mask_index.extent(0) <= bitset_len, "Mask index cannot be larger than bitset len"); cudaMemsetAsync(bitset_.data_handle(), - 0xff, - bitset_.size() * sizeof(BitsetT), + default_value ? 0xff : 0x00, + bitset_.size() * sizeof(bitset_t), resource::get_cuda_stream(res)); - bitset_unset(res, view(), mask_index); + bitset_set(res, view(), mask_index, !default_value); } /** @@ -122,14 +135,15 @@ struct bitset { * * @param res RAFT resources * @param bitset_len Length of the bitset + * @param default_value Default value to set the bits to. Default is true. */ - bitset(const raft::resources& res, IdxT bitset_len) - : bitset_{raft::make_device_vector( + bitset(const raft::resources& res, index_t bitset_len, bool default_value = true) + : bitset_{raft::make_device_vector( res, raft::ceildiv(bitset_len, bitset_element_size))} { cudaMemsetAsync(bitset_.data_handle(), - 0xff, - bitset_.size() * sizeof(BitsetT), + default_value ? 0xff : 0x00, + bitset_.size() * sizeof(bitset_t), resource::get_cuda_stream(res)); } // Disable copy constructor @@ -141,62 +155,119 @@ struct bitset { /** * @brief Create a device-usable view of the bitset. * - * @return bitset_view + * @return bitset_view */ - inline auto view() -> raft::utils::bitset_view { return bitset_view(bitset_.view()); } - [[nodiscard]] inline auto view() const -> raft::utils::bitset_view + inline auto view() -> raft::utils::bitset_view { - return bitset_view(bitset_.view()); + return bitset_view(bitset_.view()); + } + [[nodiscard]] inline auto view() const -> raft::utils::bitset_view + { + return bitset_view(bitset_.view()); + } + + /** + * @brief Get the device pointer to the bitset. + */ + inline auto data_handle() -> bitset_t* { return bitset_.data_handle(); } + inline auto data_handle() const -> const bitset_t* { return bitset_.data_handle(); } + /** + * @brief Get the number of bits of the bitset representation. + */ + inline auto size() const -> index_t { return bitset_.size() * bitset_element_size; } + + inline auto view_mdspan() -> raft::device_vector_view + { + return bitset_.view(); + } + [[nodiscard]] inline auto view_mdspan() const -> raft::device_vector_view + { + return bitset_.view(); } private: - raft::device_vector bitset_; + raft::device_vector bitset_; }; /** - * @brief Function to unset a list of indices in a bitset. + * @brief Set a list of indices in a bitset to set_value. * - * @tparam IdxT Indexing type used. Default is uint32_t. + * @tparam bitset_t Underlying type of the bitset array + * @tparam index_t Indexing type used. * @param res RAFT resources * @param bitset_view_ View of the bitset * @param mask_index indices to remove from the bitset + * @param set_value Value to set the bits to (true or false) */ -template -void bitset_unset(const raft::resources& res, - raft::utils::bitset_view bitset_view_, - raft::device_vector_view mask_index) +template +void bitset_set(const raft::resources& res, + raft::utils::bitset_view bitset_view_, + raft::device_vector_view mask_index, + bool set_value = false) { - auto* bitset_ptr = bitset_view_.get_bitset_ptr(); - thrust::for_each_n(resource::get_thrust_policy(res), - mask_index.data_handle(), - mask_index.extent(0), - [bitset_ptr] __device__(const IdxT sample_index) { - const IdxT bit_element = sample_index / 32; - const IdxT bit_index = sample_index % 32; - const uint32_t bitmask = ~(1 << bit_index); - atomicAnd(bitset_ptr + bit_element, bitmask); - }); + auto* bitset_ptr = bitset_view_.data_handle(); + constexpr auto bitset_element_size = + raft::utils::bitset_view::bitset_element_size; + thrust::for_each_n( + resource::get_thrust_policy(res), + mask_index.data_handle(), + mask_index.extent(0), + [bitset_ptr, set_value, bitset_element_size] __device__(const index_t sample_index) { + const index_t bit_element = sample_index / bitset_element_size; + const index_t bit_index = sample_index % bitset_element_size; + const bitset_t bitmask = bitset_t{1} << bit_index; + if (set_value) { + atomicOr(bitset_ptr + bit_element, bitmask); + } else { + const bitset_t bitmask2 = ~bitmask; + atomicAnd(bitset_ptr + bit_element, bitmask2); + } + }); } /** - * @brief Function to test a list of indices in a bitset. + * @brief Test a list of indices in a bitset. * - * @tparam IdxT Indexing type - * @tparam OutputT Output type of the test. Default is bool. + * @tparam bitset_t Underlying type of the bitset array + * @tparam index_t Indexing type + * @tparam output_t Output type of the test. Default is bool. * @param res RAFT resources * @param bitset_view_ View of the bitset * @param queries List of indices to test * @param output List of outputs */ -template +template void bitset_test(const raft::resources& res, - const raft::utils::bitset_view bitset_view_, - raft::device_vector_view queries, - raft::device_vector_view output) + const raft::utils::bitset_view bitset_view_, + raft::device_vector_view queries, + raft::device_vector_view output) { RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); raft::linalg::map( - res, output, [=] __device__(IdxT query) { return OutputT(bitset_view_.test(query)); }, queries); + res, + output, + [=] __device__(index_t query) { return output_t(bitset_view_.test(query)); }, + queries); +} + +/** + * @brief Flip all the bit in a bitset. + * + * @tparam bitset_t Underlying type of the bitset array + * @tparam index_t Indexing type + * @param res RAFT resources + * @param bitset_view_ View of the bitset + */ +template +void bitset_flip(const raft::resources& res, + raft::utils::bitset_view bitset_view_) +{ + auto bitset_span = bitset_view_.to_mdspan(); + raft::linalg::map( + res, + bitset_span, + [] __device__(bitset_t element) { return bitset_t(~element); }, + raft::make_const_mdspan(bitset_span)); } /** @} */ } // end namespace raft::utils diff --git a/cpp/test/util/bitset.cu b/cpp/test/util/bitset.cu index f251159a8c..3e789b70bc 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/util/bitset.cu @@ -27,58 +27,71 @@ namespace raft::utils { struct test_spec { - int bitset_len; - int mask_len; - int query_len; + uint64_t bitset_len; + uint64_t mask_len; + uint64_t query_len; }; auto operator<<(std::ostream& os, const test_spec& ss) -> std::ostream& { - os << "bitset{bitset_len: " << ss.bitset_len << ", mask_len: " << ss.mask_len << "}"; + os << "bitset{bitset_len: " << ss.bitset_len << ", mask_len: " << ss.mask_len + << ", query_len: " << ss.query_len << "}"; return os; } -template -void add_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) +template +void add_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) { + static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8; for (size_t i = 0; i < mask_idx.size(); i++) { auto idx = mask_idx[i]; - bitset[idx / 32] &= ~(1 << (idx % 32)); + bitset[idx / bitset_element_size] &= ~(bitset_t{1} << (idx % bitset_element_size)); } } -template -void create_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) +template +void create_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) { for (size_t i = 0; i < bitset.size(); i++) { - bitset[i] = 0xffffffff; + bitset[i] = ~bitset_t(0x00); } - add_cpu_bitset(bitset, mask_idx); + add_cpu_bitset(bitset, mask_idx); } -template -void test_cpu_bitset(const std::vector& bitset, - const std::vector& queries, +template +void test_cpu_bitset(const std::vector& bitset, + const std::vector& queries, std::vector& result) { + static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8; for (size_t i = 0; i < queries.size(); i++) { - result[i] = uint8_t((bitset[queries[i] / 32] & (1 << (queries[i] % 32))) != 0); + result[i] = uint8_t((bitset[queries[i] / bitset_element_size] & + (bitset_t{1} << (queries[i] % bitset_element_size))) != 0); } } -template +template +void flip_cpu_bitset(std::vector& bitset) +{ + for (size_t i = 0; i < bitset.size(); i++) { + bitset[i] = ~bitset[i]; + } +} + +template class BitsetTest : public testing::TestWithParam { protected: + index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; const test_spec spec; - std::vector bitset_result; - std::vector bitset_ref; + std::vector bitset_result; + std::vector bitset_ref; raft::resources res; public: explicit BitsetTest() : spec(testing::TestWithParam::GetParam()), - bitset_result(raft::ceildiv(spec.bitset_len, 32)), - bitset_ref(raft::ceildiv(spec.bitset_len, 32)) + bitset_result(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))), + bitset_ref(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))) { } @@ -88,50 +101,56 @@ class BitsetTest : public testing::TestWithParam { // generate input and mask raft::random::RngState rng(42); - auto mask_device = raft::make_device_vector(res, spec.mask_len); - std::vector mask_cpu(spec.mask_len); - raft::random::uniformInt(res, rng, mask_device.view(), T(0), T(spec.bitset_len)); + auto mask_device = raft::make_device_vector(res, spec.mask_len); + std::vector mask_cpu(spec.mask_len); + raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); // calculate the results - auto test_bitset = - raft::utils::bitset(res, raft::make_const_mdspan(mask_device.view()), T(spec.bitset_len)); + auto test_bitset = raft::utils::bitset( + res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len)); update_host( - bitset_result.data(), test_bitset.view().get_bitset_ptr(), bitset_result.size(), stream); + bitset_result.data(), test_bitset.view().data_handle(), bitset_result.size(), stream); // calculate the reference create_cpu_bitset(bitset_ref, mask_cpu); resource::sync_stream(res, stream); - ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); - auto query_device = raft::make_device_vector(res, spec.query_len); - auto result_device = raft::make_device_vector(res, spec.query_len); - auto query_cpu = std::vector(spec.query_len); + auto query_device = raft::make_device_vector(res, spec.query_len); + auto result_device = raft::make_device_vector(res, spec.query_len); + auto query_cpu = std::vector(spec.query_len); auto result_cpu = std::vector(spec.query_len); auto result_ref = std::vector(spec.query_len); // Create queries and verify the test results - raft::random::uniformInt(res, rng, query_device.view(), T(0), T(spec.bitset_len)); + raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); raft::utils::bitset_test( res, test_bitset.view(), raft::make_const_mdspan(query_device.view()), result_device.view()); update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); test_cpu_bitset(bitset_ref, query_cpu, result_ref); resource::sync_stream(res, stream); - ASSERT_TRUE(hostVecMatch(result_cpu, result_ref, Compare())); + ASSERT_TRUE(hostVecMatch(result_cpu, result_ref, Compare())); // Add more sample to the bitset and re-test - raft::random::uniformInt(res, rng, mask_device.view(), T(0), T(spec.bitset_len)); + raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); - raft::utils::bitset_unset(res, test_bitset.view(), mask_device.view()); - update_host( - bitset_result.data(), test_bitset.view().get_bitset_ptr(), bitset_result.size(), stream); + raft::utils::bitset_set(res, test_bitset.view(), mask_device.view()); + update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream); add_cpu_bitset(bitset_ref, mask_cpu); resource::sync_stream(res, stream); - ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + + // Flip the bitset and re-test + raft::utils::bitset_flip(res, test_bitset.view()); + update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream); + flip_cpu_bitset(bitset_ref); + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); } }; @@ -140,15 +159,31 @@ auto inputs = ::testing::Values(test_spec{32, 5, 10}, test_spec{1024, 55, 100}, test_spec{10000, 1000, 1000}, test_spec{1 << 15, 1 << 3, 1 << 12}, - test_spec{1 << 15, 1 << 14, 1 << 13}, + test_spec{1 << 15, 1 << 24, 1 << 13}, test_spec{1 << 25, 1 << 23, 1 << 14}); -using Uint32 = BitsetTest; -TEST_P(Uint32, Run) { run(); } -INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32, inputs); +using Uint16_32 = BitsetTest; +TEST_P(Uint16_32, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint16_32, inputs); + +using Uint32_32 = BitsetTest; +TEST_P(Uint32_32, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32_32, inputs); + +using Uint64_32 = BitsetTest; +TEST_P(Uint64_32, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_32, inputs); + +using Uint8_64 = BitsetTest; +TEST_P(Uint8_64, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint8_64, inputs); + +using Uint32_64 = BitsetTest; +TEST_P(Uint32_64, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32_64, inputs); -using Uint64 = BitsetTest; -TEST_P(Uint64, Run) { run(); } -INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64, inputs); +using Uint64_64 = BitsetTest; +TEST_P(Uint64_64, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_64, inputs); } // namespace raft::utils From d74bbb23a666d3938c660dc5818a6ab637e026c5 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 11 Sep 2023 18:18:18 +0200 Subject: [PATCH 08/16] Add bench for bitset --- build.sh | 2 +- cpp/bench/prims/CMakeLists.txt | 2 + cpp/bench/prims/util/bitset.cu | 75 ++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 cpp/bench/prims/util/bitset.cu diff --git a/build.sh b/build.sh index 071820ba93..1fa1abbee5 100755 --- a/build.sh +++ b/build.sh @@ -79,7 +79,7 @@ BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" -BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" +BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH;UTIL_BENCH" CACHE_ARGS="" NVTX=ON diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index e8d4739384..1690d6f320 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -155,4 +155,6 @@ if(BUILD_PRIMS_BENCH) LIB EXPLICIT_INSTANTIATE_ONLY ) + + ConfigureBench(NAME UTIL_BENCH PATH bench/prims/util/bitset.cu bench/prims/main.cpp) endif() diff --git a/cpp/bench/prims/util/bitset.cu b/cpp/bench/prims/util/bitset.cu new file mode 100644 index 0000000000..47561c8052 --- /dev/null +++ b/cpp/bench/prims/util/bitset.cu @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +namespace raft::bench::utils { + +struct bitset_inputs { + uint32_t bitset_len; + uint32_t mask_len; + uint32_t query_len; +}; // struct bitset_inputs + +template +struct bitset_bench : public fixture { + bitset_bench(const bitset_inputs& p) + : params(p), + mask{raft::make_device_vector(res, p.mask_len)}, + queries{raft::make_device_vector(res, p.query_len)}, + outputs{raft::make_device_vector(res, p.query_len)} + { + raft::random::RngState state{42}; + raft::random::uniformInt(res, state, mask.view(), index_t{0}, index_t{p.bitset_len}); + } + + void run_benchmark(::benchmark::State& state) override + { + loop_on_state(state, [this]() { + auto my_bitset = raft::utils::bitset( + this->res, raft::make_const_mdspan(mask.view()), params.bitset_len); + raft::utils::bitset_test( + res, my_bitset.view(), raft::make_const_mdspan(queries.view()), outputs.view()); + }); + } + + private: + raft::resources res; + bitset_inputs params; + raft::device_vector mask, queries; + raft::device_vector outputs; +}; // struct bitset + +const std::vector bitset_input_vecs{ + {256 * 1024 * 1024, 64 * 1024 * 1024, 256 * 1024 * 1024}, // Standard Bench + {256 * 1024 * 1024, 64 * 1024 * 1024, 1024 * 1024 * 1024}, // Extra queries + {128 * 1024 * 1024, 1024 * 1024 * 1024, 256 * 1024 * 1024}, // Extra mask to test atomics impact +}; + +using Uint8_32 = bitset_bench; +using Uint16_64 = bitset_bench; +using Uint32_32 = bitset_bench; +using Uint32_64 = bitset_bench; + +RAFT_BENCH_REGISTER(Uint8_32, "", bitset_input_vecs); +RAFT_BENCH_REGISTER(Uint16_64, "", bitset_input_vecs); +RAFT_BENCH_REGISTER(Uint32_32, "", bitset_input_vecs); +RAFT_BENCH_REGISTER(Uint32_64, "", bitset_input_vecs); + +} // namespace raft::bench::utils From d075184c2f2005eabf8c982f02330bdaa05932ed Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 12 Sep 2023 16:37:25 +0200 Subject: [PATCH 09/16] Fix namespace typo --- cpp/bench/prims/util/bitset.cu | 8 ++++---- cpp/include/raft/util/bitset.cuh | 16 ++++++++-------- cpp/test/util/bitset.cu | 12 ++++++------ 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/cpp/bench/prims/util/bitset.cu b/cpp/bench/prims/util/bitset.cu index 47561c8052..c7cba797f4 100644 --- a/cpp/bench/prims/util/bitset.cu +++ b/cpp/bench/prims/util/bitset.cu @@ -19,7 +19,7 @@ #include #include -namespace raft::bench::utils { +namespace raft::bench::util { struct bitset_inputs { uint32_t bitset_len; @@ -42,9 +42,9 @@ struct bitset_bench : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - auto my_bitset = raft::utils::bitset( + auto my_bitset = raft::util::bitset( this->res, raft::make_const_mdspan(mask.view()), params.bitset_len); - raft::utils::bitset_test( + raft::util::bitset_test( res, my_bitset.view(), raft::make_const_mdspan(queries.view()), outputs.view()); }); } @@ -72,4 +72,4 @@ RAFT_BENCH_REGISTER(Uint16_64, "", bitset_input_vecs); RAFT_BENCH_REGISTER(Uint32_32, "", bitset_input_vecs); RAFT_BENCH_REGISTER(Uint32_64, "", bitset_input_vecs); -} // namespace raft::bench::utils +} // namespace raft::bench::util diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index a13ad724eb..5dbf4c00e8 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -23,7 +23,7 @@ #include #include -namespace raft::utils { +namespace raft::util { /** * @defgroup bitset Bitset * @{ @@ -157,11 +157,11 @@ struct bitset { * * @return bitset_view */ - inline auto view() -> raft::utils::bitset_view + inline auto view() -> raft::util::bitset_view { return bitset_view(bitset_.view()); } - [[nodiscard]] inline auto view() const -> raft::utils::bitset_view + [[nodiscard]] inline auto view() const -> raft::util::bitset_view { return bitset_view(bitset_.view()); } @@ -201,13 +201,13 @@ struct bitset { */ template void bitset_set(const raft::resources& res, - raft::utils::bitset_view bitset_view_, + raft::util::bitset_view bitset_view_, raft::device_vector_view mask_index, bool set_value = false) { auto* bitset_ptr = bitset_view_.data_handle(); constexpr auto bitset_element_size = - raft::utils::bitset_view::bitset_element_size; + raft::util::bitset_view::bitset_element_size; thrust::for_each_n( resource::get_thrust_policy(res), mask_index.data_handle(), @@ -238,7 +238,7 @@ void bitset_set(const raft::resources& res, */ template void bitset_test(const raft::resources& res, - const raft::utils::bitset_view bitset_view_, + const raft::util::bitset_view bitset_view_, raft::device_vector_view queries, raft::device_vector_view output) { @@ -260,7 +260,7 @@ void bitset_test(const raft::resources& res, */ template void bitset_flip(const raft::resources& res, - raft::utils::bitset_view bitset_view_) + raft::util::bitset_view bitset_view_) { auto bitset_span = bitset_view_.to_mdspan(); raft::linalg::map( @@ -270,4 +270,4 @@ void bitset_flip(const raft::resources& res, raft::make_const_mdspan(bitset_span)); } /** @} */ -} // end namespace raft::utils +} // end namespace raft::util diff --git a/cpp/test/util/bitset.cu b/cpp/test/util/bitset.cu index 3e789b70bc..69ade45432 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/util/bitset.cu @@ -24,7 +24,7 @@ #include #include -namespace raft::utils { +namespace raft::util { struct test_spec { uint64_t bitset_len; @@ -108,7 +108,7 @@ class BitsetTest : public testing::TestWithParam { resource::sync_stream(res, stream); // calculate the results - auto test_bitset = raft::utils::bitset( + auto test_bitset = raft::util::bitset( res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len)); update_host( bitset_result.data(), test_bitset.view().data_handle(), bitset_result.size(), stream); @@ -127,7 +127,7 @@ class BitsetTest : public testing::TestWithParam { // Create queries and verify the test results raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); - raft::utils::bitset_test( + raft::util::bitset_test( res, test_bitset.view(), raft::make_const_mdspan(query_device.view()), result_device.view()); update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); test_cpu_bitset(bitset_ref, query_cpu, result_ref); @@ -138,7 +138,7 @@ class BitsetTest : public testing::TestWithParam { raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); - raft::utils::bitset_set(res, test_bitset.view(), mask_device.view()); + raft::util::bitset_set(res, test_bitset.view(), mask_device.view()); update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream); add_cpu_bitset(bitset_ref, mask_cpu); @@ -146,7 +146,7 @@ class BitsetTest : public testing::TestWithParam { ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); // Flip the bitset and re-test - raft::utils::bitset_flip(res, test_bitset.view()); + raft::util::bitset_flip(res, test_bitset.view()); update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream); flip_cpu_bitset(bitset_ref); resource::sync_stream(res, stream); @@ -186,4 +186,4 @@ using Uint64_64 = BitsetTest; TEST_P(Uint64_64, Run) { run(); } INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_64, inputs); -} // namespace raft::utils +} // namespace raft::util From 1f5472082950196e678c1e039557679a00ea01bf Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 18 Sep 2023 15:30:43 +0200 Subject: [PATCH 10/16] Use raft device_uvector in bitset --- build.sh | 3 +- cpp/include/raft/util/bitset.cuh | 60 +++++++++++++++++++++++--------- cpp/test/util/bitset.cu | 1 + 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/build.sh b/build.sh index 1fa1abbee5..7461b3ca27 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH;UTIL_BENCH" CACHE_ARGS="" @@ -324,6 +324,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then $CMAKE_TARGET == *"DISTANCE_TEST"* || \ $CMAKE_TARGET == *"MATRIX_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ $CMAKE_TARGET == *"SPARSE_DIST_TEST" || \ $CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \ diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index 5dbf4c00e8..25e46c4514 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include #include @@ -120,12 +120,14 @@ struct bitset { raft::device_vector_view mask_index, index_t bitset_len, bool default_value = true) - : bitset_{raft::make_device_vector( - res, raft::ceildiv(bitset_len, bitset_element_size))} + : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), + raft::resource::get_cuda_stream(res)}, + bitset_len_{bitset_len}, + default_value_{default_value} { - cudaMemsetAsync(bitset_.data_handle(), + cudaMemsetAsync(bitset_.data(), default_value ? 0xff : 0x00, - bitset_.size() * sizeof(bitset_t), + raft::ceildiv(bitset_len, bitset_element_size) * sizeof(bitset_t), resource::get_cuda_stream(res)); bitset_set(res, view(), mask_index, !default_value); } @@ -138,12 +140,14 @@ struct bitset { * @param default_value Default value to set the bits to. Default is true. */ bitset(const raft::resources& res, index_t bitset_len, bool default_value = true) - : bitset_{raft::make_device_vector( - res, raft::ceildiv(bitset_len, bitset_element_size))} + : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), + resource::get_cuda_stream(res)}, + bitset_len_{bitset_len}, + default_value_{default_value} { - cudaMemsetAsync(bitset_.data_handle(), + cudaMemsetAsync(bitset_.data(), default_value ? 0xff : 0x00, - bitset_.size() * sizeof(bitset_t), + raft::ceildiv(bitset_len, bitset_element_size) * sizeof(bitset_t), resource::get_cuda_stream(res)); } // Disable copy constructor @@ -159,34 +163,56 @@ struct bitset { */ inline auto view() -> raft::util::bitset_view { - return bitset_view(bitset_.view()); + return bitset_view(view_mdspan()); } [[nodiscard]] inline auto view() const -> raft::util::bitset_view { - return bitset_view(bitset_.view()); + return bitset_view(view_mdspan()); } /** * @brief Get the device pointer to the bitset. */ - inline auto data_handle() -> bitset_t* { return bitset_.data_handle(); } - inline auto data_handle() const -> const bitset_t* { return bitset_.data_handle(); } + inline auto data_handle() -> bitset_t* { return bitset_.data(); } + inline auto data_handle() const -> const bitset_t* { return bitset_.data(); } /** * @brief Get the number of bits of the bitset representation. */ - inline auto size() const -> index_t { return bitset_.size() * bitset_element_size; } + inline auto size() const -> index_t { return bitset_len_; } + /** @brief Get an mdspan view of the current bitset */ inline auto view_mdspan() -> raft::device_vector_view { - return bitset_.view(); + return raft::make_device_vector_view( + bitset_.data(), raft::ceildiv(bitset_len_, bitset_element_size)); } [[nodiscard]] inline auto view_mdspan() const -> raft::device_vector_view { - return bitset_.view(); + return raft::make_device_vector_view( + bitset_.data(), raft::ceildiv(bitset_len_, bitset_element_size)); + } + + /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to + * the default value. */ + void resize(const raft::resources& res, index_t new_bitset_len) + { + auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); + auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); + bitset_.resize(new_size); + bitset_len_ = new_bitset_len; + if (old_size < new_size) { + // If the new size is larger, set the new bits to the default value + cudaMemsetAsync(bitset_.data() + old_size, + default_value_ ? 0xff : 0x00, + (new_size - old_size) * sizeof(bitset_t), + resource::get_cuda_stream(res)); + } } private: - raft::device_vector bitset_; + raft::device_uvector bitset_; + index_t bitset_len_; + bool default_value_; }; /** diff --git a/cpp/test/util/bitset.cu b/cpp/test/util/bitset.cu index 69ade45432..f71fb48936 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/util/bitset.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" +#include #include #include From e0c1d2484279967f1adbc6e10f8a7f3bb982650b Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 18 Sep 2023 15:37:13 +0200 Subject: [PATCH 11/16] Fix build.sh --- build.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/build.sh b/build.sh index 7461b3ca27..1fa1abbee5 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH;UTIL_BENCH" CACHE_ARGS="" @@ -324,7 +324,6 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then $CMAKE_TARGET == *"DISTANCE_TEST"* || \ $CMAKE_TARGET == *"MATRIX_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \ - $CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ $CMAKE_TARGET == *"SPARSE_DIST_TEST" || \ $CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \ From 8d30c9fc3a7df580c11f2a06277ecf36c6232588 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 18 Sep 2023 17:56:19 +0200 Subject: [PATCH 12/16] Add n_elements fix size --- cpp/include/raft/util/bitset.cuh | 36 +++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index 25e46c4514..cd2106dba2 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -51,7 +51,8 @@ struct bitset_view { * @param bitset_span Device vector view of the bitset */ _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span) - : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_span.extent(0)} + : bitset_ptr_{bitset_span.data_handle()}, + bitset_len_{bitset_span.extent(0) * bitset_element_size} { } /** @@ -75,18 +76,23 @@ struct bitset_view { /** * @brief Get the number of bits of the bitset representation. */ - inline _RAFT_HOST_DEVICE auto size() const -> index_t + inline _RAFT_HOST_DEVICE auto size() const -> index_t { return bitset_len_; } + + /** + * @brief Get the number of elements used by the bitset representation. + */ + inline auto n_elements() const -> index_t { - return bitset_len_ * bitset_element_size; + return raft::ceildiv(bitset_len_, bitset_element_size); } inline auto to_mdspan() -> raft::device_vector_view { - return raft::make_device_vector_view(bitset_ptr_, bitset_len_); + return raft::make_device_vector_view(bitset_ptr_, n_elements()); } inline auto to_mdspan() const -> raft::device_vector_view { - return raft::make_device_vector_view(bitset_ptr_, bitset_len_); + return raft::make_device_vector_view(bitset_ptr_, n_elements()); } private: @@ -100,7 +106,7 @@ struct bitset_view { * This structure encapsulates a bitset in device memory. It provides a view() method to get a * device-usable lightweight view of the bitset. * Each index is represented by a single bit in the bitset. The total number of bytes used is - * ceil(bitset_len / 4). + * ceil(bitset_len / 8). * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. * @tparam index_t Indexing type used. Default is uint32_t. */ @@ -127,7 +133,7 @@ struct bitset { { cudaMemsetAsync(bitset_.data(), default_value ? 0xff : 0x00, - raft::ceildiv(bitset_len, bitset_element_size) * sizeof(bitset_t), + n_elements() * sizeof(bitset_t), resource::get_cuda_stream(res)); bitset_set(res, view(), mask_index, !default_value); } @@ -147,7 +153,7 @@ struct bitset { { cudaMemsetAsync(bitset_.data(), default_value ? 0xff : 0x00, - raft::ceildiv(bitset_len, bitset_element_size) * sizeof(bitset_t), + n_elements() * sizeof(bitset_t), resource::get_cuda_stream(res)); } // Disable copy constructor @@ -180,16 +186,22 @@ struct bitset { */ inline auto size() const -> index_t { return bitset_len_; } + /** + * @brief Get the number of elements used by the bitset representation. + */ + inline auto n_elements() const -> index_t + { + return raft::ceildiv(bitset_len_, bitset_element_size); + } + /** @brief Get an mdspan view of the current bitset */ inline auto view_mdspan() -> raft::device_vector_view { - return raft::make_device_vector_view( - bitset_.data(), raft::ceildiv(bitset_len_, bitset_element_size)); + return raft::make_device_vector_view(bitset_.data(), n_elements()); } [[nodiscard]] inline auto view_mdspan() const -> raft::device_vector_view { - return raft::make_device_vector_view( - bitset_.data(), raft::ceildiv(bitset_len_, bitset_element_size)); + return raft::make_device_vector_view(bitset_.data(), n_elements()); } /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to From 480d20ead6d2c987e7ee2a01acf36a0341fdd3d5 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 18 Sep 2023 22:21:47 +0200 Subject: [PATCH 13/16] Fix naming --- cpp/test/util/bitset.cu | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/cpp/test/util/bitset.cu b/cpp/test/util/bitset.cu index f71fb48936..10f82a480c 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/util/bitset.cu @@ -27,13 +27,13 @@ namespace raft::util { -struct test_spec { +struct test_spec_bitset { uint64_t bitset_len; uint64_t mask_len; uint64_t query_len; }; -auto operator<<(std::ostream& os, const test_spec& ss) -> std::ostream& +auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream& { os << "bitset{bitset_len: " << ss.bitset_len << ", mask_len: " << ss.mask_len << ", query_len: " << ss.query_len << "}"; @@ -80,17 +80,17 @@ void flip_cpu_bitset(std::vector& bitset) } template -class BitsetTest : public testing::TestWithParam { +class BitsetTest : public testing::TestWithParam { protected: index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; - const test_spec spec; + const test_spec_bitset spec; std::vector bitset_result; std::vector bitset_ref; raft::resources res; public: explicit BitsetTest() - : spec(testing::TestWithParam::GetParam()), + : spec(testing::TestWithParam::GetParam()), bitset_result(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))), bitset_ref(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))) { @@ -155,13 +155,13 @@ class BitsetTest : public testing::TestWithParam { } }; -auto inputs = ::testing::Values(test_spec{32, 5, 10}, - test_spec{100, 30, 10}, - test_spec{1024, 55, 100}, - test_spec{10000, 1000, 1000}, - test_spec{1 << 15, 1 << 3, 1 << 12}, - test_spec{1 << 15, 1 << 24, 1 << 13}, - test_spec{1 << 25, 1 << 23, 1 << 14}); +auto inputs = ::testing::Values(test_spec_bitset{32, 5, 10}, + test_spec_bitset{100, 30, 10}, + test_spec_bitset{1024, 55, 100}, + test_spec_bitset{10000, 1000, 1000}, + test_spec_bitset{1 << 15, 1 << 3, 1 << 12}, + test_spec_bitset{1 << 15, 1 << 24, 1 << 13}, + test_spec_bitset{1 << 25, 1 << 23, 1 << 14}); using Uint16_32 = BitsetTest; TEST_P(Uint16_32, Run) { run(); } From 240af3e299d4817a7fdcf9e84224d22a9ec634c8 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 19 Sep 2023 14:52:05 +0200 Subject: [PATCH 14/16] Fix naming of inputs --- cpp/test/util/bitset.cu | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cpp/test/util/bitset.cu b/cpp/test/util/bitset.cu index 10f82a480c..4793dde2f1 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/util/bitset.cu @@ -155,36 +155,36 @@ class BitsetTest : public testing::TestWithParam { } }; -auto inputs = ::testing::Values(test_spec_bitset{32, 5, 10}, - test_spec_bitset{100, 30, 10}, - test_spec_bitset{1024, 55, 100}, - test_spec_bitset{10000, 1000, 1000}, - test_spec_bitset{1 << 15, 1 << 3, 1 << 12}, - test_spec_bitset{1 << 15, 1 << 24, 1 << 13}, - test_spec_bitset{1 << 25, 1 << 23, 1 << 14}); +auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10}, + test_spec_bitset{100, 30, 10}, + test_spec_bitset{1024, 55, 100}, + test_spec_bitset{10000, 1000, 1000}, + test_spec_bitset{1 << 15, 1 << 3, 1 << 12}, + test_spec_bitset{1 << 15, 1 << 24, 1 << 13}, + test_spec_bitset{1 << 25, 1 << 23, 1 << 14}); using Uint16_32 = BitsetTest; TEST_P(Uint16_32, Run) { run(); } -INSTANTIATE_TEST_CASE_P(BitsetTest, Uint16_32, inputs); +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint16_32, inputs_bitset); using Uint32_32 = BitsetTest; TEST_P(Uint32_32, Run) { run(); } -INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32_32, inputs); +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32_32, inputs_bitset); using Uint64_32 = BitsetTest; TEST_P(Uint64_32, Run) { run(); } -INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_32, inputs); +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_32, inputs_bitset); using Uint8_64 = BitsetTest; TEST_P(Uint8_64, Run) { run(); } -INSTANTIATE_TEST_CASE_P(BitsetTest, Uint8_64, inputs); +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint8_64, inputs_bitset); using Uint32_64 = BitsetTest; TEST_P(Uint32_64, Run) { run(); } -INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32_64, inputs); +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32_64, inputs_bitset); using Uint64_64 = BitsetTest; TEST_P(Uint64_64, Run) { run(); } -INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_64, inputs); +INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_64, inputs_bitset); } // namespace raft::util From 08c43d0dbec98c9680b122df1c30f6c1f0645a47 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 19 Sep 2023 19:58:36 +0200 Subject: [PATCH 15/16] Fix bitset_len in bitset_view --- cpp/include/raft/util/bitset.cuh | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index cd2106dba2..af5ef79588 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -49,10 +49,11 @@ struct bitset_view { * @brief Create a bitset view from a device vector view of the bitset. * * @param bitset_span Device vector view of the bitset + * @param bitset_len Number of bits in the bitset */ - _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span) - : bitset_ptr_{bitset_span.data_handle()}, - bitset_len_{bitset_span.extent(0) * bitset_element_size} + _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span, + index_t bitset_len) + : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_len} { } /** @@ -169,11 +170,11 @@ struct bitset { */ inline auto view() -> raft::util::bitset_view { - return bitset_view(view_mdspan()); + return bitset_view(view_mdspan(), bitset_len_); } [[nodiscard]] inline auto view() const -> raft::util::bitset_view { - return bitset_view(view_mdspan()); + return bitset_view(view_mdspan(), bitset_len_); } /** From 9115bbb31fc6c0902bb9d5c8c15e360d41da4c97 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 20 Sep 2023 16:31:57 +0200 Subject: [PATCH 16/16] Move bitset to core and match std::bitset --- build.sh | 2 +- cpp/bench/prims/CMakeLists.txt | 2 +- cpp/bench/prims/{util => core}/bitset.cu | 11 +- cpp/include/raft/{util => core}/bitset.cuh | 176 ++++++++++----------- cpp/test/CMakeLists.txt | 2 +- cpp/test/{util => core}/bitset.cu | 22 ++- docs/source/cpp_api/core.rst | 3 +- docs/source/cpp_api/core_bitset.rst | 15 ++ docs/source/cpp_api/utils.rst | 12 -- 9 files changed, 121 insertions(+), 124 deletions(-) rename cpp/bench/prims/{util => core}/bitset.cu (89%) rename cpp/include/raft/{util => core}/bitset.cuh (68%) rename cpp/test/{util => core}/bitset.cu (90%) create mode 100644 docs/source/cpp_api/core_bitset.rst diff --git a/build.sh b/build.sh index 1fa1abbee5..5543faaebe 100755 --- a/build.sh +++ b/build.sh @@ -79,7 +79,7 @@ BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" -BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH;UTIL_BENCH" +BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" NVTX=ON diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 1690d6f320..ca4b0f099d 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -77,6 +77,7 @@ if(BUILD_PRIMS_BENCH) NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) + ConfigureBench(NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/main.cpp) ConfigureBench( NAME TUNE_DISTANCE PATH bench/prims/distance/tune_pairwise/kernel.cu @@ -156,5 +157,4 @@ if(BUILD_PRIMS_BENCH) EXPLICIT_INSTANTIATE_ONLY ) - ConfigureBench(NAME UTIL_BENCH PATH bench/prims/util/bitset.cu bench/prims/main.cpp) endif() diff --git a/cpp/bench/prims/util/bitset.cu b/cpp/bench/prims/core/bitset.cu similarity index 89% rename from cpp/bench/prims/util/bitset.cu rename to cpp/bench/prims/core/bitset.cu index c7cba797f4..5f44aa9af5 100644 --- a/cpp/bench/prims/util/bitset.cu +++ b/cpp/bench/prims/core/bitset.cu @@ -15,11 +15,11 @@ */ #include +#include #include -#include #include -namespace raft::bench::util { +namespace raft::bench::core { struct bitset_inputs { uint32_t bitset_len; @@ -42,10 +42,9 @@ struct bitset_bench : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - auto my_bitset = raft::util::bitset( + auto my_bitset = raft::core::bitset( this->res, raft::make_const_mdspan(mask.view()), params.bitset_len); - raft::util::bitset_test( - res, my_bitset.view(), raft::make_const_mdspan(queries.view()), outputs.view()); + my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view()); }); } @@ -72,4 +71,4 @@ RAFT_BENCH_REGISTER(Uint16_64, "", bitset_input_vecs); RAFT_BENCH_REGISTER(Uint32_32, "", bitset_input_vecs); RAFT_BENCH_REGISTER(Uint32_64, "", bitset_input_vecs); -} // namespace raft::bench::util +} // namespace raft::bench::core diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/core/bitset.cuh similarity index 68% rename from cpp/include/raft/util/bitset.cuh rename to cpp/include/raft/core/bitset.cuh index af5ef79588..6747c5fab0 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -23,7 +23,7 @@ #include #include -namespace raft::util { +namespace raft::core { /** * @defgroup bitset Bitset * @{ @@ -69,6 +69,7 @@ struct bitset_view { const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; return is_bit_set; } + /** * @brief Get the device pointer to the bitset. */ @@ -82,7 +83,7 @@ struct bitset_view { /** * @brief Get the number of elements used by the bitset representation. */ - inline auto n_elements() const -> index_t + inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t { return raft::ceildiv(bitset_len_, bitset_element_size); } @@ -136,7 +137,7 @@ struct bitset { default_value ? 0xff : 0x00, n_elements() * sizeof(bitset_t), resource::get_cuda_stream(res)); - bitset_set(res, view(), mask_index, !default_value); + set(res, mask_index, !default_value); } /** @@ -168,13 +169,13 @@ struct bitset { * * @return bitset_view */ - inline auto view() -> raft::util::bitset_view + inline auto view() -> raft::core::bitset_view { - return bitset_view(view_mdspan(), bitset_len_); + return bitset_view(to_mdspan(), bitset_len_); } - [[nodiscard]] inline auto view() const -> raft::util::bitset_view + [[nodiscard]] inline auto view() const -> raft::core::bitset_view { - return bitset_view(view_mdspan(), bitset_len_); + return bitset_view(to_mdspan(), bitset_len_); } /** @@ -196,11 +197,11 @@ struct bitset { } /** @brief Get an mdspan view of the current bitset */ - inline auto view_mdspan() -> raft::device_vector_view + inline auto to_mdspan() -> raft::device_vector_view { return raft::make_device_vector_view(bitset_.data(), n_elements()); } - [[nodiscard]] inline auto view_mdspan() const -> raft::device_vector_view + [[nodiscard]] inline auto to_mdspan() const -> raft::device_vector_view { return raft::make_device_vector_view(bitset_.data(), n_elements()); } @@ -222,91 +223,86 @@ struct bitset { } } + /** + * @brief Test a list of indices in a bitset. + * + * @tparam output_t Output type of the test. Default is bool. + * @param res RAFT resources + * @param queries List of indices to test + * @param output List of outputs + */ + template + void test(const raft::resources& res, + raft::device_vector_view queries, + raft::device_vector_view output) const + { + RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); + auto bitset_view = view(); + raft::linalg::map( + res, + output, + [bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); }, + queries); + } + /** + * @brief Set a list of indices in a bitset to set_value. + * + * @param res RAFT resources + * @param mask_index indices to remove from the bitset + * @param set_value Value to set the bits to (true or false) + */ + void set(const raft::resources& res, + raft::device_vector_view mask_index, + bool set_value = false) + { + auto* bitset_ptr = this->data_handle(); + thrust::for_each_n(resource::get_thrust_policy(res), + mask_index.data_handle(), + mask_index.extent(0), + [bitset_ptr, set_value] __device__(const index_t sample_index) { + const index_t bit_element = sample_index / bitset_element_size; + const index_t bit_index = sample_index % bitset_element_size; + const bitset_t bitmask = bitset_t{1} << bit_index; + if (set_value) { + atomicOr(bitset_ptr + bit_element, bitmask); + } else { + const bitset_t bitmask2 = ~bitmask; + atomicAnd(bitset_ptr + bit_element, bitmask2); + } + }); + } + /** + * @brief Flip all the bits in a bitset. + * + * @param res RAFT resources + */ + void flip(const raft::resources& res) + { + auto bitset_span = this->to_mdspan(); + raft::linalg::map( + res, + bitset_span, + [] __device__(bitset_t element) { return bitset_t(~element); }, + raft::make_const_mdspan(bitset_span)); + } + /** + * @brief Reset the bits in a bitset. + * + * @param res RAFT resources + */ + void reset(const raft::resources& res) + { + cudaMemsetAsync(bitset_.data(), + default_value_ ? 0xff : 0x00, + n_elements() * sizeof(bitset_t), + resource::get_cuda_stream(res)); + } + private: raft::device_uvector bitset_; index_t bitset_len_; bool default_value_; }; -/** - * @brief Set a list of indices in a bitset to set_value. - * - * @tparam bitset_t Underlying type of the bitset array - * @tparam index_t Indexing type used. - * @param res RAFT resources - * @param bitset_view_ View of the bitset - * @param mask_index indices to remove from the bitset - * @param set_value Value to set the bits to (true or false) - */ -template -void bitset_set(const raft::resources& res, - raft::util::bitset_view bitset_view_, - raft::device_vector_view mask_index, - bool set_value = false) -{ - auto* bitset_ptr = bitset_view_.data_handle(); - constexpr auto bitset_element_size = - raft::util::bitset_view::bitset_element_size; - thrust::for_each_n( - resource::get_thrust_policy(res), - mask_index.data_handle(), - mask_index.extent(0), - [bitset_ptr, set_value, bitset_element_size] __device__(const index_t sample_index) { - const index_t bit_element = sample_index / bitset_element_size; - const index_t bit_index = sample_index % bitset_element_size; - const bitset_t bitmask = bitset_t{1} << bit_index; - if (set_value) { - atomicOr(bitset_ptr + bit_element, bitmask); - } else { - const bitset_t bitmask2 = ~bitmask; - atomicAnd(bitset_ptr + bit_element, bitmask2); - } - }); -} - -/** - * @brief Test a list of indices in a bitset. - * - * @tparam bitset_t Underlying type of the bitset array - * @tparam index_t Indexing type - * @tparam output_t Output type of the test. Default is bool. - * @param res RAFT resources - * @param bitset_view_ View of the bitset - * @param queries List of indices to test - * @param output List of outputs - */ -template -void bitset_test(const raft::resources& res, - const raft::util::bitset_view bitset_view_, - raft::device_vector_view queries, - raft::device_vector_view output) -{ - RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); - raft::linalg::map( - res, - output, - [=] __device__(index_t query) { return output_t(bitset_view_.test(query)); }, - queries); -} - -/** - * @brief Flip all the bit in a bitset. - * - * @tparam bitset_t Underlying type of the bitset array - * @tparam index_t Indexing type - * @param res RAFT resources - * @param bitset_view_ View of the bitset - */ -template -void bitset_flip(const raft::resources& res, - raft::util::bitset_view bitset_view_) -{ - auto bitset_span = bitset_view_.to_mdspan(); - raft::linalg::map( - res, - bitset_span, - [] __device__(bitset_t element) { return bitset_t(~element); }, - raft::make_const_mdspan(bitset_span)); -} /** @} */ -} // end namespace raft::util +} // end namespace raft::core diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 57a45c557c..a9b387008f 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -105,6 +105,7 @@ if(BUILD_TESTS) NAME CORE_TEST PATH + test/core/bitset.cu test/core/device_resources_manager.cpp test/core/device_setter.cpp test/core/logger.cpp @@ -423,7 +424,6 @@ if(BUILD_TESTS) PATH test/core/seive.cu test/util/bitonic_sort.cu - test/util/bitset.cu test/util/cudart_utils.cpp test/util/device_atomics.cu test/util/integer_utils.cpp diff --git a/cpp/test/util/bitset.cu b/cpp/test/core/bitset.cu similarity index 90% rename from cpp/test/util/bitset.cu rename to cpp/test/core/bitset.cu index 4793dde2f1..215de98aaf 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -16,16 +16,16 @@ #include "../test_utils.cuh" +#include #include #include -#include #include #include #include -namespace raft::util { +namespace raft::core { struct test_spec_bitset { uint64_t bitset_len; @@ -109,10 +109,9 @@ class BitsetTest : public testing::TestWithParam { resource::sync_stream(res, stream); // calculate the results - auto test_bitset = raft::util::bitset( + auto my_bitset = raft::core::bitset( res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len)); - update_host( - bitset_result.data(), test_bitset.view().data_handle(), bitset_result.size(), stream); + update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); // calculate the reference create_cpu_bitset(bitset_ref, mask_cpu); @@ -128,8 +127,7 @@ class BitsetTest : public testing::TestWithParam { // Create queries and verify the test results raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); - raft::util::bitset_test( - res, test_bitset.view(), raft::make_const_mdspan(query_device.view()), result_device.view()); + my_bitset.test(res, raft::make_const_mdspan(query_device.view()), result_device.view()); update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); test_cpu_bitset(bitset_ref, query_cpu, result_ref); resource::sync_stream(res, stream); @@ -139,16 +137,16 @@ class BitsetTest : public testing::TestWithParam { raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); - raft::util::bitset_set(res, test_bitset.view(), mask_device.view()); - update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream); + my_bitset.set(res, mask_device.view()); + update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); add_cpu_bitset(bitset_ref, mask_cpu); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); // Flip the bitset and re-test - raft::util::bitset_flip(res, test_bitset.view()); - update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream); + my_bitset.flip(res); + update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); flip_cpu_bitset(bitset_ref); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); @@ -187,4 +185,4 @@ using Uint64_64 = BitsetTest; TEST_P(Uint64_64, Run) { run(); } INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_64, inputs_bitset); -} // namespace raft::util +} // namespace raft::core diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index 7e69f92948..39e57fd69a 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -20,4 +20,5 @@ expose in public APIs. core_nvtx.rst core_interruptible.rst core_operators.rst - core_math.rst \ No newline at end of file + core_math.rst + core_bitset.rst \ No newline at end of file diff --git a/docs/source/cpp_api/core_bitset.rst b/docs/source/cpp_api/core_bitset.rst new file mode 100644 index 0000000000..af1cff6d37 --- /dev/null +++ b/docs/source/cpp_api/core_bitset.rst @@ -0,0 +1,15 @@ +Bitset +====== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::core* + +.. doxygengroup:: bitset + :project: RAFT + :members: + :content-only: \ No newline at end of file diff --git a/docs/source/cpp_api/utils.rst b/docs/source/cpp_api/utils.rst index ccdb9919ac..4471093c8b 100644 --- a/docs/source/cpp_api/utils.rst +++ b/docs/source/cpp_api/utils.rst @@ -8,18 +8,6 @@ This page provides C++ API references for the publicly-exposed utility functions :language: c++ :class: highlight -Bitset ------- - -``#include `` - -namespace *raft::utils* - -.. doxygengroup:: bitset - :project: RAFT - :members: - :content-only: - Memory Pool -----------