From fc96108d24fcf318ad2e808337bc83897b465979 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 12 Jun 2023 18:59:16 +0200 Subject: [PATCH 1/2] Add support for row-major slice --- cpp/include/raft/matrix/detail/matrix.cuh | 12 ++-- cpp/include/raft/matrix/matrix.cuh | 2 +- cpp/include/raft/matrix/slice.cuh | 13 +++-- cpp/test/matrix/slice.cu | 67 ++++++++++++++--------- 4 files changed, 57 insertions(+), 37 deletions(-) diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index aba119ee73..48821df5b2 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -170,14 +170,14 @@ void printHost(const m_t* in, idx_t n_rows, idx_t n_cols) */ template __global__ void slice( - const m_t* src_d, idx_t m, idx_t n, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2) + const m_t* src_d, idx_t lda, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2) { idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; idx_t dm = x2 - x1, dn = y2 - y1; if (idx < dm * dn) { idx_t i = idx % dm, j = idx / dm; idx_t is = i + x1, js = j + y1; - dst_d[idx] = src_d[is + js * m]; + dst_d[idx] = src_d[is + js * lda]; } } @@ -190,12 +190,16 @@ void sliceMatrix(const m_t* in, idx_t y1, idx_t x2, idx_t y2, + bool row_major, cudaStream_t stream) { - // Slicing + auto lda = row_major ? n_cols : n_rows; dim3 block(64); dim3 grid(((x2 - x1) * (y2 - y1) + block.x - 1) / block.x); - slice<<>>(in, n_rows, n_cols, out, x1, y1, x2, y2); + if (row_major) + slice<<>>(in, lda, out, y1, x1, y2, x2); + else + slice<<>>(in, lda, out, x1, y1, x2, y2); } /** diff --git a/cpp/include/raft/matrix/matrix.cuh b/cpp/include/raft/matrix/matrix.cuh index 6851b8739e..63c33ff034 100644 --- a/cpp/include/raft/matrix/matrix.cuh +++ b/cpp/include/raft/matrix/matrix.cuh @@ -203,7 +203,7 @@ void sliceMatrix(m_t* in, idx_t y2, cudaStream_t stream) { - detail::sliceMatrix(in, n_rows, n_cols, out, x1, y1, x2, y2, stream); + detail::sliceMatrix(in, n_rows, n_cols, out, x1, y1, x2, y2, false, stream); } /** diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh index b739f1c732..6c541e0349 100644 --- a/cpp/include/raft/matrix/slice.cuh +++ b/cpp/include/raft/matrix/slice.cuh @@ -19,6 +19,7 @@ #include #include #include +#include namespace raft::matrix { @@ -45,17 +46,18 @@ struct slice_coordinates { * @tparam m_t type of matrix elements * @tparam idx_t integer type used for indexing * @param[in] handle: raft handle - * @param[in] in: input matrix (column-major) - * @param[out] out: output matrix (column-major) + * @param[in] in: input matrix + * @param[out] out: output matrix * @param[in] coords: coordinates of the wanted slice * example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice(handle, in, out, {0, 1, 4, 3}); */ -template +template void slice(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, + raft::device_matrix_view in, + raft::device_matrix_view out, slice_coordinates coords) { + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Unsupported matrix layout"); RAFT_EXPECTS(coords.row2 > coords.row1, "row2 must be > row1"); RAFT_EXPECTS(coords.col2 > coords.col1, "col2 must be > col1"); RAFT_EXPECTS(coords.row1 >= 0, "row1 must be >= 0"); @@ -72,6 +74,7 @@ void slice(raft::resources const& handle, coords.col1, coords.row2, coords.col2, + raft::is_row_major(in), resource::get_cuda_stream(handle)); } diff --git a/cpp/test/matrix/slice.cu b/cpp/test/matrix/slice.cu index 332db379b7..fbf735aaf7 100644 --- a/cpp/test/matrix/slice.cu +++ b/cpp/test/matrix/slice.cu @@ -29,24 +29,29 @@ template struct SliceInputs { int rows, cols; unsigned long long int seed; + bool rowMajor; }; template ::std::ostream& operator<<(::std::ostream& os, const SliceInputs& I) { - os << "{ " << I.rows << ", " << I.cols << ", " << I.seed << '}' << std::endl; + os << "{ " << I.rows << ", " << I.cols << ", " << I.seed << ", " << I.rowMajor << '}' + << std::endl; return os; } // Col-major slice reference test template -void naiveSlice(const Type* in, Type* out, int rows, int cols, int x1, int y1, int x2, int y2) +void naiveSlice( + const Type* in, Type* out, int in_lda, int x1, int y1, int x2, int y2, bool row_major) { - int out_rows = x2 - x1; - // int out_cols = y2 - y1; + int out_lda = row_major ? y2 - y1 : x2 - x1; for (int j = y1; j < y2; ++j) { for (int i = x1; i < x2; ++i) { - out[(i - x1) + (j - y1) * out_rows] = in[i + j * rows]; + if (row_major) + out[(i - x1) * out_lda + (j - y1)] = in[j + i * in_lda]; + else + out[(i - x1) + (j - y1) * out_lda] = in[i + j * in_lda]; } } } @@ -67,6 +72,7 @@ class SliceTest : public ::testing::TestWithParam> { std::default_random_engine dre(rd()); raft::random::RngState r(params.seed); int rows = params.rows, cols = params.cols, len = rows * cols; + auto lda = params.rowMajor ? cols : rows; uniform(handle, r, data.data(), len, T(-10.0), T(10.0)); std::uniform_int_distribution rowGenerator(0, (rows / 2) - 1); @@ -83,12 +89,19 @@ class SliceTest : public ::testing::TestWithParam> { std::vector h_data(rows * cols); raft::update_host(h_data.data(), data.data(), rows * cols, stream); - naiveSlice(h_data.data(), exp_result.data(), rows, cols, row1, col1, row2, col2); - auto input = + naiveSlice(h_data.data(), exp_result.data(), lda, row1, col1, row2, col2, params.rowMajor); + auto input_F = raft::make_device_matrix_view(data.data(), rows, cols); - auto output = raft::make_device_matrix_view( + auto output_F = raft::make_device_matrix_view( d_act_result.data(), row2 - row1, col2 - col1); - slice(handle, input, output, slice_coordinates(row1, col1, row2, col2)); + auto input_C = + raft::make_device_matrix_view(data.data(), rows, cols); + auto output_C = raft::make_device_matrix_view( + d_act_result.data(), row2 - row1, col2 - col1); + if (params.rowMajor) + slice(handle, input_C, output_C, slice_coordinates(row1, col1, row2, col2)); + else + slice(handle, input_F, output_F, slice_coordinates(row1, col1, row2, col2)); raft::update_host(act_result.data(), d_act_result.data(), d_act_result.size(), stream); resource::sync_stream(handle, stream); @@ -104,26 +117,26 @@ class SliceTest : public ::testing::TestWithParam> { }; ///// Row- and column-wise tests -const std::vector> inputsf = {{32, 1024, 1234ULL}, - {64, 1024, 1234ULL}, - {128, 1024, 1234ULL}, - {256, 1024, 1234ULL}, - {512, 512, 1234ULL}, - {1024, 32, 1234ULL}, - {1024, 64, 1234ULL}, - {1024, 128, 1234ULL}, - {1024, 256, 1234ULL}}; +const std::vector> inputsf = {{32, 1024, 1234ULL, true}, + {64, 1024, 1234ULL, false}, + {128, 1024, 1234ULL, true}, + {256, 1024, 1234ULL, false}, + {512, 512, 1234ULL, true}, + {1024, 32, 1234ULL, false}, + {1024, 64, 1234ULL, true}, + {1024, 128, 1234ULL, false}, + {1024, 256, 1234ULL, true}}; const std::vector> inputsd = { - {32, 1024, 1234ULL}, - {64, 1024, 1234ULL}, - {128, 1024, 1234ULL}, - {256, 1024, 1234ULL}, - {512, 512, 1234ULL}, - {1024, 32, 1234ULL}, - {1024, 64, 1234ULL}, - {1024, 128, 1234ULL}, - {1024, 256, 1234ULL}, + {32, 1024, 1234ULL, true}, + {64, 1024, 1234ULL, false}, + {128, 1024, 1234ULL, true}, + {256, 1024, 1234ULL, false}, + {512, 512, 1234ULL, true}, + {1024, 32, 1234ULL, false}, + {1024, 64, 1234ULL, true}, + {1024, 128, 1234ULL, false}, + {1024, 256, 1234ULL, true}, }; typedef SliceTest SliceTestF; From 564362fd3735609bbd2fe6186dfe570b0b2d70fa Mon Sep 17 00:00:00 2001 From: Micka Date: Thu, 22 Jun 2023 23:33:08 +0200 Subject: [PATCH 2/2] Update assert message --- cpp/include/raft/matrix/slice.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh index 6c541e0349..e81c186960 100644 --- a/cpp/include/raft/matrix/slice.cuh +++ b/cpp/include/raft/matrix/slice.cuh @@ -57,7 +57,7 @@ void slice(raft::resources const& handle, raft::device_matrix_view out, slice_coordinates coords) { - RAFT_EXPECTS(raft::is_row_or_column_major(in), "Unsupported matrix layout"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Matrix layout must be row- or column-major"); RAFT_EXPECTS(coords.row2 > coords.row1, "row2 must be > row1"); RAFT_EXPECTS(coords.col2 > coords.col1, "col2 must be > col1"); RAFT_EXPECTS(coords.row1 >= 0, "row1 must be >= 0");