Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for svd API #1190

Merged
merged 8 commits into from
Feb 4, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/include/raft/linalg/detail/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ void svdQR(raft::device_resources const& handle,
stream));

// Transpose the right singular vector back
if (trans_right) raft::linalg::transpose(right_sing_vecs, n_cols, stream);
if (trans_right && right_sing_vecs != nullptr)
raft::linalg::transpose(right_sing_vecs, n_cols, stream);

RAFT_CUDA_TRY(cudaGetLastError());

Expand Down
117 changes: 72 additions & 45 deletions cpp/include/raft/linalg/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -191,45 +191,42 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle,
* matrix using QR decomposition
* @tparam ValueType value type of parameters
* @tparam IndexType index type of parameters
* @tparam UType std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> @c
* U_in
* @tparam VType std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> @c
* V_in
* @param[in] handle raft::device_resources
* @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] sing_vals singular values raft::device_vector_view of shape (K)
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* @param[out] U std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with
* @param[out] V std::optional right singular values of raft::device_matrix_view with
* layout raft::col_major and dimensions (n, n)
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
void svd_qr(raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
UType&& U_in,
VType&& V_in)
template <typename ValueType, typename IndexType>
void svd_qr(
raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U = std::nullopt,
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V = std::nullopt)
{
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U =
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);
ValueType* left_sing_vecs_ptr = nullptr;
ValueType* right_sing_vecs_ptr = nullptr;

if (U) {
RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1),
"U should have dimensions m * n");
left_sing_vecs_ptr = U.value().data_handle();
}
if (V) {
RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1),
"V should have dimensions n * n");
right_sing_vecs_ptr = V.value().data_handle();
}
svdQR(handle,
const_cast<ValueType*>(in.data_handle()),
in.extent(0),
in.extent(1),
sing_vals.data_handle(),
U.value().data_handle(),
V.value().data_handle(),
left_sing_vecs_ptr,
right_sing_vecs_ptr,
false,
U.has_value(),
V.has_value(),
Expand All @@ -243,57 +240,62 @@ void svd_qr(raft::device_resources const& handle,
*
* Please see above for documentation of `svd_qr`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void svd_qr(Args... args)
template <typename ValueType, typename IndexType, typename UType, typename VType>
void svd_qr(raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
UType&& U_in = std::nullopt,
VType&& V_in = std::nullopt)
{
svd_qr(std::forward<Args>(args)..., std::nullopt, std::nullopt);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U =
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);

svd_qr(handle, in, sing_vals, U, V);
}

/**
* @brief singular value decomposition (SVD) on a column major
* matrix using QR decomposition. Right singular vector matrix is transposed before returning
* @tparam ValueType value type of parameters
* @tparam IndexType index type of parameters
* @tparam UType std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> @c
* U_in
* @tparam VType std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> @c
* V_in
* @param[in] handle raft::device_resources
* @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] sing_vals singular values raft::device_vector_view of shape (K)
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* @param[out] U std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with
* @param[out] V std::optional right singular values of raft::device_matrix_view with
* layout raft::col_major and dimensions (n, n)
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
template <typename ValueType, typename IndexType>
void svd_qr_transpose_right_vec(
raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
UType&& U_in,
VType&& V_in)
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U = std::nullopt,
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V = std::nullopt)
{
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U =
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);
ValueType* left_sing_vecs_ptr = nullptr;
ValueType* right_sing_vecs_ptr = nullptr;

if (U) {
RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1),
"U should have dimensions m * n");
left_sing_vecs_ptr = U.value().data_handle();
}
if (V) {
RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1),
"V should have dimensions n * n");
right_sing_vecs_ptr = V.value().data_handle();
}
svdQR(handle,
const_cast<ValueType*>(in.data_handle()),
in.extent(0),
in.extent(1),
sing_vals.data_handle(),
U.value().data_handle(),
V.value().data_handle(),
left_sing_vecs_ptr,
right_sing_vecs_ptr,
true,
U.has_value(),
V.has_value(),
Expand All @@ -307,10 +309,20 @@ void svd_qr_transpose_right_vec(
*
* Please see above for documentation of `svd_qr_transpose_right_vec`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void svd_qr_transpose_right_vec(Args... args)
template <typename ValueType, typename IndexType, typename UType, typename VType>
void svd_qr_transpose_right_vec(
raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
UType&& U_in = std::nullopt,
VType&& V_in = std::nullopt)
{
svd_qr_transpose_right_vec(std::forward<Args>(args)..., std::nullopt, std::nullopt);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U =
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);

svd_qr_transpose_right_vec(handle, in, sing_vals, U, V);
}

/**
Expand All @@ -320,7 +332,7 @@ void svd_qr_transpose_right_vec(Args... args)
* @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] S singular values raft::device_vector_view of shape (K)
* @param[out] V right singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
* raft::col_major and dimensions (n, n)
* @param[out] U optional left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
*/
Expand All @@ -332,38 +344,52 @@ void svd_eig(
raft::device_matrix_view<ValueType, IndexType, raft::col_major> V,
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U = std::nullopt)
{
ValueType* left_sing_vecs_ptr = nullptr;
if (U) {
RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1),
"U should have dimensions m * n");
left_sing_vecs_ptr = U.value().data_handle();
}
RAFT_EXPECTS(in.extent(0) == V.extent(0) && in.extent(1) == V.extent(1),
RAFT_EXPECTS(in.extent(1) == V.extent(0) && in.extent(1) == V.extent(1),
"V should have dimensions n * n");
svdEig(handle,
const_cast<ValueType*>(in.data_handle()),
in.extent(0),
in.extent(1),
S.data_handle(),
U.value().data_handle(),
V.value().data_handle(),
left_sing_vecs_ptr,
V.data_handle(),
U.has_value(),
handle.get_stream());
}

template <typename ValueType, typename IndexType, typename UType>
void svd_eig(raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> S,
raft::device_matrix_view<ValueType, IndexType, raft::col_major> V,
UType&& U = std::nullopt)
lowener marked this conversation as resolved.
Show resolved Hide resolved
{
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U_optional =
std::forward<UType>(U);
svd_eig(handle, in, S, V, U_optional);
}

/**
* @brief reconstruct a matrix use left and right singular vectors and
* singular values
* @param[in] handle raft::device_resources
* @param[in] U left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, k)
* @param[in] S singular values raft::device_vector_view of shape (k, k)
* @param[in] S square matrix with singular values on its diagonal of shape (k, k)
* @param[in] V right singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (k, n)
* @param[out] out output raft::device_matrix_view with layout raft::col_major of shape (m, n)
*/
template <typename ValueType, typename IndexType>
void svd_reconstruction(raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> U,
raft::device_vector_view<const ValueType, IndexType> S,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> S,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> V,
raft::device_matrix_view<ValueType, IndexType, raft::col_major> out)
{
Expand All @@ -380,6 +406,7 @@ void svd_reconstruction(raft::device_resources const& handle,
const_cast<ValueType*>(U.data_handle()),
const_cast<ValueType*>(S.data_handle()),
const_cast<ValueType*>(V.data_handle()),
out.data_handle(),
out.extent(0),
out.extent(1),
S.extent(0),
Expand Down
64 changes: 47 additions & 17 deletions cpp/test/linalg/svd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "../test_utils.cuh"
#include <gtest/gtest.h>
#include <raft/linalg/init.cuh>
#include <raft/linalg/svd.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/util/cuda_utils.cuh>
Expand Down Expand Up @@ -56,6 +57,49 @@ class SvdTest : public ::testing::TestWithParam<SvdInputs<T>> {
}

protected:
void test_API()
{
auto data_view = raft::make_device_matrix_view<const T, int, raft::col_major>(
data.data(), params.n_row, params.n_col);
auto sing_vals_view = raft::make_device_vector_view<T, int>(sing_vals_qr.data(), params.n_col);
auto left_eig_vectors_view = raft::make_device_matrix_view<T, int, raft::col_major>(
left_eig_vectors_qr.data(), params.n_row, params.n_col);
auto right_eig_vectors_view = raft::make_device_matrix_view<T, int, raft::col_major>(
right_eig_vectors_trans_qr.data(), params.n_col, params.n_col);
raft::linalg::svd_eig(handle, data_view, sing_vals_view, right_eig_vectors_view, std::nullopt);
raft::linalg::svd_qr(handle, data_view, sing_vals_view);
raft::linalg::svd_qr(
handle, data_view, sing_vals_view, std::make_optional(left_eig_vectors_view));
raft::linalg::svd_qr(
handle, data_view, sing_vals_view, std::nullopt, std::make_optional(right_eig_vectors_view));
raft::linalg::svd_qr_transpose_right_vec(handle, data_view, sing_vals_view);
raft::linalg::svd_qr_transpose_right_vec(
handle, data_view, sing_vals_view, std::make_optional(left_eig_vectors_view));
raft::linalg::svd_qr_transpose_right_vec(
handle, data_view, sing_vals_view, std::nullopt, std::make_optional(right_eig_vectors_view));
}

void test_qr()
{
auto data_view = raft::make_device_matrix_view<const T, int, raft::col_major>(
data.data(), params.n_row, params.n_col);
auto sing_vals_qr_view =
raft::make_device_vector_view<T, int>(sing_vals_qr.data(), params.n_col);
auto left_eig_vectors_qr_view =
std::optional(raft::make_device_matrix_view<T, int, raft::col_major>(
left_eig_vectors_qr.data(), params.n_row, params.n_col));
auto right_eig_vectors_trans_qr_view =
std::make_optional(raft::make_device_matrix_view<T, int, raft::col_major>(
right_eig_vectors_trans_qr.data(), params.n_col, params.n_col));

svd_qr_transpose_right_vec(handle,
data_view,
sing_vals_qr_view,
left_eig_vectors_qr_view,
right_eig_vectors_trans_qr_view);
handle.sync_stream(stream);
}

void SetUp() override
{
int len = params.len;
Expand All @@ -78,23 +122,9 @@ class SvdTest : public ::testing::TestWithParam<SvdInputs<T>> {
raft::update_device(right_eig_vectors_ref.data(), right_eig_vectors_ref_h, right_evl, stream);
raft::update_device(sing_vals_ref.data(), sing_vals_ref_h, params.n_col, stream);

auto data_view = raft::make_device_matrix_view<const T, int, raft::col_major>(
data.data(), params.n_row, params.n_col);
auto sing_vals_qr_view =
raft::make_device_vector_view<T, int>(sing_vals_qr.data(), params.n_col);
std::optional<raft::device_matrix_view<T, int, raft::col_major>> left_eig_vectors_qr_view =
raft::make_device_matrix_view<T, int, raft::col_major>(
left_eig_vectors_qr.data(), params.n_row, params.n_col);
std::optional<raft::device_matrix_view<T, int, raft::col_major>>
right_eig_vectors_trans_qr_view = raft::make_device_matrix_view<T, int, raft::col_major>(
right_eig_vectors_trans_qr.data(), params.n_col, params.n_col);

svd_qr_transpose_right_vec(handle,
data_view,
sing_vals_qr_view,
left_eig_vectors_qr_view,
right_eig_vectors_trans_qr_view);
handle.sync_stream(stream);
test_API();
raft::update_device(data.data(), data_h, len, stream);
test_qr();
}

protected:
Expand Down