diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 61c1b500e6..88f90485dd 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -29,7 +29,6 @@ #include #include #include -#include namespace raft { /** @@ -45,11 +44,11 @@ namespace raft { template class array_interface { /** - * @brief Get a mdspan that can be passed down to CUDA kernels. + * @brief Get an mdspan */ auto view() noexcept { return static_cast(this)->view(); } /** - * @brief Get a mdspan that can be passed down to CUDA kernels. + * @brief Get an mdspan */ auto view() const noexcept { return static_cast(this)->view(); } }; @@ -108,7 +107,8 @@ inline constexpr bool is_array_interface_v = is_array_interface::value; * template. * * - Most of the constructors from the reference implementation is removed to make sure - * CUDA stream is honorred. + * CUDA stream is honored. Note that this class is not coupled to CUDA and therefore + * will only be used in the case where the device variant is used. * * - unique_size is not implemented, which is still working in progress in the proposal * @@ -220,11 +220,11 @@ class mdarray #undef RAFT_MDARRAY_CTOR_CONSTEXPR /** - * @brief Get a mdspan that can be passed down to CUDA kernels. + * @brief Get an mdspan */ auto view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } /** - * @brief Get a mdspan that can be passed down to CUDA kernels. + * @brief Get an mdspan */ auto view() const noexcept {