diff --git a/include/cuco/detail/static_multimap/device_view_impl.inl b/include/cuco/detail/static_multimap/device_view_impl.inl index 9e328898d..98c08e720 100644 --- a/include/cuco/detail/static_multimap/device_view_impl.inl +++ b/include/cuco/detail/static_multimap/device_view_impl.inl @@ -1000,8 +1000,12 @@ class static_multimap::device_view_ if (*flushing_cg_counter + flushing_cg.size() * vector_width() > buffer_size) { flush_output_buffer( flushing_cg, *flushing_cg_counter, output_buffer, num_matches, output_begin); + // Everyone in the group reads the counter when flushing, so + // sync before writing. + flushing_cg.sync(); // First lane reset warp-level counter if (flushing_cg.thread_rank() == 0) { *flushing_cg_counter = 0; } + flushing_cg.sync(); } current_slot = next_slot(current_slot); @@ -1092,8 +1096,12 @@ class static_multimap::device_view_ // Flush if the next iteration won't fit into buffer if ((*cg_counter + g.size()) > buffer_size) { flush_output_buffer(g, *cg_counter, output_buffer, num_matches, output_begin); + // Everyone in the group reads the counter when flushing, so + // sync before writing. + g.sync(); // First lane reset CG-level counter if (lane_id == 0) { *cg_counter = 0; } + g.sync(); } current_slot = next_slot(current_slot); } // while running @@ -1428,8 +1436,12 @@ class static_multimap::device_view_ num_matches, probe_output_begin, contained_output_begin); + // Everyone in the group reads the counter when flushing, so + // sync before writing. + flushing_cg.sync(); // First lane reset warp-level counter if (flushing_cg.thread_rank() == 0) { *flushing_cg_counter = 0; } + flushing_cg.sync(); } current_slot = next_slot(current_slot); @@ -1539,8 +1551,12 @@ class static_multimap::device_view_ num_matches, probe_output_begin, contained_output_begin); + // Everyone in the group reads the counter when flushing, so + // sync before writing. + g.sync(); // First lane reset CG-level counter if (lane_id == 0) { *cg_counter = 0; } + g.sync(); } current_slot = next_slot(current_slot); } // while running diff --git a/include/cuco/detail/static_multimap/kernels.cuh b/include/cuco/detail/static_multimap/kernels.cuh index ca5f898a5..67fb36045 100644 --- a/include/cuco/detail/static_multimap/kernels.cuh +++ b/include/cuco/detail/static_multimap/kernels.cuh @@ -387,6 +387,8 @@ __global__ void retrieve(InputIt first, if (flushing_cg.thread_rank() == 0) { flushing_cg_counter[flushing_cg_id] = 0; } + flushing_cg.sync(); + while (flushing_cg.any(idx < n)) { bool active_flag = idx < n; auto active_flushing_cg = cg::binary_partition(flushing_cg, active_flag); @@ -416,6 +418,7 @@ __global__ void retrieve(InputIt first, idx += loop_stride; } + flushing_cg.sync(); // Final flush of output buffer if (flushing_cg_counter[flushing_cg_id] > 0) { view.flush_output_buffer(flushing_cg, @@ -499,6 +502,8 @@ __global__ void pair_retrieve(InputIt first, if (flushing_cg.thread_rank() == 0) { flushing_cg_counter[flushing_cg_id] = 0; } + flushing_cg.sync(); + while (flushing_cg.any(idx < n)) { bool active_flag = idx < n; auto active_flushing_cg = cg::binary_partition(flushing_cg, active_flag); @@ -532,6 +537,7 @@ __global__ void pair_retrieve(InputIt first, idx += loop_stride; } + flushing_cg.sync(); // Final flush of output buffer if (flushing_cg_counter[flushing_cg_id] > 0) { view.flush_output_buffer(flushing_cg,