Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement maybe-owning multi-dimensional container (mdbuffer) #1999

Merged
merged 176 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 168 commits
Commits
Show all changes
176 commits
Select commit Hold shift + click to select a range
e24fd2e
Initial commit
tarang-jain Apr 3, 2023
b8cda77
Merge branch 'branch-23.04' of https://github.com/rapidsai/raft into …
tarang-jain Apr 3, 2023
07dabfe
New commit
tarang-jain Apr 6, 2023
64eb461
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 6, 2023
21c2641
Update
tarang-jain Apr 6, 2023
c84daa6
Merge
tarang-jain Apr 6, 2023
4ad421b
Merge
tarang-jain Apr 6, 2023
ea11b07
Merge
tarang-jain Apr 6, 2023
ab19410
build
tarang-jain Apr 7, 2023
9870e9d
Test start
tarang-jain Apr 7, 2023
51a2581
Test start
tarang-jain Apr 7, 2023
552b21e
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 7, 2023
d0e7b2c
style changes
tarang-jain Apr 7, 2023
f72f7f8
merge
tarang-jain Apr 7, 2023
05f9daa
merge dependencies.yaml
tarang-jain Apr 7, 2023
0250931
Updates
tarang-jain Apr 10, 2023
057743d
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 10, 2023
20042b0
Debugging
tarang-jain Apr 12, 2023
2d189c3
Update gtest
tarang-jain Apr 19, 2023
53c4557
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 25, 2023
de753ae
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 27, 2023
2f8b294
Some updates after reviews
tarang-jain Apr 27, 2023
6539ef4
Use raft::resources
tarang-jain Apr 28, 2023
1709521
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 28, 2023
008bb5b
move exception
tarang-jain Apr 28, 2023
5b97273
Updates after PR Reviews
tarang-jain May 2, 2023
5be6ec2
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 2, 2023
838bfef
Add container policy
tarang-jain May 8, 2023
e035e2e
further changes with container policy
tarang-jain May 10, 2023
cd91a88
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain May 10, 2023
338c1a6
Some updates
tarang-jain May 12, 2023
6468c24
update container_policy
tarang-jain Jun 7, 2023
1bd5455
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 7, 2023
81c6a81
Working build
tarang-jain Jun 9, 2023
77ae593
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 9, 2023
451815e
Update buffer accessor policy
tarang-jain Jun 12, 2023
b553369
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 12, 2023
b410f36
Style changes
tarang-jain Jun 12, 2023
4731620
minor changes
tarang-jain Jun 13, 2023
238d010
combine owning buffer cpu/gpu
tarang-jain Jun 14, 2023
75cfcf1
update tests
tarang-jain Jun 20, 2023
7b1909f
Updates
tarang-jain Jul 3, 2023
5c041c4
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 3, 2023
0bf6f87
Merge branch 'branch-23.08' into tarbuf
wphicks Jul 3, 2023
1a1143f
Temporarily remove new files to bring back necessary ones
wphicks Jul 3, 2023
acceb61
Begin refactoring buffer container policies
wphicks Jul 5, 2023
fdefc34
Add placeholder resource for stream view in CUDA-free builds
wphicks Jul 10, 2023
24223ed
Add infrastructure for CUDA-free build
wphicks Jul 11, 2023
c6f6354
Merge branch 'branch-23.08' into fea-mdbuffer
wphicks Jul 11, 2023
4689052
Add initial set of CUDA-free tests
wphicks Jul 11, 2023
1b7e1e5
Add variant types to mdbuffer
wphicks Jul 17, 2023
5416ceb
Provide all mdarray/mdspan to mdbuffer conversions
wphicks Jul 18, 2023
355b3d4
Begin creating buffer copy utilities
wphicks Jul 31, 2023
601f65d
Merge branch 'branch-23.10' into fea-mdbuffer
wphicks Aug 18, 2023
4770a83
Correct computation of dest indices
wphicks Aug 18, 2023
28e8627
Merge branch 'branch-23.10' into fea-mdbuffer
wphicks Aug 22, 2023
8237a74
Temporarily remove simd-accelerated copy
wphicks Aug 23, 2023
022cf6e
Add initial mdspan copy utility implementation
wphicks Aug 29, 2023
a1776f4
Refactor copy properties detection
wphicks Aug 31, 2023
a970dad
Correct detection of mdspan copy paths
wphicks Sep 1, 2023
9a2fa9e
Correct build errors
wphicks Sep 1, 2023
eac9de6
Provide passing 3D host transpose tests
wphicks Sep 1, 2023
39cf094
Add working tests for cuBlas based transpose
wphicks Sep 1, 2023
760b656
Add incomplete kernel tests
wphicks Sep 5, 2023
f8d435f
Remove old mdspan copy header
wphicks Sep 5, 2023
4c4fbaf
Revert "Remove old mdspan copy header"
wphicks Sep 5, 2023
ad5c786
Remove correct mdspan copy header
wphicks Sep 5, 2023
2e433ba
Correct std::apply workaround in CUDA
wphicks Sep 6, 2023
d669e42
Provide fully working copy kernel
wphicks Sep 7, 2023
ed663c8
Begin adding SIMD support
wphicks Sep 11, 2023
ab809e8
Revert "Begin adding SIMD support"
wphicks Sep 11, 2023
49d871a
Disable initial SIMD implementation
wphicks Sep 11, 2023
cb24abc
Rename mdspan copy headers
wphicks Sep 11, 2023
2a83c1b
Remove mdbuffer work and document mdspan copy
wphicks Sep 11, 2023
4193b74
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks Sep 11, 2023
624e4f3
Remove un-needed changes left over from mdbuffer
wphicks Sep 12, 2023
e9ef750
Add testing for CUDA-disabled builds
wphicks Sep 12, 2023
06fe54d
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks Sep 12, 2023
92046e0
Fix style and revert some unnecessary changes
wphicks Sep 12, 2023
a0a5b69
Remove changes related to mdbuffer
wphicks Sep 12, 2023
58389ec
Remove change related to mdbuffer
wphicks Sep 12, 2023
0a19ae5
Correctly handle proxy references in mdspan copy kernel
wphicks Sep 12, 2023
0675207
Check for unique destination layout in any parallel copy
wphicks Sep 13, 2023
8ad9434
Use perfect forwarding for copy wrappers
wphicks Sep 13, 2023
fdbc9ee
Correct comment for dimension iteration order
wphicks Sep 13, 2023
21618ea
Add warning about copying to non-unique layouts
wphicks Sep 14, 2023
c31a898
Update mdbuffer constructors for greater versatility
wphicks Sep 18, 2023
18d462e
Add benchmarks for mdspan copy
wphicks Sep 19, 2023
4700199
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks Sep 19, 2023
2cad1ed
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks Sep 19, 2023
6e91a1c
Correct check for assignability in mdspan copy
wphicks Sep 20, 2023
55e06fe
Add comment explaining intermediate storage
wphicks Sep 20, 2023
faa402a
Correct dtype compatibility test
wphicks Sep 21, 2023
2eba34d
Provide cleaner compile error for using copy with unsupported types
wphicks Sep 21, 2023
ca77cf0
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks Sep 22, 2023
4389b64
Update stream_view docs
wphicks Sep 22, 2023
7416b73
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks Sep 22, 2023
7f407ed
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks Sep 22, 2023
62ac60a
Update stream view docs
wphicks Sep 22, 2023
5bddcc8
Merge remote-tracking branch 'origin/fea-mdspan_copy' into fea-mdspan…
wphicks Sep 22, 2023
9c858fc
Merge branch 'fea-mdspan_copy' into fea-mdbuffer
wphicks Sep 22, 2023
8d2b25b
Restore changes removed in mdspan copy PR
wphicks Sep 22, 2023
21b1970
Restore fail_container_policy
wphicks Sep 22, 2023
bd5a8f8
Merge branch 'branch-23.12' into fea-mdspan_copy
wphicks Oct 2, 2023
c926653
Restore variant utils header
wphicks Oct 2, 2023
a8b17a8
Add static asserts for mdspan_copyable
wphicks Oct 2, 2023
722425c
Correct iteration in host-to-host copies
wphicks Oct 2, 2023
a539de3
Merge branch 'fea-mdspan_copy' into fea-mdbuffer
wphicks Oct 2, 2023
8835834
Correct double definition from branch merge
wphicks Oct 4, 2023
7d68a7b
Merge branch 'branch-23.12' into fea-mdbuffer
wphicks Oct 11, 2023
9a8b52e
Add remaining constructor logic
wphicks Oct 12, 2023
502dddd
Add additional mdbuffer constructors
wphicks Oct 12, 2023
f289b6e
Simplify mdbuffer implementation
wphicks Oct 12, 2023
e96d257
Create cuh/hpp split for mdbuffer
wphicks Oct 12, 2023
c344033
Fix compilation issues
wphicks Oct 13, 2023
7939c69
Add deduction guides for mdbuffer constructors
wphicks Oct 13, 2023
5ec364f
Fix pinned container policy implementation
wphicks Oct 16, 2023
03ad7f9
Rework constructors to correctly handle all cases
wphicks Oct 17, 2023
20073e1
Correct enable_ifs for construction from mdarray
wphicks Oct 17, 2023
e012d07
Correct pinned memory handling
wphicks Oct 17, 2023
84cf006
Split off managed and pinned container policies
wphicks Oct 19, 2023
7d1c93b
FIXME: Add debugging lines for managed destructor segfault
wphicks Oct 19, 2023
4acd66e
Begin fixing incorrect separation of device and managed
wphicks Oct 30, 2023
5bf79e5
Merge branch 'branch-23.12' into fea-mdbuffer
wphicks Nov 13, 2023
da0a09f
Ensure managed memory resource remains in scope
wphicks Nov 14, 2023
98c6a3f
Revert "FIXME: Add debugging lines for managed destructor segfault"
wphicks Nov 14, 2023
934aa94
Add missing includes for managed and pinned
wphicks Nov 14, 2023
fb26fd7
Fully separate managed and pinned headers
wphicks Nov 15, 2023
a0830e1
REVERT ME: Temporary workaround for serialization size issue
wphicks Nov 15, 2023
fd852bc
Update managed and pinned header splits
wphicks Nov 15, 2023
4d7602b
Add mdbuffer docs
wphicks Nov 15, 2023
a9f24da
Update docs for managed and pinned memory
wphicks Nov 15, 2023
dc390fe
Add mdspan implicit conversion test
wphicks Nov 15, 2023
c5d4f0f
Merge branch 'branch-23.12' into fea-mdbuffer
wphicks Nov 15, 2023
eb4fddf
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Nov 15, 2023
29bd6b4
Tweak mdbuffer example code
wphicks Nov 15, 2023
b84c290
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Nov 16, 2023
1f0ad4f
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Nov 16, 2023
eb49608
Correct accessibility of pinned memory type
wphicks Nov 17, 2023
8e9ade2
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Nov 20, 2023
909b786
Add memory type dispatching
wphicks Nov 22, 2023
51ab695
Correct handling of pinned memory in dispatcher
wphicks Nov 22, 2023
36bbffe
Begin writing mdspan_dispatched_functor
wphicks Nov 22, 2023
5458e5b
Remove mdspan_dispatched_functor
wphicks Nov 22, 2023
5733005
Add docs for memory_type_dispatcher
wphicks Nov 22, 2023
e6ce9c3
Respond to review
wphicks Nov 22, 2023
40d75cc
Merge remote-tracking branch 'origin/fea-mdbuffer' into fea-mdbuffer
wphicks Nov 22, 2023
3da9348
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Nov 22, 2023
3b1f245
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Nov 28, 2023
50032c5
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Nov 29, 2023
f270e74
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Dec 1, 2023
6bd6abd
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Dec 2, 2023
afb692c
Update docs to provide clearer layout-transposition example
wphicks Dec 2, 2023
eee7238
Update for increased implementation clarity based on review
wphicks Dec 4, 2023
864477e
Update cpp/include/raft/util/memory_type_dispatcher.cuh
wphicks Dec 4, 2023
edbad93
Use implicit void pointer cast
wphicks Dec 4, 2023
01b45e4
Add memory_type_dispatcher example to mdbuffer
wphicks Dec 4, 2023
93ca677
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Dec 4, 2023
272af80
Fix style
wphicks Dec 5, 2023
fbdafd0
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Dec 6, 2023
b0c87a5
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Dec 9, 2023
9c48d98
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Dec 11, 2023
f4e2e60
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Dec 13, 2023
ccb56ab
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Dec 14, 2023
c1db8a5
Allow implicit conversion to const mdbuffer from non-const mdspan
wphicks Dec 14, 2023
253ac7a
Safeguard default_container_policy against enum changes
wphicks Dec 14, 2023
bfdb234
Correctly mark make_*_view functions as constexpr
wphicks Dec 14, 2023
fca74aa
Remove commented-out deduction guide
wphicks Dec 14, 2023
f7d470e
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Dec 15, 2023
6c74dd0
Change spelling of policy selector
wphicks Dec 15, 2023
50373d2
Update usage of memory_type_to_default_policy_t
wphicks Dec 15, 2023
8d87cbd
Add clarifying information on const-ness
wphicks Dec 15, 2023
db06e7b
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Jan 2, 2024
0177cd4
Make enum values consistent with cudaMemoryType
wphicks Jan 2, 2024
e671bd6
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Jan 2, 2024
a993fb0
Merge branch 'branch-24.02' into fea-mdbuffer
wphicks Jan 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions cpp/include/raft/core/detail/fail_container_policy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <raft/core/error.hpp>
#include <raft/core/logger-macros.hpp>
#include <raft/core/resources.hpp>
#include <raft/thirdparty/mdspan/include/experimental/mdspan>
#include <stddef.h>

namespace raft {
namespace detail {

template <typename T>
struct fail_reference {
using value_type = typename std::remove_cv_t<T>;
using pointer = T*;
using const_pointer = T const*;

fail_reference() = default;
template <typename StreamViewType>
fail_reference(T* ptr, StreamViewType stream)
{
throw non_cuda_build_error{"Attempted to construct reference to device data in non-CUDA build"};
}

operator value_type() const // NOLINT
{
throw non_cuda_build_error{"Attempted to dereference device data in non-CUDA build"};
return value_type{};
}
auto operator=(T const& other) -> fail_reference&
{
throw non_cuda_build_error{"Attempted to assign to device data in non-CUDA build"};
return *this;
}
};

/** A placeholder container which throws an exception on use
*
* This placeholder is used in non-CUDA builds for container types that would
* otherwise be provided with CUDA code. Attempting to construct a non-empty
* container of this type throws an exception indicating that there was an
* attempt to use the device from a non-CUDA build. An example of when this
* might happen is if a downstream application attempts to allocate a device
* mdarray using a library built with non-CUDA RAFT.
*/
template <typename T>
struct fail_container {
using value_type = T;
using size_type = std::size_t;

using reference = fail_reference<T>;
using const_reference = fail_reference<T const>;

using pointer = value_type*;
using const_pointer = value_type const*;

using iterator = pointer;
using const_iterator = const_pointer;

explicit fail_container(size_t n = size_t{})
{
if (n != size_t{}) {
throw non_cuda_build_error{"Attempted to allocate device container in non-CUDA build"};
}
}

template <typename Index>
auto operator[](Index i) noexcept -> reference
{
RAFT_LOG_ERROR("Attempted to access device data in non-CUDA build");
return reference{};
}

template <typename Index>
auto operator[](Index i) const noexcept -> const_reference
{
RAFT_LOG_ERROR("Attempted to access device data in non-CUDA build");
return const_reference{};
}
void resize(size_t n)
{
if (n != size_t{}) {
throw non_cuda_build_error{"Attempted to allocate device container in non-CUDA build"};
}
}

[[nodiscard]] auto data() noexcept -> pointer { return nullptr; }
[[nodiscard]] auto data() const noexcept -> const_pointer { return nullptr; }
};

/** A placeholder container policy which throws an exception on use
*
* This placeholder is used in non-CUDA builds for container types that would
* otherwise be provided with CUDA code. Attempting to construct a non-empty
* container of this type throws an exception indicating that there was an
* attempt to use the device from a non-CUDA build. An example of when this
* might happen is if a downstream application attempts to allocate a device
* mdarray using a library built with non-CUDA RAFT.
*/
template <typename ElementType>
struct fail_container_policy {
using element_type = ElementType;
using container_type = fail_container<element_type>;
using pointer = typename container_type::pointer;
using const_pointer = typename container_type::const_pointer;
using reference = typename container_type::reference;
using const_reference = typename container_type::const_reference;

using accessor_policy = std::experimental::default_accessor<element_type>;
using const_accessor_policy = std::experimental::default_accessor<element_type const>;

auto create(raft::resources const& res, size_t n) -> container_type { return container_type(n); }

fail_container_policy() = default;

[[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference
{
return c[n];
}
[[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept
-> const_reference
{
return c[n];
}

[[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; }
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }
};

} // namespace detail
} // namespace raft
20 changes: 20 additions & 0 deletions cpp/include/raft/core/device_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* limitations under the License.
*/
#pragma once
#ifndef RAFT_DISABLE_CUDA
#include <raft/core/device_mdspan.hpp>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -196,3 +197,22 @@ class device_uvector_policy {
};

} // namespace raft
#else
#include <raft/core/detail/fail_container_policy.hpp>
namespace raft {

// Provide placeholders that will allow CPU-GPU interoperable codebases to
// compile in non-CUDA mode but which will throw exceptions at runtime on any
// attempt to touch device data

template <typename T>
using device_reference = detail::fail_reference<T>;

template <typename T>
using device_uvector = detail::fail_container<T>;

template <typename ElementType>
using device_uvector_policy = detail::fail_container_policy<ElementType>;
Comment on lines +208 to +215
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this and other such declarations to detail/fail_container_policy.hpp? Easier to find them all in one location, and just include the header here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized one disadvantage to that. It may be useful in certain contexts to use fail_* on their own without having CUDA compilation disabled. If we put the declarations in fail_container_policy.hpp there is no way to do that.


} // namespace raft
#endif
86 changes: 9 additions & 77 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ namespace raft {
template <typename AccessorPolicy>
using device_accessor = host_device_accessor<AccessorPolicy, memory_type::device>;

template <typename AccessorPolicy>
using managed_accessor = host_device_accessor<AccessorPolicy, memory_type::managed>;

/**
* @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location.
*/
Expand All @@ -38,12 +35,6 @@ template <typename ElementType,
typename AccessorPolicy = std::experimental::default_accessor<ElementType>>
using device_mdspan = mdspan<ElementType, Extents, LayoutPolicy, device_accessor<AccessorPolicy>>;

template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = std::experimental::default_accessor<ElementType>>
using managed_mdspan = mdspan<ElementType, Extents, LayoutPolicy, managed_accessor<AccessorPolicy>>;

template <typename T, bool B>
struct is_device_mdspan : std::false_type {};
template <typename T>
Expand All @@ -61,23 +52,6 @@ using is_input_device_mdspan_t = is_device_mdspan<T, is_input_mdspan_v<T>>;
template <typename T>
using is_output_device_mdspan_t = is_device_mdspan<T, is_output_mdspan_v<T>>;

template <typename T, bool B>
struct is_managed_mdspan : std::false_type {};
template <typename T>
struct is_managed_mdspan<T, true> : std::bool_constant<T::accessor_type::is_managed_accessible> {};

/**
* @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type
*/
template <typename T>
using is_managed_mdspan_t = is_managed_mdspan<T, is_mdspan_v<T>>;

template <typename T>
using is_input_managed_mdspan_t = is_managed_mdspan<T, is_input_mdspan_v<T>>;

template <typename T>
using is_output_managed_mdspan_t = is_managed_mdspan<T, is_output_mdspan_v<T>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a
* derived type
Expand All @@ -102,30 +76,6 @@ using enable_if_input_device_mdspan = std::enable_if_t<is_input_device_mdspan_v<
template <typename... Tn>
using enable_if_output_device_mdspan = std::enable_if_t<is_output_device_mdspan_v<Tn...>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_managed_mdspan_v = std::conjunction_v<is_managed_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_input_managed_mdspan_v =
std::conjunction_v<is_input_managed_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_output_managed_mdspan_v =
std::conjunction_v<is_output_managed_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_managed_mdspan = std::enable_if_t<is_managed_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_input_managed_mdspan = std::enable_if_t<is_input_managed_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_output_managed_mdspan = std::enable_if_t<is_output_managed_mdspan_v<Tn...>>;

/**
* @brief Shorthand for 0-dim host mdspan (scalar).
* @tparam ElementType the data type of the scalar element
Expand Down Expand Up @@ -186,7 +136,7 @@ using device_aligned_matrix_view =
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_right_padded<ElementType>>
auto make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols)
auto constexpr make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols)
{
using data_handle_type =
typename std::experimental::aligned_accessor<ElementType,
Expand All @@ -203,24 +153,6 @@ auto make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexTy
return device_aligned_matrix_view<ElementType, IndexType, LayoutPolicy>{aligned_pointer, extents};
}

/**
* @brief Create a raft::managed_mdspan
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param ptr Pointer to the data
* @param exts dimensionality of the array (series of integers)
* @return raft::managed_mdspan
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_managed_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
{
return make_mdspan<ElementType, IndexType, LayoutPolicy, true, true>(ptr, exts);
}

/**
* @brief Create a 0-dim (scalar) mdspan instance for device value.
*
Expand All @@ -229,7 +161,7 @@ auto make_managed_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
* @param[in] ptr on device to wrap
*/
template <typename ElementType, typename IndexType = std::uint32_t>
auto make_device_scalar_view(ElementType* ptr)
auto constexpr make_device_scalar_view(ElementType* ptr)
{
scalar_extent<IndexType> extents;
return device_scalar_view<ElementType, IndexType>{ptr, extents};
Expand All @@ -249,7 +181,7 @@ auto make_device_scalar_view(ElementType* ptr)
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols)
auto constexpr make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols)
{
matrix_extent<IndexType> extents{n_rows, n_cols};
return device_matrix_view<ElementType, IndexType, LayoutPolicy>{ptr, extents};
Expand All @@ -269,10 +201,10 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col
* @param[in] stride leading dimension / stride of data
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_strided_matrix_view(ElementType* ptr,
IndexType n_rows,
IndexType n_cols,
IndexType stride)
auto constexpr make_device_strided_matrix_view(ElementType* ptr,
IndexType n_rows,
IndexType n_cols,
IndexType stride)
{
constexpr auto is_row_major = std::is_same_v<LayoutPolicy, layout_c_contiguous>;
IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1;
Expand All @@ -295,7 +227,7 @@ auto make_device_strided_matrix_view(ElementType* ptr,
* @return raft::device_vector_view
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, IndexType n)
auto constexpr make_device_vector_view(ElementType* ptr, IndexType n)
{
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}
Expand All @@ -310,7 +242,7 @@ auto make_device_vector_view(ElementType* ptr, IndexType n)
* @return raft::device_vector_view
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(
auto constexpr make_device_vector_view(
ElementType* ptr,
const typename LayoutPolicy::template mapping<vector_extent<IndexType>>& mapping)
{
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/core/host_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ class host_vector_policy {
[[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; }
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }
};

} // namespace raft
12 changes: 11 additions & 1 deletion cpp/include/raft/core/host_device_accessor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,6 +42,16 @@ struct host_device_accessor : public AccessorPolicy {
using AccessorPolicy::AccessorPolicy;
using offset_policy = host_device_accessor;
host_device_accessor(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT

// Prevent implicit conversion from incompatible host_device_accessor types
template <memory_type OtherMemType>
host_device_accessor(host_device_accessor<AccessorPolicy, OtherMemType> const& that) = delete;

template <memory_type OtherMemType, typename = std::enable_if_t<mem_type == OtherMemType>>
host_device_accessor(host_device_accessor<AccessorPolicy, OtherMemType> const& that)
: AccessorPolicy{that}
{
}
};

} // namespace raft
Loading