diff --git a/cpp/include/raft/distance/distance_types.hpp b/cpp/include/raft/distance/distance_types.hpp index 4060147f1d..d17ef358ee 100644 --- a/cpp/include/raft/distance/distance_types.hpp +++ b/cpp/include/raft/distance/distance_types.hpp @@ -74,8 +74,6 @@ inline bool is_min_close(DistanceType metric) bool select_min; switch (metric) { case DistanceType::InnerProduct: - case DistanceType::CosineExpanded: - case DistanceType::CorrelationExpanded: // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 // {perform_k_selection}) diff --git a/cpp/include/raft/sparse/neighbors/detail/knn.cuh b/cpp/include/raft/sparse/neighbors/detail/knn.cuh index 7bedec9830..f9f07c13ca 100644 --- a/cpp/include/raft/sparse/neighbors/detail/knn.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/knn.cuh @@ -355,8 +355,7 @@ class sparse_knn_t { // want to adjust k. value_idx n_neighbors = std::min(static_cast(k), batch_cols); - bool ascending = true; - if (metric == raft::distance::DistanceType::InnerProduct) ascending = false; + bool ascending = raft::distance::is_min_close(metric); // kernel to slice first (min) k cols and copy into batched merge buffer raft::spatial::knn::select_k(batch_dists, @@ -425,4 +424,4 @@ class sparse_knn_t { raft::device_resources const& handle; }; -}; // namespace raft::sparse::neighbors::detail \ No newline at end of file +}; // namespace raft::sparse::neighbors::detail