-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mdspan/mdarray template functions and utilities #601
mdspan/mdarray template functions and utilities #601
Conversation
Removing |
Co-authored-by: Jiaming Yuan <[email protected]>
…-22.06-mdspan_utils
cpp/include/raft/core/mdarray.hpp
Outdated
/** | ||
* @brief Get an implicitly constructed mdspan that can be passed down to CUDA kernels. | ||
*/ | ||
operator view_type() noexcept { return view(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if these operators are needed anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably no, if you don't want implicit conversion.
* @\brief Boolean to determine if template type T is raft::array_interface or derived type | ||
*/ | ||
template <typename T> | ||
inline constexpr bool is_array_interface_v = __is_array_interface<std::remove_const_t<T>>::value; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I propose that we give users two options here:
- They directly implement
array_interface
- For users with existing owning types, we just ask that they add a
view()
method that returns a host_mdspan or device_mdspan or a derived type
Point 1 is already implemented, and point 2 allows users to update their existing their code minimally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Point 2 is solved with recent updates
template <typename array_interface_type, | ||
size_t... Extents, | ||
std::enable_if_t<is_array_interface_v<array_interface_type>>* = nullptr> | ||
auto reshape(const array_interface_type& mda, extents<Extents...> new_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this function accepts (or will potentially accept in the future) an owning type, and might change the size of the underlying memory buffer, one might want to accept a raft handle as a parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are going to be explicit about size changes with a different function called resize
. Either way, wouldn't that have to be a member function to be able to access the underlying buffer?
cpp/include/raft/core/mdarray.hpp
Outdated
@@ -176,15 +317,24 @@ class mdarray { | |||
/** | |||
* @brief Get a mdspan that can be passed down to CUDA kernels. | |||
*/ | |||
auto view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } | |||
view_type view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's the best to avoid override and use curiously recurring template pattern:
template <Base>
class array_interface : public Base {
auto view() { return Base::view(); }
}
class mdarray : public array_interface<mdarray> {}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, this looks clean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks a little difficult to implement for my template checks though. Right now I check at compile-time if my object can be converted to a pointer of type array_interface
(mdarray
being its derived type). I'm not sure how this will be checked with this new pattern. Unless the recursion means that it still works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm .. why do you need to make that check? If you want to emit error early, you can use:
template <typename ArrayInterface, typename std::enable_if_t<is_mdspan_v<decltype(std::declval<ArrayInterface>().view())>* = nullptr>
auto my_func(ArrayInterface const& arr) {
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having said that, I'm just trying to avoid c++ virtual table. Might not be the most rational choice but from my experience template is easier to handle than virtual functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, std::is_convertible
might help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std:is_convertible
does not help since mdarray
is a class-template and we do not know ahead of time if a template-type T
could possibly be a class-template of mdarray
. I like your solution though, we can check if the view()
returns a type which satisfies is_mdspan_v
/** | ||
* @\brief Dimensions extents for raft::host_mdspan or raft::device_mdspan | ||
*/ | ||
template <size_t... ExtentsPack> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include stddef.h
for size_t
or include cstddef
and use std::size_t
where possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were largely discussing these designs throughout your progress on this PR and I don't see anything particularly concerning. LGTM
rerun tests |
@gpucibot merge |
No description provided.