From 379de5a90a69ef066d008304f48bf39e1a33c1f0 Mon Sep 17 00:00:00 2001 From: Anupam Mijar Date: Tue, 27 Aug 2024 19:42:49 +0000 Subject: [PATCH] resolving pr comments --- .../raft/sparse/solver/detail/lanczos.cuh | 100 +++++++++--------- cpp/include/raft/sparse/solver/lanczos.cuh | 23 ++-- 2 files changed, 67 insertions(+), 56 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lanczos.cuh b/cpp/include/raft/sparse/solver/detail/lanczos.cuh index 9004d1f288..e4af9372d5 100644 --- a/cpp/include/raft/sparse/solver/detail/lanczos.cuh +++ b/cpp/include/raft/sparse/solver/detail/lanczos.cuh @@ -17,10 +17,10 @@ #pragma once // for cmath: -#include "raft/core/device_csr_matrix.hpp" #define _USE_MATH_DEFINES #include +#include #include #include #include @@ -1544,24 +1544,25 @@ void lanczos_solve_ritz( } template -void lanczos_aux(raft::resources const& handle, - // spectral::matrix::sparse_matrix_t const* A, - raft::device_csr_matrix_view A, - raft::device_matrix_view V, - raft::device_matrix_view u, - raft::device_matrix_view alpha, - raft::device_matrix_view beta, - int start_idx, - int end_idx, - int ncv, - raft::device_matrix_view v, - raft::device_matrix_view uu, - raft::device_matrix_view vv) +void lanczos_aux( + raft::resources const& handle, + // spectral::matrix::sparse_matrix_t const* A, + raft::device_csr_matrix_view A, + raft::device_matrix_view V, + raft::device_matrix_view u, + raft::device_matrix_view alpha, + raft::device_matrix_view beta, + int start_idx, + int end_idx, + int ncv, + raft::device_matrix_view v, + raft::device_matrix_view uu, + raft::device_matrix_view vv) { auto stream = resource::get_cuda_stream(handle); - auto A_structure = A.get_structure_view(); - index_type_t n = A_structure.get_n_rows(); + auto A_structure = A.structure_view(); + index_type_t n = A_structure.get_n_rows(); raft::copy(v.data_handle(), &(V(start_idx, 0)), n, stream); @@ -1569,13 +1570,14 @@ void lanczos_aux(raft::resources const& handle, auto cusparse_h = resource::get_cusparse_handle(handle); cusparseSpMatDescr_t cusparse_A; - raft::sparse::detail::cusparsecreatecsr(&cusparse_A, - A_structure.get_n_rows(), - A_structure.get_n_cols(), - A_structure.get_nnz(), - const_cast(A_structure.get_indptr().data()), - const_cast(A_structure.get_indices().data()), - const_cast(A_structure.get_elements().data())); + raft::sparse::detail::cusparsecreatecsr( + &cusparse_A, + A_structure.get_n_rows(), + A_structure.get_n_cols(), + A_structure.get_nnz(), + const_cast(A_structure.get_indptr().data()), + const_cast(A_structure.get_indices().data()), + const_cast(A.get_elements().data())); cusparseDnVecDescr_t cusparse_v; cusparseDnVecDescr_t cusparse_u; @@ -1683,21 +1685,22 @@ void lanczos_aux(raft::resources const& handle, } template -int lanczos_smallest(raft::resources const& handle, - raft::device_csr_matrix_view A, - int nEigVecs, - int maxIter, - int restartIter, - value_type_t tol, - value_type_t* eigVals_dev, - value_type_t* eigVecs_dev, - value_type_t* v0, - uint64_t seed) +int lanczos_smallest( + raft::resources const& handle, + raft::device_csr_matrix_view A, + int nEigVecs, + int maxIter, + int restartIter, + value_type_t tol, + value_type_t* eigVals_dev, + value_type_t* eigVecs_dev, + value_type_t* v0, + uint64_t seed) { auto A_structure = A.structure_view(); - int n = A_structure.get_n_rows(); - int ncv = restartIter; - auto stream = resource::get_cuda_stream(handle); + int n = A_structure.get_n_rows(); + int ncv = restartIter; + auto stream = resource::get_cuda_stream(handle); std::cout << std::fixed << std::setprecision(7); // Set precision to 10 decimal places @@ -1864,16 +1867,17 @@ int lanczos_smallest(raft::resources const& handle, auto cusparse_h = resource::get_cusparse_handle(handle); cusparseSpMatDescr_t cusparse_A; - // input_config.a_indptr = const_cast(x_structure.get_indptr().data()); - // input_config.a_indices = const_cast(x_structure.get_indices().data()); - // input_config.a_data = const_cast(x.get_elements().data()); - raft::sparse::detail::cusparsecreatecsr(&cusparse_A, - A_structure.get_n_rows(), - A_structure.get_n_cols(), - A_structure.get_nnz(), - const_cast(A_structure.get_indptr().data()), - const_cast(A_structure.get_indices().data()), - const_cast(A_structure.get_elements().data())); + // input_config.a_indptr = const_cast(x_structure.get_indptr().data()); + // input_config.a_indices = const_cast(x_structure.get_indices().data()); + // input_config.a_data = const_cast(x.get_elements().data()); + raft::sparse::detail::cusparsecreatecsr( + &cusparse_A, + A_structure.get_n_rows(), + A_structure.get_n_cols(), + A_structure.get_nnz(), + const_cast(A_structure.get_indptr().data()), + const_cast(A_structure.get_indices().data()), + const_cast(A.get_elements().data())); cusparseDnVecDescr_t cusparse_v; cusparseDnVecDescr_t cusparse_u; @@ -2058,14 +2062,14 @@ int lanczos_smallest(raft::resources const& handle, template auto lanczos_compute_smallest_eigenvectors( raft::resources const& handle, - raft::device_csr_matrix_view A, + raft::device_csr_matrix_view A, lanczos_solver_config const& config, raft::device_vector_view v0, raft::device_vector_view eigenvalues, raft::device_matrix_view eigenvectors) -> int { return lanczos_smallest(handle, - &A, + A, config.n_components, config.max_iterations, config.ncv, diff --git a/cpp/include/raft/sparse/solver/lanczos.cuh b/cpp/include/raft/sparse/solver/lanczos.cuh index 218b4eb806..d65a675dac 100644 --- a/cpp/include/raft/sparse/solver/lanczos.cuh +++ b/cpp/include/raft/sparse/solver/lanczos.cuh @@ -48,16 +48,23 @@ auto lanczos_compute_smallest_eigenvectors( // raft::core::bitmap_view(bitmap_d.data(), params.m, params.n); // auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); - + // FIXME: move out of function - auto csr_structure = raft::make_device_compressed_structure_view( - A.row_offsets_, - A.col_indices_, - A.ncols_, - A.nrows_, - static_cast(A.nnz_)); + IndexTypeT ncols = A.ncols_; + IndexTypeT nrows = A.nrows_; + IndexTypeT nnz = A.nnz_; + + auto csr_structure = + raft::make_device_compressed_structure_view( + const_cast(A.row_offsets_), + const_cast(A.col_indices_), + ncols, + nrows, + nnz); - auto csr_matrix = raft::make_device_matrix_view(A.values_, csr_structure); + auto csr_matrix = + raft::make_device_csr_matrix_view( + const_cast(A.values_), csr_structure); return detail::lanczos_compute_smallest_eigenvectors( handle, csr_matrix, config, v0, eigenvalues, eigenvectors);