diff --git a/cpp/include/raft/linalg/rsvd.cuh b/cpp/include/raft/linalg/rsvd.cuh index e465ee6fa2..be2f5f0286 100644 --- a/cpp/include/raft/linalg/rsvd.cuh +++ b/cpp/include/raft/linalg/rsvd.cuh @@ -152,20 +152,24 @@ void rsvdPerc(const raft::handle_t& handle, * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples - * @param[out] U optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major - * @param[out] V optional right singular values of raft::device_matrix_view with layout + * @param[out] V_in optional right singular values of raft::device_matrix_view with layout * raft::col_major */ -template -void rsvd_fixed_rank( - const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - IndexType p, - std::optional> U = std::nullopt, - std::optional> V = std::nullopt) +template +void rsvd_fixed_rank(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + UType&& U_in, + VType&& V_in) { + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + if (U) { RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), @@ -202,22 +206,10 @@ void rsvd_fixed_rank( * * Please see above for documentation of `rsvd_fixed_rank`. */ -template -void rsvd_fixed_rank(const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - IndexType p, - ValueType tol, - int max_sweeps, - UType&& U, - VType&& V) +template > +void rsvd_fixed_rank(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - rsvd_fixed_rank(handle, M, S_vec, p, tol, max_sweeps, U_optional, V_optional); + rsvd_fixed_rank(std::forward(args)..., std::nullopt, std::nullopt); } /** @@ -228,20 +220,25 @@ void rsvd_fixed_rank(const raft::handle_t& handle, * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples - * @param[out] U optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major - * @param[out] V optional right singular values of raft::device_matrix_view with layout + * @param[out] V_in optional right singular values of raft::device_matrix_view with layout * raft::col_major */ -template +template void rsvd_fixed_rank_symmetric( const raft::handle_t& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, - std::optional> U = std::nullopt, - std::optional> V = std::nullopt) + UType&& U_in, + VType&& V_in) { + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + if (U) { RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), @@ -278,23 +275,10 @@ void rsvd_fixed_rank_symmetric( * * Please see above for documentation of `rsvd_fixed_rank_symmetric`. */ -template -void rsvd_fixed_rank_symmetric( - const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - IndexType p, - ValueType tol, - int max_sweeps, - UType&& U, - VType&& V) +template > +void rsvd_fixed_rank_symmetric(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - rsvd_fixed_rank_symmetric(handle, M, S_vec, p, tol, max_sweeps, U_optional, V_optional); + rsvd_fixed_rank_symmetric(std::forward(args)..., std::nullopt, std::nullopt); } /** @@ -307,22 +291,26 @@ void rsvd_fixed_rank_symmetric( * @param[in] p no. of upsamples * @param[in] tol tolerance for Jacobi-based solvers * @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers - * @param[out] U optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major - * @param[out] V optional right singular values of raft::device_matrix_view with layout + * @param[out] V_in optional right singular values of raft::device_matrix_view with layout * raft::col_major */ -template -void rsvd_fixed_rank_jacobi( - const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - IndexType p, - ValueType tol, - int max_sweeps, - std::optional> U = std::nullopt, - std::optional> V = std::nullopt) +template +void rsvd_fixed_rank_jacobi(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + ValueType tol, + int max_sweeps, + UType&& U_in, + VType&& V_in) { + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + if (U) { RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), @@ -359,22 +347,10 @@ void rsvd_fixed_rank_jacobi( * * Please see above for documentation of `rsvd_fixed_rank_jacobi`. */ -template -void rsvd_fixed_rank_jacobi(const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - IndexType p, - ValueType tol, - int max_sweeps, - UType&& U, - VType&& V) +template > +void rsvd_fixed_rank_jacobi(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - rsvd_fixed_rank_sjacobi(handle, M, S_vec, p, tol, max_sweeps, U_optional, V_optional); + rsvd_fixed_rank_jacobi(std::forward(args)..., std::nullopt, std::nullopt); } /** @@ -387,12 +363,12 @@ void rsvd_fixed_rank_jacobi(const raft::handle_t& handle, * @param[in] p no. of upsamples * @param[in] tol tolerance for Jacobi-based solvers * @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers - * @param[out] U optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major - * @param[out] V optional right singular values of raft::device_matrix_view with layout + * @param[out] V_in optional right singular values of raft::device_matrix_view with layout * raft::col_major */ -template +template void rsvd_fixed_rank_symmetric_jacobi( const raft::handle_t& handle, raft::device_matrix_view M, @@ -400,9 +376,14 @@ void rsvd_fixed_rank_symmetric_jacobi( IndexType p, ValueType tol, int max_sweeps, - std::optional> U = std::nullopt, - std::optional> V = std::nullopt) + UType&& U_in, + VType&& V_in) { + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + if (U) { RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), @@ -439,23 +420,10 @@ void rsvd_fixed_rank_symmetric_jacobi( * * Please see above for documentation of `rsvd_fixed_rank_symmetric_jacobi`. */ -template -void rsvd_fixed_rank_symmetric_jacobi( - const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - IndexType p, - ValueType tol, - int max_sweeps, - UType&& U, - VType&& V) +template > +void rsvd_fixed_rank_symmetric_jacobi(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - rsvd_fixed_rank_symmetric_jacobi(handle, M, S_vec, p, tol, max_sweeps, U_optional, V_optional); + rsvd_fixed_rank_symmetric_jacobi(std::forward(args)..., std::nullopt, std::nullopt); } /** @@ -467,21 +435,25 @@ void rsvd_fixed_rank_symmetric_jacobi( * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed * @param[in] UpS_perc upsampling percentage - * @param[out] U optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major - * @param[out] V optional right singular values of raft::device_matrix_view with layout + * @param[out] V_in optional right singular values of raft::device_matrix_view with layout * raft::col_major */ -template -void rsvd_perc( - const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - ValueType PC_perc, - ValueType UpS_perc, - std::optional> U = std::nullopt, - std::optional> V = std::nullopt) +template +void rsvd_perc(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + UType&& U_in, + VType&& V_in) { + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + if (U) { RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), @@ -518,23 +490,10 @@ void rsvd_perc( * * Please see above for documentation of `rsvd_perc`. */ -template -void rsvd_perc(const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - ValueType PC_perc, - ValueType UpS_perc, - ValueType tol, - int max_sweeps, - UType&& U, - VType&& V) +template > +void rsvd_perc(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - rsvd_perc(handle, M, S_vec, PC_perc, UpS_perc, tol, max_sweeps, U_optional, V_optional); + rsvd_perc(std::forward(args)..., std::nullopt, std::nullopt); } /** @@ -546,21 +505,25 @@ void rsvd_perc(const raft::handle_t& handle, * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed * @param[in] UpS_perc upsampling percentage - * @param[out] U optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major - * @param[out] V optional right singular values of raft::device_matrix_view with layout + * @param[out] V_in optional right singular values of raft::device_matrix_view with layout * raft::col_major */ -template -void rsvd_perc_symmetric( - const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - ValueType PC_perc, - ValueType UpS_perc, - std::optional> U = std::nullopt, - std::optional> V = std::nullopt) +template +void rsvd_perc_symmetric(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + UType&& U_in, + VType&& V_in) { + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + if (U) { RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), @@ -597,23 +560,10 @@ void rsvd_perc_symmetric( * * Please see above for documentation of `rsvd_perc_symmetric`. */ -template -void rsvd_perc_symmetric(const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - ValueType PC_perc, - ValueType UpS_perc, - ValueType tol, - int max_sweeps, - UType&& U, - VType&& V) +template > +void rsvd_perc_symmetric(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - rsvd_perc_symmetric(handle, M, S_vec, PC_perc, UpS_perc, tol, max_sweeps, U_optional, V_optional); + rsvd_perc_symmetric(std::forward(args)..., std::nullopt, std::nullopt); } /** @@ -627,23 +577,27 @@ void rsvd_perc_symmetric(const raft::handle_t& handle, * @param[in] UpS_perc upsampling percentage * @param[in] tol tolerance for Jacobi-based solvers * @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers - * @param[out] U optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major - * @param[out] V optional right singular values of raft::device_matrix_view with layout + * @param[out] V_in optional right singular values of raft::device_matrix_view with layout * raft::col_major */ -template -void rsvd_perc_jacobi( - const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - ValueType PC_perc, - ValueType UpS_perc, - ValueType tol, - int max_sweeps, - std::optional> U = std::nullopt, - std::optional> V = std::nullopt) +template +void rsvd_perc_jacobi(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + ValueType tol, + int max_sweeps, + UType&& U_in, + VType&& V_in) { + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + if (U) { RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), @@ -680,23 +634,10 @@ void rsvd_perc_jacobi( * * Please see above for documentation of `rsvd_perc_jacobi`. */ -template -void rsvd_perc_jacobi(const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - ValueType PC_perc, - ValueType UpS_perc, - ValueType tol, - int max_sweeps, - UType&& U, - VType&& V) +template > +void rsvd_perc_jacobi(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - rsvd_perc_jacobi(handle, M, S_vec, PC_perc, UpS_perc, tol, max_sweeps, U_optional, V_optional); + rsvd_perc_jacobi(std::forward(args)..., std::nullopt, std::nullopt); } /** @@ -710,12 +651,12 @@ void rsvd_perc_jacobi(const raft::handle_t& handle, * @param[in] UpS_perc upsampling percentage * @param[in] tol tolerance for Jacobi-based solvers * @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers - * @param[out] U optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major - * @param[out] V optional right singular values of raft::device_matrix_view with layout + * @param[out] V_in optional right singular values of raft::device_matrix_view with layout * raft::col_major */ -template +template void rsvd_perc_symmetric_jacobi( const raft::handle_t& handle, raft::device_matrix_view M, @@ -724,9 +665,14 @@ void rsvd_perc_symmetric_jacobi( ValueType UpS_perc, ValueType tol, int max_sweeps, - std::optional> U = std::nullopt, - std::optional> V = std::nullopt) + UType&& U_in, + VType&& V_in) { + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + if (U) { RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), @@ -763,25 +709,10 @@ void rsvd_perc_symmetric_jacobi( * * Please see above for documentation of `rsvd_perc_symmetric_jacobi`. */ -template -void rsvd_perc_symmetric_jacobi( - const raft::handle_t& handle, - raft::device_matrix_view M, - raft::device_vector_view S_vec, - ValueType PC_perc, - ValueType UpS_perc, - ValueType tol, - int max_sweeps, - UType&& U, - VType&& V) +template > +void rsvd_perc_symmetric_jacobi(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - rsvd_perc_symmetric_jacobi( - handle, M, S_vec, PC_perc, UpS_perc, tol, max_sweeps, U_optional, V_optional); + rsvd_perc_symmetric_jacobi(std::forward(args)..., std::nullopt, std::nullopt); } /** @} */ // end of group rsvd diff --git a/cpp/include/raft/linalg/svd.cuh b/cpp/include/raft/linalg/svd.cuh index 0026ec1f7d..fb30f17477 100644 --- a/cpp/include/raft/linalg/svd.cuh +++ b/cpp/include/raft/linalg/svd.cuh @@ -192,29 +192,29 @@ bool evaluateSVDByL2Norm(const raft::handle_t& handle, * @param[in] handle raft::handle_t * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) - * @param[out] left_sing_vecs optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (m, n) - * @param[out] right_sing_vecs optional right singular values of raft::device_matrix_view with + * @param[out] V_in optional right singular values of raft::device_matrix_view with * layout raft::col_major and dimensions (n, n) */ -template -void svd_qr( - const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_vector_view sing_vals, - std::optional> left_sing_vecs = - std::nullopt, - std::optional> right_sing_vecs = - std::nullopt) +template +void svd_qr(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view sing_vals, + UType&& U_in, + VType&& V_in) { - if (left_sing_vecs) { - RAFT_EXPECTS(in.extent(0) == left_sing_vecs.value().extent(0) && - in.extent(1) == left_sing_vecs.value().extent(1), + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + + if (U) { + RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1), "U should have dimensions m * n"); } - if (right_sing_vecs) { - RAFT_EXPECTS(in.extent(1) == right_sing_vecs.value().extent(0) && - in.extent(1) == right_sing_vecs.value().extent(1), + if (V) { + RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1), "V should have dimensions n * n"); } svdQR(handle, @@ -222,11 +222,11 @@ void svd_qr( in.extent(0), in.extent(1), sing_vals.data_handle(), - left_sing_vecs.value().data_handle(), - right_sing_vecs.value().data_handle(), + U.value().data_handle(), + V.value().data_handle(), false, - left_sing_vecs.has_value(), - right_sing_vecs.has_value(), + U.has_value(), + V.has_value(), handle.get_stream()); } @@ -237,19 +237,10 @@ void svd_qr( * * Please see above for documentation of `svd_qr`. */ -template -void svd_qr(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_vector_view sing_vals, - UType&& U, - VType&& V) +template > +void svd_qr(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - svd_qr(handle, in, sing_vals, U_optional, V_optional); + svd_qr(std::forward(args)..., std::nullopt, std::nullopt); } /** @@ -258,29 +249,30 @@ void svd_qr(const raft::handle_t& handle, * @param[in] handle raft::handle_t * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) - * @param[out] left_sing_vecs optional left singular values of raft::device_matrix_view with layout + * @param[out] U_in optional left singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (m, n) - * @param[out] right_sing_vecs optional right singular values of raft::device_matrix_view with + * @param[out] V_in optional right singular values of raft::device_matrix_view with * layout raft::col_major and dimensions (n, n) */ -template +template void svd_qr_transpose_right_vec( const raft::handle_t& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, - std::optional> left_sing_vecs = - std::nullopt, - std::optional> right_sing_vecs = - std::nullopt) + UType&& U_in, + VType&& V_in) { - if (left_sing_vecs) { - RAFT_EXPECTS(in.extent(0) == left_sing_vecs.value().extent(0) && - in.extent(1) == left_sing_vecs.value().extent(1), + std::optional> U = + std::forward(U_in); + std::optional> V = + std::forward(V_in); + + if (U) { + RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1), "U should have dimensions m * n"); } - if (right_sing_vecs) { - RAFT_EXPECTS(in.extent(1) == right_sing_vecs.value().extent(0) && - in.extent(1) == right_sing_vecs.value().extent(1), + if (V) { + RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1), "V should have dimensions n * n"); } svdQR(handle, @@ -288,11 +280,11 @@ void svd_qr_transpose_right_vec( in.extent(0), in.extent(1), sing_vals.data_handle(), - left_sing_vecs.value().data_handle(), - right_sing_vecs.value().data_handle(), + U.value().data_handle(), + V.value().data_handle(), true, - left_sing_vecs.has_value(), - right_sing_vecs.has_value(), + U.has_value(), + V.has_value(), handle.get_stream()); } @@ -303,20 +295,10 @@ void svd_qr_transpose_right_vec( * * Please see above for documentation of `svd_qr_transpose_right_vec`. */ -template -void svd_qr_transpose_right_vec( - const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_vector_view sing_vals, - UType&& U, - VType&& V) +template > +void svd_qr_transpose_right_vec(Args... args) { - std::optional> U_optional = - std::forward(U); - std::optional> V_optional = - std::forward(V); - - svd_qr_transpose_right_vec(handle, in, sing_vals, U_optional, V_optional); + svd_qr_transpose_right_vec(std::forward(args)..., std::nullopt, std::nullopt); } /**