diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/detail/span.hpp new file mode 100644 index 0000000000..aa598caf32 --- /dev/null +++ b/cpp/include/raft/detail/span.hpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // numeric_limits +#include // __host__ __device__ +#include + +namespace raft { +constexpr std::size_t dynamic_extent = std::numeric_limits::max(); + +template +class span; + +namespace detail { +/*! + * The extent E of the span returned by subspan is determined as follows: + * + * - If Count is not dynamic_extent, Count; + * - Otherwise, if Extent is not dynamic_extent, Extent - Offset; + * - Otherwise, dynamic_extent. + */ +template +struct extent_value_t + : public std::integral_constant< + std::size_t, + Count != dynamic_extent ? Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> { +}; + +/*! + * If N is dynamic_extent, the extent of the returned span E is also + * dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. + */ +template +struct extent_as_bytes_value_t + : public std::integral_constant { +}; + +template +struct is_allowed_extent_conversion_t + : public std::integral_constant { +}; + +template +struct is_allowed_element_type_conversion_t + : public std::integral_constant::value> { +}; + +template +struct is_span_oracle_t : std::false_type { +}; + +template +struct is_span_oracle_t> : std::true_type { +}; + +template +struct is_span_t : public is_span_oracle_t::type> { +}; + +template +__host__ __device__ constexpr auto lexicographical_compare(InputIt1 first1, + InputIt1 last1, + InputIt2 first2, + InputIt2 last2) -> bool +{ + Compare comp; + for (; first1 != last1 && first2 != last2; ++first1, ++first2) { + if (comp(*first1, *first2)) { return true; } + if (comp(*first2, *first1)) { return false; } + } + return first1 == last1 && first2 != last2; +} +} // namespace detail +} // namespace raft diff --git a/cpp/include/raft/span.hpp b/cpp/include/raft/span.hpp new file mode 100644 index 0000000000..389a6a2177 --- /dev/null +++ b/cpp/include/raft/span.hpp @@ -0,0 +1,283 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include // size_t +#include // std::byte +#include +#include +#include // __host__ __device__ +#include +#include + +namespace raft { +/** + * @brief The span class defined in ISO C++20. Iterator is defined as plain pointer and + * most of the methods have bound check on debug build. + * + * @code + * rmm::device_uvector uvec(10, rmm::cuda_stream_default); + * auto view = device_span{uvec.data(), uvec.size()}; + * @endcode + */ +template +class span { + public: + using element_type = T; + using value_type = typename std::remove_cv::type; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using const_pointer = T const*; + using reference = T&; + using const_reference = T const&; + + using iterator = pointer; + using const_iterator = const_pointer; + using reverse_iterator = thrust::reverse_iterator; + using const_reverse_iterator = thrust::reverse_iterator; + + /** + * @brief Default constructor that constructs a span with size 0 and nullptr. + */ + constexpr span() noexcept = default; + + /** + * @brief Constructs a span that is a view over the range [first, first + count); + */ + constexpr span(pointer ptr, size_type count) noexcept : size_(count), data_(ptr) + { + assert(!(Extent != dynamic_extent && count != Extent)); + assert(ptr || count == 0); + } + /** + * @brief Constructs a span that is a view over the range [first, last) + */ + constexpr span(pointer first, pointer last) noexcept : size_(last - first), data_(first) + { + assert(data_ || size_ == 0); + } + /** + * @brief Constructs a span that is a view over the array arr. + */ + template + constexpr span(element_type (&arr)[N]) noexcept : size_(N), data_(&arr[0]) + { + } + + /** + * @brief Initialize a span class from another one who's underlying type is convertible + * to element_type. + */ + template ::value && + detail::is_allowed_extent_conversion_t::value>> + constexpr span(const span& other) noexcept + : size_(other.size()), data_(other.data()) + { + } + + constexpr span(span const& other) noexcept = default; + constexpr span(span&& other) noexcept = default; + + constexpr auto operator=(span const& other) noexcept -> span& = default; + constexpr auto operator=(span&& other) noexcept -> span& = default; + + constexpr auto begin() const noexcept -> iterator { return data(); } + + constexpr auto end() const noexcept -> iterator { return data() + size(); } + + constexpr auto cbegin() const noexcept -> const_iterator { return data(); } + + constexpr auto cend() const noexcept -> const_iterator { return data() + size(); } + + __host__ __device__ constexpr auto rbegin() const noexcept -> reverse_iterator + { + return reverse_iterator{end()}; + } + + __host__ __device__ constexpr auto rend() const noexcept -> reverse_iterator + { + return reverse_iterator{begin()}; + } + + __host__ __device__ constexpr auto crbegin() const noexcept -> const_reverse_iterator + { + return const_reverse_iterator{cend()}; + } + + __host__ __device__ constexpr auto crend() const noexcept -> const_reverse_iterator + { + return const_reverse_iterator{cbegin()}; + } + + // element access + constexpr auto front() const -> reference { return (*this)[0]; } + + constexpr auto back() const -> reference { return (*this)[size() - 1]; } + + template + constexpr auto operator[](Index _idx) const -> reference + { + assert(static_cast(_idx) < size()); + return data()[_idx]; + } + + constexpr auto data() const noexcept -> pointer { return data_; } + + // Observers + [[nodiscard]] constexpr auto size() const noexcept -> size_type { return size_; } + [[nodiscard]] constexpr auto size_bytes() const noexcept -> size_type + { + return size() * sizeof(T); + } + + constexpr auto empty() const noexcept { return size() == 0; } + + // Subviews + template + constexpr auto first() const -> span + { + assert(Count <= size()); + return {data(), Count}; + } + + constexpr auto first(std::size_t _count) const -> span + { + assert(_count <= size()); + return {data(), _count}; + } + + template + constexpr auto last() const -> span + { + assert(Count <= size()); + return {data() + size() - Count, Count}; + } + + constexpr auto last(std::size_t _count) const -> span + { + assert(_count <= size()); + return subspan(size() - _count, _count); + } + + /*! + * If Count is std::dynamic_extent, r.size() == this->size() - Offset; + * Otherwise r.size() == Count. + */ + template + constexpr auto subspan() const + -> span::value> + { + assert((Count == dynamic_extent) ? (Offset <= size()) : (Offset + Count <= size())); + return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; + } + + constexpr auto subspan(size_type _offset, size_type _count = dynamic_extent) const + -> span + { + assert((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size())); + return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; + } + + private: + size_type size_{0}; + pointer data_{nullptr}; +}; + +/** + * @brief A span class for host pointer. + */ +template +using host_span = span; + +/** + * @brief A span class for device pointer. + */ +template +using device_span = span; + +template +constexpr auto operator==(span l, span r) -> bool +{ + if (l.size() != r.size()) { return false; } + for (auto l_beg = l.cbegin(), r_beg = r.cbegin(); l_beg != l.cend(); ++l_beg, ++r_beg) { + if (*l_beg != *r_beg) { return false; } + } + return true; +} + +template +constexpr auto operator!=(span l, span r) +{ + return !(l == r); +} + +template +constexpr auto operator<(span l, span r) +{ + return detail::lexicographical_compare< + typename span::iterator, + typename span::iterator, + thrust::less::element_type>>( + l.begin(), l.end(), r.begin(), r.end()); +} + +template +constexpr auto operator<=(span l, span r) +{ + return !(l > r); +} + +template +constexpr auto operator>(span l, span r) +{ + return detail::lexicographical_compare< + typename span::iterator, + typename span::iterator, + thrust::greater::element_type>>( + l.begin(), l.end(), r.begin(), r.end()); +} + +template +constexpr auto operator>=(span l, span r) +{ + return !(l < r); +} + +/** + * @brief Converts a span into a view of its underlying bytes + */ +template +auto as_bytes(span s) noexcept + -> span::value> +{ + return {reinterpret_cast(s.data()), s.size_bytes()}; +} + +/** + * @brief Converts a span into a mutable view of its underlying bytes + */ +template +auto as_writable_bytes(span s) noexcept + -> span::value> +{ + return {reinterpret_cast(s.data()), s.size_bytes()}; +} +} // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index b37c671525..968e60f5f9 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -69,6 +69,8 @@ add_executable(test_raft test/random/rng.cu test/random/rng_int.cu test/random/sample_without_replacement.cu + test/span.cpp + test/span.cu test/sparse/add.cu test/sparse/convert_coo.cu test/sparse/convert_csr.cu diff --git a/cpp/test/span.cpp b/cpp/test/span.cpp new file mode 100644 index 0000000000..6163811b95 --- /dev/null +++ b/cpp/test/span.cpp @@ -0,0 +1,419 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "test_span.hpp" +#include +#include // iota +#include + +namespace raft { +TEST(Span, DlfConstructors) +{ + // Dynamic extent + { + host_span s; + ASSERT_EQ(s.size(), 0); + ASSERT_EQ(s.data(), nullptr); + + host_span cs; + ASSERT_EQ(cs.size(), 0); + ASSERT_EQ(cs.data(), nullptr); + } + + // Static extent + { + host_span s; + ASSERT_EQ(s.size(), 0); + ASSERT_EQ(s.data(), nullptr); + + host_span cs; + ASSERT_EQ(cs.size(), 0); + ASSERT_EQ(cs.data(), nullptr); + } + + // Init list. + { + host_span s{}; + ASSERT_EQ(s.size(), 0); + ASSERT_EQ(s.data(), nullptr); + + host_span cs{}; + ASSERT_EQ(cs.size(), 0); + ASSERT_EQ(cs.data(), nullptr); + } +} + +TEST(Span, FromNullPtr) +{ + // dynamic extent + { + host_span s{nullptr, static_cast::size_type>(0)}; + ASSERT_EQ(s.size(), 0); + ASSERT_EQ(s.data(), nullptr); + + host_span cs{nullptr, static_cast::size_type>(0)}; + ASSERT_EQ(cs.size(), 0); + ASSERT_EQ(cs.data(), nullptr); + } + // static extent + { + host_span s{nullptr, static_cast::size_type>(0)}; + ASSERT_EQ(s.size(), 0); + ASSERT_EQ(s.data(), nullptr); + + host_span cs{nullptr, static_cast::size_type>(0)}; + ASSERT_EQ(cs.size(), 0); + ASSERT_EQ(cs.data(), nullptr); + } +} + +TEST(Span, FromPtrLen) +{ + float arr[16]; + std::iota(arr, arr + 16, 0); + + // static extent + { + host_span s(arr, 16); + ASSERT_EQ(s.size(), 16); + ASSERT_EQ(s.data(), arr); + + for (host_span::size_type i = 0; i < 16; ++i) { + ASSERT_EQ(s[i], arr[i]); + } + + host_span cs(arr, 16); + ASSERT_EQ(cs.size(), 16); + ASSERT_EQ(cs.data(), arr); + + for (host_span::size_type i = 0; i < 16; ++i) { + ASSERT_EQ(cs[i], arr[i]); + } + } + + // dynamic extent + { + host_span s(arr, 16); + ASSERT_EQ(s.size(), 16); + ASSERT_EQ(s.data(), arr); + + for (size_t i = 0; i < 16; ++i) { + ASSERT_EQ(s[i], arr[i]); + } + + host_span cs(arr, 16); + ASSERT_EQ(cs.size(), 16); + ASSERT_EQ(cs.data(), arr); + + for (host_span::size_type i = 0; i < 16; ++i) { + ASSERT_EQ(cs[i], arr[i]); + } + } +} + +TEST(Span, FromFirstLast) +{ + float arr[16]; + initialize_range(arr, arr + 16); + + // dynamic extent + { + host_span s(arr, arr + 16); + ASSERT_EQ(s.size(), 16); + ASSERT_EQ(s.data(), arr); + ASSERT_EQ(s.data() + s.size(), arr + 16); + + for (size_t i = 0; i < 16; ++i) { + ASSERT_EQ(s[i], arr[i]); + } + + host_span cs(arr, arr + 16); + ASSERT_EQ(cs.size(), 16); + ASSERT_EQ(cs.data(), arr); + ASSERT_EQ(cs.data() + cs.size(), arr + 16); + + for (size_t i = 0; i < 16; ++i) { + ASSERT_EQ(cs[i], arr[i]); + } + } + + // static extent + { + host_span s(arr, arr + 16); + ASSERT_EQ(s.size(), 16); + ASSERT_EQ(s.data(), arr); + ASSERT_EQ(s.data() + s.size(), arr + 16); + + for (size_t i = 0; i < 16; ++i) { + ASSERT_EQ(s[i], arr[i]); + } + + host_span cs(arr, arr + 16); + ASSERT_EQ(cs.size(), 16); + ASSERT_EQ(cs.data(), arr); + ASSERT_EQ(cs.data() + cs.size(), arr + 16); + + for (size_t i = 0; i < 16; ++i) { + ASSERT_EQ(cs[i], arr[i]); + } + } +} + +namespace { +struct base_class_t { + virtual void operator()() {} +}; +struct derived_class_t : public base_class_t { + void operator()() override {} +}; +} // anonymous namespace + +TEST(Span, FromOther) +{ + // convert constructor + { + host_span derived; + host_span base{derived}; + ASSERT_EQ(base.size(), derived.size()); + ASSERT_EQ(base.data(), derived.data()); + } + + float arr[16]; + initialize_range(arr, arr + 16); + + // default copy constructor + { + host_span s0(arr); + host_span s1(s0); + ASSERT_EQ(s0.size(), s1.size()); + ASSERT_EQ(s0.data(), s1.data()); + } +} + +TEST(Span, FromArray) +{ + float arr[16]; + initialize_range(arr, arr + 16); + + { + host_span s(arr); + ASSERT_EQ(&arr[0], s.data()); + ASSERT_EQ(s.size(), 16); + for (size_t i = 0; i < 16; ++i) { + ASSERT_EQ(arr[i], s[i]); + } + } + + { + host_span s(arr); + ASSERT_EQ(&arr[0], s.data()); + ASSERT_EQ(s.size(), 16); + for (size_t i = 0; i < 16; ++i) { + ASSERT_EQ(arr[i], s[i]); + } + } +} + +TEST(Span, Assignment) +{ + int status = 1; + test_assignment_t{&status}(); + ASSERT_EQ(status, 1); +} + +TEST(Span, BeginEnd) +{ + int status = 1; + test_beginend_t{&status}(); + ASSERT_EQ(status, 1); +} + +TEST(Span, ElementAccess) +{ + float arr[16]; + initialize_range(arr, arr + 16); + + host_span s(arr); + size_t j = 0; + for (auto i : s) { + ASSERT_EQ(i, arr[j]); + ++j; + } +} + +TEST(Span, Obversers) +{ + int status = 1; + test_observers_t{&status}(); + ASSERT_EQ(status, 1); +} + +TEST(Span, FrontBack) +{ + { + float arr[4]{0, 1, 2, 3}; + host_span s(arr); + ASSERT_EQ(s.front(), 0); + ASSERT_EQ(s.back(), 3); + } + { + std::vector arr{0, 1, 2, 3}; + host_span s(arr.data(), arr.size()); + ASSERT_EQ(s.front(), 0); + ASSERT_EQ(s.back(), 3); + } +} + +TEST(Span, FirstLast) +{ + // static extent + { + float arr[16]; + initialize_range(arr, arr + 16); + + host_span s(arr); + host_span first = s.first<4>(); + + ASSERT_EQ(first.size(), 4); + ASSERT_EQ(first.data(), arr); + + for (size_t i = 0; i < first.size(); ++i) { + ASSERT_EQ(first[i], arr[i]); + } + } + + { + float arr[16]; + initialize_range(arr, arr + 16); + + host_span s(arr); + host_span last = s.last<4>(); + + ASSERT_EQ(last.size(), 4); + ASSERT_EQ(last.data(), arr + 12); + + for (size_t i = 0; i < last.size(); ++i) { + ASSERT_EQ(last[i], arr[i + 12]); + } + } + + // dynamic extent + { + float* arr = new float[16]; + initialize_range(arr, arr + 16); + host_span s(arr, 16); + host_span first = s.first(4); + + ASSERT_EQ(first.size(), 4); + ASSERT_EQ(first.data(), s.data()); + + for (size_t i = 0; i < first.size(); ++i) { + ASSERT_EQ(first[i], s[i]); + } + + delete[] arr; + } + + { + float* arr = new float[16]; + initialize_range(arr, arr + 16); + host_span s(arr, 16); + host_span last = s.last(4); + + ASSERT_EQ(last.size(), 4); + ASSERT_EQ(last.data(), s.data() + 12); + + for (size_t i = 0; i < last.size(); ++i) { + ASSERT_EQ(s[12 + i], last[i]); + } + + delete[] arr; + } +} + +TEST(Span, Subspan) +{ + int arr[16]{0}; + host_span s1(arr); + auto s2 = s1.subspan<4>(); + ASSERT_EQ(s1.size() - 4, s2.size()); + + auto s3 = s1.subspan(2, 4); + ASSERT_EQ(s1.data() + 2, s3.data()); + ASSERT_EQ(s3.size(), 4); + + auto s4 = s1.subspan(2, dynamic_extent); + ASSERT_EQ(s1.data() + 2, s4.data()); + ASSERT_EQ(s4.size(), s1.size() - 2); +} + +TEST(Span, Compare) +{ + int status = 1; + test_compare_t{&status}(); + ASSERT_EQ(status, 1); +} + +TEST(Span, AsBytes) +{ + int status = 1; + test_as_bytes_t{&status}(); + ASSERT_EQ(status, 1); +} + +TEST(Span, AsWritableBytes) +{ + int status = 1; + test_as_writable_bytes_t{&status}(); + ASSERT_EQ(status, 1); +} + +TEST(Span, Empty) +{ + { + host_span s{nullptr, static_cast::size_type>(0)}; + auto res = s.subspan(0); + ASSERT_EQ(res.data(), nullptr); + ASSERT_EQ(res.size(), 0); + + res = s.subspan(0, 0); + ASSERT_EQ(res.data(), nullptr); + ASSERT_EQ(res.size(), 0); + } + + { + host_span s{nullptr, static_cast::size_type>(0)}; + auto res = s.subspan(0); + ASSERT_EQ(res.data(), nullptr); + ASSERT_EQ(res.size(), 0); + + res = s.subspan(0, 0); + ASSERT_EQ(res.data(), nullptr); + ASSERT_EQ(res.size(), 0); + } + { + // Should emit compile error + // host_span h{nullptr, 0ul}; + // device_span d{h}; + } +} + +TEST(Span, RBeginREnd) +{ + int32_t status = 1; + test_rbeginrend_t{&status}(); + ASSERT_EQ(status, 1); +} +} // namespace raft diff --git a/cpp/test/span.cu b/cpp/test/span.cu new file mode 100644 index 0000000000..e121cea108 --- /dev/null +++ b/cpp/test/span.cu @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "test_span.hpp" +#include +#include // iota +#include +#include +#include +#include +#include + +namespace raft { +struct TestStatus { + private: + int* status_; + + public: + TestStatus() + { + CUDA_CHECK(cudaMalloc(&status_, sizeof(int))); + int h_status = 1; + CUDA_CHECK(cudaMemcpy(status_, &h_status, sizeof(int), cudaMemcpyHostToDevice)); + } + ~TestStatus() noexcept(false) { CUDA_CHECK(cudaFree(status_)); } + + int Get() + { + int h_status; + CUDA_CHECK(cudaMemcpy(&h_status, status_, sizeof(int), cudaMemcpyDeviceToHost)); + return h_status; + } + + int* Data() { return status_; } +}; + +__global__ void TestFromOtherKernel(device_span span) +{ + // don't get optimized out + size_t idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx >= span.size()) { return; } +} +// Test converting different T +__global__ void TestFromOtherKernelConst(device_span span) +{ + // don't get optimized out + size_t idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx >= span.size()) { return; } +} + +/*! + * \brief Here we just test whether the code compiles. + */ +TEST(GPUSpan, FromOther) +{ + thrust::host_vector h_vec(16); + std::iota(h_vec.begin(), h_vec.end(), 0); + + thrust::device_vector d_vec(h_vec.size()); + thrust::copy(h_vec.begin(), h_vec.end(), d_vec.begin()); + // dynamic extent + { + device_span span(d_vec.data().get(), d_vec.size()); + TestFromOtherKernel<<<1, 16>>>(span); + } + { + device_span span(d_vec.data().get(), d_vec.size()); + TestFromOtherKernelConst<<<1, 16>>>(span); + } + // static extent + { + device_span span(d_vec.data().get(), d_vec.data().get() + 16); + TestFromOtherKernel<<<1, 16>>>(span); + } + { + device_span span(d_vec.data().get(), d_vec.data().get() + 16); + TestFromOtherKernelConst<<<1, 16>>>(span); + } +} + +TEST(GPUSpan, Assignment) +{ + TestStatus status; + thrust::for_each_n( + thrust::make_counting_iterator(0ul), 16, test_assignment_t{status.Data()}); + ASSERT_EQ(status.Get(), 1); +} + +TEST(GPUSpan, TestStatus) +{ + TestStatus status; + thrust::for_each_n(thrust::make_counting_iterator(0ul), 16, test_test_status_t{status.Data()}); + ASSERT_EQ(status.Get(), -1); +} + +template +struct TestEqual { + private: + T *lhs_, *rhs_; + int* status_; + + public: + TestEqual(T* _lhs, T* _rhs, int* _status) : lhs_(_lhs), rhs_(_rhs), status_(_status) {} + + HD void operator()(size_t _idx) + { + bool res = lhs_[_idx] == rhs_[_idx]; + SPAN_ASSERT_TRUE(res, status_); + } +}; + +TEST(GPUSpan, WithTrust) +{ + // Not adviced to initialize span with host_vector, since h_vec.data() is + // a host function. + thrust::host_vector h_vec(16); + std::iota(h_vec.begin(), h_vec.end(), 0); + + thrust::device_vector d_vec(h_vec.size()); + thrust::copy(h_vec.begin(), h_vec.end(), d_vec.begin()); + + // Can't initialize span with device_vector, since d_vec.data() is not raw + // pointer + { + device_span s(d_vec.data().get(), d_vec.size()); + + ASSERT_EQ(d_vec.size(), s.size()); + ASSERT_EQ(d_vec.data().get(), s.data()); + } + + { + TestStatus status; + thrust::device_vector d_vec1(d_vec.size()); + thrust::copy(thrust::device, d_vec.begin(), d_vec.end(), d_vec1.begin()); + device_span s(d_vec1.data().get(), d_vec.size()); + + thrust::for_each_n( + thrust::make_counting_iterator(0ul), + 16, + TestEqual{thrust::raw_pointer_cast(d_vec1.data()), s.data(), status.Data()}); + ASSERT_EQ(status.Get(), 1); + } +} + +TEST(GPUSpan, BeginEnd) +{ + TestStatus status; + thrust::for_each_n(thrust::make_counting_iterator(0ul), 16, test_beginend_t{status.Data()}); + ASSERT_EQ(status.Get(), 1); +} + +TEST(GPUSpan, RBeginREnd) +{ + TestStatus status; + thrust::for_each_n( + thrust::make_counting_iterator(0ul), 16, test_rbeginrend_t{status.Data()}); + ASSERT_EQ(status.Get(), 1); +} + +__global__ void TestModifyKernel(device_span span) +{ + size_t idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx >= span.size()) { return; } + span[idx] = span.size() - idx; +} + +TEST(GPUSpan, Modify) +{ + thrust::host_vector h_vec(16); + std::iota(h_vec.begin(), h_vec.end(), 0); + + thrust::device_vector d_vec(h_vec.size()); + thrust::copy(h_vec.begin(), h_vec.end(), d_vec.begin()); + + device_span span(d_vec.data().get(), d_vec.size()); + + TestModifyKernel<<<1, 16>>>(span); + + for (size_t i = 0; i < d_vec.size(); ++i) { + ASSERT_EQ(d_vec[i], d_vec.size() - i); + } +} + +TEST(GPUSpan, Observers) +{ + TestStatus status; + thrust::for_each_n( + thrust::make_counting_iterator(0ul), 16, test_observers_t{status.Data()}); + ASSERT_EQ(status.Get(), 1); +} + +TEST(GPUSpan, Compare) +{ + TestStatus status; + thrust::for_each_n(thrust::make_counting_iterator(0), 1, test_compare_t{status.Data()}); + ASSERT_EQ(status.Get(), 1); +} +} // namespace raft diff --git a/cpp/test/test_span.hpp b/cpp/test/test_span.hpp new file mode 100644 index 0000000000..254c89f91c --- /dev/null +++ b/cpp/test/test_span.hpp @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +namespace raft { + +template +__host__ __device__ void initialize_range(Iter _begin, Iter _end) +{ + float j = 0; + for (Iter i = _begin; i != _end; ++i, ++j) { + *i = j; + } +} +#define SPAN_ASSERT_TRUE(cond, status) \ + if (!(cond)) { *(status) = -1; } + +#define SPAN_ASSERT_FALSE(cond, status) \ + if ((cond)) { *(status) = -1; } + +struct test_test_status_t { + int* status; + + explicit test_test_status_t(int* _status) : status(_status) {} + + __host__ __device__ void operator()() { this->operator()(0); } + __host__ __device__ void operator()(int _idx) { SPAN_ASSERT_TRUE(false, status); } +}; + +template +struct test_assignment_t { + int* status; + + explicit test_assignment_t(int* _status) : status(_status) {} + + __host__ __device__ void operator()() { this->operator()(0); } + __host__ __device__ void operator()(int _idx) + { + span s1; + + float arr[] = {3, 4, 5}; + + span s2 = arr; + SPAN_ASSERT_TRUE(s2.size() == 3, status); + SPAN_ASSERT_TRUE(s2.data() == &arr[0], status); + + s2 = s1; + SPAN_ASSERT_TRUE(s2.empty(), status); + } +}; + +template +struct test_beginend_t { + int* status; + + explicit test_beginend_t(int* _status) : status(_status) {} + + __host__ __device__ void operator()() { this->operator()(0); } + __host__ __device__ void operator()(int _idx) + { + float arr[16]; + initialize_range(arr, arr + 16); + + span s(arr); + typename span::iterator beg{s.begin()}; + typename span::iterator end{s.end()}; + + SPAN_ASSERT_TRUE(end == beg + 16, status); + SPAN_ASSERT_TRUE(*beg == arr[0], status); + SPAN_ASSERT_TRUE(*(end - 1) == arr[15], status); + } +}; + +template +struct test_rbeginrend_t { + int* status; + + explicit test_rbeginrend_t(int* _status) : status(_status) {} + + __host__ __device__ void operator()() { this->operator()(0); } + __host__ __device__ void operator()(int _idx) + { + float arr[16]; + initialize_range(arr, arr + 16); + + span s(arr); + s.rbegin(); + typename span::reverse_iterator rbeg{s.rbegin()}; + typename span::reverse_iterator rend{s.rend()}; + + SPAN_ASSERT_TRUE(rbeg + 16 == rend, status); + SPAN_ASSERT_TRUE(*(rbeg) == arr[15], status); + SPAN_ASSERT_TRUE(*(rend - 1) == arr[0], status); + + typename span::const_reverse_iterator crbeg{s.crbegin()}; + typename span::const_reverse_iterator crend{s.crend()}; + + SPAN_ASSERT_TRUE(crbeg + 16 == crend, status); + SPAN_ASSERT_TRUE(*(crbeg) == arr[15], status); + SPAN_ASSERT_TRUE(*(crend - 1) == arr[0], status); + } +}; + +template +struct test_observers_t { + int* status; + + explicit test_observers_t(int* _status) : status(_status) {} + + __host__ __device__ void operator()() { this->operator()(0); } + __host__ __device__ void operator()(int _idx) + { + // empty + { + float* arr = nullptr; + span s(arr, static_cast::size_type>(0)); + SPAN_ASSERT_TRUE(s.empty(), status); + } + + // size, size_types + { + float* arr = new float[16]; + span s(arr, 16); + SPAN_ASSERT_TRUE(s.size() == 16, status); + SPAN_ASSERT_TRUE(s.size_bytes() == 16 * sizeof(float), status); + delete[] arr; + } + } +}; + +template +struct test_compare_t { + int* status; + + explicit test_compare_t(int* _status) : status(_status) {} + + __host__ __device__ void operator()() { this->operator()(0); } + __host__ __device__ void operator()(int _idx) + { + float lhs_arr[16], rhs_arr[16]; + initialize_range(lhs_arr, lhs_arr + 16); + initialize_range(rhs_arr, rhs_arr + 16); + + span lhs(lhs_arr); + span rhs(rhs_arr); + + SPAN_ASSERT_TRUE(lhs == rhs, status); + SPAN_ASSERT_FALSE(lhs != rhs, status); + + SPAN_ASSERT_TRUE(lhs <= rhs, status); + SPAN_ASSERT_TRUE(lhs >= rhs, status); + + lhs[2] -= 1; + + SPAN_ASSERT_FALSE(lhs == rhs, status); + SPAN_ASSERT_TRUE(lhs < rhs, status); + SPAN_ASSERT_FALSE(lhs > rhs, status); + } +}; + +template +struct test_as_bytes_t { + int* status; + + explicit test_as_bytes_t(int* _status) : status(_status) {} + + __host__ __device__ void operator()() { this->operator()(0); } + __host__ __device__ void operator()(int _idx) + { + float arr[16]; + initialize_range(arr, arr + 16); + + { + const span s{arr}; + const span bs = as_bytes(s); + SPAN_ASSERT_TRUE(bs.size() == s.size_bytes(), status); + SPAN_ASSERT_TRUE(static_cast(bs.data()) == static_cast(s.data()), + status); + } + + { + span s; + const span bs = as_bytes(s); + SPAN_ASSERT_TRUE(bs.size() == s.size(), status); + SPAN_ASSERT_TRUE(bs.size() == 0, status); + SPAN_ASSERT_TRUE(bs.size_bytes() == 0, status); + SPAN_ASSERT_TRUE(static_cast(bs.data()) == static_cast(s.data()), + status); + SPAN_ASSERT_TRUE(bs.data() == nullptr, status); + } + } +}; + +template +struct test_as_writable_bytes_t { + int* status; + + explicit test_as_writable_bytes_t(int* _status) : status(_status) {} + + __host__ __device__ void operator()() { this->operator()(0); } + __host__ __device__ void operator()(int _idx) + { + float arr[16]; + initialize_range(arr, arr + 16); + + { + span s; + span byte_s = as_writable_bytes(s); + SPAN_ASSERT_TRUE(byte_s.size() == s.size(), status); + SPAN_ASSERT_TRUE(byte_s.size_bytes() == s.size_bytes(), status); + SPAN_ASSERT_TRUE(byte_s.size() == 0, status); + SPAN_ASSERT_TRUE(byte_s.size_bytes() == 0, status); + SPAN_ASSERT_TRUE(byte_s.data() == nullptr, status); + SPAN_ASSERT_TRUE(static_cast(byte_s.data()) == static_cast(s.data()), status); + } + + { + span s{arr}; + span bs{as_writable_bytes(s)}; + SPAN_ASSERT_TRUE(s.size_bytes() == bs.size_bytes(), status); + SPAN_ASSERT_TRUE(static_cast(bs.data()) == static_cast(s.data()), status); + } + } +}; +} // namespace raft