Skip to content

Commit

Permalink
Numerical stability fixes for l2 pairwise distance (rapidsai#1319)
Browse files Browse the repository at this point in the history
When computing L2 distance, we could sometimes see the distances go slightly negative. This is because of numerical instability, especially when computing self distance. For L2SqrtExpanded, this would end up taking the sqrt of this negative value - which introduced NaN values.

Fix by clamping distances to 0.

This matches the behaviour found in the fused_l2 distance calculation: https://github.com/rapidsai/raft/blob/3ca7eacc5cb411facdfb08ff27663a3402486bc4/cpp/include/raft/distance/detail/fused_l2_nn.cuh#L179

And is similar to that found in the sparse l2 distance calculation: https://github.com/rapidsai/raft/blob/3ca7eacc5cb411facdfb08ff27663a3402486bc4/cpp/include/raft/sparse/distance/detail/l2_distance.cuh#L84-L85

Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1319
  • Loading branch information
benfred authored and lowener committed Mar 15, 2023
1 parent b779fb7 commit 0da719f
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion cpp/include/raft/distance/detail/euclidean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ struct L2ExpandedOp {
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
{
AccT outVal = aNorm + bNorm - DataT(2.0) * accVal;
// outVal could be negative due to numerical instability, especially when
// calculating self distance.
// clamp to 0 to avoid potential NaN in sqrt
outVal = outVal * (outVal > DataT(0.0));
return sqrt ? raft::sqrt(outVal) : outVal;
}

Expand Down Expand Up @@ -122,7 +126,8 @@ void euclideanExpImpl(const DataT* x,
for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < KPolicy::AccColsPerTh; ++j) {
acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j];
DataT val = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j];
acc[i][j] = val * (val > DataT(0.0));
}
}
if (sqrt) {
Expand Down

0 comments on commit 0da719f

Please sign in to comment.