From fbe6f29a201270a0d12d6c41b4a3321851da5cd8 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 22 Mar 2023 11:50:23 -0700 Subject: [PATCH 1/7] Add a post processing op for the tiled bfknn call This adds the ability for callers to post-process the distances in the tiled_brute_force_knn call, and lets us use this function for the hdbscan reachability code in cuml --- .../raft/neighbors/detail/knn_brute_force.cuh | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 875fc3b37c..01cd5cf353 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -47,7 +47,9 @@ using namespace raft::spatial::knn; * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances */ -template +template void tiled_brute_force_knn(const raft::device_resources& handle, const ElementType* search, // size (m ,d) const ElementType* index, // size (n ,d) @@ -58,9 +60,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 0.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0) + float metric_arg = 0.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + PostDistanceOp post_distance_op = raft::identity_op()) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -173,6 +176,8 @@ void tiled_brute_force_knn(const raft::device_resources& handle, }); } + post_distance_op(temp_distances.data(), i, j, current_query_size, current_centroid_size); + select_k(temp_distances.data(), nullptr, current_query_size, From 319ffeac1bfb3d30fb836d8b8387c5c0e67b3f9a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 23 Mar 2023 19:02:37 -0700 Subject: [PATCH 2/7] Use device distance epilogue And fuse the distance epilogue with the l2 adjustment where possible --- cpp/include/raft/neighbors/brute_force.cuh | 9 ++-- .../raft/neighbors/detail/knn_brute_force.cuh | 44 ++++++++++++++----- .../neighbors/specializations/brute_force.cuh | 3 +- .../brute_force_knn_impl_long_float_int.cu | 3 +- .../brute_force_knn_impl_long_float_uint.cu | 3 +- .../brute_force_knn_impl_uint_float_int.cu | 3 +- .../brute_force_knn_impl_uint_float_uint.cu | 3 +- 7 files changed, 48 insertions(+), 20 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 4891cc5f8d..7f93171901 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -144,7 +144,8 @@ template + typename search_layout, + typename epilogue_op = raft::identity_op> void knn(raft::device_resources const& handle, std::vector> index, raft::device_matrix_view search, @@ -153,7 +154,8 @@ void knn(raft::device_resources const& handle, value_int k, distance::DistanceType metric = distance::DistanceType::L2Unexpanded, std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt) + std::optional global_id_offset = std::nullopt, + epilogue_op distance_epilogue = raft::identity_op()) { RAFT_EXPECTS(index[0].extent(1) == search.extent(1), "Number of dimensions for both index and search matrices must be equal"); @@ -194,7 +196,8 @@ void knn(raft::device_resources const& handle, rowMajorQuery, trans_arg, metric, - metric_arg.value_or(2.0f)); + metric_arg.value_or(2.0f), + distance_epilogue); } /** diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 01cd5cf353..09d9b96244 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -47,9 +47,9 @@ using namespace raft::spatial::knn; * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances */ -template +template void tiled_brute_force_knn(const raft::device_resources& handle, const ElementType* search, // size (m ,d) const ElementType* index, // size (n ,d) @@ -60,10 +60,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 0.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0, - PostDistanceOp post_distance_op = raft::identity_op()) + float metric_arg = 0.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + DistanceEpilogue distance_epilogue = raft::identity_op()) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -172,12 +172,24 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // cause NaN values in the subsequent sqrt. Clamp to 0 val = val * (val >= 0.0001); if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); } + val = distance_epilogue(val, row, col); return val; }); + } else { + // if we're not l2 distance, and we have a distance epilogue - run it now + if constexpr (!std::is_same_v) { + auto distances_ptr = temp_distances.data(); + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(temp_distances.data(), + current_query_size * current_centroid_size), + [=] __device__(size_t i) { + return distance_epilogue( + distances_ptr[i], i / current_centroid_size, i % current_centroid_size); + }); + } } - post_distance_op(temp_distances.data(), i, j, current_query_size, current_centroid_size); - select_k(temp_distances.data(), nullptr, current_query_size, @@ -255,7 +267,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm */ -template +template void brute_force_knn_impl( raft::device_resources const& handle, std::vector& input, @@ -270,7 +285,8 @@ void brute_force_knn_impl( bool rowMajorQuery = true, std::vector* translations = nullptr, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metricArg = 0) + float metricArg = 0, + DistanceEpilogue distance_epilogue = raft::identity_op()) { auto userStream = handle.get_stream(); @@ -360,6 +376,7 @@ void brute_force_knn_impl( auto stream = handle.get_next_usable_stream(i); if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + std::is_same_v && (metric == raft::distance::DistanceType::L2Unexpanded || metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::L2Expanded || @@ -429,7 +446,10 @@ void brute_force_knn_impl( out_d_ptr, out_i_ptr, tiled_metric, - metricArg); + metricArg, + 0, + 0, + distance_epilogue); break; } } diff --git a/cpp/include/raft/neighbors/specializations/brute_force.cuh b/cpp/include/raft/neighbors/specializations/brute_force.cuh index d418d40185..1337beb68a 100644 --- a/cpp/include/raft/neighbors/specializations/brute_force.cuh +++ b/cpp/include/raft/neighbors/specializations/brute_force.cuh @@ -36,7 +36,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, int); RAFT_INST(long, float, unsigned int); RAFT_INST(uint32_t, float, int); diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu index 07810aa576..04aa42c9f1 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu index 0cb873b40a..a8b9d4299a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, unsigned int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu index f8a69b896f..c97e6e936a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(uint32_t, float, int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu index 3c23d1f3e0..87451c385a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(uint32_t, float, unsigned int); #undef RAFT_INST } // namespace raft::neighbors::detail From acc674f96e3702d26e2613ab1923b1574a786ba3 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 24 Mar 2023 10:54:28 -0700 Subject: [PATCH 3/7] add docstring --- cpp/include/raft/neighbors/brute_force.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 7f93171901..d0ad7051d6 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -138,6 +138,7 @@ inline void knn_merge_parts( * is ignored if the metric_type is not Minkowski. * @param[in] global_id_offset: optional starting global id mapping for the local partition * (assumes the index contains contiguous ids in the global id space) + * @param[in] distance_epilogue: optional epilogue function to run after computing distances */ template Date: Mon, 27 Mar 2023 11:46:18 -0700 Subject: [PATCH 4/7] Fix bug with epilogue --- .../raft/neighbors/detail/knn_brute_force.cuh | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 09d9b96244..904703cc62 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -155,17 +155,18 @@ void tiled_brute_force_knn(const raft::device_resources& handle, metric_arg); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto row_norms = search_norms.data() + i; - auto col_norms = index_norms.data() + j; + auto row_norms = search_norms.data(); + auto col_norms = index_norms.data(); auto dist = temp_distances.data(); raft::linalg::map_offset( handle, raft::make_device_vector_view(dist, current_query_size * current_centroid_size), - [=] __device__(IndexType i) { - IndexType row = i / current_centroid_size, col = i % current_centroid_size; + [=] __device__(IndexType idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); - auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + auto val = row_norms[row] + col_norms[col] - 2.0 * dist[idx]; // due to numerical instability (especially around self-distance) // the distances here could be slightly negative, which will @@ -183,9 +184,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, handle, raft::make_device_vector_view(temp_distances.data(), current_query_size * current_centroid_size), - [=] __device__(size_t i) { - return distance_epilogue( - distances_ptr[i], i / current_centroid_size, i % current_centroid_size); + [=] __device__(size_t idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + return distance_epilogue(distances_ptr[idx], row, col); }); } } From cd335a26657742022296d6a97e4eb20c73ec111e Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 27 Mar 2023 11:57:04 -0700 Subject: [PATCH 5/7] changes from codereview --- cpp/include/raft/neighbors/brute_force.cuh | 5 ++++- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index d0ad7051d6..cce5f2cd3a 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -138,7 +138,10 @@ inline void knn_merge_parts( * is ignored if the metric_type is not Minkowski. * @param[in] global_id_offset: optional starting global id mapping for the local partition * (assumes the index contains contiguous ids in the global id space) - * @param[in] distance_epilogue: optional epilogue function to run after computing distances + * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This + function takes a triple of the (value, rowid, colid) for each + element in the pairwise distances and returns a transformed value + back. */ template Date: Mon, 27 Mar 2023 16:16:19 -0700 Subject: [PATCH 6/7] Remove k parameter from brute_force::knn api Use the extents from the output mdspan instead --- cpp/include/raft/neighbors/brute_force.cuh | 15 ++++++--------- .../raft/spatial/knn/detail/ball_cover.cuh | 1 - .../neighbors/brute_force_knn_int64_t_float.cu | 11 ++--------- cpp/test/neighbors/ball_cover.cu | 1 - cpp/test/neighbors/knn.cu | 2 +- 5 files changed, 9 insertions(+), 21 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index cce5f2cd3a..c558c0fff5 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -145,7 +145,6 @@ inline void knn_merge_parts( */ template search, raft::device_matrix_view indices, raft::device_matrix_view distances, - value_int k, distance::DistanceType metric = distance::DistanceType::L2Unexpanded, std::optional metric_arg = std::make_optional(2.0f), std::optional global_id_offset = std::nullopt, @@ -167,15 +165,14 @@ void knn(raft::device_resources const& handle, RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), "Number of rows in output indices and distances matrices must equal number of rows " "in search matrix."); - RAFT_EXPECTS( - indices.extent(1) == distances.extent(1) && distances.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); + RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), + "Number of columns in output indices and distances matrices must the same"); bool rowMajorIndex = std::is_same_v; bool rowMajorQuery = std::is_same_v; std::vector inputs; - std::vector sizes; + std::vector sizes; for (std::size_t i = 0; i < index.size(); ++i) { inputs.push_back(const_cast(index[i].data_handle())); sizes.push_back(index[i].extent(0)); @@ -189,13 +186,13 @@ void knn(raft::device_resources const& handle, raft::neighbors::detail::brute_force_knn_impl(handle, inputs, sizes, - static_cast(index[0].extent(1)), + index[0].extent(1), // TODO: This is unfortunate. Need to fix. const_cast(search.data_handle()), - static_cast(search.extent(0)), + search.extent(0), indices.data_handle(), distances.data_handle(), - k, + indices.extent(1), rowMajorIndex, rowMajorQuery, trans_arg, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 99d688e232..c8fc6eefda 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -185,7 +185,6 @@ void k_closest_landmarks(raft::device_resources const& handle, make_device_matrix_view(query_pts, n_query_pts, inputs[0].extent(1)), make_device_matrix_view(R_knn_inds, n_query_pts, k), make_device_matrix_view(R_knn_dists, n_query_pts, k), - k, index.get_metric()); } diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu index 585084fc97..88545b3607 100644 --- a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu @@ -38,15 +38,8 @@ namespace raft::runtime::neighbors::brute_force { { \ std::vector> vec; \ vec.push_back(index); \ - raft::neighbors::brute_force::knn(handle, \ - vec, \ - search, \ - indices, \ - distances, \ - static_cast(distances.extent(1)), \ - metric, \ - metric_arg, \ - global_id_offset); \ + raft::neighbors::brute_force::knn( \ + handle, vec, search, indices, distances, metric, metric_arg, global_id_offset); \ } RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 9b51d585de..46ef3a9150 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -121,7 +121,6 @@ void compute_bfknn(const raft::device_resources& handle, make_device_matrix_view(X2, n_query_rows, d), make_device_matrix_view(inds, n_query_rows, k), make_device_matrix_view(dists, n_query_rows, k), - k, metric); } diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index 4bb977432c..bcd4b9cb0b 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -96,7 +96,7 @@ class KNNTest : public ::testing::TestWithParam { raft::make_device_matrix_view(distances_.data(), rows_, k_); auto metric = raft::distance::DistanceType::L2Unexpanded; - knn(handle, index, search, indices, distances, k_, metric, std::make_optional(0)); + knn(handle, index, search, indices, distances, metric, std::make_optional(0)); build_actual_output<<>>( actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data()); From 18e19a95b7e48fef1f32379469a46a1ba3334b3d Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 27 Mar 2023 21:33:05 -0700 Subject: [PATCH 7/7] fix docs --- cpp/include/raft/neighbors/brute_force.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index c558c0fff5..dac1a29c7f 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -122,9 +122,8 @@ inline void knn_merge_parts( * * raft::raft::device_resources handle; * ... - * int k = 10; * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::knn(handle, index, search, indices, distances, k, metric); + * brute_force::knn(handle, index, search, indices, distances, metric); * @endcode * * @param[in] handle: the cuml handle to use @@ -132,7 +131,6 @@ inline void knn_merge_parts( * @param[in] search: matrix (size n*d) to be used for searching the index * @param[out] indices: matrix (size n*k) to store output knn indices * @param[out] distances: matrix (size n*k) to store the output knn distance - * @param[in] k: the number of nearest neighbors to return * @param[in] metric: distance metric to use. Euclidean (L2) is used by default * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski.