From 0d6595462cf0d35b1440295056a72fb24c4ad6da Mon Sep 17 00:00:00 2001 From: William Hicks Date: Wed, 3 Jan 2024 21:16:46 -0500 Subject: [PATCH] Implement maybe-owning multi-dimensional container (mdbuffer) (#1999) ### What is mdbuffer? This PR introduces a maybe-owning multi-dimensional abstraction called `mdbuffer` to help simplify code that _may_ require an `mdarray` but only if the data are not already in a desired form or location. As a concrete example, consider a function `foo_device` which operates on memory accessible from the device. If we wish to pass it data originating on the host, a separate code path must be created in which a `device_mdarray` is created and the data are explicitly copied from host to device. This leads to a proliferation of branches as `foo_device` interacts with other functions with similar requirements. As an initial simplification, `mdbuffer` allows us to write a single template that accepts an `mdspan` pointing to memory on either host _or_ device and routes it through the same code: ```c++ template void foo_device(raft::resources const& res, mdspan_type data) { auto buf = raft::mdbuffer{res, raft::mdbuffer{data}, raft::memory_type::device}; // Data in buf is now guaranteed to be accessible from device. // If it was already accessible from device, no copy was performed. If it // was not, a copy was performed. some_kernel<<<...>>>(buf.view()); // It is sometimes useful to know whether or not a copy was performed to // e.g. determine whether the transformed data should be copied back to its original // location. This can be checked via the `is_owning()` method. if (buf.is_owning()) { raft::copy(res, data, buf.view()); } } foo_device(res, some_host_mdspan); // Still works; memory is allocated and copy is performed foo_device(res, some_device_mdspan); // Still works and no allocation or copy is required foo_device(res, some_managed_mdspan); // Still works and no allocation or copy is required ``` While this is a useful simplification, it still leads to a proliferation of template instantiations. If this is undesirable, `mdbuffer` permits a further consolidation through implicit conversion of an mdspan to an mdbuffer: ```c++ void foo_device(raft::resources const& res, raft::mdbuffer>&& data) { auto buf = raft::mdbuffer{res, data, raft::memory_type::device}; some_kernel<<<...>>>(buf.view()); if (buf.is_owning()) { raft::copy(res, data, buf.view()); } } // All of the following work exactly as before but no longer require separate template instantiations foo_device(res, some_host_mdspan); foo_device(res, some_device_mdspan); foo_device(res, some_managed_mdspan); ``` `mdbuffer` also offers a simple way to perform runtime dispatching based on the memory type passed to it using standard C++ patterns. While mdbuffer's `.view()` method takes an optional template parameter indicating the mdspan type to retrieve as a view, that parameter can be omitted to retrieve a `std::variant` of all mdspan types which may provide a view on the `mdbuffer`'s data (depending on its memory type). We can then use `std::visit` to perform runtime dispatching based on where the data are stored: ```c++ void foo(raft::resources const& res, raft::mdbuffer>&& data) { std::visit([](auto view) { if constexpr (typename decltype(view)::accessor_type::is_device_accessible) { // Do something with these data on device } else { // Do something with these data on host } }, data.view()); } ``` In addition to moving data among various memory types (host, device, managed, and pinned currently), `mdbuffer` can be used to coerce data to a desired in-memory layout or to a compatible data type (e.g. floats to doubles). As with changes in the memory type, a copy will be performed if and only if it is necessary. ```c++ template void foo_device(raft::resources const& res, mdspan_type data) { auto buf = raft::mdbuffer, raft::row_major>{res, raft::mdbuffer{data}, raft::memory_type::device}; // Data in buf is now guaranteed to be accessible from device, and // represented by floats in row-major order. some_kernel<<<...>>>(buf.view()); // The same check can be used to determine whether or not a copy was // required, regardless of the cause. I.e. if the data were already on // device but in column-major order, the is_owning() method would still // return true because new storage needed to be allocated. if (buf.is_owning()) { raft::copy(res, data, buf.view()); } } ``` ### What mdbuffer is **not** `mdbuffer` is **not** a replacement for either `mdspan` or `mdarray`. `mdspan` remains the standard object for passing data views throughout the RAFT codebase, and `mdarray` remains the standard object for allocating new multi-dimensional data. This is reflected in the fact that `mdbuffer` can _only_ be constructed from an existing `mdspan` or `mdarray` or another `mdbuffer`. `mdbuffer` is intended to be used solely to simplify code where data _may_ need to be copied to a different location. ### Follow-ups - I have omitted the mdbuffer-based replacement for and generalization of `temporary_device_buffer` since this PR is already enormous. I have this partially written however, and I'll post a link to its current state to help motivate the changes here. - For all necessary copies, `mdbuffer` uses `raft::copy`. For _some_ transformations that require a change in data type or layout, `raft::copy` is not fully optimized. See #1842 for more information. Optimizing this will be an important change to ensure that `mdbuffer` can be used with absolutely minimal overhead in all cases. These non-optimized cases represent a small fraction of the real-world use cases we can expect for `mdbuffer`, however, so there should be little concern about beginning to use it as is. - `std::visit`'s performance for a small number of variants is sometimes non-optimal. As a followup, it would be good to benchmark `mdbuffer`'s current performance and compare to internal use of a `visit` implementation that uses a `switch` on the available memory types. Resolve #1602 Authors: - William Hicks (https://github.com/wphicks) - Tarang Jain (https://github.com/tarang-jain) Approvers: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) - Artem M. Chirkin (https://github.com/achirkin) - Tamas Bela Feher (https://github.com/tfeher) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1999 --- .../core/detail/fail_container_policy.hpp | 146 +++ .../raft/core/device_container_policy.hpp | 22 +- cpp/include/raft/core/device_mdspan.hpp | 88 +- .../raft/core/host_container_policy.hpp | 3 +- .../raft/core/host_device_accessor.hpp | 12 +- cpp/include/raft/core/host_mdspan.hpp | 10 +- .../raft/core/managed_container_policy.hpp | 86 ++ cpp/include/raft/core/managed_mdarray.hpp | 152 +++ cpp/include/raft/core/managed_mdspan.hpp | 273 +++++ cpp/include/raft/core/mdbuffer.cuh | 1020 +++++++++++++++++ cpp/include/raft/core/mdbuffer.hpp | 26 + cpp/include/raft/core/memory_type.hpp | 56 +- .../raft/core/pinned_container_policy.hpp | 142 +++ cpp/include/raft/core/pinned_mdarray.hpp | 152 +++ cpp/include/raft/core/pinned_mdspan.hpp | 270 +++++ cpp/include/raft/core/serialize.hpp | 3 +- cpp/include/raft/core/stream_view.hpp | 3 +- .../raft/util/memory_type_dispatcher.cuh | 209 ++++ cpp/include/raft/util/variant_utils.hpp | 64 ++ cpp/test/CMakeLists.txt | 4 +- cpp/test/core/mdarray.cu | 3 +- cpp/test/core/mdbuffer.cu | 330 ++++++ cpp/test/core/memory_type.cpp | 34 +- cpp/test/core/numpy_serializer.cu | 3 +- cpp/test/util/memory_type_dispatcher.cu | 421 +++++++ docs/source/cpp_api/mdspan.rst | 2 + docs/source/cpp_api/mdspan_mdarray.rst | 66 +- docs/source/cpp_api/mdspan_mdbuffer.rst | 13 + docs/source/cpp_api/mdspan_mdspan.rst | 39 +- .../source/cpp_api/memory_type_dispatcher.rst | 13 + 30 files changed, 3563 insertions(+), 102 deletions(-) create mode 100644 cpp/include/raft/core/detail/fail_container_policy.hpp create mode 100644 cpp/include/raft/core/managed_container_policy.hpp create mode 100644 cpp/include/raft/core/managed_mdarray.hpp create mode 100644 cpp/include/raft/core/managed_mdspan.hpp create mode 100644 cpp/include/raft/core/mdbuffer.cuh create mode 100644 cpp/include/raft/core/mdbuffer.hpp create mode 100644 cpp/include/raft/core/pinned_container_policy.hpp create mode 100644 cpp/include/raft/core/pinned_mdarray.hpp create mode 100644 cpp/include/raft/core/pinned_mdspan.hpp create mode 100644 cpp/include/raft/util/memory_type_dispatcher.cuh create mode 100644 cpp/include/raft/util/variant_utils.hpp create mode 100644 cpp/test/core/mdbuffer.cu create mode 100644 cpp/test/util/memory_type_dispatcher.cu create mode 100644 docs/source/cpp_api/mdspan_mdbuffer.rst create mode 100644 docs/source/cpp_api/memory_type_dispatcher.rst diff --git a/cpp/include/raft/core/detail/fail_container_policy.hpp b/cpp/include/raft/core/detail/fail_container_policy.hpp new file mode 100644 index 0000000000..ff36659f04 --- /dev/null +++ b/cpp/include/raft/core/detail/fail_container_policy.hpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +namespace raft { +namespace detail { + +template +struct fail_reference { + using value_type = typename std::remove_cv_t; + using pointer = T*; + using const_pointer = T const*; + + fail_reference() = default; + template + fail_reference(T* ptr, StreamViewType stream) + { + throw non_cuda_build_error{"Attempted to construct reference to device data in non-CUDA build"}; + } + + operator value_type() const // NOLINT + { + throw non_cuda_build_error{"Attempted to dereference device data in non-CUDA build"}; + return value_type{}; + } + auto operator=(T const& other) -> fail_reference& + { + throw non_cuda_build_error{"Attempted to assign to device data in non-CUDA build"}; + return *this; + } +}; + +/** A placeholder container which throws an exception on use + * + * This placeholder is used in non-CUDA builds for container types that would + * otherwise be provided with CUDA code. Attempting to construct a non-empty + * container of this type throws an exception indicating that there was an + * attempt to use the device from a non-CUDA build. An example of when this + * might happen is if a downstream application attempts to allocate a device + * mdarray using a library built with non-CUDA RAFT. + */ +template +struct fail_container { + using value_type = T; + using size_type = std::size_t; + + using reference = fail_reference; + using const_reference = fail_reference; + + using pointer = value_type*; + using const_pointer = value_type const*; + + using iterator = pointer; + using const_iterator = const_pointer; + + explicit fail_container(size_t n = size_t{}) + { + if (n != size_t{}) { + throw non_cuda_build_error{"Attempted to allocate device container in non-CUDA build"}; + } + } + + template + auto operator[](Index i) noexcept -> reference + { + RAFT_LOG_ERROR("Attempted to access device data in non-CUDA build"); + return reference{}; + } + + template + auto operator[](Index i) const noexcept -> const_reference + { + RAFT_LOG_ERROR("Attempted to access device data in non-CUDA build"); + return const_reference{}; + } + void resize(size_t n) + { + if (n != size_t{}) { + throw non_cuda_build_error{"Attempted to allocate device container in non-CUDA build"}; + } + } + + [[nodiscard]] auto data() noexcept -> pointer { return nullptr; } + [[nodiscard]] auto data() const noexcept -> const_pointer { return nullptr; } +}; + +/** A placeholder container policy which throws an exception on use + * + * This placeholder is used in non-CUDA builds for container types that would + * otherwise be provided with CUDA code. Attempting to construct a non-empty + * container of this type throws an exception indicating that there was an + * attempt to use the device from a non-CUDA build. An example of when this + * might happen is if a downstream application attempts to allocate a device + * mdarray using a library built with non-CUDA RAFT. + */ +template +struct fail_container_policy { + using element_type = ElementType; + using container_type = fail_container; + using pointer = typename container_type::pointer; + using const_pointer = typename container_type::const_pointer; + using reference = typename container_type::reference; + using const_reference = typename container_type::const_reference; + + using accessor_policy = std::experimental::default_accessor; + using const_accessor_policy = std::experimental::default_accessor; + + auto create(raft::resources const& res, size_t n) -> container_type { return container_type(n); } + + fail_container_policy() = default; + + [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference + { + return c[n]; + } + [[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept + -> const_reference + { + return c[n]; + } + + [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } + [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } +}; + +} // namespace detail +} // namespace raft diff --git a/cpp/include/raft/core/device_container_policy.hpp b/cpp/include/raft/core/device_container_policy.hpp index 011de307db..e8717d4c5e 100644 --- a/cpp/include/raft/core/device_container_policy.hpp +++ b/cpp/include/raft/core/device_container_policy.hpp @@ -6,7 +6,7 @@ */ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ * limitations under the License. */ #pragma once +#ifndef RAFT_DISABLE_CUDA #include #include @@ -196,3 +197,22 @@ class device_uvector_policy { }; } // namespace raft +#else +#include +namespace raft { + +// Provide placeholders that will allow CPU-GPU interoperable codebases to +// compile in non-CUDA mode but which will throw exceptions at runtime on any +// attempt to touch device data + +template +using device_reference = detail::fail_reference; + +template +using device_uvector = detail::fail_container; + +template +using device_uvector_policy = detail::fail_container_policy; + +} // namespace raft +#endif diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index c1898a3f09..3b6165b86a 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,9 +26,6 @@ namespace raft { template using device_accessor = host_device_accessor; -template -using managed_accessor = host_device_accessor; - /** * @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location. */ @@ -38,12 +35,6 @@ template > using device_mdspan = mdspan>; -template > -using managed_mdspan = mdspan>; - template struct is_device_mdspan : std::false_type {}; template @@ -61,23 +52,6 @@ using is_input_device_mdspan_t = is_device_mdspan>; template using is_output_device_mdspan_t = is_device_mdspan>; -template -struct is_managed_mdspan : std::false_type {}; -template -struct is_managed_mdspan : std::bool_constant {}; - -/** - * @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type - */ -template -using is_managed_mdspan_t = is_managed_mdspan>; - -template -using is_input_managed_mdspan_t = is_managed_mdspan>; - -template -using is_output_managed_mdspan_t = is_managed_mdspan>; - /** * @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a * derived type @@ -102,30 +76,6 @@ using enable_if_input_device_mdspan = std::enable_if_t using enable_if_output_device_mdspan = std::enable_if_t>; -/** - * @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a - * derived type - */ -template -inline constexpr bool is_managed_mdspan_v = std::conjunction_v...>; - -template -inline constexpr bool is_input_managed_mdspan_v = - std::conjunction_v...>; - -template -inline constexpr bool is_output_managed_mdspan_v = - std::conjunction_v...>; - -template -using enable_if_managed_mdspan = std::enable_if_t>; - -template -using enable_if_input_managed_mdspan = std::enable_if_t>; - -template -using enable_if_output_managed_mdspan = std::enable_if_t>; - /** * @brief Shorthand for 0-dim host mdspan (scalar). * @tparam ElementType the data type of the scalar element @@ -186,7 +136,7 @@ using device_aligned_matrix_view = template > -auto make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +auto constexpr make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) { using data_handle_type = typename std::experimental::aligned_accessor{aligned_pointer, extents}; } -/** - * @brief Create a raft::managed_mdspan - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param ptr Pointer to the data - * @param exts dimensionality of the array (series of integers) - * @return raft::managed_mdspan - */ -template -auto make_managed_mdspan(ElementType* ptr, extents exts) -{ - return make_mdspan(ptr, exts); -} - /** * @brief Create a 0-dim (scalar) mdspan instance for device value. * @@ -229,7 +161,7 @@ auto make_managed_mdspan(ElementType* ptr, extents exts) * @param[in] ptr on device to wrap */ template -auto make_device_scalar_view(ElementType* ptr) +auto constexpr make_device_scalar_view(ElementType* ptr) { scalar_extent extents; return device_scalar_view{ptr, extents}; @@ -249,7 +181,7 @@ auto make_device_scalar_view(ElementType* ptr) template -auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +auto constexpr make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) { matrix_extent extents{n_rows, n_cols}; return device_matrix_view{ptr, extents}; @@ -269,10 +201,10 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col * @param[in] stride leading dimension / stride of data */ template -auto make_device_strided_matrix_view(ElementType* ptr, - IndexType n_rows, - IndexType n_cols, - IndexType stride) +auto constexpr make_device_strided_matrix_view(ElementType* ptr, + IndexType n_rows, + IndexType n_cols, + IndexType stride) { constexpr auto is_row_major = std::is_same_v; IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; @@ -295,7 +227,7 @@ auto make_device_strided_matrix_view(ElementType* ptr, * @return raft::device_vector_view */ template -auto make_device_vector_view(ElementType* ptr, IndexType n) +auto constexpr make_device_vector_view(ElementType* ptr, IndexType n) { return device_vector_view{ptr, n}; } @@ -310,7 +242,7 @@ auto make_device_vector_view(ElementType* ptr, IndexType n) * @return raft::device_vector_view */ template -auto make_device_vector_view( +auto constexpr make_device_vector_view( ElementType* ptr, const typename LayoutPolicy::template mapping>& mapping) { diff --git a/cpp/include/raft/core/host_container_policy.hpp b/cpp/include/raft/core/host_container_policy.hpp index 3b3538ea20..0192436934 100644 --- a/cpp/include/raft/core/host_container_policy.hpp +++ b/cpp/include/raft/core/host_container_policy.hpp @@ -6,7 +6,7 @@ */ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,4 +62,5 @@ class host_vector_policy { [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } }; + } // namespace raft diff --git a/cpp/include/raft/core/host_device_accessor.hpp b/cpp/include/raft/core/host_device_accessor.hpp index e9ebdb6c9f..7cb2aaf487 100644 --- a/cpp/include/raft/core/host_device_accessor.hpp +++ b/cpp/include/raft/core/host_device_accessor.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,6 +42,16 @@ struct host_device_accessor : public AccessorPolicy { using AccessorPolicy::AccessorPolicy; using offset_policy = host_device_accessor; host_device_accessor(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT + + // Prevent implicit conversion from incompatible host_device_accessor types + template + host_device_accessor(host_device_accessor const& that) = delete; + + template > + host_device_accessor(host_device_accessor const& that) + : AccessorPolicy{that} + { + } }; } // namespace raft diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index 9a675680ac..d5f431f4a2 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -134,7 +134,7 @@ using host_aligned_matrix_view = template > -auto make_host_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +auto constexpr make_host_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) { using data_handle_type = typename std::experimental::aligned_accessor -auto make_host_scalar_view(ElementType* ptr) +auto constexpr make_host_scalar_view(ElementType* ptr) { scalar_extent extents; return host_scalar_view{ptr, extents}; @@ -179,7 +179,7 @@ auto make_host_scalar_view(ElementType* ptr) template -auto make_host_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +auto constexpr make_host_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) { matrix_extent extents{n_rows, n_cols}; return host_matrix_view{ptr, extents}; @@ -196,7 +196,7 @@ auto make_host_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) template -auto make_host_vector_view(ElementType* ptr, IndexType n) +auto constexpr make_host_vector_view(ElementType* ptr, IndexType n) { return host_vector_view{ptr, n}; } diff --git a/cpp/include/raft/core/managed_container_policy.hpp b/cpp/include/raft/core/managed_container_policy.hpp new file mode 100644 index 0000000000..f4e26c6ef1 --- /dev/null +++ b/cpp/include/raft/core/managed_container_policy.hpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#ifndef RAFT_DISABLE_CUDA +#include +#include +#include + +#include // dynamic_extent +#include + +#include +#include +#include + +namespace raft { +/** + * @brief A container policy for managed mdarray. + */ +template +class managed_uvector_policy { + public: + using element_type = ElementType; + using container_type = device_uvector; + using pointer = typename container_type::pointer; + using const_pointer = typename container_type::const_pointer; + using reference = device_reference; + using const_reference = device_reference; + + using accessor_policy = std::experimental::default_accessor; + using const_accessor_policy = std::experimental::default_accessor; + + auto create(raft::resources const& res, size_t n) -> container_type + { + return container_type(n, resource::get_cuda_stream(res), mr_); + } + + [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference + { + return c[n]; + } + [[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept + -> const_reference + { + return c[n]; + } + + [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } + [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } + + private: + static auto* get_default_memory_resource() + { + auto static result = rmm::mr::managed_memory_resource{}; + return &result; + } + rmm::mr::managed_memory_resource* mr_{get_default_memory_resource()}; +}; + +} // namespace raft +#else +#include +namespace raft { + +// Provide placeholders that will allow CPU-GPU interoperable codebases to +// compile in non-CUDA mode but which will throw exceptions at runtime on any +// attempt to touch device data + +template +using managed_uvector_policy = detail::fail_container_policy; + +} // namespace raft +#endif diff --git a/cpp/include/raft/core/managed_mdarray.hpp b/cpp/include/raft/core/managed_mdarray.hpp new file mode 100644 index 0000000000..c1438d941d --- /dev/null +++ b/cpp/include/raft/core/managed_mdarray.hpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief mdarray with managed container policy + * @tparam ElementType the data type of the elements + * @tparam Extents defines the shape + * @tparam LayoutPolicy policy for indexing strides and layout ordering + * @tparam ContainerPolicy storage and accessor policy + */ +template > +using managed_mdarray = + mdarray>; + +/** + * @brief Shorthand for 0-dim host mdarray (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using managed_scalar = managed_mdarray>; + +/** + * @brief Shorthand for 1-dim managed mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using managed_vector = managed_mdarray, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous managed matrix. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using managed_matrix = managed_mdarray, LayoutPolicy>; + +/** + * @brief Create a managed mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param handle raft::resources + * @param exts dimensionality of the array (series of integers) + * @return raft::managed_mdarray + */ +template +auto make_managed_mdarray(raft::resources const& handle, extents exts) +{ + using mdarray_t = managed_mdarray; + + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy{}; + + return mdarray_t{handle, layout, policy}; +} + +/** + * @brief Create a 2-dim c-contiguous managed mdarray. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] handle raft handle for managing expensive resources + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @return raft::managed_matrix + */ +template +auto make_managed_matrix(raft::resources const& handle, IndexType n_rows, IndexType n_cols) +{ + return make_managed_mdarray( + handle, make_extents(n_rows, n_cols)); +} + +/** + * @brief Create a managed scalar from v. + * + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] v scalar to wrap on managed + * @return raft::managed_scalar + */ +template +auto make_managed_scalar(raft::resources const& handle, ElementType const& v) +{ + scalar_extent extents; + using policy_t = typename managed_scalar::container_policy_type; + policy_t policy{}; + auto scalar = managed_scalar{handle, extents, policy}; + scalar(0) = v; + return scalar; +} + +/** + * @brief Create a 1-dim managed mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] n number of elements in vector + * @return raft::managed_vector + */ +template +auto make_managed_vector(raft::resources const& handle, IndexType n) +{ + return make_managed_mdarray(handle, + make_extents(n)); +} + +} // end namespace raft diff --git a/cpp/include/raft/core/managed_mdspan.hpp b/cpp/include/raft/core/managed_mdspan.hpp new file mode 100644 index 0000000000..9c2976ec6b --- /dev/null +++ b/cpp/include/raft/core/managed_mdspan.hpp @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft { + +template +using managed_accessor = host_device_accessor; + +/** + * @brief std::experimental::mdspan with managed tag to indicate host/device accessibility + */ +template > +using managed_mdspan = mdspan>; + +template +struct is_managed_mdspan : std::false_type {}; +template +struct is_managed_mdspan + : std::bool_constant {}; + +/** + * @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type + */ +template +using is_managed_mdspan_t = is_managed_mdspan>; + +template +using is_input_managed_mdspan_t = is_managed_mdspan>; + +template +using is_output_managed_mdspan_t = is_managed_mdspan>; + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a + * derived type + */ +template +inline constexpr bool is_managed_mdspan_v = std::conjunction_v...>; + +template +inline constexpr bool is_input_managed_mdspan_v = + std::conjunction_v...>; + +template +inline constexpr bool is_output_managed_mdspan_v = + std::conjunction_v...>; + +template +using enable_if_managed_mdspan = std::enable_if_t>; + +template +using enable_if_input_managed_mdspan = std::enable_if_t>; + +template +using enable_if_output_managed_mdspan = std::enable_if_t>; + +/** + * @brief Shorthand for 0-dim managed mdspan (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using managed_scalar_view = managed_mdspan>; + +/** + * @brief Shorthand for 1-dim managed mdspan. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using managed_vector_view = managed_mdspan, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous managed matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using managed_matrix_view = managed_mdspan, LayoutPolicy>; + +/** + * @brief Shorthand for 128 byte aligned managed matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy must be of type layout_{left/right}_padded + */ +template , + typename = enable_if_layout_padded> +using managed_aligned_matrix_view = + managed_mdspan, + LayoutPolicy, + std::experimental::aligned_accessor>; + +/** + * @brief Create a 2-dim 128 byte aligned mdspan instance for managed pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy must be of type layout_{left/right}_padded + * @tparam IndexType the index type of the extents + * @param[in] ptr to managed memory to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + */ +template > +auto constexpr make_managed_aligned_matrix_view(ElementType* ptr, + IndexType n_rows, + IndexType n_cols) +{ + using data_handle_type = + typename std::experimental::aligned_accessor::data_handle_type; + static_assert(std::is_same>::value || + std::is_same>::value); + assert(reinterpret_cast(ptr) == + std::experimental::details::alignTo(reinterpret_cast(ptr), + detail::alignment::value)); + + data_handle_type aligned_pointer = ptr; + + matrix_extent extents{n_rows, n_cols}; + return managed_aligned_matrix_view{aligned_pointer, + extents}; +} + +/** + * @brief Create a 0-dim (scalar) mdspan instance for managed value. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @param[in] ptr to managed memory to wrap + */ +template +auto constexpr make_managed_scalar_view(ElementType* ptr) +{ + scalar_extent extents; + return managed_scalar_view{ptr, extents}; +} + +/** + * @brief Create a 2-dim c-contiguous mdspan instance for managed pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @tparam IndexType the index type of the extents + * @param[in] ptr to managed memory to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + */ +template +auto constexpr make_managed_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +{ + matrix_extent extents{n_rows, n_cols}; + return managed_matrix_view{ptr, extents}; +} + +/** + * @brief Create a 2-dim mdspan instance for managed pointer with a strided layout + * that is restricted to stride 1 in the trailing dimension. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr to managed memory to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + * @param[in] stride leading dimension / stride of data + */ +template +auto constexpr make_managed_strided_matrix_view(ElementType* ptr, + IndexType n_rows, + IndexType n_cols, + IndexType stride) +{ + constexpr auto is_row_major = std::is_same_v; + IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; + IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows); + + assert(is_row_major ? stride0 >= n_cols : stride1 >= n_rows); + matrix_extent extents{n_rows, n_cols}; + + auto layout = make_strided_layout(extents, std::array{stride0, stride1}); + return managed_matrix_view{ptr, layout}; +} + +/** + * @brief Create a 1-dim mdspan instance for managed pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr to managed memory to wrap + * @param[in] n number of elements in pointer + * @return raft::managed_vector_view + */ +template +auto constexpr make_managed_vector_view(ElementType* ptr, IndexType n) +{ + return managed_vector_view{ptr, n}; +} + +/** + * @brief Create a 1-dim mdspan instance for managed pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr to managed memory to wrap + * @param[in] mapping The layout mapping to use for this vector + * @return raft::managed_vector_view + */ +template +auto constexpr make_managed_vector_view( + ElementType* ptr, + const typename LayoutPolicy::template mapping>& mapping) +{ + return managed_vector_view{ptr, mapping}; +} + +/** + * @brief Create a raft::managed_mdspan + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param ptr Pointer to the data + * @param exts dimensionality of the array (series of integers) + * @return raft::managed_mdspan + */ +template +auto constexpr make_managed_mdspan(ElementType* ptr, extents exts) +{ + return make_mdspan(ptr, exts); +} +} // end namespace raft diff --git a/cpp/include/raft/core/mdbuffer.cuh b/cpp/include/raft/core/mdbuffer.cuh new file mode 100644 index 0000000000..18533ce882 --- /dev/null +++ b/cpp/include/raft/core/mdbuffer.cuh @@ -0,0 +1,1020 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef RAFT_DISABLE_CUDA +#include +#include +#include +#else +#include +#endif + +namespace raft { + +/** + * @defgroup mdbuffer_apis multi-dimensional maybe-owning type + * @{ + */ + +/** + * @brief Retrieve a canonical index associated with a given memory type. + * + * For variants based on memory type, this index can be used to help keep a + * consistent ordering of the memory types in the variant. + */ +inline auto constexpr variant_index_from_memory_type(raft::memory_type mem_type) +{ + return static_cast>(mem_type); +} + +/** + * @brief Retrieve the memory type associated with a canonical index + */ +inline auto constexpr memory_type_from_variant_index( + std::underlying_type_t index) +{ + return static_cast(index); +} + +/** + * @brief Retrieve a type from a variant based on a given memory type. + */ +template +using alternate_from_mem_type = + std::variant_alternative_t, + Variant>; + +namespace detail { +template +struct memory_type_to_default_policy {}; +template +struct memory_type_to_default_policy { + using type = typename raft::host_vector_policy; +}; +template +struct memory_type_to_default_policy { + using type = typename raft::device_uvector_policy; +}; +template +struct memory_type_to_default_policy { + using type = typename raft::managed_uvector_policy; +}; +template +struct memory_type_to_default_policy { + using type = typename raft::pinned_vector_policy; +}; + +template +using memory_type_to_default_policy_t = typename memory_type_to_default_policy::type; +} // namespace detail + +/** + * @brief A variant of container policies for each memory type which can be + * used to build the default container policy for a buffer. + */ +template +using default_container_policy_variant = + std::variant, + detail::memory_type_to_default_policy_t, + detail::memory_type_to_default_policy_t, + detail::memory_type_to_default_policy_t>; + +/** + * @brief A template used to translate a variant of underlying mdarray + * container policies into a container policy that can be used by an mdbuffer. + */ +template >> +struct default_buffer_container_policy { + using element_type = ElementType; + using value_type = std::remove_cv_t; + + private: + template + using raw_container_policy_at_index = std::variant_alternative_t; + + public: + using container_policy_variant = + std::variant, + static_cast(0)>, + host_device_accessor, + static_cast(1)>, + host_device_accessor, + static_cast(2)>, + host_device_accessor, + static_cast(3)>>; + template + using container_policy = alternate_from_mem_type; + using container_type_variant = + std::variant::container_type, + typename raw_container_policy_at_index<1>::container_type, + typename raw_container_policy_at_index<2>::container_type, + typename raw_container_policy_at_index<3>::container_type>; + + template + using container_type = alternate_from_mem_type; + + using accessor_policy_variant = + std::variant::accessor_policy, + static_cast(0)>, + host_device_accessor::accessor_policy, + static_cast(1)>, + host_device_accessor::accessor_policy, + static_cast(2)>, + host_device_accessor::accessor_policy, + static_cast(3)>>; + + template + using accessor_policy = alternate_from_mem_type; + + using const_accessor_policy_variant = std::variant< + host_device_accessor::const_accessor_policy, + static_cast(0)>, + host_device_accessor::const_accessor_policy, + static_cast(1)>, + host_device_accessor::const_accessor_policy, + static_cast(2)>, + host_device_accessor::const_accessor_policy, + static_cast(3)>>; + template + using const_accessor_policy = alternate_from_mem_type; + + template + auto create(raft::resources const& res, size_t n) + { + return container_type(res, n); + } + + auto create(raft::resources const& res, size_t n, raft::memory_type mem_type) + { + auto result = container_type_variant{}; + switch (mem_type) { + case raft::memory_type::host: result = create(res, n); break; + case raft::memory_type::device: result = create(res, n); break; + case raft::memory_type::managed: result = create(res, n); break; + case raft::memory_type::pinned: result = create(res, n); break; + } + return result; + } + + private: + template + auto static constexpr has_stream() -> decltype(std::declval().stream(), bool()) + { + return true; + }; + auto static constexpr has_stream(...) -> bool { return false; }; + + public: + template + [[nodiscard]] auto make_accessor_policy() noexcept + { + return accessor_policy{}; + } + template + [[nodiscard]] auto make_accessor_policy() const noexcept + { + return const_accessor_policy{}; + } + + [[nodiscard]] auto make_accessor_policy(memory_type mem_type) noexcept + { + auto result = accessor_policy_variant{}; + switch (mem_type) { + case memory_type::host: result = make_accessor_policy(); break; + case memory_type::device: result = make_accessor_policy(); break; + case memory_type::managed: result = make_accessor_policy(); break; + case memory_type::pinned: result = make_accessor_policy(); break; + } + return result; + } + [[nodiscard]] auto make_accessor_policy(memory_type mem_type) const noexcept + { + auto result = const_accessor_policy_variant{}; + switch (mem_type) { + case memory_type::host: result = make_accessor_policy(); break; + case memory_type::device: result = make_accessor_policy(); break; + case memory_type::managed: result = make_accessor_policy(); break; + case memory_type::pinned: result = make_accessor_policy(); break; + } + return result; + } +}; + +/** + * @brief A type representing multi-dimensional data which may or may not own + * its underlying storage. `raft::mdbuffer` is used to conveniently perform + * copies of data _only_ when necessary to ensure that the data are accessible + * in the desired memory space and format. + * + * When developing functions that interact with the GPU, it is often necessary + * to ensure that the data are in a particular memory space (e.g. device, + * host, managed, pinned), but those functions may be called with data that + * may or may not already be in the desired memory space. For instance, when + * called in one workflow, the data may have been previously transferred to + * device, rendering a copy unnecessary. In another, the function may be + * directly invoked on host data. + * + * Even when working strictly with host memory, it is often necessary to + * ensure that the data are in a particular layout for efficient access (e.g. + * column major vs row major) or that the the data are of a particular type + * (e.g. double) even though we wish to call the function with data of + * another compatible type (e.g. float). + * + * `mdbuffer` is a tool for ensuring that the data are represented in exactly + * the desired format and location while flexibly supporting data which may + * not already be in that format or location. It does so by providing a + * non-owning view on data which are already in the required form, but it + * allocates (owned) memory and performs a copy if and only if it is + * necessary. + * + * Usage example: + * @code{.cpp} + * template + * void foo_device(raft::resources const& res, mdspan_type data) { + * auto buf = raft::mdbuffer{res, raft::mdbuffer{data}, raft::memory_type::device}; + * // Data in buf is now guaranteed to be accessible from device. + * // If it was already accessible from device, no copy was performed. If it + * // was not, a copy was performed. + * + * some_kernel<<<...>>>(buf.view()); + * + * // It is sometimes useful to know whether or not a copy was performed to + * // e.g. determine whether the transformed data should be copied back to its original + * // location. This can be checked via the `is_owning()` method. + * if (buf.is_owning()) { + * raft::copy(res, data, buf.view()); + * } + * } + * @endcode + * + * Note that in this example, the `foo_device` template can be correctly + * instantiated for both host and device mdspans. Similarly we can use + * `mdbuffer` to coerce data to a particular memory layout and data-type, as in + * the following example: + * @code{.cpp} + * template + * void foo_device(raft::resources const& res, mdspan_type data) { + * auto buf = raft::mdbuffer, raft::row_major>{res, + * raft::mdbuffer{data}, raft::memory_type::device}; + * // Data in buf is now guaranteed to be accessible from device, and + * // represented by floats in row-major order. + * + * some_kernel<<<...>>>(buf.view()); + * + * // The same check can be used to determine whether or not a copy was + * // required, regardless of the cause. I.e. if the data were already on + * // device but in column-major order, the is_owning() method would still + * // return true because new storage needed to be allocated. + * if (buf.is_owning()) { + * raft::copy(res, data, buf.view()); + * } + * } + * @endcode + * + * Note that in this example, the `foo_device` template can accept data of + * any float-convertible type in any layout and of any memory type and coerce + * it to the desired device-accessible representation. + * + * Because `mdspan` types can be implicitly converted to `mdbuffer`, it is even + * possible to avoid multiple template instantiations by directly accepting an + * `mdbuffer` as argument, as in the following example: + * @code{.cpp} + * void foo_device(raft::resources const& res, raft::mdbuffer>&& + * data) { auto buf = raft::mdbuffer{res, data, raft::memory_type::device}; + * // Data in buf is now guaranteed to be accessible from device. + * + * some_kernel<<<...>>>(buf.view()); + * } + * @endcode + * + * In this example, `foo_device` can now accept any row-major mdspan of floats + * regardless of memory type without requiring separate template instantiations + * for each type. + * + * While the view method takes an optional compile-time memory type parameter, + * omitting this parameter will return a std::variant of mdspan types. This + * allows for straightforward runtime dispatching based on the memory type + * using std::visit, as in the following example: + * + * @code{.cpp} + * void foo(raft::resources const& res, raft::mdbuffer>&& data) { + * std::visit([](auto&& view) { + * // Do something with the view, including (possibly) dispatching based on + * // whether it is a host, device, managed, or pinned mdspan + * }, data.view()); + * } + * @endcode + * + * For convenience, runtime memory-type dispatching can also be performed + * without explicit use of `mdbuffer` using `raft::memory_type_dispatcher`, as + * described in @ref memory_type_dispatcher. Please see the full documentation + * of that function template for more extensive discussion of the many ways it + * can be used. To illustrate its connection to `mdbuffer`, however, consider + * the following example, which performs a similar task to the above + * `std::visit` call: + * + * @code{.cpp} + * void foo_device(raft::resources const& res, raft::device_matrix_view data) { + * // Implement foo solely for device data + * }; + * + * // Call foo with data of any memory type: + * template + * void foo(raft::resources const& res, mdspan_type data) { + * raft::memory_type_dispatcher(res, + * [&res](raft::device_matrix_view dev_data) {foo_device(res, dev_data);}, + * data + * ); + * } + * @endcode + * + * Here, the `memory_type_dispatcher` implicitly constructs an `mdbuffer` from + * the input and performs any necessary conversions before passing the input to + * `foo_device`. While `mdbuffer` does not require the use of + * `memory_type_dispatcher`, there are many common use cases in which explicit + * invocations of `mdbuffer` can be elided with `memory_type_dispatcher`. + * + * Finally, we should note that `mdbuffer` should almost never be passed as a + * const reference. To indicate const-ness of the underlying data, the + * `mdbuffer` should be constructed with a const memory type, but the mdbuffer + * itself should generally be passed as an rvalue reference in function + * arguments. Using an `mdbuffer` that is itself `const` is not strictly + * incorrect, but it indicates a likely misuse of the type. + * + * @tparam ElementType element type stored in the buffer + * @tparam Extents specifies the number of dimensions and their sizes + * @tparam LayoutPolicy specifies how data should be laid out in memory + * @tparam ContainerPolicy specifies how data should be allocated if necessary + * and how it should be accessed. This should very rarely need to be + * customized. For those cases where it must be customized, it is recommended + * to instantiate default_buffer_container_policy with a std::variant of + * container policies for each memory type. Note that the accessor policy of + * each container policy variant is used as the accessor policy for the mdspan + * view of the buffer for the corresponding memory type. + */ +template > +struct mdbuffer { + using extents_type = Extents; + using layout_type = LayoutPolicy; + using mapping_type = typename layout_type::template mapping; + using element_type = ElementType; + + using value_type = std::remove_cv_t; + using index_type = typename extents_type::index_type; + using difference_type = std::ptrdiff_t; + using rank_type = typename extents_type::rank_type; + + using container_policy_type = ContainerPolicy; + using accessor_policy_variant = typename ContainerPolicy::accessor_policy_variant; + + template + using accessor_policy = alternate_from_mem_type; + + using container_type_variant = typename container_policy_type::container_type_variant; + + template + using container_type = typename container_policy_type::template container_type; + + template + using owning_type = mdarray>; + // We use the static cast here to ensure that the memory types appear in the + // order expected for retrieving the correct variant alternative based on + // memory type. Even if the memory types are re-arranged in the enum and + // assigned different values, the logic should remain correct. + using owning_type_variant = std::variant(0)>, + owning_type(1)>, + owning_type(2)>, + owning_type(3)>>; + + template + using view_type = std::conditional_t, + typename owning_type::const_view_type, + typename owning_type::view_type>; + + using view_type_variant = std::variant(0)>, + view_type(1)>, + view_type(2)>, + view_type(3)>>; + + template + using const_view_type = typename owning_type::const_view_type; + using const_view_type_variant = std::variant(0)>, + const_view_type(1)>, + const_view_type(2)>, + const_view_type(3)>>; + + using storage_type_variant = concatenated_variant_t; + + // Non-owning types are stored first in the variant Thus, if we want to access the + // owning type corresponding to device memory, we would need to skip over the + // non-owning types and then go to the index which corresponds to the memory + // type: is_owning * num_non_owning_types + index = 1 * 4 + 1 = 5 + template + using storage_type = + std::variant_alternative_t + + std::size_t{variant_index_from_memory_type(MemType)}, + storage_type_variant>; + + /** + * @brief Construct an empty, uninitialized buffer + */ + constexpr mdbuffer() = default; + + private: + container_policy_type cp_{}; + storage_type_variant data_{}; + + // This template is used to determine whether or not is possible to copy from + // the mdspan returned by the view method of a FromT type mdbuffer with + // memory type indicated by FromIndex to the mdspan returned by this mdbuffer + // at ToIndex + template + auto static constexpr is_copyable_combination() + { + return detail::mdspan_copyable_v< + decltype(std::declval>().view()), + std::variant_alternative_t().view())>>; + } + + // Using an index_sequence to iterate over the possible memory types of this + // mdbuffer, we construct an array of bools to determine whether or not the + // mdspan returned by the view method of a FromT type mdbuffer with memory + // type indicated by FromIndex can be copied to the mdspan returned by this + // mdbuffer's view method at each memory type + template + auto static constexpr get_to_copyable_combinations(std::index_sequence) + { + return std::array{is_copyable_combination()...}; + } + + // Using an index_sequence to iterate over the possible memory types of the + // FromT type mdbuffer, we construct an array of arrays indicating whether it + // is possible to copy from any mdspan that can be returned from the FromT + // mdbuffer to any mdspan that can be returned from this mdbuffer + template + auto static constexpr get_from_copyable_combinations(std::index_sequence) + { + return std::array{get_to_copyable_combinations( + std::make_index_sequence>())...}; + } + + // Get an array of arrays indicating whether or not it is possible to copy + // from any given memory type of a FromT mdbuffer to any memory type of this + // mdbuffer + template + auto static constexpr get_copyable_combinations() + { + return get_from_copyable_combinations( + std::make_index_sequence().view())>>()); + } + + template + auto static constexpr is_copyable_from(std::index_sequence) + { + return (... || get_copyable_combinations()[FromIndex][Is]); + } + + template + auto static constexpr is_copyable_from(bool, std::index_sequence) + { + return (... || is_copyable_from( + std::make_index_sequence>())); + } + + template + auto static constexpr is_copyable_from() + { + return is_copyable_from( + true, + std::make_index_sequence().view())>>()); + } + + template + auto static is_copyable_from(FromT&& other, memory_type mem_type) + { + auto static copyable_combinations = get_copyable_combinations(); + return copyable_combinations[variant_index_from_memory_type(other.mem_type())] + [variant_index_from_memory_type(mem_type)]; + } + + template + auto static copy_from(raft::resources const& res, FromT&& other, memory_type mem_type) + { + auto result = storage_type_variant{}; + switch (mem_type) { + case memory_type::host: { + result = std::visit( + [&res](auto&& other_view) { + auto tmp_result = owning_type{ + res, + mapping_type{other_view.extents()}, + typename container_policy_type::template container_policy{}}; + raft::copy(res, tmp_result.view(), other_view); + return tmp_result; + }, + other.view()); + break; + } + case memory_type::device: { + result = std::visit( + [&res](auto&& other_view) { + auto tmp_result = owning_type{ + res, + mapping_type{other_view.extents()}, + typename container_policy_type::template container_policy{}}; + raft::copy(res, tmp_result.view(), other_view); + return tmp_result; + }, + other.view()); + break; + } + case memory_type::managed: { + result = std::visit( + [&res](auto&& other_view) { + auto tmp_result = owning_type{ + res, + mapping_type{other_view.extents()}, + typename container_policy_type::template container_policy{}}; + raft::copy(res, tmp_result.view(), other_view); + return tmp_result; + }, + other.view()); + break; + } + case memory_type::pinned: { + result = std::visit( + [&res](auto&& other_view) { + auto tmp_result = owning_type{ + res, + mapping_type{other_view.extents()}, + typename container_policy_type::template container_policy{}}; + raft::copy(res, tmp_result.view(), other_view); + return tmp_result; + }, + other.view()); + break; + } + } + return result; + } + + public: + /** + * @brief Construct an mdbuffer wrapping an existing mdspan. The resulting + * mdbuffer will be non-owning and match the memory type, layout, and + * element type of the mdspan. + */ + template < + typename OtherAccessorPolicy, + std::enable_if_t>* = nullptr> + mdbuffer(mdspan other) : data_{other} + { + } + + /** + * @brief Construct an mdbuffer of const elements wrapping an existing mdspan + * with non-const elements. The resulting mdbuffer will be non-owning and match the memory type, + * layout, and element type of the mdspan. + */ + template < + typename OtherElementType, + typename OtherAccessorPolicy, + std::enable_if_t && + std::is_same_v && + is_type_in_variant_v>* = nullptr> + mdbuffer(mdspan other) + : data_{raft::make_const_mdspan(other)} + { + } + + /** + * @brief Construct an mdbuffer to hold an existing mdarray rvalue. The + * mdarray will be moved into the mdbuffer, and the mdbuffer will be owning. + */ + template , + typename container_policy_type::container_policy_variant>>* = nullptr> + mdbuffer(mdarray&& other) + : data_{std::move(other)} + { + } + + /** + * @brief Construct an mdbuffer from an existing mdarray lvalue. An mdspan + * view will be taken from the mdarray in order to construct the mdbuffer, + * and the mdbuffer will be non-owning + */ + template , + typename container_policy_type::container_policy_variant>>* = nullptr> + mdbuffer(mdarray& other) + : mdbuffer{other.view()} + { + } + + /** + * @brief Construct one mdbuffer from another mdbuffer rvalue with matching + * element type, extents, layout, and container policy. + * + * If the existing mdbuffer is owning and of the correct memory type, + * the new mdbuffer will take ownership of the underlying memory + * (preventing a view on memory owned by a moved-from object). The memory + * type of the new mdbuffer may be specified explicitly, in which case a copy + * will be performed if and only if it is necessary to do so. + */ + mdbuffer(raft::resources const& res, + mdbuffer&& other, + std::optional specified_mem_type = std::nullopt) + : data_{[&res, &other, specified_mem_type, this]() { + auto other_mem_type = other.mem_type(); + auto mem_type = specified_mem_type.value_or(other_mem_type); + auto result = storage_type_variant{}; + if (mem_type == other.mem_type()) { + result = std::move(other.data_); + } else if (!other.is_owning() && has_compatible_accessibility(other_mem_type, mem_type) && + !is_host_device_accessible(mem_type)) { + switch (mem_type) { + case (memory_type::host): { + result = std::visit( + [&result, this](auto&& other_view) { + return view_type{ + other_view.data_handle(), + other_view.mapping(), + cp_.template make_accessor_policy()}; + }, + other.view()); + break; + } + case (memory_type::device): { + result = std::visit( + [&result, this](auto&& other_view) { + return view_type{ + other_view.data_handle(), + other_view.mapping(), + cp_.template make_accessor_policy()}; + }, + other.view()); + break; + } + case (memory_type::managed): { + result = std::visit( + [&result, this](auto&& other_view) { + return view_type{ + other_view.data_handle(), + other_view.mapping(), + cp_.template make_accessor_policy()}; + }, + other.view()); + break; + } + case (memory_type::pinned): { + result = std::visit( + [&result, this](auto&& other_view) { + return view_type{ + other_view.data_handle(), + other_view.mapping(), + cp_.template make_accessor_policy()}; + }, + other.view()); + break; + } + } + } else { + result = copy_from(res, other, mem_type); + } + return result; + }()} + { + } + + /** + * @brief Construct one mdbuffer from another mdbuffer lvalue with matching + * element type, extents, layout, and container policy. + * + * Unlike when constructing from an rvalue, the new mdbuffer will take a + * non-owning view whenever possible, since it is assumed that the caller + * will manage the lifetime of the lvalue input. Note that the mdbuffer + * passed here must itself be non-const in order to allow this constructor to + * provide an equivalent view of the underlying data. To indicate const-ness + * of the underlying data, mdbuffers should be constructed with a const + * ElementType. + */ + mdbuffer(raft::resources const& res, + mdbuffer& other, /* NOLINT */ + std::optional specified_mem_type = std::nullopt) + : data_{[&res, &other, specified_mem_type, this]() { + auto mem_type = specified_mem_type.value_or(other.mem_type()); + auto result = storage_type_variant{}; + auto other_mem_type = other.mem_type(); + if (mem_type == other_mem_type) { + std::visit([&result](auto&& other_view) { result = other_view; }, other.view()); + } else if (has_compatible_accessibility(other_mem_type, mem_type) && + !is_host_device_accessible(mem_type)) { + switch (mem_type) { + case (memory_type::host): { + result = std::visit( + [&result, this](auto&& other_view) { + return view_type{ + other_view.data_handle(), + other_view.mapping(), + cp_.template make_accessor_policy()}; + }, + other.view()); + break; + } + case (memory_type::device): { + result = std::visit( + [&result, this](auto&& other_view) { + return view_type{ + other_view.data_handle(), + other_view.mapping(), + cp_.template make_accessor_policy()}; + }, + other.view()); + break; + } + case (memory_type::managed): { + result = std::visit( + [&result, this](auto&& other_view) { + return view_type{ + other_view.data_handle(), + other_view.mapping(), + cp_.template make_accessor_policy()}; + }, + other.view()); + break; + } + case (memory_type::pinned): { + result = std::visit( + [&result, this](auto&& other_view) { + return view_type{ + other_view.data_handle(), + other_view.mapping(), + cp_.template make_accessor_policy()}; + }, + other.view()); + break; + } + } + } else { + result = copy_from(res, other, mem_type); + } + return result; + }()} + { + } + + /** + * @brief Construct an mdbuffer from an existing mdbuffer with arbitrary but + * compatible element type, extents, layout, and container policy. This + * constructor is used to coerce data to specific element types, layouts, + * or extents as well as specifying a memory type. + */ + template < + typename OtherElementType, + typename OtherExtents, + typename OtherLayoutPolicy, + typename OtherContainerPolicy, + std::enable_if_t>()>* = + nullptr> + mdbuffer( + raft::resources const& res, + mdbuffer const& other, + std::optional specified_mem_type = std::nullopt) + : data_{[&res, &other, specified_mem_type]() { + auto mem_type = specified_mem_type.value_or(other.mem_type()); + // Note: We perform this check at runtime because it is possible for two + // mdbuffers to have storage types which may be copied to each other for + // some memory types but not for others. This is an unusual situation, but + // we still need to guard against it. + RAFT_EXPECTS( + is_copyable_from(other, mem_type), + "mdbuffer cannot be constructed from other mdbuffer with indicated memory type"); + return copy_from(res, other, mem_type); + }()} + { + } + + /** + * @brief Return the memory type of the underlying data referenced by the + * mdbuffer + */ + [[nodiscard]] auto constexpr mem_type() const + { + return static_cast(data_.index() % std::variant_size_v); + }; + + /** + * @brief Return a boolean indicating whether or not the mdbuffer owns its + * storage + */ + [[nodiscard]] auto constexpr is_owning() const + { + return data_.index() >= std::variant_size_v; + }; + + private: + template + [[nodiscard]] auto view() + { + if constexpr (MemTypeConstant::value.has_value()) { + if (is_owning()) { + if constexpr (std::is_const_v) { + return std::as_const(std::get>(data_)).view(); + } else { + return std::get>(data_).view(); + } + } else { + return std::get>(data_); + } + } else { + return std::visit( + [](auto&& inner) { + if constexpr (is_mdspan_v>) { + return view_type_variant{inner}; + } else { + if constexpr (std::is_const_v) { + return view_type_variant{std::as_const(inner).view()}; + } else { + return view_type_variant{inner.view()}; + } + } + }, + data_); + } + } + + template + [[nodiscard]] auto view() const + { + if constexpr (MemTypeConstant::value.has_value()) { + if (is_owning()) { + return make_const_mdspan( + std::get>(data_).view()); + } else { + return make_const_mdspan(std::get>(data_)); + } + } else { + return std::visit( + [](auto&& inner) { + if constexpr (is_mdspan_v>) { + return const_view_type_variant{make_const_mdspan(inner)}; + } else { + return const_view_type_variant{make_const_mdspan(inner.view())}; + } + }, + data_); + } + } + + public: + /** + * @brief Return an mdspan of the indicated memory type representing a view + * on the stored data. If the mdbuffer does not contain data of the indicated + * memory type, a std::bad_variant_access will be thrown. + */ + template + [[nodiscard]] auto view() + { + return view>(); + } + /** + * @brief Return an mdspan containing const elements of the indicated memory type representing a + * view on the stored data. If the mdbuffer does not contain data of the indicated memory type, a + * std::bad_variant_access will be thrown. + */ + template + [[nodiscard]] auto view() const + { + return view>(); + } + /** + * @brief Return a std::variant representing the possible mdspan types that + * could be returned as views on the mdbuffer. The variant will contain the mdspan + * corresponding to its current memory type. + * + * This method is useful for writing generic code to handle any memory type + * that might be contained in an mdbuffer at a particular point in a + * workflow. By performing a `std::visit` on the returned value, the caller + * can easily dispatch to the correct code path for the memory type. + */ + [[nodiscard]] auto view() { return view>(); } + /** + * @brief Return a std::variant representing the possible mdspan types that + * could be returned as const views on the mdbuffer. The variant will contain the mdspan + * corresponding to its current memory type. + * + * This method is useful for writing generic code to handle any memory type + * that might be contained in an mdbuffer at a particular point in a + * workflow. By performing a `std::visit` on the returned value, the caller + * can easily dispatch to the correct code path for the memory type. + */ + [[nodiscard]] auto view() const { return view>(); } +}; + +/** + * @\brief Template checks and helpers to determine if type T is an mdbuffer + * or a derived type + */ + +template +void __takes_an_mdbuffer_ptr(mdbuffer*); + +template +struct is_mdbuffer : std::false_type {}; +template +struct is_mdbuffer()))>> + : std::true_type {}; + +template +struct is_input_mdbuffer : std::false_type {}; +template +struct is_input_mdbuffer()))>> + : std::bool_constant> {}; + +template +struct is_output_mdbuffer : std::false_type {}; +template +struct is_output_mdbuffer()))>> + : std::bool_constant> {}; + +template +using is_mdbuffer_t = is_mdbuffer>; + +template +using is_input_mdbuffer_t = is_input_mdbuffer; + +template +using is_output_mdbuffer_t = is_output_mdbuffer; + +/** + * @\brief Boolean to determine if variadic template types Tn are + * raft::mdbuffer or derived types + */ +template +inline constexpr bool is_mdbuffer_v = std::conjunction_v...>; + +template +using enable_if_mdbuffer = std::enable_if_t>; + +template +inline constexpr bool is_input_mdbuffer_v = std::conjunction_v...>; + +template +using enable_if_input_mdbuffer = std::enable_if_t>; + +template +inline constexpr bool is_output_mdbuffer_v = std::conjunction_v...>; + +template +using enable_if_output_mdbuffer = std::enable_if_t>; + +/** @} */ + +} // namespace raft diff --git a/cpp/include/raft/core/mdbuffer.hpp b/cpp/include/raft/core/mdbuffer.hpp new file mode 100644 index 0000000000..8281b5c6d6 --- /dev/null +++ b/cpp/include/raft/core/mdbuffer.hpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_DISABLE_CUDA +#pragma message(__FILE__ \ + " should only be used in CUDA-disabled RAFT builds." \ + " Please use equivalent .cuh header instead.") +#else +// It is safe to include this cuh file in an hpp header because all CUDA code +// is ifdef'd out for CUDA-disabled builds. +#include +#endif diff --git a/cpp/include/raft/core/memory_type.hpp b/cpp/include/raft/core/memory_type.hpp index cd37a0ee50..7849cd67ab 100644 --- a/cpp/include/raft/core/memory_type.hpp +++ b/cpp/include/raft/core/memory_type.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,28 @@ * limitations under the License. */ #pragma once +#include +#include +#ifndef RAFT_DISABLE_CUDA +#include +#include +#include +#else +#include +#endif namespace raft { -enum class memory_type { host, device, managed, pinned }; +enum class memory_type : std::uint8_t { + host = std::uint8_t{0}, + pinned = std::uint8_t{1}, + device = std::uint8_t{2}, + managed = std::uint8_t{3} +}; auto constexpr is_device_accessible(memory_type mem_type) { - return (mem_type == memory_type::device || mem_type == memory_type::managed); + return (mem_type == memory_type::device || mem_type == memory_type::managed || + mem_type == memory_type::pinned); } auto constexpr is_host_accessible(memory_type mem_type) { @@ -32,6 +47,22 @@ auto constexpr is_host_device_accessible(memory_type mem_type) return is_device_accessible(mem_type) && is_host_accessible(mem_type); } +auto constexpr has_compatible_accessibility(memory_type old_mem_type, memory_type new_mem_type) +{ + return ((!is_device_accessible(new_mem_type) || is_device_accessible(old_mem_type)) && + (!is_host_accessible(new_mem_type) || is_host_accessible(old_mem_type))); +} + +template +struct memory_type_constant { + static_assert(sizeof...(mem_types) < 2, "At most one memory type can be specified"); + auto static constexpr value = []() { + auto result = std::optional{}; + if constexpr (sizeof...(mem_types) == 1) { result = std::make_optional(mem_types...); } + return result; + }(); +}; + namespace detail { template @@ -49,4 +80,23 @@ auto constexpr memory_type_from_access() } } // end namespace detail + +template +auto memory_type_from_pointer(T* ptr) +{ + auto result = memory_type::host; +#ifndef RAFT_DISABLE_CUDA + auto attrs = cudaPointerAttributes{}; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attrs, ptr)); + switch (attrs.type) { + case cudaMemoryTypeDevice: result = memory_type::device; break; + case cudaMemoryTypeHost: result = memory_type::host; break; + case cudaMemoryTypeManaged: result = memory_type::managed; break; + default: result = memory_type::host; + } +#else + RAFT_LOG_DEBUG("RAFT compiled without CUDA support, assuming pointer is host pointer"); +#endif + return result; +} } // end namespace raft diff --git a/cpp/include/raft/core/pinned_container_policy.hpp b/cpp/include/raft/core/pinned_container_policy.hpp new file mode 100644 index 0000000000..51451deadb --- /dev/null +++ b/cpp/include/raft/core/pinned_container_policy.hpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#ifndef RAFT_DISABLE_CUDA +#include +#include +#include +#else +#include +#endif + +namespace raft { +#ifndef RAFT_DISABLE_CUDA + +/** + * @brief A thin wrapper over thrust::host_vector for implementing the pinned mdarray container + * policy. + * + */ +template +struct pinned_container { + using value_type = T; + using allocator_type = + thrust::mr::stateless_resource_allocator; + + private: + using underlying_container_type = thrust::host_vector; + underlying_container_type data_; + + public: + using size_type = std::size_t; + + using reference = value_type&; + using const_reference = value_type const&; + + using pointer = value_type*; + using const_pointer = value_type const*; + + using iterator = pointer; + using const_iterator = const_pointer; + + ~pinned_container() = default; + pinned_container(pinned_container&&) noexcept = default; + pinned_container(pinned_container const& that) : data_{that.data_} {} + + auto operator=(pinned_container const& that) -> pinned_container& + { + data_ = underlying_container_type{that.data_}; + return *this; + } + auto operator=(pinned_container&& that) noexcept -> pinned_container& = default; + + /** + * @brief Ctor that accepts a size. + */ + explicit pinned_container(std::size_t size, allocator_type const& alloc) : data_{size, alloc} {} + /** + * @brief Index operator that returns a reference to the actual data. + */ + template + auto operator[](Index i) noexcept -> reference + { + return data_[i]; + } + /** + * @brief Index operator that returns a reference to the actual data. + */ + template + auto operator[](Index i) const noexcept + { + return data_[i]; + } + + void resize(size_type size) { data_.resize(size, data_.stream()); } + + [[nodiscard]] auto data() noexcept -> pointer { return data_.data().get(); } + [[nodiscard]] auto data() const noexcept -> const_pointer { return data_.data().get(); } +}; + +/** + * @brief A container policy for pinned mdarray. + */ +template +struct pinned_vector_policy { + using element_type = ElementType; + using container_type = pinned_container; + using allocator_type = typename container_type::allocator_type; + using pointer = typename container_type::pointer; + using const_pointer = typename container_type::const_pointer; + using reference = typename container_type::reference; + using const_reference = typename container_type::const_reference; + using accessor_policy = std::experimental::default_accessor; + using const_accessor_policy = std::experimental::default_accessor; + + auto create(raft::resources const&, size_t n) -> container_type + { + return container_type(n, allocator_); + } + + constexpr pinned_vector_policy() noexcept(std::is_nothrow_default_constructible_v) + : allocator_{} + { + } + + [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference + { + return c[n]; + } + [[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept + -> const_reference + { + return c[n]; + } + + [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } + [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } + + private: + allocator_type allocator_; +}; +#else +template +using pinned_vector_policy = detail::fail_container_policy; +#endif +} // namespace raft diff --git a/cpp/include/raft/core/pinned_mdarray.hpp b/cpp/include/raft/core/pinned_mdarray.hpp new file mode 100644 index 0000000000..72b8d52e0d --- /dev/null +++ b/cpp/include/raft/core/pinned_mdarray.hpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief mdarray with pinned container policy + * @tparam ElementType the data type of the elements + * @tparam Extents defines the shape + * @tparam LayoutPolicy policy for indexing strides and layout ordering + * @tparam ContainerPolicy storage and accessor policy + */ +template > +using pinned_mdarray = + mdarray>; + +/** + * @brief Shorthand for 0-dim host mdarray (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using pinned_scalar = pinned_mdarray>; + +/** + * @brief Shorthand for 1-dim pinned mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using pinned_vector = pinned_mdarray, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous pinned matrix. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using pinned_matrix = pinned_mdarray, LayoutPolicy>; + +/** + * @brief Create a pinned mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param handle raft::resources + * @param exts dimensionality of the array (series of integers) + * @return raft::pinned_mdarray + */ +template +auto make_pinned_mdarray(raft::resources const& handle, extents exts) +{ + using mdarray_t = pinned_mdarray; + + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy{}; + + return mdarray_t{handle, layout, policy}; +} + +/** + * @brief Create a 2-dim c-contiguous pinned mdarray. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] handle raft handle for managing expensive resources + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @return raft::pinned_matrix + */ +template +auto make_pinned_matrix(raft::resources const& handle, IndexType n_rows, IndexType n_cols) +{ + return make_pinned_mdarray( + handle, make_extents(n_rows, n_cols)); +} + +/** + * @brief Create a pinned scalar from v. + * + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] v scalar to wrap on pinned + * @return raft::pinned_scalar + */ +template +auto make_pinned_scalar(raft::resources const& handle, ElementType const& v) +{ + scalar_extent extents; + using policy_t = typename pinned_scalar::container_policy_type; + policy_t policy{}; + auto scalar = pinned_scalar{handle, extents, policy}; + scalar(0) = v; + return scalar; +} + +/** + * @brief Create a 1-dim pinned mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] n number of elements in vector + * @return raft::pinned_vector + */ +template +auto make_pinned_vector(raft::resources const& handle, IndexType n) +{ + return make_pinned_mdarray(handle, + make_extents(n)); +} + +} // end namespace raft diff --git a/cpp/include/raft/core/pinned_mdspan.hpp b/cpp/include/raft/core/pinned_mdspan.hpp new file mode 100644 index 0000000000..e764101d1c --- /dev/null +++ b/cpp/include/raft/core/pinned_mdspan.hpp @@ -0,0 +1,270 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft { + +template +using pinned_accessor = host_device_accessor; + +/** + * @brief std::experimental::mdspan with pinned tag to indicate host/device accessibility + */ +template > +using pinned_mdspan = mdspan>; + +template +struct is_pinned_mdspan : std::false_type {}; +template +struct is_pinned_mdspan + : std::bool_constant {}; + +/** + * @\brief Boolean to determine if template type T is either raft::pinned_mdspan or a derived type + */ +template +using is_pinned_mdspan_t = is_pinned_mdspan>; + +template +using is_input_pinned_mdspan_t = is_pinned_mdspan>; + +template +using is_output_pinned_mdspan_t = is_pinned_mdspan>; + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::pinned_mdspan or a + * derived type + */ +template +inline constexpr bool is_pinned_mdspan_v = std::conjunction_v...>; + +template +inline constexpr bool is_input_pinned_mdspan_v = + std::conjunction_v...>; + +template +inline constexpr bool is_output_pinned_mdspan_v = + std::conjunction_v...>; + +template +using enable_if_pinned_mdspan = std::enable_if_t>; + +template +using enable_if_input_pinned_mdspan = std::enable_if_t>; + +template +using enable_if_output_pinned_mdspan = std::enable_if_t>; + +/** + * @brief Shorthand for 0-dim pinned mdspan (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using pinned_scalar_view = pinned_mdspan>; + +/** + * @brief Shorthand for 1-dim pinned mdspan. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using pinned_vector_view = pinned_mdspan, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous pinned matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using pinned_matrix_view = pinned_mdspan, LayoutPolicy>; + +/** + * @brief Shorthand for 128 byte aligned pinned matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy must be of type layout_{left/right}_padded + */ +template , + typename = enable_if_layout_padded> +using pinned_aligned_matrix_view = + pinned_mdspan, + LayoutPolicy, + std::experimental::aligned_accessor>; + +/** + * @brief Create a 2-dim 128 byte aligned mdspan instance for pinned pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy must be of type layout_{left/right}_padded + * @tparam IndexType the index type of the extents + * @param[in] ptr to pinned memory to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + */ +template > +auto constexpr make_pinned_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +{ + using data_handle_type = + typename std::experimental::aligned_accessor::data_handle_type; + static_assert(std::is_same>::value || + std::is_same>::value); + assert(reinterpret_cast(ptr) == + std::experimental::details::alignTo(reinterpret_cast(ptr), + detail::alignment::value)); + + data_handle_type aligned_pointer = ptr; + + matrix_extent extents{n_rows, n_cols}; + return pinned_aligned_matrix_view{aligned_pointer, extents}; +} + +/** + * @brief Create a 0-dim (scalar) mdspan instance for pinned value. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @param[in] ptr to pinned memory to wrap + */ +template +auto constexpr make_pinned_scalar_view(ElementType* ptr) +{ + scalar_extent extents; + return pinned_scalar_view{ptr, extents}; +} + +/** + * @brief Create a 2-dim c-contiguous mdspan instance for pinned pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @tparam IndexType the index type of the extents + * @param[in] ptr to pinned memory to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + */ +template +auto constexpr make_pinned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +{ + matrix_extent extents{n_rows, n_cols}; + return pinned_matrix_view{ptr, extents}; +} + +/** + * @brief Create a 2-dim mdspan instance for pinned pointer with a strided layout + * that is restricted to stride 1 in the trailing dimension. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr to pinned memory to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + * @param[in] stride leading dimension / stride of data + */ +template +auto constexpr make_pinned_strided_matrix_view(ElementType* ptr, + IndexType n_rows, + IndexType n_cols, + IndexType stride) +{ + constexpr auto is_row_major = std::is_same_v; + IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; + IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows); + + assert(is_row_major ? stride0 >= n_cols : stride1 >= n_rows); + matrix_extent extents{n_rows, n_cols}; + + auto layout = make_strided_layout(extents, std::array{stride0, stride1}); + return pinned_matrix_view{ptr, layout}; +} + +/** + * @brief Create a 1-dim mdspan instance for pinned pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr to pinned memory to wrap + * @param[in] n number of elements in pointer + * @return raft::pinned_vector_view + */ +template +auto constexpr make_pinned_vector_view(ElementType* ptr, IndexType n) +{ + return pinned_vector_view{ptr, n}; +} + +/** + * @brief Create a 1-dim mdspan instance for pinned pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr to pinned memory to wrap + * @param[in] mapping The layout mapping to use for this vector + * @return raft::pinned_vector_view + */ +template +auto constexpr make_pinned_vector_view( + ElementType* ptr, + const typename LayoutPolicy::template mapping>& mapping) +{ + return pinned_vector_view{ptr, mapping}; +} + +/** + * @brief Create a raft::pinned_mdspan + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param ptr Pointer to the data + * @param exts dimensionality of the array (series of integers) + * @return raft::pinned_mdspan + */ +template +auto constexpr make_pinned_mdspan(ElementType* ptr, extents exts) +{ + return make_mdspan(ptr, exts); +} +} // end namespace raft diff --git a/cpp/include/raft/core/serialize.hpp b/cpp/include/raft/core/serialize.hpp index b2fef8c6ef..7e3aab8b89 100644 --- a/cpp/include/raft/core/serialize.hpp +++ b/cpp/include/raft/core/serialize.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include #include +#include #include #include diff --git a/cpp/include/raft/core/stream_view.hpp b/cpp/include/raft/core/stream_view.hpp index f7e7934dbf..128050c414 100644 --- a/cpp/include/raft/core/stream_view.hpp +++ b/cpp/include/raft/core/stream_view.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#pragma once #include #include #include diff --git a/cpp/include/raft/util/memory_type_dispatcher.cuh b/cpp/include/raft/util/memory_type_dispatcher.cuh new file mode 100644 index 0000000000..94d838415a --- /dev/null +++ b/cpp/include/raft/util/memory_type_dispatcher.cuh @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +namespace raft { + +namespace detail { + +template +struct is_callable : std::false_type {}; + +template +struct is_callable()(std::declval()))>> + : std::true_type {}; + +template * = nullptr> +auto static constexpr is_callable_for_memory_type = + is_callable().template view())>::value; + +} // namespace detail + +/** + * @defgroup memory_type_dispatcher Dispatch functor based on memory type + * @{ + */ + +/** + * @brief Dispatch to various specializations of a functor which accepts an + * mdspan based on the mdspan's memory type + * + * This function template is used to dispatch to one or more implementations + * of a function based on memory type. For instance, if a functor has been + * implemented with an operator that accepts only a `device_mdspan`, input data + * can be passed to that functor with minimal copies or allocations by wrapping + * the functor in this template. + * + * More specifically, host memory data will be copied to device before being + * passed to the functor as a `device_mdspan`. Device, managed, and pinned data + * will be passed directly to the functor as a `device_mdspan`. + * + * If the functor's operator were _also_ specialized for `host_mdspan`, then + * this wrapper would pass an input `host_mdspan` directly to the corresponding + * specialization. + * + * If a functor explicitly specializes for managed/pinned memory and receives + * managed/pinned input, the corresponding specialization will be invoked. If the functor does not + * specialize for either, it will preferentially invoke the device + * specialization if available and then the host specialization. Managed input + * will never be dispatched to an explicit specialization for pinned memory and + * vice versa. + * + * Dispatching is performed by coercing the input mdspan to an mdbuffer of the + * correct type. If it is necessary to coerce the input data to a different + * data type (e.g. floats to doubles) or to a different memory layout, this can + * be done by passing an explicit mdbuffer type to the `memory_type_dispatcher` + * template. + * + * Usage example: + * @code{.cpp} + * // Functor which accepts only a `device_mdspan` or `managed_mdspan` of + * // doubles in C-contiguous layout. We wish to be able to call this + * // functor on any compatible data, regardless of data type, memory type, + * // or layout. + * struct functor { + * auto operator()(device_matrix_view data) { + * // Do something with data on device + * }; + * auto operator()(managed_matrix_view data) { + * // Do something with data, taking advantage of knowledge that + * // underlying memory is managed + * }; + * }; + * + * auto rows = 3; + * auto cols = 5; + * auto res = raft::device_resources{}; + * + * auto host_data = raft::make_host_matrix(rows, cols); + * // functor{}(host_data.view()); // This would fail to compile + * auto device_data = raft::make_device_matrix(res, rows, cols); + * functor{}(device_data.view()); // Functor accepts device mdspan + * auto managed_data = raft::make_managed_matrix(res, rows, cols); + * // functor{}(managed_data.view()); // Functor accepts managed mdspan + * auto pinned_data = raft::make_managed_matrix(res, rows, cols); + * functor{}(pinned_data.view()); // This would fail to compile + * auto float_data = raft::make_device_matrix(res, rows, cols); + * // functor{}(float_data.view()); // This would fail to compile + * auto f_data = raft::make_device_matrix(res, rows, cols); + * // functor{}(f_data.view()); // This would fail to compile + * + * // `memory_type_dispatcher` lets us call this functor on all of the above + * raft::memory_type_dispatcher(res, functor{}, host_data.view()); + * raft::memory_type_dispatcher(res, functor{}, device_data.view()); + * raft::memory_type_dispatcher(res, functor{}, managed_data.view()); + * raft::memory_type_dispatcher(res, functor{}, pinned_data.view()); + * // Here, we use the mdbuffer type template parameter to ensure that the data + * // type and layout are as expected by the functor + * raft::memory_type_dispatcher>>(res, functor{}, + * float_data.view()); raft::memory_type_dispatcher>>(res, functor{}, f_data.view()); + * @endcode + * + * As this example shows, `memory_type_dispatcher` can be used to dispatch any + * compatible mdspan input to a functor, regardless of the mdspan type(s) that + * functor supports. + */ +template * = nullptr> +decltype(auto) memory_type_dispatcher(raft::resources const& res, lambda_t&& f, mdbuffer_type&& buf) +{ + if (is_host_device_accessible(buf.mem_type())) { + // First see if functor has been specialized for this exact memory type + if constexpr (detail:: + is_callable_for_memory_type) { + if (buf.mem_type() == memory_type::managed) { + return f(buf.template view()); + } + } + if constexpr (detail:: + is_callable_for_memory_type) { + if (buf.mem_type() == memory_type::pinned) { + return f(buf.template view()); + } + } + } + // If the functor is specialized for device and the data are + // device-accessible, use the device specialization + if constexpr (detail::is_callable_for_memory_type) { + if (is_device_accessible(buf.mem_type())) { + return f(mdbuffer{res, buf, memory_type::device}.template view()); + } + // If there is no host specialization, still use the device specialization + if constexpr (!detail:: + is_callable_for_memory_type) { + return f(mdbuffer{res, buf, memory_type::device}.template view()); + } + } + + // If nothing else has worked, use the host specialization + if constexpr (detail::is_callable_for_memory_type) { + return f(mdbuffer{res, buf, memory_type::host}.template view()); + } + + // In the extremely rare case that the functor has been specialized _only_ + // for either pinned memory, managed memory, or both, and the input data are + // neither pinned nor managed, we must perform a copy. In this situation, if + // we have specializations for both pinned and managed memory, we arbitrarily + // prefer the managed specialization. Note that if the data _are_ either + // pinned or managed already, we will have already invoked the correct + // specialization above. + if constexpr (detail:: + is_callable_for_memory_type) { + return f(mdbuffer{res, buf, memory_type::managed}.template view()); + } else if constexpr (detail::is_callable_for_memory_type) { + return f(mdbuffer{res, buf, memory_type::pinned}.template view()); + } + + // Suppress warning for unreachable loop. In general, it is a desirable thing + // for this to be unreachable, but some functors may be specialized in such a + // way that this is not the case. +#pragma nv_diag_suppress 128 + RAFT_FAIL("The given functor could not be invoked on the provided data"); +#pragma nv_diag_default 128 +} + +template * = nullptr> +decltype(auto) memory_type_dispatcher(raft::resources const& res, lambda_t&& f, mdspan_type view) +{ + return memory_type_dispatcher(res, std::forward(f), mdbuffer{view}); +} + +template * = nullptr, + enable_if_mdspan* = nullptr> +decltype(auto) memory_type_dispatcher(raft::resources const& res, lambda_t&& f, mdspan_type view) +{ + return memory_type_dispatcher(res, std::forward(f), mdbuffer_type{res, mdbuffer{view}}); +} + +/** @} */ + +} // namespace raft diff --git a/cpp/include/raft/util/variant_utils.hpp b/cpp/include/raft/util/variant_utils.hpp new file mode 100644 index 0000000000..26ca2b7eb4 --- /dev/null +++ b/cpp/include/raft/util/variant_utils.hpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft { + +template +struct concatenated_variant; + +template +struct concatenated_variant, std::variant> { + using type = std::variant; +}; + +template +using concatenated_variant_t = typename concatenated_variant::type; + +template +auto fast_visit(visitor_t&& visitor, variant_t&& variant) +{ + using return_t = decltype(std::forward(visitor)(std::get<0>(variant))); + auto result = return_t{}; + + if constexpr (index == + std::variant_size_v>>) { + __builtin_unreachable(); + } else { + if (index == variant.index()) { + result = std::forward(visitor)(std::get(std::forward(variant))); + } else { + result = fast_visit(std::forward(visitor), + std::forward(variant)); + } + } + return result; +} + +template +struct is_type_in_variant; + +template +struct is_type_in_variant> { + static constexpr bool value = (std::is_same_v || ...); +}; + +template +auto static constexpr is_type_in_variant_v = is_type_in_variant::value; + +} // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index f043442840..6e32281ec0 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -124,6 +124,7 @@ if(BUILD_TESTS) test/core/interruptible.cu test/core/nvtx.cpp test/core/mdarray.cu + test/core/mdbuffer.cu test/core/mdspan_copy.cpp test/core/mdspan_copy.cu test/core/mdspan_utils.cu @@ -460,6 +461,7 @@ if(BUILD_TESTS) test/util/device_atomics.cu test/util/integer_utils.cpp test/util/integer_utils.cu + test/util/memory_type_dispatcher.cu test/util/pow2_utils.cu test/util/reduction.cu ) diff --git a/cpp/test/core/mdarray.cu b/cpp/test/core/mdarray.cu index 86e51be2e4..b0ab36c6e3 100644 --- a/cpp/test/core/mdarray.cu +++ b/cpp/test/core/mdarray.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/test/core/mdbuffer.cu b/cpp/test/core/mdbuffer.cu new file mode 100644 index 0000000000..d93d532938 --- /dev/null +++ b/cpp/test/core/mdbuffer.cu @@ -0,0 +1,330 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +TEST(MDBuffer, FromHost) +{ + auto res = device_resources{}; + auto constexpr depth = std::uint32_t{5}; + auto constexpr rows = std::uint32_t{3}; + auto constexpr cols = std::uint32_t{2}; + auto data = make_host_mdarray( + res, extents{}); + + auto buffer = mdbuffer(data); + EXPECT_FALSE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::host); + EXPECT_EQ(buffer.view().data_handle(), data.data_handle()); + EXPECT_EQ(std::as_const(buffer).view().data_handle(), data.data_handle()); + EXPECT_EQ(buffer.view().data_handle(), + std::as_const(buffer).view().data_handle()); + EXPECT_EQ(buffer.view().index(), variant_index_from_memory_type(memory_type::host)); + + buffer = mdbuffer(data.view()); + EXPECT_FALSE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::host); + EXPECT_EQ(buffer.view().data_handle(), data.data_handle()); + EXPECT_EQ(std::as_const(buffer).view().data_handle(), data.data_handle()); + EXPECT_EQ(buffer.view().data_handle(), + std::as_const(buffer).view().data_handle()); + + auto original_data_handle = data.data_handle(); + buffer = mdbuffer(std::move(data)); + EXPECT_TRUE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::host); + EXPECT_EQ(buffer.view().data_handle(), original_data_handle); + + auto buffer2 = mdbuffer(res, buffer); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::host); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::host); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::host); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::device); + EXPECT_TRUE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::device); + EXPECT_NE(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::managed); + EXPECT_TRUE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::managed); + EXPECT_NE(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::pinned); + EXPECT_TRUE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::pinned); + EXPECT_NE(buffer2.view().data_handle(), + buffer.view().data_handle()); +} + +TEST(MDBuffer, FromDevice) +{ + auto res = device_resources{}; + auto constexpr depth = std::uint32_t{5}; + auto constexpr rows = std::uint32_t{3}; + auto constexpr cols = std::uint32_t{2}; + auto data = make_device_mdarray( + res, extents{}); + + auto buffer = mdbuffer(data); + EXPECT_FALSE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::device); + EXPECT_EQ(buffer.view().data_handle(), data.data_handle()); + EXPECT_EQ(std::as_const(buffer).view().data_handle(), data.data_handle()); + EXPECT_EQ(buffer.view().data_handle(), + std::as_const(buffer).view().data_handle()); + EXPECT_EQ(buffer.view().index(), variant_index_from_memory_type(memory_type::device)); + + buffer = mdbuffer(data.view()); + EXPECT_FALSE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::device); + EXPECT_EQ(buffer.view().data_handle(), data.data_handle()); + EXPECT_EQ(std::as_const(buffer).view().data_handle(), data.data_handle()); + EXPECT_EQ(buffer.view().data_handle(), + std::as_const(buffer).view().data_handle()); + + auto original_data_handle = data.data_handle(); + buffer = mdbuffer(std::move(data)); + EXPECT_TRUE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::device); + EXPECT_EQ(buffer.view().data_handle(), original_data_handle); + + auto buffer2 = mdbuffer(res, buffer); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::device); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::host); + EXPECT_TRUE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::host); + EXPECT_NE(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::device); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::device); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::managed); + EXPECT_TRUE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::managed); + EXPECT_NE(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::pinned); + EXPECT_TRUE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::pinned); + EXPECT_NE(buffer2.view().data_handle(), + buffer.view().data_handle()); +} + +TEST(MDBuffer, FromManaged) +{ + auto res = device_resources{}; + auto constexpr depth = std::uint32_t{5}; + auto constexpr rows = std::uint32_t{3}; + auto constexpr cols = std::uint32_t{2}; + auto data = make_managed_mdarray( + res, extents{}); + + auto buffer = mdbuffer(data); + EXPECT_FALSE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::managed); + EXPECT_EQ(buffer.view().data_handle(), data.data_handle()); + EXPECT_EQ(std::as_const(buffer).view().data_handle(), data.data_handle()); + EXPECT_EQ(buffer.view().data_handle(), + std::as_const(buffer).view().data_handle()); + EXPECT_EQ(buffer.view().index(), variant_index_from_memory_type(memory_type::managed)); + + buffer = mdbuffer(data.view()); + EXPECT_FALSE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::managed); + EXPECT_EQ(buffer.view().data_handle(), data.data_handle()); + EXPECT_EQ(std::as_const(buffer).view().data_handle(), data.data_handle()); + EXPECT_EQ(buffer.view().data_handle(), + std::as_const(buffer).view().data_handle()); + + auto original_data_handle = data.data_handle(); + buffer = mdbuffer(std::move(data)); + EXPECT_TRUE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::managed); + EXPECT_EQ(buffer.view().data_handle(), original_data_handle); + + auto buffer2 = mdbuffer(res, buffer); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::managed); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::host); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::host); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::device); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::device); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::managed); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::managed); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::pinned); + EXPECT_TRUE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::pinned); + EXPECT_NE(buffer2.view().data_handle(), + buffer.view().data_handle()); +} + +TEST(MDBuffer, FromPinned) +{ + auto res = device_resources{}; + auto constexpr depth = std::uint32_t{5}; + auto constexpr rows = std::uint32_t{3}; + auto constexpr cols = std::uint32_t{2}; + auto data = make_pinned_mdarray( + res, extents{}); + + auto buffer = mdbuffer(data); + EXPECT_FALSE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::pinned); + EXPECT_EQ(buffer.view().data_handle(), data.data_handle()); + EXPECT_EQ(std::as_const(buffer).view().data_handle(), data.data_handle()); + EXPECT_EQ(buffer.view().data_handle(), + std::as_const(buffer).view().data_handle()); + EXPECT_EQ(buffer.view().index(), variant_index_from_memory_type(memory_type::pinned)); + + buffer = mdbuffer(data.view()); + EXPECT_FALSE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::pinned); + EXPECT_EQ(buffer.view().data_handle(), data.data_handle()); + EXPECT_EQ(std::as_const(buffer).view().data_handle(), data.data_handle()); + EXPECT_EQ(buffer.view().data_handle(), + std::as_const(buffer).view().data_handle()); + + auto original_data_handle = data.data_handle(); + buffer = mdbuffer(std::move(data)); + EXPECT_TRUE(buffer.is_owning()); + EXPECT_EQ(buffer.mem_type(), memory_type::pinned); + EXPECT_EQ(buffer.view().data_handle(), original_data_handle); + + auto buffer2 = mdbuffer(res, buffer); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::pinned); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::host); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::host); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::device); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::device); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::managed); + EXPECT_TRUE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::managed); + EXPECT_NE(buffer2.view().data_handle(), + buffer.view().data_handle()); + + buffer2 = mdbuffer(res, buffer, memory_type::pinned); + EXPECT_FALSE(buffer2.is_owning()); + EXPECT_EQ(buffer2.mem_type(), memory_type::pinned); + EXPECT_EQ(buffer2.view().data_handle(), + buffer.view().data_handle()); +} + +TEST(MDBuffer, ImplicitMdspanConversion) +{ + auto res = device_resources{}; + auto constexpr depth = std::uint32_t{5}; + auto constexpr rows = std::uint32_t{3}; + auto constexpr cols = std::uint32_t{2}; + + using extents_type = extents; + auto shared_extents = extents_type{}; + + auto data_host = make_host_mdarray( + res, shared_extents); + auto data_device = + make_device_mdarray(res, + shared_extents); + auto data_managed = + make_managed_mdarray( + res, shared_extents); + auto data_pinned = + make_pinned_mdarray(res, + shared_extents); + + auto test_function = [shared_extents](mdbuffer&& buf) { + std::visit([shared_extents](auto view) { EXPECT_EQ(view.extents(), shared_extents); }, + buf.view()); + }; + + test_function(data_host); + test_function(data_device); + test_function(data_managed); + test_function(data_pinned); + test_function(data_host.view()); + test_function(data_device.view()); + test_function(data_managed.view()); + test_function(data_pinned.view()); + + auto test_const_function = [shared_extents](mdbuffer&& buf) { + std::visit([shared_extents](auto view) { EXPECT_EQ(view.extents(), shared_extents); }, + buf.view()); + }; + + test_const_function(data_host.view()); + test_const_function(data_device.view()); + test_const_function(data_managed.view()); + test_const_function(data_pinned.view()); +} + +} // namespace raft diff --git a/cpp/test/core/memory_type.cpp b/cpp/test/core/memory_type.cpp index 02aa8caa6c..cd8aa6bd9e 100644 --- a/cpp/test/core/memory_type.cpp +++ b/cpp/test/core/memory_type.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include @@ -22,7 +23,7 @@ TEST(MemoryType, IsDeviceAccessible) static_assert(!is_device_accessible(memory_type::host)); static_assert(is_device_accessible(memory_type::device)); static_assert(is_device_accessible(memory_type::managed)); - static_assert(!is_device_accessible(memory_type::pinned)); + static_assert(is_device_accessible(memory_type::pinned)); } TEST(MemoryType, IsHostAccessible) @@ -38,6 +39,33 @@ TEST(MemoryType, IsHostDeviceAccessible) static_assert(!is_host_device_accessible(memory_type::host)); static_assert(!is_host_device_accessible(memory_type::device)); static_assert(is_host_device_accessible(memory_type::managed)); - static_assert(!is_host_device_accessible(memory_type::pinned)); + static_assert(is_host_device_accessible(memory_type::pinned)); } + +TEST(MemoryTypeFromPointer, Host) +{ + auto ptr1 = static_cast(nullptr); + cudaMallocHost(&ptr1, 1); + EXPECT_EQ(memory_type_from_pointer(ptr1), memory_type::host); + cudaFree(ptr1); + auto ptr2 = static_cast(nullptr); + EXPECT_EQ(memory_type_from_pointer(ptr2), memory_type::host); +} + +#ifndef RAFT_DISABLE_CUDA +TEST(MemoryTypeFromPointer, Device) +{ + auto ptr = static_cast(nullptr); + cudaMalloc(&ptr, 1); + EXPECT_EQ(memory_type_from_pointer(ptr), memory_type::device); + cudaFree(ptr); +} +TEST(MemoryTypeFromPointer, Managed) +{ + auto ptr = static_cast(nullptr); + cudaMallocManaged(&ptr, 1); + EXPECT_EQ(memory_type_from_pointer(ptr), memory_type::managed); + cudaFree(ptr); +} +#endif } // namespace raft diff --git a/cpp/test/core/numpy_serializer.cu b/cpp/test/core/numpy_serializer.cu index 0d12b97555..5c562d68f7 100644 --- a/cpp/test/core/numpy_serializer.cu +++ b/cpp/test/core/numpy_serializer.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #include #include +#include #include #include diff --git a/cpp/test/util/memory_type_dispatcher.cu b/cpp/test/util/memory_type_dispatcher.cu new file mode 100644 index 0000000000..5e24ff5719 --- /dev/null +++ b/cpp/test/util/memory_type_dispatcher.cu @@ -0,0 +1,421 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +namespace dispatch_test { +struct functor_h { + template + auto static constexpr expected_output() + { + return memory_type::host; + } + auto operator()(host_matrix_view input) { return memory_type::host; } +}; +struct functor_d { + template + auto static constexpr expected_output() + { + return memory_type::device; + } + auto operator()(host_matrix_view input) { return memory_type::device; } +}; +struct functor_m { + template + auto static constexpr expected_output() + { + return memory_type::managed; + } + auto operator()(host_matrix_view input) { return memory_type::managed; } +}; +struct functor_p { + template + auto static constexpr expected_output() + { + return memory_type::pinned; + } + auto operator()(host_matrix_view input) { return memory_type::pinned; } +}; + +struct functor_hd { + template + auto static constexpr expected_output() + { + if constexpr (input_memory_type == memory_type::host) { + return memory_type::host; + } else { + return memory_type::device; + } + } + auto operator()(host_matrix_view input) { return memory_type::host; } + auto operator()(device_matrix_view input) { return memory_type::device; } +}; +struct functor_hm { + template + auto static constexpr expected_output() + { + if constexpr (input_memory_type == memory_type::managed) { + return memory_type::managed; + } else { + return memory_type::host; + } + } + auto operator()(host_matrix_view input) { return memory_type::host; } + auto operator()(managed_matrix_view input) { return memory_type::managed; } +}; +struct functor_hp { + template + auto static constexpr expected_output() + { + if constexpr (input_memory_type == memory_type::pinned) { + return memory_type::pinned; + } else { + return memory_type::host; + } + } + auto operator()(host_matrix_view input) { return memory_type::host; } + auto operator()(pinned_matrix_view input) { return memory_type::pinned; } +}; +struct functor_dm { + template + auto static constexpr expected_output() + { + if constexpr (input_memory_type == memory_type::managed) { + return memory_type::managed; + } else { + return memory_type::device; + } + } + auto operator()(device_matrix_view input) { return memory_type::device; } + auto operator()(managed_matrix_view input) { return memory_type::managed; } +}; +struct functor_dp { + template + auto static constexpr expected_output() + { + if constexpr (input_memory_type == memory_type::pinned) { + return memory_type::pinned; + } else { + return memory_type::device; + } + } + auto operator()(device_matrix_view input) { return memory_type::device; } + auto operator()(pinned_matrix_view input) { return memory_type::pinned; } +}; +struct functor_mp { + template + auto static constexpr expected_output() + { + if constexpr (input_memory_type == memory_type::pinned) { + return memory_type::pinned; + } else { + return memory_type::managed; + } + } + auto operator()(managed_matrix_view input) { return memory_type::managed; } + auto operator()(pinned_matrix_view input) { return memory_type::pinned; } +}; + +struct functor_hdm { + template + auto static constexpr expected_output() + { + if constexpr (input_memory_type == memory_type::host) { + return memory_type::host; + } else if constexpr (input_memory_type == memory_type::managed) { + return memory_type::managed; + } else { + return memory_type::device; + } + } + auto operator()(host_matrix_view input) { return memory_type::host; } + auto operator()(device_matrix_view input) { return memory_type::device; } + auto operator()(managed_matrix_view input) { return memory_type::managed; } +}; +struct functor_hdp { + template + auto static constexpr expected_output() + { + if constexpr (input_memory_type == memory_type::host) { + return memory_type::host; + } else if constexpr (input_memory_type == memory_type::pinned) { + return memory_type::pinned; + } else { + return memory_type::device; + } + } + auto operator()(host_matrix_view input) { return memory_type::host; } + auto operator()(device_matrix_view input) { return memory_type::device; } + auto operator()(pinned_matrix_view input) { return memory_type::pinned; } +}; +struct functor_dmp { + template + auto static constexpr expected_output() + { + if constexpr (input_memory_type == memory_type::managed) { + return memory_type::managed; + } else if constexpr (input_memory_type == memory_type::pinned) { + return memory_type::pinned; + } else { + return memory_type::device; + } + } + auto operator()(device_matrix_view input) { return memory_type::device; } + auto operator()(managed_matrix_view input) { return memory_type::managed; } + auto operator()(pinned_matrix_view input) { return memory_type::pinned; } +}; + +struct functor_hdmp { + template + auto static constexpr expected_output() + { + return input_memory_type; + } + auto operator()(host_matrix_view input) { return memory_type::host; } + auto operator()(device_matrix_view input) { return memory_type::device; } + auto operator()(managed_matrix_view input) { return memory_type::managed; } + auto operator()(pinned_matrix_view input) { return memory_type::pinned; } +}; + +template +auto generate_input(raft::resources const& res) +{ + auto constexpr rows = std::uint32_t{3}; + auto constexpr cols = std::uint32_t{5}; + if constexpr (input_memory_type == raft::memory_type::host) { + return raft::make_host_matrix(rows, cols); + } else if constexpr (input_memory_type == raft::memory_type::device) { + return raft::make_device_matrix(res, rows, cols); + } else if constexpr (input_memory_type == raft::memory_type::managed) { + return raft::make_managed_matrix(res, rows, cols); + } else if constexpr (input_memory_type == raft::memory_type::pinned) { + return raft::make_pinned_matrix(res, rows, cols); + } +} + +template +void test_memory_type_dispatcher() +{ + auto res = raft::device_resources{}; + auto data = generate_input(res); + auto data_float = generate_input(res); + auto data_f = generate_input(res); + auto data_f_float = generate_input(res); + + EXPECT_EQ(memory_type_dispatcher(res, functor_h{}, data.view()), + functor_h::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_d{}, data.view()), + functor_d::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_m{}, data.view()), + functor_m::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_p{}, data.view()), + functor_p::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_hd{}, data.view()), + functor_hd::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_hm{}, data.view()), + functor_hm::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_hp{}, data.view()), + functor_hp::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_dm{}, data.view()), + functor_dm::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_dp{}, data.view()), + functor_dp::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_mp{}, data.view()), + functor_mp::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_hdm{}, data.view()), + functor_hdm::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_hdp{}, data.view()), + functor_hdp::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_dmp{}, data.view()), + functor_dmp::expected_output()); + EXPECT_EQ(memory_type_dispatcher(res, functor_hdmp{}, data.view()), + functor_hdmp::expected_output()); + + // Functor expects double; input is float + auto out = memory_type_dispatcher>>( + res, functor_h{}, data_float.view()); + EXPECT_EQ(out, functor_h::expected_output()); + out = memory_type_dispatcher>>( + res, functor_d{}, data_float.view()); + EXPECT_EQ(out, functor_d::expected_output()); + out = memory_type_dispatcher>>( + res, functor_m{}, data_float.view()); + EXPECT_EQ(out, functor_m::expected_output()); + out = memory_type_dispatcher>>( + res, functor_p{}, data_float.view()); + EXPECT_EQ(out, functor_p::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hd{}, data_float.view()); + EXPECT_EQ(out, functor_hd::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hm{}, data_float.view()); + EXPECT_EQ(out, functor_hm::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hp{}, data_float.view()); + EXPECT_EQ(out, functor_hp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_dm{}, data_float.view()); + EXPECT_EQ(out, functor_dm::expected_output()); + out = memory_type_dispatcher>>( + res, functor_dp{}, data_float.view()); + EXPECT_EQ(out, functor_dp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_mp{}, data_float.view()); + EXPECT_EQ(out, functor_mp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hdm{}, data_float.view()); + EXPECT_EQ(out, functor_hdm::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hdp{}, data_float.view()); + EXPECT_EQ(out, functor_hdp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_dmp{}, data_float.view()); + EXPECT_EQ(out, functor_dmp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hdmp{}, data_float.view()); + EXPECT_EQ(out, functor_hdmp::expected_output()); + + // Functor expects C-contiguous; input is F-contiguous + out = memory_type_dispatcher>>( + res, functor_h{}, data_f.view()); + EXPECT_EQ(out, functor_h::expected_output()); + out = memory_type_dispatcher>>( + res, functor_d{}, data_f.view()); + EXPECT_EQ(out, functor_d::expected_output()); + out = memory_type_dispatcher>>( + res, functor_m{}, data_f.view()); + EXPECT_EQ(out, functor_m::expected_output()); + out = memory_type_dispatcher>>( + res, functor_p{}, data_f.view()); + EXPECT_EQ(out, functor_p::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hd{}, data_f.view()); + EXPECT_EQ(out, functor_hd::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hm{}, data_f.view()); + EXPECT_EQ(out, functor_hm::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hp{}, data_f.view()); + EXPECT_EQ(out, functor_hp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_dm{}, data_f.view()); + EXPECT_EQ(out, functor_dm::expected_output()); + out = memory_type_dispatcher>>( + res, functor_dp{}, data_f.view()); + EXPECT_EQ(out, functor_dp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_mp{}, data_f.view()); + EXPECT_EQ(out, functor_mp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hdm{}, data_f.view()); + EXPECT_EQ(out, functor_hdm::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hdp{}, data_f.view()); + EXPECT_EQ(out, functor_hdp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_dmp{}, data_f.view()); + EXPECT_EQ(out, functor_dmp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hdmp{}, data_f.view()); + EXPECT_EQ(out, functor_hdmp::expected_output()); + + // Functor expects C-contiguous double; input is F-contiguous float + out = memory_type_dispatcher>>( + res, functor_h{}, data_f_float.view()); + EXPECT_EQ(out, functor_h::expected_output()); + out = memory_type_dispatcher>>( + res, functor_d{}, data_f_float.view()); + EXPECT_EQ(out, functor_d::expected_output()); + out = memory_type_dispatcher>>( + res, functor_m{}, data_f_float.view()); + EXPECT_EQ(out, functor_m::expected_output()); + out = memory_type_dispatcher>>( + res, functor_p{}, data_f_float.view()); + EXPECT_EQ(out, functor_p::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hd{}, data_f_float.view()); + EXPECT_EQ(out, functor_hd::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hm{}, data_f_float.view()); + EXPECT_EQ(out, functor_hm::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hp{}, data_f_float.view()); + EXPECT_EQ(out, functor_hp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_dm{}, data_f_float.view()); + EXPECT_EQ(out, functor_dm::expected_output()); + out = memory_type_dispatcher>>( + res, functor_dp{}, data_f_float.view()); + EXPECT_EQ(out, functor_dp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_mp{}, data_f_float.view()); + EXPECT_EQ(out, functor_mp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hdm{}, data_f_float.view()); + EXPECT_EQ(out, functor_hdm::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hdp{}, data_f_float.view()); + EXPECT_EQ(out, functor_hdp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_dmp{}, data_f_float.view()); + EXPECT_EQ(out, functor_dmp::expected_output()); + out = memory_type_dispatcher>>( + res, functor_hdmp{}, data_f_float.view()); + EXPECT_EQ(out, functor_hdmp::expected_output()); +} + +} // namespace dispatch_test + +TEST(MemoryTypeDispatcher, FromHost) +{ + dispatch_test::test_memory_type_dispatcher(); +} + +TEST(MemoryTypeDispatcher, FromDevice) +{ + dispatch_test::test_memory_type_dispatcher(); +} + +TEST(MemoryTypeDispatcher, FromManaged) +{ + dispatch_test::test_memory_type_dispatcher(); +} + +TEST(MemoryTypeDispatcher, FromPinned) +{ + dispatch_test::test_memory_type_dispatcher(); +} + +} // namespace raft diff --git a/docs/source/cpp_api/mdspan.rst b/docs/source/cpp_api/mdspan.rst index 3fc0db7b96..b311020049 100644 --- a/docs/source/cpp_api/mdspan.rst +++ b/docs/source/cpp_api/mdspan.rst @@ -16,4 +16,6 @@ This page provides C++ class references for the RAFT's 1d span and multi-dimensi mdspan_mdspan.rst mdspan_mdarray.rst mdspan_span.rst + mdspan_mdbuffer.rst + memory_type_dispatcher.rst mdspan_temporary_device_buffer.rst diff --git a/docs/source/cpp_api/mdspan_mdarray.rst b/docs/source/cpp_api/mdspan_mdarray.rst index bcc2254204..af3943065d 100644 --- a/docs/source/cpp_api/mdspan_mdarray.rst +++ b/docs/source/cpp_api/mdspan_mdarray.rst @@ -68,4 +68,68 @@ Host Factories .. doxygengroup:: host_mdarray_factories :project: RAFT :members: - :content-only: \ No newline at end of file + :content-only: + +Managed Vocabulary +------------------ + +``#include `` + +.. doxygentypedef:: raft::managed_mdarray + :project: RAFT + +.. doxygentypedef:: raft::managed_matrix + :project: RAFT + +.. doxygentypedef:: raft::managed_vector + :project: RAFT + +.. doxygentypedef:: raft::managed_scalar + :project: RAFT + + +Managed Factories +----------------- + +``#include `` + +.. doxygenfunction:: raft::make_managed_matrix + :project: RAFT + +.. doxygenfunction:: raft::make_managed_vector + :project: RAFT + +.. doxygenfunction:: raft::make_managed_scalar + :project: RAFT + +Pinned Vocabulary +----------------- + +``#include `` + +.. doxygentypedef:: raft::pinned_mdarray + :project: RAFT + +.. doxygentypedef:: raft::pinned_matrix + :project: RAFT + +.. doxygentypedef:: raft::pinned_vector + :project: RAFT + +.. doxygentypedef:: raft::pinned_scalar + :project: RAFT + + +Pinned Factories +---------------- + +``#include `` + +.. doxygenfunction:: raft::make_pinned_matrix + :project: RAFT + +.. doxygenfunction:: raft::make_pinned_vector + :project: RAFT + +.. doxygenfunction:: raft::make_pinned_scalar + :project: RAFT diff --git a/docs/source/cpp_api/mdspan_mdbuffer.rst b/docs/source/cpp_api/mdspan_mdbuffer.rst new file mode 100644 index 0000000000..40fe066a2e --- /dev/null +++ b/docs/source/cpp_api/mdspan_mdbuffer.rst @@ -0,0 +1,13 @@ +mdbuffer: Multi-dimensional Maybe-Owning Container +================================================== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +.. doxygengroup:: mdbuffer_apis + :project: RAFT + :members: + :content-only: diff --git a/docs/source/cpp_api/mdspan_mdspan.rst b/docs/source/cpp_api/mdspan_mdspan.rst index f9f972aa74..28d06b5323 100644 --- a/docs/source/cpp_api/mdspan_mdspan.rst +++ b/docs/source/cpp_api/mdspan_mdspan.rst @@ -92,9 +92,9 @@ Device Factories Managed Vocabulary ------------------ -``#include `` +``#include `` -..doxygentypedef:: raft::managed_mdspan +.. doxygentypedef:: raft::managed_mdspan :project: RAFT .. doxygenstruct:: raft::is_managed_mdspan @@ -122,7 +122,7 @@ Managed Vocabulary Managed Factories ----------------- -``#include `` +``#include `` .. doxygenfunction:: make_managed_mdspan(ElementType* ptr, extents exts) :project: RAFT @@ -177,7 +177,38 @@ Host Factories .. doxygenfunction:: raft::make_host_vector_view :project: RAFT -.. doxygenfunction:: raft::make_device_scalar_view +.. doxygenfunction:: raft::make_host_scalar_view + :project: RAFT + +Pinned Vocabulary +--------------- + +``#include `` + +.. doxygentypedef:: raft::pinned_mdspan + :project: RAFT + +.. doxygentypedef:: raft::pinned_matrix_view + :project: RAFT + +.. doxygentypedef:: raft::pinned_vector_view + :project: RAFT + +.. doxygentypedef:: raft::pinned_scalar_view + :project: RAFT + +Pinned Factories +-------------- + +``#include `` + +.. doxygenfunction:: raft::make_pinned_matrix_view + :project: RAFT + +.. doxygenfunction:: raft::make_pinned_vector_view + :project: RAFT + +.. doxygenfunction:: raft::make_pinned_scalar_view :project: RAFT diff --git a/docs/source/cpp_api/memory_type_dispatcher.rst b/docs/source/cpp_api/memory_type_dispatcher.rst new file mode 100644 index 0000000000..687a872967 --- /dev/null +++ b/docs/source/cpp_api/memory_type_dispatcher.rst @@ -0,0 +1,13 @@ +memory_type_dispatcher +====================== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +.. doxygengroup:: memory_type_dispatcher + :project: RAFT + :members: + :content-only: