Skip to content

Commit

Permalink
Add struct for slice coordinates and simplify reverse code
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Oct 27, 2022
1 parent 2de4cf7 commit ddc8a52
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 58 deletions.
56 changes: 19 additions & 37 deletions cpp/include/raft/matrix/reverse.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,59 +18,41 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/matrix/detail/matrix.cuh>
#include <raft/util/input_validation.hpp>

namespace raft::matrix {

/**
* @brief Columns of a column major matrix are reversed in place (i.e. first column and
* @brief Reverse the columns of a matrix in place (i.e. first column and
* last column are swapped)
* @param[in] handle: raft handle
* @param[inout] inout: input and output matrix
*/
template <typename m_t, typename idx_t>
void col_reverse(const raft::handle_t& handle,
raft::device_matrix_view<m_t, idx_t, col_major> inout)
template <typename m_t, typename idx_t, typename layout_t>
void col_reverse(const raft::handle_t& handle, raft::device_matrix_view<m_t, idx_t, layout_t> inout)
{
detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream());
RAFT_EXPECTS(raft::is_row_or_column_major(inout), "Unsupported matrix layout");
if (raft::is_col_major(inout)) {
detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream());
} else {
detail::rowReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream());
}
}

/**
* @brief Columns of a row major matrix are reversed in place (i.e. first column and
* last column are swapped)
* @param[in] handle: raft handle
* @param[inout] inout: input and output matrix
*/
template <typename m_t, typename idx_t>
void col_reverse(const raft::handle_t& handle,
raft::device_matrix_view<m_t, idx_t, row_major> inout)
{
detail::rowReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream());
}

/**
* @brief Rows of a column major matrix are reversed in place (i.e. first row and last
* @brief Reverse the rows of a matrix in place (i.e. first row and last
* row are swapped)
* @param[in] handle: raft handle
* @param[inout] inout: input and output matrix
*/
template <typename m_t, typename idx_t>
void row_reverse(const raft::handle_t& handle,
raft::device_matrix_view<m_t, idx_t, col_major> inout)
template <typename m_t, typename idx_t, typename layout_t>
void row_reverse(const raft::handle_t& handle, raft::device_matrix_view<m_t, idx_t, layout_t> inout)
{
detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream());
RAFT_EXPECTS(raft::is_row_or_column_major(inout), "Unsupported matrix layout");
if (raft::is_col_major(inout)) {
detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream());
} else {
detail::colReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream());
}
}

/**
* @brief Rows of a row major matrix are reversed in place (i.e. first row and last
* row are swapped)
* @param[in] handle: raft handle
* @param[inout] inout: input and output matrix
*/
template <typename m_t, typename idx_t>
void row_reverse(const raft::handle_t& handle,
raft::device_matrix_view<m_t, idx_t, row_major> inout)
{
detail::colReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream());
}

} // namespace raft::matrix
50 changes: 30 additions & 20 deletions cpp/include/raft/matrix/slice.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,51 @@

namespace raft::matrix {

template <typename idx_t>
struct slice_coordinates {
idx_t row1; ///< row coordinate of the top-left point of the wanted area (0-based)

This comment has been minimized.

Copy link
@cjnolet

cjnolet Oct 27, 2022

Member

This looks very clean, thanks!

idx_t col1; ///< column coordinate of the top-left point of the wanted area (0-based)
idx_t row2; ///< row coordinate of the bottom-right point of the wanted area (1-based)
idx_t col2; ///< column coordinate of the bottom-right point of the wanted area (1-based)

slice_coordinates(idx_t row1_, idx_t col1_, idx_t row2_, idx_t col2_)
: row1(row1_), col1(col1_), row2(row2_), col2(col2_)
{
}
};

/**
* @brief Slice a matrix (in-place)
* @tparam m_t type of matrix elements
* @tparam idx_t integer type used for indexing
* @param[in] handle: raft handle
* @param[in] in: input matrix (column-major)
* @param[inout] out: output matrix (column-major)
* @param[in] x1, y1: coordinate of the top-left point of the wanted area (0-based)
* @param[in] x2, y2: coordinate of the bottom-right point of the wanted area
* (1-based)
* example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice_matrix(M_d, 4,
* 3, 0, 1, 4, 3);
* @param[out] out: output matrix (column-major)
* @param[in] coords: coordinates of the wanted slice
* example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice(handle, in, out, {0, 1, 4, 3});
*/
template <typename m_t, typename idx_t>
void slice(const raft::handle_t& handle,
raft::device_matrix_view<const m_t, idx_t, col_major> in,
raft::device_matrix_view<m_t, idx_t, col_major> out,
idx_t x1,
idx_t y1,
idx_t x2,
idx_t y2)
slice_coordinates<idx_t> coords)

This comment has been minimized.

Copy link
@cjnolet

cjnolet Oct 27, 2022

Member

We probably want to pass this by reference, though.

{
RAFT_EXPECTS(x2 > x1, "x2 must be > x1");
RAFT_EXPECTS(y2 > y1, "y2 must be > y1");
RAFT_EXPECTS(x1 >= 0, "x1 must be >= 0");
RAFT_EXPECTS(x2 <= in.extent(0), "x2 must be <= number of rows in the input matrix");
RAFT_EXPECTS(y1 >= 0, "y1 must be >= 0");
RAFT_EXPECTS(y2 <= in.extent(1), "y2 must be <= number of columns in the input matrix");
RAFT_EXPECTS(coords.row2 > coords.row1, "row2 must be > row1");
RAFT_EXPECTS(coords.col2 > coords.col1, "col2 must be > col1");
RAFT_EXPECTS(coords.row1 >= 0, "row1 must be >= 0");
RAFT_EXPECTS(coords.row2 <= in.extent(0), "row2 must be <= number of rows in the input matrix");
RAFT_EXPECTS(coords.col1 >= 0, "col1 must be >= 0");
RAFT_EXPECTS(coords.col2 <= in.extent(1),
"col2 must be <= number of columns in the input matrix");

detail::sliceMatrix(in.data_handle(),
in.extent(0),
in.extent(1),
out.data_handle(),
x1,
y1,
x2,
y2,
coords.row1,
coords.col1,
coords.row2,
coords.col2,
handle.get_stream());
}
} // namespace raft::matrix
2 changes: 1 addition & 1 deletion cpp/test/matrix/slice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class SliceTest : public ::testing::TestWithParam<SliceInputs<T>> {
raft::make_device_matrix_view<const T, int, raft::col_major>(data.data(), rows, cols);
auto output = raft::make_device_matrix_view<T, int, raft::col_major>(
d_act_result.data(), row2 - row1, col2 - col1);
slice(handle, input, output, row1, col1, row2, col2);
slice(handle, input, output, slice_coordinates(row1, col1, row2, col2));

raft::update_host(act_result.data(), d_act_result.data(), d_act_result.size(), stream);
handle.sync_stream(stream);
Expand Down

0 comments on commit ddc8a52

Please sign in to comment.