-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a distance epilogue function to the bfknn call #1371
Conversation
This adds the ability for callers to post-process the distances in the tiled_brute_force_knn call, and lets us use this function for the hdbscan reachability code in cuml
And fuse the distance epilogue with the l2 adjustment where possible
// if we're not l2 distance, and we have a distance epilogue - run it now | ||
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) { | ||
auto distances_ptr = temp_distances.data(); | ||
raft::linalg::map_offset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
handle, | ||
raft::make_device_vector_view(temp_distances.data(), | ||
current_query_size * current_centroid_size), | ||
[=] __device__(size_t i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes sense to do this as an epilogue on the pairwise distances rather than burdening the (already register constrained) k selection with it.
This formulation of the algorithm (tiled knn with epilogue after tiling) makes this portion equivalent to our gramm matrix API for constructing RKHS kernels (see raft::distance::kernel) and it makes me wonder if there's something to be gained by consolidating these eventually (eg brute force knn primitive becomes a composition of [tiled gramm + epilogue] + k-selection. That could also allow us to reuse more from the gram APIs.
@@ -138,13 +138,15 @@ inline void knn_merge_parts( | |||
* is ignored if the metric_type is not Minkowski. | |||
* @param[in] global_id_offset: optional starting global id mapping for the local partition | |||
* (assumes the index contains contiguous ids in the global id space) | |||
* @param[in] distance_epilogue: optional epilogue function to run after computing distances |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the expectation if the epilogue function's argument list here too? Just a small example prototype of the function definition will do.
template <typename ElementType = float, typename IndexType = int64_t> | ||
template <typename ElementType = float, | ||
typename IndexType = int64_t, | ||
typename DistanceEpilogue = raft::identity_op> | ||
void tiled_brute_force_knn(const raft::device_resources& handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, id eventually love to find ways we can consolidate (and reuse) this with the gram matrix APIs. Not an immediate priority but the current gramm matrix API is a class that doesn't need to store state and we have a todo to convert it into flattened public API functions like the rest of RAFT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sorry! I meant to request changes, not approve. I think it looks great so far but just some minor things
Use the extents from the output mdspan instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
/merge |
Add the ability for a user to specify an epilogue function to run after the distance in the brute_force::knn call.
This lets us remove faiss from cuml, by updating the hdbscan reachability code (rapidsai/cuml#5293)