From c1b077edb13111b0d2a596aced516f63f68491bc Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Wed, 23 Nov 2022 01:00:19 +0100 Subject: [PATCH] Fix fusedL2NN bug that can happen when the same point appears in both x and y (#1040) Solves #1036 Even when computing a sum of squares, the distance from a point to itself can apparently be `-0.0` in which case the square root is `nan` and comparisons are broken. Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1040 --- cpp/include/raft/distance/detail/fused_l2_nn.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 1385d0aa09..e8c2648c2e 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -174,7 +174,8 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + auto acc_ij = acc[i][j]; + acc[i][j] = acc_ij > DataT{0} ? raft::mySqrt(acc_ij) : DataT{0}; } } }