From 2a63aff614a0d7f417144d0a13bfa9acaeefface Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 3 Nov 2022 14:28:41 -0400 Subject: [PATCH 1/6] Provide memory_type enum Provide an enum class used to distinguish among various memory types. This allows decoupling of the concept of memory type from a specific mdspan accessor policy and ensures that templates receive only a logically-consistent memory type. --- cpp/include/raft/core/memory_type.hpp | 36 ++++++++++++++++++++++ cpp/test/memory_type.cpp | 43 +++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 cpp/include/raft/core/memory_type.hpp create mode 100644 cpp/test/memory_type.cpp diff --git a/cpp/include/raft/core/memory_type.hpp b/cpp/include/raft/core/memory_type.hpp new file mode 100644 index 0000000000..e09b570fa1 --- /dev/null +++ b/cpp/include/raft/core/memory_type.hpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace raft { +enum class memory_type { + host, + device, + managed, + pinned +}; + +auto constexpr is_device_accessible(memory_type mem_type) { + return (mem_type == memory_type::device || mem_type == memory_type::managed); +} +auto constexpr is_host_accessible(memory_type mem_type) { + return (mem_type == memory_type::host || mem_type == memory_type::managed || mem_type == memory_type::pinned); +} +auto constexpr is_host_device_accessible(memory_type mem_type) { + return mem_type == memory_type::managed; +} + +} // end namespace raft diff --git a/cpp/test/memory_type.cpp b/cpp/test/memory_type.cpp new file mode 100644 index 0000000000..57d44ceefe --- /dev/null +++ b/cpp/test/memory_type.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace raft { +TEST(MemoryType, IsDeviceAccessible) +{ + static_assert(!is_device_accessible(memory_type::host)); + static_assert(is_device_accessible(memory_type::device)); + static_assert(is_device_accessible(memory_type::managed)); + static_assert(!is_device_accessible(memory_type::pinned)); +} + +TEST(MemoryType, IsHostAccessible) +{ + static_assert(is_host_accessible(memory_type::host)); + static_assert(!is_host_accessible(memory_type::device)); + static_assert(is_host_accessible(memory_type::managed)); + static_assert(is_host_accessible(memory_type::pinned)); +} + +TEST(MemoryType, IsHostDeviceAccessible) +{ + static_assert(!is_host_device_accessible(memory_type::host)); + static_assert(!is_host_device_accessible(memory_type::device)); + static_assert(is_host_device_accessible(memory_type::managed)); + static_assert(!is_host_device_accessible(memory_type::pinned)); +} +} // namespace raft From 5760977689749bbda55e6f447270367b3c7d036d Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 3 Nov 2022 14:33:01 -0400 Subject: [PATCH 2/6] Update style --- cpp/include/raft/core/memory_type.hpp | 19 +++++++++---------- cpp/test/CMakeLists.txt | 1 + 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/core/memory_type.hpp b/cpp/include/raft/core/memory_type.hpp index e09b570fa1..bb3185d8c6 100644 --- a/cpp/include/raft/core/memory_type.hpp +++ b/cpp/include/raft/core/memory_type.hpp @@ -16,20 +16,19 @@ #pragma once namespace raft { -enum class memory_type { - host, - device, - managed, - pinned -}; +enum class memory_type { host, device, managed, pinned }; -auto constexpr is_device_accessible(memory_type mem_type) { +auto constexpr is_device_accessible(memory_type mem_type) +{ return (mem_type == memory_type::device || mem_type == memory_type::managed); } -auto constexpr is_host_accessible(memory_type mem_type) { - return (mem_type == memory_type::host || mem_type == memory_type::managed || mem_type == memory_type::pinned); +auto constexpr is_host_accessible(memory_type mem_type) +{ + return (mem_type == memory_type::host || mem_type == memory_type::managed || + mem_type == memory_type::pinned); } -auto constexpr is_host_device_accessible(memory_type mem_type) { +auto constexpr is_host_device_accessible(memory_type mem_type) +{ return mem_type == memory_type::managed; } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index c7bf166439..088b15aaf1 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -98,6 +98,7 @@ if(BUILD_TESTS) test/nvtx.cpp test/mdarray.cu test/mdspan_utils.cu + test/memory_type.cpp test/span.cpp test/span.cu test/test.cpp From 241f3530ae13d4efdd844b35f1f4681b8bdc2096 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 3 Nov 2022 15:08:31 -0400 Subject: [PATCH 3/6] Reuse accessibility functions for host/device check Co-authored-by: Rajesh Gandham --- cpp/include/raft/core/memory_type.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/core/memory_type.hpp b/cpp/include/raft/core/memory_type.hpp index bb3185d8c6..f0483f6ba0 100644 --- a/cpp/include/raft/core/memory_type.hpp +++ b/cpp/include/raft/core/memory_type.hpp @@ -29,7 +29,7 @@ auto constexpr is_host_accessible(memory_type mem_type) } auto constexpr is_host_device_accessible(memory_type mem_type) { - return mem_type == memory_type::managed; + return is_device_accessible() && is_host_accessible(); } } // end namespace raft From 14d97bf7d2d05d4ac9383b65e80969147337c4f7 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 3 Nov 2022 15:23:00 -0400 Subject: [PATCH 4/6] Add missing arguments to refactor --- cpp/include/raft/core/memory_type.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/core/memory_type.hpp b/cpp/include/raft/core/memory_type.hpp index f0483f6ba0..9694bf65a5 100644 --- a/cpp/include/raft/core/memory_type.hpp +++ b/cpp/include/raft/core/memory_type.hpp @@ -29,7 +29,7 @@ auto constexpr is_host_accessible(memory_type mem_type) } auto constexpr is_host_device_accessible(memory_type mem_type) { - return is_device_accessible() && is_host_accessible(); + return is_device_accessible(mem_type) && is_host_accessible(mem_type); } } // end namespace raft From 43c4695896187d9308ac5feaa480fb9225d32706 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Tue, 8 Nov 2022 14:20:44 -0500 Subject: [PATCH 5/6] Use memory_type in mdspan Resolve #994 --- cpp/include/raft/core/device_mdspan.hpp | 7 +++-- .../raft/core/host_device_accessor.hpp | 17 ++++++----- cpp/include/raft/core/host_mdspan.hpp | 5 ++-- cpp/include/raft/core/mdarray.hpp | 4 +-- cpp/include/raft/core/mdspan.hpp | 28 +++++++++++++++++-- cpp/include/raft/core/memory_type.hpp | 15 ++++++++++ 6 files changed, 60 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 394ea228b4..ae66f315d9 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -19,14 +19,15 @@ #include #include #include +#include namespace raft { template -using device_accessor = host_device_accessor; +using device_accessor = host_device_accessor; template -using managed_accessor = host_device_accessor; +using managed_accessor = host_device_accessor; /** * @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location. @@ -276,4 +277,4 @@ auto make_device_vector_view(ElementType* ptr, IndexType n) return device_vector_view{ptr, n}; } -} // end namespace raft \ No newline at end of file +} // end namespace raft diff --git a/cpp/include/raft/core/host_device_accessor.hpp b/cpp/include/raft/core/host_device_accessor.hpp index 4f6f559be4..d80b32a4c0 100644 --- a/cpp/include/raft/core/host_device_accessor.hpp +++ b/cpp/include/raft/core/host_device_accessor.hpp @@ -15,6 +15,8 @@ */ #pragma once +#include +#include namespace raft { @@ -23,15 +25,16 @@ namespace raft { * accessor used throught RAFT's APIs to denote whether an underlying pointer * is accessible from device, host, or both. */ -template +template struct host_device_accessor : public AccessorPolicy { using accessor_type = AccessorPolicy; - using is_host_type = std::conditional_t; - using is_device_type = std::conditional_t; - using is_managed_type = std::conditional_t; - static constexpr bool is_host_accessible = is_host; - static constexpr bool is_device_accessible = is_device; - static constexpr bool is_managed_accessible = is_device && is_host; + auto static constexpr const mem_type = MemType; + using is_host_type = std::conditional_t; + using is_device_type = std::conditional_t; + using is_managed_type = std::conditional_t; + static constexpr bool is_host_accessible = raft::is_host_accessible(mem_type); + static constexpr bool is_device_accessible = raft::is_device_accessible(mem_type); + static constexpr bool is_managed_accessible = raft::is_host_device_accessible(mem_type); // make sure the explicit ctor can fall through using AccessorPolicy::AccessorPolicy; using offset_policy = host_device_accessor; diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index 0b49ca9945..d3d6c53df3 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -18,13 +18,14 @@ #include #include +#include #include namespace raft { template -using host_accessor = host_device_accessor; +using host_accessor = host_device_accessor; /** * @brief std::experimental::mdspan with host tag to avoid accessing incorrect memory location. @@ -205,4 +206,4 @@ auto make_host_vector_view(ElementType* ptr, IndexType n) { return host_vector_view{ptr, n}; } -} // end namespace raft \ No newline at end of file +} // end namespace raft diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index ae5d236395..2b0210cd9b 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace raft { @@ -158,8 +159,7 @@ class mdarray extents_type, layout_type, host_device_accessor>; + container_policy_type::mem_type>>; public: /** diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 1faac44cc8..baf6e19cdb 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -182,10 +183,33 @@ template auto make_mdspan(ElementType* ptr, extents exts) +{ + using accessor_type = host_device_accessor, detail::memory_type_from_access()>; + /*using accessor_type = host_device_accessor, + mem_type>; */ + + return mdspan{ptr, exts}; +} + +/** + * @brief Create a raft::mdspan + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @tparam MemType the raft::memory_type for where the data are stored + * @param ptr Pointer to the data + * @param exts dimensionality of the array (series of integers) + * @return raft::mdspan + */ +template +auto make_mdspan(ElementType* ptr, extents exts) { using accessor_type = host_device_accessor, - is_host_accessible, - is_device_accessible>; + MemType>; return mdspan{ptr, exts}; } diff --git a/cpp/include/raft/core/memory_type.hpp b/cpp/include/raft/core/memory_type.hpp index 9694bf65a5..e5b2fe0039 100644 --- a/cpp/include/raft/core/memory_type.hpp +++ b/cpp/include/raft/core/memory_type.hpp @@ -32,4 +32,19 @@ auto constexpr is_host_device_accessible(memory_type mem_type) return is_device_accessible(mem_type) && is_host_accessible(mem_type); } +namespace detail { + +template +auto constexpr memory_type_from_access() { + if constexpr (is_host_accessible && is_device_accessible) { + return memory_type::managed; + } else if constexpr (is_host_accessible) { + return memory_type::host; + } else if constexpr (is_device_accessible) { + return memory_type::device; + } + static_assert(is_host_accessible || is_device_accessible); +} + +} // end namespace detail } // end namespace raft From fb16969ebbd6631f226e0b6ebd5ed0615cc006d2 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Tue, 8 Nov 2022 15:52:56 -0500 Subject: [PATCH 6/6] Provide error message on failed memory_type derivation Also update style --- cpp/include/raft/core/host_device_accessor.hpp | 13 ++++++++----- cpp/include/raft/core/mdarray.hpp | 10 +++++----- cpp/include/raft/core/mdspan.hpp | 14 ++++++++------ cpp/include/raft/core/memory_type.hpp | 8 +++++--- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/core/host_device_accessor.hpp b/cpp/include/raft/core/host_device_accessor.hpp index d80b32a4c0..81bf015f2e 100644 --- a/cpp/include/raft/core/host_device_accessor.hpp +++ b/cpp/include/raft/core/host_device_accessor.hpp @@ -15,8 +15,8 @@ */ #pragma once -#include #include +#include namespace raft { @@ -27,11 +27,14 @@ namespace raft { */ template struct host_device_accessor : public AccessorPolicy { - using accessor_type = AccessorPolicy; + using accessor_type = AccessorPolicy; auto static constexpr const mem_type = MemType; - using is_host_type = std::conditional_t; - using is_device_type = std::conditional_t; - using is_managed_type = std::conditional_t; + using is_host_type = + std::conditional_t; + using is_device_type = + std::conditional_t; + using is_managed_type = + std::conditional_t; static constexpr bool is_host_accessible = raft::is_host_accessible(mem_type); static constexpr bool is_device_accessible = raft::is_device_accessible(mem_type); static constexpr bool is_managed_accessible = raft::is_host_device_accessible(mem_type); diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 2b0210cd9b..12d186c03f 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -155,11 +155,11 @@ class mdarray std::conditional_t, typename container_policy_type::const_accessor_policy, typename container_policy_type::accessor_policy>> - using view_type_impl = mdspan>; + using view_type_impl = + mdspan>; public: /** diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index baf6e19cdb..db131ff6fa 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -184,7 +184,9 @@ template auto make_mdspan(ElementType* ptr, extents exts) { - using accessor_type = host_device_accessor, detail::memory_type_from_access()>; + using accessor_type = host_device_accessor< + std::experimental::default_accessor, + detail::memory_type_from_access()>; /*using accessor_type = host_device_accessor, mem_type>; */ @@ -202,14 +204,14 @@ auto make_mdspan(ElementType* ptr, extents exts) * @return raft::mdspan */ template auto make_mdspan(ElementType* ptr, extents exts) { - using accessor_type = host_device_accessor, - MemType>; + using accessor_type = + host_device_accessor, MemType>; return mdspan{ptr, exts}; } diff --git a/cpp/include/raft/core/memory_type.hpp b/cpp/include/raft/core/memory_type.hpp index e5b2fe0039..cd37a0ee50 100644 --- a/cpp/include/raft/core/memory_type.hpp +++ b/cpp/include/raft/core/memory_type.hpp @@ -34,8 +34,9 @@ auto constexpr is_host_device_accessible(memory_type mem_type) namespace detail { -template -auto constexpr memory_type_from_access() { +template +auto constexpr memory_type_from_access() +{ if constexpr (is_host_accessible && is_device_accessible) { return memory_type::managed; } else if constexpr (is_host_accessible) { @@ -43,7 +44,8 @@ auto constexpr memory_type_from_access() { } else if constexpr (is_device_accessible) { return memory_type::device; } - static_assert(is_host_accessible || is_device_accessible); + static_assert(is_host_accessible || is_device_accessible, + "Must be either host or device accessible to return a valid memory type"); } } // end namespace detail