From 87319de8d0925f6f72258ad12f7eaee3c1c594eb Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 1 Sep 2022 14:13:19 +0200 Subject: [PATCH] fusedL2NN: Preventatively reduce shfl_sync width In the current implementation, it looks like values from different rows are mixed together in what should be a row-wise warp reduce. All tests do pass however. Just in case, I have added a width parameter to the shuffle so that it only shuffles within a row within the warp. --- cpp/include/raft/distance/detail/fused_l2_nn.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 78becc1b5b..6a51bdcf1a 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -202,8 +202,8 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); KVPair tmp = {tmpkey, tmpvalue}; val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); }