Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUBLAS wrappers with switchable host/device pointer mode #453

Merged
merged 4 commits into from
Feb 4, 2022

Conversation

achirkin
Copy link
Contributor

@achirkin achirkin commented Jan 20, 2022

Add a few overloads for raft-CUBLAS gemv, gemm, axpy functions to support switching between host and device pointer mode. This allows passing some of the parameters (constants alpha, beta) as device pointers, which sometimes improves performance.

By default, CUBLAS context is created in the host pointer mode. To keep this presumption, the device pointer mode is enabled only for the time of a particular CUBLAS call.

This feature is required for rapidsai/cuml#4446.

@achirkin achirkin requested review from a team as code owners January 20, 2022 08:22
@github-actions github-actions bot added the cpp label Jan 20, 2022
@achirkin achirkin added enhancement New feature or request non-breaking Non-breaking change improvement Improvement / enhancement to an existing function 3 - Ready for Review and removed cpp labels Jan 20, 2022
@achirkin achirkin requested a review from cjnolet January 20, 2022 08:55
@achirkin
Copy link
Contributor Author

rerun tests

@achirkin achirkin changed the base branch from branch-22.02 to branch-22.04 January 25, 2022 06:56
@github-actions github-actions bot added the cpp label Feb 2, 2022
Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Artem for this PR! The changes look good to me.

@tfeher
Copy link
Contributor

tfeher commented Feb 2, 2022

One thing still worth to discuss with @cjnolet: how many overloads for gemm and gemv do we want to keep in the API? The advantage of the wrappers introduced by Artem is that they follow the standard BLAS argument order and they also expose the ld parameter.

If we are still consolidating the RAFT API then one could consider a follow up PR to remove some of the overloads (together with cleanup in downstream repositories):

  • For gemv, I am not convinced that we need any of the additional overloads.
  • For gemm, one could argue that the additional versions are useful, because they set several default parameters. I am still not sure if we need all of the overloads.

@cjnolet
Copy link
Member

cjnolet commented Feb 2, 2022

@tfeher,

how many overloads for gemm and gemv do we want to keep in the API?

Indeed, that was the first thing I noticed when I began reviewing this PR. I don't think we need them, and even more specifically, once the mdspan PR is merged, we should be able to acept device_scalar and host_scalar in place of the new pointers (or have them implicitly convertable).

That said, if these changes are urgent, I'm okay creating a Github issue to follow up when the mdsan is ready and getting these merged.

@cjnolet cjnolet removed the request for review from a team February 2, 2022 20:27
@achirkin
Copy link
Contributor Author

achirkin commented Feb 3, 2022

That said, if these changes are urgent, I'm okay creating a Github issue to follow up when the mdsan is ready and getting these merged.

Thanks, I'd prefer to get this merged now to proceed with rapidsai/cuml#4446, which is almost two month old now.

@tfeher
Copy link
Contributor

tfeher commented Feb 3, 2022

@cjnolet I had a discussion with Artem about this:

  • We need an additional overload that accepts device pointers (or device_scalar in the future).
  • Indeed, the preferred way would be to have an overload that accepts a device scalar or something equivalent.
  • But to make it actually useful in Rewrite CD solver using more BLAS cuml#4446, we would need a way to get a device scalar from a device_vector, which we do not have yet.
  • Since RMM device_scalar owns the scalar, constructing that would mean memcpy, which we want to avoid.
  • Alternatively one could think about accepting a device_span<math_t, 1> once [REVIEW] Span implementation. #399 is merged.

Considering these points, we would prefer to go forward with the current PR and rapidsai/cuml#4446. Once span / mdspan is merged, we shall consolidate the API of gemm / gemmv to use mdspan, and refactor the relevant calls in the CD solver.

@cjnolet
Copy link
Member

cjnolet commented Feb 3, 2022

Since RMM device_scalar owns the scalar, constructing that would mean memcpy, which we want to avoid.

I'd prefer not to use RMM in the public API. The new mdspan will provide a device_scalar_view which can be created from an rmm device_uvector without copying. I would prefer to use that when ready.

Otherwise, I don't see any reason to hold up this PR.

@cjnolet
Copy link
Member

cjnolet commented Feb 3, 2022

I've gone ahead and added our plan to an existing issue, which I created to clean up the ctk wrappers in cuml in general: #475

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@achirkin, just missing doxygen but otherwise looks great.

cpp/include/raft/linalg/axpy.h Show resolved Hide resolved
cpp/include/raft/linalg/gemm.cuh Show resolved Hide resolved
cpp/include/raft/linalg/gemv.h Show resolved Hide resolved
@achirkin achirkin requested a review from a team as a code owner February 3, 2022 15:36
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks again @achirkin!

@achirkin
Copy link
Contributor Author

achirkin commented Feb 4, 2022

rerun tests

1 similar comment
@cjnolet
Copy link
Member

cjnolet commented Feb 4, 2022

rerun tests

@cjnolet
Copy link
Member

cjnolet commented Feb 4, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit b7640ae into rapidsai:branch-22.04 Feb 4, 2022
@achirkin achirkin deleted the fea-more-cublas-wrappers branch March 31, 2022 06:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review cpp enhancement New feature or request improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants