Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
esoha-nvidia committed Sep 27, 2023
1 parent aa16f99 commit faa7810
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
4 changes: 2 additions & 2 deletions include/cuco/detail/equal_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ struct equal_wrapper {
*
* @return `EQUAL` if `lhs` and `rhs` are equivalent. `UNEQUAL` otherwise.
*/
template <typename U>
__device__ constexpr equal_result equal_to(T const& lhs, U const& rhs) const noexcept
template <typename local_T, typename U>
__device__ constexpr equal_result equal_to(local_T const& lhs, U const& rhs) const noexcept
{
return equal_(lhs, rhs) ? equal_result::EQUAL : equal_result::UNEQUAL;
}
Expand Down
14 changes: 7 additions & 7 deletions include/cuco/detail/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <bool HasPayload, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value,
template <bool HasPayload, typename Predicate, typename local_value_type>
__device__ thrust::pair<iterator, bool> insert_and_find(local_value_type const& value,
Predicate const& predicate) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
Expand Down Expand Up @@ -346,10 +346,10 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Predicate, typename local_value_type>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value,
local_value_type const& value,
Predicate const& predicate) noexcept
{
auto const key = [&]() {
Expand Down Expand Up @@ -720,12 +720,12 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <bool HasPayload, typename Predicate>
template <bool HasPayload, typename Predicate, typename local_value_type>
[[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot,
value_type const& value,
local_value_type const& value,
Predicate const& predicate) noexcept
{
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, value);
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast<value_type>(value));
auto* old_ptr = reinterpret_cast<value_type*>(&old);
auto const inserted = [&]() {
if constexpr (HasPayload) {
Expand Down
3 changes: 2 additions & 1 deletion include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,9 @@ class operator_impl<op::insert_and_find_tag,
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <typename local_value_type>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
cooperative_groups::thread_block_tile<cg_size> const& group, local_value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
Expand Down

0 comments on commit faa7810

Please sign in to comment.