Skip to content

Commit

Permalink
resolving pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aamijar committed Aug 27, 2024
1 parent 74908f2 commit 379de5a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 56 deletions.
100 changes: 52 additions & 48 deletions cpp/include/raft/sparse/solver/detail/lanczos.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
#pragma once

// for cmath:
#include "raft/core/device_csr_matrix.hpp"
#define _USE_MATH_DEFINES

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
Expand Down Expand Up @@ -1544,38 +1544,40 @@ void lanczos_solve_ritz(
}

template <typename index_type_t, typename value_type_t>
void lanczos_aux(raft::resources const& handle,
// spectral::matrix::sparse_matrix_t<index_type_t, value_type_t> const* A,
raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, value_type_t> A,
raft::device_matrix_view<value_type_t, uint32_t, raft::row_major> V,
raft::device_matrix_view<value_type_t> u,
raft::device_matrix_view<value_type_t> alpha,
raft::device_matrix_view<value_type_t> beta,
int start_idx,
int end_idx,
int ncv,
raft::device_matrix_view<value_type_t> v,
raft::device_matrix_view<value_type_t> uu,
raft::device_matrix_view<value_type_t> vv)
void lanczos_aux(
raft::resources const& handle,
// spectral::matrix::sparse_matrix_t<index_type_t, value_type_t> const* A,
raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, index_type_t> A,
raft::device_matrix_view<value_type_t, uint32_t, raft::row_major> V,
raft::device_matrix_view<value_type_t> u,
raft::device_matrix_view<value_type_t> alpha,
raft::device_matrix_view<value_type_t> beta,
int start_idx,
int end_idx,
int ncv,
raft::device_matrix_view<value_type_t> v,
raft::device_matrix_view<value_type_t> uu,
raft::device_matrix_view<value_type_t> vv)
{
auto stream = resource::get_cuda_stream(handle);

auto A_structure = A.get_structure_view();
index_type_t n = A_structure.get_n_rows();
auto A_structure = A.structure_view();
index_type_t n = A_structure.get_n_rows();

raft::copy(v.data_handle(), &(V(start_idx, 0)), n, stream);

std::cout << start_idx << " " << end_idx << std::endl;

auto cusparse_h = resource::get_cusparse_handle(handle);
cusparseSpMatDescr_t cusparse_A;
raft::sparse::detail::cusparsecreatecsr(&cusparse_A,
A_structure.get_n_rows(),
A_structure.get_n_cols(),
A_structure.get_nnz(),
const_cast<index_type_t*>(A_structure.get_indptr().data()),
const_cast<index_type_t*>(A_structure.get_indices().data()),
const_cast<value_type_t*>(A_structure.get_elements().data()));
raft::sparse::detail::cusparsecreatecsr(
&cusparse_A,
A_structure.get_n_rows(),
A_structure.get_n_cols(),
A_structure.get_nnz(),
const_cast<index_type_t*>(A_structure.get_indptr().data()),
const_cast<index_type_t*>(A_structure.get_indices().data()),
const_cast<value_type_t*>(A.get_elements().data()));

cusparseDnVecDescr_t cusparse_v;
cusparseDnVecDescr_t cusparse_u;
Expand Down Expand Up @@ -1683,21 +1685,22 @@ void lanczos_aux(raft::resources const& handle,
}

template <typename index_type_t, typename value_type_t>
int lanczos_smallest(raft::resources const& handle,
raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, value_type_t> A,
int nEigVecs,
int maxIter,
int restartIter,
value_type_t tol,
value_type_t* eigVals_dev,
value_type_t* eigVecs_dev,
value_type_t* v0,
uint64_t seed)
int lanczos_smallest(
raft::resources const& handle,
raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, index_type_t> A,
int nEigVecs,
int maxIter,
int restartIter,
value_type_t tol,
value_type_t* eigVals_dev,
value_type_t* eigVecs_dev,
value_type_t* v0,
uint64_t seed)
{
auto A_structure = A.structure_view();
int n = A_structure.get_n_rows();
int ncv = restartIter;
auto stream = resource::get_cuda_stream(handle);
int n = A_structure.get_n_rows();
int ncv = restartIter;
auto stream = resource::get_cuda_stream(handle);

std::cout << std::fixed << std::setprecision(7); // Set precision to 10 decimal places

Expand Down Expand Up @@ -1864,16 +1867,17 @@ int lanczos_smallest(raft::resources const& handle,

auto cusparse_h = resource::get_cusparse_handle(handle);
cusparseSpMatDescr_t cusparse_A;
// input_config.a_indptr = const_cast<IndexType*>(x_structure.get_indptr().data());
// input_config.a_indices = const_cast<IndexType*>(x_structure.get_indices().data());
// input_config.a_data = const_cast<ElementType*>(x.get_elements().data());
raft::sparse::detail::cusparsecreatecsr(&cusparse_A,
A_structure.get_n_rows(),
A_structure.get_n_cols(),
A_structure.get_nnz(),
const_cast<index_type_t*>(A_structure.get_indptr().data()),
const_cast<index_type_t*>(A_structure.get_indices().data()),
const_cast<value_type_t*>(A_structure.get_elements().data()));
// input_config.a_indptr = const_cast<IndexType*>(x_structure.get_indptr().data());
// input_config.a_indices = const_cast<IndexType*>(x_structure.get_indices().data());
// input_config.a_data = const_cast<ElementType*>(x.get_elements().data());
raft::sparse::detail::cusparsecreatecsr(
&cusparse_A,
A_structure.get_n_rows(),
A_structure.get_n_cols(),
A_structure.get_nnz(),
const_cast<index_type_t*>(A_structure.get_indptr().data()),
const_cast<index_type_t*>(A_structure.get_indices().data()),
const_cast<value_type_t*>(A.get_elements().data()));

cusparseDnVecDescr_t cusparse_v;
cusparseDnVecDescr_t cusparse_u;
Expand Down Expand Up @@ -2058,14 +2062,14 @@ int lanczos_smallest(raft::resources const& handle,
template <typename IndexTypeT, typename ValueTypeT>
auto lanczos_compute_smallest_eigenvectors(
raft::resources const& handle,
raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, ValueTypeT> A,
raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
lanczos_solver_config<IndexTypeT, ValueTypeT> const& config,
raft::device_vector_view<ValueTypeT, uint32_t> v0,
raft::device_vector_view<ValueTypeT, uint32_t> eigenvalues,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
{
return lanczos_smallest(handle,
&A,
A,
config.n_components,
config.max_iterations,
config.ncv,
Expand Down
23 changes: 15 additions & 8 deletions cpp/include/raft/sparse/solver/lanczos.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,23 @@ auto lanczos_compute_smallest_eigenvectors(
// raft::core::bitmap_view<const bitmap_t, index_t>(bitmap_d.data(), params.m, params.n);

// auto c = raft::make_device_csr_matrix_view<value_t>(c_data_d.data(), c_structure);

// FIXME: move out of function
auto csr_structure = raft::make_device_compressed_structure_view<IndexTypeT, IndexTypeT, IndexTypeT>(
A.row_offsets_,
A.col_indices_,
A.ncols_,
A.nrows_,
static_cast<IndexTypeT>(A.nnz_));
IndexTypeT ncols = A.ncols_;
IndexTypeT nrows = A.nrows_;
IndexTypeT nnz = A.nnz_;

auto csr_structure =
raft::make_device_compressed_structure_view<IndexTypeT, IndexTypeT, IndexTypeT>(
const_cast<IndexTypeT*>(A.row_offsets_),
const_cast<IndexTypeT*>(A.col_indices_),
ncols,
nrows,
nnz);

auto csr_matrix = raft::make_device_matrix_view<ValueTypeT>(A.values_, csr_structure);
auto csr_matrix =
raft::make_device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT>(
const_cast<ValueTypeT*>(A.values_), csr_structure);

return detail::lanczos_compute_smallest_eigenvectors<IndexTypeT, ValueTypeT>(
handle, csr_matrix, config, v0, eigenvalues, eigenvectors);
Expand Down

0 comments on commit 379de5a

Please sign in to comment.