Skip to content

Commit

Permalink
Expose public interface
Browse files Browse the repository at this point in the history
New mdspan-based compute_multi_variable_gaussian interface
is a one-pass interface.  It uses device_memory_manager
so that the user no longer needs to allocate workspace by hand.
  • Loading branch information
mhoemmen committed Sep 28, 2022
1 parent 9b8ced5 commit fb949f9
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 99 deletions.
151 changes: 73 additions & 78 deletions cpp/include/raft/random/detail/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -297,51 +297,36 @@ template <typename ValueType>
class multi_variable_gaussian_setup_token;

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method);

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method);
const multi_variable_gaussian_decomposition_method method);

// @param x[in] vector of dim elements
// @param P[inout] On input, dim x dim matrix; overwritten on output
// @param X[out] dim x nPoints matrix
template <typename ValueType>
void compute_multi_variable_gaussian(
void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
raft::device_vector_view<ValueType, int> workspace);
raft::device_matrix_view<ValueType, int, raft::col_major> X);

template <typename ValueType>
class multi_variable_gaussian_setup_token {
private:
template <typename T>
friend multi_variable_gaussian_setup_token<T> setup_multi_variable_gaussian(
const raft::handle_t& handle,
const int dim,
multi_variable_gaussian_decomposition_method method);

template <typename T>
friend multi_variable_gaussian_setup_token<T> setup_multi_variable_gaussian(
friend multi_variable_gaussian_setup_token<T> setup_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method);
const multi_variable_gaussian_decomposition_method method);

template <typename T>
friend void compute_multi_variable_gaussian(
friend void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<T>& token,
std::optional<raft::device_vector_view<const T, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
raft::device_vector_view<ValueType, int> workspace);
raft::device_matrix_view<T, int, raft::col_major> P,
raft::device_matrix_view<T, int, raft::col_major> X);

private:
typename multi_variable_gaussian_impl<ValueType>::Decomposer new_enum_to_old_enum(
multi_variable_gaussian_decomposition_method method)
{
Expand All @@ -357,94 +342,104 @@ class multi_variable_gaussian_setup_token {
// Constructor, only for use by friend functions.
// Hiding this will let us change the implementation in the future.
multi_variable_gaussian_setup_token(const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method)
const multi_variable_gaussian_decomposition_method method)
: impl_(std::make_unique<multi_variable_gaussian_impl<ValueType>>(
handle, dim, new_enum_to_old_enum(method))),
handle_(handle),
mem_resource_(mem_resource),
dim_(dim)
{
RAFT_EXPECTS(mem_resource_ != nullptr, "multi_variable_gaussian: device_memory_resource pointer must be nonnull");
}

/**
* @brief Compute the multivariable Gaussian.
*
* @param[in] x vector of dim elements
* @param[inout] P On input, dim x dim matrix; overwritten on output
* @param[out] X dim x nPoints matrix
*/
void compute(std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
const int input_dim = P.extent(0);
RAFT_EXPECTS(input_dim == dim(),
"multi_variable_gaussian: "
"P.extent(0) = %d does not match the extent %d "
"with which the token was created",
input_dim, dim());
RAFT_EXPECTS(P.extent(0) == P.extent(1),
"multi_variable_gaussian: "
"P must be square, but P.extent(0) = %d != P.extent(1) = %d",
P.extent(0), P.extent(1));
RAFT_EXPECTS(P.extent(0) == X.extent(0),
"multi_variable_gaussian: "
"P.extent(0) = %d != X.extent(0) = %d",
P.extent(0), X.extent(0));
const bool x_has_value = x.has_value();
const int x_extent_0 = x_has_value ? (*x).extent(0) : 0;
RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0,
"multi_variable_gaussian: "
"P.extent(0) = %d != x.extent(0) = %d",
P.extent(0), x_extent_0);
const int nPoints = X.extent(1);
const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr;

auto workspace = allocate_workspace();
impl_->set_workspace(workspace.data());
impl_->give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr);
}

private:
std::unique_ptr<multi_variable_gaussian_impl<ValueType>> impl_;
const raft::handle_t& handle_;
rmm::mr::device_memory_resource* mem_resource_ = nullptr;
rmm::mr::device_memory_resource& mem_resource_;
int dim_ = 0;

public:
// FIXME (mfh 2022/09/23) Just a hack, because my friend declarations aren't working.
multi_variable_gaussian_impl<ValueType>& get_impl() const { return *impl_; }

auto allocate_workspace() const {
const auto num_elements = impl_->get_workspace_size();
return rmm::device_uvector<ValueType>{num_elements, handle_.get_stream(), mem_resource_};
return rmm::device_uvector<ValueType>{num_elements, handle_.get_stream(), &mem_resource_};
}

int dim() const { return dim_; }
};

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource* mem_resource,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method)
const multi_variable_gaussian_decomposition_method method)
{
return multi_variable_gaussian_setup_token<ValueType>(handle, mem_resource, dim, method);
}

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> setup_multi_variable_gaussian(
const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method)
void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
rmm::mr::device_memory_resource* mem_resource =
rmm::mr::get_current_device_resource();
return multi_variable_gaussian_setup_token<ValueType>(handle, mem_resource, dim, method);
token.compute(x, P, X);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
multi_variable_gaussian_setup_token<ValueType>& token,
void compute_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
const int dim = P.extent(0);
RAFT_EXPECTS(dim == token.dim(),
"multi_variable_gaussian: "
"P.extent(0) = %d does not match the dimension %d "
"with which the token was created",
P.extent(0),
token.dim());
RAFT_EXPECTS(P.extent(0) == P.extent(1),
"multi_variable_gaussian: "
"P must be square, but P.extent(0) = %d != P.extent(1) = %d",
P.extent(0),
P.extent(1));
RAFT_EXPECTS(P.extent(0) == X.extent(0),
"multi_variable_gaussian: "
"P.extent(0) = %d != X.extent(0) = %d",
P.extent(0),
X.extent(0));
const bool x_has_value = x.has_value();
const int x_extent_0 = x_has_value ? (*x).extent(0) : 0;
RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0,
"multi_variable_gaussian: "
"P.extent(0) = %d != x.extent(0) = %d",
P.extent(0),
x_extent_0);

const int nPoints = X.extent(1);
const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr;

auto workspace = token.allocate_workspace();
token.get_impl().set_workspace(workspace.data());
token.get_impl().give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr);
auto token = setup_multi_variable_gaussian_impl<ValueType>(handle, mem_resource,
P.extent(0), method);
compute_multi_variable_gaussian_impl(token, x, P, X);
}


}; // end of namespace detail
}; // end of namespace raft::random
78 changes: 77 additions & 1 deletion cpp/include/raft/random/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,82 @@ class multi_variable_gaussian : public detail::multi_variable_gaussian_impl<T> {
~multi_variable_gaussian() { deinit(); }
}; // end of multi_variable_gaussian

/**
* @brief Matrix decomposition method for `compute_multi_variable_gaussian` to use.
*
* `compute_multi_variable_gaussian` can use any of the following methods.
*
* - `CHOLESKY`: Cholesky decomposition
* - `JACOBI`:
*/
using detail::multi_variable_gaussian_decomposition_method;
using detail::multi_variable_gaussian_setup_token;

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType>
setup_multi_variable_gaussian(
const raft::handle_t& handle,
const int dim,
const multi_variable_gaussian_decomposition_method method)
{
rmm::mr::device_memory_resource* mem_resource_ptr =
rmm::mr::get_current_device_resource();
RAFT_EXPECTS(mem_resource_ptr != nullptr, "setup_multi_variable_gaussian: "
"rmm::mr::get_current_device_resource() returned null; "
"please report this bug to the RAPIDS RAFT developers.");
return detail::setup_multi_variable_gaussian_impl<ValueType>(handle, *mem_resource_ptr, dim, method);
}

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType>
setup_multi_variable_gaussian(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
multi_variable_gaussian_decomposition_method method)
{
return detail::setup_multi_variable_gaussian_impl<ValueType>(handle, mem_resource, dim, method);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
detail::compute_multi_variable_gaussian_impl(token, x, P, X);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
detail::compute_multi_variable_gaussian_impl(handle, mem_resource, x, P, X, method);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
const raft::handle_t& handle,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
rmm::mr::device_memory_resource* mem_resource_ptr =
rmm::mr::get_current_device_resource();
RAFT_EXPECTS(mem_resource_ptr != nullptr, "compute_multi_variable_gaussian: "
"rmm::mr::get_current_device_resource() returned null; "
"please report this bug to the RAPIDS RAFT developers.");
detail::compute_multi_variable_gaussian_impl(handle, *mem_resource_ptr, x, P, X, method);
}

}; // end of namespace raft::random

#endif
#endif
37 changes: 17 additions & 20 deletions cpp/test/random/multi_variable_gaussian.cu
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ class MVGMdspanTest : public ::testing::TestWithParam<MVGInputs<T>> {
auto cusolverH = handle.get_cusolver_dn_handle();
auto stream = handle.get_stream();

// preparing to store stuff
P.resize(dim * dim);
x.resize(dim);
X.resize(dim * nPoints);
Expand All @@ -254,12 +253,10 @@ class MVGMdspanTest : public ::testing::TestWithParam<MVGInputs<T>> {
Rand_cov.resize(dim * dim, stream);
Rand_mean.resize(dim, stream);

// generating random mean and cov.
srand(params.seed);
for (int j = 0; j < dim; j++)
x.data()[j] = rand() % 100 + 5.0f;

// for random Cov. matrix
std::default_random_engine generator(params.seed);
std::uniform_real_distribution<T> distribution(0.0, 1.0);

Expand All @@ -274,26 +271,24 @@ class MVGMdspanTest : public ::testing::TestWithParam<MVGInputs<T>> {
}
}

// porting inputs to gpu
raft::update_device(P_d.data(), P.data(), dim * dim, stream);
raft::update_device(x_d.data(), x.data(), dim, stream);

// Set up the multivariable Gaussian computation
{
// Test that setup with a default memory resource compiles.
auto token = detail::setup_multi_variable_gaussian<T>(handle, dim, method);
ASSERT_EQ(dim, token.dim()); // just so token is used
}
rmm::mr::device_memory_resource* mem_resource = rmm::mr::get_current_device_resource();
ASSERT_TRUE(mem_resource != nullptr);
auto token = detail::setup_multi_variable_gaussian<T>(handle, mem_resource, dim, method);

std::optional<raft::device_vector_view<const T, int>> x_view(std::in_place, x_d.data(), dim);
raft::device_matrix_view<T, int, raft::col_major> P_view(P_d.data(), dim, dim);
raft::device_matrix_view<T, int, raft::col_major> X_view(X_d.data(), dim, nPoints);

// X_view is the output.
detail::compute_multi_variable_gaussian(token, x_view, P_view, X_view);
{
// Test that setup with a default memory resource compiles.
auto token = raft::random::setup_multi_variable_gaussian<T>(handle, dim, method);
(void) token;
}

rmm::mr::device_memory_resource* mem_resource_ptr =
rmm::mr::get_current_device_resource();
ASSERT_TRUE(mem_resource_ptr != nullptr);
raft::random::compute_multi_variable_gaussian(handle, *mem_resource_ptr,
x_view, P_view, X_view, method);

// saving the mean of the randoms in Rand_mean
//@todo can be swapped with a API that calculates mean
Expand Down Expand Up @@ -446,10 +441,6 @@ TEST_P(MVGTestD, CovIsCorrectD)
<< " in CovIsCorrect";
}

// call the tests
INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestF, ::testing::ValuesIn(inputsf));
INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestD, ::testing::ValuesIn(inputsd));

using MVGMdspanTestF = MVGMdspanTest<float>;
using MVGMdspanTestD = MVGMdspanTest<double>;
TEST_P(MVGMdspanTestF, MeanIsCorrectF)
Expand All @@ -476,6 +467,8 @@ TEST_P(MVGMdspanTestD, MeanIsCorrectD)
}
TEST_P(MVGMdspanTestD, CovIsCorrectD)
{
EXPECT_TRUE(false);

EXPECT_TRUE(raft::devArrMatch(P_d.data(),
Rand_cov.data(),
dim,
Expand All @@ -485,6 +478,10 @@ TEST_P(MVGMdspanTestD, CovIsCorrectD)
<< " in CovIsCorrect";
}

// call the tests
INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestF, ::testing::ValuesIn(inputsf));
INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestD, ::testing::ValuesIn(inputsd));

// call the tests
INSTANTIATE_TEST_CASE_P(MVGMdspanTests, MVGMdspanTestF, ::testing::ValuesIn(inputsf));
INSTANTIATE_TEST_CASE_P(MVGMdspanTests, MVGMdspanTestD, ::testing::ValuesIn(inputsd));
Expand Down

0 comments on commit fb949f9

Please sign in to comment.