From 5b4a80e08c711386edb7c19eefe5f787828abb6b Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Sat, 2 Nov 2024 13:15:17 -0700 Subject: [PATCH] Add consistent `for_each` APIs for cuco hash tables (#632) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds host and device `for_each` APIs for all cuco hash tables. --------- Co-authored-by: Daniel Jünger Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../cuco/detail/static_map/static_map_ref.inl | 7 +- .../static_multimap/static_multimap.inl | 67 ++++++ .../static_multiset/static_multiset.inl | 150 +++++++++---- include/cuco/detail/static_set/static_set.inl | 60 ++++++ .../cuco/detail/static_set/static_set_ref.inl | 69 ++++++ include/cuco/static_map.cuh | 8 +- include/cuco/static_multimap.cuh | 68 ++++++ include/cuco/static_multiset.cuh | 200 ++++++++++++------ include/cuco/static_set.cuh | 72 ++++++- tests/CMakeLists.txt | 1 + tests/static_set/for_each_test.cu | 103 +++++++++ 11 files changed, 685 insertions(+), 120 deletions(-) create mode 100644 tests/static_set/for_each_test.cu diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 3a1ed5ddd..662667b3e 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -21,9 +21,8 @@ #include #include -#include #include -#include +#include #include @@ -1335,7 +1334,7 @@ class operator_impl< { // CRTP: cast `this` to the actual ref type auto const& ref_ = static_cast(*this); - ref_.impl_.for_each(key, std::forward(callback_op)); + ref_.impl_.for_each(key, cuda::std::forward(callback_op)); } /** @@ -1363,7 +1362,7 @@ class operator_impl< { // CRTP: cast `this` to the actual ref type auto const& ref_ = static_cast(*this); - ref_.impl_.for_each(group, key, std::forward(callback_op)); + ref_.impl_.for_each(group, key, cuda::std::forward(callback_op)); } }; diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index 7048a5426..965e14f3d 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -315,6 +315,73 @@ void static_multimapfind_async(first, last, output_begin, ref(op::find), stream); } +template +template +void static_multimap::for_each( + CallbackOp&& callback_op, cuda::stream_ref stream) const +{ + impl_->for_each_async(std::forward(callback_op), stream); + stream.wait(); +} + +template +template +void static_multimap:: + for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const +{ + impl_->for_each_async(std::forward(callback_op), stream); +} + +template +template +void static_multimap::for_each( + InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const +{ + impl_->for_each_async( + first, last, std::forward(callback_op), ref(op::for_each), stream); + stream.wait(); +} + +template +template +void static_multimap:: + for_each_async(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream) const noexcept +{ + impl_->for_each_async( + first, last, std::forward(callback_op), ref(op::for_each), stream); +} + template -template -std::pair -static_multiset::retrieve( - InputProbeIt first, - InputProbeIt last, - OutputProbeIt output_probe, - OutputMatchIt output_match, - cuda::stream_ref stream) const +template +void static_multiset::for_each( + CallbackOp&& callback_op, cuda::stream_ref stream) const { - return this->impl_->retrieve( - first, last, output_probe, output_match, this->ref(op::retrieve), stream); + impl_->for_each_async(std::forward(callback_op), stream); + stream.wait(); } template -template -std::pair -static_multiset::retrieve( - InputProbeIt first, - InputProbeIt last, - ProbeEqual const& probe_equal, - ProbeHash const& probe_hash, - OutputProbeIt output_probe, - OutputMatchIt output_match, - cuda::stream_ref stream) const +template +void static_multiset:: + for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const { - auto const probe_ref = - this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash); - return this->impl_->retrieve(first, last, output_probe, output_match, probe_ref, stream); + impl_->for_each_async(std::forward(callback_op), stream); } template -template -std::pair -static_multiset::retrieve_outer( - InputProbeIt first, - InputProbeIt last, - ProbeEqual const& probe_equal, - ProbeHash const& probe_hash, - OutputProbeIt output_probe, - OutputMatchIt output_match, - cuda::stream_ref stream) const +template +void static_multiset::for_each( + InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const { - auto const probe_ref = - this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash); - return this->impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream); + impl_->for_each_async( + first, last, std::forward(callback_op), ref(op::for_each), stream); + stream.wait(); +} + +template +template +void static_multiset:: + for_each_async(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream) const noexcept +{ + impl_->for_each_async( + first, last, std::forward(callback_op), ref(op::for_each), stream); } template stream); } +template +template +std::pair +static_multiset::retrieve( + InputProbeIt first, + InputProbeIt last, + OutputProbeIt output_probe, + OutputMatchIt output_match, + cuda::stream_ref stream) const +{ + return impl_->retrieve(first, last, output_probe, output_match, this->ref(op::retrieve), stream); +} + +template +template +std::pair +static_multiset::retrieve( + InputProbeIt first, + InputProbeIt last, + ProbeEqual const& probe_equal, + ProbeHash const& probe_hash, + OutputProbeIt output_probe, + OutputMatchIt output_match, + cuda::stream_ref stream) const +{ + auto const probe_ref = + this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash); + return impl_->retrieve(first, last, output_probe, output_match, probe_ref, stream); +} + +template +template +std::pair +static_multiset::retrieve_outer( + InputProbeIt first, + InputProbeIt last, + ProbeEqual const& probe_equal, + ProbeHash const& probe_hash, + OutputProbeIt output_probe, + OutputMatchIt output_match, + cuda::stream_ref stream) const +{ + auto const probe_ref = + this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash); + return impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream); +} + template impl_->find_async(first, last, output_begin, ref(op::find), stream); } +template +template +void static_set::for_each( + CallbackOp&& callback_op, cuda::stream_ref stream) const +{ + impl_->for_each_async(std::forward(callback_op), stream); + stream.wait(); +} + +template +template +void static_set::for_each_async( + CallbackOp&& callback_op, cuda::stream_ref stream) const +{ + impl_->for_each_async(std::forward(callback_op), stream); +} + +template +template +void static_set::for_each( + InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const +{ + impl_->for_each_async( + first, last, std::forward(callback_op), ref(op::for_each), stream); + stream.wait(); +} + +template +template +void static_set::for_each_async( + InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const noexcept +{ + impl_->for_each_async( + first, last, std::forward(callback_op), ref(op::for_each), stream); +} + template #include +#include #include @@ -629,6 +630,74 @@ class operator_impl +class operator_impl> { + using base_type = static_set_ref; + using ref_type = static_set_ref; + using key_type = typename base_type::key_type; + using value_type = typename base_type::value_type; + using iterator = typename base_type::iterator; + using const_iterator = typename base_type::const_iterator; + + static constexpr auto cg_size = base_type::cg_size; + static constexpr auto window_size = base_type::window_size; + + public: + /** + * @brief For a given key, applies the function object `callback_op` to its match found in the + * container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * + * @param key The key to search for + * @param callback_op Function to apply to the copy of the matched slot + */ + template + __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(key, cuda::std::forward(callback_op)); + } + + /** + * @brief For a given key, applies the function object `callback_op` to its match found in the + * container. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching slot. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to apply to the copy of the matched slot + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile const& group, + ProbeKey const& key, + CallbackOp&& callback_op) const noexcept + { + // CRTP: cast `this` to the actual ref type + auto const& ref_ = static_cast(*this); + ref_.impl_.for_each(group, key, cuda::std::forward(callback_op)); + } +}; + template @@ -789,7 +789,7 @@ class static_map { * * @tparam CallbackOp Type of unary callback function object * - * @param callback_op Function to apply to the copy of the matched key-value pair + * @param callback_op Function to apply to the copy of the filled slot * @param stream CUDA stream used for this operation */ template @@ -806,7 +806,7 @@ class static_map { * * @param first Beginning of the sequence of keys * @param last End of the sequence of keys - * @param callback_op Function to apply to the copy of the matched key-value pair + * @param callback_op Function to apply to the copy of the matched slot * @param stream CUDA stream used for this operation */ template @@ -826,7 +826,7 @@ class static_map { * * @param first Beginning of the sequence of keys * @param last End of the sequence of keys - * @param callback_op Function to apply to the copy of the matched key-value pair + * @param callback_op Function to apply to the copy of the matched slot * @param stream CUDA stream used for this operation */ template diff --git a/include/cuco/static_multimap.cuh b/include/cuco/static_multimap.cuh index 1cae596aa..f8067405b 100644 --- a/include/cuco/static_multimap.cuh +++ b/include/cuco/static_multimap.cuh @@ -518,6 +518,74 @@ class static_multimap { OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** + * @brief Applies the given function object `callback_op` to the copy of every filled slot in the + * container + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CallbackOp Type of unary callback function object + * + * @param callback_op Function to apply to the copy of the filled slot + * @param stream CUDA stream used for this operation + */ + template + void for_each(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const; + + /** + * @brief Asynchronously applies the given function object `callback_op` to the copy of every + * filled slot in the container + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CallbackOp Type of unary callback function object + * + * @param callback_op Function to apply to the copy of the filled slot + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const; + + /** + * @brief For each key in the range [first, last), applies the function object `callback_op` to + * the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam CallbackOp Type of unary callback function object + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param callback_op Function to apply to the copy of the matched slot + * @param stream CUDA stream used for this operation + */ + template + void for_each(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream = {}) const; + + /** + * @brief For each key in the range [first, last), asynchronously applies the function object + * `callback_op` to the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam CallbackOp Type of unary callback function object + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param callback_op Function to apply to the copy of the matched slot + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream = {}) const noexcept; + /** * @brief Counts the occurrences of keys in `[first, last)` contained in the multimap * diff --git a/include/cuco/static_multiset.cuh b/include/cuco/static_multiset.cuh index 943465c51..4cd5277d5 100644 --- a/include/cuco/static_multiset.cuh +++ b/include/cuco/static_multiset.cuh @@ -482,6 +482,140 @@ class static_multiset { OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** + * @brief Applies the given function object `callback_op` to the copy of every filled slot in the + * container + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CallbackOp Type of unary callback function object + * + * @param callback_op Function to apply to the copy of the filled slot + * @param stream CUDA stream used for this operation + */ + template + void for_each(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const; + + /** + * @brief Asynchronously applies the given function object `callback_op` to the copy of every + * filled slot in the container + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CallbackOp Type of unary callback function object + * + * @param callback_op Function to apply to the copy of the filled slot + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const; + + /** + * @brief For each key in the range [first, last), applies the function object `callback_op` to + * the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam CallbackOp Type of unary callback function object + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param callback_op Function to apply to the copy of the matched slot + * @param stream CUDA stream used for this operation + */ + template + void for_each(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream = {}) const; + + /** + * @brief For each key in the range [first, last), asynchronously applies the function object + * `callback_op` to the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam CallbackOp Type of unary callback function object + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param callback_op Function to apply to the copy of the matched slot + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream = {}) const noexcept; + + /** + * @brief Counts the occurrences of keys in `[first, last)` contained in the multiset + * + * @note This function synchronizes the given stream. + * + * @tparam Input Device accessible input iterator + * + * @param first Beginning of the sequence of keys to count + * @param last End of the sequence of keys to count + * @param stream CUDA stream used for count + * + * @return The sum of total occurrences of all keys in `[first, last)` + */ + template + size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const; + + /** + * @brief Counts the occurrences of keys in `[first, last)` contained in the multiset + * + * @note This function synchronizes the given stream. + * + * @tparam Input Device accessible input iterator + * @tparam ProbeKeyEqual Binary callable + * @tparam ProbeHash Unary hash callable + * + * @param first Beginning of the sequence of keys to count + * @param last End of the sequence of keys to count + * @param probe_key_equal Binary callable to compare two keys for equality + * @param probe_hash Unary callable to hash a given key + * @param stream CUDA stream used for count + * + * @return The sum of total occurrences of all keys in `[first, last)` + */ + template + size_type count(InputIt first, + InputIt last, + ProbeKeyEqual const& probe_key_equal, + ProbeHash const& probe_hash, + cuda::stream_ref stream = {}) const; + + /** + * @brief Counts the occurrences of keys in `[first, last)` contained in the multiset + * + * @note This function synchronizes the given stream. + * @note If a given key has no matches, its occurrence is 1. + * + * @tparam Input Device accessible input iterator + * @tparam ProbeKeyEqual Binary callable + * @tparam ProbeHash Unary hash callable + * + * @param first Beginning of the sequence of keys to count + * @param last End of the sequence of keys to count + * @param probe_key_equal Binary callable to compare two keys for equality + * @param probe_hash Unary callable to hash a given key + * @param stream CUDA stream used for count + * + * @return The sum of total occurrences of all keys in `[first, last)` where keys have no matches + * are considered to have a single occurrence. + */ + template + size_type count_outer(InputIt first, + InputIt last, + ProbeKeyEqual const& probe_key_equal, + ProbeHash const& probe_hash, + cuda::stream_ref stream = {}) const; + /** * @brief Retrieves all the slots corresponding to all keys in the range `[first, last)`. * @@ -604,72 +738,6 @@ class static_multiset { OutputMatchIt output_match, cuda::stream_ref stream = {}) const; - /** - * @brief Counts the occurrences of keys in `[first, last)` contained in the multiset - * - * @note This function synchronizes the given stream. - * - * @tparam Input Device accessible input iterator - * - * @param first Beginning of the sequence of keys to count - * @param last End of the sequence of keys to count - * @param stream CUDA stream used for count - * - * @return The sum of total occurrences of all keys in `[first, last)` - */ - template - size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const; - - /** - * @brief Counts the occurrences of keys in `[first, last)` contained in the multiset - * - * @note This function synchronizes the given stream. - * - * @tparam Input Device accessible input iterator - * @tparam ProbeKeyEqual Binary callable - * @tparam ProbeHash Unary hash callable - * - * @param first Beginning of the sequence of keys to count - * @param last End of the sequence of keys to count - * @param probe_key_equal Binary callable to compare two keys for equality - * @param probe_hash Unary callable to hash a given key - * @param stream CUDA stream used for count - * - * @return The sum of total occurrences of all keys in `[first, last)` - */ - template - size_type count(InputIt first, - InputIt last, - ProbeKeyEqual const& probe_key_equal, - ProbeHash const& probe_hash, - cuda::stream_ref stream = {}) const; - - /** - * @brief Counts the occurrences of keys in `[first, last)` contained in the multiset - * - * @note This function synchronizes the given stream. - * @note If a given key has no matches, its occurrence is 1. - * - * @tparam Input Device accessible input iterator - * @tparam ProbeKeyEqual Binary callable - * @tparam ProbeHash Unary hash callable - * - * @param first Beginning of the sequence of keys to count - * @param last End of the sequence of keys to count - * @param probe_key_equal Binary callable to compare two keys for equality - * @param probe_hash Unary callable to hash a given key - * @param stream CUDA stream used for count - * - * @return The sum of total occurrences of all keys in `[first, last)` where keys have no matches - * are considered to have a single occurrence. - */ - template - size_type count_outer(InputIt first, - InputIt last, - ProbeKeyEqual const& probe_key_equal, - ProbeHash const& probe_hash, - cuda::stream_ref stream = {}) const; - /** * @brief Gets the number of elements in the container. * diff --git a/include/cuco/static_set.cuh b/include/cuco/static_set.cuh index 8da360d75..d5f7acb95 100644 --- a/include/cuco/static_set.cuh +++ b/include/cuco/static_set.cuh @@ -351,7 +351,7 @@ class static_set { * * @tparam InputIt Device accessible random access input iterator * @tparam FoundIt Device accessible random access output iterator whose `value_type` - * is constructible from `map::iterator` type + * is constructible from `set::iterator` type * @tparam InsertedIt Device accessible random access output iterator whose `value_type` * is constructible from `bool` * @@ -379,7 +379,7 @@ class static_set { * * @tparam InputIt Device accessible random access input iterator * @tparam FoundIt Device accessible random access output iterator whose `value_type` - * is constructible from `map::iterator` type + * is constructible from `set::iterator` type * @tparam InsertedIt Device accessible random access output iterator whose `value_type` * is constructible from `bool` * @@ -590,6 +590,74 @@ class static_set { OutputIt output_begin, cuda::stream_ref stream = {}) const; + /** + * @brief Applies the given function object `callback_op` to the copy of every filled slot in the + * container + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CallbackOp Type of unary callback function object + * + * @param callback_op Function to apply to the copy of the filled slot + * @param stream CUDA stream used for this operation + */ + template + void for_each(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const; + + /** + * @brief Asynchronously applies the given function object `callback_op` to the copy of every + * filled slot in the container + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam CallbackOp Type of unary callback function object + * + * @param callback_op Function to apply to the copy of the filled slot + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream = {}) const; + + /** + * @brief For each key in the range [first, last), applies the function object `callback_op` to + * the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam CallbackOp Type of unary callback function object + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param callback_op Function to apply to the copy of the matched slot + * @param stream CUDA stream used for this operation + */ + template + void for_each(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream = {}) const; + + /** + * @brief For each key in the range [first, last), asynchronously applies the function object + * `callback_op` to the copy of all corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam CallbackOp Type of unary callback function object + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param callback_op Function to apply to the copy of the matched slot + * @param stream CUDA stream used for this operation + */ + template + void for_each_async(InputIt first, + InputIt last, + CallbackOp&& callback_op, + cuda::stream_ref stream = {}) const noexcept; + /** * @brief Counts the occurrences of keys in `[first, last)` contained in the set * diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 05ceca69d..dfe478251 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -60,6 +60,7 @@ ConfigureTest(UTILITY_TEST # - static_set tests ------------------------------------------------------------------------------ ConfigureTest(STATIC_SET_TEST static_set/capacity_test.cu + static_set/for_each_test.cu static_set/heterogeneous_lookup_test.cu static_set/insert_and_find_test.cu static_set/large_input_test.cu diff --git a/tests/static_set/for_each_test.cu b/tests/static_set/for_each_test.cu new file mode 100644 index 000000000..b854c0cf6 --- /dev/null +++ b/tests/static_set/for_each_test.cu @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +using size_type = std::size_t; + +template +void test_for_each(Set& set, size_type num_keys) +{ + using Key = typename Set::key_type; + + REQUIRE(num_keys % 2 == 0); + + cuda::stream_ref stream{}; + + // Insert keys + auto keys_begin = thrust::make_transform_iterator( + thrust::counting_iterator{0}, cuda::proclaim_return_type([] __device__(auto i) { + // generates a sequence of 1, 2, 1, 2, ... + return static_cast(i); + })); + set.insert(keys_begin, keys_begin + num_keys, stream); + + using Allocator = cuco::cuda_allocator>; + cuco::detail::counter_storage counter_storage( + Allocator{}); + counter_storage.reset(stream); + + // count the sum of all even keys + set.for_each( + [counter = counter_storage.data()] __device__(auto const slot) { + if (slot % 2 == 0) { counter->fetch_add(slot, cuda::memory_order_relaxed); } + }, + stream); + REQUIRE(counter_storage.load_to_host(stream) == 249'500); + + counter_storage.reset(stream); + + // count the sum of all odd keys + set.for_each( + thrust::counting_iterator(0), + thrust::counting_iterator(2 * num_keys), // test for false-positives + [counter = counter_storage.data()] __device__(auto const slot) { + if (!(slot % 2 == 0)) { counter->fetch_add(slot, cuda::memory_order_relaxed); } + }, + stream); + REQUIRE(counter_storage.load_to_host(stream) == 250'000); +} + +TEMPLATE_TEST_CASE_SIG( + "static_set for_each tests", + "", + ((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize), + (int32_t, cuco::test::probe_sequence::double_hashing, 1), + (int32_t, cuco::test::probe_sequence::double_hashing, 2), + (int64_t, cuco::test::probe_sequence::double_hashing, 1), + (int64_t, cuco::test::probe_sequence::double_hashing, 2), + (int32_t, cuco::test::probe_sequence::linear_probing, 1), + (int32_t, cuco::test::probe_sequence::linear_probing, 2), + (int64_t, cuco::test::probe_sequence::linear_probing, 1), + (int64_t, cuco::test::probe_sequence::linear_probing, 2)) +{ + constexpr size_type num_keys{1'000}; + using probe = std::conditional_t< + Probe == cuco::test::probe_sequence::linear_probing, + cuco::linear_probing>, + cuco::double_hashing, cuco::murmurhash3_32>>; + + using set_t = cuco::static_set, + cuda::thread_scope_device, + thrust::equal_to, + probe, + cuco::cuda_allocator, + cuco::storage<2>>; + + auto set = set_t{num_keys, cuco::empty_key{-1}}; + test_for_each(set, num_keys); +}