From e60cd1cd23fda55470772ab4c0c0385074d15a90 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Tue, 8 Nov 2022 21:39:49 -0500 Subject: [PATCH] Provide memory_type enum (#984) This PR introduces an enum to specify a memory type (e.g. host, device, managed...). This allows us to provide a template parameter which always indicates a valid memory type, which is extensible for possible future memory types, and which is less verbose than alternatives. The most serious shortcoming of existing alternatives is the possibility of indicating invalid memory states. E.g. by templating on `is_device` and `is_host`, we introduce the possible state `is_host=false, is_device=false`. Authors: - William Hicks (https://github.com/wphicks) Approvers: - Rajesh Gandham (https://github.com/rg20) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/984 --- cpp/include/raft/core/device_mdspan.hpp | 7 +-- .../raft/core/host_device_accessor.hpp | 22 +++++--- cpp/include/raft/core/host_mdspan.hpp | 5 +- cpp/include/raft/core/mdarray.hpp | 12 ++--- cpp/include/raft/core/mdspan.hpp | 32 ++++++++++-- cpp/include/raft/core/memory_type.hpp | 52 +++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/memory_type.cpp | 43 +++++++++++++++ 8 files changed, 152 insertions(+), 22 deletions(-) create mode 100644 cpp/include/raft/core/memory_type.hpp create mode 100644 cpp/test/memory_type.cpp diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 394ea228b4..ae66f315d9 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -19,14 +19,15 @@ #include #include #include +#include namespace raft { template -using device_accessor = host_device_accessor; +using device_accessor = host_device_accessor; template -using managed_accessor = host_device_accessor; +using managed_accessor = host_device_accessor; /** * @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location. @@ -276,4 +277,4 @@ auto make_device_vector_view(ElementType* ptr, IndexType n) return device_vector_view{ptr, n}; } -} // end namespace raft \ No newline at end of file +} // end namespace raft diff --git a/cpp/include/raft/core/host_device_accessor.hpp b/cpp/include/raft/core/host_device_accessor.hpp index 4f6f559be4..81bf015f2e 100644 --- a/cpp/include/raft/core/host_device_accessor.hpp +++ b/cpp/include/raft/core/host_device_accessor.hpp @@ -15,6 +15,8 @@ */ #pragma once +#include +#include namespace raft { @@ -23,15 +25,19 @@ namespace raft { * accessor used throught RAFT's APIs to denote whether an underlying pointer * is accessible from device, host, or both. */ -template +template struct host_device_accessor : public AccessorPolicy { - using accessor_type = AccessorPolicy; - using is_host_type = std::conditional_t; - using is_device_type = std::conditional_t; - using is_managed_type = std::conditional_t; - static constexpr bool is_host_accessible = is_host; - static constexpr bool is_device_accessible = is_device; - static constexpr bool is_managed_accessible = is_device && is_host; + using accessor_type = AccessorPolicy; + auto static constexpr const mem_type = MemType; + using is_host_type = + std::conditional_t; + using is_device_type = + std::conditional_t; + using is_managed_type = + std::conditional_t; + static constexpr bool is_host_accessible = raft::is_host_accessible(mem_type); + static constexpr bool is_device_accessible = raft::is_device_accessible(mem_type); + static constexpr bool is_managed_accessible = raft::is_host_device_accessible(mem_type); // make sure the explicit ctor can fall through using AccessorPolicy::AccessorPolicy; using offset_policy = host_device_accessor; diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index 0b49ca9945..d3d6c53df3 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -18,13 +18,14 @@ #include #include +#include #include namespace raft { template -using host_accessor = host_device_accessor; +using host_accessor = host_device_accessor; /** * @brief std::experimental::mdspan with host tag to avoid accessing incorrect memory location. @@ -205,4 +206,4 @@ auto make_host_vector_view(ElementType* ptr, IndexType n) { return host_vector_view{ptr, n}; } -} // end namespace raft \ No newline at end of file +} // end namespace raft diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index ae5d236395..12d186c03f 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace raft { @@ -154,12 +155,11 @@ class mdarray std::conditional_t, typename container_policy_type::const_accessor_policy, typename container_policy_type::accessor_policy>> - using view_type_impl = mdspan>; + using view_type_impl = + mdspan>; public: /** diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 1faac44cc8..db131ff6fa 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -183,9 +184,34 @@ template auto make_mdspan(ElementType* ptr, extents exts) { - using accessor_type = host_device_accessor, - is_host_accessible, - is_device_accessible>; + using accessor_type = host_device_accessor< + std::experimental::default_accessor, + detail::memory_type_from_access()>; + /*using accessor_type = host_device_accessor, + mem_type>; */ + + return mdspan{ptr, exts}; +} + +/** + * @brief Create a raft::mdspan + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @tparam MemType the raft::memory_type for where the data are stored + * @param ptr Pointer to the data + * @param exts dimensionality of the array (series of integers) + * @return raft::mdspan + */ +template +auto make_mdspan(ElementType* ptr, extents exts) +{ + using accessor_type = + host_device_accessor, MemType>; return mdspan{ptr, exts}; } diff --git a/cpp/include/raft/core/memory_type.hpp b/cpp/include/raft/core/memory_type.hpp new file mode 100644 index 0000000000..cd37a0ee50 --- /dev/null +++ b/cpp/include/raft/core/memory_type.hpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace raft { +enum class memory_type { host, device, managed, pinned }; + +auto constexpr is_device_accessible(memory_type mem_type) +{ + return (mem_type == memory_type::device || mem_type == memory_type::managed); +} +auto constexpr is_host_accessible(memory_type mem_type) +{ + return (mem_type == memory_type::host || mem_type == memory_type::managed || + mem_type == memory_type::pinned); +} +auto constexpr is_host_device_accessible(memory_type mem_type) +{ + return is_device_accessible(mem_type) && is_host_accessible(mem_type); +} + +namespace detail { + +template +auto constexpr memory_type_from_access() +{ + if constexpr (is_host_accessible && is_device_accessible) { + return memory_type::managed; + } else if constexpr (is_host_accessible) { + return memory_type::host; + } else if constexpr (is_device_accessible) { + return memory_type::device; + } + static_assert(is_host_accessible || is_device_accessible, + "Must be either host or device accessible to return a valid memory type"); +} + +} // end namespace detail +} // end namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index c7bf166439..088b15aaf1 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -98,6 +98,7 @@ if(BUILD_TESTS) test/nvtx.cpp test/mdarray.cu test/mdspan_utils.cu + test/memory_type.cpp test/span.cpp test/span.cu test/test.cpp diff --git a/cpp/test/memory_type.cpp b/cpp/test/memory_type.cpp new file mode 100644 index 0000000000..57d44ceefe --- /dev/null +++ b/cpp/test/memory_type.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace raft { +TEST(MemoryType, IsDeviceAccessible) +{ + static_assert(!is_device_accessible(memory_type::host)); + static_assert(is_device_accessible(memory_type::device)); + static_assert(is_device_accessible(memory_type::managed)); + static_assert(!is_device_accessible(memory_type::pinned)); +} + +TEST(MemoryType, IsHostAccessible) +{ + static_assert(is_host_accessible(memory_type::host)); + static_assert(!is_host_accessible(memory_type::device)); + static_assert(is_host_accessible(memory_type::managed)); + static_assert(is_host_accessible(memory_type::pinned)); +} + +TEST(MemoryType, IsHostDeviceAccessible) +{ + static_assert(!is_host_device_accessible(memory_type::host)); + static_assert(!is_host_device_accessible(memory_type::device)); + static_assert(is_host_device_accessible(memory_type::managed)); + static_assert(!is_host_device_accessible(memory_type::pinned)); +} +} // namespace raft