Skip to content

Commit

Permalink
Expose public interface
Browse files Browse the repository at this point in the history
New mdspan-based compute_multi_variable_gaussian interface
is a one-pass interface.  It uses device_memory_manager
so that the user no longer needs to allocate workspace by hand.
  • Loading branch information
mhoemmen committed Sep 28, 2022
1 parent 9b8ced5 commit d866f64
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 101 deletions.
159 changes: 79 additions & 80 deletions cpp/include/raft/random/detail/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <cmath>
#include <memory>
#include <optional>
#include <rmm/device_uvector.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
Expand All @@ -28,6 +27,7 @@
#include <raft/linalg/unary_op.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
#include <stdio.h>
#include <type_traits>

Expand Down Expand Up @@ -297,51 +297,36 @@ template <typename ValueType>
class multi_variable_gaussian_setup_token;

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method);

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method);
const multi_variable_gaussian_decomposition_method method);

// @param x[in] vector of dim elements
// @param P[inout] On input, dim x dim matrix; overwritten on output
// @param X[out] dim x nPoints matrix
template <typename ValueType>
void compute_multi_variable_gaussian(
void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
raft::device_vector_view<ValueType, int> workspace);
raft::device_matrix_view<ValueType, int, raft::col_major> X);

template <typename ValueType>
class multi_variable_gaussian_setup_token {
private:
template <typename T>
friend multi_variable_gaussian_setup_token<T> setup_multi_variable_gaussian(
friend multi_variable_gaussian_setup_token<T> setup_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method);
const multi_variable_gaussian_decomposition_method method);

template <typename T>
friend multi_variable_gaussian_setup_token<T> setup_multi_variable_gaussian(
const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method);

template <typename T>
friend void compute_multi_variable_gaussian(
friend void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<T>& token,
std::optional<raft::device_vector_view<const T, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
raft::device_vector_view<ValueType, int> workspace);
raft::device_matrix_view<T, int, raft::col_major> P,
raft::device_matrix_view<T, int, raft::col_major> X);

private:
typename multi_variable_gaussian_impl<ValueType>::Decomposer new_enum_to_old_enum(
multi_variable_gaussian_decomposition_method method)
{
Expand All @@ -357,93 +342,107 @@ class multi_variable_gaussian_setup_token {
// Constructor, only for use by friend functions.
// Hiding this will let us change the implementation in the future.
multi_variable_gaussian_setup_token(const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method)
const multi_variable_gaussian_decomposition_method method)
: impl_(std::make_unique<multi_variable_gaussian_impl<ValueType>>(
handle, dim, new_enum_to_old_enum(method))),
handle_(handle),
mem_resource_(mem_resource),
dim_(dim)
{
RAFT_EXPECTS(mem_resource_ != nullptr, "multi_variable_gaussian: device_memory_resource pointer must be nonnull");
}

/**
* @brief Compute the multivariable Gaussian.
*
* @param[in] x vector of dim elements
* @param[inout] P On input, dim x dim matrix; overwritten on output
* @param[out] X dim x nPoints matrix
*/
void compute(std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
const int input_dim = P.extent(0);
RAFT_EXPECTS(input_dim == dim(),
"multi_variable_gaussian: "
"P.extent(0) = %d does not match the extent %d "
"with which the token was created",
input_dim,
dim());
RAFT_EXPECTS(P.extent(0) == P.extent(1),
"multi_variable_gaussian: "
"P must be square, but P.extent(0) = %d != P.extent(1) = %d",
P.extent(0),
P.extent(1));
RAFT_EXPECTS(P.extent(0) == X.extent(0),
"multi_variable_gaussian: "
"P.extent(0) = %d != X.extent(0) = %d",
P.extent(0),
X.extent(0));
const bool x_has_value = x.has_value();
const int x_extent_0 = x_has_value ? (*x).extent(0) : 0;
RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0,
"multi_variable_gaussian: "
"P.extent(0) = %d != x.extent(0) = %d",
P.extent(0),
x_extent_0);
const int nPoints = X.extent(1);
const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr;

auto workspace = allocate_workspace();
impl_->set_workspace(workspace.data());
impl_->give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr);
}

private:
std::unique_ptr<multi_variable_gaussian_impl<ValueType>> impl_;
const raft::handle_t& handle_;
rmm::mr::device_memory_resource* mem_resource_ = nullptr;
rmm::mr::device_memory_resource& mem_resource_;
int dim_ = 0;

public:
// FIXME (mfh 2022/09/23) Just a hack, because my friend declarations aren't working.
multi_variable_gaussian_impl<ValueType>& get_impl() const { return *impl_; }

auto allocate_workspace() const {
auto allocate_workspace() const
{
const auto num_elements = impl_->get_workspace_size();
return rmm::device_uvector<ValueType>{num_elements, handle_.get_stream(), mem_resource_};
return rmm::device_uvector<ValueType>{num_elements, handle_.get_stream(), &mem_resource_};
}

int dim() const { return dim_; }
};

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method)
const multi_variable_gaussian_decomposition_method method)
{
return multi_variable_gaussian_setup_token<ValueType>(handle, mem_resource, dim, method);
}

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method)
void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
rmm::mr::device_memory_resource* mem_resource =
rmm::mr::get_current_device_resource();
return multi_variable_gaussian_setup_token<ValueType>(handle, mem_resource, dim, method);
token.compute(x, P, X);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
multi_variable_gaussian_setup_token<ValueType>& token,
void compute_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
const int dim = P.extent(0);
RAFT_EXPECTS(dim == token.dim(),
"multi_variable_gaussian: "
"P.extent(0) = %d does not match the dimension %d "
"with which the token was created",
P.extent(0),
token.dim());
RAFT_EXPECTS(P.extent(0) == P.extent(1),
"multi_variable_gaussian: "
"P must be square, but P.extent(0) = %d != P.extent(1) = %d",
P.extent(0),
P.extent(1));
RAFT_EXPECTS(P.extent(0) == X.extent(0),
"multi_variable_gaussian: "
"P.extent(0) = %d != X.extent(0) = %d",
P.extent(0),
X.extent(0));
const bool x_has_value = x.has_value();
const int x_extent_0 = x_has_value ? (*x).extent(0) : 0;
RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0,
"multi_variable_gaussian: "
"P.extent(0) = %d != x.extent(0) = %d",
P.extent(0),
x_extent_0);

const int nPoints = X.extent(1);
const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr;

auto workspace = token.allocate_workspace();
token.get_impl().set_workspace(workspace.data());
token.get_impl().give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr);
auto token =
setup_multi_variable_gaussian_impl<ValueType>(handle, mem_resource, P.extent(0), method);
compute_multi_variable_gaussian_impl(token, x, P, X);
}

}; // end of namespace detail
Expand Down
77 changes: 76 additions & 1 deletion cpp/include/raft/random/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,81 @@ class multi_variable_gaussian : public detail::multi_variable_gaussian_impl<T> {
~multi_variable_gaussian() { deinit(); }
}; // end of multi_variable_gaussian

/**
* @brief Matrix decomposition method for `compute_multi_variable_gaussian` to use.
*
* `compute_multi_variable_gaussian` can use any of the following methods.
*
* - `CHOLESKY`: Cholesky decomposition
* - `JACOBI`:
*/
using detail::multi_variable_gaussian_decomposition_method;
using detail::multi_variable_gaussian_setup_token;

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle,
const int dim,
const multi_variable_gaussian_decomposition_method method)
{
rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource();
RAFT_EXPECTS(mem_resource_ptr != nullptr,
"setup_multi_variable_gaussian: "
"rmm::mr::get_current_device_resource() returned null; "
"please report this bug to the RAPIDS RAFT developers.");
return detail::setup_multi_variable_gaussian_impl<ValueType>(
handle, *mem_resource_ptr, dim, method);
}

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method)
{
return detail::setup_multi_variable_gaussian_impl<ValueType>(handle, mem_resource, dim, method);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
detail::compute_multi_variable_gaussian_impl(token, x, P, X);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
detail::compute_multi_variable_gaussian_impl(handle, mem_resource, x, P, X, method);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
const raft::handle_t& handle,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource();
RAFT_EXPECTS(mem_resource_ptr != nullptr,
"compute_multi_variable_gaussian: "
"rmm::mr::get_current_device_resource() returned null; "
"please report this bug to the RAPIDS RAFT developers.");
detail::compute_multi_variable_gaussian_impl(handle, *mem_resource_ptr, x, P, X, method);
}

}; // end of namespace raft::random

#endif
#endif
Loading

0 comments on commit d866f64

Please sign in to comment.