diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 43d6dc702f..f1e735c4ab 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -33,7 +33,7 @@ namespace raft { /** - * @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan + * @brief Dimensions extents for raft::host_mdspan or raft::device_mdspan */ template using extents = std::experimental::extents; @@ -56,6 +56,11 @@ using layout_f_contiguous = layout_left; using col_major = layout_left; /** @} */ +/** + * @brief Strided layout for non-contiguous memory. + */ +using detail::stdex::layout_stride; + /** * @defgroup Common mdarray/mdspan extent types. The rank is known at compile time, each dimension * is known at run time (dynamic_extent in each dimension). @@ -424,8 +429,7 @@ class mdarray auto operator()(IndexType&&... indices) -> std::enable_if_t && ...) && - std::is_constructible_v && - std::is_constructible_v, + std::is_constructible_v, /* device policy is not default constructible due to requirement for CUDA stream. */ /* std::is_default_constructible_v */ diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index c09b7a2450..242d3a3912 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -18,6 +18,7 @@ #include "cublas_wrappers.hpp" +#include #include #include #include @@ -79,6 +80,57 @@ void transpose(math_t* inout, int n, cudaStream_t stream) }); } +template +void transpose_row_major_impl( + handle_t const& handle, + raft::mdspan, LayoutPolicy, AccessorPolicy> in, + raft::mdspan, LayoutPolicy, AccessorPolicy> out) +{ + auto out_n_rows = in.extent(1); + auto out_n_cols = in.extent(0); + T constexpr kOne = 1; + T constexpr kZero = 0; + CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), + CUBLAS_OP_T, + CUBLAS_OP_N, + out_n_cols, + out_n_rows, + &kOne, + in.data_handle(), + in.stride(0), + &kZero, + static_cast(nullptr), + out.stride(0), + out.data_handle(), + out.stride(0), + handle.get_stream())); +} + +template +void transpose_col_major_impl( + handle_t const& handle, + raft::mdspan, LayoutPolicy, AccessorPolicy> in, + raft::mdspan, LayoutPolicy, AccessorPolicy> out) +{ + auto out_n_rows = in.extent(1); + auto out_n_cols = in.extent(0); + T constexpr kOne = 1; + T constexpr kZero = 0; + CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), + CUBLAS_OP_T, + CUBLAS_OP_N, + out_n_rows, + out_n_cols, + &kOne, + in.data_handle(), + in.stride(1), + &kZero, + static_cast(nullptr), + out.stride(1), + out.data_handle(), + out.stride(1), + handle.get_stream())); +} }; // end namespace detail }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/transpose.cuh b/cpp/include/raft/linalg/transpose.cuh index a9ada5125a..cd78a2f495 100644 --- a/cpp/include/raft/linalg/transpose.cuh +++ b/cpp/include/raft/linalg/transpose.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/transpose.cuh" +#include namespace raft { namespace linalg { @@ -55,7 +56,45 @@ void transpose(math_t* inout, int n, cudaStream_t stream) detail::transpose(inout, n, stream); } +/** + * @brief Transpose a matrix. The output has same layout policy as the input. + * + * @tparam T Data type of input matrix element. + * @tparam IndexType Index type of matrix extent. + * @tparam LayoutPolicy Layout type of the input matrix. When layout is strided, it can + * be a submatrix of a larger matrix. Arbitrary stride is not supported. + * @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on + * device. + * + * @param[in] handle raft handle for managing expensive cuda resources. + * @param[in] in Input matrix. + * @param[out] out Output matirx, storage is pre-allocated by caller. + */ +template +auto transpose(handle_t const& handle, + raft::mdspan, LayoutPolicy, AccessorPolicy> in, + raft::mdspan, LayoutPolicy, AccessorPolicy> out) + -> std::enable_if_t, void> +{ + RAFT_EXPECTS(out.extent(0) == in.extent(1), "Invalid shape for transpose."); + RAFT_EXPECTS(out.extent(1) == in.extent(0), "Invalid shape for transpose."); + + if constexpr (std::is_same_v) { + detail::transpose_row_major_impl(handle, in, out); + } else if (std::is_same_v) { + detail::transpose_col_major_impl(handle, in, out); + } else { + RAFT_EXPECTS(in.stride(0) == 1 || in.stride(1) == 1, "Unsupported matrix layout."); + if (in.stride(1) == 1) { + // row-major submatrix + detail::transpose_row_major_impl(handle, in, out); + } else { + // col-major submatrix + detail::transpose_col_major_impl(handle, in, out); + } + } +} }; // end namespace linalg }; // end namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index 3bb30c9f33..98f6d5e7e4 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -113,5 +113,150 @@ INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValF, ::testing::ValuesIn( INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValD, ::testing::ValuesIn(inputsd2)); +namespace { +/** + * We hide these functions in tests for now until we have a heterogeneous mdarray + * implementation. + */ + +/** + * @brief Transpose a matrix. The output has same layout policy as the input. + * + * @tparam T Data type of input matrix elements. + * @tparam LayoutPolicy Layout type of the input matrix. When layout is strided, it can + * be a submatrix of a larger matrix. Arbitrary stride is not supported. + * + * @param[in] handle raft handle for managing expensive cuda resources. + * @param[in] in Input matrix. + * + * @return The transposed matrix. + */ +template +[[nodiscard]] auto transpose(handle_t const& handle, + device_matrix_view in) + -> std::enable_if_t && + (std::is_same_v || + std::is_same_v), + device_matrix> +{ + auto out = make_device_matrix(handle, in.extent(1), in.extent(0)); + ::raft::linalg::transpose(handle, in, out.view()); + return out; +} + +/** + * @brief Transpose a matrix. The output has same layout policy as the input. + * + * @tparam T Data type of input matrix elements. + * @tparam LayoutPolicy Layout type of the input matrix. When layout is strided, it can + * be a submatrix of a larger matrix. Arbitrary stride is not supported. + * + * @param[in] handle raft handle for managing expensive cuda resources. + * @param[in] in Input matrix. + * + * @return The transposed matrix. + */ +template +[[nodiscard]] auto transpose(handle_t const& handle, + device_matrix_view in) + -> std::enable_if_t, device_matrix> +{ + matrix_extent exts{in.extent(1), in.extent(0)}; + using policy_type = + typename raft::device_matrix::container_policy_type; + policy_type policy(handle.get_stream()); + + RAFT_EXPECTS(in.stride(0) == 1 || in.stride(1) == 1, "Unsupported matrix layout."); + if (in.stride(1) == 1) { + // row-major submatrix + std::array strides{in.extent(0), 1}; + auto layout = layout_stride::mapping>{exts, strides}; + raft::device_matrix out{layout, policy}; + ::raft::linalg::transpose(handle, in, out.view()); + return out; + } else { + // col-major submatrix + std::array strides{1, in.extent(1)}; + auto layout = layout_stride::mapping>{exts, strides}; + raft::device_matrix out{layout, policy}; + ::raft::linalg::transpose(handle, in, out.view()); + return out; + } +} + +template +void test_transpose_with_mdspan() +{ + handle_t handle; + auto v = make_device_matrix(handle, 32, 3); + T k{0}; + for (size_t i = 0; i < v.extent(0); ++i) { + for (size_t j = 0; j < v.extent(1); ++j) { + v(i, j) = k++; + } + } + auto out = transpose(handle, v.view()); + static_assert(std::is_same_v); + ASSERT_EQ(out.extent(0), v.extent(1)); + ASSERT_EQ(out.extent(1), v.extent(0)); + + k = 0; + for (size_t i = 0; i < out.extent(1); ++i) { + for (size_t j = 0; j < out.extent(0); ++j) { + ASSERT_EQ(out(j, i), k++); + } + } +} +} // namespace + +TEST(TransposeTest, MDSpan) +{ + test_transpose_with_mdspan(); + test_transpose_with_mdspan(); + + test_transpose_with_mdspan(); + test_transpose_with_mdspan(); +} + +namespace { +template +void test_transpose_submatrix() +{ + handle_t handle; + auto v = make_device_matrix(handle, 32, 33); + T k{0}; + size_t row_beg{3}, row_end{13}, col_beg{2}, col_end{11}; + for (size_t i = row_beg; i < row_end; ++i) { + for (size_t j = col_beg; j < col_end; ++j) { + v(i, j) = k++; + } + } + + auto vv = v.view(); + auto submat = raft::detail::stdex::submdspan( + vv, std::make_tuple(row_beg, row_end), std::make_tuple(col_beg, col_end)); + static_assert(std::is_same_v); + + auto out = transpose(handle, submat); + ASSERT_EQ(out.extent(0), submat.extent(1)); + ASSERT_EQ(out.extent(1), submat.extent(0)); + + k = 0; + for (size_t i = 0; i < out.extent(1); ++i) { + for (size_t j = 0; j < out.extent(0); ++j) { + ASSERT_EQ(out(j, i), k++); + } + } +} +} // namespace + +TEST(TransposeTest, SubMatrix) +{ + test_transpose_submatrix(); + test_transpose_submatrix(); + + test_transpose_submatrix(); + test_transpose_submatrix(); +} } // end namespace linalg } // end namespace raft