diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index 4850744f51..998bea5b1b 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -89,7 +89,8 @@ void svdQR(raft::device_resources const& handle, stream)); // Transpose the right singular vector back - if (trans_right) raft::linalg::transpose(right_sing_vecs, n_cols, stream); + if (trans_right && right_sing_vecs != nullptr) + raft::linalg::transpose(right_sing_vecs, n_cols, stream); RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/linalg/svd.cuh b/cpp/include/raft/linalg/svd.cuh index eb51093240..4b78f2ef61 100644 --- a/cpp/include/raft/linalg/svd.cuh +++ b/cpp/include/raft/linalg/svd.cuh @@ -191,45 +191,42 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle, * matrix using QR decomposition * @tparam ValueType value type of parameters * @tparam IndexType index type of parameters - * @tparam UType std::optional> @c - * U_in - * @tparam VType std::optional> @c - * V_in * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) - * @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout + * @param[out] U std::optional left singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (m, n) - * @param[out] V_in std::optional right singular values of raft::device_matrix_view with + * @param[out] V std::optional right singular values of raft::device_matrix_view with * layout raft::col_major and dimensions (n, n) */ -template -void svd_qr(raft::device_resources const& handle, - raft::device_matrix_view in, - raft::device_vector_view sing_vals, - UType&& U_in, - VType&& V_in) +template +void svd_qr( + raft::device_resources const& handle, + raft::device_matrix_view in, + raft::device_vector_view sing_vals, + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) { - std::optional> U = - std::forward(U_in); - std::optional> V = - std::forward(V_in); + ValueType* left_sing_vecs_ptr = nullptr; + ValueType* right_sing_vecs_ptr = nullptr; if (U) { RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1), "U should have dimensions m * n"); + left_sing_vecs_ptr = U.value().data_handle(); } if (V) { RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1), "V should have dimensions n * n"); + right_sing_vecs_ptr = V.value().data_handle(); } svdQR(handle, const_cast(in.data_handle()), in.extent(0), in.extent(1), sing_vals.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + left_sing_vecs_ptr, + right_sing_vecs_ptr, false, U.has_value(), V.has_value(), @@ -243,10 +240,19 @@ void svd_qr(raft::device_resources const& handle, * * Please see above for documentation of `svd_qr`. */ -template > -void svd_qr(Args... args) +template +void svd_qr(raft::device_resources const& handle, + raft::device_matrix_view in, + raft::device_vector_view sing_vals, + UType&& U_in = std::nullopt, + VType&& V_in = std::nullopt) { - svd_qr(std::forward(args)..., std::nullopt, std::nullopt); + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + + svd_qr(handle, in, sing_vals, U, V); } /** @@ -254,46 +260,42 @@ void svd_qr(Args... args) * matrix using QR decomposition. Right singular vector matrix is transposed before returning * @tparam ValueType value type of parameters * @tparam IndexType index type of parameters - * @tparam UType std::optional> @c - * U_in - * @tparam VType std::optional> @c - * V_in * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) - * @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout + * @param[out] U std::optional left singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (m, n) - * @param[out] V_in std::optional right singular values of raft::device_matrix_view with + * @param[out] V std::optional right singular values of raft::device_matrix_view with * layout raft::col_major and dimensions (n, n) */ -template +template void svd_qr_transpose_right_vec( raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, - UType&& U_in, - VType&& V_in) + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) { - std::optional> U = - std::forward(U_in); - std::optional> V = - std::forward(V_in); + ValueType* left_sing_vecs_ptr = nullptr; + ValueType* right_sing_vecs_ptr = nullptr; if (U) { RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1), "U should have dimensions m * n"); + left_sing_vecs_ptr = U.value().data_handle(); } if (V) { RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1), "V should have dimensions n * n"); + right_sing_vecs_ptr = V.value().data_handle(); } svdQR(handle, const_cast(in.data_handle()), in.extent(0), in.extent(1), sing_vals.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + left_sing_vecs_ptr, + right_sing_vecs_ptr, true, U.has_value(), V.has_value(), @@ -307,10 +309,20 @@ void svd_qr_transpose_right_vec( * * Please see above for documentation of `svd_qr_transpose_right_vec`. */ -template > -void svd_qr_transpose_right_vec(Args... args) +template +void svd_qr_transpose_right_vec( + raft::device_resources const& handle, + raft::device_matrix_view in, + raft::device_vector_view sing_vals, + UType&& U_in = std::nullopt, + VType&& V_in = std::nullopt) { - svd_qr_transpose_right_vec(std::forward(args)..., std::nullopt, std::nullopt); + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + + svd_qr_transpose_right_vec(handle, in, sing_vals, U, V); } /** @@ -320,7 +332,7 @@ void svd_qr_transpose_right_vec(Args... args) * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S singular values raft::device_vector_view of shape (K) * @param[out] V right singular values of raft::device_matrix_view with layout - * raft::col_major and dimensions (m, n) + * raft::col_major and dimensions (n, n) * @param[out] U optional left singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (m, n) */ @@ -332,30 +344,44 @@ void svd_eig( raft::device_matrix_view V, std::optional> U = std::nullopt) { + ValueType* left_sing_vecs_ptr = nullptr; if (U) { RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1), "U should have dimensions m * n"); + left_sing_vecs_ptr = U.value().data_handle(); } - RAFT_EXPECTS(in.extent(0) == V.extent(0) && in.extent(1) == V.extent(1), + RAFT_EXPECTS(in.extent(1) == V.extent(0) && in.extent(1) == V.extent(1), "V should have dimensions n * n"); svdEig(handle, const_cast(in.data_handle()), in.extent(0), in.extent(1), S.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + left_sing_vecs_ptr, + V.data_handle(), U.has_value(), handle.get_stream()); } +template +void svd_eig(raft::device_resources const& handle, + raft::device_matrix_view in, + raft::device_vector_view S, + raft::device_matrix_view V, + UType&& U = std::nullopt) +{ + std::optional> U_optional = + std::forward(U); + svd_eig(handle, in, S, V, U_optional); +} + /** * @brief reconstruct a matrix use left and right singular vectors and * singular values * @param[in] handle raft::device_resources * @param[in] U left singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (m, k) - * @param[in] S singular values raft::device_vector_view of shape (k, k) + * @param[in] S square matrix with singular values on its diagonal of shape (k, k) * @param[in] V right singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (k, n) * @param[out] out output raft::device_matrix_view with layout raft::col_major of shape (m, n) @@ -363,7 +389,7 @@ void svd_eig( template void svd_reconstruction(raft::device_resources const& handle, raft::device_matrix_view U, - raft::device_vector_view S, + raft::device_matrix_view S, raft::device_matrix_view V, raft::device_matrix_view out) { @@ -380,6 +406,7 @@ void svd_reconstruction(raft::device_resources const& handle, const_cast(U.data_handle()), const_cast(S.data_handle()), const_cast(V.data_handle()), + out.data_handle(), out.extent(0), out.extent(1), S.extent(0), diff --git a/cpp/test/linalg/svd.cu b/cpp/test/linalg/svd.cu index 9eee0f538e..bd66459962 100644 --- a/cpp/test/linalg/svd.cu +++ b/cpp/test/linalg/svd.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" #include +#include #include #include #include @@ -56,6 +57,49 @@ class SvdTest : public ::testing::TestWithParam> { } protected: + void test_API() + { + auto data_view = raft::make_device_matrix_view( + data.data(), params.n_row, params.n_col); + auto sing_vals_view = raft::make_device_vector_view(sing_vals_qr.data(), params.n_col); + auto left_eig_vectors_view = raft::make_device_matrix_view( + left_eig_vectors_qr.data(), params.n_row, params.n_col); + auto right_eig_vectors_view = raft::make_device_matrix_view( + right_eig_vectors_trans_qr.data(), params.n_col, params.n_col); + raft::linalg::svd_eig(handle, data_view, sing_vals_view, right_eig_vectors_view, std::nullopt); + raft::linalg::svd_qr(handle, data_view, sing_vals_view); + raft::linalg::svd_qr( + handle, data_view, sing_vals_view, std::make_optional(left_eig_vectors_view)); + raft::linalg::svd_qr( + handle, data_view, sing_vals_view, std::nullopt, std::make_optional(right_eig_vectors_view)); + raft::linalg::svd_qr_transpose_right_vec(handle, data_view, sing_vals_view); + raft::linalg::svd_qr_transpose_right_vec( + handle, data_view, sing_vals_view, std::make_optional(left_eig_vectors_view)); + raft::linalg::svd_qr_transpose_right_vec( + handle, data_view, sing_vals_view, std::nullopt, std::make_optional(right_eig_vectors_view)); + } + + void test_qr() + { + auto data_view = raft::make_device_matrix_view( + data.data(), params.n_row, params.n_col); + auto sing_vals_qr_view = + raft::make_device_vector_view(sing_vals_qr.data(), params.n_col); + auto left_eig_vectors_qr_view = + std::optional(raft::make_device_matrix_view( + left_eig_vectors_qr.data(), params.n_row, params.n_col)); + auto right_eig_vectors_trans_qr_view = + std::make_optional(raft::make_device_matrix_view( + right_eig_vectors_trans_qr.data(), params.n_col, params.n_col)); + + svd_qr_transpose_right_vec(handle, + data_view, + sing_vals_qr_view, + left_eig_vectors_qr_view, + right_eig_vectors_trans_qr_view); + handle.sync_stream(stream); + } + void SetUp() override { int len = params.len; @@ -78,23 +122,9 @@ class SvdTest : public ::testing::TestWithParam> { raft::update_device(right_eig_vectors_ref.data(), right_eig_vectors_ref_h, right_evl, stream); raft::update_device(sing_vals_ref.data(), sing_vals_ref_h, params.n_col, stream); - auto data_view = raft::make_device_matrix_view( - data.data(), params.n_row, params.n_col); - auto sing_vals_qr_view = - raft::make_device_vector_view(sing_vals_qr.data(), params.n_col); - std::optional> left_eig_vectors_qr_view = - raft::make_device_matrix_view( - left_eig_vectors_qr.data(), params.n_row, params.n_col); - std::optional> - right_eig_vectors_trans_qr_view = raft::make_device_matrix_view( - right_eig_vectors_trans_qr.data(), params.n_col, params.n_col); - - svd_qr_transpose_right_vec(handle, - data_view, - sing_vals_qr_view, - left_eig_vectors_qr_view, - right_eig_vectors_trans_qr_view); - handle.sync_stream(stream); + test_API(); + raft::update_device(data.data(), data_h, len, stream); + test_qr(); } protected: