diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 2432a9cc1c..e77ab36dc6 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -270,7 +270,11 @@ class hellinger_expanded_distances_t : public distances_t { raft::linalg::unaryOp( out_dists, out_dists, config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return sqrt(1 - input); }, + [=] __device__(value_t input) { + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + bool rectifier = abs(1 - input) > 1e-8; + return rectifier * sqrt(1 - input); + }, config_->stream); }