Skip to content

Commit

Permalink
Add new multi_variable_gaussian interface
Browse files Browse the repository at this point in the history
It's not ready for public consumption yet,
as it still lives in the detail namespace.
However, it builds and passes tests.
  • Loading branch information
mhoemmen committed Sep 27, 2022
1 parent 694deda commit 3a74368
Show file tree
Hide file tree
Showing 2 changed files with 318 additions and 5 deletions.
140 changes: 139 additions & 1 deletion cpp/include/raft/random/detail/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
#pragma once
#include "curand_wrappers.hpp"
#include <cmath>
#include <memory>
#include <optional>
#include <type_traits>
#include <raft/core/handle.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/detail/cusolver_wrappers.hpp>
#include <raft/linalg/matrix_vector_op.cuh>
Expand Down Expand Up @@ -286,5 +290,139 @@ class multi_variable_gaussian_impl {
~multi_variable_gaussian_impl() { deinit(); }
}; // end of multi_variable_gaussian_impl


enum class multi_variable_gaussian_decomposition_method {
CHOLESKY, JACOBI, QR
};

template<typename ValueType>
class multi_variable_gaussian_setup_token;

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType>
setup_multi_variable_gaussian(const raft::handle_t& handle,
const int dim,
multi_variable_gaussian_decomposition_method method);

template <typename ValueType>
std::size_t
workspace_size(const multi_variable_gaussian_setup_token<ValueType>& token);

// @param x[in] vector of dim elements
// @param P[inout] On input, dim x dim matrix; overwritten on output
// @param X[out] dim x nPoints matrix
template <typename ValueType>
void
compute_multi_variable_gaussian(multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
raft::device_vector_view<ValueType, int> workspace);

template <typename ValueType>
class multi_variable_gaussian_setup_token {
private:
template <typename T>
friend multi_variable_gaussian_setup_token<T>
setup_multi_variable_gaussian(const raft::handle_t& handle,
const int dim,
multi_variable_gaussian_decomposition_method method);

template <typename T>
friend std::size_t
workspace_size(multi_variable_gaussian_setup_token<T>& token);

template <typename T>
friend void
compute_multi_variable_gaussian(multi_variable_gaussian_setup_token<T>& token,
std::optional<raft::device_vector_view<const T, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
raft::device_vector_view<ValueType, int> workspace);

typename multi_variable_gaussian_impl<ValueType>::Decomposer
new_enum_to_old_enum(multi_variable_gaussian_decomposition_method method)
{
if (method == multi_variable_gaussian_decomposition_method::CHOLESKY) {
return multi_variable_gaussian_impl<ValueType>::chol_decomp;
} else if (method == multi_variable_gaussian_decomposition_method::JACOBI) {
return multi_variable_gaussian_impl<ValueType>::jacobi;
} else {
return multi_variable_gaussian_impl<ValueType>::qr;
}
}

// Constructor, only for use by friend functions.
// Hiding this will let us change the implementation in the future.
multi_variable_gaussian_setup_token(const raft::handle_t& handle,
const int dim,
multi_variable_gaussian_decomposition_method method)
: impl_(std::make_unique<multi_variable_gaussian_impl<ValueType>>(handle, dim, new_enum_to_old_enum(method))), dim_(dim)
{}

private:
std::unique_ptr<multi_variable_gaussian_impl<ValueType>> impl_;
int dim_ = 0;

public:
// FIXME (mfh 2022/09/23) Just a hack, because my friend declarations aren't working.
multi_variable_gaussian_impl<ValueType>& get_impl() const {
return *impl_;
}

int dim() const {
return dim_;
}
};


template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType>
setup_multi_variable_gaussian(const raft::handle_t& handle,
const int dim,
multi_variable_gaussian_decomposition_method method)
{
return multi_variable_gaussian_setup_token<ValueType>(handle, dim, method);
}

template <typename ValueType>
std::size_t
workspace_size(multi_variable_gaussian_setup_token<ValueType>& token)
{
return token.get_impl().get_workspace_size();
}

template <typename ValueType>
void
compute_multi_variable_gaussian(multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
raft::device_vector_view<ValueType, int> workspace)
{
RAFT_EXPECTS(static_cast<std::size_t>(workspace.extent(0)) >= workspace_size(token),
"multi_variable_gaussian: Insufficient workspace");
token.get_impl().set_workspace(workspace.data_handle());

const int dim = P.extent(0);
RAFT_EXPECTS(dim == token.dim(), "multi_variable_gaussian: "
"P.extent(0) = %d does not match the dimension %d "
"with which the token was created", P.extent(0), token.dim());
RAFT_EXPECTS(P.extent(0) == P.extent(1), "multi_variable_gaussian: "
"P must be square, but P.extent(0) = %d != P.extent(1) = %d",
P.extent(0), P.extent(1));
RAFT_EXPECTS(P.extent(0) == X.extent(0), "multi_variable_gaussian: "
"P.extent(0) = %d != X.extent(0) = %d", P.extent(0), X.extent(0));

const bool x_has_value = x.has_value();
const int x_extent_0 = x_has_value ? (*x).extent(0) : 0;
RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0, "multi_variable_gaussian: "
"P.extent(0) = %d != x.extent(0) = %d", P.extent(0), x_extent_0);

const int nPoints = X.extent(1);
const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr;
token.get_impl().give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr);
}

}; // end of namespace detail
}; // end of namespace raft::random
}; // end of namespace raft::random
183 changes: 179 additions & 4 deletions cpp/test/random/multi_variable_gaussian.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <random>
#include <rmm/device_uvector.hpp>

// mvg.h takes in matrices that are colomn major (as in fortan)
// mvg.h takes in column-major matrices (as in Fortran)
#define IDX2C(i, j, ld) (j * ld + i)

namespace raft::random {
Expand Down Expand Up @@ -206,6 +206,142 @@ class MVGTest : public ::testing::TestWithParam<MVGInputs<T>> {
raft::handle_t handle;
}; // end of MVGTest class

template <typename T>
class MVGMdspanTest : public ::testing::TestWithParam<MVGInputs<T>> {
private:
static auto old_enum_to_new_enum(typename multi_variable_gaussian<T>::Decomposer method)
{
if(method == multi_variable_gaussian<T>::chol_decomp) {
return detail::multi_variable_gaussian_decomposition_method::CHOLESKY;
} else if(method == multi_variable_gaussian<T>::jacobi) {
return detail::multi_variable_gaussian_decomposition_method::JACOBI;
} else {
return detail::multi_variable_gaussian_decomposition_method::QR;
}
}

protected:
MVGMdspanTest()
: workspace_d(0, handle.get_stream()),
P_d(0, handle.get_stream()),
x_d(0, handle.get_stream()),
X_d(0, handle.get_stream()),
Rand_cov(0, handle.get_stream()),
Rand_mean(0, handle.get_stream())
{
}

void SetUp() override
{
params = ::testing::TestWithParam<MVGInputs<T>>::GetParam();
dim = params.dim;
nPoints = params.nPoints;
auto method = old_enum_to_new_enum(params.method);
corr = params.corr;
tolerance = params.tolerance;

auto cublasH = handle.get_cublas_handle();
auto cusolverH = handle.get_cusolver_dn_handle();
auto stream = handle.get_stream();

// preparing to store stuff
P.resize(dim * dim);
x.resize(dim);
X.resize(dim * nPoints);
P_d.resize(dim * dim, stream);
X_d.resize(nPoints * dim, stream);
x_d.resize(dim, stream);
Rand_cov.resize(dim * dim, stream);
Rand_mean.resize(dim, stream);

// generating random mean and cov.
srand(params.seed);
for (int j = 0; j < dim; j++)
x.data()[j] = rand() % 100 + 5.0f;

// for random Cov. matrix
std::default_random_engine generator(params.seed);
std::uniform_real_distribution<T> distribution(0.0, 1.0);

// P (symmetric positive definite matrix)
for (int j = 0; j < dim; j++) {
for (int i = 0; i < j + 1; i++) {
T k = distribution(generator);
if (corr == UNCORRELATED) k = 0.0;
P.data()[IDX2C(i, j, dim)] = k;
P.data()[IDX2C(j, i, dim)] = k;
if (i == j) P.data()[IDX2C(i, j, dim)] += dim;
}
}

// porting inputs to gpu
raft::update_device(P_d.data(), P.data(), dim * dim, stream);
raft::update_device(x_d.data(), x.data(), dim, stream);

// Set up the multivariable Gaussian computation
auto token = detail::setup_multi_variable_gaussian<T>(handle, dim, method);
std::size_t o = detail::workspace_size(token);

// give the workspace area to mvg
workspace_d.resize(o, stream);

raft::device_vector_view<T, int> workspace_view(workspace_d.data(), o);
std::optional<raft::device_vector_view<const T, int>> x_view(std::in_place, x_d.data(), dim);
raft::device_matrix_view<T, int, raft::col_major> P_view(P_d.data(), dim, dim);
raft::device_matrix_view<T, int, raft::col_major> X_view(X_d.data(), dim, nPoints);

// X_view is the output.
detail::compute_multi_variable_gaussian(token, x_view, P_view, X_view, workspace_view);

// saving the mean of the randoms in Rand_mean
//@todo can be swapped with a API that calculates mean
RAFT_CUDA_TRY(cudaMemset(Rand_mean.data(), 0, dim * sizeof(T)));
dim3 block = (64);
dim3 grid = (raft::ceildiv(nPoints * dim, (int)block.x));
En_KF_accumulate<<<grid, block, 0, stream>>>(nPoints, dim, X_d.data(), Rand_mean.data());
RAFT_CUDA_TRY(cudaPeekAtLastError());
grid = (raft::ceildiv(dim, (int)block.x));
En_KF_normalize<<<grid, block, 0, stream>>>(nPoints, dim, Rand_mean.data());
RAFT_CUDA_TRY(cudaPeekAtLastError());

// storing the error wrt random point mean in X_d
grid = (raft::ceildiv(dim * nPoints, (int)block.x));
En_KF_dif<<<grid, block, 0, stream>>>(nPoints, dim, X_d.data(), Rand_mean.data(), X_d.data());
RAFT_CUDA_TRY(cudaPeekAtLastError());

// finding the cov matrix, placing in Rand_cov
T alfa = 1.0 / (nPoints - 1), beta = 0.0;

RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublasH,
CUBLAS_OP_N,
CUBLAS_OP_T,
dim,
dim,
nPoints,
&alfa,
X_d.data(),
dim,
X_d.data(),
dim,
&beta,
Rand_cov.data(),
dim,
stream));

// restoring cov provided into P_d
raft::update_device(P_d.data(), P.data(), dim * dim, stream);
}

protected:
MVGInputs<T> params;
std::vector<T> P, x, X;
rmm::device_uvector<T> workspace_d, P_d, x_d, X_d, Rand_cov, Rand_mean;
int dim, nPoints;
Correlation corr;
T tolerance;
raft::handle_t handle;
}; // end of MVGTest class

///@todo find out the reason that Un-correlated covs are giving problems (in qr)
// Declare your inputs
const std::vector<MVGInputs<float>> inputsf = {
Expand Down Expand Up @@ -273,8 +409,8 @@ const std::vector<MVGInputs<double>> inputsd = {
};

// make the tests
typedef MVGTest<float> MVGTestF;
typedef MVGTest<double> MVGTestD;
using MVGTestF = MVGTest<float>;
using MVGTestD = MVGTest<double>;
TEST_P(MVGTestF, MeanIsCorrectF)
{
EXPECT_TRUE(raft::devArrMatch(
Expand Down Expand Up @@ -312,4 +448,43 @@ TEST_P(MVGTestD, CovIsCorrectD)
INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestF, ::testing::ValuesIn(inputsf));
INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestD, ::testing::ValuesIn(inputsd));

}; // end of namespace raft::random
using MVGMdspanTestF = MVGMdspanTest<float>;
using MVGMdspanTestD = MVGMdspanTest<double>;
TEST_P(MVGMdspanTestF, MeanIsCorrectF)
{
EXPECT_TRUE(raft::devArrMatch(
x_d.data(), Rand_mean.data(), dim, raft::CompareApprox<float>(tolerance), handle.get_stream()))
<< " in MeanIsCorrect";
}
TEST_P(MVGMdspanTestF, CovIsCorrectF)
{
EXPECT_TRUE(raft::devArrMatch(P_d.data(),
Rand_cov.data(),
dim,
dim,
raft::CompareApprox<float>(tolerance),
handle.get_stream()))
<< " in CovIsCorrect";
}
TEST_P(MVGMdspanTestD, MeanIsCorrectD)
{
EXPECT_TRUE(raft::devArrMatch(
x_d.data(), Rand_mean.data(), dim, raft::CompareApprox<double>(tolerance), handle.get_stream()))
<< " in MeanIsCorrect";
}
TEST_P(MVGMdspanTestD, CovIsCorrectD)
{
EXPECT_TRUE(raft::devArrMatch(P_d.data(),
Rand_cov.data(),
dim,
dim,
raft::CompareApprox<double>(tolerance),
handle.get_stream()))
<< " in CovIsCorrect";
}

// call the tests
INSTANTIATE_TEST_CASE_P(MVGMdspanTests, MVGMdspanTestF, ::testing::ValuesIn(inputsf));
INSTANTIATE_TEST_CASE_P(MVGMdspanTests, MVGMdspanTestD, ::testing::ValuesIn(inputsd));

}; // end of namespace raft::random

0 comments on commit 3a74368

Please sign in to comment.