diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index 2c23a77e47..cafd1977ab 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -39,7 +39,7 @@ _RAFT_HOST_DEVICE void bitmap_view::set(const index_t row, const index_t col, bool new_value) const { - set(row * cols_ + col, &new_value); + set(row * cols_ + col, new_value); } } // end namespace raft::core diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.hpp index 5c77866164..86b2d77478 100644 --- a/cpp/include/raft/core/bitmap.hpp +++ b/cpp/include/raft/core/bitmap.hpp @@ -41,6 +41,9 @@ namespace raft::core { */ template struct bitmap_view : public bitset_view { + using bitset_view::set; + using bitset_view::test; + static_assert((std::is_same::type, uint32_t>::value || std::is_same::type, uint64_t>::value), "The bitmap_t must be uint32_t or uint64_t."); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index ed923fb1db..08541ad135 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -113,6 +113,7 @@ if(BUILD_TESTS) NAME CORE_TEST PATH + core/bitmap.cu core/bitset.cu core/device_resources_manager.cpp core/device_setter.cpp diff --git a/cpp/test/core/bitmap.cu b/cpp/test/core/bitmap.cu new file mode 100644 index 0000000000..358c08a50f --- /dev/null +++ b/cpp/test/core/bitmap.cu @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft::core { + +template +struct test_spec_bitmap { + index_t rows; + index_t cols; + index_t mask_len; + index_t query_len; +}; + +template +auto operator<<(std::ostream& os, const test_spec_bitmap& ss) -> std::ostream& +{ + os << "bitmap{rows: " << ss.rows << ", cols: " << ss.cols << ", mask_len: " << ss.mask_len + << ", query_len: " << ss.query_len << "}"; + return os; +} + +template +void create_cpu_bitmap(std::vector& bitmap, + std::vector& mask_idx, + const index_t rows, + const index_t cols) +{ + for (size_t i = 0; i < bitmap.size(); i++) { + bitmap[i] = ~bitmap_t(0x00); + } + constexpr size_t bitmap_element_size = sizeof(bitmap_t) * 8; + for (size_t i = 0; i < mask_idx.size(); i++) { + auto row = mask_idx[i] / cols; + auto col = mask_idx[i] % cols; + auto idx = row * cols + col; + bitmap[idx / bitmap_element_size] &= ~(bitmap_t{1} << (idx % bitmap_element_size)); + } +} + +template +void test_cpu_bitmap(const std::vector& bitmap, + const std::vector& queries, + std::vector& result, + index_t rows, + index_t cols) +{ + constexpr size_t bitmap_element_size = sizeof(bitmap_t) * 8; + for (size_t i = 0; i < queries.size(); i++) { + auto row = queries[i] / cols; + auto col = queries[i] % cols; + auto idx = row * cols + col; + result[i] = uint8_t( + (bitmap[idx / bitmap_element_size] & (bitmap_t{1} << (idx % bitmap_element_size))) != 0); + } +} + +template +class BitmapTest : public testing::TestWithParam> { + protected: + index_t static constexpr const bitmap_element_size = sizeof(bitmap_t) * 8; + const test_spec_bitmap spec; + std::vector bitmap_result; + std::vector bitmap_ref; + raft::resources res; + + public: + explicit BitmapTest() + : spec(testing::TestWithParam>::GetParam()), + bitmap_result(raft::ceildiv(spec.rows * spec.cols, index_t(bitmap_element_size))), + bitmap_ref(raft::ceildiv(spec.rows * spec.cols, index_t(bitmap_element_size))) + { + } + + void run() + { + auto stream = resource::get_cuda_stream(res); + + raft::random::RngState rng(42); + auto mask_device = raft::make_device_vector(res, spec.mask_len); + std::vector mask_cpu(spec.mask_len); + raft::random::uniformInt( + res, rng, mask_device.view(), index_t(0), index_t(spec.rows * spec.cols)); + raft::update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); + resource::sync_stream(res, stream); + + create_cpu_bitmap(bitmap_ref, mask_cpu, spec.rows, spec.cols); + + auto bitset_d = raft::core::bitset( + res, raft::make_const_mdspan(mask_device.view()), index_t(spec.rows * spec.cols)); + + auto bitmap_view_d = + raft::core::bitmap_view(bitset_d.data(), spec.rows, spec.cols); + + ASSERT_EQ(bitmap_view_d.get_n_rows(), spec.rows); + ASSERT_EQ(bitmap_view_d.get_n_cols(), spec.cols); + + auto query_device = raft::make_device_vector(res, spec.query_len); + auto result_device = raft::make_device_vector(res, spec.query_len); + auto query_cpu = std::vector(spec.query_len); + auto result_cpu = std::vector(spec.query_len); + auto result_ref = std::vector(spec.query_len); + + raft::random::uniformInt( + res, rng, query_device.view(), index_t(0), index_t(spec.rows * spec.cols)); + raft::update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); + + auto queries_device_view = + raft::make_device_vector_view(query_device.data_handle(), spec.query_len); + + raft::linalg::map( + res, + result_device.view(), + [bitmap_view_d] __device__(index_t query) { + auto row = query / bitmap_view_d.get_n_cols(); + auto col = query % bitmap_view_d.get_n_cols(); + return (uint8_t)(bitmap_view_d.test(row, col)); + }, + queries_device_view); + + raft::update_host(result_cpu.data(), result_device.data_handle(), query_device.size(), stream); + resource::sync_stream(res, stream); + + test_cpu_bitmap(bitmap_ref, query_cpu, result_ref, spec.rows, spec.cols); + + ASSERT_TRUE(hostVecMatch(result_cpu, result_ref, Compare())); + + raft::random::uniformInt( + res, rng, mask_device.view(), index_t(0), index_t(spec.rows * spec.cols)); + raft::update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); + resource::sync_stream(res, stream); + + thrust::for_each_n(raft::resource::get_thrust_policy(res), + mask_device.data_handle(), + mask_device.extent(0), + [bitmap_view_d] __device__(const index_t sample_index) { + auto row = sample_index / bitmap_view_d.get_n_cols(); + auto col = sample_index % bitmap_view_d.get_n_cols(); + bitmap_view_d.set(row, col, false); + }); + + raft::update_host(bitmap_result.data(), bitmap_view_d.data(), bitmap_result.size(), stream); + + for (size_t i = 0; i < mask_cpu.size(); i++) { + auto row = mask_cpu[i] / spec.cols; + auto col = mask_cpu[i] % spec.cols; + auto idx = row * spec.cols + col; + bitmap_ref[idx / bitmap_element_size] &= ~(bitmap_t{1} << (idx % bitmap_element_size)); + } + resource::sync_stream(res, stream); + ASSERT_TRUE(hostVecMatch(bitmap_ref, bitmap_result, raft::Compare())); + } +}; + +template +auto inputs_bitmap = + ::testing::Values(test_spec_bitmap{32, 32, 5, 10}, + test_spec_bitmap{100, 100, 30, 10}, + test_spec_bitmap{1024, 1024, 55, 100}, + test_spec_bitmap{10000, 10000, 1000, 1000}, + test_spec_bitmap{1 << 15, 1 << 15, 1 << 3, 1 << 12}, + test_spec_bitmap{1 << 15, 1 << 15, 1 << 24, 1 << 13}); + +using BitmapTest_Uint32_32 = BitmapTest; +TEST_P(BitmapTest_Uint32_32, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitmapTest, BitmapTest_Uint32_32, inputs_bitmap); + +using BitmapTest_Uint64_32 = BitmapTest; +TEST_P(BitmapTest_Uint64_32, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitmapTest, BitmapTest_Uint64_32, inputs_bitmap); + +using BitmapTest_Uint32_64 = BitmapTest; +TEST_P(BitmapTest_Uint32_64, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitmapTest, BitmapTest_Uint32_64, inputs_bitmap); + +using BitmapTest_Uint64_64 = BitmapTest; +TEST_P(BitmapTest_Uint64_64, Run) { run(); } +INSTANTIATE_TEST_CASE_P(BitmapTest, BitmapTest_Uint64_64, inputs_bitmap); + +} // namespace raft::core