Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide memory_type enum #984

Merged
merged 6 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
#include <cstdint>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/memory_type.hpp>

namespace raft {

template <typename AccessorPolicy>
using device_accessor = host_device_accessor<AccessorPolicy, false, true>;
using device_accessor = host_device_accessor<AccessorPolicy, memory_type::device>;

template <typename AccessorPolicy>
using managed_accessor = host_device_accessor<AccessorPolicy, true, true>;
using managed_accessor = host_device_accessor<AccessorPolicy, memory_type::managed>;

/**
* @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location.
Expand Down Expand Up @@ -276,4 +277,4 @@ auto make_device_vector_view(ElementType* ptr, IndexType n)
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}

} // end namespace raft
} // end namespace raft
17 changes: 10 additions & 7 deletions cpp/include/raft/core/host_device_accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/

#pragma once
#include <type_traits>
#include <raft/core/memory_type.hpp>

namespace raft {

Expand All @@ -23,15 +25,16 @@ namespace raft {
* accessor used throught RAFT's APIs to denote whether an underlying pointer
* is accessible from device, host, or both.
*/
template <typename AccessorPolicy, bool is_host, bool is_device>
template <typename AccessorPolicy, memory_type MemType>
struct host_device_accessor : public AccessorPolicy {
using accessor_type = AccessorPolicy;
using is_host_type = std::conditional_t<is_host, std::true_type, std::false_type>;
using is_device_type = std::conditional_t<is_device, std::true_type, std::false_type>;
using is_managed_type = std::conditional_t<is_device && is_host, std::true_type, std::false_type>;
static constexpr bool is_host_accessible = is_host;
static constexpr bool is_device_accessible = is_device;
static constexpr bool is_managed_accessible = is_device && is_host;
auto static constexpr const mem_type = MemType;
using is_host_type = std::conditional_t<raft::is_host_accessible(mem_type), std::true_type, std::false_type>;
using is_device_type = std::conditional_t<raft::is_device_accessible(mem_type), std::true_type, std::false_type>;
using is_managed_type = std::conditional_t<raft::is_host_device_accessible(mem_type), std::true_type, std::false_type>;
static constexpr bool is_host_accessible = raft::is_host_accessible(mem_type);
static constexpr bool is_device_accessible = raft::is_device_accessible(mem_type);
static constexpr bool is_managed_accessible = raft::is_host_device_accessible(mem_type);
// make sure the explicit ctor can fall through
using AccessorPolicy::AccessorPolicy;
using offset_policy = host_device_accessor;
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/raft/core/host_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

#include <cstdint>
#include <raft/core/mdspan.hpp>
#include <raft/core/memory_type.hpp>

#include <raft/core/host_device_accessor.hpp>

namespace raft {

template <typename AccessorPolicy>
using host_accessor = host_device_accessor<AccessorPolicy, true, false>;
using host_accessor = host_device_accessor<AccessorPolicy, memory_type::host>;

/**
* @brief std::experimental::mdspan with host tag to avoid accessing incorrect memory location.
Expand Down Expand Up @@ -205,4 +206,4 @@ auto make_host_vector_view(ElementType* ptr, IndexType n)
{
return host_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}
} // end namespace raft
} // end namespace raft
4 changes: 2 additions & 2 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/memory_type.hpp>
#include <rmm/cuda_stream_view.hpp>

namespace raft {
Expand Down Expand Up @@ -158,8 +159,7 @@ class mdarray
extents_type,
layout_type,
host_device_accessor<ViewAccessorPolicy,
container_policy_type::is_host_accessible,
container_policy_type::is_device_accessible>>;
container_policy_type::mem_type>>;

public:
/**
Expand Down
28 changes: 26 additions & 2 deletions cpp/include/raft/core/mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <raft/core/error.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/memory_type.hpp>

#include <raft/core/detail/macros.hpp>
#include <raft/core/detail/mdspan_util.cuh>
Expand Down Expand Up @@ -182,10 +183,33 @@ template <typename ElementType,
bool is_device_accessible = true,
size_t... Extents>
auto make_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
{
using accessor_type = host_device_accessor<std::experimental::default_accessor<ElementType>, detail::memory_type_from_access<is_host_accessible, is_device_accessible>()>;
/*using accessor_type = host_device_accessor<std::experimental::default_accessor<ElementType>,
mem_type>; */

return mdspan<ElementType, decltype(exts), LayoutPolicy, accessor_type>{ptr, exts};
}

/**
* @brief Create a raft::mdspan
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @tparam MemType the raft::memory_type for where the data are stored
* @param ptr Pointer to the data
* @param exts dimensionality of the array (series of integers)
* @return raft::mdspan
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
memory_type MemType = memory_type::device,
size_t... Extents>
auto make_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
{
using accessor_type = host_device_accessor<std::experimental::default_accessor<ElementType>,
is_host_accessible,
is_device_accessible>;
MemType>;

return mdspan<ElementType, decltype(exts), LayoutPolicy, accessor_type>{ptr, exts};
}
Expand Down
50 changes: 50 additions & 0 deletions cpp/include/raft/core/memory_type.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like consolidating the different memory variations into it's own enum. I'm wondering if we should integrate / consolidate this further into the raft::host_device_accessor here? or maybe there's no value in doing that quite yet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's certainly worth doing as a follow-up.


namespace raft {
enum class memory_type { host, device, managed, pinned };

auto constexpr is_device_accessible(memory_type mem_type)
{
return (mem_type == memory_type::device || mem_type == memory_type::managed);
}
auto constexpr is_host_accessible(memory_type mem_type)
{
return (mem_type == memory_type::host || mem_type == memory_type::managed ||
mem_type == memory_type::pinned);
}
auto constexpr is_host_device_accessible(memory_type mem_type)
{
return is_device_accessible(mem_type) && is_host_accessible(mem_type);
}

namespace detail {

template<bool is_host_accessible, bool is_device_accessible>
auto constexpr memory_type_from_access() {
if constexpr (is_host_accessible && is_device_accessible) {
return memory_type::managed;
} else if constexpr (is_host_accessible) {
return memory_type::host;
} else if constexpr (is_device_accessible) {
return memory_type::device;
}
static_assert(is_host_accessible || is_device_accessible);
wphicks marked this conversation as resolved.
Show resolved Hide resolved
}

} // end namespace detail
} // end namespace raft
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ if(BUILD_TESTS)
test/nvtx.cpp
test/mdarray.cu
test/mdspan_utils.cu
test/memory_type.cpp
test/span.cpp
test/span.cu
test/test.cpp
Expand Down
43 changes: 43 additions & 0 deletions cpp/test/memory_type.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <raft/core/memory_type.hpp>

namespace raft {
TEST(MemoryType, IsDeviceAccessible)
{
static_assert(!is_device_accessible(memory_type::host));
static_assert(is_device_accessible(memory_type::device));
static_assert(is_device_accessible(memory_type::managed));
static_assert(!is_device_accessible(memory_type::pinned));
}

TEST(MemoryType, IsHostAccessible)
{
static_assert(is_host_accessible(memory_type::host));
static_assert(!is_host_accessible(memory_type::device));
static_assert(is_host_accessible(memory_type::managed));
static_assert(is_host_accessible(memory_type::pinned));
}

TEST(MemoryType, IsHostDeviceAccessible)
{
static_assert(!is_host_device_accessible(memory_type::host));
static_assert(!is_host_device_accessible(memory_type::device));
static_assert(is_host_device_accessible(memory_type::managed));
static_assert(!is_host_device_accessible(memory_type::pinned));
}
} // namespace raft