-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Separating mdspan/mdarray infra into host_* and device_* variants (#810)
This is a breaking change as it provides users with a more granular set of headers to import `host` separately from `device` and `managed` versions. It also separates the headers for `mdspan` and `mdarray`. As an example, the following public headers can now be imported individually: ```c++ raft/core/host_mdspan.hpp raft/core/device_mdspan.hpp raft/core/host_mdarray.hpp raft/core/device_mdarray.hpp ``` cc @rg20 @afender @akifcorduk for awareness. Closes #806 Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Divye Gala (https://github.com/divyegala) - Mark Hoemmen (https://github.com/mhoemmen) - William Hicks (https://github.com/wphicks) URL: #810
- Loading branch information
Showing
30 changed files
with
1,300 additions
and
1,016 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
namespace raft::detail { | ||
|
||
/** | ||
* @brief A mixin to distinguish host and device memory. | ||
*/ | ||
template <typename AccessorPolicy, bool is_host, bool is_device> | ||
struct host_device_accessor : public AccessorPolicy { | ||
using accessor_type = AccessorPolicy; | ||
using is_host_type = std::conditional_t<is_host, std::true_type, std::false_type>; | ||
using is_device_type = std::conditional_t<is_device, std::true_type, std::false_type>; | ||
using is_managed_type = std::conditional_t<is_device && is_host, std::true_type, std::false_type>; | ||
static constexpr bool is_host_accessible = is_host; | ||
static constexpr bool is_device_accessible = is_device; | ||
static constexpr bool is_managed_accessible = is_device && is_host; | ||
// make sure the explicit ctor can fall through | ||
using AccessorPolicy::AccessorPolicy; | ||
using offset_policy = host_device_accessor; | ||
host_device_accessor(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT | ||
}; | ||
|
||
} // namespace raft::detail |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Copyright (2019) Sandia Corporation | ||
* | ||
* The source code is licensed under the 3-clause BSD license found in the LICENSE file | ||
* thirdparty/LICENSES/mdarray.license | ||
*/ | ||
|
||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
#include <raft/core/mdspan_types.hpp> | ||
#include <vector> | ||
|
||
namespace raft::detail { | ||
|
||
/** | ||
* @brief A container policy for host mdarray. | ||
*/ | ||
template <typename ElementType, typename Allocator = std::allocator<ElementType>> | ||
class host_vector_policy { | ||
public: | ||
using element_type = ElementType; | ||
using container_type = std::vector<element_type, Allocator>; | ||
using allocator_type = typename container_type::allocator_type; | ||
using pointer = typename container_type::pointer; | ||
using const_pointer = typename container_type::const_pointer; | ||
using reference = element_type&; | ||
using const_reference = element_type const&; | ||
using accessor_policy = std::experimental::default_accessor<element_type>; | ||
using const_accessor_policy = std::experimental::default_accessor<element_type const>; | ||
|
||
public: | ||
auto create(size_t n) -> container_type { return container_type(n); } | ||
|
||
constexpr host_vector_policy() noexcept(std::is_nothrow_default_constructible_v<ElementType>) = | ||
default; | ||
explicit constexpr host_vector_policy(rmm::cuda_stream_view) noexcept( | ||
std::is_nothrow_default_constructible_v<ElementType>) | ||
: host_vector_policy() | ||
{ | ||
} | ||
|
||
[[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference | ||
{ | ||
return c[n]; | ||
} | ||
[[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept | ||
-> const_reference | ||
{ | ||
return c[n]; | ||
} | ||
|
||
[[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } | ||
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } | ||
}; | ||
} // namespace raft::detail |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#ifndef _RAFT_HAS_CUDA | ||
#if defined(__CUDACC__) | ||
#define _RAFT_HAS_CUDA __CUDACC__ | ||
#endif | ||
#endif | ||
|
||
#ifndef _RAFT_HOST_DEVICE | ||
#if defined(_RAFT_HAS_CUDA) | ||
#define _RAFT_HOST_DEVICE __host__ __device__ | ||
#else | ||
#define _RAFT_HOST_DEVICE | ||
#endif | ||
#endif | ||
|
||
#ifndef RAFT_INLINE_FUNCTION | ||
#define RAFT_INLINE_FUNCTION inline _RAFT_HOST_DEVICE | ||
#endif |
Oops, something went wrong.