From cdb8dc88152be24036a74b8b24da3d9fc97fdb44 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 1 Mar 2023 14:20:34 -0800 Subject: [PATCH 1/8] passing tests --- cpp/include/raft/core/device_buffer.hpp | 108 ++++++++++++++++++++++++ cpp/test/core/mdarray.cu | 60 +++++++++++++ 2 files changed, 168 insertions(+) create mode 100644 cpp/include/raft/core/device_buffer.hpp diff --git a/cpp/include/raft/core/device_buffer.hpp b/cpp/include/raft/core/device_buffer.hpp new file mode 100644 index 0000000000..2853af8bb4 --- /dev/null +++ b/cpp/include/raft/core/device_buffer.hpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "device_mdarray.hpp" +#include "device_mdspan.hpp" + +#include + +#include + +namespace raft { + +template >> +class device_buffer { + using view_type = device_mdspan; + using index_type = typename Extents::index_type; + using element_type = std::remove_cv_t; + using owning_device_buffer = device_mdarray; + using data_store = std::variant; + + public: + device_buffer(device_buffer const&) = delete; + device_buffer& operator=(device_buffer const&) = delete; + + constexpr device_buffer(device_buffer&&) = default; + constexpr device_buffer& operator=(device_buffer&&) = default; + + device_buffer(device_resources const& handle, + ElementType* data, + Extents extents, + bool write_back = false) + : stream_(handle.get_stream()), + original_data_(data), + extents_{extents}, + write_back_(write_back), + length_([this]() { + std::size_t length = 1; + for (std::size_t i = 0; i < extents_.rank(); ++i) { + length *= extents_.extent(i); + } + return length; + }()), + device_id_{get_device_for_address(data)} + { + if (device_id_ == -1) { + typename owning_device_buffer::mapping_type layout{extents_}; + typename owning_device_buffer::container_policy_type policy{handle.get_stream()}; + + owning_device_buffer device_data{layout, policy}; + raft::copy(device_data.data_handle(), data, length_, handle.get_stream()); + data_ = data_store{std::in_place_index<1>, std::move(device_data)}; + } else { + data_ = data_store{std::in_place_index<0>, data}; + } + } + + ~device_buffer() + { + // only need to write data back for non const pointers + // when write_back=true and original pointer is in + // host memory + if constexpr (not is_const_pointer_) { + if (write_back_ && device_id_ == -1) { + raft::copy(original_data_, std::get<1>(data_).data_handle(), length_, stream_); + } + } + } + + auto view() -> view_type + { + if (device_id_ == -1) { + return std::get<1>(data_).view(); + } else { + return make_mdspan(original_data_, + extents_); + } + } + + private: + static constexpr bool is_const_pointer_ = std::is_const_v; + rmm::cuda_stream_view stream_; + ElementType* original_data_; + data_store data_; + Extents extents_; + bool write_back_; + std::size_t length_; + int device_id_; +}; + +} // namespace raft diff --git a/cpp/test/core/mdarray.cu b/cpp/test/core/mdarray.cu index 018b8a4e5a..0daad272f7 100644 --- a/cpp/test/core/mdarray.cu +++ b/cpp/test/core/mdarray.cu @@ -13,7 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include "../test_utils.cuh" + #include +#include #include #include #include @@ -947,4 +951,60 @@ void test_mdarray_unravel() TEST(MDArray, Unravel) { test_mdarray_unravel(); } +TEST(DeviceBuffer, DevicePointer) +{ + { + raft::device_resources handle; + auto exts = raft::make_extents(5); + auto array = raft::make_device_mdarray(handle, exts); + + raft::device_buffer d_buf{handle, array.data_handle(), exts}; + + ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); + static_assert(!std::is_const_v, + "element_type should not be const"); + } + + { + raft::device_resources handle; + auto exts = raft::make_extents(5); + auto const array = raft::make_device_mdarray(handle, exts); + + raft::device_buffer d_buf{handle, array.data_handle(), exts}; + + ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); + static_assert(std::is_const_v, + "element_type should be const"); + } +} + +TEST(DeviceBuffer, HostPointerWithWriteBack) +{ + raft::device_resources handle; + auto exts = raft::make_extents(5); + auto array = raft::make_host_mdarray(exts); + thrust::fill(array.data_handle(), array.data_handle() + array.extent(0), 1); + rmm::device_uvector result(5, handle.get_stream()); + + { + raft::device_buffer d_buf{handle, array.data_handle(), exts, true}; + auto d_view = d_buf.view(); + + thrust::fill(rmm::exec_policy(handle.get_stream()), + d_view.data_handle(), + d_view.data_handle() + d_view.extent(0), + 10); + raft::copy(result.data(), d_view.data_handle(), d_view.extent(0), handle.get_stream()); + + static_assert(!std::is_const_v, + "element_type should not be const"); + } + + ASSERT_TRUE(raft::devArrMatchHost(array.data_handle(), + result.data(), + array.extent(0), + raft::Compare(), + handle.get_stream())); +} + } // namespace raft From f41a2a6cfb5fc0b1480208fb2ad2ff5d1aa8ff16 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 6 Mar 2023 09:30:04 -0800 Subject: [PATCH 2/8] review feedback, documentation --- cpp/include/raft/core/device_buffer.hpp | 108 --------- .../raft/core/temporary_device_buffer.hpp | 220 ++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/core/mdarray.cu | 57 ----- cpp/test/core/temporary_device_buffer.cu | 87 +++++++ 5 files changed, 308 insertions(+), 165 deletions(-) delete mode 100644 cpp/include/raft/core/device_buffer.hpp create mode 100644 cpp/include/raft/core/temporary_device_buffer.hpp create mode 100644 cpp/test/core/temporary_device_buffer.cu diff --git a/cpp/include/raft/core/device_buffer.hpp b/cpp/include/raft/core/device_buffer.hpp deleted file mode 100644 index 2853af8bb4..0000000000 --- a/cpp/include/raft/core/device_buffer.hpp +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "device_mdarray.hpp" -#include "device_mdspan.hpp" - -#include - -#include - -namespace raft { - -template >> -class device_buffer { - using view_type = device_mdspan; - using index_type = typename Extents::index_type; - using element_type = std::remove_cv_t; - using owning_device_buffer = device_mdarray; - using data_store = std::variant; - - public: - device_buffer(device_buffer const&) = delete; - device_buffer& operator=(device_buffer const&) = delete; - - constexpr device_buffer(device_buffer&&) = default; - constexpr device_buffer& operator=(device_buffer&&) = default; - - device_buffer(device_resources const& handle, - ElementType* data, - Extents extents, - bool write_back = false) - : stream_(handle.get_stream()), - original_data_(data), - extents_{extents}, - write_back_(write_back), - length_([this]() { - std::size_t length = 1; - for (std::size_t i = 0; i < extents_.rank(); ++i) { - length *= extents_.extent(i); - } - return length; - }()), - device_id_{get_device_for_address(data)} - { - if (device_id_ == -1) { - typename owning_device_buffer::mapping_type layout{extents_}; - typename owning_device_buffer::container_policy_type policy{handle.get_stream()}; - - owning_device_buffer device_data{layout, policy}; - raft::copy(device_data.data_handle(), data, length_, handle.get_stream()); - data_ = data_store{std::in_place_index<1>, std::move(device_data)}; - } else { - data_ = data_store{std::in_place_index<0>, data}; - } - } - - ~device_buffer() - { - // only need to write data back for non const pointers - // when write_back=true and original pointer is in - // host memory - if constexpr (not is_const_pointer_) { - if (write_back_ && device_id_ == -1) { - raft::copy(original_data_, std::get<1>(data_).data_handle(), length_, stream_); - } - } - } - - auto view() -> view_type - { - if (device_id_ == -1) { - return std::get<1>(data_).view(); - } else { - return make_mdspan(original_data_, - extents_); - } - } - - private: - static constexpr bool is_const_pointer_ = std::is_const_v; - rmm::cuda_stream_view stream_; - ElementType* original_data_; - data_store data_; - Extents extents_; - bool write_back_; - std::size_t length_; - int device_id_; -}; - -} // namespace raft diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp new file mode 100644 index 0000000000..6bf59717e0 --- /dev/null +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "device_mdarray.hpp" +#include "device_mdspan.hpp" + +#include + +#include + +namespace raft { + +/** + * @brief An object to have temporary representation of either a host/device pointer in device + * memory. This object provides a `view()` method that will provide a `raft::device_mdspan` that may + * be read-only depending on const-qualified nature of the input pointer. + * + * @tparam ElementType type of the input + * @tparam Extents raft::extents + * @tparam LayoutPolicy layout of the input + * @tparam ContainerPolicy container to be used to own device memory if needed + */ +template >> +class temporary_device_buffer { + using view_type = device_mdspan; + using index_type = typename Extents::index_type; + using element_type = std::remove_cv_t; + using owning_device_buffer = device_mdarray; + using data_store = std::variant; + static constexpr bool is_const_pointer_ = std::is_const_v; + + public: + temporary_device_buffer(temporary_device_buffer const&) = delete; + temporary_device_buffer& operator=(temporary_device_buffer const&) = delete; + + constexpr temporary_device_buffer(temporary_device_buffer&&) = default; + constexpr temporary_device_buffer& operator=(temporary_device_buffer&&) = default; + + /** + * @brief Construct a new temporary device buffer object + * + * @param handle raft::device_resources + * @param data input pointer + * @param extents dimensions of input array + * @param write_back if true, any writes to the `view()` of this object will be copid + * back if the original pointer was in host memory + */ + temporary_device_buffer(device_resources const& handle, + ElementType* data, + Extents extents, + bool write_back = false) + : stream_(handle.get_stream()), + original_data_(data), + extents_{extents}, + write_back_(write_back), + length_([this]() { + std::size_t length = 1; + for (std::size_t i = 0; i < extents_.rank(); ++i) { + length *= extents_.extent(i); + } + return length; + }()), + device_id_{get_device_for_address(data)} + { + if (device_id_ == -1) { + typename owning_device_buffer::mapping_type layout{extents_}; + typename owning_device_buffer::container_policy_type policy{handle.get_stream()}; + + owning_device_buffer device_data{layout, policy}; + raft::copy(device_data.data_handle(), data, length_, handle.get_stream()); + data_ = data_store{std::in_place_index<1>, std::move(device_data)}; + } else { + data_ = data_store{std::in_place_index<0>, data}; + } + } + + ~temporary_device_buffer() noexcept(is_const_pointer_) + { + // only need to write data back for non const pointers + // when write_back=true and original pointer is in + // host memory + if constexpr (not is_const_pointer_) { + if (write_back_ && device_id_ == -1) { + raft::copy(original_data_, std::get<1>(data_).data_handle(), length_, stream_); + } + } + } + + /** + * @brief Returns a `raft::device_mdspan` + * + * @return raft::device_mdspan + */ + auto view() -> view_type + { + if (device_id_ == -1) { + return std::get<1>(data_).view(); + } else { + return make_mdspan(original_data_, + extents_); + } + } + + private: + rmm::cuda_stream_view stream_; + ElementType* original_data_; + data_store data_; + Extents extents_; + bool write_back_; + std::size_t length_; + int device_id_; +}; + +/** + * @brief Factory to create a `raft::temporary_device_buffer` + * + * @tparam ElementType type of the input + * @tparam IndexType index type of `raft::extents` + * @tparam LayoutPolicy layout of the input + * @tparam ContainerPolicy container to be used to own device memory if needed + * @tparam Extents variadic dimensions for `raft::extents` + * @param handle raft::device_resources + * @param data input pointer + * @param extents dimensions of input array + * @param write_back if true, any writes to the `view()` of this object will be copid + * back if the original pointer was in host memory + * @return raft::temporary_device_buffer + */ +template >, + size_t... Extents> +auto make_temporary_device_buffer(raft::device_resources const& handle, + ElementType* data, + raft::extents extents, + bool write_back = false) +{ + return temporary_device_buffer( + handle, data, extents, write_back); +} + +/** + * @brief Factory to create a `raft::temporary_device_buffer` which produces a + * read-only `raft::device_mdspan` from `view()` method with + * `write_back=false` + * + * @tparam ElementType type of the input + * @tparam IndexType index type of `raft::extents` + * @tparam LayoutPolicy layout of the input + * @tparam ContainerPolicy container to be used to own device memory if needed + * @tparam Extents variadic dimensions for `raft::extents` + * @param handle raft::device_resources + * @param data input pointer + * @param extents dimensions of input array + * @return raft::temporary_device_buffer + */ +template >, + size_t... Extents> +auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, + ElementType* data, + raft::extents extents) +{ + return temporary_device_buffer(handle, data, extents, false); +} + +/** + * @brief Factory to create a `raft::temporary_device_buffer` which produces a + * writeable `raft::device_mdspan` from `view()` method with + * `write_back=true` + * + * @tparam ElementType type of the input + * @tparam IndexType index type of `raft::extents` + * @tparam LayoutPolicy layout of the input + * @tparam ContainerPolicy container to be used to own device memory if needed + * @tparam Extents variadic dimensions for `raft::extents` + * @param handle raft::device_resources + * @param data input pointer + * @param extents dimensions of input array + * @return raft::temporary_device_buffer + */ +template , + size_t... Extents, + typename = std::enable_if_t>> +auto make_writeback_temporary_device_buffer(raft::device_resources const& handle, + ElementType* data, + raft::extents extents) +{ + return temporary_device_buffer( + handle, data, extents, true); +} + +} // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 4b633864a3..26ec8ebf74 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -109,6 +109,7 @@ if(BUILD_TESTS) test/core/memory_type.cpp test/core/span.cpp test/core/span.cu + test/core/temporary_device_buffer.cu test/test.cpp ) diff --git a/cpp/test/core/mdarray.cu b/cpp/test/core/mdarray.cu index 0daad272f7..5eff6e7539 100644 --- a/cpp/test/core/mdarray.cu +++ b/cpp/test/core/mdarray.cu @@ -17,7 +17,6 @@ #include "../test_utils.cuh" #include -#include #include #include #include @@ -951,60 +950,4 @@ void test_mdarray_unravel() TEST(MDArray, Unravel) { test_mdarray_unravel(); } -TEST(DeviceBuffer, DevicePointer) -{ - { - raft::device_resources handle; - auto exts = raft::make_extents(5); - auto array = raft::make_device_mdarray(handle, exts); - - raft::device_buffer d_buf{handle, array.data_handle(), exts}; - - ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); - static_assert(!std::is_const_v, - "element_type should not be const"); - } - - { - raft::device_resources handle; - auto exts = raft::make_extents(5); - auto const array = raft::make_device_mdarray(handle, exts); - - raft::device_buffer d_buf{handle, array.data_handle(), exts}; - - ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); - static_assert(std::is_const_v, - "element_type should be const"); - } -} - -TEST(DeviceBuffer, HostPointerWithWriteBack) -{ - raft::device_resources handle; - auto exts = raft::make_extents(5); - auto array = raft::make_host_mdarray(exts); - thrust::fill(array.data_handle(), array.data_handle() + array.extent(0), 1); - rmm::device_uvector result(5, handle.get_stream()); - - { - raft::device_buffer d_buf{handle, array.data_handle(), exts, true}; - auto d_view = d_buf.view(); - - thrust::fill(rmm::exec_policy(handle.get_stream()), - d_view.data_handle(), - d_view.data_handle() + d_view.extent(0), - 10); - raft::copy(result.data(), d_view.data_handle(), d_view.extent(0), handle.get_stream()); - - static_assert(!std::is_const_v, - "element_type should not be const"); - } - - ASSERT_TRUE(raft::devArrMatchHost(array.data_handle(), - result.data(), - array.extent(0), - raft::Compare(), - handle.get_stream())); -} - } // namespace raft diff --git a/cpp/test/core/temporary_device_buffer.cu b/cpp/test/core/temporary_device_buffer.cu new file mode 100644 index 0000000000..b1530e4d9d --- /dev/null +++ b/cpp/test/core/temporary_device_buffer.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include + +#include + +#include + +namespace raft { + +TEST(TemporaryDeviceBuffer, DevicePointer) +{ + { + raft::device_resources handle; + auto exts = raft::make_extents(5); + auto array = raft::make_device_mdarray(handle, exts); + + // raft::device_buffer d_buf{handle, array.data_handle(), exts}; + auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts); + + ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); + static_assert(!std::is_const_v, + "element_type should not be const"); + } + + { + raft::device_resources handle; + auto exts = raft::make_extents(5); + auto array = raft::make_device_mdarray(handle, exts); + + // raft::device_buffer d_buf{handle, array.data_handle(), exts}; + auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts); + + ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); + static_assert(std::is_const_v, + "element_type should be const"); + } +} + +TEST(TemporaryDeviceBuffer, HostPointerWithWriteBack) +{ + raft::device_resources handle; + auto exts = raft::make_extents(5); + auto array = raft::make_host_mdarray(exts); + thrust::fill(array.data_handle(), array.data_handle() + array.extent(0), 1); + rmm::device_uvector result(5, handle.get_stream()); + + { + // raft::device_buffer d_buf{handle, array.data_handle(), exts, true}; + auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts); + auto d_view = d_buf.view(); + + thrust::fill(rmm::exec_policy(handle.get_stream()), + d_view.data_handle(), + d_view.data_handle() + d_view.extent(0), + 10); + raft::copy(result.data(), d_view.data_handle(), d_view.extent(0), handle.get_stream()); + + static_assert(!std::is_const_v, + "element_type should not be const"); + } + + ASSERT_TRUE(raft::devArrMatchHost(array.data_handle(), + result.data(), + array.extent(0), + raft::Compare(), + handle.get_stream())); +} + +} // namespace raft From 920282447674979635556e0cff5b337879264829 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 6 Mar 2023 10:14:34 -0800 Subject: [PATCH 3/8] remove comments --- cpp/test/core/temporary_device_buffer.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/cpp/test/core/temporary_device_buffer.cu b/cpp/test/core/temporary_device_buffer.cu index b1530e4d9d..52a2ec4c9b 100644 --- a/cpp/test/core/temporary_device_buffer.cu +++ b/cpp/test/core/temporary_device_buffer.cu @@ -32,7 +32,6 @@ TEST(TemporaryDeviceBuffer, DevicePointer) auto exts = raft::make_extents(5); auto array = raft::make_device_mdarray(handle, exts); - // raft::device_buffer d_buf{handle, array.data_handle(), exts}; auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts); ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); @@ -45,7 +44,6 @@ TEST(TemporaryDeviceBuffer, DevicePointer) auto exts = raft::make_extents(5); auto array = raft::make_device_mdarray(handle, exts); - // raft::device_buffer d_buf{handle, array.data_handle(), exts}; auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts); ASSERT_EQ(array.data_handle(), d_buf.view().data_handle()); @@ -63,7 +61,6 @@ TEST(TemporaryDeviceBuffer, HostPointerWithWriteBack) rmm::device_uvector result(5, handle.get_stream()); { - // raft::device_buffer d_buf{handle, array.data_handle(), exts, true}; auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts); auto d_view = d_buf.view(); From c10f4276f06a4eaf2d58a2ac0735338285c67a6e Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 6 Mar 2023 11:03:49 -0800 Subject: [PATCH 4/8] usage example and more docs --- .../raft/core/temporary_device_buffer.hpp | 50 ++++++++++++++++++- docs/source/cpp_api/mdspan.rst | 1 + .../mdspan_temporary_device_buffer.rst | 23 +++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 docs/source/cpp_api/mdspan_temporary_device_buffer.rst diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp index 6bf59717e0..1851fe6173 100644 --- a/cpp/include/raft/core/temporary_device_buffer.hpp +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -25,6 +25,11 @@ namespace raft { +/** + * \defgroup TemporaryDeviceBuffer `raft::temporary_device_buffer` and associated factories + * @{ + */ + /** * @brief An object to have temporary representation of either a host/device pointer in device * memory. This object provides a `view()` method that will provide a `raft::device_mdspan` that may @@ -132,6 +137,19 @@ class temporary_device_buffer { /** * @brief Factory to create a `raft::temporary_device_buffer` * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // Initialize raft::device_mdarray and raft::extents + * // Can be either raft::device_mdarray or raft::host_mdarray + * auto exts = raft::make_extents(5); + * auto array = raft::make_device_mdarray(handle, exts); + * + * auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts); + * @endcode + * * @tparam ElementType type of the input * @tparam IndexType index type of `raft::extents` * @tparam LayoutPolicy layout of the input @@ -162,7 +180,20 @@ auto make_temporary_device_buffer(raft::device_resources const& handle, * @brief Factory to create a `raft::temporary_device_buffer` which produces a * read-only `raft::device_mdspan` from `view()` method with * `write_back=false` - * + * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // Initialize raft::device_mdarray and raft::extents + * // Can be either raft::device_mdarray or raft::host_mdarray + * auto exts = raft::make_extents(5); + * auto array = raft::make_device_mdarray(handle, exts); + * + * auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts); + * @endcode + * * @tparam ElementType type of the input * @tparam IndexType index type of `raft::extents` * @tparam LayoutPolicy layout of the input @@ -182,7 +213,7 @@ auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, ElementType* data, raft::extents extents) { - return temporary_device_buffer, decltype(extents), LayoutPolicy, ContainerPolicy>(handle, data, extents, false); @@ -193,6 +224,19 @@ auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, * writeable `raft::device_mdspan` from `view()` method with * `write_back=true` * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // Initialize raft::host_mdarray and raft::extents + * // Can be either raft::device_mdarray or raft::host_mdarray + * auto exts = raft::make_extents(5); + * auto array = raft::make_host_mdarray(handle, exts); + * + * auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts); + * @endcode + * * @tparam ElementType type of the input * @tparam IndexType index type of `raft::extents` * @tparam LayoutPolicy layout of the input @@ -217,4 +261,6 @@ auto make_writeback_temporary_device_buffer(raft::device_resources const& handle handle, data, extents, true); } +/**@}*/ + } // namespace raft diff --git a/docs/source/cpp_api/mdspan.rst b/docs/source/cpp_api/mdspan.rst index af38247c01..3fc0db7b96 100644 --- a/docs/source/cpp_api/mdspan.rst +++ b/docs/source/cpp_api/mdspan.rst @@ -16,3 +16,4 @@ This page provides C++ class references for the RAFT's 1d span and multi-dimensi mdspan_mdspan.rst mdspan_mdarray.rst mdspan_span.rst + mdspan_temporary_device_buffer.rst diff --git a/docs/source/cpp_api/mdspan_temporary_device_buffer.rst b/docs/source/cpp_api/mdspan_temporary_device_buffer.rst new file mode 100644 index 0000000000..90d08ac5bb --- /dev/null +++ b/docs/source/cpp_api/mdspan_temporary_device_buffer.rst @@ -0,0 +1,23 @@ +temporary_device_buffer: Temporary raft::device_mdspan Producing Object +=========================================== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +.. doxygenclass:: raft::temporary_device_buffer + :project: RAFT + :members: + +Factories +--------- +.. doxygenfunction:: raft::make_temporary_device_buffer + :project: RAFT + +.. doxygenfunction:: raft::make_readonly_temporary_device_buffer + :project: RAFT + +.. doxygenfunction:: raft::make_writeback_temporary_device_buffer + :project: RAFT From 3512e7bb965209779b59039082b7a9b14136ce09 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 6 Mar 2023 11:16:42 -0800 Subject: [PATCH 5/8] style changes --- .../raft/core/temporary_device_buffer.hpp | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp index 1851fe6173..79fd609f67 100644 --- a/cpp/include/raft/core/temporary_device_buffer.hpp +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -139,17 +139,17 @@ class temporary_device_buffer { * * @code{.cpp} * #include - * + * * raft::device_resources handle; - * + * * // Initialize raft::device_mdarray and raft::extents * // Can be either raft::device_mdarray or raft::host_mdarray * auto exts = raft::make_extents(5); * auto array = raft::make_device_mdarray(handle, exts); - * + * * auto d_buf = raft::make_temporary_device_buffer(handle, array.data_handle(), exts); * @endcode - * + * * @tparam ElementType type of the input * @tparam IndexType index type of `raft::extents` * @tparam LayoutPolicy layout of the input @@ -180,20 +180,20 @@ auto make_temporary_device_buffer(raft::device_resources const& handle, * @brief Factory to create a `raft::temporary_device_buffer` which produces a * read-only `raft::device_mdspan` from `view()` method with * `write_back=false` - * + * * @code{.cpp} * #include - * + * * raft::device_resources handle; - * + * * // Initialize raft::device_mdarray and raft::extents * // Can be either raft::device_mdarray or raft::host_mdarray * auto exts = raft::make_extents(5); * auto array = raft::make_device_mdarray(handle, exts); - * + * * auto d_buf = raft::make_readonly_temporary_device_buffer(handle, array.data_handle(), exts); * @endcode - * + * * @tparam ElementType type of the input * @tparam IndexType index type of `raft::extents` * @tparam LayoutPolicy layout of the input @@ -226,17 +226,17 @@ auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, * * @code{.cpp} * #include - * + * * raft::device_resources handle; - * + * * // Initialize raft::host_mdarray and raft::extents * // Can be either raft::device_mdarray or raft::host_mdarray * auto exts = raft::make_extents(5); * auto array = raft::make_host_mdarray(handle, exts); - * + * * auto d_buf = raft::make_writeback_temporary_device_buffer(handle, array.data_handle(), exts); * @endcode - * + * * @tparam ElementType type of the input * @tparam IndexType index type of `raft::extents` * @tparam LayoutPolicy layout of the input From 06062371b8cb6f8e09a5da4b05371a9106fc6145 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Tue, 7 Mar 2023 07:58:02 -0500 Subject: [PATCH 6/8] Update doc Co-authored-by: William Hicks --- cpp/include/raft/core/temporary_device_buffer.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp index 79fd609f67..69903a424d 100644 --- a/cpp/include/raft/core/temporary_device_buffer.hpp +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -31,8 +31,8 @@ namespace raft { */ /** - * @brief An object to have temporary representation of either a host/device pointer in device - * memory. This object provides a `view()` method that will provide a `raft::device_mdspan` that may + * @brief An object which provides temporary access on-device to memory from either a host or device pointer. + * This object provides a `view()` method that will provide a `raft::device_mdspan` that may * be read-only depending on const-qualified nature of the input pointer. * * @tparam ElementType type of the input From 1af9046f98723362cf72de0014d3299ed5ede904 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 7 Mar 2023 05:07:34 -0800 Subject: [PATCH 7/8] style change --- cpp/include/raft/core/temporary_device_buffer.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp index 69903a424d..999861c443 100644 --- a/cpp/include/raft/core/temporary_device_buffer.hpp +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -31,9 +31,9 @@ namespace raft { */ /** - * @brief An object which provides temporary access on-device to memory from either a host or device pointer. - * This object provides a `view()` method that will provide a `raft::device_mdspan` that may - * be read-only depending on const-qualified nature of the input pointer. + * @brief An object which provides temporary access on-device to memory from either a host or device + * pointer. This object provides a `view()` method that will provide a `raft::device_mdspan` that + * may be read-only depending on const-qualified nature of the input pointer. * * @tparam ElementType type of the input * @tparam Extents raft::extents From 4b61f8383b7cfd1e37c9abbd3f5d6e717950525f Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 7 Mar 2023 11:17:33 -0800 Subject: [PATCH 8/8] use partial type for container policy --- .../raft/core/temporary_device_buffer.hpp | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp index 999861c443..5e6ae84eb5 100644 --- a/cpp/include/raft/core/temporary_device_buffer.hpp +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -42,14 +42,16 @@ namespace raft { */ template >> + typename LayoutPolicy = layout_c_contiguous, + template typename ContainerPolicy = detail::device_uvector_policy> class temporary_device_buffer { - using view_type = device_mdspan; - using index_type = typename Extents::index_type; - using element_type = std::remove_cv_t; - using owning_device_buffer = device_mdarray; - using data_store = std::variant; + using view_type = device_mdspan; + using index_type = typename Extents::index_type; + using element_type = std::remove_cv_t; + using container_policy = ContainerPolicy; + using owning_device_buffer = + device_mdarray; + using data_store = std::variant; static constexpr bool is_const_pointer_ = std::is_const_v; public: @@ -163,9 +165,9 @@ class temporary_device_buffer { * @return raft::temporary_device_buffer */ template >, + typename IndexType = std::uint32_t, + typename LayoutPolicy = layout_c_contiguous, + template typename ContainerPolicy = detail::device_uvector_policy, size_t... Extents> auto make_temporary_device_buffer(raft::device_resources const& handle, ElementType* data, @@ -205,9 +207,9 @@ auto make_temporary_device_buffer(raft::device_resources const& handle, * @return raft::temporary_device_buffer */ template >, + typename IndexType = std::uint32_t, + typename LayoutPolicy = layout_c_contiguous, + template typename ContainerPolicy = detail::device_uvector_policy, size_t... Extents> auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, ElementType* data, @@ -248,9 +250,9 @@ auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, * @return raft::temporary_device_buffer */ template , + typename IndexType = std::uint32_t, + typename LayoutPolicy = layout_c_contiguous, + template typename ContainerPolicy = detail::device_uvector_policy, size_t... Extents, typename = std::enable_if_t>> auto make_writeback_temporary_device_buffer(raft::device_resources const& handle,