Skip to content

Commit

Permalink
Add support for 64bit svdeig (#1060)
Browse files Browse the repository at this point in the history
`raft::linalg::eigDC` already supports 64bit.
This will be useful for rapidsai/cuml#4906

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1060
  • Loading branch information
lowener authored Dec 8, 2022
1 parent ce7f4c2 commit de24a94
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
24 changes: 12 additions & 12 deletions cpp/include/raft/linalg/detail/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ void svdQR(const raft::handle_t& handle,
"This usually occurs when some of the features do not vary enough.");
}

template <typename T>
template <typename math_t, typename idx_t>
void svdEig(const raft::handle_t& handle,
T* in,
int n_rows,
int n_cols,
T* S,
T* U,
T* V,
math_t* in,
idx_t n_rows,
idx_t n_cols,
math_t* S,
math_t* U,
math_t* V,
bool gen_left_vec,
cudaStream_t stream)
{
Expand All @@ -117,11 +117,11 @@ void svdEig(const raft::handle_t& handle,
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cublasHandle_t cublasH = handle.get_cublas_handle();

int len = n_cols * n_cols;
rmm::device_uvector<T> in_cross_mult(len, stream);
auto len = n_cols * n_cols;
rmm::device_uvector<math_t> in_cross_mult(len, stream);

T alpha = T(1);
T beta = T(0);
math_t alpha = math_t(1);
math_t beta = math_t(0);
raft::linalg::gemm(handle,
in,
n_rows,
Expand All @@ -139,7 +139,7 @@ void svdEig(const raft::handle_t& handle,
raft::linalg::eigDC(handle, in_cross_mult.data(), n_cols, n_cols, V, S, stream);

raft::matrix::colReverse(V, n_cols, n_cols, stream);
raft::matrix::rowReverse(S, n_cols, 1, stream);
raft::matrix::rowReverse(S, n_cols, idx_t(1), stream);

raft::matrix::seqRoot(S, S, alpha, n_cols, stream, true);

Expand Down
14 changes: 7 additions & 7 deletions cpp/include/raft/linalg/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ void svdQR(const raft::handle_t& handle,
stream);
}

template <typename T>
template <typename math_t, typename idx_t>
void svdEig(const raft::handle_t& handle,
T* in,
int n_rows,
int n_cols,
T* S,
T* U,
T* V,
math_t* in,
idx_t n_rows,
idx_t n_cols,
math_t* S,
math_t* U,
math_t* V,
bool gen_left_vec,
cudaStream_t stream)
{
Expand Down

0 comments on commit de24a94

Please sign in to comment.