Skip to content

Commit

Permalink
Gram matrix support for sparse input (rapidsai#1296)
Browse files Browse the repository at this point in the history
This PR adds sparse input support (CSR) for GramMatrix kernel computation. This is a requirement to enable SVM support for sparse input in [cuML issue 2197](rapidsai/cuml#2197).

It also adds row norm computation for CSR which is utilized for expanded L2 norm computation within RBF kernels.

Although this branch introduces a new API it is still backwards compatible with the old GramMatrix API (which is marked as deprecated).

CC @cjnolet @tfeher

Authors:
  - Malte Förster (https://github.com/mfoerste4)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1296
  • Loading branch information
mfoerste4 authored and ahendriksen committed Apr 27, 2023
1 parent 8aed3bd commit 909e2c3
Show file tree
Hide file tree
Showing 14 changed files with 1,855 additions and 293 deletions.
30 changes: 30 additions & 0 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,36 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col
return device_matrix_view<ElementType, IndexType, LayoutPolicy>{ptr, extents};
}

/**
* @brief Create a 2-dim mdspan instance for device pointer with a strided layout
* that is restricted to stride 1 in the trailing dimension. It's
* expected that the given layout policy match the layout of the underlying
* pointer.
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] ptr on device to wrap
* @param[in] n_rows number of rows in pointer
* @param[in] n_cols number of columns in pointer
* @param[in] stride leading dimension / stride of data
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_strided_matrix_view(ElementType* ptr,
IndexType n_rows,
IndexType n_cols,
IndexType stride)
{
constexpr auto is_row_major = std::is_same_v<LayoutPolicy, layout_c_contiguous>;
IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1;
IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows);

assert(is_row_major ? stride0 >= n_cols : stride1 >= n_rows);
matrix_extent<IndexType> extents{n_rows, n_cols};

auto layout = make_strided_layout(extents, std::array<IndexType, 2>{stride0, stride1});
return device_matrix_view<ElementType, IndexType, layout_stride>{ptr, layout};
}

/**
* @brief Create a 1-dim mdspan instance for device pointer.
* @tparam ElementType the data type of the vector elements
Expand Down
Loading

0 comments on commit 909e2c3

Please sign in to comment.