-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add innerproduct to the pairwise distance api #1226
Add innerproduct to the pairwise distance api #1226
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it would be ideal to use cuBLAS for inner product instead of pairwiseDistanceMatKernel perf wise.
This adds InnerProduct distance to the pairwise distances api, using cublass gemm to compute the distance. Since this requires a cublas handle, this also changes the distance api to take a raft::device_resources instead of just a cuda stream
a94ceb6
to
c18564f
Compare
@mdoijade I've updated to use cublas here - I had to change the distance api to use a I think in general though we should be passing |
@@ -212,7 +212,7 @@ class GramMatrixBase { | |||
int ld_out) | |||
{ | |||
raft::distance::distance<raft::distance::DistanceType::L2Unexpanded, math_t, math_t, math_t>( | |||
x1, x2, out, n1, n2, n_cols, stream, is_row_major); | |||
raft::device_resources(stream), x1, x2, out, n1, n2, n_cols, is_row_major); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think from the public API standpoint, we want to strive to accept device_resources
everywhere (and eventually just raft::resources
) but things inside of detail
are okay to accept just the individual resources they need- so long as the public API is pulling resources like streams from the raft::resources
speified from the user and not creating their own.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, I'm okay with this function accepting the stream directly since it's in the detail
namespace. The kernel gramm
API itself is in kind of weird state right now. We moved it over from cuml mostly because it's 1) primitives built directly from the pairwise distances API, and 2) it required just enough customization that it was hard to use the existing specializtions and thus cuml's compile time was increased.
At some point, we do need to refactor the kernel gramm API into a set of flattened stateless public API functions but for now it's kind of been left in this state just so cuml doesn't break.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah - I didn't want to break this here by changing the kernel gramm API (for the other changes, cuml is using the pairwise_distances api's which already take a raft handle instead of a stream afaict).
This is a little weird in that its upgrading a stream -> device_resources, but since the L2Unexpanded distance here only uses the stream, this won't do anything funky like allocate a cublas handle.
Thanks @benfred this looks good to me. Do you think it is worth adding test case for inner product on the c++ distance tests layer as well ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look great overall! Mostly very minor things.
cpp/src/distance/distance/specializations/detail/inner_product.cu
Outdated
Show resolved
Hide resolved
@@ -57,7 +58,10 @@ def test_distance(n_rows, n_cols, inplace, metric, order, dtype): | |||
|
|||
output = np.zeros((n_rows, n_rows), dtype=dtype) | |||
|
|||
expected = cdist(input1, input1, metric) | |||
if metric == "inner_product": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We mentioned offline, we should test this through both the C++ and python APIs here. If you think it'd be helpful, I don't mind helping to add a quick google test case here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the pointers! I can't believe I missed the c++ tests for the other metrics (didn't see them when I first looked for some , though its super obvious they are there). Added in the last commit -
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha no problem. I just know a lot of things have been filling up people's plates lately so I didn't mind jumping in to help if needed.
@@ -109,20 +108,20 @@ template <raft::distance::DistanceType distanceType, | |||
typename AccType, | |||
typename OutType, | |||
typename Index_ = int> | |||
void distance(const InType* x, | |||
void distance(raft::resources const& handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a breaking change, though it might not? (e.g. comment here). We probably want to build cuml w/ this change at least to check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
testing out w/ cuml here rapidsai/cuml#5230
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs rapidsai/cuml@6f0d7fc in cuml looks like
@@ -202,21 +200,22 @@ template <raft::distance::DistanceType distanceType, | |||
typename AccType, | |||
typename OutType, | |||
typename Index_ = int> | |||
void distance(const InType* x, | |||
void distance(raft::resources const& handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same above (checking / updating cuml as needed).
@@ -238,7 +237,7 @@ void distance(const InType* x, | |||
* @param metric_arg metric argument (used for Minkowski distance) | |||
*/ | |||
template <typename Type, typename Index_ = int> | |||
void pairwise_distance(raft::device_resources const& handle, | |||
void pairwise_distance(raft::resources const& handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this one is breaking
Test out the changes in rapidsai/raft#1226, to make sure they don't break cuml
@@ -19,20 +19,8 @@ | |||
namespace raft { | |||
namespace distance { | |||
namespace detail { | |||
template void distance<raft::distance::DistanceType::Canberra, float, float, float, int>( | |||
const float* x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wow, so we were building the same canberra specialization multiple times? It makes sense now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
/merge |
Test out the changes in rapidsai/raft#1226, to make sure they don't break cuml
Fixes for supporting InnerProduct distance in the pairwise_distance api - required to handle the changes in rapidsai/raft#1226
Fixes for supporting InnerProduct distance in the pairwise_distance api - required to handle the changes in rapidsai/raft#1226. Also resolves #4078. That fix was necessary to tack on to this PR due to upstream RAPIDS updates to the spdlog version (in rmm via rapids-cmake). Authors: - Ben Frederickson (https://github.com/benfred) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #5230
No description provided.