diff --git a/cpp/include/raft/core/coo_matrix.hpp b/cpp/include/raft/core/coo_matrix.hpp index efab8a1601..a5f7c05493 100644 --- a/cpp/include/raft/core/coo_matrix.hpp +++ b/cpp/include/raft/core/coo_matrix.hpp @@ -71,12 +71,6 @@ class coordinate_structure_view { } - /** - * Create a view from this view. Note that this is for interface compatibility - * @return - */ - view_type view() { return view_type(rows_, cols_, this->get_n_rows(), this->get_n_cols()); } - /** * Return span containing underlying rows array * @return span containing underlying rows array @@ -209,6 +203,10 @@ class coo_matrix_view coordinate_structure_view, is_device> { public: + using element_type = ElementType; + using row_type = RowType; + using col_type = ColType; + using nnz_type = NZType; coo_matrix_view(raft::span element_span, coordinate_structure_view structure_view) : sparse_matrix_view { public: using element_type = ElementType; + using row_type = RowType; + using col_type = ColType; + using nnz_type = NZType; using structure_view_type = typename structure_type::view_type; using container_type = typename ContainerPolicy::container_type; using sparse_matrix_type = @@ -258,14 +259,9 @@ class coo_matrix // Constructor that owns the data but not the structure template > - coo_matrix(raft::resources const& handle, std::shared_ptr structure) noexcept( + coo_matrix(raft::resources const& handle, structure_type structure) noexcept( std::is_nothrow_default_constructible_v) : sparse_matrix_type(handle, structure){}; - /** - * Return a view of the structure underlying this matrix - * @return - */ - structure_view_type structure_view() { return this->structure_.get()->view(); } /** * Initialize the sparsity on this instance if it was not known upon construction @@ -277,7 +273,20 @@ class coo_matrix void initialize_sparsity(NZType nnz) { sparse_matrix_type::initialize_sparsity(nnz); - this->structure_.get()->initialize_sparsity(nnz); + this->structure_.initialize_sparsity(nnz); + } + + /** + * Return a view of the structure underlying this matrix + * @return + */ + structure_view_type structure_view() + { + if constexpr (get_sparsity_type() == SparsityType::OWNING) { + return this->structure_.view(); + } else { + return this->structure_; + } } }; } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/csr_matrix.hpp b/cpp/include/raft/core/csr_matrix.hpp index fac656b3f9..c37cfa41c8 100644 --- a/cpp/include/raft/core/csr_matrix.hpp +++ b/cpp/include/raft/core/csr_matrix.hpp @@ -87,12 +87,6 @@ class compressed_structure_view */ span get_indices() override { return indices_; } - /** - * Create a view from this view. Note that this is for interface compatibility - * @return - */ - view_type view() { return view_type(indptr_, indices_, this->get_n_cols()); } - protected: raft::span indptr_; raft::span indices_; @@ -221,6 +215,10 @@ class csr_matrix_view compressed_structure_view, is_device> { public: + using element_type = ElementType; + using indptr_type = IndptrType; + using indices_type = IndicesType; + using nnz_type = NZType; csr_matrix_view( raft::span element_span, compressed_structure_view structure_view) @@ -249,6 +247,9 @@ class csr_matrix ContainerPolicy> { public: using element_type = ElementType; + using indptr_type = IndptrType; + using indices_type = IndicesType; + using nnz_type = NZType; using structure_view_type = typename structure_type::view_type; static constexpr auto get_sparsity_type() { return sparsity_type; } using sparse_matrix_type = @@ -271,7 +272,7 @@ class csr_matrix template > - csr_matrix(raft::resources const& handle, std::shared_ptr structure) noexcept( + csr_matrix(raft::resources const& handle, structure_type structure) noexcept( std::is_nothrow_default_constructible_v) : sparse_matrix_type(handle, structure){}; @@ -284,13 +285,20 @@ class csr_matrix void initialize_sparsity(NZType nnz) { sparse_matrix_type::initialize_sparsity(nnz); - this->structure_.get()->initialize_sparsity(nnz); + this->structure_.initialize_sparsity(nnz); } /** * Return a view of the structure underlying this matrix * @return */ - structure_view_type structure_view() { return this->structure_.get()->view(); } + structure_view_type structure_view() + { + if constexpr (get_sparsity_type() == SparsityType::OWNING) { + return this->structure_.view(); + } else { + return this->structure_; + } + } }; } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/device_coo_matrix.hpp b/cpp/include/raft/core/device_coo_matrix.hpp index b1e9ca30fc..35be67431d 100644 --- a/cpp/include/raft/core/device_coo_matrix.hpp +++ b/cpp/include/raft/core/device_coo_matrix.hpp @@ -174,16 +174,15 @@ auto make_device_coo_matrix(raft::resources const& handle, * @tparam ColType * @tparam NZType * @param[in] handle raft handle for managing expensive device resources - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return a sparsity-preserving sparse matrix in coordinate (coo) format */ template auto make_device_coo_matrix(raft::resources const& handle, - device_coordinate_structure_view structure_) + device_coordinate_structure_view structure) { - return device_sparsity_preserving_coo_matrix( - handle, - std::make_shared>(structure_)); + return device_sparsity_preserving_coo_matrix(handle, + structure); } /** @@ -212,16 +211,15 @@ auto make_device_coo_matrix(raft::resources const& handle, * @tparam ColType * @tparam NZType * @param[in] ptr a pointer to array of nonzero matrix elements on device (size nnz) - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return a sparsity-preserving sparse matrix in coordinate (coo) format */ template auto make_device_coo_matrix_view( - ElementType* ptr, device_coordinate_structure_view structure_) + ElementType* ptr, device_coordinate_structure_view structure) { return device_coo_matrix_view( - raft::device_span(ptr, structure_.get_nnz()), - std::make_shared>(structure_)); + raft::device_span(ptr, structure.get_nnz()), structure); } /** @@ -251,19 +249,17 @@ auto make_device_coo_matrix_view( * @tparam ColType * @tparam NZType * @param[in] elements a device span containing nonzero matrix elements (size nnz) - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return */ template auto make_device_coo_matrix_view( raft::device_span elements, - device_coordinate_structure_view structure_) + device_coordinate_structure_view structure) { - RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + RAFT_EXPECTS(elements.size() == structure.get_nnz(), "Size of elements must be equal to the nnz from the structure"); - return device_coo_matrix_view( - elements, - std::make_shared>(structure_)); + return device_coo_matrix_view(elements, structure); } /** @@ -338,7 +334,7 @@ auto make_device_coordinate_structure(raft::resources const& handle, * @return a sparsity-preserving coordinate structural view */ template -auto make_device_coo_structure_view( +auto make_device_coordinate_structure_view( RowType* rows, ColType* cols, RowType n_rows, ColType n_cols, NZType nnz) { return device_coordinate_structure_view( @@ -376,10 +372,10 @@ auto make_device_coo_structure_view( * @return a sparsity-preserving coordinate structural view */ template -auto make_device_coo_structure_view(raft::device_span rows, - raft::device_span cols, - RowType n_rows, - ColType n_cols) +auto make_device_coordinate_structure_view(raft::device_span rows, + raft::device_span cols, + RowType n_rows, + ColType n_cols) { return device_coordinate_structure_view(rows, cols, n_rows, n_cols); } diff --git a/cpp/include/raft/core/device_csr_matrix.hpp b/cpp/include/raft/core/device_csr_matrix.hpp index 59cabacf6d..e4ec15f9bd 100644 --- a/cpp/include/raft/core/device_csr_matrix.hpp +++ b/cpp/include/raft/core/device_csr_matrix.hpp @@ -189,7 +189,7 @@ auto make_device_csr_matrix(raft::device_resources const& handle, * @tparam IndicesType * @tparam NZType * @param[in] handle raft handle for managing expensive device resources - * @param[in] structure_ a sparsity-preserving compressed structural view + * @param[in] structure a sparsity-preserving compressed structural view * @return a sparsity-preserving sparse matrix in compressed (csr) format */ template auto make_device_csr_matrix( raft::device_resources const& handle, - device_compressed_structure_view structure_) + device_compressed_structure_view structure) { return device_sparsity_preserving_csr_matrix( - handle, - std::make_shared>( - structure_)); + handle, structure); } /** @@ -232,7 +230,7 @@ auto make_device_csr_matrix( * @tparam IndicesType * @tparam NZType * @param[in] ptr a pointer to array of nonzero matrix elements on device (size nnz) - * @param[in] structure_ a sparsity-preserving compressed sparse structural view + * @param[in] structure a sparsity-preserving compressed sparse structural view * @return a sparsity-preserving csr matrix view */ template auto make_device_csr_matrix_view( - ElementType* ptr, device_compressed_structure_view structure_) + ElementType* ptr, device_compressed_structure_view structure) { return device_csr_matrix_view( - raft::device_span(ptr, structure_.get_nnz()), std::make_shared(structure_)); + raft::device_span(ptr, structure.get_nnz()), structure); } /** @@ -273,7 +271,7 @@ auto make_device_csr_matrix_view( * @tparam IndicesType * @tparam NZType * @param[in] elements device span containing array of matrix elements (size nnz) - * @param[in] structure_ a sparsity-preserving structural view + * @param[in] structure a sparsity-preserving structural view * @return a sparsity-preserving csr matrix view */ template auto make_device_csr_matrix_view( raft::device_span elements, - device_compressed_structure_view structure_) + device_compressed_structure_view structure) { - RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + RAFT_EXPECTS(elements.size() == structure.get_nnz(), "Size of elements must be equal to the nnz from the structure"); - return device_csr_matrix_view( - elements, std::make_shared(structure_)); + return device_csr_matrix_view(elements, structure); } /** @@ -365,7 +362,7 @@ auto make_device_compressed_structure(raft::device_resources const& handle, * @return a sparsity-preserving compressed structural view */ template -auto make_device_csr_structure_view( +auto make_device_compressed_structure_view( IndptrType* indptr, IndicesType* indices, IndptrType n_rows, IndicesType n_cols, NZType nnz) { return device_compressed_structure_view( @@ -408,9 +405,9 @@ auto make_device_csr_structure_view( * */ template -auto make_device_csr_structure_view(raft::device_span indptr, - raft::device_span indices, - IndicesType n_cols) +auto make_device_compressed_structure_view(raft::device_span indptr, + raft::device_span indices, + IndicesType n_cols) { return device_compressed_structure_view(indptr, indices, n_cols); } diff --git a/cpp/include/raft/core/host_coo_matrix.hpp b/cpp/include/raft/core/host_coo_matrix.hpp index 45ec278a7d..8fabf5aa95 100644 --- a/cpp/include/raft/core/host_coo_matrix.hpp +++ b/cpp/include/raft/core/host_coo_matrix.hpp @@ -173,15 +173,15 @@ auto make_host_coo_matrix(raft::resources const& handle, * @tparam ColType * @tparam NZType * @param[in] handle raft handle for managing expensive resources - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return a sparsity-preserving sparse matrix in coordinate (coo) format */ template auto make_host_coo_matrix(raft::resources const& handle, - host_coordinate_structure_view structure_) + host_coordinate_structure_view structure) { - return host_sparsity_preserving_coo_matrix( - handle, std::make_shared>(structure_)); + return host_sparsity_preserving_coo_matrix(handle, + structure); } /** @@ -210,15 +210,15 @@ auto make_host_coo_matrix(raft::resources const& handle, * @tparam ColType * @tparam NZType * @param[in] ptr a pointer to array of nonzero matrix elements on host (size nnz) - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return a sparsity-preserving sparse matrix in coordinate (coo) format */ template auto make_host_coo_matrix_view(ElementType* ptr, - host_coordinate_structure_view structure_) + host_coordinate_structure_view structure) { return host_coo_matrix_view( - raft::host_span(ptr, structure_.get_nnz()), std::make_shared(structure_)); + raft::host_span(ptr, structure.get_nnz()), structure); } /** @@ -248,17 +248,16 @@ auto make_host_coo_matrix_view(ElementType* ptr, * @tparam ColType * @tparam NZType * @param[in] elements a host span containing nonzero matrix elements (size nnz) - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return */ template auto make_host_coo_matrix_view(raft::host_span elements, - host_coordinate_structure_view structure_) + host_coordinate_structure_view structure) { - RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + RAFT_EXPECTS(elements.size() == structure.get_nnz(), "Size of elements must be equal to the nnz from the structure"); - return host_coo_matrix_view(elements, - std::make_shared(structure_)); + return host_coo_matrix_view(elements, structure); } /** @@ -333,7 +332,7 @@ auto make_host_coordinate_structure(raft::resources const& handle, * @return a sparsity-preserving coordinate structural view */ template -auto make_host_coo_structure_view( +auto make_host_coordinate_structure_view( RowType* rows, ColType* cols, RowType n_rows, ColType n_cols, NZType nnz) { return host_coordinate_structure_view( @@ -371,10 +370,10 @@ auto make_host_coo_structure_view( * @return a sparsity-preserving coordinate structural view */ template -auto make_host_coo_structure_view(raft::host_span rows, - raft::host_span cols, - RowType n_rows, - ColType n_cols) +auto make_host_coordinate_structure_view(raft::host_span rows, + raft::host_span cols, + RowType n_rows, + ColType n_cols) { return host_coordinate_structure_view(rows, cols, n_rows, n_cols); } diff --git a/cpp/include/raft/core/host_csr_matrix.hpp b/cpp/include/raft/core/host_csr_matrix.hpp index 437f60814e..c64bcdcea6 100644 --- a/cpp/include/raft/core/host_csr_matrix.hpp +++ b/cpp/include/raft/core/host_csr_matrix.hpp @@ -189,20 +189,18 @@ auto make_host_csr_matrix(raft::resources const& handle, * @tparam IndicesType * @tparam NZType * @param[in] handle raft handle for managing expensive resources - * @param[in] structure_ a sparsity-preserving compressed structural view + * @param[in] structure a sparsity-preserving compressed structural view * @return a sparsity-preserving sparse matrix in compressed (csr) format */ template -auto make_host_csr_matrix( - raft::resources const& handle, - host_compressed_structure_view structure_) +auto make_host_csr_matrix(raft::resources const& handle, + host_compressed_structure_view structure) { return host_sparsity_preserving_csr_matrix( - handle, - std::make_shared>(structure_)); + handle, structure); } /** @@ -231,7 +229,7 @@ auto make_host_csr_matrix( * @tparam IndicesType * @tparam NZType * @param[in] ptr a pointer to array of nonzero matrix elements on host (size nnz) - * @param[in] structure_ a sparsity-preserving compressed sparse structural view + * @param[in] structure a sparsity-preserving compressed sparse structural view * @return a sparsity-preserving csr matrix view */ template auto make_host_csr_matrix_view( - ElementType* ptr, host_compressed_structure_view structure_) + ElementType* ptr, host_compressed_structure_view structure) { return host_csr_matrix_view( - raft::host_span(ptr, structure_.get_nnz()), std::make_shared(structure_)); + raft::host_span(ptr, structure.get_nnz()), structure); } /** @@ -272,7 +270,7 @@ auto make_host_csr_matrix_view( * @tparam IndicesType * @tparam NZType * @param[in] elements host span containing array of matrix elements (size nnz) - * @param[in] structure_ a sparsity-preserving structural view + * @param[in] structure a sparsity-preserving structural view * @return a sparsity-preserving csr matrix view */ template auto make_host_csr_matrix_view( raft::host_span elements, - host_compressed_structure_view structure_) + host_compressed_structure_view structure) { - RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + RAFT_EXPECTS(elements.size() == structure.get_nnz(), "Size of elements must be equal to the nnz from the structure"); - return host_csr_matrix_view( - elements, std::make_shared(structure_)); + return host_csr_matrix_view(elements, structure); } /** @@ -365,7 +362,7 @@ auto make_host_compressed_structure(raft::resources const& handle, * @return a sparsity-preserving compressed structural view */ template -auto make_host_csr_structure_view( +auto make_host_compressed_structure_view( IndptrType* indptr, IndicesType* indices, IndptrType n_rows, IndicesType n_cols, NZType nnz) { return host_compressed_structure_view( @@ -408,9 +405,9 @@ auto make_host_csr_structure_view( * */ template -auto make_host_csr_structure_view(raft::host_span indptr, - raft::host_span indices, - IndicesType n_cols) +auto make_host_compressed_structure_view(raft::host_span indptr, + raft::host_span indices, + IndicesType n_cols) { return host_compressed_structure_view(indptr, indices, n_cols); } diff --git a/cpp/include/raft/core/sparse_types.hpp b/cpp/include/raft/core/sparse_types.hpp index 207cc944d2..a14944ed5b 100644 --- a/cpp/include/raft/core/sparse_types.hpp +++ b/cpp/include/raft/core/sparse_types.hpp @@ -109,7 +109,7 @@ class sparse_matrix_view { * Return a view of the structure underlying this matrix * @return */ - structure_view_type get_structure() { return structure_view_; } + structure_view_type structure_view() { return structure_view_; } /** * Return a span of the nonzero elements of the matrix @@ -158,18 +158,19 @@ class sparse_matrix { using container_policy_type = ContainerPolicy; using container_type = typename container_policy_type::container_type; + // constructor that owns the data and the structure sparse_matrix(raft::resources const& handle, row_type n_rows, col_type n_cols, nnz_type nnz = 0) noexcept(std::is_nothrow_default_constructible_v) - : structure_{std::make_shared(handle, n_rows, n_cols, nnz)}, - cp_{}, - c_elements_{cp_.create(handle, 0)} {}; + : structure_{handle, n_rows, n_cols, nnz}, cp_{}, c_elements_{cp_.create(handle, 0)} {}; // Constructor that owns the data but not the structure - sparse_matrix(raft::resources const& handle, std::shared_ptr structure) noexcept( + // This constructor is only callable with a `structure_type == *_structure_view` + // which makes it okay to copy + sparse_matrix(raft::resources const& handle, structure_type structure) noexcept( std::is_nothrow_default_constructible_v) - : structure_{structure}, cp_{}, c_elements_{cp_.create(handle, structure.get()->get_nnz())} {}; + : structure_{structure}, cp_{}, c_elements_{cp_.create(handle, structure_.get_nnz())} {}; constexpr sparse_matrix(sparse_matrix const&) noexcept( std::is_nothrow_copy_constructible_v) = default; @@ -187,7 +188,7 @@ class sparse_matrix { raft::span get_elements() { - return raft::span(c_elements_.data(), structure_view().get_nnz()); + return raft::span(c_elements_.data(), structure_.get_nnz()); } /** @@ -209,7 +210,7 @@ class sparse_matrix { } protected: - std::shared_ptr structure_; + structure_type structure_; container_policy_type cp_; container_type c_elements_; };