From 0cd4da08be0289b20306ec44a68044668730c0a9 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Fri, 15 Sep 2023 10:37:40 -0700 Subject: [PATCH] Clean up ref implementations with `has_payload` flag (#368) #356 introduces the `HasPayload` template boolean to distinguish code paths between map and set implementations thus the key input for base ref insert functions becomes redundant. This PR cleans up the base ref implementations by removing the key input and fixes a logical issue in #356: set doesn't have payload while map has. --- .../cuco/detail/open_addressing_ref_impl.cuh | 55 +++++++++++++------ .../cuco/detail/static_map/static_map_ref.inl | 16 +++--- .../cuco/detail/static_set/static_set_ref.inl | 16 +++--- 3 files changed, 54 insertions(+), 33 deletions(-) diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index 46ef2bfd7..213d35af1 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -159,18 +159,23 @@ class open_addressing_ref_impl { * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * - * @param key Key of the element to insert * @param value The element to insert * @param predicate Predicate used to compare slot content against `key` * * @return True if the given element is successfully inserted */ template - __device__ bool insert(key_type const& key, - value_type const& value, - Predicate const& predicate) noexcept + __device__ bool insert(value_type const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + + auto const key = [&]() { + if constexpr (HasPayload) { + return value.first; + } else { + return value; + } + }(); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); while (true) { @@ -202,7 +207,6 @@ class open_addressing_ref_impl { * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert - * @param key Key of the element to insert * @param value The element to insert * @param predicate Predicate used to compare slot content against `key` * @@ -210,10 +214,16 @@ class open_addressing_ref_impl { */ template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - key_type const& key, value_type const& value, Predicate const& predicate) noexcept { + auto const key = [&]() { + if constexpr (HasPayload) { + return value.first; + } else { + return value; + } + }(); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); while (true) { @@ -269,7 +279,6 @@ class open_addressing_ref_impl { * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * - * @param key Key of the element to insert * @param value The element to insert * @param predicate Predicate used to compare slot content against `key` * @@ -277,11 +286,18 @@ class open_addressing_ref_impl { * insertion is successful or not. */ template - __device__ thrust::pair insert_and_find(key_type const& key, - value_type const& value, + __device__ thrust::pair insert_and_find(value_type const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + + auto const key = [&]() { + if constexpr (HasPayload) { + return value.first; + } else { + return value; + } + }(); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); while (true) { @@ -326,7 +342,6 @@ class open_addressing_ref_impl { * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert_and_find - * @param key Key of the element to insert * @param value The element to insert * @param predicate Predicate used to compare slot content against `key` * @@ -336,10 +351,16 @@ class open_addressing_ref_impl { template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, - key_type const& key, value_type const& value, Predicate const& predicate) noexcept { + auto const key = [&]() { + if constexpr (HasPayload) { + return value.first; + } else { + return value; + } + }(); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); while (true) { @@ -710,11 +731,11 @@ class open_addressing_ref_impl { auto* old_ptr = reinterpret_cast(&old); auto const inserted = [&]() { if constexpr (HasPayload) { - // If it's a set implementation, compare the whole slot content - return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); - } else { // If it's a map implementation, compare keys only return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); + } else { + // If it's a set implementation, compare the whole slot content + return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); } }(); if (inserted) { @@ -723,11 +744,11 @@ class open_addressing_ref_impl { // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare auto const res = [&]() { if constexpr (HasPayload) { - // If it's a set implementation, compare the whole slot content - return predicate.equal_to(*old_ptr, value); - } else { // If it's a map implementation, compare keys only return predicate.equal_to(old_ptr->first, value.first); + } else { + // If it's a set implementation, compare the whole slot content + return predicate.equal_to(*old_ptr, value); } }(); return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 13fc2ce47..250c84feb 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -210,8 +210,8 @@ class operator_impl< __device__ bool insert(value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert(value.first, value, ref_.predicate_); + auto constexpr has_payload = true; + return ref_.impl_.insert(value, ref_.predicate_); } /** @@ -225,8 +225,8 @@ class operator_impl< value_type const& value) noexcept { auto& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert(group, value.first, value, ref_.predicate_); + auto constexpr has_payload = true; + return ref_.impl_.insert(group, value, ref_.predicate_); } }; @@ -454,8 +454,8 @@ class operator_impl< __device__ thrust::pair insert_and_find(value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert_and_find(value.first, value, ref_.predicate_); + auto constexpr has_payload = true; + return ref_.impl_.insert_and_find(value, ref_.predicate_); } /** @@ -475,8 +475,8 @@ class operator_impl< cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert_and_find(group, value.first, value, ref_.predicate_); + auto constexpr has_payload = true; + return ref_.impl_.insert_and_find(group, value, ref_.predicate_); } }; diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 3131f3764..2bb7f0c6f 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -101,8 +101,8 @@ class operator_impl(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert(value, value, ref_.predicate_); + auto constexpr has_payload = false; + return ref_.impl_.insert(value, ref_.predicate_); } /** @@ -117,8 +117,8 @@ class operator_impl(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert(group, value, value, ref_.predicate_); + auto constexpr has_payload = false; + return ref_.impl_.insert(group, value, ref_.predicate_); } }; @@ -182,8 +182,8 @@ class operator_impl insert_and_find(value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert_and_find(value, value, ref_.predicate_); + auto constexpr has_payload = false; + return ref_.impl_.insert_and_find(value, ref_.predicate_); } /** @@ -203,8 +203,8 @@ class operator_impl const& group, value_type const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert_and_find(group, value, value, ref_.predicate_); + auto constexpr has_payload = false; + return ref_.impl_.insert_and_find(group, value, ref_.predicate_); } };