Skip to content

Commit

Permalink
Fix key equality check order (#479)
Browse files Browse the repository at this point in the history
Closes #474 

This PR makes the reference value always the right-hand side for key
equality checks.

The updates for heterogeneous lookup tests indicate that it will be
probably a breaking change for libcudf byte pair encoding.
  • Loading branch information
PointKernel authored May 7, 2024
1 parent a41c046 commit d4c1613
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 36 deletions.
10 changes: 5 additions & 5 deletions include/cuco/detail/equal_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ struct equal_wrapper {
/**
* @brief Order-sensitive equality operator.
*
* @note This function always compares the left-hand side element against sentinel values first
* @note This function always compares the right-hand side element against sentinel values first
* then performs a equality check with the given `equal_` callable, i.e., `equal_(lhs, rhs)`.
* @note Container (like set or map) keys MUST be always on the left-hand side.
* @note Container (like set or map) buckets MUST be always on the right-hand side.
*
* @tparam IsInsert Flag indicating whether it's an insert equality check or not. Insert probing
* stops when it's an empty or erased slot while query probing stops only when it's empty.
Expand All @@ -96,12 +96,12 @@ struct equal_wrapper {
__device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept
{
if constexpr (IsInsert == is_insert::YES) {
return (cuco::detail::bitwise_compare(lhs, empty_sentinel_) or
cuco::detail::bitwise_compare(lhs, erased_sentinel_))
return (cuco::detail::bitwise_compare(rhs, empty_sentinel_) or
cuco::detail::bitwise_compare(rhs, erased_sentinel_))
? equal_result::AVAILABLE
: this->equal_to(lhs, rhs);
} else {
return cuco::detail::bitwise_compare(lhs, empty_sentinel_) ? equal_result::EMPTY
return cuco::detail::bitwise_compare(rhs, empty_sentinel_) ? equal_result::EMPTY
: this->equal_to(lhs, rhs);
}
}
Expand Down
4 changes: 2 additions & 2 deletions include/cuco/detail/open_addressing/functors.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ struct slot_is_filled {
return slot;
}
}();
return not(cuco::detail::bitwise_compare(empty_sentinel_, key) or
cuco::detail::bitwise_compare(erased_sentinel_, key));
return not(cuco::detail::bitwise_compare(key, empty_sentinel_) or
cuco::detail::bitwise_compare(key, erased_sentinel_));
}
};

Expand Down
26 changes: 13 additions & 13 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class open_addressing_ref_impl {

for (auto& slot_content : window_slots) {
auto const eq_res =
this->predicate_.operator()<is_insert::YES>(this->extract_key(slot_content), key);
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(slot_content));

if constexpr (not allows_duplicates) {
// If the key is already in the container, return false
Expand Down Expand Up @@ -422,7 +422,7 @@ class open_addressing_ref_impl {
auto const [state, intra_window_index] = [&]() {
for (auto i = 0; i < window_size; ++i) {
switch (
this->predicate_.operator()<is_insert::YES>(this->extract_key(window_slots[i]), key)) {
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(window_slots[i]))) {
case detail::equal_result::AVAILABLE:
return window_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EQUAL: {
Expand Down Expand Up @@ -506,7 +506,7 @@ class open_addressing_ref_impl {

for (auto i = 0; i < window_size; ++i) {
auto const eq_res =
this->predicate_.operator()<is_insert::YES>(this->extract_key(window_slots[i]), key);
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(window_slots[i]));
auto* window_ptr = (storage_ref_.data() + *probing_iter)->data();

// If the key is already in the container, return false
Expand Down Expand Up @@ -579,7 +579,7 @@ class open_addressing_ref_impl {
auto res = detail::equal_result::UNEQUAL;
for (auto i = 0; i < window_size; ++i) {
res =
this->predicate_.operator()<is_insert::YES>(this->extract_key(window_slots[i]), key);
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(window_slots[i]));
if (res != detail::equal_result::UNEQUAL) { return window_probing_results{res, i}; }
}
// returns dummy index `-1` for UNEQUAL
Expand Down Expand Up @@ -662,7 +662,7 @@ class open_addressing_ref_impl {

for (auto& slot_content : window_slots) {
auto const eq_res =
this->predicate_.operator()<is_insert::NO>(this->extract_key(slot_content), key);
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot_content));

// Key doesn't exist, return false
if (eq_res == detail::equal_result::EMPTY) { return false; }
Expand Down Expand Up @@ -704,7 +704,7 @@ class open_addressing_ref_impl {
auto const [state, intra_window_index] = [&]() {
auto res = detail::equal_result::UNEQUAL;
for (auto i = 0; i < window_size; ++i) {
res = this->predicate_.operator()<is_insert::NO>(this->extract_key(window_slots[i]), key);
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(window_slots[i]));
if (res != detail::equal_result::UNEQUAL) { return window_probing_results{res, i}; }
}
// returns dummy index `-1` for UNEQUAL
Expand Down Expand Up @@ -758,7 +758,7 @@ class open_addressing_ref_impl {
auto const window_slots = storage_ref_[*probing_iter];

for (auto& slot_content : window_slots) {
switch (this->predicate_.operator()<is_insert::NO>(this->extract_key(slot_content), key)) {
switch (this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot_content))) {
case detail::equal_result::UNEQUAL: continue;
case detail::equal_result::EMPTY: return false;
case detail::equal_result::EQUAL: return true;
Expand Down Expand Up @@ -793,7 +793,7 @@ class open_addressing_ref_impl {
auto const state = [&]() {
auto res = detail::equal_result::UNEQUAL;
for (auto& slot : window_slots) {
res = this->predicate_.operator()<is_insert::NO>(this->extract_key(slot), key);
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(slot));
if (res != detail::equal_result::UNEQUAL) { return res; }
}
return res;
Expand Down Expand Up @@ -830,7 +830,7 @@ class open_addressing_ref_impl {

for (auto i = 0; i < window_size; ++i) {
switch (
this->predicate_.operator()<is_insert::NO>(this->extract_key(window_slots[i]), key)) {
this->predicate_.operator()<is_insert::NO>(key, this->extract_key(window_slots[i]))) {
case detail::equal_result::EMPTY: {
return this->end();
}
Expand Down Expand Up @@ -869,7 +869,7 @@ class open_addressing_ref_impl {
auto const [state, intra_window_index] = [&]() {
auto res = detail::equal_result::UNEQUAL;
for (auto i = 0; i < window_size; ++i) {
res = this->predicate_.operator()<is_insert::NO>(this->extract_key(window_slots[i]), key);
res = this->predicate_.operator()<is_insert::NO>(key, this->extract_key(window_slots[i]));
if (res != detail::equal_result::UNEQUAL) { return window_probing_results{res, i}; }
}
// returns dummy index `-1` for UNEQUAL
Expand Down Expand Up @@ -1097,7 +1097,7 @@ class open_addressing_ref_impl {
if (cuco::detail::bitwise_compare(this->extract_key(*old_ptr), this->extract_key(expected))) {
return insert_result::SUCCESS;
} else {
return this->predicate_.equal_to(this->extract_key(*old_ptr), this->extract_key(desired)) ==
return this->predicate_.equal_to(this->extract_key(desired), this->extract_key(*old_ptr)) ==
detail::equal_result::EQUAL
? insert_result::DUPLICATE
: insert_result::CONTINUE;
Expand Down Expand Up @@ -1144,7 +1144,7 @@ class open_addressing_ref_impl {

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(desired.first, *old_key_ptr) == detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand Down Expand Up @@ -1183,7 +1183,7 @@ class open_addressing_ref_impl {

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(desired.first, *old_key_ptr) == detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand Down
6 changes: 3 additions & 3 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ class operator_impl<

for (auto& slot_content : window_slots) {
auto const eq_res =
ref_.impl_.predicate_.operator()<is_insert::YES>(slot_content.first, key);
ref_.impl_.predicate_.operator()<is_insert::YES>(key, slot_content.first);

// If the key is already in the container, update the payload and return
if (eq_res == detail::equal_result::EQUAL) {
Expand Down Expand Up @@ -449,7 +449,7 @@ class operator_impl<
auto const [state, intra_window_index] = [&]() {
auto res = detail::equal_result::UNEQUAL;
for (auto i = 0; i < window_size; ++i) {
res = ref_.impl_.predicate_.operator()<is_insert::YES>(window_slots[i].first, key);
res = ref_.impl_.predicate_.operator()<is_insert::YES>(key, window_slots[i].first);
if (res != detail::equal_result::UNEQUAL) {
return detail::window_probing_results{res, i};
}
Expand Down Expand Up @@ -514,7 +514,7 @@ class operator_impl<

// if key success or key was already present in the map
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key) or
(ref_.impl_.predicate().equal_to(*old_key_ptr, value.first) ==
(ref_.impl_.predicate().equal_to(value.first, *old_key_ptr) ==
detail::equal_result::EQUAL)) {
// Update payload
ref_.impl_.atomic_store(&slot->second, value.second);
Expand Down
4 changes: 2 additions & 2 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ namespace cuco {
* construction.
*
* @note Allows constant time concurrent modify or lookup operations from threads in device code.
* @note cuCollections data structures always place the slot keys on the left-hand side when
* invoking the key comparison predicate, i.e., `pred(slot_key, query_key)`. Order-sensitive
* @note cuCollections data structures always place the slot keys on the right-hand side when
* invoking the key comparison predicate, i.e., `pred(query_key, slot_key)`. Order-sensitive
* `KeyEqual` should be used with caution.
* @note `ProbingScheme::cg_size` indicates how many threads are used to handle one independent
* device operation. `cg_size == 1` uses the scalar (or non-CG) code paths.
Expand Down
4 changes: 2 additions & 2 deletions include/cuco/static_multiset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ namespace cuco {
* construction.
*
* @note Allows constant time concurrent modify or lookup operations from threads in device code.
* @note cuCollections data structures always place the slot keys on the left-hand side when
* invoking the key comparison predicate, i.e., `pred(slot_key, query_key)`. Order-sensitive
* @note cuCollections data structures always place the slot keys on the right-hand side when
* invoking the key comparison predicate, i.e., `pred(query_key, slot_key)`. Order-sensitive
* `KeyEqual` should be used with caution.
* @note `ProbingScheme::cg_size` indicates how many threads are used to handle one independent
* device operation. `cg_size == 1` uses the scalar (or non-CG) code paths.
Expand Down
4 changes: 2 additions & 2 deletions include/cuco/static_set.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ namespace cuco {
* construction.
*
* @note Allows constant time concurrent modify or lookup operations from threads in device code.
* @note cuCollections data structures always place the slot keys on the left-hand side when
* invoking the key comparison predicate, i.e., `pred(slot_key, query_key)`. Order-sensitive
* @note cuCollections data structures always place the slot keys on the right-hand side when
* invoking the key comparison predicate, i.e., `pred(query_key, slot_key)`. Order-sensitive
* `KeyEqual` should be used with caution.
* @note `ProbingScheme::cg_size` indicates how many threads are used to handle one independent
* device operation. `cg_size == 1` uses the scalar (or non-CG) code paths.
Expand Down
8 changes: 4 additions & 4 deletions tests/static_map/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ struct custom_hasher {
};
};

// User-defined device key equality
// User-defined device key equality, Slot key always on the right-hand side
struct custom_key_equal {
template <typename SlotKey, typename InputKey>
__device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const
template <typename InputKey, typename SlotKey>
__device__ bool operator()(InputKey const& lhs, SlotKey const& rhs) const
{
return lhs == rhs.a;
return lhs.a == rhs;
}
};

Expand Down
6 changes: 3 additions & 3 deletions tests/static_set/heterogeneous_lookup_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ struct custom_hasher {

// User-defined device key equality
struct custom_key_equal {
template <typename SlotKey, typename InputKey>
__device__ bool operator()(SlotKey const& lhs, InputKey const& rhs) const
template <typename InsertKey, typename SlotKey>
__device__ bool operator()(InsertKey const& lhs, SlotKey const& rhs) const
{
return lhs == rhs.a;
return lhs.a == rhs;
}
};

Expand Down

0 comments on commit d4c1613

Please sign in to comment.