Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067) #2067

Merged
merged 5 commits into from
Jan 16, 2024

Conversation

rhdong
Copy link
Member

@rhdong rhdong commented Dec 18, 2023

  • Add support for SDDMM by wrapping the cusparseSDDMM
  • This PR also moved some APIs shared with SpMM to the utils.cuh file.

@rhdong rhdong requested review from a team as code owners December 18, 2023 01:52
Copy link

copy-pr-bot bot commented Dec 18, 2023

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@rhdong rhdong requested a review from benfred December 18, 2023 01:52
@rhdong rhdong requested a review from cjnolet December 18, 2023 01:55
@rhdong rhdong added feature request New feature or request non-breaking Non-breaking change labels Dec 18, 2023
@rhdong rhdong changed the title [FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2xxx) [FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067) Dec 18, 2023
@benfred
Copy link
Member

benfred commented Dec 18, 2023

/ok to test

copy-pr-bot bot pushed a commit that referenced this pull request Dec 18, 2023
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
@benfred
Copy link
Member

benfred commented Dec 18, 2023

/ok to test

copy-pr-bot bot pushed a commit that referenced this pull request Dec 18, 2023
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
Copy link
Member

@benfred benfred left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great - thanks for the PR!

cpp/include/raft/sparse/linalg/sddmm.cuh Outdated Show resolved Hide resolved
@rhdong
Copy link
Member Author

rhdong commented Dec 18, 2023

/ok to test

1 similar comment
@benfred
Copy link
Member

benfred commented Dec 18, 2023

/ok to test

copy-pr-bot bot pushed a commit that referenced this pull request Dec 18, 2023
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
@benfred benfred added improvement Improvement / enhancement to an existing function and removed feature request New feature or request labels Dec 18, 2023
@rhdong rhdong force-pushed the rhdong/sddmm branch 2 times, most recently from 13b6a45 to 61d9558 Compare December 19, 2023 00:38
copy-pr-bot bot pushed a commit that referenced this pull request Dec 19, 2023
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
Copy link
Member

@benfred benfred left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

typename NZType,
typename LayoutPolicyA,
typename LayoutPolicyB>
void sddmm(raft::resources const& handle,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with the RAFT API functions, the order of parameters should be:

  1. Handle
  2. Input view
  3. Output view
  4. Extra parameter (alpha, beta, trans_a, trans_b)

(Even though spmm currently doesn't do that)
And using raft::host_scalar_view should be a good idea instead of a raw pointer for alpha and beta.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense! But considering the minimum of the customer's learning cost, may I suggest here to keep the params in a similar sequence with the cuSparse original API?

Copy link
Member

@cjnolet cjnolet Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RAFT APIs don't assume the user is familiar with, nor even knows that cusparse is being used under the hood, so it's best to keep up with the conventions established by RAFT so that the user has a consistent experience.

@lowener is correct to point this out- we had a lot of discussions about this when the convention was established and we strive to use the same conventions everywhere.

(Resources, param structs, in, out, params)

Our APIs are also intentionally based on mdspan so we shouldn't accept pointers anywhere. All of the pointer-based APIs that are exposed publicly are deprecated.

If alpha and beta can be specified on device or host, we should capture this with mdspans. This also makes the APIs self documenting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it.

typename IndexType,
typename LayoutPolicyA,
typename LayoutPolicyB>
bool is_row_major(raft::device_matrix_view<ValueTypeA, IndexType, LayoutPolicyA>& a,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function can be constexpr. Can it also reuse raft::util::is_row_major ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept, and I removed it because the new implement doesn't need the API anymore.

copy-pr-bot bot pushed a commit that referenced this pull request Dec 20, 2023
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
copy-pr-bot bot pushed a commit that referenced this pull request Dec 20, 2023
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
@rhdong rhdong force-pushed the rhdong/sddmm branch 2 times, most recently from 523f4f7 to cdae2b5 Compare January 6, 2024 01:48
copy-pr-bot bot pushed a commit that referenced this pull request Jan 6, 2024
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
@rhdong
Copy link
Member Author

rhdong commented Jan 6, 2024

Hi @cjnolet @benfred @lowener, I updated the code inspired by our discussion. Please help review it if you have time. Thank you!

@@ -32,4 +32,11 @@ enum class Apply { ALONG_ROWS, ALONG_COLUMNS };
*/
enum class FillMode { UPPER, LOWER };

/**
* @brief Enum for this type indicates which operations is applied to the related input (e.g. sparse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operation

*/
template <typename ValueType, typename IndexType, typename LayoutPolicy>
cusparseDnMatDescr_t create_descriptor(
raft::device_matrix_view<ValueType, IndexType, LayoutPolicy>& dense_view, const bool is_row_major)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Views/mdspans are usually passed by value and not by reference because they are lightweight.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_row_major is not necessary here because this information can be inferred from the layout of the matrix view. Call raft::is_row_major() inside this function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept, just inherit from the spmm.

* @brief convert the operation to cusparseOperation_t type
* @tparam OpVal type of operation
*/
static inline cusparseOperation_t convert_operation(const raft::linalg::Operation& op)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function need to be static? The reference is not needed as well, this can be passed by value.

Copy link
Member Author

@rhdong rhdong Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept!


/**
* @brief convert the operation to cusparseOperation_t type
* @tparam OpVal type of operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update this comment: param[in] op type of operation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

copy-pr-bot bot pushed a commit that referenced this pull request Jan 9, 2024
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
copy-pr-bot bot pushed a commit that referenced this pull request Jan 9, 2024
- Add support for SDDMM by wrapping the `cusparseSDDMM`
- This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file.

Authors:
  - James Rong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2067
@cjnolet
Copy link
Member

cjnolet commented Jan 9, 2024

@rhdong Why so many force pushes to the branch? You should be able to merge upstream into your branch cleanly. The commits are squashed automatically upon merging the PR so there's no reason to rewrite history.

@rhdong
Copy link
Member Author

rhdong commented Jan 9, 2024

@rhdong Why so many force pushes to the branch? You should be able to merge upstream into your branch cleanly. The commits are squashed automatically upon merging the PR so there's no reason to rewrite history.

OK, get it~(just want to make the PR as one commit and keep the history clear)

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great so far! I mostly have minor things (in addition to automating the expected test data and contributing the benchmarks that you've put a lot of effort into).

&bufferSize,
resource::get_cuda_stream(handle)));

raft::interruptible::synchronize(resource::get_cuda_stream(handle));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to centralize the interruptible calls instead of calling it directly, please use resource::sync_stream() instead.

@@ -0,0 +1,103 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to avoid confusion, can we rename this file to cusparse_utils.hpp? (Please note this shouldn't be a cuh because it's not creating any device functions).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept

@@ -0,0 +1,83 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this file doesn't create or use any device functions, please rename to sddmm.hpp. This is a great designation to users that it only reuires the CUDA runtime APIs, math libs, and nothing else.

@@ -19,6 +19,7 @@
#pragma once

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should probaly be hpp as well. Up to you whether you want to rename it in this PR (since it's already quite big).

@@ -0,0 +1,425 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New file should only have current year

@@ -0,0 +1,83 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New files should only have current year.

@@ -0,0 +1,103 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New files should only have current year

@@ -0,0 +1,99 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New files should only have current year

4,
4,
3,
1.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few places where we do this because it's really hard to generate the outputs for automating the comparisons. The problem with hardcoding these things is that it's non-trivial to add new test cases and so when a user comes comes to us with an issue where they are seeing strange results with some sparsities or specific inputs, we couldn't otherwise be able to quickly reproduce in a test once we fix the issue.

I think it's easy enough to automate theexpected outputs, though- since you are already creating a mask, just copy out the code from your micro benchmarks that creates a random mask, generate two input dense arrays (using raft::random::make_blobs), compute the pairwise distances between A and B, and then copy in the rows/cols of the results to your "expected output" parse structure.

Given how much work went into benchmarking, I'd also highly suggest we commit your benchmarking code with these changes also. Since it was already written, I can't express enough how convenient it is to be able to load up a simple microbenchmark when someone finds a specific case to test. I understand this version is a lightweight wrapper around the SDDMM from cusparse, but that's likely not always going to be the case and so automated tests and benchmarks helps us evolve the code over time (and back-ends).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense!

ValueType(1.0f),
uint64_t(2024));

raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tests look great, thanks!

@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a publicly facing header, we should add a new file called spmm.cuh and have it import this header (with an include guard, of course), so that it doesn't break users downstream.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense.

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @rhdong!

@cjnolet
Copy link
Member

cjnolet commented Jan 16, 2024

/merge

@rapids-bot rapids-bot bot merged commit 3c7586f into rapidsai:branch-24.02 Jan 16, 2024
61 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CMake cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

4 participants