From 5d697b3ef6e12d6ffd77fbe7842a47bc7d4819e5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 14:47:53 -0400 Subject: [PATCH 01/11] Breaking apart mdspan/mdarray into host_ and device_ variants --- .../raft/core/detail/accessor_mixin.hpp | 39 + cpp/include/raft/core/device_mdarray.hpp | 174 ++++ cpp/include/raft/core/device_mdspan.hpp | 194 +++++ cpp/include/raft/core/host_mdarray.hpp | 144 ++++ cpp/include/raft/core/host_mdspan.hpp | 142 ++++ cpp/include/raft/core/mdarray.hpp | 778 +----------------- cpp/include/raft/core/mdspan.hpp | 265 +++++- cpp/include/raft/detail/mdarray.hpp | 74 +- cpp/include/raft/detail/span.hpp | 2 +- 9 files changed, 962 insertions(+), 850 deletions(-) create mode 100644 cpp/include/raft/core/detail/accessor_mixin.hpp create mode 100644 cpp/include/raft/core/device_mdarray.hpp create mode 100644 cpp/include/raft/core/device_mdspan.hpp create mode 100644 cpp/include/raft/core/host_mdarray.hpp create mode 100644 cpp/include/raft/core/host_mdspan.hpp diff --git a/cpp/include/raft/core/detail/accessor_mixin.hpp b/cpp/include/raft/core/detail/accessor_mixin.hpp new file mode 100644 index 0000000000..6edd85dbaf --- /dev/null +++ b/cpp/include/raft/core/detail/accessor_mixin.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::detail { + +/** + * @brief A mixin to distinguish host and device memory. + */ +template +struct accessor_mixin : public AccessorPolicy { + using accessor_type = AccessorPolicy; + using is_host_type = std::conditional_t; + using is_device_type = std::conditional_t; + using is_managed_type = std::conditional_t; + static constexpr bool is_host_accessible = is_host; + static constexpr bool is_device_accessible = is_device; + static constexpr bool is_managed_accessible = is_device && is_host; + // make sure the explicit ctor can fall through + using AccessorPolicy::AccessorPolicy; + using offset_policy = accessor_mixin; + accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT +}; + +} // namespace raft::detail diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp new file mode 100644 index 0000000000..c5b5d73fdc --- /dev/null +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { + +/** + * @brief mdarray with device container policy + * @tparam ElementType the data type of the elements + * @tparam Extents defines the shape + * @tparam LayoutPolicy policy for indexing strides and layout ordering + * @tparam ContainerPolicy storage and accessor policy + */ +template > +using device_mdarray = + mdarray>; + +/** + * @brief Shorthand for 0-dim host mdarray (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using device_scalar = device_mdarray>; + +/** + * @brief Shorthand for 1-dim device mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using device_vector = device_mdarray, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous device matrix. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using device_matrix = device_mdarray, LayoutPolicy>; + +/** + * @brief Create a device mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param handle raft::handle_t + * @param exts dimensionality of the array (series of integers) + * @return raft::device_mdarray + */ +template +auto make_device_mdarray(const raft::handle_t& handle, extents exts) +{ + using mdarray_t = device_mdarray; + + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy{handle.get_stream()}; + + return mdarray_t{layout, policy}; +} + +/** + * @brief Create a device mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param handle raft::handle_t + * @param mr rmm memory resource used for allocating the memory for the array + * @param exts dimensionality of the array (series of integers) + * @return raft::device_mdarray + */ +template +auto make_device_mdarray(const raft::handle_t& handle, + rmm::mr::device_memory_resource* mr, + extents exts) +{ + using mdarray_t = device_mdarray; + + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy{handle.get_stream(), mr}; + + return mdarray_t{layout, policy}; +} + +/** + * @brief Create a 2-dim c-contiguous device mdarray. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] handle raft handle for managing expensive resources + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @return raft::device_matrix + */ +template +auto make_device_matrix(raft::handle_t const& handle, IndexType n_rows, IndexType n_cols) +{ + return make_device_mdarray( + handle.get_stream(), make_extents(n_rows, n_cols)); +} + +/** + * @brief Create a device scalar from v. + * + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] v scalar to wrap on device + * @return raft::device_scalar + */ +template +auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) +{ + scalar_extent extents; + using policy_t = typename device_scalar::container_policy_type; + policy_t policy{handle.get_stream()}; + auto scalar = device_scalar{extents, policy}; + scalar(0) = v; + return scalar; +} + +/** + * @brief Create a 1-dim device mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] n number of elements in vector + * @return raft::device_vector + */ +template +auto make_device_vector(raft::handle_t const& handle, IndexType n) +{ + return make_device_mdarray(handle.get_stream(), + make_extents(n)); +} + +} // end namespace raft diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp new file mode 100644 index 0000000000..f835397e9d --- /dev/null +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft { + +template +using device_accessor = detail::accessor_mixin; + +template +using managed_accessor = detail::accessor_mixin; + +/** + * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. + */ +template > +using device_mdspan = mdspan>; + +template > +using managed_mdspan = mdspan>; + +namespace detail { +template +struct is_device_mdspan : std::false_type { +}; +template +struct is_device_mdspan : std::bool_constant { +}; + +/** + * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type + */ +template +using is_device_mdspan_t = is_device_mdspan>; + +template +struct is_managed_mdspan : std::false_type { +}; +template +struct is_managed_mdspan : std::bool_constant { +}; + +/** + * @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type + */ +template +using is_managed_mdspan_t = is_managed_mdspan>; + +} // end namespace detail + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a + * derived type + */ +template +inline constexpr bool is_device_mdspan_v = std::conjunction_v...>; + +template +using enable_if_device_mdspan = std::enable_if_t>; + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a + * derived type + */ +template +inline constexpr bool is_managed_mdspan_v = std::conjunction_v...>; + +template +using enable_if_managed_mdspan = std::enable_if_t>; + +/** + * @brief Shorthand for 0-dim host mdspan (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using device_scalar_view = device_mdspan>; + +/** + * @brief Shorthand for 1-dim device mdspan. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using device_vector_view = device_mdspan, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous device matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using device_matrix_view = device_mdspan, LayoutPolicy>; + +/** + * @brief Create a raft::managed_mdspan + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param ptr Pointer to the data + * @param exts dimensionality of the array (series of integers) + * @return raft::managed_mdspan + */ +template +auto make_managed_mdspan(ElementType* ptr, extents exts) +{ + return make_mdspan(ptr, exts); +} + +/** + * @brief Create a 0-dim (scalar) mdspan instance for device value. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @param[in] ptr on device to wrap + */ +template +auto make_device_scalar_view(ElementType* ptr) +{ + scalar_extent extents; + return device_scalar_view{ptr, extents}; +} + +/** + * @brief Create a 2-dim c-contiguous mdspan instance for device pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @tparam IndexType the index type of the extents + * @param[in] ptr on device to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + */ +template +auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +{ + matrix_extent extents{n_rows, n_cols}; + return device_matrix_view{ptr, extents}; +} + +/** + * @brief Create a 1-dim mdspan instance for device pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on device to wrap + * @param[in] n number of elements in pointer + * @return raft::device_vector_view + */ +template +auto make_device_vector_view(ElementType* ptr, IndexType n) +{ + vector_extent extents{n}; + return device_vector_view{ptr, extents}; +} + +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_mdarray.hpp b/cpp/include/raft/core/host_mdarray.hpp new file mode 100644 index 0000000000..872c007255 --- /dev/null +++ b/cpp/include/raft/core/host_mdarray.hpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { +/** + * @brief mdarray with host container policy + * @tparam ElementType the data type of the elements + * @tparam Extents defines the shape + * @tparam LayoutPolicy policy for indexing strides and layout ordering + * @tparam ContainerPolicy storage and accessor policy + */ +template > +using host_mdarray = + mdarray>; + +/** + * @brief Shorthand for 0-dim host mdarray (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using host_scalar = host_mdarray>; + +/** + * @brief Shorthand for 1-dim host mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using host_vector = host_mdarray, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous host matrix. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using host_matrix = host_mdarray, LayoutPolicy>; + +/** + * @brief Create a host mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param exts dimensionality of the array (series of integers) + * @return raft::host_mdarray + */ +template +auto make_host_mdarray(extents exts) +{ + using mdarray_t = host_mdarray; + + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy; + + return mdarray_t{layout, policy}; +} + +/** + * @brief Create a 2-dim c-contiguous host mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @return raft::host_matrix + */ +template +auto make_host_matrix(IndexType n_rows, IndexType n_cols) +{ + return make_host_mdarray( + make_extents(n_rows, n_cols)); +} + +/** + * @brief Create a host scalar from v. + * + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + * @param[in] v scalar type to wrap + * @return raft::host_scalar + */ +template +auto make_host_scalar(ElementType const& v) +{ + // FIXME(jiamingy): We can optimize this by using std::array as container policy, which + // requires some more compile time dispatching. This is enabled in the ref impl but + // hasn't been ported here yet. + scalar_extent extents; + using policy_t = typename host_scalar::container_policy_type; + policy_t policy; + auto scalar = host_scalar{extents, policy}; + scalar(0) = v; + return scalar; +} + +/** + * @brief Create a 1-dim host mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] n number of elements in vector + * @return raft::host_vector + */ +template +auto make_host_vector(IndexType n) +{ + return make_host_mdarray(make_extents(n)); +} + +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp new file mode 100644 index 0000000000..f46fb6ff17 --- /dev/null +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft { + +template +using host_accessor = detail::accessor_mixin; + +/** + * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. + */ +template > +using host_mdspan = mdspan>; + +namespace detail { + +template +struct is_host_mdspan : std::false_type { +}; +template +struct is_host_mdspan : std::bool_constant { +}; + +/** + * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type + */ +template +using is_host_mdspan_t = is_host_mdspan>; + +} // namespace detail + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a + * derived type + */ +template +inline constexpr bool is_host_mdspan_v = std::conjunction_v...>; + +template +using enable_if_host_mdspan = std::enable_if_t>; + +/** + * @brief Shorthand for 0-dim host mdspan (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using host_scalar_view = host_mdspan>; + +/** + * @brief Shorthand for 1-dim host mdspan. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + */ +template +using host_vector_view = host_mdspan, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous host matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using host_matrix_view = host_mdspan, LayoutPolicy>; + +/** + * @brief Create a 0-dim (scalar) mdspan instance for host value. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @param[in] ptr on device to wrap + */ +template +auto make_host_scalar_view(ElementType* ptr) +{ + scalar_extent extents; + return host_scalar_view{ptr, extents}; +} + +/** + * @brief Create a 2-dim c-contiguous mdspan instance for host pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on host to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + */ +template +auto make_host_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +{ + matrix_extent extents{n_rows, n_cols}; + return host_matrix_view{ptr, extents}; +} + +/** + * @brief Create a 1-dim mdspan instance for host pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @param[in] ptr on host to wrap + * @param[in] n number of elements in pointer + * @return raft::host_vector_view + */ +template +auto make_host_vector_view(ElementType* ptr, IndexType n) +{ + vector_extent extents{n}; + return host_vector_view{ptr, extents}; +} +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index d251c2b419..e918d81ff1 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -24,212 +24,15 @@ #include +#include #include +#include #include #include - #include #include namespace raft { -/** - * @brief Dimensions extents for raft::host_mdspan or raft::device_mdspan - */ -template -using extents = std::experimental::extents; - -/** - * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. - * @{ - */ -using detail::stdex::layout_right; -using layout_c_contiguous = layout_right; -using row_major = layout_right; -/** @} */ - -/** - * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. - * @{ - */ -using detail::stdex::layout_left; -using layout_f_contiguous = layout_left; -using col_major = layout_left; -/** @} */ - -/** - * @brief Strided layout for non-contiguous memory. - */ -using detail::stdex::layout_stride; - -/** - * @defgroup Common mdarray/mdspan extent types. The rank is known at compile time, each dimension - * is known at run time (dynamic_extent in each dimension). - * @{ - */ -using detail::matrix_extent; -using detail::scalar_extent; -using detail::vector_extent; - -template -using extent_1d = vector_extent; - -template -using extent_2d = matrix_extent; - -template -using extent_3d = detail::stdex::extents; - -template -using extent_4d = - detail::stdex::extents; - -template -using extent_5d = detail::stdex::extents; -/** @} */ - -template > -using mdspan = detail::stdex::mdspan; - -namespace detail { -/** - * @\brief Template checks and helpers to determine if type T is an std::mdspan - * or a derived type - */ - -template -void __takes_an_mdspan_ptr(mdspan*); - -template -struct is_mdspan : std::false_type { -}; -template -struct is_mdspan()))>> - : std::true_type { -}; - -template -using is_mdspan_t = is_mdspan>; - -template -inline constexpr bool is_mdspan_v = is_mdspan_t::value; -} // namespace detail - -/** - * @\brief Boolean to determine if variadic template types Tn are either - * raft::host_mdspan/raft::device_mdspan or their derived types - */ -template -inline constexpr bool is_mdspan_v = std::conjunction_v...>; - -template -using enable_if_mdspan = std::enable_if_t>; - -/** - * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. - */ -template > -using device_mdspan = - mdspan>; - -/** - * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. - */ -template > -using host_mdspan = - mdspan>; - -template > -using managed_mdspan = - mdspan>; - -namespace detail { -template -struct is_device_mdspan : std::false_type { -}; -template -struct is_device_mdspan : std::bool_constant { -}; - -/** - * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type - */ -template -using is_device_mdspan_t = is_device_mdspan>; - -template -struct is_host_mdspan : std::false_type { -}; -template -struct is_host_mdspan : std::bool_constant { -}; - -/** - * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type - */ -template -using is_host_mdspan_t = is_host_mdspan>; - -template -struct is_managed_mdspan : std::false_type { -}; -template -struct is_managed_mdspan : std::bool_constant { -}; - -/** - * @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type - */ -template -using is_managed_mdspan_t = is_managed_mdspan>; -} // namespace detail - -/** - * @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a - * derived type - */ -template -inline constexpr bool is_device_mdspan_v = std::conjunction_v...>; - -template -using enable_if_device_mdspan = std::enable_if_t>; - -/** - * @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a - * derived type - */ -template -inline constexpr bool is_host_mdspan_v = std::conjunction_v...>; - -template -using enable_if_host_mdspan = std::enable_if_t>; - -/** - * @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a - * derived type - */ -template -inline constexpr bool is_managed_mdspan_v = std::conjunction_v...>; - -template -using enable_if_managed_mdspan = std::enable_if_t>; - /** * @brief Interface to implement an owning multi-dimensional array * @@ -531,521 +334,6 @@ class mdarray container_type c_; }; -/** - * @brief mdarray with host container policy - * @tparam ElementType the data type of the elements - * @tparam Extents defines the shape - * @tparam LayoutPolicy policy for indexing strides and layout ordering - * @tparam ContainerPolicy storage and accessor policy - */ -template > -using host_mdarray = - mdarray>; - -/** - * @brief mdarray with device container policy - * @tparam ElementType the data type of the elements - * @tparam Extents defines the shape - * @tparam LayoutPolicy policy for indexing strides and layout ordering - * @tparam ContainerPolicy storage and accessor policy - */ -template > -using device_mdarray = - mdarray>; - -/** - * @brief Shorthand for 0-dim host mdarray (scalar). - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - */ -template -using host_scalar = host_mdarray>; - -/** - * @brief Shorthand for 0-dim host mdarray (scalar). - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - */ -template -using device_scalar = device_mdarray>; - -/** - * @brief Shorthand for 1-dim host mdarray. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using host_vector = host_mdarray, LayoutPolicy>; - -/** - * @brief Shorthand for 1-dim device mdarray. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using device_vector = device_mdarray, LayoutPolicy>; - -/** - * @brief Shorthand for c-contiguous host matrix. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using host_matrix = host_mdarray, LayoutPolicy>; - -/** - * @brief Shorthand for c-contiguous device matrix. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using device_matrix = device_mdarray, LayoutPolicy>; - -/** - * @brief Shorthand for 0-dim host mdspan (scalar). - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - */ -template -using host_scalar_view = host_mdspan>; - -/** - * @brief Shorthand for 0-dim host mdspan (scalar). - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - */ -template -using device_scalar_view = device_mdspan>; - -/** - * @brief Shorthand for 1-dim host mdspan. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - */ -template -using host_vector_view = host_mdspan, LayoutPolicy>; - -/** - * @brief Shorthand for 1-dim device mdspan. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using device_vector_view = device_mdspan, LayoutPolicy>; - -/** - * @brief Shorthand for c-contiguous host matrix view. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using host_matrix_view = host_mdspan, LayoutPolicy>; - -/** - * @brief Shorthand for c-contiguous device matrix view. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using device_matrix_view = device_mdspan, LayoutPolicy>; - -/** - * @brief Create a raft::mdspan - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @tparam is_host_accessible whether the data is accessible on host - * @tparam is_device_accessible whether the data is accessible on device - * @param ptr Pointer to the data - * @param exts dimensionality of the array (series of integers) - * @return raft::mdspan - */ -template -auto make_mdspan(ElementType* ptr, extents exts) -{ - using accessor_type = detail::accessor_mixin, - is_host_accessible, - is_device_accessible>; - - return mdspan{ptr, exts}; -} - -/** - * @brief Create a raft::managed_mdspan - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param ptr Pointer to the data - * @param exts dimensionality of the array (series of integers) - * @return raft::managed_mdspan - */ -template -auto make_managed_mdspan(ElementType* ptr, extents exts) -{ - return make_mdspan(ptr, exts); -} - -/** - * @brief Create a 0-dim (scalar) mdspan instance for host value. - * - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @param[in] ptr on device to wrap - */ -template -auto make_host_scalar_view(ElementType* ptr) -{ - scalar_extent extents; - return host_scalar_view{ptr, extents}; -} - -/** - * @brief Create a 0-dim (scalar) mdspan instance for device value. - * - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @param[in] ptr on device to wrap - */ -template -auto make_device_scalar_view(ElementType* ptr) -{ - scalar_extent extents; - return device_scalar_view{ptr, extents}; -} - -/** - * @brief Create a 2-dim c-contiguous mdspan instance for host pointer. It's - * expected that the given layout policy match the layout of the underlying - * pointer. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] ptr on host to wrap - * @param[in] n_rows number of rows in pointer - * @param[in] n_cols number of columns in pointer - */ -template -auto make_host_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) -{ - matrix_extent extents{n_rows, n_cols}; - return host_matrix_view{ptr, extents}; -} -/** - * @brief Create a 2-dim c-contiguous mdspan instance for device pointer. It's - * expected that the given layout policy match the layout of the underlying - * pointer. - * @tparam ElementType the data type of the matrix elements - * @tparam LayoutPolicy policy for strides and layout ordering - * @tparam IndexType the index type of the extents - * @param[in] ptr on device to wrap - * @param[in] n_rows number of rows in pointer - * @param[in] n_cols number of columns in pointer - */ -template -auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) -{ - matrix_extent extents{n_rows, n_cols}; - return device_matrix_view{ptr, extents}; -} - -/** - * @brief Create a 1-dim mdspan instance for host pointer. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @param[in] ptr on host to wrap - * @param[in] n number of elements in pointer - * @return raft::host_vector_view - */ -template -auto make_host_vector_view(ElementType* ptr, IndexType n) -{ - vector_extent extents{n}; - return host_vector_view{ptr, extents}; -} - -/** - * @brief Create a 1-dim mdspan instance for device pointer. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] ptr on device to wrap - * @param[in] n number of elements in pointer - * @return raft::device_vector_view - */ -template -auto make_device_vector_view(ElementType* ptr, IndexType n) -{ - vector_extent extents{n}; - return device_vector_view{ptr, extents}; -} - -/** - * @brief Create a host mdarray. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param exts dimensionality of the array (series of integers) - * @return raft::host_mdarray - */ -template -auto make_host_mdarray(extents exts) -{ - using mdarray_t = host_mdarray; - - typename mdarray_t::mapping_type layout{exts}; - typename mdarray_t::container_policy_type policy; - - return mdarray_t{layout, policy}; -} - -/** - * @brief Create a device mdarray. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param handle raft::handle_t - * @param exts dimensionality of the array (series of integers) - * @return raft::device_mdarray - */ -template -auto make_device_mdarray(const raft::handle_t& handle, extents exts) -{ - using mdarray_t = device_mdarray; - - typename mdarray_t::mapping_type layout{exts}; - typename mdarray_t::container_policy_type policy{handle.get_stream()}; - - return mdarray_t{layout, policy}; -} - -/** - * @brief Create a device mdarray. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param handle raft::handle_t - * @param mr rmm memory resource used for allocating the memory for the array - * @param exts dimensionality of the array (series of integers) - * @return raft::device_mdarray - */ -template -auto make_device_mdarray(const raft::handle_t& handle, - rmm::mr::device_memory_resource* mr, - extents exts) -{ - using mdarray_t = device_mdarray; - - typename mdarray_t::mapping_type layout{exts}; - typename mdarray_t::container_policy_type policy{handle.get_stream(), mr}; - - return mdarray_t{layout, policy}; -} - -/** - * @brief Create raft::extents to specify dimensionality - * - * @tparam IndexType The type of each dimension of the extents - * @tparam Extents Dimensions (a series of integers) - * @param exts The desired dimensions - * @return raft::extents - */ -template > -auto make_extents(Extents... exts) -{ - return extents{exts...}; -} - -/** - * @brief Create a 2-dim c-contiguous host mdarray. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] n_rows number or rows in matrix - * @param[in] n_cols number of columns in matrix - * @return raft::host_matrix - */ -template -auto make_host_matrix(IndexType n_rows, IndexType n_cols) -{ - return make_host_mdarray( - make_extents(n_rows, n_cols)); -} - -/** - * @brief Create a 2-dim c-contiguous device mdarray. - * - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] handle raft handle for managing expensive resources - * @param[in] n_rows number or rows in matrix - * @param[in] n_cols number of columns in matrix - * @return raft::device_matrix - */ -template -auto make_device_matrix(raft::handle_t const& handle, IndexType n_rows, IndexType n_cols) -{ - return make_device_mdarray( - handle.get_stream(), make_extents(n_rows, n_cols)); -} - -/** - * @brief Create a host scalar from v. - * - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - * @param[in] v scalar type to wrap - * @return raft::host_scalar - */ -template -auto make_host_scalar(ElementType const& v) -{ - // FIXME(jiamingy): We can optimize this by using std::array as container policy, which - // requires some more compile time dispatching. This is enabled in the ref impl but - // hasn't been ported here yet. - scalar_extent extents; - using policy_t = typename host_scalar::container_policy_type; - policy_t policy; - auto scalar = host_scalar{extents, policy}; - scalar(0) = v; - return scalar; -} - -/** - * @brief Create a device scalar from v. - * - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - * @param[in] handle raft handle for managing expensive cuda resources - * @param[in] v scalar to wrap on device - * @return raft::device_scalar - */ -template -auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) -{ - scalar_extent extents; - using policy_t = typename device_scalar::container_policy_type; - policy_t policy{handle.get_stream()}; - auto scalar = device_scalar{extents, policy}; - scalar(0) = v; - return scalar; -} - -/** - * @brief Create a 1-dim host mdarray. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] n number of elements in vector - * @return raft::host_vector - */ -template -auto make_host_vector(IndexType n) -{ - return make_host_mdarray(make_extents(n)); -} - -/** - * @brief Create a 1-dim device mdarray. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] handle raft handle for managing expensive cuda resources - * @param[in] n number of elements in vector - * @return raft::device_vector - */ -template -auto make_device_vector(raft::handle_t const& handle, IndexType n) -{ - return make_device_mdarray(handle.get_stream(), - make_extents(n)); -} - -/** - * @brief Flatten raft::host_mdspan or raft::device_mdspan into a 1-dim array view - * - * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan - * @param mds raft::host_mdspan or raft::device_mdspan object - * @return raft::host_mdspan or raft::device_mdspan with vector_extent - * depending on AccessoryPolicy - */ -template > -auto flatten(mdspan_type mds) -{ - RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous."); - - vector_extent ext{mds.size()}; - - return detail::stdex::mdspan(mds.data_handle(), ext); -} - /** * @brief Flatten object implementing raft::array_interface into a 1-dim array view * @@ -1061,36 +349,6 @@ auto flatten(const array_interface_type& mda) return flatten(mda.view()); } -/** - * @brief Reshape raft::host_mdspan or raft::device_mdspan - * - * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan - * @tparam IndexType the index type of the extents - * @tparam Extents raft::extents for dimensions - * @param mds raft::host_mdspan or raft::device_mdspan object - * @param new_shape Desired new shape of the input - * @return raft::host_mdspan or raft::device_mdspan, depending on AccessorPolicy - */ -template > -auto reshape(mdspan_type mds, extents new_shape) -{ - RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous."); - - size_t new_size = 1; - for (size_t i = 0; i < new_shape.rank(); ++i) { - new_size *= new_shape.extent(i); - } - RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch"); - - return detail::stdex::mdspan(mds.data_handle(), new_shape); -} - /** * @brief Reshape object implementing raft::array_interface * @@ -1111,36 +369,4 @@ auto reshape(const array_interface_type& mda, extents new return reshape(mda.view(), new_shape); } -/** - * \brief Turns linear index into coordinate. Similar to numpy unravel_index. - * - * \code - * auto m = make_host_matrix(7, 6); - * auto m_v = m.view(); - * auto coord = unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); - * std::apply(m_v, coord) = 2; - * \endcode - * - * \param idx The linear index. - * \param shape The shape of the array to use. - * \param layout Must be `layout_c_contiguous` (row-major) in current implementation. - * - * \return A std::tuple that represents the coordinate. - */ -template -MDSPAN_INLINE_FUNCTION auto unravel_index(Idx idx, - extents shape, - LayoutPolicy const& layout) -{ - static_assert(std::is_same_v>, - layout_c_contiguous>, - "Only C layout is supported."); - static_assert(std::is_integral_v, "Index must be integral."); - auto constexpr kIs64 = sizeof(std::remove_cv_t>) == sizeof(uint64_t); - if (kIs64 && static_cast(idx) > std::numeric_limits::max()) { - return detail::unravel_index_impl(static_cast(idx), shape); - } else { - return detail::unravel_index_impl(static_cast(idx), shape); - } -} } // namespace raft diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 809134e96e..6078c730bf 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -22,4 +22,267 @@ */ #pragma once -#include \ No newline at end of file +#include +#include // dynamic_extent +#include + +namespace raft { +/** + * @brief Dimensions extents for raft::mdspan + */ +template +using extents = std::experimental::extents; + +namespace stdex = std::experimental; + +/** + * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. + * @{ + */ +using stdex::layout_right; +using layout_c_contiguous = layout_right; +using row_major = layout_right; +/** @} */ + +/** + * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. + * @{ + */ +using stdex::layout_left; +using layout_f_contiguous = layout_left; +using col_major = layout_left; +/** @} */ + +template +using vector_extent = stdex::extents; + +template +using matrix_extent = stdex::extents; + +template +using scalar_extent = stdex::extents; + +/** + * @brief Strided layout for non-contiguous memory. + */ +using stdex::layout_stride; + +template +using extent_1d = vector_extent; + +template +using extent_2d = matrix_extent; + +template +using extent_3d = stdex::extents; + +template +using extent_4d = + stdex::extents; + +template +using extent_5d = stdex::extents; +/** @} */ + +template > +using mdspan = stdex::mdspan; + +/** + * Ensure all types listed in the parameter pack `Extents` are integral types. + * Usage: + * put it as the last nameless template parameter of a function: + * `typename = ensure_integral_extents` + */ +template +using ensure_integral_extents = std::enable_if_t...>>; + +/** + * @\brief Template checks and helpers to determine if type T is an std::mdspan + * or a derived type + */ + +template +void __takes_an_mdspan_ptr(mdspan*); + +template +struct is_mdspan : std::false_type { +}; +template +struct is_mdspan()))>> + : std::true_type { +}; + +template +using is_mdspan_t = is_mdspan>; + +// template +// inline constexpr bool is_mdspan_v = is_mdspan_t::value; + +/** + * @\brief Boolean to determine if variadic template types Tn are either + * raft::host_mdspan/raft::device_mdspan or their derived types + */ +template +inline constexpr bool is_mdspan_v = std::conjunction_v...>; + +template +using enable_if_mdspan = std::enable_if_t>; + +// uint division optimization inspired by the CIndexer in cupy. Division operation is +// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 +// bit when the index is smaller, then try to avoid division when it's exp of 2. +template +MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents shape) +{ + constexpr auto kRank = static_cast(shape.rank()); + std::size_t index[shape.rank()]{0}; // NOLINT + static_assert(std::is_signed::value, + "Don't change the type without changing the for loop."); + for (int32_t dim = kRank; --dim > 0;) { + auto s = static_cast>>(shape.extent(dim)); + if (s & (s - 1)) { + auto t = idx / s; + index[dim] = idx - t * s; + idx = t; + } else { // exp of 2 + index[dim] = idx & (s - 1); + idx >>= popc(s - 1); + } + } + index[0] = idx; + return arr_to_tup(index); +} + +/** + * @brief Create a raft::mdspan + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @tparam is_host_accessible whether the data is accessible on host + * @tparam is_device_accessible whether the data is accessible on device + * @param ptr Pointer to the data + * @param exts dimensionality of the array (series of integers) + * @return raft::mdspan + */ +template +auto make_mdspan(ElementType* ptr, extents exts) +{ + using accessor_type = detail::accessor_mixin, + is_host_accessible, + is_device_accessible>; + + return mdspan{ptr, exts}; +} + +/** + * @brief Create raft::extents to specify dimensionality + * + * @tparam IndexType The type of each dimension of the extents + * @tparam Extents Dimensions (a series of integers) + * @param exts The desired dimensions + * @return raft::extents + */ +template > +auto make_extents(Extents... exts) +{ + return extents{exts...}; +} + +/** + * @brief Flatten raft::mdspan into a 1-dim array view + * + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @param mds raft::host_mdspan or raft::device_mdspan object + * @return raft::host_mdspan or raft::device_mdspan with vector_extent + * depending on AccessoryPolicy + */ +template > +auto flatten(mdspan_type mds) +{ + RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous."); + + vector_extent ext{mds.size()}; + + return stdex::mdspan(mds.data_handle(), ext); +} + +/** + * @brief Reshape raft::host_mdspan or raft::device_mdspan + * + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @tparam IndexType the index type of the extents + * @tparam Extents raft::extents for dimensions + * @param mds raft::host_mdspan or raft::device_mdspan object + * @param new_shape Desired new shape of the input + * @return raft::host_mdspan or raft::device_mdspan, depending on AccessorPolicy + */ +template > +auto reshape(mdspan_type mds, extents new_shape) +{ + RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous."); + + size_t new_size = 1; + for (size_t i = 0; i < new_shape.rank(); ++i) { + new_size *= new_shape.extent(i); + } + RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch"); + + return stdex::mdspan(mds.data_handle(), new_shape); +} + +/** + * \brief Turns linear index into coordinate. Similar to numpy unravel_index. + * + * \code + * auto m = make_host_matrix(7, 6); + * auto m_v = m.view(); + * auto coord = unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); + * std::apply(m_v, coord) = 2; + * \endcode + * + * \param idx The linear index. + * \param shape The shape of the array to use. + * \param layout Must be `layout_c_contiguous` (row-major) in current implementation. + * + * \return A std::tuple that represents the coordinate. + */ +template +MDSPAN_INLINE_FUNCTION auto unravel_index(Idx idx, + extents shape, + LayoutPolicy const& layout) +{ + static_assert(std::is_same_v>, + layout_c_contiguous>, + "Only C layout is supported."); + static_assert(std::is_integral_v, "Index must be integral."); + auto constexpr kIs64 = sizeof(std::remove_cv_t>) == sizeof(uint64_t); + if (kIs64 && static_cast(idx) > std::numeric_limits::max()) { + return unravel_index_impl(static_cast(idx), shape); + } else { + return unravel_index_impl(static_cast(idx), shape); + } +} + +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index dd813a7c18..e3d2b0bf9e 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -21,7 +21,9 @@ * limitations under the License. */ #pragma once +#include #include + #include #include // dynamic_extent @@ -228,44 +230,6 @@ class host_vector_policy { [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } }; -/** - * @brief A mixin to distinguish host and device memory. - */ -template -struct accessor_mixin : public AccessorPolicy { - using accessor_type = AccessorPolicy; - using is_host_type = std::conditional_t; - using is_device_type = std::conditional_t; - using is_managed_type = std::conditional_t; - static constexpr bool is_host_accessible = is_host; - static constexpr bool is_device_accessible = is_device; - static constexpr bool is_managed_accessible = is_device && is_host; - // make sure the explicit ctor can fall through - using AccessorPolicy::AccessorPolicy; - using offset_policy = accessor_mixin; - accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT -}; - -template -using host_accessor = accessor_mixin; - -template -using device_accessor = accessor_mixin; - -template -using managed_accessor = accessor_mixin; - -namespace stdex = std::experimental; - -template -using vector_extent = stdex::extents; - -template -using matrix_extent = stdex::extents; - -template -using scalar_extent = stdex::extents; - template MDSPAN_INLINE_FUNCTION auto native_popc(T v) -> int32_t { @@ -310,38 +274,4 @@ MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N]) return arr_to_tup(arr, std::make_index_sequence{}); } -// uint division optimization inspired by the CIndexer in cupy. Division operation is -// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 -// bit when the index is smaller, then try to avoid division when it's exp of 2. -template -MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents shape) -{ - constexpr auto kRank = static_cast(shape.rank()); - std::size_t index[shape.rank()]{0}; // NOLINT - static_assert(std::is_signed::value, - "Don't change the type without changing the for loop."); - for (int32_t dim = kRank; --dim > 0;) { - auto s = static_cast>>(shape.extent(dim)); - if (s & (s - 1)) { - auto t = idx / s; - index[dim] = idx - t * s; - idx = t; - } else { // exp of 2 - index[dim] = idx & (s - 1); - idx >>= popc(s - 1); - } - } - index[0] = idx; - return arr_to_tup(index); -} - -/** - * Ensure all types listed in the parameter pack `Extents` are integral types. - * Usage: - * put it as the last nameless template parameter of a function: - * `typename = ensure_integral_extents` - */ -template -using ensure_integral_extents = std::enable_if_t...>>; - } // namespace raft::detail diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/detail/span.hpp index 555b47dcae..de76ff3138 100644 --- a/cpp/include/raft/detail/span.hpp +++ b/cpp/include/raft/detail/span.hpp @@ -16,7 +16,7 @@ #pragma once #include // numeric_limits -#include +#include #include // __host__ __device__ #include From 0e6cc8618a139fda8f476a28edcd616c13d717db Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 15:28:18 -0400 Subject: [PATCH 02/11] Updates --- cpp/include/raft/cluster/detail/kmeans.cuh | 3 +- .../raft/cluster/detail/kmeans_common.cuh | 2 +- cpp/include/raft/core/detail/mdspan_util.hpp | 68 +++++++++++++++++++ cpp/include/raft/core/device_mdarray.hpp | 3 +- cpp/include/raft/core/host_mdarray.hpp | 4 +- cpp/include/raft/core/mdspan.hpp | 6 +- cpp/include/raft/detail/mdarray.hpp | 44 ------------ cpp/include/raft/linalg/transpose.cuh | 2 +- cpp/include/raft/random/make_blobs.cuh | 2 +- .../raft/spatial/knn/ivf_flat_types.hpp | 2 +- cpp/test/linalg/transpose.cu | 3 +- cpp/test/mdarray.cu | 10 +-- cpp/test/mdspan_utils.cu | 7 +- cpp/test/random/make_blobs.cu | 4 +- 14 files changed, 95 insertions(+), 65 deletions(-) create mode 100644 cpp/include/raft/core/detail/mdspan_util.hpp diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 303de77078..b3f17295aa 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -29,9 +29,10 @@ #include #include #include +#include #include +#include #include -#include #include #include #include diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 358c8ce16e..4a77d230c2 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -29,9 +29,9 @@ #include #include +#include #include #include -#include #include #include #include diff --git a/cpp/include/raft/core/detail/mdspan_util.hpp b/cpp/include/raft/core/detail/mdspan_util.hpp new file mode 100644 index 0000000000..af5eabf3e3 --- /dev/null +++ b/cpp/include/raft/core/detail/mdspan_util.hpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::detail { + +template +MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N], std::index_sequence) +{ + return std::make_tuple(arr[Idx]...); +} + +template +MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N]) +{ + return arr_to_tup(arr, std::make_index_sequence{}); +} + +template +MDSPAN_INLINE_FUNCTION auto native_popc(T v) -> int32_t +{ + int c = 0; + for (; v != 0; v &= v - 1) { + c++; + } + return c; +} + +MDSPAN_INLINE_FUNCTION auto popc(uint32_t v) -> int32_t +{ +#if defined(__CUDA_ARCH__) + return __popc(v); +#elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcount(v); +#else + return native_popc(v); +#endif // compiler +} + +MDSPAN_INLINE_FUNCTION auto popc(uint64_t v) -> int32_t +{ +#if defined(__CUDA_ARCH__) + return __popcll(v); +#elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcountll(v); +#else + return native_popc(v); +#endif // compiler +} + +} // end namespace raft::detail \ No newline at end of file diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index c5b5d73fdc..393ff45815 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -16,6 +16,7 @@ #pragma once +#include #include namespace raft { @@ -32,7 +33,7 @@ template > using device_mdarray = - mdarray>; + mdarray>; /** * @brief Shorthand for 0-dim host mdarray (scalar). diff --git a/cpp/include/raft/core/host_mdarray.hpp b/cpp/include/raft/core/host_mdarray.hpp index 872c007255..448a639390 100644 --- a/cpp/include/raft/core/host_mdarray.hpp +++ b/cpp/include/raft/core/host_mdarray.hpp @@ -16,6 +16,7 @@ #pragma once +#include #include namespace raft { @@ -30,8 +31,7 @@ template > -using host_mdarray = - mdarray>; +using host_mdarray = mdarray>; /** * @brief Shorthand for 0-dim host mdarray (scalar). diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 6078c730bf..93645e01cc 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -23,6 +23,8 @@ #pragma once #include +#include +#include #include // dynamic_extent #include @@ -154,11 +156,11 @@ MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents>= popc(s - 1); + idx >>= detail::popc(s - 1); } } index[0] = idx; - return arr_to_tup(index); + return detail::arr_to_tup(index); } /** diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index e3d2b0bf9e..1c13f5d1fc 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -230,48 +230,4 @@ class host_vector_policy { [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } }; -template -MDSPAN_INLINE_FUNCTION auto native_popc(T v) -> int32_t -{ - int c = 0; - for (; v != 0; v &= v - 1) { - c++; - } - return c; -} - -MDSPAN_INLINE_FUNCTION auto popc(uint32_t v) -> int32_t -{ -#if defined(__CUDA_ARCH__) - return __popc(v); -#elif defined(__GNUC__) || defined(__clang__) - return __builtin_popcount(v); -#else - return native_popc(v); -#endif // compiler -} - -MDSPAN_INLINE_FUNCTION auto popc(uint64_t v) -> int32_t -{ -#if defined(__CUDA_ARCH__) - return __popcll(v); -#elif defined(__GNUC__) || defined(__clang__) - return __builtin_popcountll(v); -#else - return native_popc(v); -#endif // compiler -} - -template -MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N], std::index_sequence) -{ - return std::make_tuple(arr[Idx]...); -} - -template -MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N]) -{ - return arr_to_tup(arr, std::make_index_sequence{}); -} - } // namespace raft::detail diff --git a/cpp/include/raft/linalg/transpose.cuh b/cpp/include/raft/linalg/transpose.cuh index cd78a2f495..e765ea7925 100644 --- a/cpp/include/raft/linalg/transpose.cuh +++ b/cpp/include/raft/linalg/transpose.cuh @@ -19,7 +19,7 @@ #pragma once #include "detail/transpose.cuh" -#include +#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/random/make_blobs.cuh b/cpp/include/raft/random/make_blobs.cuh index 8bd78d98eb..82c940b471 100644 --- a/cpp/include/raft/random/make_blobs.cuh +++ b/cpp/include/raft/random/make_blobs.cuh @@ -21,7 +21,7 @@ #include "detail/make_blobs.cuh" #include -#include +#include namespace raft::random { diff --git a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp index 02c4e30c1f..9beba05a5c 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp +++ b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp @@ -18,8 +18,8 @@ #include "common.hpp" +#include #include -#include #include #include diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index 98f6d5e7e4..6edf9448b0 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -233,7 +234,7 @@ void test_transpose_submatrix() } auto vv = v.view(); - auto submat = raft::detail::stdex::submdspan( + auto submat = raft::stdex::submdspan( vv, std::make_tuple(row_beg, row_end), std::make_tuple(col_beg, col_end)); static_assert(std::is_same_v); diff --git a/cpp/test/mdarray.cu b/cpp/test/mdarray.cu index af7bb7adf3..39cf76333b 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/mdarray.cu @@ -14,8 +14,8 @@ * limitations under the License. */ #include -#include -#include +#include +#include #include #include #include @@ -467,19 +467,19 @@ void test_mdarray_unravel() // examples from numpy unravel_index { - auto coord = unravel_index(22, detail::matrix_extent{7, 6}, stdex::layout_right{}); + auto coord = unravel_index(22, matrix_extent{7, 6}, stdex::layout_right{}); static_assert(std::tuple_size::value == 2); ASSERT_EQ(std::get<0>(coord), 3); ASSERT_EQ(std::get<1>(coord), 4); } { - auto coord = unravel_index(41, detail::matrix_extent{7, 6}, stdex::layout_right{}); + auto coord = unravel_index(41, matrix_extent{7, 6}, stdex::layout_right{}); static_assert(std::tuple_size::value == 2); ASSERT_EQ(std::get<0>(coord), 6); ASSERT_EQ(std::get<1>(coord), 5); } { - auto coord = unravel_index(37, detail::matrix_extent{7, 6}, stdex::layout_right{}); + auto coord = unravel_index(37, matrix_extent{7, 6}, stdex::layout_right{}); static_assert(std::tuple_size::value == 2); ASSERT_EQ(std::get<0>(coord), 6); ASSERT_EQ(std::get<1>(coord), 1); diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 0d7d180b8f..7f1efb78bb 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -15,7 +15,8 @@ */ #include -#include +#include +#include namespace raft { @@ -24,7 +25,7 @@ namespace stdex = std::experimental; template > + typename AccessorPolicy = stdex::default_accessor> struct derived_device_mdspan : public device_mdspan { }; @@ -37,7 +38,7 @@ void test_template_asserts() using d_mdspan = derived_device_mdspan; static_assert( - std::is_same_v, device_mdspan>>, + std::is_same_v, device_mdspan>>, "not same"); static_assert(std::is_same_v, device_mdspan>>, diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index 3f75a4cf0a..2971f62c56 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -18,8 +18,8 @@ #include #include #include -#include -#include +#include +#include #include namespace raft { From 71922f61ab4b9c1e4319d19efd48069b994fccbf Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 15:33:14 -0400 Subject: [PATCH 03/11] Fixing style --- cpp/test/linalg/transpose.cu | 2 +- cpp/test/mdarray.cu | 2 +- cpp/test/random/make_blobs.cu | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index 6edf9448b0..b5ad073f62 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -20,8 +20,8 @@ #include #include -#include #include +#include #include diff --git a/cpp/test/mdarray.cu b/cpp/test/mdarray.cu index 39cf76333b..0954073d86 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/mdarray.cu @@ -14,8 +14,8 @@ * limitations under the License. */ #include -#include #include +#include #include #include #include diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index 2971f62c56..d06fa4c1cc 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -17,9 +17,9 @@ #include "../test_utils.h" #include #include -#include #include #include +#include #include namespace raft { From b0e5a02233597d2c2691c194bda27d5903bfaa86 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 15:47:55 -0400 Subject: [PATCH 04/11] Separating host_span and device_span as well --- cpp/include/raft/core/device_span.hpp | 29 +++++++++++++ cpp/include/raft/core/host_span.hpp | 28 ++++++++++++ cpp/include/raft/core/mdarray.hpp | 2 +- cpp/include/raft/core/mdspan.hpp | 61 ++++++++++++++------------- cpp/include/raft/core/span.hpp | 42 ++++++++---------- cpp/include/raft/detail/span.hpp | 30 +++++++------ cpp/test/span.cpp | 2 +- cpp/test/span.cu | 2 +- 8 files changed, 126 insertions(+), 70 deletions(-) create mode 100644 cpp/include/raft/core/device_span.hpp create mode 100644 cpp/include/raft/core/host_span.hpp diff --git a/cpp/include/raft/core/device_span.hpp b/cpp/include/raft/core/device_span.hpp new file mode 100644 index 0000000000..0730b20bfb --- /dev/null +++ b/cpp/include/raft/core/device_span.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { + +/** + * @brief A span class for device pointer. + */ +template +using device_span = span; + +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_span.hpp b/cpp/include/raft/core/host_span.hpp new file mode 100644 index 0000000000..3cad62b7cd --- /dev/null +++ b/cpp/include/raft/core/host_span.hpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { +/** + * @brief A span class for host pointer. + */ +template +using host_span = span; + +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index e918d81ff1..1110ecb75e 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -104,7 +104,7 @@ inline constexpr bool is_array_interface_v = is_array_interface::value; * are some inconsistencies in between them. We have made some modificiations to fit our * needs, which are listed below. * - * - Layout policy is different, the mdarray in raft uses `stdex::extent` directly just + * - Layout policy is different, the mdarray in raft uses `std::experimental::extent` directly just * like `mdspan`, while the `mdarray` in the reference implementation uses varidic * template. * diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 93645e01cc..64a69db171 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -25,23 +25,23 @@ #include #include #include -#include // dynamic_extent #include namespace raft { + +constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; + /** * @brief Dimensions extents for raft::mdspan */ template using extents = std::experimental::extents; -namespace stdex = std::experimental; - /** * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. * @{ */ -using stdex::layout_right; +using std::experimental::layout_right; using layout_c_contiguous = layout_right; using row_major = layout_right; /** @} */ @@ -50,24 +50,24 @@ using row_major = layout_right; * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. * @{ */ -using stdex::layout_left; +using std::experimental::layout_left; using layout_f_contiguous = layout_left; using col_major = layout_left; /** @} */ template -using vector_extent = stdex::extents; +using vector_extent = std::experimental::extents; template -using matrix_extent = stdex::extents; +using matrix_extent = std::experimental::extents; template -using scalar_extent = stdex::extents; +using scalar_extent = std::experimental::extents; /** * @brief Strided layout for non-contiguous memory. */ -using stdex::layout_stride; +using std::experimental::layout_stride; template using extent_1d = vector_extent; @@ -76,26 +76,27 @@ template using extent_2d = matrix_extent; template -using extent_3d = stdex::extents; +using extent_3d = + std::experimental::extents; template -using extent_4d = - stdex::extents; +using extent_4d = std::experimental:: + extents; template -using extent_5d = stdex::extents; +using extent_5d = std::experimental::extents; /** @} */ template > -using mdspan = stdex::mdspan; + typename AccessorPolicy = std::experimental::default_accessor> +using mdspan = std::experimental::mdspan; /** * Ensure all types listed in the parameter pack `Extents` are integral types. @@ -142,7 +143,8 @@ using enable_if_mdspan = std::enable_if_t>; // slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 // bit when the index is smaller, then try to avoid division when it's exp of 2. template -MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents shape) +MDSPAN_INLINE_FUNCTION auto unravel_index_impl( + I idx, std::experimental::extents shape) { constexpr auto kRank = static_cast(shape.rank()); std::size_t index[shape.rank()]{0}; // NOLINT @@ -218,10 +220,10 @@ auto flatten(mdspan_type mds) vector_extent ext{mds.size()}; - return stdex::mdspan(mds.data_handle(), ext); + return std::experimental::mdspan(mds.data_handle(), ext); } /** @@ -248,10 +250,11 @@ auto reshape(mdspan_type mds, extents new_shape) } RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch"); - return stdex::mdspan(mds.data_handle(), new_shape); + return std::experimental::mdspan(mds.data_handle(), + new_shape); } /** diff --git a/cpp/include/raft/core/span.hpp b/cpp/include/raft/core/span.hpp index 96950e979e..e32cf47138 100644 --- a/cpp/include/raft/core/span.hpp +++ b/cpp/include/raft/core/span.hpp @@ -35,7 +35,7 @@ namespace raft { * auto view = device_span{uvec.data(), uvec.size()}; * @endcode */ -template +template class span { public: using element_type = T; @@ -62,7 +62,7 @@ class span { */ constexpr span(pointer ptr, size_type count) noexcept : storage_{ptr, count} { - assert(!(Extent != dynamic_extent && count != Extent)); + assert(!(Extent != std::experimental::dynamic_extent && count != Extent)); assert(ptr || count == 0); } /** @@ -159,7 +159,8 @@ class span { return {data(), Count}; } - constexpr auto first(std::size_t _count) const -> span + constexpr auto first(std::size_t _count) const + -> span { assert(_count <= size()); return {data(), _count}; @@ -172,47 +173,40 @@ class span { return {data() + size() - Count, Count}; } - constexpr auto last(std::size_t _count) const -> span + constexpr auto last(std::size_t _count) const + -> span { assert(_count <= size()); return subspan(size() - _count, _count); } /*! - * If Count is std::dynamic_extent, r.size() == this->size() - Offset; + * If Count is std::std::experimental::dynamic_extent, r.size() == this->size() - Offset; * Otherwise r.size() == Count. */ - template + template constexpr auto subspan() const -> span::value> { - assert((Count == dynamic_extent) ? (Offset <= size()) : (Offset + Count <= size())); - return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; + assert((Count == std::experimental::dynamic_extent) ? (Offset <= size()) + : (Offset + Count <= size())); + return {data() + Offset, Count == std::experimental::dynamic_extent ? size() - Offset : Count}; } - constexpr auto subspan(size_type _offset, size_type _count = dynamic_extent) const - -> span + constexpr auto subspan(size_type _offset, + size_type _count = std::experimental::dynamic_extent) const + -> span { - assert((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size())); - return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; + assert((_count == std::experimental::dynamic_extent) ? (_offset <= size()) + : (_offset + _count <= size())); + return {data() + _offset, + _count == std::experimental::dynamic_extent ? size() - _offset : _count}; } private: detail::span_storage storage_; }; -/** - * @brief A span class for host pointer. - */ -template -using host_span = span; - -/** - * @brief A span class for device pointer. - */ -template -using device_span = span; - template constexpr auto operator==(span l, span r) -> bool { diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/detail/span.hpp index de76ff3138..408224a617 100644 --- a/cpp/include/raft/detail/span.hpp +++ b/cpp/include/raft/detail/span.hpp @@ -17,11 +17,9 @@ #include // numeric_limits #include -#include // __host__ __device__ #include namespace raft { -constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; template class span; @@ -30,31 +28,35 @@ namespace detail { /*! * The extent E of the span returned by subspan is determined as follows: * - * - If Count is not dynamic_extent, Count; - * - Otherwise, if Extent is not dynamic_extent, Extent - Offset; - * - Otherwise, dynamic_extent. + * - If Count is not std::experimental::dynamic_extent, Count; + * - Otherwise, if Extent is not std::experimental::dynamic_extent, Extent - Offset; + * - Otherwise, std::experimental::dynamic_extent. */ template struct extent_value_t - : public std::integral_constant< - std::size_t, - Count != dynamic_extent ? Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> { + : public std::integral_constant { }; /*! - * If N is dynamic_extent, the extent of the returned span E is also - * dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. + * If N is std::experimental::dynamic_extent, the extent of the returned span E is also + * std::experimental::dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. */ template struct extent_as_bytes_value_t - : public std::integral_constant { + : public std::integral_constant< + std::size_t, + Extent == std::experimental::dynamic_extent ? Extent : sizeof(T) * Extent> { }; template struct is_allowed_extent_conversion_t : public std::integral_constant { + From == To || From == std::experimental::dynamic_extent || + To == std::experimental::dynamic_extent> { }; template @@ -101,7 +103,7 @@ struct span_storage { }; template -struct span_storage { +struct span_storage { private: T* ptr_{nullptr}; std::size_t size_{0}; diff --git a/cpp/test/span.cpp b/cpp/test/span.cpp index 6163811b95..0867427e1b 100644 --- a/cpp/test/span.cpp +++ b/cpp/test/span.cpp @@ -16,7 +16,7 @@ #include "test_span.hpp" #include #include // iota -#include +#include namespace raft { TEST(Span, DlfConstructors) diff --git a/cpp/test/span.cu b/cpp/test/span.cu index dcde9b5432..e03d1ddf5a 100644 --- a/cpp/test/span.cu +++ b/cpp/test/span.cu @@ -18,7 +18,7 @@ #include // iota #include #include -#include +#include #include #include From d69b16335d8de5880761855cfa403f5e5e54d9da Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 17:21:58 -0400 Subject: [PATCH 05/11] Cleanup and getting to build --- cpp/include/raft/core/device_mdspan.hpp | 6 +- cpp/include/raft/core/host_mdspan.hpp | 4 +- cpp/include/raft/core/mdarray.hpp | 1 + cpp/include/raft/core/mdspan.hpp | 62 +----------------- cpp/include/raft/core/mdspan_types.hpp | 85 +++++++++++++++++++++++++ cpp/include/raft/core/span.hpp | 41 ++++++------ cpp/include/raft/detail/span.hpp | 31 ++++----- cpp/test/linalg/transpose.cu | 2 +- cpp/test/span.cpp | 2 +- cpp/test/span.cu | 2 +- 10 files changed, 127 insertions(+), 109 deletions(-) create mode 100644 cpp/include/raft/core/mdspan_types.hpp diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index f835397e9d..88cd3dbf8e 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -28,18 +28,18 @@ template using managed_accessor = detail::accessor_mixin; /** - * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. + * @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location. */ template > + typename AccessorPolicy = std::experimental::default_accessor> using device_mdspan = mdspan>; template > + typename AccessorPolicy = std::experimental::default_accessor> using managed_mdspan = mdspan>; namespace detail { diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index f46fb6ff17..4602088a44 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -25,12 +25,12 @@ template using host_accessor = detail::accessor_mixin; /** - * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. + * @brief std::experimental::mdspan with host tag to avoid accessing incorrect memory location. */ template > + typename AccessorPolicy = std::experimental::default_accessor> using host_mdspan = mdspan>; namespace detail { diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 1110ecb75e..611d01fb70 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 64a69db171..55b651d69f 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -25,71 +25,11 @@ #include #include #include +#include #include namespace raft { -constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; - -/** - * @brief Dimensions extents for raft::mdspan - */ -template -using extents = std::experimental::extents; - -/** - * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. - * @{ - */ -using std::experimental::layout_right; -using layout_c_contiguous = layout_right; -using row_major = layout_right; -/** @} */ - -/** - * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. - * @{ - */ -using std::experimental::layout_left; -using layout_f_contiguous = layout_left; -using col_major = layout_left; -/** @} */ - -template -using vector_extent = std::experimental::extents; - -template -using matrix_extent = std::experimental::extents; - -template -using scalar_extent = std::experimental::extents; - -/** - * @brief Strided layout for non-contiguous memory. - */ -using std::experimental::layout_stride; - -template -using extent_1d = vector_extent; - -template -using extent_2d = matrix_extent; - -template -using extent_3d = - std::experimental::extents; - -template -using extent_4d = std::experimental:: - extents; - -template -using extent_5d = std::experimental::extents; /** @} */ template + +namespace raft { + +constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; + +/** + * @brief Dimensions extents for raft::mdspan + */ +template +using extents = std::experimental::extents; + +/** + * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. + * @{ + */ +using std::experimental::layout_right; +using layout_c_contiguous = layout_right; +using row_major = layout_right; +/** @} */ + +/** + * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. + * @{ + */ +using std::experimental::layout_left; +using layout_f_contiguous = layout_left; +using col_major = layout_left; +/** @} */ + +template +using vector_extent = std::experimental::extents; + +template +using matrix_extent = std::experimental::extents; + +template +using scalar_extent = std::experimental::extents; + +/** + * @brief Strided layout for non-contiguous memory. + */ +using std::experimental::layout_stride; + +template +using extent_1d = vector_extent; + +template +using extent_2d = matrix_extent; + +template +using extent_3d = + std::experimental::extents; + +template +using extent_4d = std::experimental:: + extents; + +template +using extent_5d = std::experimental::extents; + +} // namespace raft diff --git a/cpp/include/raft/core/span.hpp b/cpp/include/raft/core/span.hpp index e32cf47138..3dec7e6fa8 100644 --- a/cpp/include/raft/core/span.hpp +++ b/cpp/include/raft/core/span.hpp @@ -18,10 +18,11 @@ #include #include // size_t #include // std::byte +#include #include #include #include -#include // __host__ __device__ +#include // _MDSPAN_HOST_DEVICE #include #include @@ -35,7 +36,7 @@ namespace raft { * auto view = device_span{uvec.data(), uvec.size()}; * @endcode */ -template +template class span { public: using element_type = T; @@ -62,7 +63,7 @@ class span { */ constexpr span(pointer ptr, size_type count) noexcept : storage_{ptr, count} { - assert(!(Extent != std::experimental::dynamic_extent && count != Extent)); + assert(!(Extent != dynamic_extent && count != Extent)); assert(ptr || count == 0); } /** @@ -108,22 +109,22 @@ class span { constexpr auto cend() const noexcept -> const_iterator { return data() + size(); } - __host__ __device__ constexpr auto rbegin() const noexcept -> reverse_iterator + _MDSPAN_HOST_DEVICE constexpr auto rbegin() const noexcept -> reverse_iterator { return reverse_iterator{end()}; } - __host__ __device__ constexpr auto rend() const noexcept -> reverse_iterator + _MDSPAN_HOST_DEVICE constexpr auto rend() const noexcept -> reverse_iterator { return reverse_iterator{begin()}; } - __host__ __device__ constexpr auto crbegin() const noexcept -> const_reverse_iterator + _MDSPAN_HOST_DEVICE constexpr auto crbegin() const noexcept -> const_reverse_iterator { return const_reverse_iterator{cend()}; } - __host__ __device__ constexpr auto crend() const noexcept -> const_reverse_iterator + _MDSPAN_HOST_DEVICE constexpr auto crend() const noexcept -> const_reverse_iterator { return const_reverse_iterator{cbegin()}; } @@ -159,8 +160,7 @@ class span { return {data(), Count}; } - constexpr auto first(std::size_t _count) const - -> span + constexpr auto first(std::size_t _count) const -> span { assert(_count <= size()); return {data(), _count}; @@ -173,34 +173,29 @@ class span { return {data() + size() - Count, Count}; } - constexpr auto last(std::size_t _count) const - -> span + constexpr auto last(std::size_t _count) const -> span { assert(_count <= size()); return subspan(size() - _count, _count); } /*! - * If Count is std::std::experimental::dynamic_extent, r.size() == this->size() - Offset; + * If Count is std::dynamic_extent, r.size() == this->size() - Offset; * Otherwise r.size() == Count. */ - template + template constexpr auto subspan() const -> span::value> { - assert((Count == std::experimental::dynamic_extent) ? (Offset <= size()) - : (Offset + Count <= size())); - return {data() + Offset, Count == std::experimental::dynamic_extent ? size() - Offset : Count}; + assert((Count == dynamic_extent) ? (Offset <= size()) : (Offset + Count <= size())); + return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; } - constexpr auto subspan(size_type _offset, - size_type _count = std::experimental::dynamic_extent) const - -> span + constexpr auto subspan(size_type _offset, size_type _count = dynamic_extent) const + -> span { - assert((_count == std::experimental::dynamic_extent) ? (_offset <= size()) - : (_offset + _count <= size())); - return {data() + _offset, - _count == std::experimental::dynamic_extent ? size() - _offset : _count}; + assert((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size())); + return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; } private: diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/detail/span.hpp index 408224a617..c11e6ba32b 100644 --- a/cpp/include/raft/detail/span.hpp +++ b/cpp/include/raft/detail/span.hpp @@ -16,6 +16,7 @@ #pragma once #include // numeric_limits +#include #include #include @@ -28,35 +29,31 @@ namespace detail { /*! * The extent E of the span returned by subspan is determined as follows: * - * - If Count is not std::experimental::dynamic_extent, Count; - * - Otherwise, if Extent is not std::experimental::dynamic_extent, Extent - Offset; - * - Otherwise, std::experimental::dynamic_extent. + * - If Count is not dynamic_extent, Count; + * - Otherwise, if Extent is not dynamic_extent, Extent - Offset; + * - Otherwise, dynamic_extent. */ template struct extent_value_t - : public std::integral_constant { + : public std::integral_constant< + std::size_t, + Count != dynamic_extent ? Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> { }; /*! - * If N is std::experimental::dynamic_extent, the extent of the returned span E is also - * std::experimental::dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. + * If N is dynamic_extent, the extent of the returned span E is also + * dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. */ template struct extent_as_bytes_value_t - : public std::integral_constant< - std::size_t, - Extent == std::experimental::dynamic_extent ? Extent : sizeof(T) * Extent> { + : public std::integral_constant { }; template struct is_allowed_extent_conversion_t : public std::integral_constant { + From == To || From == dynamic_extent || To == dynamic_extent> { }; template @@ -77,7 +74,7 @@ struct is_span_t : public is_span_oracle_t::type> { }; template -__host__ __device__ constexpr auto lexicographical_compare(InputIt1 first1, +_MDSPAN_HOST_DEVICE constexpr auto lexicographical_compare(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) -> bool @@ -103,7 +100,7 @@ struct span_storage { }; template -struct span_storage { +struct span_storage { private: T* ptr_{nullptr}; std::size_t size_{0}; diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index b5ad073f62..adfaf6e49d 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -234,7 +234,7 @@ void test_transpose_submatrix() } auto vv = v.view(); - auto submat = raft::stdex::submdspan( + auto submat = std::experimental::submdspan( vv, std::make_tuple(row_beg, row_end), std::make_tuple(col_beg, col_end)); static_assert(std::is_same_v); diff --git a/cpp/test/span.cpp b/cpp/test/span.cpp index 0867427e1b..f8d9345a12 100644 --- a/cpp/test/span.cpp +++ b/cpp/test/span.cpp @@ -16,7 +16,7 @@ #include "test_span.hpp" #include #include // iota -#include +#include namespace raft { TEST(Span, DlfConstructors) diff --git a/cpp/test/span.cu b/cpp/test/span.cu index e03d1ddf5a..91833d4dc7 100644 --- a/cpp/test/span.cu +++ b/cpp/test/span.cu @@ -16,9 +16,9 @@ #include "test_span.hpp" #include #include // iota +#include #include #include -#include #include #include From 50d750b63b823ee81ed455fed1ef2c868e5de270 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 18:00:43 -0400 Subject: [PATCH 06/11] Updates --- .../detail/device_mdarray.hpp} | 50 ++------------ cpp/include/raft/core/detail/host_mdarray.hpp | 69 +++++++++++++++++++ .../{mdspan_util.hpp => mdspan_util.cuh} | 1 + cpp/include/raft/{ => core}/detail/span.hpp | 0 cpp/include/raft/core/device_mdarray.hpp | 1 + cpp/include/raft/core/host_mdarray.hpp | 2 + cpp/include/raft/core/host_mdspan.hpp | 3 +- cpp/include/raft/core/mdarray.hpp | 16 ++--- cpp/include/raft/core/mdspan.hpp | 6 +- cpp/include/raft/core/span.hpp | 6 +- cpp/include/raft/distance/distance.cuh | 2 +- 11 files changed, 98 insertions(+), 58 deletions(-) rename cpp/include/raft/{detail/mdarray.hpp => core/detail/device_mdarray.hpp} (77%) create mode 100644 cpp/include/raft/core/detail/host_mdarray.hpp rename cpp/include/raft/core/detail/{mdspan_util.hpp => mdspan_util.cuh} (99%) rename cpp/include/raft/{ => core}/detail/span.hpp (100%) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/core/detail/device_mdarray.hpp similarity index 77% rename from cpp/include/raft/detail/mdarray.hpp rename to cpp/include/raft/core/detail/device_mdarray.hpp index 1c13f5d1fc..569f573c19 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/core/detail/device_mdarray.hpp @@ -21,11 +21,12 @@ * limitations under the License. */ #pragma once -#include -#include +#include +#include +#include -#include -#include // dynamic_extent +#include +#include // dynamic_extent #include #include @@ -189,45 +190,4 @@ class device_uvector_policy { [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } }; -/** - * @brief A container policy for host mdarray. - */ -template > -class host_vector_policy { - public: - using element_type = ElementType; - using container_type = std::vector; - using allocator_type = typename container_type::allocator_type; - using pointer = typename container_type::pointer; - using const_pointer = typename container_type::const_pointer; - using reference = element_type&; - using const_reference = element_type const&; - using accessor_policy = std::experimental::default_accessor; - using const_accessor_policy = std::experimental::default_accessor; - - public: - auto create(size_t n) -> container_type { return container_type(n); } - - constexpr host_vector_policy() noexcept(std::is_nothrow_default_constructible_v) = - default; - explicit constexpr host_vector_policy(rmm::cuda_stream_view) noexcept( - std::is_nothrow_default_constructible_v) - : host_vector_policy() - { - } - - [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference - { - return c[n]; - } - [[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept - -> const_reference - { - return c[n]; - } - - [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } - [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } -}; - } // namespace raft::detail diff --git a/cpp/include/raft/core/detail/host_mdarray.hpp b/cpp/include/raft/core/detail/host_mdarray.hpp new file mode 100644 index 0000000000..74bd55e78c --- /dev/null +++ b/cpp/include/raft/core/detail/host_mdarray.hpp @@ -0,0 +1,69 @@ +/* + * Copyright (2019) Sandia Corporation + * + * The source code is licensed under the 3-clause BSD license found in the LICENSE file + * thirdparty/LICENSES/mdarray.license + */ + +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace raft::detail { + +/** + * @brief A container policy for host mdarray. + */ +template > +class host_vector_policy { + public: + using element_type = ElementType; + using container_type = std::vector; + using allocator_type = typename container_type::allocator_type; + using pointer = typename container_type::pointer; + using const_pointer = typename container_type::const_pointer; + using reference = element_type&; + using const_reference = element_type const&; + using accessor_policy = std::experimental::default_accessor; + using const_accessor_policy = std::experimental::default_accessor; + + public: + auto create(size_t n) -> container_type { return container_type(n); } + + constexpr host_vector_policy() noexcept(std::is_nothrow_default_constructible_v) = + default; + explicit constexpr host_vector_policy(rmm::cuda_stream_view) noexcept( + std::is_nothrow_default_constructible_v) + : host_vector_policy() + { + } + + [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference + { + return c[n]; + } + [[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept + -> const_reference + { + return c[n]; + } + + [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } + [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } +}; +} // namespace raft::detail diff --git a/cpp/include/raft/core/detail/mdspan_util.hpp b/cpp/include/raft/core/detail/mdspan_util.cuh similarity index 99% rename from cpp/include/raft/core/detail/mdspan_util.hpp rename to cpp/include/raft/core/detail/mdspan_util.cuh index af5eabf3e3..b2f73857b2 100644 --- a/cpp/include/raft/core/detail/mdspan_util.hpp +++ b/cpp/include/raft/core/detail/mdspan_util.cuh @@ -16,6 +16,7 @@ #pragma once #include + #include #include diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/core/detail/span.hpp similarity index 100% rename from cpp/include/raft/detail/span.hpp rename to cpp/include/raft/core/detail/span.hpp diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 393ff45815..1b20629f5f 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -16,6 +16,7 @@ #pragma once +#include #include #include diff --git a/cpp/include/raft/core/host_mdarray.hpp b/cpp/include/raft/core/host_mdarray.hpp index 448a639390..6221ca59f0 100644 --- a/cpp/include/raft/core/host_mdarray.hpp +++ b/cpp/include/raft/core/host_mdarray.hpp @@ -17,6 +17,8 @@ #pragma once #include + +#include #include namespace raft { diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index 4602088a44..fcd637f3a5 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -16,9 +16,10 @@ #pragma once -#include #include +#include + namespace raft { template diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 611d01fb70..64e8504708 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -24,14 +24,11 @@ #include -#include -#include -#include +#include #include #include -#include + #include -#include namespace raft { /** @@ -158,9 +155,12 @@ class mdarray typename container_policy_type::const_accessor_policy, typename container_policy_type::accessor_policy>> using view_type_impl = - std::conditional_t, - device_mdspan>; + mdspan>; public: /** diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 55b651d69f..f608f3d085 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -22,10 +22,12 @@ */ #pragma once -#include -#include #include #include + +#include +#include + #include namespace raft { diff --git a/cpp/include/raft/core/span.hpp b/cpp/include/raft/core/span.hpp index 3dec7e6fa8..db3b25296b 100644 --- a/cpp/include/raft/core/span.hpp +++ b/cpp/include/raft/core/span.hpp @@ -19,7 +19,11 @@ #include // size_t #include // std::byte #include -#include + +#include + +// TODO (cjnolet): Remove thrust dependencies here so host_span can be used without CUDA Toolkit +// being installed. Reference: https://github.com/rapidsai/raft/issues/812. #include #include #include // _MDSPAN_HOST_DEVICE diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 3db1749bb4..8aaeadd531 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -23,7 +23,7 @@ #include #include -#include +#include /** * @defgroup pairwise_distance pairwise distance prims From 6fda1fd454260ed5478c414b090bfbedb1663a82 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 8 Sep 2022 09:05:41 -0400 Subject: [PATCH 07/11] Fixing docs --- cpp/include/raft/core/mdspan.hpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index f608f3d085..289cc82f1b 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -1,10 +1,3 @@ -/* - * Copyright (2019) Sandia Corporation - * - * The source code is licensed under the 3-clause BSD license found in the LICENSE file - * thirdparty/LICENSES/mdarray.license - */ - /* * Copyright (c) 2022, NVIDIA CORPORATION. * @@ -32,8 +25,6 @@ namespace raft { -/** @} */ - template Date: Thu, 8 Sep 2022 09:54:34 -0400 Subject: [PATCH 08/11] Updating readme to use proper header paths --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2159f128bf..8194c27cf9 100755 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ The APIs in RAFT currently accept raw pointers to device memory and we are in th The `mdarray` forms a convenience layer over RMM and can be constructed in RAFT using a number of different helper functions: ```c++ -#include +#include int n_rows = 10; int n_cols = 10; @@ -56,8 +56,8 @@ Most of the primitives in RAFT accept a `raft::handle_t` object for the manageme The example below demonstrates creating a RAFT handle and using it with `device_matrix` and `device_vector` to allocate memory, generating random clusters, and computing pairwise Euclidean distances: ```c++ -#include -#include +#include +#include #include #include From 3aeb5304f0024a05580c011bc36f30fdaff3d72e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 8 Sep 2022 17:14:40 -0400 Subject: [PATCH 09/11] More updates based on review feedback --- cpp/include/raft.hpp | 5 +-- .../raft/core/detail/device_mdarray.hpp | 2 +- ...sor_mixin.hpp => host_device_accessor.hpp} | 6 ++-- cpp/include/raft/core/detail/macros.hpp | 35 ++++++++++++++++++ cpp/include/raft/core/detail/mdspan_util.cuh | 1 + cpp/include/raft/core/detail/span.hpp | 10 +++--- cpp/include/raft/core/device_mdarray.hpp | 2 +- cpp/include/raft/core/device_mdspan.hpp | 21 +++++------ cpp/include/raft/core/host_mdspan.hpp | 19 +++++----- cpp/include/raft/core/mdarray.hpp | 36 +++++++++---------- cpp/include/raft/core/mdspan.hpp | 21 ++++++----- cpp/include/raft/core/mdspan_types.hpp | 9 ++--- cpp/include/raft/core/span.hpp | 11 +++--- cpp/test/mdspan_utils.cu | 11 +++--- 14 files changed, 112 insertions(+), 77 deletions(-) rename cpp/include/raft/core/detail/{accessor_mixin.hpp => host_device_accessor.hpp} (88%) create mode 100644 cpp/include/raft/core/detail/macros.hpp diff --git a/cpp/include/raft.hpp b/cpp/include/raft.hpp index b1b8255b7e..7b997bfb87 100644 --- a/cpp/include/raft.hpp +++ b/cpp/include/raft.hpp @@ -17,9 +17,10 @@ /** * This file is deprecated and will be removed in release 22.06. */ +#include "raft/core/device_mdarray.hpp" +#include "raft/core/device_mdspan.hpp" +#include "raft/core/device_span.hpp" #include "raft/handle.hpp" -#include "raft/mdarray.hpp" -#include "raft/span.hpp" #include diff --git a/cpp/include/raft/core/detail/device_mdarray.hpp b/cpp/include/raft/core/detail/device_mdarray.hpp index 569f573c19..aad35b6282 100644 --- a/cpp/include/raft/core/detail/device_mdarray.hpp +++ b/cpp/include/raft/core/detail/device_mdarray.hpp @@ -25,7 +25,7 @@ #include #include -#include +#include #include // dynamic_extent #include diff --git a/cpp/include/raft/core/detail/accessor_mixin.hpp b/cpp/include/raft/core/detail/host_device_accessor.hpp similarity index 88% rename from cpp/include/raft/core/detail/accessor_mixin.hpp rename to cpp/include/raft/core/detail/host_device_accessor.hpp index 6edd85dbaf..3a71e6366b 100644 --- a/cpp/include/raft/core/detail/accessor_mixin.hpp +++ b/cpp/include/raft/core/detail/host_device_accessor.hpp @@ -22,7 +22,7 @@ namespace raft::detail { * @brief A mixin to distinguish host and device memory. */ template -struct accessor_mixin : public AccessorPolicy { +struct host_device_accessor : public AccessorPolicy { using accessor_type = AccessorPolicy; using is_host_type = std::conditional_t; using is_device_type = std::conditional_t; @@ -32,8 +32,8 @@ struct accessor_mixin : public AccessorPolicy { static constexpr bool is_managed_accessible = is_device && is_host; // make sure the explicit ctor can fall through using AccessorPolicy::AccessorPolicy; - using offset_policy = accessor_mixin; - accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT + using offset_policy = host_device_accessor; + host_device_accessor(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT }; } // namespace raft::detail diff --git a/cpp/include/raft/core/detail/macros.hpp b/cpp/include/raft/core/detail/macros.hpp new file mode 100644 index 0000000000..00fbab1530 --- /dev/null +++ b/cpp/include/raft/core/detail/macros.hpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifndef _RAFT_HAS_CUDA +#if defined(__CUDACC__) +#define _RAFT_HAS_CUDA __CUDACC__ +#endif +#endif + +#ifndef _RAFT_HOST_DEVICE +#if defined(_RAFT_HAS_CUDA) +#define _RAFT_HOST_DEVICE __host__ __device__ +#else +#define _RAFT_HOST_DEVICE +#endif +#endif + +#ifndef RAFT_INLINE_FUNCTION +#define RAFT_INLINE_FUNCTION inline _RAFT_HOST_DEVICE +#endif diff --git a/cpp/include/raft/core/detail/mdspan_util.cuh b/cpp/include/raft/core/detail/mdspan_util.cuh index b2f73857b2..6b2c90abcc 100644 --- a/cpp/include/raft/core/detail/mdspan_util.cuh +++ b/cpp/include/raft/core/detail/mdspan_util.cuh @@ -17,6 +17,7 @@ #include +#include #include #include diff --git a/cpp/include/raft/core/detail/span.hpp b/cpp/include/raft/core/detail/span.hpp index c11e6ba32b..20500d618b 100644 --- a/cpp/include/raft/core/detail/span.hpp +++ b/cpp/include/raft/core/detail/span.hpp @@ -16,8 +16,8 @@ #pragma once #include // numeric_limits +#include #include -#include #include namespace raft { @@ -74,10 +74,10 @@ struct is_span_t : public is_span_oracle_t::type> { }; template -_MDSPAN_HOST_DEVICE constexpr auto lexicographical_compare(InputIt1 first1, - InputIt1 last1, - InputIt2 first2, - InputIt2 last2) -> bool +_RAFT_HOST_DEVICE constexpr auto lexicographical_compare(InputIt1 first1, + InputIt1 last1, + InputIt2 first2, + InputIt2 last2) -> bool { Compare comp; for (; first1 != last1 && first2 != last2; ++first1, ++first2) { diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 1b20629f5f..1c17b5bcb9 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include namespace raft { diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 88cd3dbf8e..4d6c1836fc 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -16,16 +16,16 @@ #pragma once -#include +#include #include namespace raft { template -using device_accessor = detail::accessor_mixin; +using device_accessor = detail::host_device_accessor; template -using managed_accessor = detail::accessor_mixin; +using managed_accessor = detail::host_device_accessor; /** * @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location. @@ -44,17 +44,18 @@ using managed_mdspan = mdspan -struct is_device_mdspan : std::false_type { +struct is_device_accessible_mdspan : std::false_type { }; template -struct is_device_mdspan : std::bool_constant { +struct is_device_accessible_mdspan + : std::bool_constant { }; /** * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type */ template -using is_device_mdspan_t = is_device_mdspan>; +using is_device_accessible_mdspan_t = is_device_accessible_mdspan>; template struct is_managed_mdspan : std::false_type { @@ -76,10 +77,11 @@ using is_managed_mdspan_t = is_managed_mdspan>; * derived type */ template -inline constexpr bool is_device_mdspan_v = std::conjunction_v...>; +inline constexpr bool is_device_accessible_mdspan_v = + std::conjunction_v...>; template -using enable_if_device_mdspan = std::enable_if_t>; +using enable_if_device_mdspan = std::enable_if_t>; /** * @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a @@ -187,8 +189,7 @@ template auto make_device_vector_view(ElementType* ptr, IndexType n) { - vector_extent extents{n}; - return device_vector_view{ptr, extents}; + return device_vector_view{ptr, n}; } } // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index fcd637f3a5..e6ab22004e 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -18,12 +18,12 @@ #include -#include +#include namespace raft { template -using host_accessor = detail::accessor_mixin; +using host_accessor = detail::host_device_accessor; /** * @brief std::experimental::mdspan with host tag to avoid accessing incorrect memory location. @@ -37,17 +37,18 @@ using host_mdspan = mdspan -struct is_host_mdspan : std::false_type { +struct is_host_accessible_mdspan : std::false_type { }; template -struct is_host_mdspan : std::bool_constant { +struct is_host_accessible_mdspan + : std::bool_constant { }; /** * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type */ template -using is_host_mdspan_t = is_host_mdspan>; +using is_host_accessible_mdspan_t = is_host_accessible_mdspan>; } // namespace detail @@ -56,10 +57,11 @@ using is_host_mdspan_t = is_host_mdspan>; * derived type */ template -inline constexpr bool is_host_mdspan_v = std::conjunction_v...>; +inline constexpr bool is_host_accessible_mdspan_v = + std::conjunction_v...>; template -using enable_if_host_mdspan = std::enable_if_t>; +using enable_if_host_mdspan = std::enable_if_t>; /** * @brief Shorthand for 0-dim host mdspan (scalar). @@ -137,7 +139,6 @@ template auto make_host_vector_view(ElementType* ptr, IndexType n) { - vector_extent extents{n}; - return host_vector_view{ptr, extents}; + return host_vector_view{ptr, n}; } } // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 64e8504708..44730d901e 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -24,10 +24,10 @@ #include -#include +#include +#include #include #include - #include namespace raft { @@ -158,9 +158,9 @@ class mdarray mdspan>; + detail::host_device_accessor>; public: /** @@ -266,61 +266,61 @@ class mdarray } // basic_mdarray observers of the domain multidimensional index space (also in basic_mdspan) - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto rank() noexcept -> rank_type + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto rank() noexcept -> rank_type { return extents_type::rank(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto rank_dynamic() noexcept -> rank_type + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto rank_dynamic() noexcept -> rank_type { return extents_type::rank_dynamic(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto static_extent(size_t r) noexcept + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto static_extent(size_t r) noexcept -> index_type { return extents_type::static_extent(r); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto extents() const noexcept -> extents_type + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto extents() const noexcept -> extents_type { return map_.extents(); } /** * @brief the extent of rank r */ - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto extent(size_t r) const noexcept -> index_type + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto extent(size_t r) const noexcept -> index_type { return map_.extents().extent(r); } // mapping - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto mapping() const noexcept -> mapping_type + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto mapping() const noexcept -> mapping_type { return map_; } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto is_unique() const noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto is_unique() const noexcept -> bool { return map_.is_unique(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto is_exhaustive() const noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto is_exhaustive() const noexcept -> bool { return map_.is_exhaustive(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto is_strided() const noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto is_strided() const noexcept -> bool { return map_.is_strided(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto stride(size_t r) const -> index_type + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto stride(size_t r) const -> index_type { return map_.stride(r); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto is_always_unique() noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto is_always_unique() noexcept -> bool { return mapping_type::is_always_unique(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto is_always_exhaustive() noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto is_always_exhaustive() noexcept -> bool { return mapping_type::is_always_exhaustive(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto is_always_strided() noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto is_always_strided() noexcept -> bool { return mapping_type::is_always_strided(); } diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 289cc82f1b..7169a010b6 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -18,7 +18,8 @@ #include #include -#include +#include +#include #include #include @@ -59,9 +60,6 @@ struct is_mdspan( template using is_mdspan_t = is_mdspan>; -// template -// inline constexpr bool is_mdspan_v = is_mdspan_t::value; - /** * @\brief Boolean to determine if variadic template types Tn are either * raft::host_mdspan/raft::device_mdspan or their derived types @@ -76,7 +74,7 @@ using enable_if_mdspan = std::enable_if_t>; // slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 // bit when the index is smaller, then try to avoid division when it's exp of 2. template -MDSPAN_INLINE_FUNCTION auto unravel_index_impl( +RAFT_INLINE_FUNCTION auto unravel_index_impl( I idx, std::experimental::extents shape) { constexpr auto kRank = static_cast(shape.rank()); @@ -117,9 +115,10 @@ template auto make_mdspan(ElementType* ptr, extents exts) { - using accessor_type = detail::accessor_mixin, - is_host_accessible, - is_device_accessible>; + using accessor_type = + detail::host_device_accessor, + is_host_accessible, + is_device_accessible>; return mdspan{ptr, exts}; } @@ -207,9 +206,9 @@ auto reshape(mdspan_type mds, extents new_shape) * \return A std::tuple that represents the coordinate. */ template -MDSPAN_INLINE_FUNCTION auto unravel_index(Idx idx, - extents shape, - LayoutPolicy const& layout) +RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, + extents shape, + LayoutPolicy const& layout) { static_assert(std::is_same_v>, layout_c_contiguous>, diff --git a/cpp/include/raft/core/mdspan_types.hpp b/cpp/include/raft/core/mdspan_types.hpp index 1417dea8f6..bc2ba314a3 100644 --- a/cpp/include/raft/core/mdspan_types.hpp +++ b/cpp/include/raft/core/mdspan_types.hpp @@ -20,13 +20,8 @@ namespace raft { -constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; - -/** - * @brief Dimensions extents for raft::mdspan - */ -template -using extents = std::experimental::extents; +using std::experimental::dynamic_extent; +using std::experimental::extents; /** * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. diff --git a/cpp/include/raft/core/span.hpp b/cpp/include/raft/core/span.hpp index db3b25296b..188d58c896 100644 --- a/cpp/include/raft/core/span.hpp +++ b/cpp/include/raft/core/span.hpp @@ -20,13 +20,14 @@ #include // std::byte #include +#include #include // TODO (cjnolet): Remove thrust dependencies here so host_span can be used without CUDA Toolkit // being installed. Reference: https://github.com/rapidsai/raft/issues/812. #include #include -#include // _MDSPAN_HOST_DEVICE +#include // _RAFT_HOST_DEVICE #include #include @@ -113,22 +114,22 @@ class span { constexpr auto cend() const noexcept -> const_iterator { return data() + size(); } - _MDSPAN_HOST_DEVICE constexpr auto rbegin() const noexcept -> reverse_iterator + _RAFT_HOST_DEVICE constexpr auto rbegin() const noexcept -> reverse_iterator { return reverse_iterator{end()}; } - _MDSPAN_HOST_DEVICE constexpr auto rend() const noexcept -> reverse_iterator + _RAFT_HOST_DEVICE constexpr auto rend() const noexcept -> reverse_iterator { return reverse_iterator{begin()}; } - _MDSPAN_HOST_DEVICE constexpr auto crbegin() const noexcept -> const_reverse_iterator + _RAFT_HOST_DEVICE constexpr auto crbegin() const noexcept -> const_reverse_iterator { return const_reverse_iterator{cend()}; } - _MDSPAN_HOST_DEVICE constexpr auto crend() const noexcept -> const_reverse_iterator + _RAFT_HOST_DEVICE constexpr auto crend() const noexcept -> const_reverse_iterator { return const_reverse_iterator{cbegin()}; } diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 7f1efb78bb..5683c0267a 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -55,16 +55,17 @@ void test_template_asserts() static_assert(is_mdspan_v, "Derived device mdspan type is not mdspan"); // Checking if types are device_mdspan - static_assert(is_device_mdspan_v>, + static_assert(is_device_accessible_mdspan_v>, "device_matrix_view type not a device_mdspan"); - static_assert(!is_device_mdspan_v>, + static_assert(!is_device_accessible_mdspan_v>, "host_matrix_view type is a device_mdspan"); - static_assert(is_device_mdspan_v, "Derived device mdspan type is not device_mdspan"); + static_assert(is_device_accessible_mdspan_v, + "Derived device mdspan type is not device_mdspan"); // Checking if types are host_mdspan - static_assert(!is_host_mdspan_v>, + static_assert(!is_host_accessible_mdspan_v>, "device_matrix_view type is a host_mdspan"); - static_assert(is_host_mdspan_v>, + static_assert(is_host_accessible_mdspan_v>, "host_matrix_view type is not a host_mdspan"); // checking variadics From b6c758c44e6186149f36f2ccabb857353d64b5be Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 14 Sep 2022 19:55:06 -0400 Subject: [PATCH 10/11] Fixing style --- cpp/include/raft/core/detail/device_mdarray.hpp | 2 +- cpp/include/raft/spatial/knn/ivf_flat_types.hpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/include/raft/core/detail/device_mdarray.hpp b/cpp/include/raft/core/detail/device_mdarray.hpp index ad387db33d..ff7c31000d 100644 --- a/cpp/include/raft/core/detail/device_mdarray.hpp +++ b/cpp/include/raft/core/detail/device_mdarray.hpp @@ -21,9 +21,9 @@ * limitations under the License. */ #pragma once -#include #include #include +#include #include #include // dynamic_extent diff --git a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp index 5014e6c41e..aabd6edfe3 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp +++ b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp @@ -20,7 +20,6 @@ #include #include -#include #include #include From 7eae6e3d16ad20ceab58b0bf97a041aab177d04a Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 15 Sep 2022 13:00:12 -0400 Subject: [PATCH 11/11] Fixing bad merge --- cpp/include/raft/spatial/knn/ivf_flat_types.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp index aabd6edfe3..41fa1dd8ce 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp +++ b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include