diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 595c0161cd..ab6a04587a 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -21,12 +21,21 @@ * limitations under the License. */ #pragma once + +#include + #include #include #include #include namespace raft { +/** + * @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan + */ +template +using extents = std::experimental::extents; + /** * @\brief C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. */ @@ -37,6 +46,30 @@ using layout_c_contiguous = detail::stdex::layout_right; */ using layout_f_contiguous = detail::stdex::layout_left; +/** + * @\brief Template checks and helpers to determine if type T is an std::mdspan + * or a derived type + */ + +template +void __takes_an_mdspan_ptr( + detail::stdex::mdspan*); + +template +struct __is_mdspan : std::false_type { +}; + +template +struct __is_mdspan()))>> + : std::true_type { +}; + +template +using __is_mdspan_t = __is_mdspan>; + +template +inline constexpr bool __is_mdspan_v = __is_mdspan_t::value; + /** * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. */ @@ -57,6 +90,85 @@ template >; +template +struct __is_device_mdspan : std::false_type { +}; + +template +struct __is_device_mdspan : std::bool_constant { +}; + +/** + * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type + */ +template +inline constexpr bool is_device_mdspan_v = __is_device_mdspan>::value; + +template +struct __is_host_mdspan : std::false_type { +}; + +template +struct __is_host_mdspan : T::accessor_type::is_host_type { +}; + +/** + * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type + */ +template +inline constexpr bool is_host_mdspan_v = __is_host_mdspan>::value; + +/** + * @\brief Boolean to determine if template type T is either raft::host_mdspan/raft::device_mdspan + * or their derived types + * This is structured such that it will short-circuit if the type is not std::mdspan + * or a derived type, and otherwise it will check whether it is a raft::device_mdspan + * or raft::host_mdspan assuming the type was found to be std::mdspan or a derived type + */ +template +inline constexpr bool is_mdspan_v = + std::conjunction_v<__is_mdspan_t, + std::disjunction<__is_device_mdspan, __is_host_mdspan>>; + +/** + * @brief Interface to implement an owning multi-dimensional array + * + * raft::array_interace is an interface to owning container types for mdspan. + * Check implementation of raft::mdarray which implements raft::array_interface + * using Curiously Recurring Template Pattern. + * This interface calls into method `view()` whose implementation is provided by + * the implementing class. `view()` must return an object of type raft::host_mdspan + * or raft::device_mdspan or any types derived from the them. + */ +template +class array_interface { + /** + * @brief Get a mdspan that can be passed down to CUDA kernels. + */ + auto view() noexcept { return static_cast(this)->view(); } + /** + * @brief Get a mdspan that can be passed down to CUDA kernels. + */ + auto view() const noexcept { return static_cast(this)->view(); } +}; + +template +struct __is_array_interface : std::false_type { +}; + +template +struct __is_array_interface().view())>> + : std::bool_constant().view())>> { +}; + +/** + * @\brief Boolean to determine if template type T is raft::array_interface or derived type + * or any type that has a member function `view()` that returns either + * raft::host_mdspan or raft::device_mdspan + */ +template +inline constexpr bool is_array_interface_v = __is_array_interface>::value; + /** * @brief Modified from the c++ mdarray proposal * @@ -87,7 +199,8 @@ using host_mdspan = * removed. */ template -class mdarray { +class mdarray + : public array_interface> { static_assert(!std::is_const::value, "Element type for container must not be const."); @@ -340,15 +453,15 @@ using device_scalar = device_mdarray; * @brief Shorthand for 1-dim host mdarray. * @tparam ElementType the data type of the vector elements */ -template -using host_vector = host_mdarray; +template +using host_vector = host_mdarray; /** * @brief Shorthand for 1-dim device mdarray. * @tparam ElementType the data type of the vector elements */ -template -using device_vector = device_mdarray; +template +using device_vector = device_mdarray; /** * @brief Shorthand for c-contiguous host matrix. @@ -384,15 +497,15 @@ using device_scalar_view = device_mdspan; * @brief Shorthand for 1-dim host mdspan. * @tparam ElementType the data type of the vector elements */ -template -using host_vector_view = host_mdspan; +template +using host_vector_view = host_mdspan; /** * @brief Shorthand for 1-dim device mdspan. * @tparam ElementType the data type of the vector elements */ -template -using device_vector_view = device_mdspan; +template +using device_vector_view = device_mdspan; /** * @brief Shorthand for c-contiguous host matrix view. @@ -478,11 +591,11 @@ auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) * @param[in] n number of elements in pointer * @return raft::host_vector_view */ -template +template auto make_host_vector_view(ElementType* ptr, size_t n) { detail::vector_extent extents{n}; - return host_vector_view{ptr, extents}; + return host_vector_view{ptr, extents}; } /** @@ -492,11 +605,11 @@ auto make_host_vector_view(ElementType* ptr, size_t n) * @param[in] n number of elements in pointer * @return raft::device_vector_view */ -template +template auto make_device_vector_view(ElementType* ptr, size_t n) { detail::vector_extent extents{n}; - return device_vector_view{ptr, extents}; + return device_vector_view{ptr, extents}; } /** @@ -610,13 +723,13 @@ auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) * @param[in] n number of elements in vector * @return raft::host_vector */ -template +template auto make_host_vector(size_t n) { detail::vector_extent extents{n}; - using policy_t = typename host_vector::container_policy_type; + using policy_t = typename host_vector::container_policy_type; policy_t policy; - return host_vector{extents, policy}; + return host_vector{extents, policy}; } /** @@ -626,13 +739,13 @@ auto make_host_vector(size_t n) * @param[in] stream the cuda stream for ordering events * @return raft::device_vector */ -template +template auto make_device_vector(size_t n, rmm::cuda_stream_view stream) { detail::vector_extent extents{n}; - using policy_t = typename device_vector::container_policy_type; + using policy_t = typename device_vector::container_policy_type; policy_t policy{stream}; - return device_vector{extents, policy}; + return device_vector{extents, policy}; } /** @@ -642,9 +755,92 @@ auto make_device_vector(size_t n, rmm::cuda_stream_view stream) * @param[in] n number of elements in vector * @return raft::device_vector */ -template +template auto make_device_vector(raft::handle_t const& handle, size_t n) { - return make_device_vector(n, handle.get_stream()); + return make_device_vector(n, handle.get_stream()); } + +/** + * @brief Flatten raft::host_mdspan or raft::device_mdspan into a 1-dim array view + * + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @param mds raft::host_mdspan or raft::device_mdspan object + * @return raft::host_mdspan or raft::device_mdspan with vector_extent + * depending on AccessoryPolicy + */ +template >* = nullptr> +auto flatten(mdspan_type mds) +{ + RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous."); + + detail::vector_extent ext{mds.size()}; + + return detail::stdex::mdspan(mds.data(), ext); +} + +/** + * @brief Flatten object implementing raft::array_interface into a 1-dim array view + * + * @tparam array_interface_type Expected type implementing raft::array_interface + * @param mda raft::array_interace implementing object + * @return Either raft::host_mdspan or raft::device_mdspan with vector_extent + * depending on the underlying ContainerPolicy + */ +template >* = nullptr> +auto flatten(const array_interface_type& mda) +{ + return flatten(mda.view()); +} + +/** + * @brief Reshape raft::host_mdspan or raft::device_mdspan + * + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @tparam Extents raft::extents for dimensions + * @param mds raft::host_mdspan or raft::device_mdspan object + * @param new_shape Desired new shape of the input + * @return raft::host_mdspan or raft::device_mdspan, depending on AccessorPolicy + */ +template >* = nullptr> +auto reshape(mdspan_type mds, extents new_shape) +{ + RAFT_EXPECTS(mds.is_contiguous(), "Input must be contiguous."); + + size_t new_size = 1; + for (size_t i = 0; i < new_shape.rank(); ++i) { + new_size *= new_shape.extent(i); + } + RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch"); + + return detail::stdex::mdspan(mds.data(), new_shape); +} + +/** + * @brief Reshape object implementing raft::array_interface + * + * @tparam array_interface_type Expected type implementing raft::array_interface + * @tparam Extents raft::extents for dimensions + * @param mda raft::array_interace implementing object + * @param new_shape Desired new shape of the input + * @return raft::host_mdspan or raft::device_mdspan, depending on the underlying + * ContainerPolicy + */ +template >* = nullptr> +auto reshape(const array_interface_type& mda, extents new_shape) +{ + return reshape(mda.view(), new_shape); +} + } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index 624c7a4d07..03a56839a7 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -22,6 +22,7 @@ */ #pragma once #include +#include #include // dynamic_extent #include #include @@ -58,7 +59,7 @@ class device_reference { auto operator=(T const& other) -> device_reference& { auto* raw = ptr_.get(); - update_device(raw, &other, 1, stream_); + raft::update_device(raw, &other, 1, stream_); return *this; } }; diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index d8e60550ca..3fbe7fc085 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -258,7 +258,7 @@ void distance(raft::handle_t const& handle, RAFT_EXPECTS(x.is_contiguous(), "Input x must be contiguous."); RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous."); - auto is_rowmajor = std::is_same::value; + constexpr auto is_rowmajor = std::is_same_v; distance(x.data(), y.data(), @@ -433,7 +433,7 @@ void pairwise_distance(raft::handle_t const& handle, RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous."); RAFT_EXPECTS(dist.is_contiguous(), "Output must be contiguous."); - bool rowmajor = x.stride(0) == 0; + constexpr auto rowmajor = std::is_same_v; rmm::device_uvector workspace(0, handle.get_stream()); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 354b5e8fc4..43c6257966 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -75,6 +75,7 @@ add_executable(test_raft test/matrix/columnSort.cu test/matrix/linewise_op.cu test/mdarray.cu + test/mdspan_utils.cu test/mst.cu test/random/make_blobs.cu test/random/make_regression.cu diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu new file mode 100644 index 0000000000..15388a5cef --- /dev/null +++ b/cpp/test/mdspan_utils.cu @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft { + +namespace stdex = std::experimental; + +template > +struct derived_device_mdspan + : public device_mdspan { +}; + +void test_template_asserts() +{ + // Testing 3d device mdspan to be an mdspan + using three_d_extents = extents; + using three_d_mdspan = device_mdspan; + using d_mdspan = derived_device_mdspan; + + static_assert(std::is_same_v, device_mdspan>, + "not same"); + static_assert(std::is_same_v, + device_mdspan>>, + "not same"); + + // Checking if types are mdspan, supposed to fail for std::vector + static_assert(is_mdspan_v, "3d mdspan type not an mdspan"); + static_assert(is_mdspan_v>, "device_matrix_view type not an mdspan"); + static_assert(is_mdspan_v>, + "const host_vector_view type not an mdspan"); + static_assert(is_mdspan_v>, + "const host_scalar_view type not an mdspan"); + static_assert(!is_mdspan_v>, "std::vector is an mdspan"); + static_assert(is_mdspan_v, "Derived device mdspan type is not mdspan"); + + // Checking if types are device_mdspan + static_assert(is_device_mdspan_v>, + "device_matrix_view type not a device_mdspan"); + static_assert(!is_device_mdspan_v>, + "host_matrix_view type is a device_mdspan"); + static_assert(is_device_mdspan_v, "Derived device mdspan type is not device_mdspan"); + + // Checking if types are host_mdspan + static_assert(!is_host_mdspan_v>, + "device_matrix_view type not a host_mdspan"); + static_assert(is_host_mdspan_v>, + "host_matrix_view type is a host_mdspan"); +} + +TEST(MDSpan, TemplateAsserts) { test_template_asserts(); } + +void test_host_flatten() +{ + // flatten 3d host mdspan + { + using three_d_extents = extents; + using three_d_mdarray = host_mdarray; + + three_d_extents extents{3, 3, 3}; + three_d_mdarray::container_policy_type policy; + three_d_mdarray mda{extents, policy}; + + auto flat_view = flatten(mda); + + static_assert(std::is_same_v, + "layouts not the same"); + + ASSERT_EQ(flat_view.extents().rank(), 1); + ASSERT_EQ(flat_view.size(), mda.size()); + } + + // flatten host vector + { + auto hv = make_host_vector(27); + auto flat_view = flatten(hv.view()); + + static_assert(std::is_same_v, "types not the same"); + + ASSERT_EQ(hv.extents().rank(), flat_view.extents().rank()); + ASSERT_EQ(hv.extent(0), flat_view.extent(0)); + } + + // flatten host scalar + { + auto hs = make_host_scalar(27); + auto flat_view = flatten(hs.view()); + + ASSERT_EQ(flat_view.extent(0), 1); + } +} + +TEST(MDArray, HostFlatten) { test_host_flatten(); } + +void test_device_flatten() +{ + raft::handle_t handle{}; + // flatten 3d device mdspan + { + using three_d_extents = extents; + using three_d_mdarray = device_mdarray; + + three_d_extents extents{3, 3, 3}; + three_d_mdarray::container_policy_type policy{handle.get_stream()}; + three_d_mdarray mda{extents, policy}; + + auto flat_view = flatten(mda); + + static_assert(std::is_same_v, + "layouts not the same"); + + ASSERT_EQ(flat_view.extents().rank(), 1); + ASSERT_EQ(flat_view.size(), mda.size()); + } + + // flatten device vector + { + auto dv = make_device_vector(27, handle.get_stream()); + auto flat_view = flatten(dv.view()); + + static_assert(std::is_same_v, "types not the same"); + + ASSERT_EQ(dv.extents().rank(), flat_view.extents().rank()); + ASSERT_EQ(dv.extent(0), flat_view.extent(0)); + } + + // flatten device scalar + { + auto ds = make_device_scalar(27, handle.get_stream()); + auto flat_view = flatten(ds.view()); + + ASSERT_EQ(flat_view.extent(0), 1); + } +} + +TEST(MDArray, DeviceFlatten) { test_device_flatten(); } + +void test_reshape() +{ + // reshape 3d host array to vector + { + using three_d_extents = extents; + using three_d_mdarray = host_mdarray; + + three_d_extents extents{3, 3, 3}; + three_d_mdarray::container_policy_type policy; + three_d_mdarray mda{extents, policy}; + + auto flat_view = reshape(mda, raft::extents{27}); + // this confirms aliasing works as intended + static_assert(std::is_same_v>, + "types not the same"); + + ASSERT_EQ(flat_view.extents().rank(), 1); + ASSERT_EQ(flat_view.size(), mda.size()); + } + + // reshape 4d device array to 2d + { + raft::handle_t handle{}; + using four_d_extents = extents; + using four_d_mdarray = device_mdarray; + + four_d_extents extents{2, 2, 2, 2}; + four_d_mdarray::container_policy_type policy{handle.get_stream()}; + four_d_mdarray mda{extents, policy}; + + auto matrix = reshape(mda, raft::extents{4, 4}); + // this confirms aliasing works as intended + static_assert(std::is_same_v>, + "types not the same"); + + ASSERT_EQ(matrix.extents().rank(), 2); + ASSERT_EQ(matrix.extent(0), 4); + ASSERT_EQ(matrix.extent(1), 4); + } + + // reshape 2d host matrix with static extents to vector + { + using two_d_extents = extents<5, 5>; + using two_d_mdarray = host_mdarray; + + two_d_mdarray::container_policy_type policy; + two_d_mdarray mda{two_d_extents{}, policy}; + + auto vector = reshape(mda, raft::extents<25>{}); + + ASSERT_EQ(vector.extents().rank(), 1); + ASSERT_EQ(vector.size(), mda.size()); + } +} + +TEST(MDArray, Reshape) { test_reshape(); } + +} // namespace raft \ No newline at end of file