Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a distance epilogue function to the bfknn call #1371

Merged
merged 14 commits into from
Mar 28, 2023

Conversation

benfred
Copy link
Member

@benfred benfred commented Mar 24, 2023

Add the ability for a user to specify an epilogue function to run after the distance in the brute_force::knn call.

This lets us remove faiss from cuml, by updating the hdbscan reachability code (rapidsai/cuml#5293)

benfred added 2 commits March 22, 2023 11:50
This adds the ability for callers to post-process the distances
in the tiled_brute_force_knn call, and lets us use this function
for the hdbscan reachability code in cuml
And fuse the distance epilogue with the l2 adjustment where possible
@benfred benfred requested a review from a team as a code owner March 24, 2023 02:05
@benfred benfred added non-breaking Non-breaking change improvement Improvement / enhancement to an existing function labels Mar 24, 2023
@github-actions github-actions bot added the cpp label Mar 24, 2023
cpp/include/raft/neighbors/brute_force.cuh Show resolved Hide resolved
cpp/include/raft/neighbors/brute_force.cuh Outdated Show resolved Hide resolved
cpp/include/raft/neighbors/brute_force.cuh Show resolved Hide resolved
// if we're not l2 distance, and we have a distance epilogue - run it now
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) {
auto distances_ptr = temp_distances.data();
raft::linalg::map_offset(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

handle,
raft::make_device_vector_view(temp_distances.data(),
current_query_size * current_centroid_size),
[=] __device__(size_t i) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to do this as an epilogue on the pairwise distances rather than burdening the (already register constrained) k selection with it.

This formulation of the algorithm (tiled knn with epilogue after tiling) makes this portion equivalent to our gramm matrix API for constructing RKHS kernels (see raft::distance::kernel) and it makes me wonder if there's something to be gained by consolidating these eventually (eg brute force knn primitive becomes a composition of [tiled gramm + epilogue] + k-selection. That could also allow us to reuse more from the gram APIs.

@@ -138,13 +138,15 @@ inline void knn_merge_parts(
* is ignored if the metric_type is not Minkowski.
* @param[in] global_id_offset: optional starting global id mapping for the local partition
* (assumes the index contains contiguous ids in the global id space)
* @param[in] distance_epilogue: optional epilogue function to run after computing distances
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the expectation if the epilogue function's argument list here too? Just a small example prototype of the function definition will do.

template <typename ElementType = float, typename IndexType = int64_t>
template <typename ElementType = float,
typename IndexType = int64_t,
typename DistanceEpilogue = raft::identity_op>
void tiled_brute_force_knn(const raft::device_resources& handle,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, id eventually love to find ways we can consolidate (and reuse) this with the gram matrix APIs. Not an immediate priority but the current gramm matrix API is a class that doesn't need to store state and we have a todo to convert it into flattened public API functions like the rest of RAFT.

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry! I meant to request changes, not approve. I think it looks great so far but just some minor things

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@cjnolet
Copy link
Member

cjnolet commented Mar 28, 2023

/merge

@rapids-bot rapids-bot bot merged commit 0d3bd3d into rapidsai:branch-23.04 Mar 28, 2023
@benfred benfred deleted the post_distance_op branch March 28, 2023 17:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

3 participants