Skip to content

Commit

Permalink
Distinguish host/device.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 14, 2021
1 parent d31cb7f commit 7ecad11
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 183 deletions.
114 changes: 58 additions & 56 deletions cpp/include/raft/common/span.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
#include <limits> // numeric_limits
#include <type_traits>

#include <thrust/functional.h>
#include <thrust/iterator/reverse_iterator.h>

namespace raft::common {

constexpr std::size_t dynamic_extent = std::numeric_limits<std::size_t>::max();

template <class ElementType, std::size_t Extent>
template <class ElementType, bool is_device, std::size_t Extent>
class Span;

namespace detail {
Expand All @@ -39,7 +40,7 @@ namespace detail {
* - Otherwise, dynamic_extent.
*/
template <std::size_t Extent, std::size_t Offset, std::size_t Count>
struct ExtentValue
struct extent_value_t
: public std::integral_constant<
std::size_t, Count != dynamic_extent
? Count
Expand All @@ -51,39 +52,28 @@ struct ExtentValue
* dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N.
*/
template <typename T, std::size_t Extent>
struct ExtentAsBytesValue
struct extent_as_bytes_value_t
: public std::integral_constant<
std::size_t, Extent == dynamic_extent ? Extent : sizeof(T) * Extent> {};

template <std::size_t From, std::size_t To>
struct IsAllowedExtentConversion
struct is_allowed_extent_conversion_t
: public std::integral_constant<bool, From == To || From == dynamic_extent ||
To == dynamic_extent> {};

template <class From, class To>
struct IsAllowedElementTypeConversion
struct is_allowed_element_type_conversion_t
: public std::integral_constant<
bool, std::is_convertible<From (*)[], To (*)[]>::value> {};

template <class T>
struct IsSpanOracle : std::false_type {};
struct is_span_oracle_t : std::false_type {};

template <class T, std::size_t Extent>
struct IsSpanOracle<Span<T, Extent>> : std::true_type {};
template <class T, bool is_device, std::size_t Extent>
struct is_span_oracle_t<Span<T, is_device, Extent>> : std::true_type {};

template <class T>
struct IsSpan : public IsSpanOracle<typename std::remove_cv<T>::type> {};

// Re-implement std algorithms here to adopt CUDA.
template <typename T>
struct Less {
constexpr bool operator()(const T& _x, const T& _y) const { return _x < _y; }
};

template <typename T>
struct Greater {
constexpr bool operator()(const T& _x, const T& _y) const { return _x > _y; }
};
struct is_span_t : public is_span_oracle_t<typename std::remove_cv<T>::type> {};

template <class InputIt1, class InputIt2, class Compare>
__host__ __device__ constexpr auto lexicographical_compare(
Expand All @@ -105,7 +95,7 @@ __host__ __device__ constexpr auto lexicographical_compare(
* \brief The span class defined in ISO C++20. Iterator is defined as plain pointer and
* most of the methods have bound check on debug build.
*/
template <typename T, std::size_t Extent = dynamic_extent>
template <typename T, bool is_device, std::size_t Extent = dynamic_extent>
class Span {
public:
using element_type = T;
Expand Down Expand Up @@ -139,11 +129,12 @@ class Span {
template <std::size_t N>
constexpr Span(element_type (&arr)[N]) noexcept : size_(N), data_(&arr[0]) {}

template <class U, std::size_t OtherExtent,
class = typename std::enable_if<
detail::IsAllowedElementTypeConversion<U, T>::value &&
detail::IsAllowedExtentConversion<OtherExtent, Extent>::value>>
constexpr Span(const Span<U, OtherExtent>& _other) noexcept
template <
class U, std::size_t OtherExtent,
class = typename std::enable_if<
detail::is_allowed_element_type_conversion_t<U, T>::value &&
detail::is_allowed_extent_conversion_t<OtherExtent, Extent>::value>>
constexpr Span(const Span<U, is_device, OtherExtent>& _other) noexcept
: size_(_other.size()), data_(_other.data()) {}

constexpr Span(const Span& _other) noexcept
Expand Down Expand Up @@ -208,25 +199,25 @@ class Span {

// Subviews
template <std::size_t Count>
constexpr auto first() const -> Span<element_type, Count> {
constexpr auto first() const -> Span<element_type, is_device, Count> {
assert(Count <= size());
return {data(), Count};
}

constexpr auto first(std::size_t _count) const
-> Span<element_type, dynamic_extent> {
-> Span<element_type, is_device, dynamic_extent> {
assert(_count <= size());
return {data(), _count};
}

template <std::size_t Count>
constexpr auto last() const -> Span<element_type, Count> {
constexpr auto last() const -> Span<element_type, is_device, Count> {
assert(Count <= size());
return {data() + size() - Count, Count};
}

constexpr auto last(std::size_t _count) const
-> Span<element_type, dynamic_extent> {
-> Span<element_type, is_device, dynamic_extent> {
assert(_count <= size());
return subspan(size() - _count, _count);
}
Expand All @@ -237,15 +228,16 @@ class Span {
*/
template <std::size_t Offset, std::size_t Count = dynamic_extent>
constexpr auto subspan() const
-> Span<element_type, detail::ExtentValue<Extent, Offset, Count>::value> {
-> Span<element_type, is_device,
detail::extent_value_t<Extent, Offset, Count>::value> {
assert((Count == dynamic_extent) ? (Offset <= size())
: (Offset + Count <= size()));
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
}

constexpr auto subspan(size_type _offset,
size_type _count = dynamic_extent) const
-> Span<element_type, dynamic_extent> {
-> Span<element_type, is_device, dynamic_extent> {
assert((_count == dynamic_extent) ? (_offset <= size())
: (_offset + _count <= size()));
return {data() + _offset,
Expand All @@ -257,8 +249,15 @@ class Span {
pointer data_{nullptr};
};

template <class T, std::size_t X, class U, std::size_t Y>
constexpr auto operator==(Span<T, X> l, Span<U, Y> r) -> bool {
template <typename T, size_t extent = dynamic_extent>
using host_span = Span<T, false, extent>;

template <typename T, size_t extent = dynamic_extent>
using device_span = Span<T, true, extent>;

template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator==(Span<T, is_device, X> l, Span<U, is_device, Y> r)
-> bool {
if (l.size() != r.size()) {
return false;
}
Expand All @@ -271,46 +270,49 @@ constexpr auto operator==(Span<T, X> l, Span<U, Y> r) -> bool {
return true;
}

template <class T, std::size_t X, class U, std::size_t Y>
constexpr auto operator!=(Span<T, X> l, Span<U, Y> r) {
template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator!=(Span<T, is_device, X> l, Span<U, is_device, Y> r) {
return !(l == r);
}

template <class T, std::size_t X, class U, std::size_t Y>
constexpr auto operator<(Span<T, X> l, Span<U, Y> r) {
template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator<(Span<T, is_device, X> l, Span<U, is_device, Y> r) {
return detail::lexicographical_compare<
typename Span<T, X>::iterator, typename Span<U, Y>::iterator,
detail::Less<typename Span<T, X>::element_type>>(l.begin(), l.end(),
r.begin(), r.end());
typename Span<T, is_device, X>::iterator,
typename Span<U, is_device, Y>::iterator,
thrust::less<typename Span<T, is_device, X>::element_type>>(
l.begin(), l.end(), r.begin(), r.end());
}

template <class T, std::size_t X, class U, std::size_t Y>
constexpr auto operator<=(Span<T, X> l, Span<U, Y> r) {
template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator<=(Span<T, is_device, X> l, Span<U, is_device, Y> r) {
return !(l > r);
}

template <class T, std::size_t X, class U, std::size_t Y>
constexpr auto operator>(Span<T, X> l, Span<U, Y> r) {
template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator>(Span<T, is_device, X> l, Span<U, is_device, Y> r) {
return detail::lexicographical_compare<
typename Span<T, X>::iterator, typename Span<U, Y>::iterator,
detail::Greater<typename Span<T, X>::element_type>>(l.begin(), l.end(),
r.begin(), r.end());
typename Span<T, is_device, X>::iterator,
typename Span<U, is_device, Y>::iterator,
thrust::greater<typename Span<T, is_device, X>::element_type>>(
l.begin(), l.end(), r.begin(), r.end());
}

template <class T, std::size_t X, class U, std::size_t Y>
constexpr auto operator>=(Span<T, X> l, Span<U, Y> r) {
template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator>=(Span<T, is_device, X> l, Span<U, is_device, Y> r) {
return !(l < r);
}

template <class T, std::size_t E>
auto as_bytes(Span<T, E> s) noexcept
-> Span<const std::byte, detail::ExtentAsBytesValue<T, E>::value> {
template <class T, bool is_device, std::size_t E>
auto as_bytes(Span<T, is_device, E> s) noexcept
-> Span<const std::byte, is_device,
detail::extent_as_bytes_value_t<T, E>::value> {
return {reinterpret_cast<const std::byte*>(s.data()), s.size_bytes()};
}

template <class T, std::size_t E>
auto as_writable_bytes(Span<T, E> s) noexcept
-> Span<std::byte, detail::ExtentAsBytesValue<T, E>::value> {
template <class T, bool is_device, std::size_t E>
auto as_writable_bytes(Span<T, is_device, E> s) noexcept
-> Span<std::byte, is_device, detail::extent_as_bytes_value_t<T, E>::value> {
return {reinterpret_cast<std::byte*>(s.data()), s.size_bytes()};
}
} // namespace raft::common
Loading

0 comments on commit 7ecad11

Please sign in to comment.