Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Develop new mdspan-ified multi_variable_gaussian interface #845

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 159 additions & 1 deletion cpp/include/raft/random/detail/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@

#pragma once
#include "curand_wrappers.hpp"
#include "random_types.hpp"
#include <cmath>
#include <memory>
#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/detail/cusolver_wrappers.hpp>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
#include <stdio.h>
#include <type_traits>

// mvg.cuh takes in matrices that are colomn major (as in fortan)
#define IDX2C(i, j, ld) (j * ld + i)
Expand Down Expand Up @@ -286,5 +292,157 @@ class multi_variable_gaussian_impl {
~multi_variable_gaussian_impl() { deinit(); }
}; // end of multi_variable_gaussian_impl

template <typename ValueType>
class multi_variable_gaussian_setup_token;

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> build_multi_variable_gaussian_token_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method);

template <typename ValueType>
void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X);

template <typename ValueType>
class multi_variable_gaussian_setup_token {
template <typename T>
friend multi_variable_gaussian_setup_token<T> build_multi_variable_gaussian_token_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method);

template <typename T>
friend void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<T>& token,
std::optional<raft::device_vector_view<const T, int>> x,
raft::device_matrix_view<T, int, raft::col_major> P,
raft::device_matrix_view<T, int, raft::col_major> X);

private:
typename multi_variable_gaussian_impl<ValueType>::Decomposer new_enum_to_old_enum(
multi_variable_gaussian_decomposition_method method)
{
if (method == multi_variable_gaussian_decomposition_method::CHOLESKY) {
return multi_variable_gaussian_impl<ValueType>::chol_decomp;
} else if (method == multi_variable_gaussian_decomposition_method::JACOBI) {
return multi_variable_gaussian_impl<ValueType>::jacobi;
} else {
return multi_variable_gaussian_impl<ValueType>::qr;
}
}

// Constructor, only for use by friend functions.
// Hiding this will let us change the implementation in the future.
multi_variable_gaussian_setup_token(const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method)
: impl_(std::make_unique<multi_variable_gaussian_impl<ValueType>>(
handle, dim, new_enum_to_old_enum(method))),
handle_(handle),
mem_resource_(mem_resource),
dim_(dim)
{
}

/**
* @brief Compute the multivariable Gaussian.
*
* @param[in] x vector of dim elements
* @param[inout] P On input, dim x dim matrix; overwritten on output
* @param[out] X dim x nPoints matrix
*/
void compute(std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
const int input_dim = P.extent(0);
RAFT_EXPECTS(input_dim == dim(),
"multi_variable_gaussian: "
"P.extent(0) = %d does not match the extent %d "
"with which the token was created",
input_dim,
dim());
RAFT_EXPECTS(P.extent(0) == P.extent(1),
"multi_variable_gaussian: "
"P must be square, but P.extent(0) = %d != P.extent(1) = %d",
P.extent(0),
P.extent(1));
RAFT_EXPECTS(P.extent(0) == X.extent(0),
"multi_variable_gaussian: "
"P.extent(0) = %d != X.extent(0) = %d",
P.extent(0),
X.extent(0));
const bool x_has_value = x.has_value();
const int x_extent_0 = x_has_value ? (*x).extent(0) : 0;
RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0,
"multi_variable_gaussian: "
"P.extent(0) = %d != x.extent(0) = %d",
P.extent(0),
x_extent_0);
const int nPoints = X.extent(1);
const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr;

auto workspace = allocate_workspace();
impl_->set_workspace(workspace.data());
impl_->give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr);
}

private:
std::unique_ptr<multi_variable_gaussian_impl<ValueType>> impl_;
const raft::handle_t& handle_;
rmm::mr::device_memory_resource& mem_resource_;
int dim_ = 0;

auto allocate_workspace() const
{
const auto num_elements = impl_->get_workspace_size();
return rmm::device_uvector<ValueType>{num_elements, handle_.get_stream(), &mem_resource_};
}

int dim() const { return dim_; }
};

template <typename ValueType>
multi_variable_gaussian_setup_token<ValueType> build_multi_variable_gaussian_token_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
const int dim,
const multi_variable_gaussian_decomposition_method method)
{
return multi_variable_gaussian_setup_token<ValueType>(handle, mem_resource, dim, method);
}

template <typename ValueType>
void compute_multi_variable_gaussian_impl(
multi_variable_gaussian_setup_token<ValueType>& token,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X)
{
token.compute(x, P, X);
}

template <typename ValueType>
void compute_multi_variable_gaussian_impl(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
auto token =
build_multi_variable_gaussian_token_impl<ValueType>(handle, mem_resource, P.extent(0), method);
compute_multi_variable_gaussian_impl(token, x, P, X);
}

}; // end of namespace detail
}; // end of namespace raft::random
}; // end of namespace raft::random
23 changes: 23 additions & 0 deletions cpp/include/raft/random/detail/random_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace raft::random::detail {

enum class multi_variable_gaussian_decomposition_method { CHOLESKY, JACOBI, QR };

}; // end of namespace raft::random::detail
48 changes: 47 additions & 1 deletion cpp/include/raft/random/multi_variable_gaussian.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,52 @@ class multi_variable_gaussian : public detail::multi_variable_gaussian_impl<T> {
~multi_variable_gaussian() { deinit(); }
}; // end of multi_variable_gaussian

/**
* @brief Matrix decomposition method for `compute_multi_variable_gaussian` to use.
*
* `compute_multi_variable_gaussian` can use any of the following methods.
*
* - `CHOLESKY`: Uses Cholesky decomposition on the normal equations.
* This may be faster than the other two methods, but less accurate.
*
* - `JACOBI`: Uses the singular value decomposition (SVD) computed with
* cuSOLVER's gesvdj algorithm, which is based on the Jacobi method
* (sweeps of plane rotations). This exposes more parallelism
* for small and medium size matrices than the QR option below.
*
* - `QR`: Uses the SVD computed with cuSOLVER's gesvd algorithm,
* which is based on the QR algortihm.
*/
using detail::multi_variable_gaussian_decomposition_method;

template <typename ValueType>
void compute_multi_variable_gaussian(
const raft::handle_t& handle,
rmm::mr::device_memory_resource& mem_resource,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
detail::compute_multi_variable_gaussian_impl(handle, mem_resource, x, P, X, method);
}

template <typename ValueType>
void compute_multi_variable_gaussian(
const raft::handle_t& handle,
std::optional<raft::device_vector_view<const ValueType, int>> x,
raft::device_matrix_view<ValueType, int, raft::col_major> P,
raft::device_matrix_view<ValueType, int, raft::col_major> X,
const multi_variable_gaussian_decomposition_method method)
{
rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource();
RAFT_EXPECTS(mem_resource_ptr != nullptr,
"compute_multi_variable_gaussian: "
"rmm::mr::get_current_device_resource() returned null; "
"please report this bug to the RAPIDS RAFT developers.");
detail::compute_multi_variable_gaussian_impl(handle, *mem_resource_ptr, x, P, X, method);
}

}; // end of namespace raft::random

#endif
#endif
Loading