diff --git a/cpp/include/raft/linalg/distance_type.h b/cpp/include/raft/linalg/distance_type.h index 3dd038630e..f3a22a07ed 100644 --- a/cpp/include/raft/linalg/distance_type.h +++ b/cpp/include/raft/linalg/distance_type.h @@ -52,12 +52,12 @@ enum DistanceType : unsigned short { Haversine = 13, /** Bray-Curtis distance **/ BrayCurtis = 14, - /** Jensen-Shannon distance**/ + /** Jensen-Shannon distance **/ JensenShannon = 15, - + /** Dice-Sorensen distance **/ + DiceExpanded = 16, /** Precomputed (special value) **/ Precomputed = 100 }; - }; // namespace distance }; // end namespace raft diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 92492dc37a..c6ec40cb03 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -55,7 +55,8 @@ static const std::unordered_set supportedDistance{ raft::distance::DistanceType::LpUnexpanded, raft::distance::DistanceType::JaccardExpanded, raft::distance::DistanceType::CosineExpanded, - raft::distance::DistanceType::HellingerExpanded}; + raft::distance::DistanceType::HellingerExpanded, + raft::distance::DistanceType::DiceExpanded}; /** * Compute pairwise distances between A and B, using the provided @@ -116,6 +117,8 @@ void pairwiseDistance(value_t *out, hellinger_expanded_distances_t(input_config) .compute(out); break; + case raft::distance::DistanceType::DiceExpanded: + dice_expanded_distances_t(input_config).compute(out); default: THROW("Unsupported distance: %d", metric);