diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 1a2db63f5c..7c5fdb912c 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -34,6 +34,10 @@ struct L2ExpandedOp { __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + // outVal could be negative due to numerical instability, especially when + // calculating self distance. + // clamp to 0 to avoid potential NaN in sqrt + outVal = outVal * (outVal > DataT(0.0)); return sqrt ? raft::sqrt(outVal) : outVal; } @@ -122,7 +126,8 @@ void euclideanExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + DataT val = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + acc[i][j] = val * (val > DataT(0.0)); } } if (sqrt) {