From da8fa388b25b98a45f93d9477f9a2fd18b503471 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 1 Sep 2022 14:13:19 +0200 Subject: [PATCH] fusedL2NN: Preventatively reduce shfl_sync width In the current implementation, it looks like values from different rows are mixed together in what should be a row-wise warp reduce. All tests do pass however. Just in case, I have added a width parameter to the shuffle so that it only shuffles within a row within the warp. --- cpp/include/raft/distance/detail/fused_l2_nn.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 78becc1b5b..159c4649aa 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -195,15 +195,15 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, ReduceOpT red_op(redOp); const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); + const auto lid = threadIdx.x % P::AccThCols; // reduce #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); KVPair tmp = {tmpkey, tmpvalue}; val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); }