Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Span implementation. #399

Merged
merged 6 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 325 additions & 0 deletions cpp/include/raft/common/span.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cassert>
#include <cinttypes> // size_t
#include <cstddef> // std::byte
#include <limits> // numeric_limits
#include <type_traits>

#include <thrust/functional.h>
#include <thrust/iterator/reverse_iterator.h>

namespace raft::common {
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

constexpr std::size_t dynamic_extent = std::numeric_limits<std::size_t>::max();

template <class ElementType, bool is_device, std::size_t Extent>
class span;

namespace detail {
/*!
* The extent E of the span returned by subspan is determined as follows:
*
* - If Count is not dynamic_extent, Count;
* - Otherwise, if Extent is not dynamic_extent, Extent - Offset;
* - Otherwise, dynamic_extent.
*/
template <std::size_t Extent, std::size_t Offset, std::size_t Count>
struct extent_value_t
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a side thought- it would be super useful to be able to implicitly convert between a _mdspan_vector_t and a _span_t. For example, this would allow the RAFT API to accept an _mdspan_vector_t and pass it a _span_t

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet We have some integration work to do. For instance, the dynamic_extent in this PR is duplicated with the one in stdex.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can polish after having the base implementation and make a list of additional features we want, like conversion between span and 1-dim mdspan, padding for arrays, resize for arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis, I've looked over your changes and it's looking like this is about ready to merge. What do you think? Can you capture the follow-on items in Github issues so we don't lose them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: public std::integral_constant<
std::size_t,
Count != dynamic_extent ? Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> {
};

/*!
* If N is dynamic_extent, the extent of the returned span E is also
* dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N.
*/
template <typename T, std::size_t Extent>
struct extent_as_bytes_value_t
: public std::integral_constant<std::size_t,
Extent == dynamic_extent ? Extent : sizeof(T) * Extent> {
};

template <std::size_t From, std::size_t To>
struct is_allowed_extent_conversion_t
: public std::integral_constant<bool,
From == To || From == dynamic_extent || To == dynamic_extent> {
};

template <class From, class To>
struct is_allowed_element_type_conversion_t
: public std::integral_constant<bool, std::is_convertible<From (*)[], To (*)[]>::value> {
};

template <class T>
struct is_span_oracle_t : std::false_type {
};

template <class T, bool is_device, std::size_t Extent>
struct is_span_oracle_t<span<T, is_device, Extent>> : std::true_type {
};

template <class T>
struct is_span_t : public is_span_oracle_t<typename std::remove_cv<T>::type> {
};

template <class InputIt1, class InputIt2, class Compare>
__host__ __device__ constexpr auto lexicographical_compare(InputIt1 first1,
InputIt1 last1,
InputIt2 first2,
InputIt2 last2) -> bool
{
Compare comp;
for (; first1 != last1 && first2 != last2; ++first1, ++first2) {
if (comp(*first1, *first2)) { return true; }
if (comp(*first2, *first1)) { return false; }
}
return first1 == last1 && first2 != last2;
}
} // namespace detail

/**
* \brief The span class defined in ISO C++20. Iterator is defined as plain pointer and
* most of the methods have bound check on debug build.
*/
template <typename T, bool is_device, std::size_t Extent = dynamic_extent>
class span {
public:
using element_type = T;
using value_type = typename std::remove_cv<T>::type;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using pointer = T*;
using const_pointer = T const*;
using reference = T&;
using const_reference = T const&;

using iterator = pointer;
using const_iterator = const_pointer;
using reverse_iterator = thrust::reverse_iterator<iterator>;
using const_reverse_iterator = thrust::reverse_iterator<const_iterator>;

// constructors
constexpr span() noexcept = default;

constexpr span(pointer _ptr, size_type _count) noexcept : size_(_count), data_(_ptr)
{
assert(!(Extent != dynamic_extent && _count != Extent));
assert(_ptr || _count == 0);
}

constexpr span(pointer _first, pointer _last) noexcept : size_(_last - _first), data_(_first)
{
assert(data_ || size_ == 0);
}

template <std::size_t N>
constexpr span(element_type (&arr)[N]) noexcept : size_(N), data_(&arr[0])
{
}

template <class U,
std::size_t OtherExtent,
class = typename std::enable_if<
detail::is_allowed_element_type_conversion_t<U, T>::value &&
detail::is_allowed_extent_conversion_t<OtherExtent, Extent>::value>>
constexpr span(const span<U, is_device, OtherExtent>& _other) noexcept
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
: size_(_other.size()), data_(_other.data())
{
}

constexpr span(const span& _other) noexcept : size_(_other.size()), data_(_other.data()) {}

constexpr auto operator=(const span& _other) noexcept -> span&
{
size_ = _other.size();
data_ = _other.data();
return *this;
}

constexpr auto begin() const noexcept -> iterator { return data(); }

constexpr auto end() const noexcept -> iterator { return data() + size(); }

constexpr auto cbegin() const noexcept -> const_iterator { return data(); }

constexpr auto cend() const noexcept -> const_iterator { return data() + size(); }

__host__ __device__ constexpr auto rbegin() const noexcept -> reverse_iterator
{
return reverse_iterator{end()};
}

__host__ __device__ constexpr auto rend() const noexcept -> reverse_iterator
{
return reverse_iterator{begin()};
}

__host__ __device__ constexpr auto crbegin() const noexcept -> const_reverse_iterator
{
return const_reverse_iterator{cend()};
}

__host__ __device__ constexpr auto crend() const noexcept -> const_reverse_iterator
{
return const_reverse_iterator{cbegin()};
}

// element access
constexpr auto front() const -> reference { return (*this)[0]; }

constexpr auto back() const -> reference { return (*this)[size() - 1]; }

template <typename Index>
constexpr auto operator[](Index _idx) const -> reference
{
assert(_idx < size());
return data()[_idx];
}

constexpr auto data() const noexcept -> pointer { return data_; }

// Observers
[[nodiscard]] constexpr auto size() const noexcept -> size_type { return size_; }
[[nodiscard]] constexpr auto size_bytes() const noexcept -> size_type
{
return size() * sizeof(T);
}

constexpr auto empty() const noexcept { return size() == 0; }

// Subviews
template <std::size_t Count>
constexpr auto first() const -> span<element_type, is_device, Count>
{
assert(Count <= size());
return {data(), Count};
}

constexpr auto first(std::size_t _count) const -> span<element_type, is_device, dynamic_extent>
{
assert(_count <= size());
return {data(), _count};
}

template <std::size_t Count>
constexpr auto last() const -> span<element_type, is_device, Count>
{
assert(Count <= size());
return {data() + size() - Count, Count};
}

constexpr auto last(std::size_t _count) const -> span<element_type, is_device, dynamic_extent>
{
assert(_count <= size());
return subspan(size() - _count, _count);
}

/*!
* If Count is std::dynamic_extent, r.size() == this->size() - Offset;
* Otherwise r.size() == Count.
*/
template <std::size_t Offset, std::size_t Count = dynamic_extent>
constexpr auto subspan() const
-> span<element_type, is_device, detail::extent_value_t<Extent, Offset, Count>::value>
{
assert((Count == dynamic_extent) ? (Offset <= size()) : (Offset + Count <= size()));
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
}

constexpr auto subspan(size_type _offset, size_type _count = dynamic_extent) const
-> span<element_type, is_device, dynamic_extent>
{
assert((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size()));
return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count};
}

private:
size_type size_{0};
pointer data_{nullptr};
};

template <typename T, size_t extent = dynamic_extent>
using host_span = span<T, false, extent>;
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

template <typename T, size_t extent = dynamic_extent>
using device_span = span<T, true, extent>;

template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator==(span<T, is_device, X> l, span<U, is_device, Y> r) -> bool
{
if (l.size() != r.size()) { return false; }
for (auto l_beg = l.cbegin(), r_beg = r.cbegin(); l_beg != l.cend(); ++l_beg, ++r_beg) {
if (*l_beg != *r_beg) { return false; }
}
return true;
}

template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator!=(span<T, is_device, X> l, span<U, is_device, Y> r)
{
return !(l == r);
}

template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator<(span<T, is_device, X> l, span<U, is_device, Y> r)
{
return detail::lexicographical_compare<
typename span<T, is_device, X>::iterator,
typename span<U, is_device, Y>::iterator,
thrust::less<typename span<T, is_device, X>::element_type>>(
l.begin(), l.end(), r.begin(), r.end());
}

template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator<=(span<T, is_device, X> l, span<U, is_device, Y> r)
{
return !(l > r);
}

template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator>(span<T, is_device, X> l, span<U, is_device, Y> r)
{
return detail::lexicographical_compare<
typename span<T, is_device, X>::iterator,
typename span<U, is_device, Y>::iterator,
thrust::greater<typename span<T, is_device, X>::element_type>>(
l.begin(), l.end(), r.begin(), r.end());
}

template <class T, std::size_t X, class U, std::size_t Y, bool is_device>
constexpr auto operator>=(span<T, is_device, X> l, span<U, is_device, Y> r)
{
return !(l < r);
}

template <class T, bool is_device, std::size_t E>
auto as_bytes(span<T, is_device, E> s) noexcept
-> span<const std::byte, is_device, detail::extent_as_bytes_value_t<T, E>::value>
{
return {reinterpret_cast<const std::byte*>(s.data()), s.size_bytes()};
}

template <class T, bool is_device, std::size_t E>
auto as_writable_bytes(span<T, is_device, E> s) noexcept
-> span<std::byte, is_device, detail::extent_as_bytes_value_t<T, E>::value>
{
return {reinterpret_cast<std::byte*>(s.data()), s.size_bytes()};
}
} // namespace raft::common
2 changes: 2 additions & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
add_executable(test_raft
test/cudart_utils.cpp
test/cluster_solvers.cu
test/common/span.cpp
test/common/span.cu
test/distance/dist_adj.cu
test/distance/dist_canberra.cu
test/distance/dist_chebyshev.cu
Expand Down
Loading