Skip to content

Commit

Permalink
Temporary buffer to view host or device memory in device (#1313)
Browse files Browse the repository at this point in the history
closes #1299

Authors:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - William Hicks (https://github.com/wphicks)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1313
  • Loading branch information
divyegala authored Mar 8, 2023
1 parent 3ca7eac commit 206611c
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 0 deletions.
268 changes: 268 additions & 0 deletions cpp/include/raft/core/temporary_device_buffer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "device_mdarray.hpp"
#include "device_mdspan.hpp"

#include <raft/util/cudart_utils.hpp>

#include <variant>

namespace raft {

/**
* \defgroup TemporaryDeviceBuffer `raft::temporary_device_buffer` and associated factories
* @{
*/

/**
* @brief An object which provides temporary access on-device to memory from either a host or device
* pointer. This object provides a `view()` method that will provide a `raft::device_mdspan` that
* may be read-only depending on const-qualified nature of the input pointer.
*
* @tparam ElementType type of the input
* @tparam Extents raft::extents
* @tparam LayoutPolicy layout of the input
* @tparam ContainerPolicy container to be used to own device memory if needed
*/
template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
template <typename> typename ContainerPolicy = detail::device_uvector_policy>
class temporary_device_buffer {
using view_type = device_mdspan<ElementType, Extents, LayoutPolicy>;
using index_type = typename Extents::index_type;
using element_type = std::remove_cv_t<ElementType>;
using container_policy = ContainerPolicy<element_type>;
using owning_device_buffer =
device_mdarray<element_type, Extents, LayoutPolicy, container_policy>;
using data_store = std::variant<ElementType*, owning_device_buffer>;
static constexpr bool is_const_pointer_ = std::is_const_v<ElementType>;

public:
temporary_device_buffer(temporary_device_buffer const&) = delete;
temporary_device_buffer& operator=(temporary_device_buffer const&) = delete;

constexpr temporary_device_buffer(temporary_device_buffer&&) = default;
constexpr temporary_device_buffer& operator=(temporary_device_buffer&&) = default;

/**
* @brief Construct a new temporary device buffer object
*
* @param handle raft::device_resources
* @param data input pointer
* @param extents dimensions of input array
* @param write_back if true, any writes to the `view()` of this object will be copid
* back if the original pointer was in host memory
*/
temporary_device_buffer(device_resources const& handle,
ElementType* data,
Extents extents,
bool write_back = false)
: stream_(handle.get_stream()),
original_data_(data),
extents_{extents},
write_back_(write_back),
length_([this]() {
std::size_t length = 1;
for (std::size_t i = 0; i < extents_.rank(); ++i) {
length *= extents_.extent(i);
}
return length;
}()),
device_id_{get_device_for_address(data)}
{
if (device_id_ == -1) {
typename owning_device_buffer::mapping_type layout{extents_};
typename owning_device_buffer::container_policy_type policy{handle.get_stream()};

owning_device_buffer device_data{layout, policy};
raft::copy(device_data.data_handle(), data, length_, handle.get_stream());
data_ = data_store{std::in_place_index<1>, std::move(device_data)};
} else {
data_ = data_store{std::in_place_index<0>, data};
}
}

~temporary_device_buffer() noexcept(is_const_pointer_)
{
// only need to write data back for non const pointers
// when write_back=true and original pointer is in
// host memory
if constexpr (not is_const_pointer_) {
if (write_back_ && device_id_ == -1) {
raft::copy(original_data_, std::get<1>(data_).data_handle(), length_, stream_);
}
}
}

/**
* @brief Returns a `raft::device_mdspan`
*
* @return raft::device_mdspan
*/
auto view() -> view_type
{
if (device_id_ == -1) {
return std::get<1>(data_).view();
} else {
return make_mdspan<ElementType, index_type, LayoutPolicy, false, true>(original_data_,
extents_);
}
}

private:
rmm::cuda_stream_view stream_;
ElementType* original_data_;
data_store data_;
Extents extents_;
bool write_back_;
std::size_t length_;
int device_id_;
};

/**
* @brief Factory to create a `raft::temporary_device_buffer`
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // Initialize raft::device_mdarray and raft::extents
* // Can be either raft::device_mdarray or raft::host_mdarray
* auto exts = raft::make_extents<int>(5);
* auto array = raft::make_device_mdarray<int, int>(handle, exts);
*
* auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts);
* @endcode
*
* @tparam ElementType type of the input
* @tparam IndexType index type of `raft::extents`
* @tparam LayoutPolicy layout of the input
* @tparam ContainerPolicy container to be used to own device memory if needed
* @tparam Extents variadic dimensions for `raft::extents`
* @param handle raft::device_resources
* @param data input pointer
* @param extents dimensions of input array
* @param write_back if true, any writes to the `view()` of this object will be copid
* back if the original pointer was in host memory
* @return raft::temporary_device_buffer
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
template <typename> typename ContainerPolicy = detail::device_uvector_policy,
size_t... Extents>
auto make_temporary_device_buffer(raft::device_resources const& handle,
ElementType* data,
raft::extents<IndexType, Extents...> extents,
bool write_back = false)
{
return temporary_device_buffer<ElementType, decltype(extents), LayoutPolicy, ContainerPolicy>(
handle, data, extents, write_back);
}

/**
* @brief Factory to create a `raft::temporary_device_buffer` which produces a
* read-only `raft::device_mdspan` from `view()` method with
* `write_back=false`
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // Initialize raft::device_mdarray and raft::extents
* // Can be either raft::device_mdarray or raft::host_mdarray
* auto exts = raft::make_extents<int>(5);
* auto array = raft::make_device_mdarray<int, int>(handle, exts);
*
* auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts);
* @endcode
*
* @tparam ElementType type of the input
* @tparam IndexType index type of `raft::extents`
* @tparam LayoutPolicy layout of the input
* @tparam ContainerPolicy container to be used to own device memory if needed
* @tparam Extents variadic dimensions for `raft::extents`
* @param handle raft::device_resources
* @param data input pointer
* @param extents dimensions of input array
* @return raft::temporary_device_buffer
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
template <typename> typename ContainerPolicy = detail::device_uvector_policy,
size_t... Extents>
auto make_readonly_temporary_device_buffer(raft::device_resources const& handle,
ElementType* data,
raft::extents<IndexType, Extents...> extents)
{
return temporary_device_buffer<std::add_const_t<ElementType>,
decltype(extents),
LayoutPolicy,
ContainerPolicy>(handle, data, extents, false);
}

/**
* @brief Factory to create a `raft::temporary_device_buffer` which produces a
* writeable `raft::device_mdspan` from `view()` method with
* `write_back=true`
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // Initialize raft::host_mdarray and raft::extents
* // Can be either raft::device_mdarray or raft::host_mdarray
* auto exts = raft::make_extents<int>(5);
* auto array = raft::make_host_mdarray<int, int>(handle, exts);
*
* auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts);
* @endcode
*
* @tparam ElementType type of the input
* @tparam IndexType index type of `raft::extents`
* @tparam LayoutPolicy layout of the input
* @tparam ContainerPolicy container to be used to own device memory if needed
* @tparam Extents variadic dimensions for `raft::extents`
* @param handle raft::device_resources
* @param data input pointer
* @param extents dimensions of input array
* @return raft::temporary_device_buffer
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
template <typename> typename ContainerPolicy = detail::device_uvector_policy,
size_t... Extents,
typename = std::enable_if_t<not std::is_const_v<ElementType>>>
auto make_writeback_temporary_device_buffer(raft::device_resources const& handle,
ElementType* data,
raft::extents<IndexType, Extents...> extents)
{
return temporary_device_buffer<ElementType, decltype(extents), LayoutPolicy, ContainerPolicy>(
handle, data, extents, true);
}

/**@}*/

} // namespace raft
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ if(BUILD_TESTS)
test/core/memory_type.cpp
test/core/span.cpp
test/core/span.cu
test/core/temporary_device_buffer.cu
test/test.cpp
)

Expand Down
3 changes: 3 additions & 0 deletions cpp/test/core/mdarray.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../test_utils.cuh"

#include <gtest/gtest.h>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/host_mdarray.hpp>
Expand Down
84 changes: 84 additions & 0 deletions cpp/test/core/temporary_device_buffer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../test_utils.cuh"

#include <raft/core/host_mdarray.hpp>
#include <raft/core/temporary_device_buffer.hpp>

#include <rmm/device_uvector.hpp>

#include <gtest/gtest.h>

namespace raft {

TEST(TemporaryDeviceBuffer, DevicePointer)
{
{
raft::device_resources handle;
auto exts = raft::make_extents<int>(5);
auto array = raft::make_device_mdarray<int, int>(handle, exts);

auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts);

ASSERT_EQ(array.data_handle(), d_buf.view().data_handle());
static_assert(!std::is_const_v<typename decltype(d_buf.view())::element_type>,
"element_type should not be const");
}

{
raft::device_resources handle;
auto exts = raft::make_extents<int>(5);
auto array = raft::make_device_mdarray<int, int>(handle, exts);

auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts);

ASSERT_EQ(array.data_handle(), d_buf.view().data_handle());
static_assert(std::is_const_v<typename decltype(d_buf.view())::element_type>,
"element_type should be const");
}
}

TEST(TemporaryDeviceBuffer, HostPointerWithWriteBack)
{
raft::device_resources handle;
auto exts = raft::make_extents<int>(5);
auto array = raft::make_host_mdarray<int, int>(exts);
thrust::fill(array.data_handle(), array.data_handle() + array.extent(0), 1);
rmm::device_uvector<int> result(5, handle.get_stream());

{
auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts);
auto d_view = d_buf.view();

thrust::fill(rmm::exec_policy(handle.get_stream()),
d_view.data_handle(),
d_view.data_handle() + d_view.extent(0),
10);
raft::copy(result.data(), d_view.data_handle(), d_view.extent(0), handle.get_stream());

static_assert(!std::is_const_v<typename decltype(d_buf.view())::element_type>,
"element_type should not be const");
}

ASSERT_TRUE(raft::devArrMatchHost(array.data_handle(),
result.data(),
array.extent(0),
raft::Compare<int>(),
handle.get_stream()));
}

} // namespace raft
1 change: 1 addition & 0 deletions docs/source/cpp_api/mdspan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ This page provides C++ class references for the RAFT's 1d span and multi-dimensi
mdspan_mdspan.rst
mdspan_mdarray.rst
mdspan_span.rst
mdspan_temporary_device_buffer.rst
Loading

0 comments on commit 206611c

Please sign in to comment.