-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pairwise_distance api's for C, Python and Rust #142
Add pairwise_distance api's for C, Python and Rust #142
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think overall this PR looks great. I'm pre-approving, but do want to make sure we verify the behavior of the C api for unsupported type combinations. Maybe you've already done this, but I just want to avoid any unintended behaviors.
* @param[in] metric distance to evaluate | ||
* @param[in] metric_arg metric argument (used for Minkowski distance) | ||
*/ | ||
cuvsError_t cuvsPairwiseDistance(cuvsResources_t res, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks so simple and elegant to use!
auto y_mds = cuvs::core::from_dlpack<mdspan_type>(y_tensor); | ||
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor); | ||
|
||
cuvs::distance::pairwise_distance(*res_ptr, x_mds, y_mds, distances_mds, metric, metric_arg); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know that is always going to use the float
(instantiated) version?
distances_tensor.dl_tensor.strides = NULL; | ||
|
||
// run pairwise distances | ||
cuvsPairwiseDistance(res, &dataset_tensor, &queries_tensor, &distances_tensor, metric, 2.0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you haven't done it before, can you try this w/ double
? (Just making sure we don't get any unnexpected surprises from the UX or the compile times / binary size).
@@ -30,9 +30,10 @@ from cuvs.distance_type cimport cuvsDistanceType | |||
from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray | |||
from pylibraft.common.cai_wrapper import wrap_array | |||
from pylibraft.common.interruptible import cuda_interruptible | |||
from pylibraft.distance.pairwise_distance import DISTANCE_TYPES | |||
from pylibraft.neighbors.common import _check_input_array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll probably need to bring this over at some point too (no rush for 24.06 since we already depend on pylibraft
.
/merge |
Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai/cuvs#142
No description provided.