Skip to content

Commit

Permalink
Sparse Pairwise Distances API Updates (#1502)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1502
  • Loading branch information
divyegala authored Jul 3, 2023
1 parent 7358763 commit e9d86f1
Show file tree
Hide file tree
Showing 17 changed files with 294 additions and 199 deletions.
50 changes: 30 additions & 20 deletions cpp/include/raft/core/device_coo_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,37 @@

namespace raft {

template <typename ElementType,
typename RowType,
/**
* Specialization for a sparsity-preserving coordinate structure view which uses device memory
*/
template <typename RowType, typename ColType, typename NZType>
using device_coordinate_structure_view = coordinate_structure_view<RowType, ColType, NZType, true>;

/**
* Specialization for a sparsity-owning coordinate structure which uses device memory
*/
template <typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy,
SparsityType sparsity_type = SparsityType::OWNING>
using device_coo_matrix =
coo_matrix<ElementType, RowType, ColType, NZType, true, ContainerPolicy, sparsity_type>;
template <typename T> typename ContainerPolicy = device_uvector_policy>
using device_coordinate_structure =
coordinate_structure<RowType, ColType, NZType, true, ContainerPolicy>;

/**
* Specialization for a coo matrix view which uses device memory
*/
template <typename ElementType, typename RowType, typename ColType, typename NZType>
using device_coo_matrix_view = coo_matrix_view<ElementType, RowType, ColType, NZType, true>;

template <typename ElementType,
typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy,
SparsityType sparsity_type = SparsityType::OWNING>
using device_coo_matrix =
coo_matrix<ElementType, RowType, ColType, NZType, true, ContainerPolicy, sparsity_type>;

/**
* Specialization for a sparsity-owning coo matrix which uses device memory
*/
Expand All @@ -62,21 +78,15 @@ using device_sparsity_preserving_coo_matrix = coo_matrix<ElementType,
ContainerPolicy,
SparsityType::PRESERVING>;

/**
* Specialization for a sparsity-owning coordinate structure which uses device memory
*/
template <typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
using device_coordinate_structure =
coordinate_structure<RowType, ColType, NZType, true, ContainerPolicy>;
template <typename T>
struct is_device_coo_matrix_view : std::false_type {};

/**
* Specialization for a sparsity-preserving coordinate structure view which uses device memory
*/
template <typename RowType, typename ColType, typename NZType>
using device_coordinate_structure_view = coordinate_structure_view<RowType, ColType, NZType, true>;
template <typename ElementType, typename RowType, typename ColType, typename NZType>
struct is_device_coo_matrix_view<device_coo_matrix_view<ElementType, RowType, ColType, NZType>>
: std::true_type {};

template <typename T>
constexpr bool is_device_coo_matrix_view_v = is_device_coo_matrix_view<T>::value;

template <typename T>
struct is_device_coo_matrix : std::false_type {};
Expand Down
94 changes: 49 additions & 45 deletions cpp/include/raft/core/device_csr_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,29 @@

namespace raft {

/**
* Specialization for a sparsity-preserving compressed structure view which uses device memory
*/
template <typename IndptrType, typename IndicesType, typename NZType>
using device_compressed_structure_view =
compressed_structure_view<IndptrType, IndicesType, NZType, true>;

/**
* Specialization for a sparsity-owning compressed structure which uses device memory
*/
template <typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
using device_compressed_structure =
compressed_structure<IndptrType, IndicesType, NZType, true, ContainerPolicy>;

/**
* Specialization for a csr matrix view which uses device memory
*/
template <typename ElementType, typename IndptrType, typename IndicesType, typename NZType>
using device_csr_matrix_view = csr_matrix_view<ElementType, IndptrType, IndicesType, NZType, true>;

template <typename ElementType,
typename IndptrType,
typename IndicesType,
Expand All @@ -45,6 +68,32 @@ template <typename ElementType,
using device_sparsity_owning_csr_matrix =
csr_matrix<ElementType, IndptrType, IndicesType, NZType, true, ContainerPolicy>;

/**
* Specialization for a sparsity-preserving csr matrix which uses device memory
*/
template <typename ElementType,
typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
using device_sparsity_preserving_csr_matrix = csr_matrix<ElementType,
IndptrType,
IndicesType,
NZType,
true,
ContainerPolicy,
SparsityType::PRESERVING>;

template <typename T>
struct is_device_csr_matrix_view : std::false_type {};

template <typename ElementType, typename IndptrType, typename IndicesType, typename NZType>
struct is_device_csr_matrix_view<
device_csr_matrix_view<ElementType, IndptrType, IndicesType, NZType>> : std::true_type {};

template <typename T>
constexpr bool is_device_csr_matrix_view_v = is_device_csr_matrix_view<T>::value;

template <typename T>
struct is_device_csr_matrix : std::false_type {};

Expand All @@ -70,51 +119,6 @@ template <typename T>
constexpr bool is_device_csr_sparsity_preserving_v =
is_device_csr_matrix<T>::value and T::get_sparsity_type() == PRESERVING;

/**
* Specialization for a csr matrix view which uses device memory
*/
template <typename ElementType, typename IndptrType, typename IndicesType, typename NZType>
using device_csr_matrix_view = csr_matrix_view<ElementType, IndptrType, IndicesType, NZType, true>;

/**
* Specialization for a sparsity-preserving csr matrix which uses device memory
*/
template <typename ElementType,
typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
using device_sparsity_preserving_csr_matrix = csr_matrix<ElementType,
IndptrType,
IndicesType,
NZType,
true,
ContainerPolicy,
SparsityType::PRESERVING>;

/**
* Specialization for a csr matrix view which uses device memory
*/
template <typename ElementType, typename IndptrType, typename IndicesType, typename NZType>
using device_csr_matrix_view = csr_matrix_view<ElementType, IndptrType, IndicesType, NZType, true>;

/**
* Specialization for a sparsity-owning compressed structure which uses device memory
*/
template <typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = device_uvector_policy>
using device_compressed_structure =
compressed_structure<IndptrType, IndicesType, NZType, true, ContainerPolicy>;

/**
* Specialization for a sparsity-preserving compressed structure view which uses device memory
*/
template <typename IndptrType, typename IndicesType, typename NZType>
using device_compressed_structure_view =
compressed_structure_view<IndptrType, IndicesType, NZType, true>;

/**
* Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning
* means that all of the underlying vectors (data, indptr, indices) are owned by the csr_matrix
Expand Down
50 changes: 30 additions & 20 deletions cpp/include/raft/core/host_coo_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,37 @@

namespace raft {

template <typename ElementType,
typename RowType,
/**
* Specialization for a sparsity-preserving coordinate structure view which uses host memory
*/
template <typename RowType, typename ColType, typename NZType>
using host_coordinate_structure_view = coordinate_structure_view<RowType, ColType, NZType, false>;

/**
* Specialization for a sparsity-owning coordinate structure which uses host memory
*/
template <typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = host_vector_policy,
SparsityType sparsity_type = SparsityType::OWNING>
using host_coo_matrix =
coo_matrix<ElementType, RowType, ColType, NZType, false, ContainerPolicy, sparsity_type>;
template <typename T> typename ContainerPolicy = host_vector_policy>
using host_coordinate_structure =
coordinate_structure<RowType, ColType, NZType, false, ContainerPolicy>;

/**
* Specialization for a coo matrix view which uses host memory
*/
template <typename ElementType, typename RowType, typename ColType, typename NZType>
using host_coo_matrix_view = coo_matrix_view<ElementType, RowType, ColType, NZType, false>;

template <typename ElementType,
typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = host_vector_policy,
SparsityType sparsity_type = SparsityType::OWNING>
using host_coo_matrix =
coo_matrix<ElementType, RowType, ColType, NZType, false, ContainerPolicy, sparsity_type>;

/**
* Specialization for a sparsity-owning coo matrix which uses host memory
*/
Expand All @@ -61,21 +77,15 @@ using host_sparsity_preserving_coo_matrix = coo_matrix<ElementType,
ContainerPolicy,
SparsityType::PRESERVING>;

/**
* Specialization for a sparsity-owning coordinate structure which uses host memory
*/
template <typename RowType,
typename ColType,
typename NZType,
template <typename T> typename ContainerPolicy = host_vector_policy>
using host_coordinate_structure =
coordinate_structure<RowType, ColType, NZType, false, ContainerPolicy>;
template <typename T>
struct is_host_coo_matrix_view : std::false_type {};

/**
* Specialization for a sparsity-preserving coordinate structure view which uses host memory
*/
template <typename RowType, typename ColType, typename NZType>
using host_coordinate_structure_view = coordinate_structure_view<RowType, ColType, NZType, false>;
template <typename ElementType, typename RowType, typename ColType, typename NZType>
struct is_host_coo_matrix_view<host_coo_matrix_view<ElementType, RowType, ColType, NZType>>
: std::true_type {};

template <typename T>
constexpr bool is_host_coo_matrix_view_v = is_host_coo_matrix_view<T>::value;

template <typename T>
struct is_host_coo_matrix : std::false_type {};
Expand Down
99 changes: 52 additions & 47 deletions cpp/include/raft/core/host_csr_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@

namespace raft {

/**
* Specialization for a sparsity-preserving compressed structure view which uses host memory
*/
template <typename IndptrType, typename IndicesType, typename NZType>
using host_compressed_structure_view =
compressed_structure_view<IndptrType, IndicesType, NZType, false>;

/**
* Specialization for a sparsity-owning compressed structure which uses host memory
*/
template <typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = host_vector_policy>
using host_compressed_structure =
compressed_structure<IndptrType, IndicesType, NZType, false, ContainerPolicy>;

/**
* Specialization for a csr matrix view which uses host memory
*/
template <typename ElementType, typename IndptrType, typename IndicesType, typename NZType>
using host_csr_matrix_view = csr_matrix_view<ElementType, IndptrType, IndicesType, NZType, false>;

template <typename ElementType,
typename IndptrType,
typename IndicesType,
Expand All @@ -44,6 +67,32 @@ template <typename ElementType,
using host_sparsity_owning_csr_matrix =
csr_matrix<ElementType, IndptrType, IndicesType, NZType, false, ContainerPolicy>;

/**
* Specialization for a sparsity-preserving csr matrix which uses host memory
*/
template <typename ElementType,
typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = host_vector_policy>
using host_sparsity_preserving_csr_matrix = csr_matrix<ElementType,
IndptrType,
IndicesType,
NZType,
false,
ContainerPolicy,
SparsityType::PRESERVING>;

template <typename T>
struct is_host_csr_matrix_view : std::false_type {};

template <typename ElementType, typename IndptrType, typename IndicesType, typename NZType>
struct is_host_csr_matrix_view<host_csr_matrix_view<ElementType, IndptrType, IndicesType, NZType>>
: std::true_type {};

template <typename T>
constexpr bool is_host_csr_matrix_view_v = is_host_csr_matrix_view<T>::value;

template <typename T>
struct is_host_csr_matrix : std::false_type {};

Expand All @@ -66,53 +115,9 @@ constexpr bool is_host_csr_sparsity_owning_v =
is_host_csr_matrix<T>::value and T::get_sparsity_type() == OWNING;

template <typename T>
constexpr bool is_host_csr_sparsity_preserving_v =
is_host_csr_matrix<T>::value and T::get_sparsity_type() == PRESERVING;

/**
* Specialization for a csr matrix view which uses host memory
*/
template <typename ElementType, typename IndptrType, typename IndicesType, typename NZType>
using host_csr_matrix_view = csr_matrix_view<ElementType, IndptrType, IndicesType, NZType, false>;

/**
* Specialization for a sparsity-preserving csr matrix which uses host memory
*/
template <typename ElementType,
typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = host_vector_policy>
using host_sparsity_preserving_csr_matrix = csr_matrix<ElementType,
IndptrType,
IndicesType,
NZType,
false,
ContainerPolicy,
SparsityType::PRESERVING>;

/**
* Specialization for a csr matrix view which uses host memory
*/
template <typename ElementType, typename IndptrType, typename IndicesType, typename NZType>
using host_csr_matrix_view = csr_matrix_view<ElementType, IndptrType, IndicesType, NZType, false>;

/**
* Specialization for a sparsity-owning compressed structure which uses host memory
*/
template <typename IndptrType,
typename IndicesType,
typename NZType,
template <typename T> typename ContainerPolicy = host_vector_policy>
using host_compressed_structure =
compressed_structure<IndptrType, IndicesType, NZType, false, ContainerPolicy>;

/**
* Specialization for a sparsity-preserving compressed structure view which uses host memory
*/
template <typename IndptrType, typename IndicesType, typename NZType>
using host_compressed_structure_view =
compressed_structure_view<IndptrType, IndicesType, NZType, false>;
constexpr bool is_host_csr_sparsity_preserving_v = std::disjunction_v<
is_host_csr_matrix_view<T>,
std::bool_constant<is_host_csr_matrix<T>::value and T::get_sparsity_type() == PRESERVING>>;

/**
* Create a sparsity-owning sparse matrix in the compressed-sparse row format. sparsity-owning
Expand Down
Loading

0 comments on commit e9d86f1

Please sign in to comment.