From b29b6089090927d7a6238c1a4677d6d919d95e15 Mon Sep 17 00:00:00 2001 From: Muhammad Haseeb <14217455+mhaseeb123@users.noreply.github.com> Date: Fri, 1 Nov 2024 16:35:29 -0700 Subject: [PATCH] Add a bitset validatation test for `cuco::arrow_filter_policy` (#633) This PR adds a tests to validate the bitset from inserting specific keys to a `cuco::bloom_filter` with `cuco::arrow_filter_policy` against the one generated by inserting the same keys to the implementation in Arrow. Related to #625. Part of https://github.com/rapidsai/cudf/issues/17164. Reference bitset gen with arrow here: https://godbolt.org/z/ebdddezbP --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/CMakeLists.txt | 4 +- tests/bloom_filter/arrow_policy_test.cu | 164 ++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 tests/bloom_filter/arrow_policy_test.cu diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bc7cc697f..05ceca69d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -142,4 +142,6 @@ ConfigureTest(HYPERLOGLOG_TEST ################################################################################################### # - bloom_filter ---------------------------------------------------------------------------------- ConfigureTest(BLOOM_FILTER_TEST - bloom_filter/unique_sequence_test.cu) + bloom_filter/unique_sequence_test.cu + bloom_filter/arrow_policy_test.cu + ) diff --git a/tests/bloom_filter/arrow_policy_test.cu b/tests/bloom_filter/arrow_policy_test.cu new file mode 100644 index 000000000..1b7f5384f --- /dev/null +++ b/tests/bloom_filter/arrow_policy_test.cu @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include + +#include + +#include +#include + +namespace { + +template +thrust::device_vector get_arrow_filter_reference_bitset() +{ + static std::vector> const reference_bitsets{ + {4294752255, + 928963967, + 4227333887, + 3183462382, + 3892030683, + 3481206270, + 3513757613, + 3220961761, + 3186616955, + 4026531705, + 4110408887, + 804913147, + 1039007726, + 4286569403, + 2675948542, + 3688689479}, // type = int32, blocks = 2, num_keys = 100 + {2290897413, 3368027184, 2432735301, 2013315170, 610406792, 35787348, 43061541, + 1145143906, 238486532, 2840527950, 241188878, 624061504, 759830680, 184694210, + 2282459916, 3232258264, 285316692, 3284142851, 2760958614, 2974341265, 38749317, + 2655160577, 2193666087, 261196816, 411328595, 5391621, 2308014147, 2550892738, + 1224755395, 1396835974, 3227911200, 307324929}, // type = int64, blocks = 4, num_keys = 50 + {3037098621, 1001208422, 3070541682, 3611620780, 372254302, 2869772027, 2629135999, + 3332804862, 2832966981, 1225184253, 1315442262, 211922492, 1020510327, 2725704195, + 2909038118, 2783622989, 4214109798, 535934391, 2385459605, 4109595381, 3219664733, + 3164400602, 1995984498, 2917029602, 3047576211, 2212973933, 1672737343, 300902378, + 3000318461, 1561320274, 2710202091, 3067275349, 2734901244, 2638172076, 3669981206, + 3719000395, 793729452, 2258222966, 4111863618, 2391109497, 240119500, 855317864, + 2893522276, 1103034386, 738173080, 4098968587, 1271241025, 499361504, 4174530401, + 3259956170, 3823469907, 578271374, 3168397042, 3890816473, 431898609, 1583427570, + 1835797371, 2078281027, 2741410265, 2639785266, 3422606831, 1589476610, 3972396492, + 3611525326} // type = float, blocks = 8, num_keys = 200 + }; + + if constexpr (std::is_same_v) { + return reference_bitsets[0]; // int32 + } else if constexpr (std::is_same_v) { + return reference_bitsets[1]; // int64 + } else if constexpr (std::is_same_v) { + return reference_bitsets[2]; // float + } else { + throw std::invalid_argument("Reference bitsets available for int32, int64, float only.\n\n"); + } +} + +template +std::pair get_arrow_filter_test_settings() +{ + static std::vector> const test_settings = { + {2, 100}, // type = int32, blocks = 2, num_keys = 100 + {4, 50}, // type = int64, blocks = 4, num_keys = 50 + {8, 200} // type = float, blocks = 8, num_keys = 200 + }; + + if constexpr (std::is_same_v) { + return test_settings[0]; // int32 + } else if constexpr (std::is_same_v) { + return test_settings[1]; // int64 + } else if constexpr (std::is_same_v) { + return test_settings[2]; // float + } else { + throw std::invalid_argument("Test settings available for int32, int64, float only.\n\n"); + } +} + +template +std::vector random_values(size_t size) +{ + std::vector values(size); + + using uniform_distribution = + typename std::conditional_t, + std::bernoulli_distribution, + std::conditional_t, + std::uniform_real_distribution, + std::uniform_int_distribution>>; + + static constexpr auto seed = 0xf00d; + static std::mt19937 engine{seed}; + static uniform_distribution dist{}; + std::generate_n(values.begin(), size, [&]() { return Key{dist(engine)}; }); + + return values; +} + +} // namespace + +template +void test_filter_bitset(Filter& filter, size_t num_keys) +{ + using key_type = typename Filter::key_type; + using word_type = typename Filter::word_type; + + // Generate keys + auto const h_keys = random_values(num_keys); + thrust::device_vector d_keys(h_keys.begin(), h_keys.end()); + + // Insert to the bloom filter + filter.add(d_keys.begin(), d_keys.begin() + num_keys); + + // Get reference words device_vector + auto const reference_words = get_arrow_filter_reference_bitset(); + + // Number of words in the filter + auto const num_words = filter.block_extent() * filter.words_per_block; + + // Get the bitset + thrust::device_vector filter_words(filter.data(), filter.data() + num_words); + + REQUIRE(cuco::test::equal( + filter_words.begin(), + filter_words.end(), + reference_words.begin(), + cuda::proclaim_return_type([] __device__(auto const& filter_word, auto const& ref_word) { + return filter_word == ref_word; + }))); +} + +TEMPLATE_TEST_CASE_SIG( + "Arrow filter policy bitset validation", "", (class Key), (int32_t), (int64_t), (float)) +{ + // Get test settings + auto const [sub_filters, num_keys] = get_arrow_filter_test_settings(); + + using policy_type = cuco::arrow_filter_policy; + cuco::bloom_filter, cuda::thread_scope_device, policy_type> filter{ + sub_filters}; + + test_filter_bitset(filter, num_keys); +}