From 0da719fa6672a42903c0f7b481c72e766c00cb41 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 8 Mar 2023 23:35:39 -0800 Subject: [PATCH] Numerical stability fixes for l2 pairwise distance (#1319) When computing L2 distance, we could sometimes see the distances go slightly negative. This is because of numerical instability, especially when computing self distance. For L2SqrtExpanded, this would end up taking the sqrt of this negative value - which introduced NaN values. Fix by clamping distances to 0. This matches the behaviour found in the fused_l2 distance calculation: https://github.com/rapidsai/raft/blob/3ca7eacc5cb411facdfb08ff27663a3402486bc4/cpp/include/raft/distance/detail/fused_l2_nn.cuh#L179 And is similar to that found in the sparse l2 distance calculation: https://github.com/rapidsai/raft/blob/3ca7eacc5cb411facdfb08ff27663a3402486bc4/cpp/include/raft/sparse/distance/detail/l2_distance.cuh#L84-L85 Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1319 --- cpp/include/raft/distance/detail/euclidean.cuh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 1a2db63f5c..7c5fdb912c 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -34,6 +34,10 @@ struct L2ExpandedOp { __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + // outVal could be negative due to numerical instability, especially when + // calculating self distance. + // clamp to 0 to avoid potential NaN in sqrt + outVal = outVal * (outVal > DataT(0.0)); return sqrt ? raft::sqrt(outVal) : outVal; } @@ -122,7 +126,8 @@ void euclideanExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + DataT val = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + acc[i][j] = val * (val > DataT(0.0)); } } if (sqrt) {