-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Temporary buffer to view host or device memory in device (#1313)
closes #1299 Authors: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - William Hicks (https://github.com/wphicks) - Corey J. Nolet (https://github.com/cjnolet) URL: #1313
- Loading branch information
Showing
6 changed files
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "device_mdarray.hpp" | ||
#include "device_mdspan.hpp" | ||
|
||
#include <raft/util/cudart_utils.hpp> | ||
|
||
#include <variant> | ||
|
||
namespace raft { | ||
|
||
/** | ||
* \defgroup TemporaryDeviceBuffer `raft::temporary_device_buffer` and associated factories | ||
* @{ | ||
*/ | ||
|
||
/** | ||
* @brief An object which provides temporary access on-device to memory from either a host or device | ||
* pointer. This object provides a `view()` method that will provide a `raft::device_mdspan` that | ||
* may be read-only depending on const-qualified nature of the input pointer. | ||
* | ||
* @tparam ElementType type of the input | ||
* @tparam Extents raft::extents | ||
* @tparam LayoutPolicy layout of the input | ||
* @tparam ContainerPolicy container to be used to own device memory if needed | ||
*/ | ||
template <typename ElementType, | ||
typename Extents, | ||
typename LayoutPolicy = layout_c_contiguous, | ||
template <typename> typename ContainerPolicy = detail::device_uvector_policy> | ||
class temporary_device_buffer { | ||
using view_type = device_mdspan<ElementType, Extents, LayoutPolicy>; | ||
using index_type = typename Extents::index_type; | ||
using element_type = std::remove_cv_t<ElementType>; | ||
using container_policy = ContainerPolicy<element_type>; | ||
using owning_device_buffer = | ||
device_mdarray<element_type, Extents, LayoutPolicy, container_policy>; | ||
using data_store = std::variant<ElementType*, owning_device_buffer>; | ||
static constexpr bool is_const_pointer_ = std::is_const_v<ElementType>; | ||
|
||
public: | ||
temporary_device_buffer(temporary_device_buffer const&) = delete; | ||
temporary_device_buffer& operator=(temporary_device_buffer const&) = delete; | ||
|
||
constexpr temporary_device_buffer(temporary_device_buffer&&) = default; | ||
constexpr temporary_device_buffer& operator=(temporary_device_buffer&&) = default; | ||
|
||
/** | ||
* @brief Construct a new temporary device buffer object | ||
* | ||
* @param handle raft::device_resources | ||
* @param data input pointer | ||
* @param extents dimensions of input array | ||
* @param write_back if true, any writes to the `view()` of this object will be copid | ||
* back if the original pointer was in host memory | ||
*/ | ||
temporary_device_buffer(device_resources const& handle, | ||
ElementType* data, | ||
Extents extents, | ||
bool write_back = false) | ||
: stream_(handle.get_stream()), | ||
original_data_(data), | ||
extents_{extents}, | ||
write_back_(write_back), | ||
length_([this]() { | ||
std::size_t length = 1; | ||
for (std::size_t i = 0; i < extents_.rank(); ++i) { | ||
length *= extents_.extent(i); | ||
} | ||
return length; | ||
}()), | ||
device_id_{get_device_for_address(data)} | ||
{ | ||
if (device_id_ == -1) { | ||
typename owning_device_buffer::mapping_type layout{extents_}; | ||
typename owning_device_buffer::container_policy_type policy{handle.get_stream()}; | ||
|
||
owning_device_buffer device_data{layout, policy}; | ||
raft::copy(device_data.data_handle(), data, length_, handle.get_stream()); | ||
data_ = data_store{std::in_place_index<1>, std::move(device_data)}; | ||
} else { | ||
data_ = data_store{std::in_place_index<0>, data}; | ||
} | ||
} | ||
|
||
~temporary_device_buffer() noexcept(is_const_pointer_) | ||
{ | ||
// only need to write data back for non const pointers | ||
// when write_back=true and original pointer is in | ||
// host memory | ||
if constexpr (not is_const_pointer_) { | ||
if (write_back_ && device_id_ == -1) { | ||
raft::copy(original_data_, std::get<1>(data_).data_handle(), length_, stream_); | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* @brief Returns a `raft::device_mdspan` | ||
* | ||
* @return raft::device_mdspan | ||
*/ | ||
auto view() -> view_type | ||
{ | ||
if (device_id_ == -1) { | ||
return std::get<1>(data_).view(); | ||
} else { | ||
return make_mdspan<ElementType, index_type, LayoutPolicy, false, true>(original_data_, | ||
extents_); | ||
} | ||
} | ||
|
||
private: | ||
rmm::cuda_stream_view stream_; | ||
ElementType* original_data_; | ||
data_store data_; | ||
Extents extents_; | ||
bool write_back_; | ||
std::size_t length_; | ||
int device_id_; | ||
}; | ||
|
||
/** | ||
* @brief Factory to create a `raft::temporary_device_buffer` | ||
* | ||
* @code{.cpp} | ||
* #include <raft/core/device_resources.hpp> | ||
* | ||
* raft::device_resources handle; | ||
* | ||
* // Initialize raft::device_mdarray and raft::extents | ||
* // Can be either raft::device_mdarray or raft::host_mdarray | ||
* auto exts = raft::make_extents<int>(5); | ||
* auto array = raft::make_device_mdarray<int, int>(handle, exts); | ||
* | ||
* auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts); | ||
* @endcode | ||
* | ||
* @tparam ElementType type of the input | ||
* @tparam IndexType index type of `raft::extents` | ||
* @tparam LayoutPolicy layout of the input | ||
* @tparam ContainerPolicy container to be used to own device memory if needed | ||
* @tparam Extents variadic dimensions for `raft::extents` | ||
* @param handle raft::device_resources | ||
* @param data input pointer | ||
* @param extents dimensions of input array | ||
* @param write_back if true, any writes to the `view()` of this object will be copid | ||
* back if the original pointer was in host memory | ||
* @return raft::temporary_device_buffer | ||
*/ | ||
template <typename ElementType, | ||
typename IndexType = std::uint32_t, | ||
typename LayoutPolicy = layout_c_contiguous, | ||
template <typename> typename ContainerPolicy = detail::device_uvector_policy, | ||
size_t... Extents> | ||
auto make_temporary_device_buffer(raft::device_resources const& handle, | ||
ElementType* data, | ||
raft::extents<IndexType, Extents...> extents, | ||
bool write_back = false) | ||
{ | ||
return temporary_device_buffer<ElementType, decltype(extents), LayoutPolicy, ContainerPolicy>( | ||
handle, data, extents, write_back); | ||
} | ||
|
||
/** | ||
* @brief Factory to create a `raft::temporary_device_buffer` which produces a | ||
* read-only `raft::device_mdspan` from `view()` method with | ||
* `write_back=false` | ||
* | ||
* @code{.cpp} | ||
* #include <raft/core/device_resources.hpp> | ||
* | ||
* raft::device_resources handle; | ||
* | ||
* // Initialize raft::device_mdarray and raft::extents | ||
* // Can be either raft::device_mdarray or raft::host_mdarray | ||
* auto exts = raft::make_extents<int>(5); | ||
* auto array = raft::make_device_mdarray<int, int>(handle, exts); | ||
* | ||
* auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts); | ||
* @endcode | ||
* | ||
* @tparam ElementType type of the input | ||
* @tparam IndexType index type of `raft::extents` | ||
* @tparam LayoutPolicy layout of the input | ||
* @tparam ContainerPolicy container to be used to own device memory if needed | ||
* @tparam Extents variadic dimensions for `raft::extents` | ||
* @param handle raft::device_resources | ||
* @param data input pointer | ||
* @param extents dimensions of input array | ||
* @return raft::temporary_device_buffer | ||
*/ | ||
template <typename ElementType, | ||
typename IndexType = std::uint32_t, | ||
typename LayoutPolicy = layout_c_contiguous, | ||
template <typename> typename ContainerPolicy = detail::device_uvector_policy, | ||
size_t... Extents> | ||
auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, | ||
ElementType* data, | ||
raft::extents<IndexType, Extents...> extents) | ||
{ | ||
return temporary_device_buffer<std::add_const_t<ElementType>, | ||
decltype(extents), | ||
LayoutPolicy, | ||
ContainerPolicy>(handle, data, extents, false); | ||
} | ||
|
||
/** | ||
* @brief Factory to create a `raft::temporary_device_buffer` which produces a | ||
* writeable `raft::device_mdspan` from `view()` method with | ||
* `write_back=true` | ||
* | ||
* @code{.cpp} | ||
* #include <raft/core/device_resources.hpp> | ||
* | ||
* raft::device_resources handle; | ||
* | ||
* // Initialize raft::host_mdarray and raft::extents | ||
* // Can be either raft::device_mdarray or raft::host_mdarray | ||
* auto exts = raft::make_extents<int>(5); | ||
* auto array = raft::make_host_mdarray<int, int>(handle, exts); | ||
* | ||
* auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts); | ||
* @endcode | ||
* | ||
* @tparam ElementType type of the input | ||
* @tparam IndexType index type of `raft::extents` | ||
* @tparam LayoutPolicy layout of the input | ||
* @tparam ContainerPolicy container to be used to own device memory if needed | ||
* @tparam Extents variadic dimensions for `raft::extents` | ||
* @param handle raft::device_resources | ||
* @param data input pointer | ||
* @param extents dimensions of input array | ||
* @return raft::temporary_device_buffer | ||
*/ | ||
template <typename ElementType, | ||
typename IndexType = std::uint32_t, | ||
typename LayoutPolicy = layout_c_contiguous, | ||
template <typename> typename ContainerPolicy = detail::device_uvector_policy, | ||
size_t... Extents, | ||
typename = std::enable_if_t<not std::is_const_v<ElementType>>> | ||
auto make_writeback_temporary_device_buffer(raft::device_resources const& handle, | ||
ElementType* data, | ||
raft::extents<IndexType, Extents...> extents) | ||
{ | ||
return temporary_device_buffer<ElementType, decltype(extents), LayoutPolicy, ContainerPolicy>( | ||
handle, data, extents, true); | ||
} | ||
|
||
/**@}*/ | ||
|
||
} // namespace raft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* | ||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "../test_utils.cuh" | ||
|
||
#include <raft/core/host_mdarray.hpp> | ||
#include <raft/core/temporary_device_buffer.hpp> | ||
|
||
#include <rmm/device_uvector.hpp> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
namespace raft { | ||
|
||
TEST(TemporaryDeviceBuffer, DevicePointer) | ||
{ | ||
{ | ||
raft::device_resources handle; | ||
auto exts = raft::make_extents<int>(5); | ||
auto array = raft::make_device_mdarray<int, int>(handle, exts); | ||
|
||
auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts); | ||
|
||
ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); | ||
static_assert(!std::is_const_v<typename decltype(d_buf.view())::element_type>, | ||
"element_type should not be const"); | ||
} | ||
|
||
{ | ||
raft::device_resources handle; | ||
auto exts = raft::make_extents<int>(5); | ||
auto array = raft::make_device_mdarray<int, int>(handle, exts); | ||
|
||
auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts); | ||
|
||
ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); | ||
static_assert(std::is_const_v<typename decltype(d_buf.view())::element_type>, | ||
"element_type should be const"); | ||
} | ||
} | ||
|
||
TEST(TemporaryDeviceBuffer, HostPointerWithWriteBack) | ||
{ | ||
raft::device_resources handle; | ||
auto exts = raft::make_extents<int>(5); | ||
auto array = raft::make_host_mdarray<int, int>(exts); | ||
thrust::fill(array.data_handle(), array.data_handle() + array.extent(0), 1); | ||
rmm::device_uvector<int> result(5, handle.get_stream()); | ||
|
||
{ | ||
auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts); | ||
auto d_view = d_buf.view(); | ||
|
||
thrust::fill(rmm::exec_policy(handle.get_stream()), | ||
d_view.data_handle(), | ||
d_view.data_handle() + d_view.extent(0), | ||
10); | ||
raft::copy(result.data(), d_view.data_handle(), d_view.extent(0), handle.get_stream()); | ||
|
||
static_assert(!std::is_const_v<typename decltype(d_buf.view())::element_type>, | ||
"element_type should not be const"); | ||
} | ||
|
||
ASSERT_TRUE(raft::devArrMatchHost(array.data_handle(), | ||
result.data(), | ||
array.extent(0), | ||
raft::Compare<int>(), | ||
handle.get_stream())); | ||
} | ||
|
||
} // namespace raft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.