Skip to content

Commit

Permalink
Using raft::resources across raft::random (#1420)
Browse files Browse the repository at this point in the history
Eventually we need to do this across all the headers in the codebase so that users have a choice as to whether they want to use `raft::device_resources` (which implicitly depends on the cuda math libs and thrust) or whether they just want to use `raft::resources` (which is agnostic of the resources it contains and allows the primitives themselves to levvy the dependency requirements). 

cc @MatthiasKohl this *should* allow cugraph-ops to completely remove the math libs dependency (though the conda recipes will also need to be changed to depend on `libraft-headers-only` and the cmake changed to turn off the CTK math libs dependency). 

**NOTE**: Before this PR is merged, it's important that it be tested w/ cugraph/cuml at the very least to spot any cases where the `device_resources.hpp` include was being assumed transitively from the RAFT functions.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: #1420
  • Loading branch information
cjnolet authored May 5, 2023
1 parent 641f164 commit aa9d686
Show file tree
Hide file tree
Showing 31 changed files with 214 additions and 150 deletions.
5 changes: 3 additions & 2 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ class mdarray
#endif // RAFT_MDARRAY_CTOR_CONSTEXPR

/**
* @brief The only constructor that can create storage, this is to make sure CUDA stream is being
* used.
* @brief The only constructor that can create storage, raft::resources is accepted
* so that the device implementation can make sure the relevant CUDA stream is
* being used for allocation.
*/
RAFT_MDARRAY_CTOR_CONSTEXPR mdarray(raft::resources const& handle,
mapping_type const& m,
Expand Down
12 changes: 7 additions & 5 deletions cpp/include/raft/linalg/detail/qr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "cublas_wrappers.hpp"
#include "cusolver_wrappers.hpp"
#include <raft/core/resource/cusolver_dn_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/matrix.cuh>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>
Expand All @@ -42,10 +44,10 @@ namespace detail {
*/
template <typename math_t>
void qrGetQ_inplace(
raft::device_resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream)
raft::resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream)
{
RAFT_EXPECTS(n_rows >= n_cols, "QR decomposition expects n_rows >= n_cols.");
cusolverDnHandle_t cusolver = handle.get_cusolver_dn_handle();
cusolverDnHandle_t cusolver = resource::get_cusolver_dn_handle(handle);

rmm::device_uvector<math_t> tau(n_cols, stream);
RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * n_cols, stream));
Expand Down Expand Up @@ -83,7 +85,7 @@ void qrGetQ_inplace(
}

template <typename math_t>
void qrGetQ(raft::device_resources const& handle,
void qrGetQ(raft::resources const& handle,
const math_t* M,
math_t* Q,
int n_rows,
Expand All @@ -95,15 +97,15 @@ void qrGetQ(raft::device_resources const& handle,
}

template <typename math_t>
void qrGetQR(raft::device_resources const& handle,
void qrGetQR(raft::resources const& handle,
math_t* M,
math_t* Q,
math_t* R,
int n_rows,
int n_cols,
cudaStream_t stream)
{
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle);

int m = n_rows, n = n_cols;
rmm::device_uvector<math_t> R_full(m * n, stream);
Expand Down
20 changes: 11 additions & 9 deletions cpp/include/raft/linalg/detail/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#include "cublas_wrappers.hpp"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <rmm/exec_policy.hpp>
#include <thrust/for_each.h>
#include <thrust/iterator/counting_iterator.h>
Expand All @@ -29,14 +31,14 @@ namespace linalg {
namespace detail {

template <typename math_t>
void transpose(raft::device_resources const& handle,
void transpose(raft::resources const& handle,
math_t* in,
math_t* out,
int n_rows,
int n_cols,
cudaStream_t stream)
{
cublasHandle_t cublas_h = handle.get_cublas_handle();
cublasHandle_t cublas_h = resource::get_cublas_handle(handle);
RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream));

int out_n_rows = n_cols;
Expand Down Expand Up @@ -83,7 +85,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream)

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
void transpose_row_major_impl(
raft::device_resources const& handle,
raft::resources const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
Expand All @@ -92,7 +94,7 @@ void transpose_row_major_impl(
T constexpr kOne = 1;
T constexpr kZero = 0;

CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_TRY(cublasgeam(resource::get_cublas_handle(handle),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_n_cols,
Expand All @@ -105,12 +107,12 @@ void transpose_row_major_impl(
out.stride(0),
out.data_handle(),
out.stride(0),
handle.get_stream()));
resource::get_cuda_stream(handle)));
}

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
void transpose_col_major_impl(
raft::device_resources const& handle,
raft::resources const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
Expand All @@ -119,7 +121,7 @@ void transpose_col_major_impl(
T constexpr kOne = 1;
T constexpr kZero = 0;

CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_TRY(cublasgeam(resource::get_cublas_handle(handle),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_n_rows,
Expand All @@ -132,7 +134,7 @@ void transpose_col_major_impl(
out.stride(1),
out.data_handle(),
out.stride(1),
handle.get_stream()));
resource::get_cuda_stream(handle)));
}
}; // end namespace detail
}; // end namespace linalg
Expand Down
19 changes: 13 additions & 6 deletions cpp/include/raft/linalg/qr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#pragma once

#include "detail/qr.cuh"
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

namespace raft {
namespace linalg {
Expand All @@ -33,7 +35,7 @@ namespace linalg {
* @param stream cuda stream
*/
template <typename math_t>
void qrGetQ(raft::device_resources const& handle,
void qrGetQ(raft::resources const& handle,
const math_t* M,
math_t* Q,
int n_rows,
Expand All @@ -54,7 +56,7 @@ void qrGetQ(raft::device_resources const& handle,
* @param stream cuda stream
*/
template <typename math_t>
void qrGetQR(raft::device_resources const& handle,
void qrGetQR(raft::resources const& handle,
math_t* M,
math_t* Q,
math_t* R,
Expand All @@ -77,13 +79,18 @@ void qrGetQR(raft::device_resources const& handle,
* @param[out] Q Output raft::device_matrix_view
*/
template <typename ElementType, typename IndexType>
void qr_get_q(raft::device_resources const& handle,
void qr_get_q(raft::resources const& handle,
raft::device_matrix_view<const ElementType, IndexType, raft::col_major> M,
raft::device_matrix_view<ElementType, IndexType, raft::col_major> Q)
{
RAFT_EXPECTS(Q.size() == M.size(), "Size mismatch between Output and Input");

qrGetQ(handle, M.data_handle(), Q.data_handle(), M.extent(0), M.extent(1), handle.get_stream());
qrGetQ(handle,
M.data_handle(),
Q.data_handle(),
M.extent(0),
M.extent(1),
resource::get_cuda_stream(handle));
}

/**
Expand All @@ -94,7 +101,7 @@ void qr_get_q(raft::device_resources const& handle,
* @param[out] R Output raft::device_matrix_view
*/
template <typename ElementType, typename IndexType>
void qr_get_qr(raft::device_resources const& handle,
void qr_get_qr(raft::resources const& handle,
raft::device_matrix_view<const ElementType, IndexType, raft::col_major> M,
raft::device_matrix_view<ElementType, IndexType, raft::col_major> Q,
raft::device_matrix_view<ElementType, IndexType, raft::col_major> R)
Expand All @@ -107,7 +114,7 @@ void qr_get_qr(raft::device_resources const& handle,
R.data_handle(),
M.extent(0),
M.extent(1),
handle.get_stream());
resource::get_cuda_stream(handle));
}

/** @} */
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/raft/linalg/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "detail/transpose.cuh"
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resources.hpp>

namespace raft {
namespace linalg {
Expand All @@ -34,7 +35,7 @@ namespace linalg {
* @param stream: cuda stream
*/
template <typename math_t>
void transpose(raft::device_resources const& handle,
void transpose(raft::resources const& handle,
math_t* in,
math_t* out,
int n_rows,
Expand Down Expand Up @@ -76,7 +77,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream)
* @param[out] out Output matirx, storage is pre-allocated by caller.
*/
template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
auto transpose(raft::device_resources const& handle,
auto transpose(raft::resources const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
-> std::enable_if_t<std::is_floating_point_v<T>, void>
Expand Down
13 changes: 6 additions & 7 deletions cpp/include/raft/random/detail/make_regression.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

#include <algorithm>

#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/init.cuh>
Expand Down Expand Up @@ -52,7 +53,7 @@ static __global__ void _singular_profile_kernel(DataT* out, IdxT n, DataT tail_s

/* Internal auxiliary function to generate a low-rank matrix */
template <typename DataT, typename IdxT>
static void _make_low_rank_matrix(raft::device_resources const& handle,
static void _make_low_rank_matrix(raft::resources const& handle,
DataT* out,
IdxT n_rows,
IdxT n_cols,
Expand All @@ -61,8 +62,7 @@ static void _make_low_rank_matrix(raft::device_resources const& handle,
raft::random::RngState& r,
cudaStream_t stream)
{
cusolverDnHandle_t cusolver_handle = handle.get_cusolver_dn_handle();
cublasHandle_t cublas_handle = handle.get_cublas_handle();
cublasHandle_t cublas_handle = resource::get_cublas_handle(handle);

IdxT n = std::min(n_rows, n_cols);

Expand Down Expand Up @@ -143,7 +143,7 @@ static __global__ void _gather2d_kernel(
}

template <typename DataT, typename IdxT>
void make_regression_caller(raft::device_resources const& handle,
void make_regression_caller(raft::resources const& handle,
DataT* out,
DataT* values,
IdxT n_rows,
Expand All @@ -162,8 +162,7 @@ void make_regression_caller(raft::device_resources const& handle,
{
n_informative = std::min(n_informative, n_cols);

cusolverDnHandle_t cusolver_handle = handle.get_cusolver_dn_handle();
cublasHandle_t cublas_handle = handle.get_cublas_handle();
cublasHandle_t cublas_handle = resource::get_cublas_handle(handle);

cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);
raft::random::RngState r(seed, type);
Expand Down
36 changes: 19 additions & 17 deletions cpp/include/raft/random/detail/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
#include <memory>
#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cusolver_dn_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/detail/cusolver_wrappers.hpp>
#include <raft/linalg/matrix_vector_op.cuh>
Expand Down Expand Up @@ -139,18 +142,16 @@ class multi_variable_gaussian_impl {
int *info, Lwork, info_h;
syevjInfo_t syevj_params = NULL;
curandGenerator_t gen;
raft::device_resources const& handle;
raft::resources const& handle;
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
bool deinitilized = false;

public: // functions
multi_variable_gaussian_impl() = delete;
multi_variable_gaussian_impl(raft::device_resources const& handle,
const int dim,
Decomposer method)
multi_variable_gaussian_impl(raft::resources const& handle, const int dim, Decomposer method)
: handle(handle), dim(dim), method(method)
{
auto cusolverHandle = handle.get_cusolver_dn_handle();
auto cusolverHandle = resource::get_cusolver_dn_handle(handle);

CURAND_CHECK(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(gen, 28)); // SEED
Expand Down Expand Up @@ -191,9 +192,9 @@ class multi_variable_gaussian_impl {

void give_gaussian(const int nPoints, T* P, T* X, const T* x = 0)
{
auto cusolverHandle = handle.get_cusolver_dn_handle();
auto cublasHandle = handle.get_cublas_handle();
auto cudaStream = handle.get_stream();
auto cusolverHandle = resource::get_cusolver_dn_handle(handle);
auto cublasHandle = resource::get_cublas_handle(handle);
auto cudaStream = resource::get_cuda_stream(handle);
if (method == chol_decomp) {
// lower part will contains chol_decomp
RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf(
Expand Down Expand Up @@ -299,7 +300,7 @@ class multi_variable_gaussian_setup_token;

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> build_multi_variable_gaussian_token_impl(
raft::device_resources const& handle,
raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method);
Expand All @@ -315,7 +316,7 @@ template <typename ValueType>
class multi_variable_gaussian_setup_token {
template <typename T>
friend multi_variable_gaussian_setup_token<T> build_multi_variable_gaussian_token_impl(
raft::device_resources const& handle,
raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method);
Expand All @@ -342,7 +343,7 @@ class multi_variable_gaussian_setup_token {

// Constructor, only for use by friend functions.
// Hiding this will let us change the implementation in the future.
multi_variable_gaussian_setup_token(raft::device_resources const& handle,
multi_variable_gaussian_setup_token(raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method)
Expand Down Expand Up @@ -399,22 +400,23 @@ class multi_variable_gaussian_setup_token {

private:
std::unique_ptr<multi_variable_gaussian_impl<ValueType>> impl_;
raft::device_resources const& handle_;
raft::resources const& handle_;
rmm::mr::device_memory_resource& mem_resource_;
int dim_ = 0;

auto allocate_workspace() const
{
const auto num_elements = impl_->get_workspace_size();
return rmm::device_uvector<ValueType>{num_elements, handle_.get_stream(), &mem_resource_};
return rmm::device_uvector<ValueType>{
num_elements, resource::get_cuda_stream(handle_), &mem_resource_};
}

int dim() const { return dim_; }
};

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> build_multi_variable_gaussian_token_impl(
raft::device_resources const& handle,
raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method)
Expand All @@ -434,7 +436,7 @@ void compute_multi_variable_gaussian_impl(

template <typename ValueType>
void compute_multi_variable_gaussian_impl(
raft::device_resources const& handle,
raft::resources const& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
Expand All @@ -455,7 +457,7 @@ class multi_variable_gaussian : public detail::multi_variable_gaussian_impl<T> {
// using detail::multi_variable_gaussian_impl<T>::Decomposer::qr;

multi_variable_gaussian() = delete;
multi_variable_gaussian(raft::device_resources const& handle,
multi_variable_gaussian(raft::resources const& handle,
const int dim,
typename detail::multi_variable_gaussian_impl<T>::Decomposer method)
: detail::multi_variable_gaussian_impl<T>{handle, dim, method}
Expand Down
Loading

0 comments on commit aa9d686

Please sign in to comment.