Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mdspan/mdarray template functions and utilities #601

Merged
merged 25 commits into from
Apr 28, 2022

Conversation

divyegala
Copy link
Member

No description provided.

@divyegala divyegala requested review from a team as code owners March 30, 2022 19:21
@divyegala divyegala changed the base branch from branch-22.04 to branch-22.06 March 30, 2022 19:21
@ajschmidt8 ajschmidt8 removed the request for review from a team April 4, 2022 13:47
@ajschmidt8
Copy link
Member

Removing ops-codeowners from the required reviews since it doesn't seem there are any file changes that we're responsible for. Feel free to add us back if necessary.

@divyegala divyegala added feature request New feature or request non-breaking Non-breaking change labels Apr 6, 2022
cpp/include/raft/mdarray.hpp Outdated Show resolved Hide resolved
cpp/include/raft/mdarray.hpp Outdated Show resolved Hide resolved
cpp/include/raft/mdarray.hpp Outdated Show resolved Hide resolved
cpp/include/raft/core/mdarray.hpp Outdated Show resolved Hide resolved
/**
* @brief Get an implicitly constructed mdspan that can be passed down to CUDA kernels.
*/
operator view_type() noexcept { return view(); }
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if these operators are needed anymore

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no, if you don't want implicit conversion.

* @\brief Boolean to determine if template type T is raft::array_interface or derived type
*/
template <typename T>
inline constexpr bool is_array_interface_v = __is_array_interface<std::remove_const_t<T>>::value;
Copy link
Member Author

@divyegala divyegala Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose that we give users two options here:

  1. They directly implement array_interface
  2. For users with existing owning types, we just ask that they add a view() method that returns a host_mdspan or device_mdspan or a derived type

Point 1 is already implemented, and point 2 allows users to update their existing their code minimally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Point 2 is solved with recent updates

template <typename array_interface_type,
size_t... Extents,
std::enable_if_t<is_array_interface_v<array_interface_type>>* = nullptr>
auto reshape(const array_interface_type& mda, extents<Extents...> new_shape)
Copy link
Member

@trivialfis trivialfis Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this function accepts (or will potentially accept in the future) an owning type, and might change the size of the underlying memory buffer, one might want to accept a raft handle as a parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are going to be explicit about size changes with a different function called resize. Either way, wouldn't that have to be a member function to be able to access the underlying buffer?

@@ -176,15 +317,24 @@ class mdarray {
/**
* @brief Get a mdspan that can be passed down to CUDA kernels.
*/
auto view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); }
view_type view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); }
Copy link
Member

@trivialfis trivialfis Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's the best to avoid override and use curiously recurring template pattern:

template <Base>
class array_interface : public Base {
    auto view() { return Base::view(); }
}

class mdarray : public array_interface<mdarray> {}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, this looks clean.

Copy link
Member Author

@divyegala divyegala Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a little difficult to implement for my template checks though. Right now I check at compile-time if my object can be converted to a pointer of type array_interface (mdarray being its derived type). I'm not sure how this will be checked with this new pattern. Unless the recursion means that it still works?

Copy link
Member

@trivialfis trivialfis Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm .. why do you need to make that check? If you want to emit error early, you can use:

template <typename ArrayInterface, typename std::enable_if_t<is_mdspan_v<decltype(std::declval<ArrayInterface>().view())>* = nullptr>
auto my_func(ArrayInterface const& arr) {
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that, I'm just trying to avoid c++ virtual table. Might not be the most rational choice but from my experience template is easier to handle than virtual functions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, std::is_convertible might help.

Copy link
Member Author

@divyegala divyegala Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std:is_convertible does not help since mdarray is a class-template and we do not know ahead of time if a template-type T could possibly be a class-template of mdarray. I like your solution though, we can check if the view() returns a type which satisfies is_mdspan_v

/**
* @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan
*/
template <size_t... ExtentsPack>
Copy link
Contributor

@wphicks wphicks Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include stddef.h for size_t or include cstddef and use std::size_t where possible

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were largely discussing these designs throughout your progress on this PR and I don't see anything particularly concerning. LGTM

cpp/include/raft/core/mdarray.hpp Show resolved Hide resolved
@divyegala
Copy link
Member Author

rerun tests

@cjnolet
Copy link
Member

cjnolet commented Apr 28, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 924c245 into rapidsai:branch-22.06 Apr 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CMake cpp feature request New feature or request non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants