From d1fd927bc4ec67bfd765620b5fa93f17c54cfa70 Mon Sep 17 00:00:00 2001 From: Micka <9810050+lowener@users.noreply.github.com> Date: Thu, 1 Apr 2021 17:02:03 +0200 Subject: [PATCH] Adjust Hellinger pairwise distance to vaoid NaNs (#189) This change will fix NaNs that arise in Hellinger pairwise distance when the input to sqrt is negative. These kind of inputs can happen even when the inputs are normalized row-wise to 1 due to accumulation of rounding approximation. Authors: - Micka (https://github.com/lowener) Approvers: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/189 --- cpp/include/raft/sparse/distance/l2_distance.cuh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 2432a9cc1cc..e77ab36dc69 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -270,7 +270,11 @@ class hellinger_expanded_distances_t : public distances_t { raft::linalg::unaryOp( out_dists, out_dists, config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return sqrt(1 - input); }, + [=] __device__(value_t input) { + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + bool rectifier = abs(1 - input) > 1e-8; + return rectifier * sqrt(1 - input); + }, config_->stream); }