Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Implement matrix transpose with mdspan. #739

Merged
merged 16 commits into from
Jul 28, 2022
10 changes: 7 additions & 3 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

namespace raft {
/**
* @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan
* @brief Dimensions extents for raft::host_mdspan or raft::device_mdspan
*/
template <typename IndexType, size_t... ExtentsPack>
using extents = std::experimental::extents<IndexType, ExtentsPack...>;
Expand All @@ -56,6 +56,11 @@ using layout_f_contiguous = layout_left;
using col_major = layout_left;
/** @} */

/**
* @brief Strided layout for non-contiguous memory.
*/
using detail::stdex::layout_stride;

/**
* @defgroup Common mdarray/mdspan extent types. The rank is known at compile time, each dimension
* is known at run time (dynamic_extent in each dimension).
Expand Down Expand Up @@ -424,8 +429,7 @@ class mdarray
auto operator()(IndexType&&... indices)
-> std::enable_if_t<sizeof...(IndexType) == extents_type::rank() &&
(std::is_convertible_v<IndexType, index_type> && ...) &&
std::is_constructible_v<extents_type, IndexType...> &&
std::is_constructible_v<mapping_type, extents_type>,
std::is_constructible_v<extents_type, IndexType...>,
/* device policy is not default constructible due to requirement for CUDA
stream. */
/* std::is_default_constructible_v<container_policy_type> */
Expand Down
52 changes: 52 additions & 0 deletions cpp/include/raft/linalg/detail/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "cublas_wrappers.hpp"

#include <raft/core/mdarray.hpp>
#include <raft/handle.hpp>
#include <rmm/exec_policy.hpp>
#include <thrust/for_each.h>
Expand Down Expand Up @@ -79,6 +80,57 @@ void transpose(math_t* inout, int n, cudaStream_t stream)
});
}

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
void transpose_row_major_impl(
handle_t const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
auto out_n_rows = in.extent(1);
auto out_n_cols = in.extent(0);
T constexpr kOne = 1;
T constexpr kZero = 0;
CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_n_cols,
out_n_rows,
&kOne,
in.data_handle(),
in.stride(0),
&kZero,
static_cast<T*>(nullptr),
out.stride(0),
out.data_handle(),
out.stride(0),
handle.get_stream()));
}

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
void transpose_col_major_impl(
handle_t const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
auto out_n_rows = in.extent(1);
auto out_n_cols = in.extent(0);
T constexpr kOne = 1;
T constexpr kZero = 0;
CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_n_rows,
out_n_cols,
&kOne,
in.data_handle(),
in.stride(1),
&kZero,
static_cast<T*>(nullptr),
out.stride(1),
out.data_handle(),
out.stride(1),
handle.get_stream()));
}
}; // end namespace detail
}; // end namespace linalg
}; // end namespace raft
41 changes: 40 additions & 1 deletion cpp/include/raft/linalg/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#pragma once

#include "detail/transpose.cuh"
#include <raft/core/mdarray.hpp>

namespace raft {
namespace linalg {
Expand Down Expand Up @@ -55,7 +56,45 @@ void transpose(math_t* inout, int n, cudaStream_t stream)
detail::transpose(inout, n, stream);
}

/**
* @brief Transpose a matrix. The output has same layout policy as the input.
*
* @tparam T Data type of input matrix element.
* @tparam IndexType Index type of matrix extent.
* @tparam LayoutPolicy Layout type of the input matrix. When layout is strided, it can
* be a submatrix of a larger matrix. Arbitrary stride is not supported.
* @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on
* device.
*
* @param[in] handle raft handle for managing expensive cuda resources.
* @param[in] in Input matrix.
* @param[out] out Output matirx, storage is pre-allocated by caller.
*/
template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
auto transpose(handle_t const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
-> std::enable_if_t<std::is_floating_point_v<T>, void>
{
RAFT_EXPECTS(out.extent(0) == in.extent(1), "Invalid shape for transpose.");
RAFT_EXPECTS(out.extent(1) == in.extent(0), "Invalid shape for transpose.");

if constexpr (std::is_same_v<typename decltype(in)::layout_type, layout_c_contiguous>) {
detail::transpose_row_major_impl(handle, in, out);
} else if (std::is_same_v<typename decltype(in)::layout_type, layout_f_contiguous>) {
detail::transpose_col_major_impl(handle, in, out);
} else {
RAFT_EXPECTS(in.stride(0) == 1 || in.stride(1) == 1, "Unsupported matrix layout.");
if (in.stride(1) == 1) {
// row-major submatrix
detail::transpose_row_major_impl(handle, in, out);
} else {
// col-major submatrix
detail::transpose_col_major_impl(handle, in, out);
}
}
}
}; // end namespace linalg
}; // end namespace raft

#endif
#endif
145 changes: 145 additions & 0 deletions cpp/test/linalg/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,150 @@ INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValF, ::testing::ValuesIn(

INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValD, ::testing::ValuesIn(inputsd2));

namespace {
/**
* We hide these functions in tests for now until we have a heterogeneous mdarray
* implementation.
*/

/**
* @brief Transpose a matrix. The output has same layout policy as the input.
*
* @tparam T Data type of input matrix elements.
* @tparam LayoutPolicy Layout type of the input matrix. When layout is strided, it can
* be a submatrix of a larger matrix. Arbitrary stride is not supported.
*
* @param[in] handle raft handle for managing expensive cuda resources.
* @param[in] in Input matrix.
*
* @return The transposed matrix.
*/
template <typename T, typename IndexType, typename LayoutPolicy>
[[nodiscard]] auto transpose(handle_t const& handle,
device_matrix_view<T, IndexType, LayoutPolicy> in)
-> std::enable_if_t<std::is_floating_point_v<T> &&
(std::is_same_v<LayoutPolicy, layout_c_contiguous> ||
std::is_same_v<LayoutPolicy, layout_f_contiguous>),
device_matrix<T, IndexType, LayoutPolicy>>
{
auto out = make_device_matrix<T, IndexType, LayoutPolicy>(handle, in.extent(1), in.extent(0));
::raft::linalg::transpose(handle, in, out.view());
return out;
}

/**
* @brief Transpose a matrix. The output has same layout policy as the input.
*
* @tparam T Data type of input matrix elements.
* @tparam LayoutPolicy Layout type of the input matrix. When layout is strided, it can
* be a submatrix of a larger matrix. Arbitrary stride is not supported.
*
* @param[in] handle raft handle for managing expensive cuda resources.
* @param[in] in Input matrix.
*
* @return The transposed matrix.
*/
template <typename T, typename IndexType>
[[nodiscard]] auto transpose(handle_t const& handle,
device_matrix_view<T, IndexType, layout_stride> in)
-> std::enable_if_t<std::is_floating_point_v<T>, device_matrix<T, IndexType, layout_stride>>
{
matrix_extent<size_t> exts{in.extent(1), in.extent(0)};
using policy_type =
typename raft::device_matrix<T, IndexType, layout_stride>::container_policy_type;
policy_type policy(handle.get_stream());

RAFT_EXPECTS(in.stride(0) == 1 || in.stride(1) == 1, "Unsupported matrix layout.");
if (in.stride(1) == 1) {
// row-major submatrix
std::array<size_t, 2> strides{in.extent(0), 1};
auto layout = layout_stride::mapping<matrix_extent<size_t>>{exts, strides};
raft::device_matrix<T, IndexType, layout_stride> out{layout, policy};
::raft::linalg::transpose(handle, in, out.view());
return out;
} else {
// col-major submatrix
std::array<size_t, 2> strides{1, in.extent(1)};
auto layout = layout_stride::mapping<matrix_extent<size_t>>{exts, strides};
raft::device_matrix<T, IndexType, layout_stride> out{layout, policy};
::raft::linalg::transpose(handle, in, out.view());
return out;
}
}

template <typename T, typename LayoutPolicy>
void test_transpose_with_mdspan()
{
handle_t handle;
auto v = make_device_matrix<T, size_t, LayoutPolicy>(handle, 32, 3);
T k{0};
for (size_t i = 0; i < v.extent(0); ++i) {
for (size_t j = 0; j < v.extent(1); ++j) {
v(i, j) = k++;
}
}
auto out = transpose(handle, v.view());
static_assert(std::is_same_v<LayoutPolicy, typename decltype(out)::layout_type>);
ASSERT_EQ(out.extent(0), v.extent(1));
ASSERT_EQ(out.extent(1), v.extent(0));

k = 0;
for (size_t i = 0; i < out.extent(1); ++i) {
for (size_t j = 0; j < out.extent(0); ++j) {
ASSERT_EQ(out(j, i), k++);
}
}
}
} // namespace

TEST(TransposeTest, MDSpan)
{
test_transpose_with_mdspan<float, layout_c_contiguous>();
test_transpose_with_mdspan<double, layout_c_contiguous>();

test_transpose_with_mdspan<float, layout_f_contiguous>();
test_transpose_with_mdspan<double, layout_f_contiguous>();
}

namespace {
template <typename T, typename LayoutPolicy>
void test_transpose_submatrix()
{
handle_t handle;
auto v = make_device_matrix<T, size_t, LayoutPolicy>(handle, 32, 33);
T k{0};
size_t row_beg{3}, row_end{13}, col_beg{2}, col_end{11};
for (size_t i = row_beg; i < row_end; ++i) {
for (size_t j = col_beg; j < col_end; ++j) {
v(i, j) = k++;
}
}

auto vv = v.view();
auto submat = raft::detail::stdex::submdspan(
vv, std::make_tuple(row_beg, row_end), std::make_tuple(col_beg, col_end));
static_assert(std::is_same_v<typename decltype(submat)::layout_type, layout_stride>);

auto out = transpose(handle, submat);
ASSERT_EQ(out.extent(0), submat.extent(1));
ASSERT_EQ(out.extent(1), submat.extent(0));

k = 0;
for (size_t i = 0; i < out.extent(1); ++i) {
for (size_t j = 0; j < out.extent(0); ++j) {
ASSERT_EQ(out(j, i), k++);
}
}
}
} // namespace

TEST(TransposeTest, SubMatrix)
{
test_transpose_submatrix<float, layout_c_contiguous>();
test_transpose_submatrix<double, layout_c_contiguous>();

test_transpose_submatrix<float, layout_f_contiguous>();
test_transpose_submatrix<double, layout_f_contiguous>();
}
} // end namespace linalg
} // end namespace raft