Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] mdspan-ify iterator-based API for linalg::reduce_rows_by_keys #925

Open
Nyrio opened this issue Oct 19, 2022 · 2 comments
Open

[ENH] mdspan-ify iterator-based API for linalg::reduce_rows_by_keys #925

Nyrio opened this issue Oct 19, 2022 · 2 comments
Labels
0 - Backlog In queue waiting for assignment cpp improvement Improvement / enhancement to an existing function

Comments

@Nyrio
Copy link
Contributor

Nyrio commented Oct 19, 2022

Following #909, the API for linalg::reduce_rows_by_keys can take custom iterators, in order to save intermediate steps in ann_kmeans_balanced. But raw-pointer/iterator APIs are being deprecated in favor of mdspan.

We should provide appropriate helpers and types for iterator-based mdspan, and change this API accordingly. See discussions on the aforementioned PR.

@Nyrio Nyrio added feature request New feature or request improvement Improvement / enhancement to an existing function 0 - Backlog In queue waiting for assignment cpp and removed feature request New feature or request labels Oct 19, 2022
@Nyrio
Copy link
Contributor Author

Nyrio commented Oct 19, 2022

cc @cjnolet @mhoemmen who were part of the discussion about this.

@mhoemmen
Copy link
Contributor

mhoemmen commented Oct 19, 2022

Hi @Nyrio and @cjnolet ! I can think of two likely meanings for "iterator-based mdspan."

  1. "An mdspan whose data_handle_type is a random access iterator"
  2. "Implement begin and end for generic mdspan"

Given the context, I'm guessing that you mean (1). You'll need at least forward iterators (because mdspan depends on the multipass guarantee). You can imitate default_accessor to write a generic iterator accessor. Use std::iterator_traits to get the iterator's value type, and std::advance (if you want to support forward iterators that are not also random access iterators) in the access and offset member functions. For example:

using reference = typename std::iterator_traits<Iterator>::reference;
reference access(Iterator iter, std::size_t index) const {
  return *std::advance(iter, index);
}

For defining the accessor's required type aliases (see [mdspan.accessor.reqmts]), it's a bit tricky to get the element_type (it's const value_type if the iterator is an iterator-of-const, else value_type). If you can figure out how to do that, say via type alias template<std::forward_iterator Iterator> using iterator_element_t = /* ... */, then you'll find that useful below.

The question then becomes how to apply __host__ or __device__ (or both, for managed or pinned allocations) to access, so that you get the type safety benefits of RAFT's host_accessor and device_accessor. (For this case, you might also need to require random access iterators and replace std::advance with iter[index], unless libcu++ has a std::advance equivalent blessed by __host__ __device__.) If you have different custom iterator types for host vs. device access, you can map from the trait to whether the access function has __host__, __device__, or both (e.g., via specialization of a base class).

At this point, you can convert your iterator ranges concisely into rank-1 mdspan.

template<std::forward_iterator Iterator, std::integral IndexType, std::size_t Extent = dynamic_extent>
using iterator_mdspan_t = mdspan<
  iterator_element_t<Iterator>, // (possibly const value_type)
  extents<IndexType, Extent>,
  iterator_accessor<Iterator>>;

// For pre-C++20, remove Sentinel template parameter,
// make the second function parameter an Iterator, and
// replace std::ranges::distance with std::distance.
template<class Iterator, class Sentinel>
auto range_to_mdspan(Iterator begin, Sentinel end) {
  auto distance = std::ranges::distance(begin, end);
  using distance_type = decltype(distance);
  return iterator_mdspan_t<Iterator, distance_type, dynamic_extent>{begin, distance};
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
0 - Backlog In queue waiting for assignment cpp improvement Improvement / enhancement to an existing function
Projects
None yet
Development

No branches or pull requests

2 participants