From c7a72bea63b4c57c40cc545ce95fb9c0252d1995 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 17 Apr 2023 11:34:47 -0700 Subject: [PATCH] Fix is_min_close (#1419) Correlation and Cosine distance both return (1 - similarity) in the pairwise distances apis, meaning that is_min_close is returning the wrong sort order for them. Fix. Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1419 --- cpp/include/raft/distance/distance_types.hpp | 2 -- cpp/include/raft/sparse/neighbors/detail/knn.cuh | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/distance/distance_types.hpp b/cpp/include/raft/distance/distance_types.hpp index 4060147f1d..d17ef358ee 100644 --- a/cpp/include/raft/distance/distance_types.hpp +++ b/cpp/include/raft/distance/distance_types.hpp @@ -74,8 +74,6 @@ inline bool is_min_close(DistanceType metric) bool select_min; switch (metric) { case DistanceType::InnerProduct: - case DistanceType::CosineExpanded: - case DistanceType::CorrelationExpanded: // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 // {perform_k_selection}) diff --git a/cpp/include/raft/sparse/neighbors/detail/knn.cuh b/cpp/include/raft/sparse/neighbors/detail/knn.cuh index 7bedec9830..f9f07c13ca 100644 --- a/cpp/include/raft/sparse/neighbors/detail/knn.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/knn.cuh @@ -355,8 +355,7 @@ class sparse_knn_t { // want to adjust k. value_idx n_neighbors = std::min(static_cast(k), batch_cols); - bool ascending = true; - if (metric == raft::distance::DistanceType::InnerProduct) ascending = false; + bool ascending = raft::distance::is_min_close(metric); // kernel to slice first (min) k cols and copy into batched merge buffer raft::spatial::knn::select_k(batch_dists, @@ -425,4 +424,4 @@ class sparse_knn_t { raft::device_resources const& handle; }; -}; // namespace raft::sparse::neighbors::detail \ No newline at end of file +}; // namespace raft::sparse::neighbors::detail