Skip to content

Commit

Permalink
Add block synchronisation after reset of CG-level counters (#337)
Browse files Browse the repository at this point in the history
This is needed because if we go round loops again, we might read before
things have been reset.

- Tentatively closes #336.
  • Loading branch information
wence- authored Jul 31, 2023
1 parent ff16201 commit fd7263c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
16 changes: 16 additions & 0 deletions include/cuco/detail/static_multimap/device_view_impl.inl
Original file line number Diff line number Diff line change
Expand Up @@ -1000,8 +1000,12 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
if (*flushing_cg_counter + flushing_cg.size() * vector_width() > buffer_size) {
flush_output_buffer(
flushing_cg, *flushing_cg_counter, output_buffer, num_matches, output_begin);
// Everyone in the group reads the counter when flushing, so
// sync before writing.
flushing_cg.sync();
// First lane reset warp-level counter
if (flushing_cg.thread_rank() == 0) { *flushing_cg_counter = 0; }
flushing_cg.sync();
}

current_slot = next_slot(current_slot);
Expand Down Expand Up @@ -1092,8 +1096,12 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
// Flush if the next iteration won't fit into buffer
if ((*cg_counter + g.size()) > buffer_size) {
flush_output_buffer(g, *cg_counter, output_buffer, num_matches, output_begin);
// Everyone in the group reads the counter when flushing, so
// sync before writing.
g.sync();
// First lane reset CG-level counter
if (lane_id == 0) { *cg_counter = 0; }
g.sync();
}
current_slot = next_slot(current_slot);
} // while running
Expand Down Expand Up @@ -1428,8 +1436,12 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
num_matches,
probe_output_begin,
contained_output_begin);
// Everyone in the group reads the counter when flushing, so
// sync before writing.
flushing_cg.sync();
// First lane reset warp-level counter
if (flushing_cg.thread_rank() == 0) { *flushing_cg_counter = 0; }
flushing_cg.sync();
}

current_slot = next_slot(current_slot);
Expand Down Expand Up @@ -1539,8 +1551,12 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
num_matches,
probe_output_begin,
contained_output_begin);
// Everyone in the group reads the counter when flushing, so
// sync before writing.
g.sync();
// First lane reset CG-level counter
if (lane_id == 0) { *cg_counter = 0; }
g.sync();
}
current_slot = next_slot(current_slot);
} // while running
Expand Down
6 changes: 6 additions & 0 deletions include/cuco/detail/static_multimap/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ __global__ void retrieve(InputIt first,

if (flushing_cg.thread_rank() == 0) { flushing_cg_counter[flushing_cg_id] = 0; }

flushing_cg.sync();

while (flushing_cg.any(idx < n)) {
bool active_flag = idx < n;
auto active_flushing_cg = cg::binary_partition<flushing_cg_size>(flushing_cg, active_flag);
Expand Down Expand Up @@ -416,6 +418,7 @@ __global__ void retrieve(InputIt first,
idx += loop_stride;
}

flushing_cg.sync();
// Final flush of output buffer
if (flushing_cg_counter[flushing_cg_id] > 0) {
view.flush_output_buffer(flushing_cg,
Expand Down Expand Up @@ -499,6 +502,8 @@ __global__ void pair_retrieve(InputIt first,

if (flushing_cg.thread_rank() == 0) { flushing_cg_counter[flushing_cg_id] = 0; }

flushing_cg.sync();

while (flushing_cg.any(idx < n)) {
bool active_flag = idx < n;
auto active_flushing_cg = cg::binary_partition<flushing_cg_size>(flushing_cg, active_flag);
Expand Down Expand Up @@ -532,6 +537,7 @@ __global__ void pair_retrieve(InputIt first,
idx += loop_stride;
}

flushing_cg.sync();
// Final flush of output buffer
if (flushing_cg_counter[flushing_cg_id] > 0) {
view.flush_output_buffer(flushing_cg,
Expand Down

0 comments on commit fd7263c

Please sign in to comment.