Numerical stability fixes for l2 pairwise distance #1319
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
When computing L2 distance, we could sometimes see the distances go slightly negative. This is because of numerical instability, especially when computing self distance. For L2SqrtExpanded, this would end up taking the sqrt of this negative value - which introduced NaN values.
Fix by clamping distances to 0.
This matches the behaviour found in the fused_l2 distance calculation:
raft/cpp/include/raft/distance/detail/fused_l2_nn.cuh
Line 179 in 3ca7eac
And is similar to that found in the sparse l2 distance calculation:
raft/cpp/include/raft/sparse/distance/detail/l2_distance.cuh
Lines 84 to 85 in 3ca7eac