From 941550498b8bd6c63eae18be81a9cf2b3d7c61fc Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 30 Mar 2021 12:25:08 -0700 Subject: [PATCH 1/2] Adjust Hellinger pairwise distance to vaoid NaNs --- cpp/include/raft/sparse/distance/l2_distance.cuh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 2432a9cc1c..f591e5f380 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -270,7 +270,11 @@ class hellinger_expanded_distances_t : public distances_t { raft::linalg::unaryOp( out_dists, out_dists, config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return sqrt(1 - input); }, + [=] __device__(value_t input) { + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + value_t rectifier = (1 - input) > 0; + return rectifier * sqrt(1 - input); + }, config_->stream); } From df899879ce24cc19a4f00c43e17ee5c5d5a7876d Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 1 Apr 2021 07:12:17 -0700 Subject: [PATCH 2/2] Change the rectifier to only correct numerical instabilities --- cpp/include/raft/sparse/distance/l2_distance.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index f591e5f380..e77ab36dc6 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -272,7 +272,7 @@ class hellinger_expanded_distances_t : public distances_t { out_dists, out_dists, config_->a_nrows * config_->b_nrows, [=] __device__(value_t input) { // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative - value_t rectifier = (1 - input) > 0; + bool rectifier = abs(1 - input) > 1e-8; return rectifier * sqrt(1 - input); }, config_->stream);