Skip to content

Commit

Permalink
Use device_memory_resource instead of workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
mhoemmen committed Sep 27, 2022
1 parent 5e7750d commit 9b8ced5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
49 changes: 36 additions & 13 deletions cpp/include/raft/random/detail/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cmath>
#include <memory>
#include <optional>
#include <rmm/device_uvector.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
Expand Down Expand Up @@ -300,7 +301,11 @@ multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method);

template <typename ValueType>
std::size_t workspace_size(const multi_variable_gaussian_setup_token<ValueType>& token);
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method);

// @param x[in] vector of dim elements
// @param P[inout] On input, dim x dim matrix; overwritten on output
Expand All @@ -323,7 +328,11 @@ class multi_variable_gaussian_setup_token {
multi_variable_gaussian_decomposition_method method);

template <typename T>
friend std::size_t workspace_size(multi_variable_gaussian_setup_token<T>& token);
friend multi_variable_gaussian_setup_token<T> setup_multi_variable_gaussian(
const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method);

template <typename T>
friend void compute_multi_variable_gaussian(
Expand All @@ -348,50 +357,62 @@ class multi_variable_gaussian_setup_token {
// Constructor, only for use by friend functions.
// Hiding this will let us change the implementation in the future.
multi_variable_gaussian_setup_token(const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method)
: impl_(std::make_unique<multi_variable_gaussian_impl<ValueType>>(
handle, dim, new_enum_to_old_enum(method))),
handle_(handle),
mem_resource_(mem_resource),
dim_(dim)
{
RAFT_EXPECTS(mem_resource_ != nullptr, "multi_variable_gaussian: device_memory_resource pointer must be nonnull");
}

private:
std::unique_ptr<multi_variable_gaussian_impl<ValueType>> impl_;
const raft::handle_t& handle_;
rmm::mr::device_memory_resource* mem_resource_ = nullptr;
int dim_ = 0;

public:
// FIXME (mfh 2022/09/23) Just a hack, because my friend declarations aren't working.
multi_variable_gaussian_impl<ValueType>& get_impl() const { return *impl_; }

auto allocate_workspace() const {
const auto num_elements = impl_->get_workspace_size();
return rmm::device_uvector<ValueType>{num_elements, handle_.get_stream(), mem_resource_};
}

int dim() const { return dim_; }
};

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method)
const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method)
{
return multi_variable_gaussian_setup_token<ValueType>(handle, dim, method);
return multi_variable_gaussian_setup_token<ValueType>(handle, mem_resource, dim, method);
}

template <typename ValueType>
std::size_t workspace_size(multi_variable_gaussian_setup_token<ValueType>& token)
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method)
{
return token.get_impl().get_workspace_size();
rmm::mr::device_memory_resource* mem_resource =
rmm::mr::get_current_device_resource();
return multi_variable_gaussian_setup_token<ValueType>(handle, mem_resource, dim, method);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
raft::device_vector_view<ValueType, int> workspace)
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
RAFT_EXPECTS(static_cast<std::size_t>(workspace.extent(0)) >= workspace_size(token),
"multi_variable_gaussian: Insufficient workspace");
token.get_impl().set_workspace(workspace.data_handle());

const int dim = P.extent(0);
RAFT_EXPECTS(dim == token.dim(),
"multi_variable_gaussian: "
Expand All @@ -409,7 +430,6 @@ void compute_multi_variable_gaussian(
"P.extent(0) = %d != X.extent(0) = %d",
P.extent(0),
X.extent(0));

const bool x_has_value = x.has_value();
const int x_extent_0 = x_has_value ? (*x).extent(0) : 0;
RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0,
Expand All @@ -420,6 +440,9 @@ void compute_multi_variable_gaussian(

const int nPoints = X.extent(1);
const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr;

auto workspace = token.allocate_workspace();
token.get_impl().set_workspace(workspace.data());
token.get_impl().give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr);
}

Expand Down
16 changes: 9 additions & 7 deletions cpp/test/random/multi_variable_gaussian.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,21 @@ class MVGMdspanTest : public ::testing::TestWithParam<MVGInputs<T>> {
raft::update_device(x_d.data(), x.data(), dim, stream);

// Set up the multivariable Gaussian computation
auto token = detail::setup_multi_variable_gaussian<T>(handle, dim, method);
std::size_t o = detail::workspace_size(token);

// give the workspace area to mvg
workspace_d.resize(o, stream);
{
// Test that setup with a default memory resource compiles.
auto token = detail::setup_multi_variable_gaussian<T>(handle, dim, method);
ASSERT_EQ(dim, token.dim()); // just so token is used
}
rmm::mr::device_memory_resource* mem_resource = rmm::mr::get_current_device_resource();
ASSERT_TRUE(mem_resource != nullptr);
auto token = detail::setup_multi_variable_gaussian<T>(handle, mem_resource, dim, method);

raft::device_vector_view<T, int> workspace_view(workspace_d.data(), o);
std::optional<raft::device_vector_view<const T, int>> x_view(std::in_place, x_d.data(), dim);
raft::device_matrix_view<T, int, raft::col_major> P_view(P_d.data(), dim, dim);
raft::device_matrix_view<T, int, raft::col_major> X_view(X_d.data(), dim, nPoints);

// X_view is the output.
detail::compute_multi_variable_gaussian(token, x_view, P_view, X_view, workspace_view);
detail::compute_multi_variable_gaussian(token, x_view, P_view, X_view);

// saving the mean of the randoms in Rand_mean
//@todo can be swapped with a API that calculates mean
Expand Down

0 comments on commit 9b8ced5

Please sign in to comment.