Skip to content

Commit

Permalink
Implement maybe-owning multi-dimensional container (mdbuffer) (rapids…
Browse files Browse the repository at this point in the history
…ai#1999)

### What is mdbuffer?

This PR introduces a maybe-owning multi-dimensional abstraction called `mdbuffer` to help simplify code that _may_ require an `mdarray` but only if the data are not already in a desired form or location.

As a concrete example, consider a function `foo_device` which operates on memory accessible from the device. If we wish to pass it data originating on the host, a separate code path must be created in which a `device_mdarray` is created and the data are explicitly copied from host to device. This leads to a proliferation of branches as `foo_device` interacts with other functions with similar requirements.

As an initial simplification, `mdbuffer` allows us to write a single template that accepts an `mdspan` pointing to memory on either host _or_ device and routes it through the same code:
```c++
template <typename mdspan_type>
void foo_device(raft::resources const& res, mdspan_type data) {
  auto buf = raft::mdbuffer{res, raft::mdbuffer{data}, raft::memory_type::device};
  // Data in buf is now guaranteed to be accessible from device.
  // If it was already accessible from device, no copy was performed. If it
  // was not, a copy was performed.

  some_kernel<<<...>>>(buf.view<raft::memory_type::device>());

  // It is sometimes useful to know whether or not a copy was performed to
  // e.g. determine whether the transformed data should be copied back to its original
  // location. This can be checked via the `is_owning()` method.
  if (buf.is_owning()) {
    raft::copy(res, data, buf.view<raft::memory_type::device>());
  }
}

foo_device(res, some_host_mdspan);  // Still works; memory is allocated and copy is performed
foo_device(res, some_device_mdspan);  // Still works and no allocation or copy is required
foo_device(res, some_managed_mdspan);  // Still works and no allocation or copy is required
```

While this is a useful simplification, it still leads to a proliferation of template instantiations. If this is undesirable, `mdbuffer` permits a further consolidation through implicit conversion of an mdspan to an mdbuffer:

```c++
void foo_device(raft::resources const& res, raft::mdbuffer<float, raft::matrix_extent<int>>&& data)
{ auto buf = raft::mdbuffer{res, data, raft::memory_type::device};
  some_kernel<<<...>>>(buf.view<raft::memory_type::device>());
  if (buf.is_owning()) {
    raft::copy(res, data, buf.view<raft::memory_type::device>());
  }
}

// All of the following work exactly as before but no longer require separate template instantiations
foo_device(res, some_host_mdspan);
foo_device(res, some_device_mdspan);
foo_device(res, some_managed_mdspan);
```

`mdbuffer` also offers a simple way to perform runtime dispatching based on the memory type passed to it using standard C++ patterns. While mdbuffer's `.view()` method takes an optional template parameter indicating the mdspan type to retrieve as a view, that parameter can be omitted to retrieve a `std::variant` of all mdspan types which may provide a view on the `mdbuffer`'s data (depending on its memory type). We can then use `std::visit` to perform runtime dispatching based on where the data are stored:

```c++
void foo(raft::resources const& res, raft::mdbuffer<float, raft::matrix_extent<int>>&& data) {
  std::visit([](auto view) {
    if constexpr (typename decltype(view)::accessor_type::is_device_accessible) {
      // Do something with these data on device
    } else {
      // Do something with these data on host
    }
  }, data.view());
}
```

In addition to moving data among various memory types (host, device, managed, and pinned currently), `mdbuffer` can be used to coerce data to a desired in-memory layout or to a compatible data type (e.g. floats to doubles). As with changes in the memory type, a copy will be performed if and only if it is necessary.

```c++
template <typename mdspan_type>
void foo_device(raft::resources const& res, mdspan_type data) {
  auto buf = raft::mdbuffer<float, raft::matrix_extent<int>, raft::row_major>{res,
raft::mdbuffer{data}, raft::memory_type::device};
  // Data in buf is now guaranteed to be accessible from device, and
  // represented by floats in row-major order.

  some_kernel<<<...>>>(buf.view<raft::memory_type::device>());

  // The same check can be used to determine whether or not a copy was
  // required, regardless of the cause. I.e. if the data were already on
  // device but in column-major order, the is_owning() method would still
  // return true because new storage needed to be allocated.
  if (buf.is_owning()) {
    raft::copy(res, data, buf.view<raft::memory_type::device>());
  }
}
```

### What mdbuffer is **not**
`mdbuffer` is **not** a replacement for either `mdspan` or `mdarray`. `mdspan` remains the standard object for passing data views throughout the RAFT codebase, and `mdarray` remains the standard object for allocating new multi-dimensional data. This is reflected in the fact that `mdbuffer` can _only_ be constructed from an existing `mdspan` or `mdarray` or another `mdbuffer`. `mdbuffer` is intended to be used solely to simplify code where data _may_ need to be copied to a different location.

### Follow-ups

-  I have omitted the mdbuffer-based replacement for and generalization of `temporary_device_buffer` since this PR is already enormous. I have this partially written however, and I'll post a link to its current state to help motivate the changes here.
- For all necessary copies, `mdbuffer` uses `raft::copy`. For _some_ transformations that require a change in data type or layout, `raft::copy` is not fully optimized. See rapidsai#1842 for more information. Optimizing this will be an important change to ensure that `mdbuffer` can be used with absolutely minimal overhead in all cases. These non-optimized cases represent a small fraction of the real-world use cases we can expect for `mdbuffer`, however, so there should be little concern about beginning to use it as is.
- `std::visit`'s performance for a small number of variants is sometimes non-optimal. As a followup, it would be good to benchmark `mdbuffer`'s current performance and compare to internal use of a `visit` implementation that uses a `switch` on the available memory types.

Resolve rapidsai#1602

Authors:
  - William Hicks (https://github.com/wphicks)
  - Tarang Jain (https://github.com/tarang-jain)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Artem M. Chirkin (https://github.com/achirkin)
  - Tamas Bela Feher (https://github.com/tfeher)
  - Ben Frederickson (https://github.com/benfred)

URL: rapidsai#1999
  • Loading branch information
wphicks authored and ChristinaZ committed Jan 17, 2024
1 parent a92c462 commit aac3b4f
Show file tree
Hide file tree
Showing 30 changed files with 3,563 additions and 102 deletions.
146 changes: 146 additions & 0 deletions cpp/include/raft/core/detail/fail_container_policy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <raft/core/error.hpp>
#include <raft/core/logger-macros.hpp>
#include <raft/core/resources.hpp>
#include <raft/thirdparty/mdspan/include/experimental/mdspan>
#include <stddef.h>

namespace raft {
namespace detail {

template <typename T>
struct fail_reference {
using value_type = typename std::remove_cv_t<T>;
using pointer = T*;
using const_pointer = T const*;

fail_reference() = default;
template <typename StreamViewType>
fail_reference(T* ptr, StreamViewType stream)
{
throw non_cuda_build_error{"Attempted to construct reference to device data in non-CUDA build"};
}

operator value_type() const // NOLINT
{
throw non_cuda_build_error{"Attempted to dereference device data in non-CUDA build"};
return value_type{};
}
auto operator=(T const& other) -> fail_reference&
{
throw non_cuda_build_error{"Attempted to assign to device data in non-CUDA build"};
return *this;
}
};

/** A placeholder container which throws an exception on use
*
* This placeholder is used in non-CUDA builds for container types that would
* otherwise be provided with CUDA code. Attempting to construct a non-empty
* container of this type throws an exception indicating that there was an
* attempt to use the device from a non-CUDA build. An example of when this
* might happen is if a downstream application attempts to allocate a device
* mdarray using a library built with non-CUDA RAFT.
*/
template <typename T>
struct fail_container {
using value_type = T;
using size_type = std::size_t;

using reference = fail_reference<T>;
using const_reference = fail_reference<T const>;

using pointer = value_type*;
using const_pointer = value_type const*;

using iterator = pointer;
using const_iterator = const_pointer;

explicit fail_container(size_t n = size_t{})
{
if (n != size_t{}) {
throw non_cuda_build_error{"Attempted to allocate device container in non-CUDA build"};
}
}

template <typename Index>
auto operator[](Index i) noexcept -> reference
{
RAFT_LOG_ERROR("Attempted to access device data in non-CUDA build");
return reference{};
}

template <typename Index>
auto operator[](Index i) const noexcept -> const_reference
{
RAFT_LOG_ERROR("Attempted to access device data in non-CUDA build");
return const_reference{};
}
void resize(size_t n)
{
if (n != size_t{}) {
throw non_cuda_build_error{"Attempted to allocate device container in non-CUDA build"};
}
}

[[nodiscard]] auto data() noexcept -> pointer { return nullptr; }
[[nodiscard]] auto data() const noexcept -> const_pointer { return nullptr; }
};

/** A placeholder container policy which throws an exception on use
*
* This placeholder is used in non-CUDA builds for container types that would
* otherwise be provided with CUDA code. Attempting to construct a non-empty
* container of this type throws an exception indicating that there was an
* attempt to use the device from a non-CUDA build. An example of when this
* might happen is if a downstream application attempts to allocate a device
* mdarray using a library built with non-CUDA RAFT.
*/
template <typename ElementType>
struct fail_container_policy {
using element_type = ElementType;
using container_type = fail_container<element_type>;
using pointer = typename container_type::pointer;
using const_pointer = typename container_type::const_pointer;
using reference = typename container_type::reference;
using const_reference = typename container_type::const_reference;

using accessor_policy = std::experimental::default_accessor<element_type>;
using const_accessor_policy = std::experimental::default_accessor<element_type const>;

auto create(raft::resources const& res, size_t n) -> container_type { return container_type(n); }

fail_container_policy() = default;

[[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference
{
return c[n];
}
[[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept
-> const_reference
{
return c[n];
}

[[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; }
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }
};

} // namespace detail
} // namespace raft
22 changes: 21 additions & 1 deletion cpp/include/raft/core/device_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
* limitations under the License.
*/
#pragma once
#ifndef RAFT_DISABLE_CUDA
#include <raft/core/device_mdspan.hpp>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -196,3 +197,22 @@ class device_uvector_policy {
};

} // namespace raft
#else
#include <raft/core/detail/fail_container_policy.hpp>
namespace raft {

// Provide placeholders that will allow CPU-GPU interoperable codebases to
// compile in non-CUDA mode but which will throw exceptions at runtime on any
// attempt to touch device data

template <typename T>
using device_reference = detail::fail_reference<T>;

template <typename T>
using device_uvector = detail::fail_container<T>;

template <typename ElementType>
using device_uvector_policy = detail::fail_container_policy<ElementType>;

} // namespace raft
#endif
88 changes: 10 additions & 78 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,9 +26,6 @@ namespace raft {
template <typename AccessorPolicy>
using device_accessor = host_device_accessor<AccessorPolicy, memory_type::device>;

template <typename AccessorPolicy>
using managed_accessor = host_device_accessor<AccessorPolicy, memory_type::managed>;

/**
* @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location.
*/
Expand All @@ -38,12 +35,6 @@ template <typename ElementType,
typename AccessorPolicy = std::experimental::default_accessor<ElementType>>
using device_mdspan = mdspan<ElementType, Extents, LayoutPolicy, device_accessor<AccessorPolicy>>;

template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = std::experimental::default_accessor<ElementType>>
using managed_mdspan = mdspan<ElementType, Extents, LayoutPolicy, managed_accessor<AccessorPolicy>>;

template <typename T, bool B>
struct is_device_mdspan : std::false_type {};
template <typename T>
Expand All @@ -61,23 +52,6 @@ using is_input_device_mdspan_t = is_device_mdspan<T, is_input_mdspan_v<T>>;
template <typename T>
using is_output_device_mdspan_t = is_device_mdspan<T, is_output_mdspan_v<T>>;

template <typename T, bool B>
struct is_managed_mdspan : std::false_type {};
template <typename T>
struct is_managed_mdspan<T, true> : std::bool_constant<T::accessor_type::is_managed_accessible> {};

/**
* @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type
*/
template <typename T>
using is_managed_mdspan_t = is_managed_mdspan<T, is_mdspan_v<T>>;

template <typename T>
using is_input_managed_mdspan_t = is_managed_mdspan<T, is_input_mdspan_v<T>>;

template <typename T>
using is_output_managed_mdspan_t = is_managed_mdspan<T, is_output_mdspan_v<T>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a
* derived type
Expand All @@ -102,30 +76,6 @@ using enable_if_input_device_mdspan = std::enable_if_t<is_input_device_mdspan_v<
template <typename... Tn>
using enable_if_output_device_mdspan = std::enable_if_t<is_output_device_mdspan_v<Tn...>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_managed_mdspan_v = std::conjunction_v<is_managed_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_input_managed_mdspan_v =
std::conjunction_v<is_input_managed_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_output_managed_mdspan_v =
std::conjunction_v<is_output_managed_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_managed_mdspan = std::enable_if_t<is_managed_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_input_managed_mdspan = std::enable_if_t<is_input_managed_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_output_managed_mdspan = std::enable_if_t<is_output_managed_mdspan_v<Tn...>>;

/**
* @brief Shorthand for 0-dim host mdspan (scalar).
* @tparam ElementType the data type of the scalar element
Expand Down Expand Up @@ -186,7 +136,7 @@ using device_aligned_matrix_view =
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_right_padded<ElementType>>
auto make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols)
auto constexpr make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols)
{
using data_handle_type =
typename std::experimental::aligned_accessor<ElementType,
Expand All @@ -203,24 +153,6 @@ auto make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexTy
return device_aligned_matrix_view<ElementType, IndexType, LayoutPolicy>{aligned_pointer, extents};
}

/**
* @brief Create a raft::managed_mdspan
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param ptr Pointer to the data
* @param exts dimensionality of the array (series of integers)
* @return raft::managed_mdspan
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_managed_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
{
return make_mdspan<ElementType, IndexType, LayoutPolicy, true, true>(ptr, exts);
}

/**
* @brief Create a 0-dim (scalar) mdspan instance for device value.
*
Expand All @@ -229,7 +161,7 @@ auto make_managed_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
* @param[in] ptr on device to wrap
*/
template <typename ElementType, typename IndexType = std::uint32_t>
auto make_device_scalar_view(ElementType* ptr)
auto constexpr make_device_scalar_view(ElementType* ptr)
{
scalar_extent<IndexType> extents;
return device_scalar_view<ElementType, IndexType>{ptr, extents};
Expand All @@ -249,7 +181,7 @@ auto make_device_scalar_view(ElementType* ptr)
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols)
auto constexpr make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols)
{
matrix_extent<IndexType> extents{n_rows, n_cols};
return device_matrix_view<ElementType, IndexType, LayoutPolicy>{ptr, extents};
Expand All @@ -269,10 +201,10 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col
* @param[in] stride leading dimension / stride of data
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_strided_matrix_view(ElementType* ptr,
IndexType n_rows,
IndexType n_cols,
IndexType stride)
auto constexpr make_device_strided_matrix_view(ElementType* ptr,
IndexType n_rows,
IndexType n_cols,
IndexType stride)
{
constexpr auto is_row_major = std::is_same_v<LayoutPolicy, layout_c_contiguous>;
IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1;
Expand All @@ -295,7 +227,7 @@ auto make_device_strided_matrix_view(ElementType* ptr,
* @return raft::device_vector_view
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, IndexType n)
auto constexpr make_device_vector_view(ElementType* ptr, IndexType n)
{
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}
Expand All @@ -310,7 +242,7 @@ auto make_device_vector_view(ElementType* ptr, IndexType n)
* @return raft::device_vector_view
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(
auto constexpr make_device_vector_view(
ElementType* ptr,
const typename LayoutPolicy::template mapping<vector_extent<IndexType>>& mapping)
{
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/core/host_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,4 +62,5 @@ class host_vector_policy {
[[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; }
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }
};

} // namespace raft
12 changes: 11 additions & 1 deletion cpp/include/raft/core/host_device_accessor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,6 +42,16 @@ struct host_device_accessor : public AccessorPolicy {
using AccessorPolicy::AccessorPolicy;
using offset_policy = host_device_accessor;
host_device_accessor(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT

// Prevent implicit conversion from incompatible host_device_accessor types
template <memory_type OtherMemType>
host_device_accessor(host_device_accessor<AccessorPolicy, OtherMemType> const& that) = delete;

template <memory_type OtherMemType, typename = std::enable_if_t<mem_type == OtherMemType>>
host_device_accessor(host_device_accessor<AccessorPolicy, OtherMemType> const& that)
: AccessorPolicy{that}
{
}
};

} // namespace raft
Loading

0 comments on commit aac3b4f

Please sign in to comment.