From de24a9490b0b0b4c578ca0e264f1e5ad36f513b6 Mon Sep 17 00:00:00 2001 From: Micka <9810050+lowener@users.noreply.github.com> Date: Thu, 8 Dec 2022 11:39:51 +0100 Subject: [PATCH] Add support for 64bit svdeig (#1060) `raft::linalg::eigDC` already supports 64bit. This will be useful for https://github.com/rapidsai/cuml/issues/4906 Authors: - Micka (https://github.com/lowener) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1060 --- cpp/include/raft/linalg/detail/svd.cuh | 24 ++++++++++++------------ cpp/include/raft/linalg/svd.cuh | 14 +++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index 90a7ddec1f..8626c7888b 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -101,14 +101,14 @@ void svdQR(const raft::handle_t& handle, "This usually occurs when some of the features do not vary enough."); } -template +template void svdEig(const raft::handle_t& handle, - T* in, - int n_rows, - int n_cols, - T* S, - T* U, - T* V, + math_t* in, + idx_t n_rows, + idx_t n_cols, + math_t* S, + math_t* U, + math_t* V, bool gen_left_vec, cudaStream_t stream) { @@ -117,11 +117,11 @@ void svdEig(const raft::handle_t& handle, cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); cublasHandle_t cublasH = handle.get_cublas_handle(); - int len = n_cols * n_cols; - rmm::device_uvector in_cross_mult(len, stream); + auto len = n_cols * n_cols; + rmm::device_uvector in_cross_mult(len, stream); - T alpha = T(1); - T beta = T(0); + math_t alpha = math_t(1); + math_t beta = math_t(0); raft::linalg::gemm(handle, in, n_rows, @@ -139,7 +139,7 @@ void svdEig(const raft::handle_t& handle, raft::linalg::eigDC(handle, in_cross_mult.data(), n_cols, n_cols, V, S, stream); raft::matrix::colReverse(V, n_cols, n_cols, stream); - raft::matrix::rowReverse(S, n_cols, 1, stream); + raft::matrix::rowReverse(S, n_cols, idx_t(1), stream); raft::matrix::seqRoot(S, S, alpha, n_cols, stream, true); diff --git a/cpp/include/raft/linalg/svd.cuh b/cpp/include/raft/linalg/svd.cuh index 7be1b9d63c..2c1b5a5cb7 100644 --- a/cpp/include/raft/linalg/svd.cuh +++ b/cpp/include/raft/linalg/svd.cuh @@ -66,14 +66,14 @@ void svdQR(const raft::handle_t& handle, stream); } -template +template void svdEig(const raft::handle_t& handle, - T* in, - int n_rows, - int n_cols, - T* S, - T* U, - T* V, + math_t* in, + idx_t n_rows, + idx_t n_cols, + math_t* S, + math_t* U, + math_t* V, bool gen_left_vec, cudaStream_t stream) {