diff --git a/cpp/include/raft/sparse/distance/bin_distance.cuh b/cpp/include/raft/sparse/distance/bin_distance.cuh index ef605e1fe0..ae5cbdf9d3 100644 --- a/cpp/include/raft/sparse/distance/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/bin_distance.cuh @@ -155,7 +155,7 @@ class jaccard_expanded_distances_t : public distances_t { /** * Dice distance using the expanded form: - * 1 - ((2 * sum(x_k * y_k)) / (sum(x_k)^2 + sum(y_k)^2)) + * 1 - ((2 * sum(x_k * y_k)) / (sum(x_k) + sum(y_k))) */ template class dice_expanded_distances_t : public distances_t { @@ -183,8 +183,10 @@ class dice_expanded_distances_t : public distances_t { b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, config_->handle, config_->allocator, config_->stream, [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { - value_t q_r_union = (q_norm * q_norm) + (r_norm * r_norm); - return (2 * dot) / q_r_union; + value_t q_r_union = q_norm + r_norm; + value_t dice = (2 * dot) / q_r_union; + bool both_empty = q_r_union == 0; + return 1 - ((!both_empty * dice) + both_empty); }); } diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index c6ec40cb03..0cd0be11be 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -119,6 +119,7 @@ void pairwiseDistance(value_t *out, break; case raft::distance::DistanceType::DiceExpanded: dice_expanded_distances_t(input_config).compute(out); + break; default: THROW("Unsupported distance: %d", metric);