Skip to content

Commit

Permalink
Fix hash table full occupancy bug (#647)
Browse files Browse the repository at this point in the history
The existing hash table implementation relies on empty slots to
terminate the probing sequence, leading to hangs when inserting into or
querying a fully occupied hash table. This PR resolves the issue by
tracking the initial slot index for each probing key and ensuring the
probing sequence terminates upon looping through all slots and returning
to the starting index. Benchmark tests confirm that this change has no
performance impact on non-fully occupied hash tables.
  • Loading branch information
PointKernel authored Dec 3, 2024
1 parent a4fb985 commit 681cf95
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 27 deletions.
86 changes: 59 additions & 27 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,10 @@ class open_addressing_ref_impl {
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -411,6 +412,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -428,9 +430,10 @@ class open_addressing_ref_impl {
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
Value const& value) noexcept
{
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -483,6 +486,7 @@ class open_addressing_ref_impl {
}
} else {
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}
}
Expand Down Expand Up @@ -513,9 +517,10 @@ class open_addressing_ref_impl {
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -554,6 +559,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return {this->end(), false}; }
};
}

Expand Down Expand Up @@ -584,9 +590,10 @@ class open_addressing_ref_impl {
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -653,6 +660,7 @@ class open_addressing_ref_impl {
}
} else {
++probing_iter;
if (*probing_iter == init_idx) { return {this->end(), false}; }
}
}
}
Expand All @@ -671,7 +679,8 @@ class open_addressing_ref_impl {
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand All @@ -696,6 +705,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -713,7 +723,8 @@ class open_addressing_ref_impl {
__device__ bool erase(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key) noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -750,6 +761,7 @@ class open_addressing_ref_impl {
if (group.any(state == detail::equal_result::EMPTY)) { return false; }

++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -769,7 +781,8 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ bool contains(ProbeKey const& key) const noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -783,6 +796,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -803,7 +817,8 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ bool contains(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand All @@ -821,6 +836,7 @@ class open_addressing_ref_impl {
if (group.any(state == detail::equal_result::EMPTY)) { return false; }

++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -840,7 +856,8 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ const_iterator find(ProbeKey const& key) const noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -859,6 +876,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return this->end(); }
}
}

Expand All @@ -879,7 +897,8 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ const_iterator find(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -908,6 +927,7 @@ class open_addressing_ref_impl {
if (group.any(state == detail::equal_result::EMPTY)) { return this->end(); }

++probing_iter;
if (*probing_iter == init_idx) { return this->end(); }
}
}

Expand All @@ -926,8 +946,9 @@ class open_addressing_ref_impl {
if constexpr (not allows_duplicates) {
return static_cast<size_type>(this->contains(key));
} else {
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
size_type count = 0;
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
size_type count = 0;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -942,6 +963,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return count; }
}
}
}
Expand All @@ -960,8 +982,9 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ size_type count(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
size_type count = 0;
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
size_type count = 0;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand All @@ -978,6 +1001,7 @@ class open_addressing_ref_impl {

if (group.any(state == detail::equal_result::EMPTY)) { return count; }
++probing_iter;
if (*probing_iter == init_idx) { return count; }
}
}

Expand Down Expand Up @@ -1177,6 +1201,7 @@ class open_addressing_ref_impl {
auto const& probe_key = *(input_probe + idx);
auto probing_iter =
this->probing_scheme_(probing_tile, probe_key, this->storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

bool running = true;
[[maybe_unused]] bool found_match = false;
Expand Down Expand Up @@ -1277,6 +1302,7 @@ class open_addressing_ref_impl {

// onto the next probing bucket
++probing_iter;
if (*probing_iter == init_idx) { running = false; }
} // while running
} // if active_flag

Expand Down Expand Up @@ -1305,7 +1331,8 @@ class open_addressing_ref_impl {
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent());
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -1325,6 +1352,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}

Expand Down Expand Up @@ -1352,8 +1380,9 @@ class open_addressing_ref_impl {
ProbeKey const& key,
CallbackOp&& callback_op) const noexcept
{
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
bool empty = false;
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
bool empty = false;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -1378,6 +1407,7 @@ class open_addressing_ref_impl {
if (group.any(empty)) { return; }

++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}

Expand Down Expand Up @@ -1414,8 +1444,9 @@ class open_addressing_ref_impl {
CallbackOp&& callback_op,
SyncOp&& sync_op) const noexcept
{
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
bool empty = false;
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
bool empty = false;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -1441,6 +1472,7 @@ class open_addressing_ref_impl {
if (group.any(empty)) { return; }

++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}

Expand Down
8 changes: 8 additions & 0 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ class operator_impl<
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref[*probing_iter];
Expand All @@ -514,6 +515,7 @@ class operator_impl<
}
}
++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}

Expand All @@ -539,6 +541,7 @@ class operator_impl<
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref[*probing_iter];
Expand Down Expand Up @@ -578,6 +581,7 @@ class operator_impl<
if (group.shfl(status, src_lane)) { return; }
} else {
++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}
}
Expand Down Expand Up @@ -855,6 +859,7 @@ class operator_impl<
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.bucket_extent());
auto const init_idx = *probing_iter;
auto const empty_value = ref_.empty_value_sentinel();

// wait for payload only when init != sentinel and insert strategy is not `packed_cas`
Expand Down Expand Up @@ -894,6 +899,7 @@ class operator_impl<
}
}
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand Down Expand Up @@ -929,6 +935,7 @@ class operator_impl<
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent());
auto const init_idx = *probing_iter;
auto const empty_value = ref_.empty_value_sentinel();

// wait for payload only when init != sentinel and insert strategy is not `packed_cas`
Expand Down Expand Up @@ -987,6 +994,7 @@ class operator_impl<
}
} else {
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}
}
Expand Down

0 comments on commit 681cf95

Please sign in to comment.