From 0d3bd3da5a2eb77a5ca2f7f9b9ed367030811c06 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 28 Mar 2023 01:00:57 -0700 Subject: [PATCH] add a distance epilogue function to the bfknn call (#1371) Add the ability for a user to specify an epilogue function to run after the distance in the brute_force::knn call. This lets us remove faiss from cuml, by updating the hdbscan reachability code (https://github.com/rapidsai/cuml/pull/5293) Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1371 --- cpp/include/raft/neighbors/brute_force.cuh | 32 ++++++------ .../raft/neighbors/detail/knn_brute_force.cuh | 51 ++++++++++++++----- .../neighbors/specializations/brute_force.cuh | 3 +- .../raft/spatial/knn/detail/ball_cover.cuh | 1 - .../brute_force_knn_int64_t_float.cu | 11 +--- .../brute_force_knn_impl_long_float_int.cu | 3 +- .../brute_force_knn_impl_long_float_uint.cu | 3 +- .../brute_force_knn_impl_uint_float_int.cu | 3 +- .../brute_force_knn_impl_uint_float_uint.cu | 3 +- cpp/test/neighbors/ball_cover.cu | 1 - cpp/test/neighbors/knn.cu | 2 +- 11 files changed, 69 insertions(+), 44 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 4891cc5f8d..dac1a29c7f 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -122,9 +122,8 @@ inline void knn_merge_parts( * * raft::raft::device_resources handle; * ... - * int k = 10; * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::knn(handle, index, search, indices, distances, k, metric); + * brute_force::knn(handle, index, search, indices, distances, metric); * @endcode * * @param[in] handle: the cuml handle to use @@ -132,28 +131,31 @@ inline void knn_merge_parts( * @param[in] search: matrix (size n*d) to be used for searching the index * @param[out] indices: matrix (size n*k) to store output knn indices * @param[out] distances: matrix (size n*k) to store the output knn distance - * @param[in] k: the number of nearest neighbors to return * @param[in] metric: distance metric to use. Euclidean (L2) is used by default * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. * @param[in] global_id_offset: optional starting global id mapping for the local partition * (assumes the index contains contiguous ids in the global id space) + * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This + function takes a triple of the (value, rowid, colid) for each + element in the pairwise distances and returns a transformed value + back. */ template + typename search_layout, + typename epilogue_op = raft::identity_op> void knn(raft::device_resources const& handle, std::vector> index, raft::device_matrix_view search, raft::device_matrix_view indices, raft::device_matrix_view distances, - value_int k, distance::DistanceType metric = distance::DistanceType::L2Unexpanded, std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt) + std::optional global_id_offset = std::nullopt, + epilogue_op distance_epilogue = raft::identity_op()) { RAFT_EXPECTS(index[0].extent(1) == search.extent(1), "Number of dimensions for both index and search matrices must be equal"); @@ -161,15 +163,14 @@ void knn(raft::device_resources const& handle, RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), "Number of rows in output indices and distances matrices must equal number of rows " "in search matrix."); - RAFT_EXPECTS( - indices.extent(1) == distances.extent(1) && distances.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); + RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), + "Number of columns in output indices and distances matrices must the same"); bool rowMajorIndex = std::is_same_v; bool rowMajorQuery = std::is_same_v; std::vector inputs; - std::vector sizes; + std::vector sizes; for (std::size_t i = 0; i < index.size(); ++i) { inputs.push_back(const_cast(index[i].data_handle())); sizes.push_back(index[i].extent(0)); @@ -183,18 +184,19 @@ void knn(raft::device_resources const& handle, raft::neighbors::detail::brute_force_knn_impl(handle, inputs, sizes, - static_cast(index[0].extent(1)), + index[0].extent(1), // TODO: This is unfortunate. Need to fix. const_cast(search.data_handle()), - static_cast(search.extent(0)), + search.extent(0), indices.data_handle(), distances.data_handle(), - k, + indices.extent(1), rowMajorIndex, rowMajorQuery, trans_arg, metric, - metric_arg.value_or(2.0f)); + metric_arg.value_or(2.0f), + distance_epilogue); } /** diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 875fc3b37c..a776ce2586 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -47,7 +47,9 @@ using namespace raft::spatial::knn; * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances */ -template +template void tiled_brute_force_knn(const raft::device_resources& handle, const ElementType* search, // size (m ,d) const ElementType* index, // size (n ,d) @@ -58,9 +60,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 0.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0) + float metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + DistanceEpilogue distance_epilogue = raft::identity_op()) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -152,25 +155,41 @@ void tiled_brute_force_knn(const raft::device_resources& handle, metric_arg); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto row_norms = search_norms.data() + i; - auto col_norms = index_norms.data() + j; + auto row_norms = search_norms.data(); + auto col_norms = index_norms.data(); auto dist = temp_distances.data(); raft::linalg::map_offset( handle, raft::make_device_vector_view(dist, current_query_size * current_centroid_size), - [=] __device__(IndexType i) { - IndexType row = i / current_centroid_size, col = i % current_centroid_size; + [=] __device__(IndexType idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); - auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + auto val = row_norms[row] + col_norms[col] - 2.0 * dist[idx]; // due to numerical instability (especially around self-distance) // the distances here could be slightly negative, which will // cause NaN values in the subsequent sqrt. Clamp to 0 val = val * (val >= 0.0001); if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); } + val = distance_epilogue(val, row, col); return val; }); + } else { + // if we're not l2 distance, and we have a distance epilogue - run it now + if constexpr (!std::is_same_v) { + auto distances_ptr = temp_distances.data(); + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(temp_distances.data(), + current_query_size * current_centroid_size), + [=] __device__(size_t idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + return distance_epilogue(distances_ptr[idx], row, col); + }); + } } select_k(temp_distances.data(), @@ -250,7 +269,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm */ -template +template void brute_force_knn_impl( raft::device_resources const& handle, std::vector& input, @@ -265,7 +287,8 @@ void brute_force_knn_impl( bool rowMajorQuery = true, std::vector* translations = nullptr, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metricArg = 0) + float metricArg = 0, + DistanceEpilogue distance_epilogue = raft::identity_op()) { auto userStream = handle.get_stream(); @@ -355,6 +378,7 @@ void brute_force_knn_impl( auto stream = handle.get_next_usable_stream(i); if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + std::is_same_v && (metric == raft::distance::DistanceType::L2Unexpanded || metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::L2Expanded || @@ -424,7 +448,10 @@ void brute_force_knn_impl( out_d_ptr, out_i_ptr, tiled_metric, - metricArg); + metricArg, + 0, + 0, + distance_epilogue); break; } } diff --git a/cpp/include/raft/neighbors/specializations/brute_force.cuh b/cpp/include/raft/neighbors/specializations/brute_force.cuh index d418d40185..1337beb68a 100644 --- a/cpp/include/raft/neighbors/specializations/brute_force.cuh +++ b/cpp/include/raft/neighbors/specializations/brute_force.cuh @@ -36,7 +36,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, int); RAFT_INST(long, float, unsigned int); RAFT_INST(uint32_t, float, int); diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 99d688e232..c8fc6eefda 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -185,7 +185,6 @@ void k_closest_landmarks(raft::device_resources const& handle, make_device_matrix_view(query_pts, n_query_pts, inputs[0].extent(1)), make_device_matrix_view(R_knn_inds, n_query_pts, k), make_device_matrix_view(R_knn_dists, n_query_pts, k), - k, index.get_metric()); } diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu index 585084fc97..88545b3607 100644 --- a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu @@ -38,15 +38,8 @@ namespace raft::runtime::neighbors::brute_force { { \ std::vector> vec; \ vec.push_back(index); \ - raft::neighbors::brute_force::knn(handle, \ - vec, \ - search, \ - indices, \ - distances, \ - static_cast(distances.extent(1)), \ - metric, \ - metric_arg, \ - global_id_offset); \ + raft::neighbors::brute_force::knn( \ + handle, vec, search, indices, distances, metric, metric_arg, global_id_offset); \ } RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu index 07810aa576..04aa42c9f1 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu index 0cb873b40a..a8b9d4299a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, unsigned int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu index f8a69b896f..c97e6e936a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(uint32_t, float, int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu index 3c23d1f3e0..87451c385a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(uint32_t, float, unsigned int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 9b51d585de..46ef3a9150 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -121,7 +121,6 @@ void compute_bfknn(const raft::device_resources& handle, make_device_matrix_view(X2, n_query_rows, d), make_device_matrix_view(inds, n_query_rows, k), make_device_matrix_view(dists, n_query_rows, k), - k, metric); } diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index 4bb977432c..bcd4b9cb0b 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -96,7 +96,7 @@ class KNNTest : public ::testing::TestWithParam { raft::make_device_matrix_view(distances_.data(), rows_, k_); auto metric = raft::distance::DistanceType::L2Unexpanded; - knn(handle, index, search, indices, distances, k_, metric, std::make_optional(0)); + knn(handle, index, search, indices, distances, metric, std::make_optional(0)); build_actual_output<<>>( actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data());