Skip to content

Commit

Permalink
Develop new mdspan-ified multi_variable_gaussian interface (#845)
Browse files Browse the repository at this point in the history
Develop a new `multi_variable_gaussian` interface that uses `mdspan`.  The new interface uses free functions, rather than a class.

TODO:

- [x] Do not expose workspaces to the public API (see e.g., #802 (comment)); use RMM instead.
- [x] Add a public API not in in the `detail` namespace, and test it.
- [x] File issues for refactoring and improving the implementation.

Authors:
  - Mark Hoemmen (https://github.com/mhoemmen)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #845
  • Loading branch information
mhoemmen authored Oct 5, 2022
1 parent d998eff commit b3d5103
Show file tree
Hide file tree
Showing 4 changed files with 398 additions and 6 deletions.
160 changes: 159 additions & 1 deletion cpp/include/raft/random/detail/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@

#pragma once
#include "curand_wrappers.hpp"
#include "random_types.hpp"
#include <cmath>
#include <memory>
#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/detail/cusolver_wrappers.hpp>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
#include <stdio.h>
#include <type_traits>

// mvg.cuh takes in matrices that are colomn major (as in fortan)
#define IDX2C(i, j, ld) (j * ld + i)
Expand Down Expand Up @@ -286,5 +292,157 @@ class multi_variable_gaussian_impl {
~multi_variable_gaussian_impl() { deinit(); }
}; // end of multi_variable_gaussian_impl

template <typename ValueType>
class multi_variable_gaussian_setup_token;

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> build_multi_variable_gaussian_token_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method);

template <typename ValueType>
void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X);

template <typename ValueType>
class multi_variable_gaussian_setup_token {
template <typename T>
friend multi_variable_gaussian_setup_token<T> build_multi_variable_gaussian_token_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method);

template <typename T>
friend void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<T>& token,
std::optional<raft::device_vector_view<const T, int>> x,
raft::device_matrix_view<T, int, raft::col_major> P,
raft::device_matrix_view<T, int, raft::col_major> X);

private:
typename multi_variable_gaussian_impl<ValueType>::Decomposer new_enum_to_old_enum(
multi_variable_gaussian_decomposition_method method)
{
if (method == multi_variable_gaussian_decomposition_method::CHOLESKY) {
return multi_variable_gaussian_impl<ValueType>::chol_decomp;
} else if (method == multi_variable_gaussian_decomposition_method::JACOBI) {
return multi_variable_gaussian_impl<ValueType>::jacobi;
} else {
return multi_variable_gaussian_impl<ValueType>::qr;
}
}

// Constructor, only for use by friend functions.
// Hiding this will let us change the implementation in the future.
multi_variable_gaussian_setup_token(const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method)
: impl_(std::make_unique<multi_variable_gaussian_impl<ValueType>>(
handle, dim, new_enum_to_old_enum(method))),
handle_(handle),
mem_resource_(mem_resource),
dim_(dim)
{
}

/**
* @brief Compute the multivariable Gaussian.
*
* @param[in] x vector of dim elements
* @param[inout] P On input, dim x dim matrix; overwritten on output
* @param[out] X dim x nPoints matrix
*/
void compute(std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
const int input_dim = P.extent(0);
RAFT_EXPECTS(input_dim == dim(),
"multi_variable_gaussian: "
"P.extent(0) = %d does not match the extent %d "
"with which the token was created",
input_dim,
dim());
RAFT_EXPECTS(P.extent(0) == P.extent(1),
"multi_variable_gaussian: "
"P must be square, but P.extent(0) = %d != P.extent(1) = %d",
P.extent(0),
P.extent(1));
RAFT_EXPECTS(P.extent(0) == X.extent(0),
"multi_variable_gaussian: "
"P.extent(0) = %d != X.extent(0) = %d",
P.extent(0),
X.extent(0));
const bool x_has_value = x.has_value();
const int x_extent_0 = x_has_value ? (*x).extent(0) : 0;
RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0,
"multi_variable_gaussian: "
"P.extent(0) = %d != x.extent(0) = %d",
P.extent(0),
x_extent_0);
const int nPoints = X.extent(1);
const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr;

auto workspace = allocate_workspace();
impl_->set_workspace(workspace.data());
impl_->give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr);
}

private:
std::unique_ptr<multi_variable_gaussian_impl<ValueType>> impl_;
const raft::handle_t& handle_;
rmm::mr::device_memory_resource& mem_resource_;
int dim_ = 0;

auto allocate_workspace() const
{
const auto num_elements = impl_->get_workspace_size();
return rmm::device_uvector<ValueType>{num_elements, handle_.get_stream(), &mem_resource_};
}

int dim() const { return dim_; }
};

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> build_multi_variable_gaussian_token_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method)
{
return multi_variable_gaussian_setup_token<ValueType>(handle, mem_resource, dim, method);
}

template <typename ValueType>
void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
token.compute(x, P, X);
}

template <typename ValueType>
void compute_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
auto token =
build_multi_variable_gaussian_token_impl<ValueType>(handle, mem_resource, P.extent(0), method);
compute_multi_variable_gaussian_impl(token, x, P, X);
}

}; // end of namespace detail
}; // end of namespace raft::random
}; // end of namespace raft::random
23 changes: 23 additions & 0 deletions cpp/include/raft/random/detail/random_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace raft::random::detail {

enum class multi_variable_gaussian_decomposition_method { CHOLESKY, JACOBI, QR };

}; // end of namespace raft::random::detail
48 changes: 47 additions & 1 deletion cpp/include/raft/random/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,52 @@ class multi_variable_gaussian : public detail::multi_variable_gaussian_impl<T> {
~multi_variable_gaussian() { deinit(); }
}; // end of multi_variable_gaussian

/**
* @brief Matrix decomposition method for `compute_multi_variable_gaussian` to use.
*
* `compute_multi_variable_gaussian` can use any of the following methods.
*
* - `CHOLESKY`: Uses Cholesky decomposition on the normal equations.
* This may be faster than the other two methods, but less accurate.
*
* - `JACOBI`: Uses the singular value decomposition (SVD) computed with
* cuSOLVER's gesvdj algorithm, which is based on the Jacobi method
* (sweeps of plane rotations). This exposes more parallelism
* for small and medium size matrices than the QR option below.
*
* - `QR`: Uses the SVD computed with cuSOLVER's gesvd algorithm,
* which is based on the QR algortihm.
*/
using detail::multi_variable_gaussian_decomposition_method;

template <typename ValueType>
void compute_multi_variable_gaussian(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
detail::compute_multi_variable_gaussian_impl(handle, mem_resource, x, P, X, method);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
const raft::handle_t& handle,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource();
RAFT_EXPECTS(mem_resource_ptr != nullptr,
"compute_multi_variable_gaussian: "
"rmm::mr::get_current_device_resource() returned null; "
"please report this bug to the RAPIDS RAFT developers.");
detail::compute_multi_variable_gaussian_impl(handle, *mem_resource_ptr, x, P, X, method);
}

}; // end of namespace raft::random

#endif
#endif
Loading

0 comments on commit b3d5103

Please sign in to comment.