diff --git a/cpp/include/raft/matrix/reverse.cuh b/cpp/include/raft/matrix/reverse.cuh index 13000332a0..e00a240577 100644 --- a/cpp/include/raft/matrix/reverse.cuh +++ b/cpp/include/raft/matrix/reverse.cuh @@ -18,59 +18,41 @@ #include #include +#include namespace raft::matrix { /** - * @brief Columns of a column major matrix are reversed in place (i.e. first column and + * @brief Reverse the columns of a matrix in place (i.e. first column and * last column are swapped) * @param[in] handle: raft handle * @param[inout] inout: input and output matrix */ -template -void col_reverse(const raft::handle_t& handle, - raft::device_matrix_view inout) +template +void col_reverse(const raft::handle_t& handle, raft::device_matrix_view inout) { - detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); + RAFT_EXPECTS(raft::is_row_or_column_major(inout), "Unsupported matrix layout"); + if (raft::is_col_major(inout)) { + detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); + } else { + detail::rowReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream()); + } } /** - * @brief Columns of a row major matrix are reversed in place (i.e. first column and - * last column are swapped) - * @param[in] handle: raft handle - * @param[inout] inout: input and output matrix - */ -template -void col_reverse(const raft::handle_t& handle, - raft::device_matrix_view inout) -{ - detail::rowReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream()); -} - -/** - * @brief Rows of a column major matrix are reversed in place (i.e. first row and last + * @brief Reverse the rows of a matrix in place (i.e. first row and last * row are swapped) * @param[in] handle: raft handle * @param[inout] inout: input and output matrix */ -template -void row_reverse(const raft::handle_t& handle, - raft::device_matrix_view inout) +template +void row_reverse(const raft::handle_t& handle, raft::device_matrix_view inout) { - detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); + RAFT_EXPECTS(raft::is_row_or_column_major(inout), "Unsupported matrix layout"); + if (raft::is_col_major(inout)) { + detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); + } else { + detail::colReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream()); + } } - -/** - * @brief Rows of a row major matrix are reversed in place (i.e. first row and last - * row are swapped) - * @param[in] handle: raft handle - * @param[inout] inout: input and output matrix - */ -template -void row_reverse(const raft::handle_t& handle, - raft::device_matrix_view inout) -{ - detail::colReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream()); -} - } // namespace raft::matrix \ No newline at end of file diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh index ef7ff3d28d..eda2853c78 100644 --- a/cpp/include/raft/matrix/slice.cuh +++ b/cpp/include/raft/matrix/slice.cuh @@ -21,41 +21,51 @@ namespace raft::matrix { +template +struct slice_coordinates { + idx_t row1; ///< row coordinate of the top-left point of the wanted area (0-based) + idx_t col1; ///< column coordinate of the top-left point of the wanted area (0-based) + idx_t row2; ///< row coordinate of the bottom-right point of the wanted area (1-based) + idx_t col2; ///< column coordinate of the bottom-right point of the wanted area (1-based) + + slice_coordinates(idx_t row1_, idx_t col1_, idx_t row2_, idx_t col2_) + : row1(row1_), col1(col1_), row2(row2_), col2(col2_) + { + } +}; + /** * @brief Slice a matrix (in-place) + * @tparam m_t type of matrix elements + * @tparam idx_t integer type used for indexing * @param[in] handle: raft handle * @param[in] in: input matrix (column-major) - * @param[inout] out: output matrix (column-major) - * @param[in] x1, y1: coordinate of the top-left point of the wanted area (0-based) - * @param[in] x2, y2: coordinate of the bottom-right point of the wanted area - * (1-based) - * example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice_matrix(M_d, 4, - * 3, 0, 1, 4, 3); + * @param[out] out: output matrix (column-major) + * @param[in] coords: coordinates of the wanted slice + * example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice(handle, in, out, {0, 1, 4, 3}); */ template void slice(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, - idx_t x1, - idx_t y1, - idx_t x2, - idx_t y2) + slice_coordinates coords) { - RAFT_EXPECTS(x2 > x1, "x2 must be > x1"); - RAFT_EXPECTS(y2 > y1, "y2 must be > y1"); - RAFT_EXPECTS(x1 >= 0, "x1 must be >= 0"); - RAFT_EXPECTS(x2 <= in.extent(0), "x2 must be <= number of rows in the input matrix"); - RAFT_EXPECTS(y1 >= 0, "y1 must be >= 0"); - RAFT_EXPECTS(y2 <= in.extent(1), "y2 must be <= number of columns in the input matrix"); + RAFT_EXPECTS(coords.row2 > coords.row1, "row2 must be > row1"); + RAFT_EXPECTS(coords.col2 > coords.col1, "col2 must be > col1"); + RAFT_EXPECTS(coords.row1 >= 0, "row1 must be >= 0"); + RAFT_EXPECTS(coords.row2 <= in.extent(0), "row2 must be <= number of rows in the input matrix"); + RAFT_EXPECTS(coords.col1 >= 0, "col1 must be >= 0"); + RAFT_EXPECTS(coords.col2 <= in.extent(1), + "col2 must be <= number of columns in the input matrix"); detail::sliceMatrix(in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), - x1, - y1, - x2, - y2, + coords.row1, + coords.col1, + coords.row2, + coords.col2, handle.get_stream()); } } // namespace raft::matrix \ No newline at end of file diff --git a/cpp/test/matrix/slice.cu b/cpp/test/matrix/slice.cu index f0cce2c184..9744e3724a 100644 --- a/cpp/test/matrix/slice.cu +++ b/cpp/test/matrix/slice.cu @@ -87,7 +87,7 @@ class SliceTest : public ::testing::TestWithParam> { raft::make_device_matrix_view(data.data(), rows, cols); auto output = raft::make_device_matrix_view( d_act_result.data(), row2 - row1, col2 - col1); - slice(handle, input, output, row1, col1, row2, col2); + slice(handle, input, output, slice_coordinates(row1, col1, row2, col2)); raft::update_host(act_result.data(), d_act_result.data(), d_act_result.size(), stream); handle.sync_stream(stream);