Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Temporary buffer to view host or device memory in device #1313

Merged
merged 12 commits into from
Mar 8, 2023
268 changes: 268 additions & 0 deletions cpp/include/raft/core/temporary_device_buffer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "device_mdarray.hpp"
#include "device_mdspan.hpp"

#include <raft/util/cudart_utils.hpp>

#include <variant>

namespace raft {

/**
* \defgroup TemporaryDeviceBuffer `raft::temporary_device_buffer` and associated factories
* @{
*/

/**
* @brief An object which provides temporary access on-device to memory from either a host or device
* pointer. This object provides a `view()` method that will provide a `raft::device_mdspan` that
* may be read-only depending on const-qualified nature of the input pointer.
*
* @tparam ElementType type of the input
* @tparam Extents raft::extents
* @tparam LayoutPolicy layout of the input
* @tparam ContainerPolicy container to be used to own device memory if needed
*/
template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
template <typename> typename ContainerPolicy = detail::device_uvector_policy>
class temporary_device_buffer {
using view_type = device_mdspan<ElementType, Extents, LayoutPolicy>;
using index_type = typename Extents::index_type;
using element_type = std::remove_cv_t<ElementType>;
using container_policy = ContainerPolicy<element_type>;
using owning_device_buffer =
device_mdarray<element_type, Extents, LayoutPolicy, container_policy>;
using data_store = std::variant<ElementType*, owning_device_buffer>;
static constexpr bool is_const_pointer_ = std::is_const_v<ElementType>;

public:
temporary_device_buffer(temporary_device_buffer const&) = delete;
temporary_device_buffer& operator=(temporary_device_buffer const&) = delete;

constexpr temporary_device_buffer(temporary_device_buffer&&) = default;
constexpr temporary_device_buffer& operator=(temporary_device_buffer&&) = default;

/**
* @brief Construct a new temporary device buffer object
*
* @param handle raft::device_resources
* @param data input pointer
* @param extents dimensions of input array
* @param write_back if true, any writes to the `view()` of this object will be copid
* back if the original pointer was in host memory
*/
temporary_device_buffer(device_resources const& handle,
ElementType* data,
Extents extents,
bool write_back = false)
: stream_(handle.get_stream()),
original_data_(data),
extents_{extents},
write_back_(write_back),
length_([this]() {
std::size_t length = 1;
for (std::size_t i = 0; i < extents_.rank(); ++i) {
length *= extents_.extent(i);
}
return length;
}()),
device_id_{get_device_for_address(data)}
{
if (device_id_ == -1) {
typename owning_device_buffer::mapping_type layout{extents_};
typename owning_device_buffer::container_policy_type policy{handle.get_stream()};

owning_device_buffer device_data{layout, policy};
raft::copy(device_data.data_handle(), data, length_, handle.get_stream());
data_ = data_store{std::in_place_index<1>, std::move(device_data)};
} else {
data_ = data_store{std::in_place_index<0>, data};
}
}

~temporary_device_buffer() noexcept(is_const_pointer_)
{
// only need to write data back for non const pointers
// when write_back=true and original pointer is in
// host memory
if constexpr (not is_const_pointer_) {
if (write_back_ && device_id_ == -1) {
raft::copy(original_data_, std::get<1>(data_).data_handle(), length_, stream_);
}
}
}

/**
* @brief Returns a `raft::device_mdspan`
*
* @return raft::device_mdspan
*/
auto view() -> view_type
{
if (device_id_ == -1) {
return std::get<1>(data_).view();
} else {
return make_mdspan<ElementType, index_type, LayoutPolicy, false, true>(original_data_,
extents_);
}
}

private:
rmm::cuda_stream_view stream_;
ElementType* original_data_;
data_store data_;
Extents extents_;
bool write_back_;
std::size_t length_;
int device_id_;
};

/**
* @brief Factory to create a `raft::temporary_device_buffer`
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // Initialize raft::device_mdarray and raft::extents
* // Can be either raft::device_mdarray or raft::host_mdarray
* auto exts = raft::make_extents<int>(5);
* auto array = raft::make_device_mdarray<int, int>(handle, exts);
*
* auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts);
* @endcode
*
* @tparam ElementType type of the input
* @tparam IndexType index type of `raft::extents`
* @tparam LayoutPolicy layout of the input
* @tparam ContainerPolicy container to be used to own device memory if needed
* @tparam Extents variadic dimensions for `raft::extents`
* @param handle raft::device_resources
divyegala marked this conversation as resolved.
Show resolved Hide resolved
* @param data input pointer
* @param extents dimensions of input array
* @param write_back if true, any writes to the `view()` of this object will be copid
* back if the original pointer was in host memory
* @return raft::temporary_device_buffer
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
template <typename> typename ContainerPolicy = detail::device_uvector_policy,
size_t... Extents>
auto make_temporary_device_buffer(raft::device_resources const& handle,
ElementType* data,
raft::extents<IndexType, Extents...> extents,
bool write_back = false)
{
return temporary_device_buffer<ElementType, decltype(extents), LayoutPolicy, ContainerPolicy>(
handle, data, extents, write_back);
}

/**
* @brief Factory to create a `raft::temporary_device_buffer` which produces a
* read-only `raft::device_mdspan` from `view()` method with
* `write_back=false`
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // Initialize raft::device_mdarray and raft::extents
* // Can be either raft::device_mdarray or raft::host_mdarray
* auto exts = raft::make_extents<int>(5);
* auto array = raft::make_device_mdarray<int, int>(handle, exts);
*
* auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts);
* @endcode
*
* @tparam ElementType type of the input
* @tparam IndexType index type of `raft::extents`
* @tparam LayoutPolicy layout of the input
* @tparam ContainerPolicy container to be used to own device memory if needed
* @tparam Extents variadic dimensions for `raft::extents`
* @param handle raft::device_resources
* @param data input pointer
* @param extents dimensions of input array
* @return raft::temporary_device_buffer
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
template <typename> typename ContainerPolicy = detail::device_uvector_policy,
size_t... Extents>
auto make_readonly_temporary_device_buffer(raft::device_resources const& handle,
ElementType* data,
raft::extents<IndexType, Extents...> extents)
{
return temporary_device_buffer<std::add_const_t<ElementType>,
decltype(extents),
LayoutPolicy,
ContainerPolicy>(handle, data, extents, false);
}

/**
* @brief Factory to create a `raft::temporary_device_buffer` which produces a
* writeable `raft::device_mdspan` from `view()` method with
* `write_back=true`
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // Initialize raft::host_mdarray and raft::extents
* // Can be either raft::device_mdarray or raft::host_mdarray
* auto exts = raft::make_extents<int>(5);
* auto array = raft::make_host_mdarray<int, int>(handle, exts);
*
* auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts);
* @endcode
*
* @tparam ElementType type of the input
* @tparam IndexType index type of `raft::extents`
* @tparam LayoutPolicy layout of the input
* @tparam ContainerPolicy container to be used to own device memory if needed
* @tparam Extents variadic dimensions for `raft::extents`
* @param handle raft::device_resources
* @param data input pointer
* @param extents dimensions of input array
* @return raft::temporary_device_buffer
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
template <typename> typename ContainerPolicy = detail::device_uvector_policy,
size_t... Extents,
typename = std::enable_if_t<not std::is_const_v<ElementType>>>
auto make_writeback_temporary_device_buffer(raft::device_resources const& handle,
ElementType* data,
raft::extents<IndexType, Extents...> extents)
{
return temporary_device_buffer<ElementType, decltype(extents), LayoutPolicy, ContainerPolicy>(
handle, data, extents, true);
}

/**@}*/

} // namespace raft
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ if(BUILD_TESTS)
test/core/memory_type.cpp
test/core/span.cpp
test/core/span.cu
test/core/temporary_device_buffer.cu
test/test.cpp
)

Expand Down
3 changes: 3 additions & 0 deletions cpp/test/core/mdarray.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../test_utils.cuh"

#include <gtest/gtest.h>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/host_mdarray.hpp>
Expand Down
84 changes: 84 additions & 0 deletions cpp/test/core/temporary_device_buffer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../test_utils.cuh"

#include <raft/core/host_mdarray.hpp>
#include <raft/core/temporary_device_buffer.hpp>

#include <rmm/device_uvector.hpp>

#include <gtest/gtest.h>

namespace raft {

TEST(TemporaryDeviceBuffer, DevicePointer)
{
{
raft::device_resources handle;
auto exts = raft::make_extents<int>(5);
auto array = raft::make_device_mdarray<int, int>(handle, exts);

auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts);

ASSERT_EQ(array.data_handle(), d_buf.view().data_handle());
static_assert(!std::is_const_v<typename decltype(d_buf.view())::element_type>,
"element_type should not be const");
}

{
raft::device_resources handle;
auto exts = raft::make_extents<int>(5);
auto array = raft::make_device_mdarray<int, int>(handle, exts);

auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts);

ASSERT_EQ(array.data_handle(), d_buf.view().data_handle());
static_assert(std::is_const_v<typename decltype(d_buf.view())::element_type>,
"element_type should be const");
}
}

TEST(TemporaryDeviceBuffer, HostPointerWithWriteBack)
{
raft::device_resources handle;
auto exts = raft::make_extents<int>(5);
auto array = raft::make_host_mdarray<int, int>(exts);
thrust::fill(array.data_handle(), array.data_handle() + array.extent(0), 1);
rmm::device_uvector<int> result(5, handle.get_stream());

{
auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts);
auto d_view = d_buf.view();

thrust::fill(rmm::exec_policy(handle.get_stream()),
d_view.data_handle(),
d_view.data_handle() + d_view.extent(0),
10);
raft::copy(result.data(), d_view.data_handle(), d_view.extent(0), handle.get_stream());

static_assert(!std::is_const_v<typename decltype(d_buf.view())::element_type>,
"element_type should not be const");
}

ASSERT_TRUE(raft::devArrMatchHost(array.data_handle(),
result.data(),
array.extent(0),
raft::Compare<int>(),
handle.get_stream()));
}

} // namespace raft
1 change: 1 addition & 0 deletions docs/source/cpp_api/mdspan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ This page provides C++ class references for the RAFT's 1d span and multi-dimensi
mdspan_mdspan.rst
mdspan_mdarray.rst
mdspan_span.rst
mdspan_temporary_device_buffer.rst
Loading