diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 78becc1b5b..159c4649aa 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -195,15 +195,15 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, ReduceOpT red_op(redOp); const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); + const auto lid = threadIdx.x % P::AccThCols; // reduce #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); KVPair tmp = {tmpkey, tmpvalue}; val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); }