Skip to content

Commit

Permalink
Implement matrix transpose with mdspan. (#739)
Browse files Browse the repository at this point in the history
* Implement a transpose function that works on both column and row major matrix.
* sub-matrix is supported as well.

Authors:
  - Jiaming Yuan (https://github.com/trivialfis)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #739
  • Loading branch information
trivialfis authored Jul 28, 2022
1 parent 20b6ee5 commit ad55c7b
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 4 deletions.
10 changes: 7 additions & 3 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

namespace raft {
/**
* @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan
* @brief Dimensions extents for raft::host_mdspan or raft::device_mdspan
*/
template <typename IndexType, size_t... ExtentsPack>
using extents = std::experimental::extents<IndexType, ExtentsPack...>;
Expand All @@ -56,6 +56,11 @@ using layout_f_contiguous = layout_left;
using col_major = layout_left;
/** @} */

/**
* @brief Strided layout for non-contiguous memory.
*/
using detail::stdex::layout_stride;

/**
* @defgroup Common mdarray/mdspan extent types. The rank is known at compile time, each dimension
* is known at run time (dynamic_extent in each dimension).
Expand Down Expand Up @@ -424,8 +429,7 @@ class mdarray
auto operator()(IndexType&&... indices)
-> std::enable_if_t<sizeof...(IndexType) == extents_type::rank() &&
(std::is_convertible_v<IndexType, index_type> && ...) &&
std::is_constructible_v<extents_type, IndexType...> &&
std::is_constructible_v<mapping_type, extents_type>,
std::is_constructible_v<extents_type, IndexType...>,
/* device policy is not default constructible due to requirement for CUDA
stream. */
/* std::is_default_constructible_v<container_policy_type> */
Expand Down
52 changes: 52 additions & 0 deletions cpp/include/raft/linalg/detail/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "cublas_wrappers.hpp"

#include <raft/core/mdarray.hpp>
#include <raft/handle.hpp>
#include <rmm/exec_policy.hpp>
#include <thrust/for_each.h>
Expand Down Expand Up @@ -79,6 +80,57 @@ void transpose(math_t* inout, int n, cudaStream_t stream)
});
}

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
void transpose_row_major_impl(
handle_t const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
auto out_n_rows = in.extent(1);
auto out_n_cols = in.extent(0);
T constexpr kOne = 1;
T constexpr kZero = 0;
CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_n_cols,
out_n_rows,
&kOne,
in.data_handle(),
in.stride(0),
&kZero,
static_cast<T*>(nullptr),
out.stride(0),
out.data_handle(),
out.stride(0),
handle.get_stream()));
}

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
void transpose_col_major_impl(
handle_t const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
auto out_n_rows = in.extent(1);
auto out_n_cols = in.extent(0);
T constexpr kOne = 1;
T constexpr kZero = 0;
CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_n_rows,
out_n_cols,
&kOne,
in.data_handle(),
in.stride(1),
&kZero,
static_cast<T*>(nullptr),
out.stride(1),
out.data_handle(),
out.stride(1),
handle.get_stream()));
}
}; // end namespace detail
}; // end namespace linalg
}; // end namespace raft
41 changes: 40 additions & 1 deletion cpp/include/raft/linalg/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#pragma once

#include "detail/transpose.cuh"
#include <raft/core/mdarray.hpp>

namespace raft {
namespace linalg {
Expand Down Expand Up @@ -55,7 +56,45 @@ void transpose(math_t* inout, int n, cudaStream_t stream)
detail::transpose(inout, n, stream);
}

/**
* @brief Transpose a matrix. The output has same layout policy as the input.
*
* @tparam T Data type of input matrix element.
* @tparam IndexType Index type of matrix extent.
* @tparam LayoutPolicy Layout type of the input matrix. When layout is strided, it can
* be a submatrix of a larger matrix. Arbitrary stride is not supported.
* @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on
* device.
*
* @param[in] handle raft handle for managing expensive cuda resources.
* @param[in] in Input matrix.
* @param[out] out Output matirx, storage is pre-allocated by caller.
*/
template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
auto transpose(handle_t const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
-> std::enable_if_t<std::is_floating_point_v<T>, void>
{
RAFT_EXPECTS(out.extent(0) == in.extent(1), "Invalid shape for transpose.");
RAFT_EXPECTS(out.extent(1) == in.extent(0), "Invalid shape for transpose.");

if constexpr (std::is_same_v<typename decltype(in)::layout_type, layout_c_contiguous>) {
detail::transpose_row_major_impl(handle, in, out);
} else if (std::is_same_v<typename decltype(in)::layout_type, layout_f_contiguous>) {
detail::transpose_col_major_impl(handle, in, out);
} else {
RAFT_EXPECTS(in.stride(0) == 1 || in.stride(1) == 1, "Unsupported matrix layout.");
if (in.stride(1) == 1) {
// row-major submatrix
detail::transpose_row_major_impl(handle, in, out);
} else {
// col-major submatrix
detail::transpose_col_major_impl(handle, in, out);
}
}
}
}; // end namespace linalg
}; // end namespace raft

#endif
#endif
145 changes: 145 additions & 0 deletions cpp/test/linalg/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,150 @@ INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValF, ::testing::ValuesIn(

INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValD, ::testing::ValuesIn(inputsd2));

namespace {
/**
* We hide these functions in tests for now until we have a heterogeneous mdarray
* implementation.
*/

/**
* @brief Transpose a matrix. The output has same layout policy as the input.
*
* @tparam T Data type of input matrix elements.
* @tparam LayoutPolicy Layout type of the input matrix. When layout is strided, it can
* be a submatrix of a larger matrix. Arbitrary stride is not supported.
*
* @param[in] handle raft handle for managing expensive cuda resources.
* @param[in] in Input matrix.
*
* @return The transposed matrix.
*/
template <typename T, typename IndexType, typename LayoutPolicy>
[[nodiscard]] auto transpose(handle_t const& handle,
device_matrix_view<T, IndexType, LayoutPolicy> in)
-> std::enable_if_t<std::is_floating_point_v<T> &&
(std::is_same_v<LayoutPolicy, layout_c_contiguous> ||
std::is_same_v<LayoutPolicy, layout_f_contiguous>),
device_matrix<T, IndexType, LayoutPolicy>>
{
auto out = make_device_matrix<T, IndexType, LayoutPolicy>(handle, in.extent(1), in.extent(0));
::raft::linalg::transpose(handle, in, out.view());
return out;
}

/**
* @brief Transpose a matrix. The output has same layout policy as the input.
*
* @tparam T Data type of input matrix elements.
* @tparam LayoutPolicy Layout type of the input matrix. When layout is strided, it can
* be a submatrix of a larger matrix. Arbitrary stride is not supported.
*
* @param[in] handle raft handle for managing expensive cuda resources.
* @param[in] in Input matrix.
*
* @return The transposed matrix.
*/
template <typename T, typename IndexType>
[[nodiscard]] auto transpose(handle_t const& handle,
device_matrix_view<T, IndexType, layout_stride> in)
-> std::enable_if_t<std::is_floating_point_v<T>, device_matrix<T, IndexType, layout_stride>>
{
matrix_extent<size_t> exts{in.extent(1), in.extent(0)};
using policy_type =
typename raft::device_matrix<T, IndexType, layout_stride>::container_policy_type;
policy_type policy(handle.get_stream());

RAFT_EXPECTS(in.stride(0) == 1 || in.stride(1) == 1, "Unsupported matrix layout.");
if (in.stride(1) == 1) {
// row-major submatrix
std::array<size_t, 2> strides{in.extent(0), 1};
auto layout = layout_stride::mapping<matrix_extent<size_t>>{exts, strides};
raft::device_matrix<T, IndexType, layout_stride> out{layout, policy};
::raft::linalg::transpose(handle, in, out.view());
return out;
} else {
// col-major submatrix
std::array<size_t, 2> strides{1, in.extent(1)};
auto layout = layout_stride::mapping<matrix_extent<size_t>>{exts, strides};
raft::device_matrix<T, IndexType, layout_stride> out{layout, policy};
::raft::linalg::transpose(handle, in, out.view());
return out;
}
}

template <typename T, typename LayoutPolicy>
void test_transpose_with_mdspan()
{
handle_t handle;
auto v = make_device_matrix<T, size_t, LayoutPolicy>(handle, 32, 3);
T k{0};
for (size_t i = 0; i < v.extent(0); ++i) {
for (size_t j = 0; j < v.extent(1); ++j) {
v(i, j) = k++;
}
}
auto out = transpose(handle, v.view());
static_assert(std::is_same_v<LayoutPolicy, typename decltype(out)::layout_type>);
ASSERT_EQ(out.extent(0), v.extent(1));
ASSERT_EQ(out.extent(1), v.extent(0));

k = 0;
for (size_t i = 0; i < out.extent(1); ++i) {
for (size_t j = 0; j < out.extent(0); ++j) {
ASSERT_EQ(out(j, i), k++);
}
}
}
} // namespace

TEST(TransposeTest, MDSpan)
{
test_transpose_with_mdspan<float, layout_c_contiguous>();
test_transpose_with_mdspan<double, layout_c_contiguous>();

test_transpose_with_mdspan<float, layout_f_contiguous>();
test_transpose_with_mdspan<double, layout_f_contiguous>();
}

namespace {
template <typename T, typename LayoutPolicy>
void test_transpose_submatrix()
{
handle_t handle;
auto v = make_device_matrix<T, size_t, LayoutPolicy>(handle, 32, 33);
T k{0};
size_t row_beg{3}, row_end{13}, col_beg{2}, col_end{11};
for (size_t i = row_beg; i < row_end; ++i) {
for (size_t j = col_beg; j < col_end; ++j) {
v(i, j) = k++;
}
}

auto vv = v.view();
auto submat = raft::detail::stdex::submdspan(
vv, std::make_tuple(row_beg, row_end), std::make_tuple(col_beg, col_end));
static_assert(std::is_same_v<typename decltype(submat)::layout_type, layout_stride>);

auto out = transpose(handle, submat);
ASSERT_EQ(out.extent(0), submat.extent(1));
ASSERT_EQ(out.extent(1), submat.extent(0));

k = 0;
for (size_t i = 0; i < out.extent(1); ++i) {
for (size_t j = 0; j < out.extent(0); ++j) {
ASSERT_EQ(out(j, i), k++);
}
}
}
} // namespace

TEST(TransposeTest, SubMatrix)
{
test_transpose_submatrix<float, layout_c_contiguous>();
test_transpose_submatrix<double, layout_c_contiguous>();

test_transpose_submatrix<float, layout_f_contiguous>();
test_transpose_submatrix<double, layout_f_contiguous>();
}
} // end namespace linalg
} // end namespace raft

0 comments on commit ad55c7b

Please sign in to comment.