From 46250b2f88a244c2de5bed7c573b3bf24135d13a Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 29 Mar 2021 08:37:37 -0700 Subject: [PATCH 1/2] Dice formula correction --- cpp/include/raft/sparse/distance/bin_distance.cuh | 8 +++++--- cpp/include/raft/sparse/distance/distance.cuh | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/sparse/distance/bin_distance.cuh b/cpp/include/raft/sparse/distance/bin_distance.cuh index ef605e1fe0..95814eebff 100644 --- a/cpp/include/raft/sparse/distance/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/bin_distance.cuh @@ -155,7 +155,7 @@ class jaccard_expanded_distances_t : public distances_t { /** * Dice distance using the expanded form: - * 1 - ((2 * sum(x_k * y_k)) / (sum(x_k)^2 + sum(y_k)^2)) + * 1 - ((2 * sum(x_k * y_k)) / (sum(x_k) + sum(y_k))) */ template class dice_expanded_distances_t : public distances_t { @@ -183,8 +183,10 @@ class dice_expanded_distances_t : public distances_t { b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, config_->handle, config_->allocator, config_->stream, [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { - value_t q_r_union = (q_norm * q_norm) + (r_norm * r_norm); - return (2 * dot) / q_r_union; + value_t q_r_union = q_norm + r_norm; + // deal with potential for 0 in denominator by forcing 0/1 instead + return 1 - + ((q_r_union != 0) * (2 * dot)) / ((q_r_union == 0) + q_r_union); }); } diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index c6ec40cb03..0cd0be11be 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -119,6 +119,7 @@ void pairwiseDistance(value_t *out, break; case raft::distance::DistanceType::DiceExpanded: dice_expanded_distances_t(input_config).compute(out); + break; default: THROW("Unsupported distance: %d", metric); From 45b70075347d810b41af2b8daab78fef39ce4f00 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 29 Mar 2021 13:21:56 -0700 Subject: [PATCH 2/2] Fixed nans to match Scipy in dice distance --- cpp/include/raft/sparse/distance/bin_distance.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/sparse/distance/bin_distance.cuh b/cpp/include/raft/sparse/distance/bin_distance.cuh index 95814eebff..ae5cbdf9d3 100644 --- a/cpp/include/raft/sparse/distance/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/bin_distance.cuh @@ -184,9 +184,9 @@ class dice_expanded_distances_t : public distances_t { config_->handle, config_->allocator, config_->stream, [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { value_t q_r_union = q_norm + r_norm; - // deal with potential for 0 in denominator by forcing 0/1 instead - return 1 - - ((q_r_union != 0) * (2 * dot)) / ((q_r_union == 0) + q_r_union); + value_t dice = (2 * dot) / q_r_union; + bool both_empty = q_r_union == 0; + return 1 - ((!both_empty * dice) + both_empty); }); }