diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp new file mode 100644 index 0000000000..5e6ae84eb5 --- /dev/null +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "device_mdarray.hpp" +#include "device_mdspan.hpp" + +#include + +#include + +namespace raft { + +/** + * \defgroup TemporaryDeviceBuffer `raft::temporary_device_buffer` and associated factories + * @{ + */ + +/** + * @brief An object which provides temporary access on-device to memory from either a host or device + * pointer. This object provides a `view()` method that will provide a `raft::device_mdspan` that + * may be read-only depending on const-qualified nature of the input pointer. + * + * @tparam ElementType type of the input + * @tparam Extents raft::extents + * @tparam LayoutPolicy layout of the input + * @tparam ContainerPolicy container to be used to own device memory if needed + */ +template typename ContainerPolicy = detail::device_uvector_policy> +class temporary_device_buffer { + using view_type = device_mdspan; + using index_type = typename Extents::index_type; + using element_type = std::remove_cv_t; + using container_policy = ContainerPolicy; + using owning_device_buffer = + device_mdarray; + using data_store = std::variant; + static constexpr bool is_const_pointer_ = std::is_const_v; + + public: + temporary_device_buffer(temporary_device_buffer const&) = delete; + temporary_device_buffer& operator=(temporary_device_buffer const&) = delete; + + constexpr temporary_device_buffer(temporary_device_buffer&&) = default; + constexpr temporary_device_buffer& operator=(temporary_device_buffer&&) = default; + + /** + * @brief Construct a new temporary device buffer object + * + * @param handle raft::device_resources + * @param data input pointer + * @param extents dimensions of input array + * @param write_back if true, any writes to the `view()` of this object will be copid + * back if the original pointer was in host memory + */ + temporary_device_buffer(device_resources const& handle, + ElementType* data, + Extents extents, + bool write_back = false) + : stream_(handle.get_stream()), + original_data_(data), + extents_{extents}, + write_back_(write_back), + length_([this]() { + std::size_t length = 1; + for (std::size_t i = 0; i < extents_.rank(); ++i) { + length *= extents_.extent(i); + } + return length; + }()), + device_id_{get_device_for_address(data)} + { + if (device_id_ == -1) { + typename owning_device_buffer::mapping_type layout{extents_}; + typename owning_device_buffer::container_policy_type policy{handle.get_stream()}; + + owning_device_buffer device_data{layout, policy}; + raft::copy(device_data.data_handle(), data, length_, handle.get_stream()); + data_ = data_store{std::in_place_index<1>, std::move(device_data)}; + } else { + data_ = data_store{std::in_place_index<0>, data}; + } + } + + ~temporary_device_buffer() noexcept(is_const_pointer_) + { + // only need to write data back for non const pointers + // when write_back=true and original pointer is in + // host memory + if constexpr (not is_const_pointer_) { + if (write_back_ && device_id_ == -1) { + raft::copy(original_data_, std::get<1>(data_).data_handle(), length_, stream_); + } + } + } + + /** + * @brief Returns a `raft::device_mdspan` + * + * @return raft::device_mdspan + */ + auto view() -> view_type + { + if (device_id_ == -1) { + return std::get<1>(data_).view(); + } else { + return make_mdspan(original_data_, + extents_); + } + } + + private: + rmm::cuda_stream_view stream_; + ElementType* original_data_; + data_store data_; + Extents extents_; + bool write_back_; + std::size_t length_; + int device_id_; +}; + +/** + * @brief Factory to create a `raft::temporary_device_buffer` + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // Initialize raft::device_mdarray and raft::extents + * // Can be either raft::device_mdarray or raft::host_mdarray + * auto exts = raft::make_extents(5); + * auto array = raft::make_device_mdarray(handle, exts); + * + * auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts); + * @endcode + * + * @tparam ElementType type of the input + * @tparam IndexType index type of `raft::extents` + * @tparam LayoutPolicy layout of the input + * @tparam ContainerPolicy container to be used to own device memory if needed + * @tparam Extents variadic dimensions for `raft::extents` + * @param handle raft::device_resources + * @param data input pointer + * @param extents dimensions of input array + * @param write_back if true, any writes to the `view()` of this object will be copid + * back if the original pointer was in host memory + * @return raft::temporary_device_buffer + */ +template typename ContainerPolicy = detail::device_uvector_policy, + size_t... Extents> +auto make_temporary_device_buffer(raft::device_resources const& handle, + ElementType* data, + raft::extents extents, + bool write_back = false) +{ + return temporary_device_buffer( + handle, data, extents, write_back); +} + +/** + * @brief Factory to create a `raft::temporary_device_buffer` which produces a + * read-only `raft::device_mdspan` from `view()` method with + * `write_back=false` + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // Initialize raft::device_mdarray and raft::extents + * // Can be either raft::device_mdarray or raft::host_mdarray + * auto exts = raft::make_extents(5); + * auto array = raft::make_device_mdarray(handle, exts); + * + * auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts); + * @endcode + * + * @tparam ElementType type of the input + * @tparam IndexType index type of `raft::extents` + * @tparam LayoutPolicy layout of the input + * @tparam ContainerPolicy container to be used to own device memory if needed + * @tparam Extents variadic dimensions for `raft::extents` + * @param handle raft::device_resources + * @param data input pointer + * @param extents dimensions of input array + * @return raft::temporary_device_buffer + */ +template typename ContainerPolicy = detail::device_uvector_policy, + size_t... Extents> +auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, + ElementType* data, + raft::extents extents) +{ + return temporary_device_buffer, + decltype(extents), + LayoutPolicy, + ContainerPolicy>(handle, data, extents, false); +} + +/** + * @brief Factory to create a `raft::temporary_device_buffer` which produces a + * writeable `raft::device_mdspan` from `view()` method with + * `write_back=true` + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // Initialize raft::host_mdarray and raft::extents + * // Can be either raft::device_mdarray or raft::host_mdarray + * auto exts = raft::make_extents(5); + * auto array = raft::make_host_mdarray(handle, exts); + * + * auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts); + * @endcode + * + * @tparam ElementType type of the input + * @tparam IndexType index type of `raft::extents` + * @tparam LayoutPolicy layout of the input + * @tparam ContainerPolicy container to be used to own device memory if needed + * @tparam Extents variadic dimensions for `raft::extents` + * @param handle raft::device_resources + * @param data input pointer + * @param extents dimensions of input array + * @return raft::temporary_device_buffer + */ +template typename ContainerPolicy = detail::device_uvector_policy, + size_t... Extents, + typename = std::enable_if_t>> +auto make_writeback_temporary_device_buffer(raft::device_resources const& handle, + ElementType* data, + raft::extents extents) +{ + return temporary_device_buffer( + handle, data, extents, true); +} + +/**@}*/ + +} // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 4b633864a3..26ec8ebf74 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -109,6 +109,7 @@ if(BUILD_TESTS) test/core/memory_type.cpp test/core/span.cpp test/core/span.cu + test/core/temporary_device_buffer.cu test/test.cpp ) diff --git a/cpp/test/core/mdarray.cu b/cpp/test/core/mdarray.cu index 018b8a4e5a..5eff6e7539 100644 --- a/cpp/test/core/mdarray.cu +++ b/cpp/test/core/mdarray.cu @@ -13,6 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include "../test_utils.cuh" + #include #include #include diff --git a/cpp/test/core/temporary_device_buffer.cu b/cpp/test/core/temporary_device_buffer.cu new file mode 100644 index 0000000000..52a2ec4c9b --- /dev/null +++ b/cpp/test/core/temporary_device_buffer.cu @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include + +#include + +#include + +namespace raft { + +TEST(TemporaryDeviceBuffer, DevicePointer) +{ + { + raft::device_resources handle; + auto exts = raft::make_extents(5); + auto array = raft::make_device_mdarray(handle, exts); + + auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts); + + ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); + static_assert(!std::is_const_v, + "element_type should not be const"); + } + + { + raft::device_resources handle; + auto exts = raft::make_extents(5); + auto array = raft::make_device_mdarray(handle, exts); + + auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts); + + ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); + static_assert(std::is_const_v, + "element_type should be const"); + } +} + +TEST(TemporaryDeviceBuffer, HostPointerWithWriteBack) +{ + raft::device_resources handle; + auto exts = raft::make_extents(5); + auto array = raft::make_host_mdarray(exts); + thrust::fill(array.data_handle(), array.data_handle() + array.extent(0), 1); + rmm::device_uvector result(5, handle.get_stream()); + + { + auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts); + auto d_view = d_buf.view(); + + thrust::fill(rmm::exec_policy(handle.get_stream()), + d_view.data_handle(), + d_view.data_handle() + d_view.extent(0), + 10); + raft::copy(result.data(), d_view.data_handle(), d_view.extent(0), handle.get_stream()); + + static_assert(!std::is_const_v, + "element_type should not be const"); + } + + ASSERT_TRUE(raft::devArrMatchHost(array.data_handle(), + result.data(), + array.extent(0), + raft::Compare(), + handle.get_stream())); +} + +} // namespace raft diff --git a/docs/source/cpp_api/mdspan.rst b/docs/source/cpp_api/mdspan.rst index af38247c01..3fc0db7b96 100644 --- a/docs/source/cpp_api/mdspan.rst +++ b/docs/source/cpp_api/mdspan.rst @@ -16,3 +16,4 @@ This page provides C++ class references for the RAFT's 1d span and multi-dimensi mdspan_mdspan.rst mdspan_mdarray.rst mdspan_span.rst + mdspan_temporary_device_buffer.rst diff --git a/docs/source/cpp_api/mdspan_temporary_device_buffer.rst b/docs/source/cpp_api/mdspan_temporary_device_buffer.rst new file mode 100644 index 0000000000..90d08ac5bb --- /dev/null +++ b/docs/source/cpp_api/mdspan_temporary_device_buffer.rst @@ -0,0 +1,23 @@ +temporary_device_buffer: Temporary raft::device_mdspan Producing Object +=========================================== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +.. doxygenclass:: raft::temporary_device_buffer + :project: RAFT + :members: + +Factories +--------- +.. doxygenfunction:: raft::make_temporary_device_buffer + :project: RAFT + +.. doxygenfunction:: raft::make_readonly_temporary_device_buffer + :project: RAFT + +.. doxygenfunction:: raft::make_writeback_temporary_device_buffer + :project: RAFT