diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing_ref_impl.cuh index 99187cc51..4aa701759 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing_ref_impl.cuh @@ -132,6 +132,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param key Key of the element to insert @@ -140,7 +141,7 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(key_type const& key, value_type const& value, Predicate const& predicate) noexcept @@ -158,7 +159,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EQUAL) { return false; } if (eq_res == detail::equal_result::EMPTY) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); - switch (attempt_insert( + switch (attempt_insert( (storage_ref_.data() + *probing_iter)->data() + intra_window_index, value, predicate)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; @@ -173,6 +174,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert @@ -182,7 +184,7 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, key_type const& key, value_type const& value, @@ -214,9 +216,10 @@ class open_addressing_ref_impl { auto const src_lane = __ffs(group_contains_empty) - 1; auto const status = (group.thread_rank() == src_lane) - ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, - value, - predicate) + ? attempt_insert( + (storage_ref_.data() + *probing_iter)->data() + intra_window_index, + value, + predicate) : insert_result::CONTINUE; switch (group.shfl(status, src_lane)) { @@ -237,6 +240,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param key Key of the element to insert @@ -246,7 +250,7 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find(key_type const& key, value_type const& value, Predicate const& predicate) noexcept @@ -266,7 +270,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EMPTY) { switch ([&]() { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(window_ptr + i, value, predicate); + return packed_cas(window_ptr + i, value, predicate); } else { return cas_dependent_write(window_ptr + i, value, predicate); } @@ -292,6 +296,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param group The Cooperative Group used to perform group insert_and_find @@ -302,7 +307,7 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, key_type const& key, @@ -343,7 +348,7 @@ class open_addressing_ref_impl { auto const status = [&]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, value, predicate); + return packed_cas(slot_ptr, value, predicate); } else { return cas_dependent_write(slot_ptr, value, predicate); } @@ -649,6 +654,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with one single CAS operation. * + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -657,20 +663,37 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, value_type const& value, Predicate const& predicate) noexcept { - auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value); - auto* old_ptr = reinterpret_cast(&old); - if (cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_)) { + auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value); + auto* old_ptr = reinterpret_cast(&old); + auto const inserted = [&]() { + if constexpr (HasPayload) { + // If it's a set implementation, compare the whole slot content + return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_); + } else { + // If it's a map implementation, compare keys only + return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); + } + }(); + if (inserted) { return insert_result::SUCCESS; } else { // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - return predicate.equal_to(*old_ptr, value) == detail::equal_result::EQUAL - ? insert_result::DUPLICATE - : insert_result::CONTINUE; + auto const res = [&]() { + if constexpr (HasPayload) { + // If it's a set implementation, compare the whole slot content + return predicate.equal_to(*old_ptr, value); + } else { + // If it's a map implementation, compare keys only + return predicate.equal_to(old_ptr->first, value.first); + } + }(); + return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE + : insert_result::CONTINUE; } } @@ -761,6 +784,7 @@ class open_addressing_ref_impl { * @note Dispatches the correct implementation depending on the container * type and presence of other operator mixins. * + * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Predicate Predicate type * * @param slot Pointer to the slot in memory @@ -769,13 +793,13 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, value_type const& value, Predicate const& predicate) noexcept { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot, value, predicate); + return packed_cas(slot, value, predicate); } else { #if (_CUDA_ARCH__ < 700) return cas_dependent_write(slot, value, predicate); diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index f3c412924..536973b20 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -209,8 +209,9 @@ class operator_impl< */ __device__ bool insert(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert(value.first, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = false; + return ref_.impl_.insert(value.first, value, ref_.predicate_); } /** @@ -223,8 +224,9 @@ class operator_impl< __device__ bool insert(cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - auto& ref_ = static_cast(*this); - return ref_.impl_.insert(group, value.first, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + auto constexpr has_payload = false; + return ref_.impl_.insert(group, value.first, value, ref_.predicate_); } }; @@ -289,8 +291,9 @@ class operator_impl< */ __device__ thrust::pair insert_and_find(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert_and_find(value.first, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = false; + return ref_.impl_.insert_and_find(value.first, value, ref_.predicate_); } /** @@ -309,8 +312,9 @@ class operator_impl< __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert_and_find(group, value.first, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = false; + return ref_.impl_.insert_and_find(group, value.first, value, ref_.predicate_); } }; diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 3482738cc..3131f3764 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -100,8 +100,9 @@ class operator_impl(*this); - return ref_.impl_.insert(value, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = true; + return ref_.impl_.insert(value, value, ref_.predicate_); } /** @@ -115,8 +116,9 @@ class operator_impl const& group, value_type const& value) noexcept { - auto& ref_ = static_cast(*this); - return ref_.impl_.insert(group, value, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + auto constexpr has_payload = true; + return ref_.impl_.insert(group, value, value, ref_.predicate_); } }; @@ -179,8 +181,9 @@ class operator_impl insert_and_find(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert_and_find(value, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = true; + return ref_.impl_.insert_and_find(value, value, ref_.predicate_); } /** @@ -199,8 +202,9 @@ class operator_impl insert_and_find( cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - return ref_.impl_.insert_and_find(group, value, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + auto constexpr has_payload = true; + return ref_.impl_.insert_and_find(group, value, value, ref_.predicate_); } };