Skip to content

Commit

Permalink
fusedL2NN: Preventatively reduce shfl_sync width
Browse files Browse the repository at this point in the history
In the current implementation, it looks like values from different rows
are mixed together in what should be a row-wise warp reduce. All tests
do pass however.

Just in case, I have added a width parameter to the shuffle so that it
only shuffles within a row within the warp.
  • Loading branch information
ahendriksen committed Sep 1, 2022
1 parent bb3c8dd commit da8fa38
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,15 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min,
ReduceOpT red_op(redOp);

const auto accrowid = threadIdx.x / P::AccThCols;
const auto lid = raft::laneId();
const auto lid = threadIdx.x % P::AccThCols;

// reduce
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = P::AccThCols / 2; j > 0; j >>= 1) {
auto tmpkey = raft::shfl(val[i].key, lid + j);
auto tmpvalue = raft::shfl(val[i].value, lid + j);
auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols);
auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols);
KVPair tmp = {tmpkey, tmpvalue};
val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]);
}
Expand Down

0 comments on commit da8fa38

Please sign in to comment.