Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for row-major slice #1591

Merged
merged 4 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions cpp/include/raft/matrix/detail/matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,14 @@ void printHost(const m_t* in, idx_t n_rows, idx_t n_cols)
*/
template <typename m_t, typename idx_t = int>
__global__ void slice(
const m_t* src_d, idx_t m, idx_t n, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2)
const m_t* src_d, idx_t lda, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2)
{
idx_t idx = threadIdx.x + blockDim.x * blockIdx.x;
idx_t dm = x2 - x1, dn = y2 - y1;
if (idx < dm * dn) {
idx_t i = idx % dm, j = idx / dm;
idx_t is = i + x1, js = j + y1;
dst_d[idx] = src_d[is + js * m];
dst_d[idx] = src_d[is + js * lda];
}
}

Expand All @@ -190,12 +190,16 @@ void sliceMatrix(const m_t* in,
idx_t y1,
idx_t x2,
idx_t y2,
bool row_major,
cudaStream_t stream)
{
// Slicing
auto lda = row_major ? n_cols : n_rows;
dim3 block(64);
dim3 grid(((x2 - x1) * (y2 - y1) + block.x - 1) / block.x);
slice<<<grid, block, 0, stream>>>(in, n_rows, n_cols, out, x1, y1, x2, y2);
if (row_major)
slice<<<grid, block, 0, stream>>>(in, lda, out, y1, x1, y2, x2);
else
slice<<<grid, block, 0, stream>>>(in, lda, out, x1, y1, x2, y2);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/matrix/matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ void sliceMatrix(m_t* in,
idx_t y2,
cudaStream_t stream)
{
detail::sliceMatrix(in, n_rows, n_cols, out, x1, y1, x2, y2, stream);
detail::sliceMatrix(in, n_rows, n_cols, out, x1, y1, x2, y2, false, stream);
}

/**
Expand Down
13 changes: 8 additions & 5 deletions cpp/include/raft/matrix/slice.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/matrix/detail/matrix.cuh>
#include <raft/util/input_validation.hpp>

namespace raft::matrix {

Expand All @@ -45,17 +46,18 @@ struct slice_coordinates {
* @tparam m_t type of matrix elements
* @tparam idx_t integer type used for indexing
* @param[in] handle: raft handle
* @param[in] in: input matrix (column-major)
* @param[out] out: output matrix (column-major)
* @param[in] in: input matrix
* @param[out] out: output matrix
* @param[in] coords: coordinates of the wanted slice
* example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice(handle, in, out, {0, 1, 4, 3});
*/
template <typename m_t, typename idx_t>
template <typename m_t, typename idx_t, typename layout_t>
void slice(raft::resources const& handle,
raft::device_matrix_view<const m_t, idx_t, col_major> in,
raft::device_matrix_view<m_t, idx_t, col_major> out,
raft::device_matrix_view<const m_t, idx_t, layout_t> in,
raft::device_matrix_view<m_t, idx_t, layout_t> out,
slice_coordinates<idx_t> coords)
{
RAFT_EXPECTS(raft::is_row_or_column_major(in), "Matrix layout must be row- or column-major");
RAFT_EXPECTS(coords.row2 > coords.row1, "row2 must be > row1");
RAFT_EXPECTS(coords.col2 > coords.col1, "col2 must be > col1");
RAFT_EXPECTS(coords.row1 >= 0, "row1 must be >= 0");
Expand All @@ -72,6 +74,7 @@ void slice(raft::resources const& handle,
coords.col1,
coords.row2,
coords.col2,
raft::is_row_major(in),
resource::get_cuda_stream(handle));
}

Expand Down
67 changes: 40 additions & 27 deletions cpp/test/matrix/slice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,29 @@ template <typename T>
struct SliceInputs {
int rows, cols;
unsigned long long int seed;
bool rowMajor;
};

template <typename T>
::std::ostream& operator<<(::std::ostream& os, const SliceInputs<T>& I)
{
os << "{ " << I.rows << ", " << I.cols << ", " << I.seed << '}' << std::endl;
os << "{ " << I.rows << ", " << I.cols << ", " << I.seed << ", " << I.rowMajor << '}'
<< std::endl;
return os;
}

// Col-major slice reference test
template <typename Type>
void naiveSlice(const Type* in, Type* out, int rows, int cols, int x1, int y1, int x2, int y2)
void naiveSlice(
const Type* in, Type* out, int in_lda, int x1, int y1, int x2, int y2, bool row_major)
{
int out_rows = x2 - x1;
// int out_cols = y2 - y1;
int out_lda = row_major ? y2 - y1 : x2 - x1;
for (int j = y1; j < y2; ++j) {
for (int i = x1; i < x2; ++i) {
out[(i - x1) + (j - y1) * out_rows] = in[i + j * rows];
if (row_major)
out[(i - x1) * out_lda + (j - y1)] = in[j + i * in_lda];
else
out[(i - x1) + (j - y1) * out_lda] = in[i + j * in_lda];
}
}
}
Expand All @@ -67,6 +72,7 @@ class SliceTest : public ::testing::TestWithParam<SliceInputs<T>> {
std::default_random_engine dre(rd());
raft::random::RngState r(params.seed);
int rows = params.rows, cols = params.cols, len = rows * cols;
auto lda = params.rowMajor ? cols : rows;
uniform(handle, r, data.data(), len, T(-10.0), T(10.0));

std::uniform_int_distribution<int> rowGenerator(0, (rows / 2) - 1);
Expand All @@ -83,12 +89,19 @@ class SliceTest : public ::testing::TestWithParam<SliceInputs<T>> {

std::vector<T> h_data(rows * cols);
raft::update_host(h_data.data(), data.data(), rows * cols, stream);
naiveSlice(h_data.data(), exp_result.data(), rows, cols, row1, col1, row2, col2);
auto input =
naiveSlice(h_data.data(), exp_result.data(), lda, row1, col1, row2, col2, params.rowMajor);
auto input_F =
raft::make_device_matrix_view<const T, int, raft::col_major>(data.data(), rows, cols);
auto output = raft::make_device_matrix_view<T, int, raft::col_major>(
auto output_F = raft::make_device_matrix_view<T, int, raft::col_major>(
d_act_result.data(), row2 - row1, col2 - col1);
slice(handle, input, output, slice_coordinates(row1, col1, row2, col2));
auto input_C =
raft::make_device_matrix_view<const T, int, raft::row_major>(data.data(), rows, cols);
auto output_C = raft::make_device_matrix_view<T, int, raft::row_major>(
d_act_result.data(), row2 - row1, col2 - col1);
if (params.rowMajor)
slice(handle, input_C, output_C, slice_coordinates(row1, col1, row2, col2));
else
slice(handle, input_F, output_F, slice_coordinates(row1, col1, row2, col2));

raft::update_host(act_result.data(), d_act_result.data(), d_act_result.size(), stream);
resource::sync_stream(handle, stream);
Expand All @@ -104,26 +117,26 @@ class SliceTest : public ::testing::TestWithParam<SliceInputs<T>> {
};

///// Row- and column-wise tests
const std::vector<SliceInputs<float>> inputsf = {{32, 1024, 1234ULL},
{64, 1024, 1234ULL},
{128, 1024, 1234ULL},
{256, 1024, 1234ULL},
{512, 512, 1234ULL},
{1024, 32, 1234ULL},
{1024, 64, 1234ULL},
{1024, 128, 1234ULL},
{1024, 256, 1234ULL}};
const std::vector<SliceInputs<float>> inputsf = {{32, 1024, 1234ULL, true},
{64, 1024, 1234ULL, false},
{128, 1024, 1234ULL, true},
{256, 1024, 1234ULL, false},
{512, 512, 1234ULL, true},
{1024, 32, 1234ULL, false},
{1024, 64, 1234ULL, true},
{1024, 128, 1234ULL, false},
{1024, 256, 1234ULL, true}};

const std::vector<SliceInputs<double>> inputsd = {
{32, 1024, 1234ULL},
{64, 1024, 1234ULL},
{128, 1024, 1234ULL},
{256, 1024, 1234ULL},
{512, 512, 1234ULL},
{1024, 32, 1234ULL},
{1024, 64, 1234ULL},
{1024, 128, 1234ULL},
{1024, 256, 1234ULL},
{32, 1024, 1234ULL, true},
{64, 1024, 1234ULL, false},
{128, 1024, 1234ULL, true},
{256, 1024, 1234ULL, false},
{512, 512, 1234ULL, true},
{1024, 32, 1234ULL, false},
{1024, 64, 1234ULL, true},
{1024, 128, 1234ULL, false},
{1024, 256, 1234ULL, true},
};

typedef SliceTest<float> SliceTestF;
Expand Down