Skip to content

Commit

Permalink
fusedL2NN: Fix updateReDucedVal with >2 rows/warp
Browse files Browse the repository at this point in the history
In updateReDucedVal, a single warp can contain multiple rows (in
registers). A single thread within the warp uses the first element of
each row to update an output array (atomically).

In the previous implementation, a shuffle was used to move the head of
each row into the first thread of the warp. Unfortunately, this would
overwrite the value all other rows. This strategy, however, worked when
the number of rows per warp equalled 2. Hence, the bug never triggered.

In a recent commit, the number of rows per warp was increased to four in
certain situations (skinny matrices). Hence, this bug triggered.

In the new implementation, the values are not shuffled into the first
thread of the warp any more. Instead, the threads that contain the first
element of a row update the output in sequential order. The sequential
ordering is necessary to avoid deadlock on Pascal architecture.
  • Loading branch information
ahendriksen committed Sep 1, 2022
1 parent 495a9d2 commit bb3c8dd
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ DI void updateReducedVal(
const auto lid = threadIdx.x % raft::WarpSize;
const auto accrowid = threadIdx.x / P::AccThCols;

// for now have first lane from each warp update a unique output row. This
// will resolve hang issues with pre-Volta architectures
// Update each output row in order within a warp. This will resolve hang
// issues with pre-Volta architectures
#pragma unroll
for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) {
if (lid == 0) {
if (lid == j * P::AccThCols) {
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rid = gridStrideY + accrowid + j + i * P::AccThRows;
auto rid = gridStrideY + accrowid + i * P::AccThRows;
if (rid < m) {
auto value = val[i];
while (atomicCAS(mutex + rid, 0, 1) == 1)
Expand All @@ -111,14 +111,6 @@ DI void updateReducedVal(
}
}
}
if (j < (raft::WarpSize / P::AccThCols) - 1) {
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols);
auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols);
val[i] = {tmpkey, tmpvalue};
}
}
}
}

Expand Down

0 comments on commit bb3c8dd

Please sign in to comment.