diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/detail/span.hpp index aa598caf32..8a26a33247 100644 --- a/cpp/include/raft/detail/span.hpp +++ b/cpp/include/raft/detail/span.hpp @@ -86,5 +86,30 @@ __host__ __device__ constexpr auto lexicographical_compare(InputIt1 first1, } return first1 == last1 && first2 != last2; } + +template +struct span_storage { + private: + T* ptr_{nullptr}; + + public: + constexpr span_storage() noexcept = default; + constexpr span_storage(T* ptr, std::size_t) noexcept : ptr_{ptr} {} + [[nodiscard]] constexpr auto size() const noexcept -> std::size_t { return Extent; } + [[nodiscard]] constexpr auto data() const noexcept -> T* { return ptr_; } +}; + +template +struct span_storage { + private: + T* ptr_{nullptr}; + std::size_t size_{0}; + + public: + constexpr span_storage() noexcept = default; + constexpr span_storage(T* ptr, std::size_t size) noexcept : ptr_{ptr}, size_{size} {} + [[nodiscard]] constexpr auto size() const noexcept -> std::size_t { return size_; } + [[nodiscard]] constexpr auto data() const noexcept -> T* { return ptr_; } +}; } // namespace detail } // namespace raft diff --git a/cpp/include/raft/span.hpp b/cpp/include/raft/span.hpp index 389a6a2177..b4fbf5b63a 100644 --- a/cpp/include/raft/span.hpp +++ b/cpp/include/raft/span.hpp @@ -59,7 +59,7 @@ class span { /** * @brief Constructs a span that is a view over the range [first, first + count); */ - constexpr span(pointer ptr, size_type count) noexcept : size_(count), data_(ptr) + constexpr span(pointer ptr, size_type count) noexcept : storage_{ptr, count} { assert(!(Extent != dynamic_extent && count != Extent)); assert(ptr || count == 0); @@ -67,15 +67,15 @@ class span { /** * @brief Constructs a span that is a view over the range [first, last) */ - constexpr span(pointer first, pointer last) noexcept : size_(last - first), data_(first) + constexpr span(pointer first, pointer last) noexcept + : span{first, static_cast(thrust::distance(first, last))} { - assert(data_ || size_ == 0); } /** * @brief Constructs a span that is a view over the array arr. */ template - constexpr span(element_type (&arr)[N]) noexcept : size_(N), data_(&arr[0]) + constexpr span(element_type (&arr)[N]) noexcept : span{&arr[0], N} { } @@ -89,7 +89,7 @@ class span { detail::is_allowed_element_type_conversion_t::value && detail::is_allowed_extent_conversion_t::value>> constexpr span(const span& other) noexcept - : size_(other.size()), data_(other.data()) + : span{other.data(), other.size()} { } @@ -139,10 +139,10 @@ class span { return data()[_idx]; } - constexpr auto data() const noexcept -> pointer { return data_; } + constexpr auto data() const noexcept -> pointer { return storage_.data(); } // Observers - [[nodiscard]] constexpr auto size() const noexcept -> size_type { return size_; } + [[nodiscard]] constexpr auto size() const noexcept -> size_type { return storage_.size(); } [[nodiscard]] constexpr auto size_bytes() const noexcept -> size_type { return size() * sizeof(T); @@ -197,8 +197,7 @@ class span { } private: - size_type size_{0}; - pointer data_{nullptr}; + detail::span_storage storage_; }; /**