diff --git a/cpp/include/raft/common/span.hpp b/cpp/include/raft/common/span.hpp index 1e2c36e4fb..7e98d01759 100644 --- a/cpp/include/raft/common/span.hpp +++ b/cpp/include/raft/common/span.hpp @@ -21,13 +21,14 @@ #include // numeric_limits #include +#include #include namespace raft::common { constexpr std::size_t dynamic_extent = std::numeric_limits::max(); -template +template class Span; namespace detail { @@ -39,7 +40,7 @@ namespace detail { * - Otherwise, dynamic_extent. */ template -struct ExtentValue +struct extent_value_t : public std::integral_constant< std::size_t, Count != dynamic_extent ? Count @@ -51,39 +52,28 @@ struct ExtentValue * dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. */ template -struct ExtentAsBytesValue +struct extent_as_bytes_value_t : public std::integral_constant< std::size_t, Extent == dynamic_extent ? Extent : sizeof(T) * Extent> {}; template -struct IsAllowedExtentConversion +struct is_allowed_extent_conversion_t : public std::integral_constant {}; template -struct IsAllowedElementTypeConversion +struct is_allowed_element_type_conversion_t : public std::integral_constant< bool, std::is_convertible::value> {}; template -struct IsSpanOracle : std::false_type {}; +struct is_span_oracle_t : std::false_type {}; -template -struct IsSpanOracle> : std::true_type {}; +template +struct is_span_oracle_t> : std::true_type {}; template -struct IsSpan : public IsSpanOracle::type> {}; - -// Re-implement std algorithms here to adopt CUDA. -template -struct Less { - constexpr bool operator()(const T& _x, const T& _y) const { return _x < _y; } -}; - -template -struct Greater { - constexpr bool operator()(const T& _x, const T& _y) const { return _x > _y; } -}; +struct is_span_t : public is_span_oracle_t::type> {}; template __host__ __device__ constexpr auto lexicographical_compare( @@ -105,7 +95,7 @@ __host__ __device__ constexpr auto lexicographical_compare( * \brief The span class defined in ISO C++20. Iterator is defined as plain pointer and * most of the methods have bound check on debug build. */ -template +template class Span { public: using element_type = T; @@ -139,11 +129,12 @@ class Span { template constexpr Span(element_type (&arr)[N]) noexcept : size_(N), data_(&arr[0]) {} - template ::value && - detail::IsAllowedExtentConversion::value>> - constexpr Span(const Span& _other) noexcept + template < + class U, std::size_t OtherExtent, + class = typename std::enable_if< + detail::is_allowed_element_type_conversion_t::value && + detail::is_allowed_extent_conversion_t::value>> + constexpr Span(const Span& _other) noexcept : size_(_other.size()), data_(_other.data()) {} constexpr Span(const Span& _other) noexcept @@ -208,25 +199,25 @@ class Span { // Subviews template - constexpr auto first() const -> Span { + constexpr auto first() const -> Span { assert(Count <= size()); return {data(), Count}; } constexpr auto first(std::size_t _count) const - -> Span { + -> Span { assert(_count <= size()); return {data(), _count}; } template - constexpr auto last() const -> Span { + constexpr auto last() const -> Span { assert(Count <= size()); return {data() + size() - Count, Count}; } constexpr auto last(std::size_t _count) const - -> Span { + -> Span { assert(_count <= size()); return subspan(size() - _count, _count); } @@ -237,7 +228,8 @@ class Span { */ template constexpr auto subspan() const - -> Span::value> { + -> Span::value> { assert((Count == dynamic_extent) ? (Offset <= size()) : (Offset + Count <= size())); return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; @@ -245,7 +237,7 @@ class Span { constexpr auto subspan(size_type _offset, size_type _count = dynamic_extent) const - -> Span { + -> Span { assert((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size())); return {data() + _offset, @@ -257,8 +249,15 @@ class Span { pointer data_{nullptr}; }; -template -constexpr auto operator==(Span l, Span r) -> bool { +template +using host_span = Span; + +template +using device_span = Span; + +template +constexpr auto operator==(Span l, Span r) + -> bool { if (l.size() != r.size()) { return false; } @@ -271,46 +270,49 @@ constexpr auto operator==(Span l, Span r) -> bool { return true; } -template -constexpr auto operator!=(Span l, Span r) { +template +constexpr auto operator!=(Span l, Span r) { return !(l == r); } -template -constexpr auto operator<(Span l, Span r) { +template +constexpr auto operator<(Span l, Span r) { return detail::lexicographical_compare< - typename Span::iterator, typename Span::iterator, - detail::Less::element_type>>(l.begin(), l.end(), - r.begin(), r.end()); + typename Span::iterator, + typename Span::iterator, + thrust::less::element_type>>( + l.begin(), l.end(), r.begin(), r.end()); } -template -constexpr auto operator<=(Span l, Span r) { +template +constexpr auto operator<=(Span l, Span r) { return !(l > r); } -template -constexpr auto operator>(Span l, Span r) { +template +constexpr auto operator>(Span l, Span r) { return detail::lexicographical_compare< - typename Span::iterator, typename Span::iterator, - detail::Greater::element_type>>(l.begin(), l.end(), - r.begin(), r.end()); + typename Span::iterator, + typename Span::iterator, + thrust::greater::element_type>>( + l.begin(), l.end(), r.begin(), r.end()); } -template -constexpr auto operator>=(Span l, Span r) { +template +constexpr auto operator>=(Span l, Span r) { return !(l < r); } -template -auto as_bytes(Span s) noexcept - -> Span::value> { +template +auto as_bytes(Span s) noexcept + -> Span::value> { return {reinterpret_cast(s.data()), s.size_bytes()}; } -template -auto as_writable_bytes(Span s) noexcept - -> Span::value> { +template +auto as_writable_bytes(Span s) noexcept + -> Span::value> { return {reinterpret_cast(s.data()), s.size_bytes()}; } } // namespace raft::common diff --git a/cpp/test/common/span.cpp b/cpp/test/common/span.cpp index 700eb20b1d..7659f7462b 100644 --- a/cpp/test/common/span.cpp +++ b/cpp/test/common/span.cpp @@ -1,39 +1,54 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include // iota #include -#include "test_span.h" +#include "test_span.hpp" namespace raft::common { TEST(Span, DlfConstructors) { // Dynamic extent { - Span s; + host_span s; ASSERT_EQ(s.size(), 0); ASSERT_EQ(s.data(), nullptr); - Span cs; + host_span cs; ASSERT_EQ(cs.size(), 0); ASSERT_EQ(cs.data(), nullptr); } // Static extent { - Span s; + host_span s; ASSERT_EQ(s.size(), 0); ASSERT_EQ(s.data(), nullptr); - Span cs; + host_span cs; ASSERT_EQ(cs.size(), 0); ASSERT_EQ(cs.data(), nullptr); } // Init list. { - Span s{}; + host_span s{}; ASSERT_EQ(s.size(), 0); ASSERT_EQ(s.data(), nullptr); - Span cs{}; + host_span cs{}; ASSERT_EQ(cs.size(), 0); ASSERT_EQ(cs.data(), nullptr); } @@ -42,21 +57,23 @@ TEST(Span, DlfConstructors) { TEST(Span, FromNullPtr) { // dynamic extent { - Span s{nullptr, static_cast::size_type>(0)}; + host_span s{nullptr, static_cast::size_type>(0)}; ASSERT_EQ(s.size(), 0); ASSERT_EQ(s.data(), nullptr); - Span cs{nullptr, static_cast::size_type>(0)}; + host_span cs{nullptr, + static_cast::size_type>(0)}; ASSERT_EQ(cs.size(), 0); ASSERT_EQ(cs.data(), nullptr); } // static extent { - Span s{nullptr, static_cast::size_type>(0)}; + host_span s{nullptr, static_cast::size_type>(0)}; ASSERT_EQ(s.size(), 0); ASSERT_EQ(s.data(), nullptr); - Span cs{nullptr, static_cast::size_type>(0)}; + host_span cs{nullptr, + static_cast::size_type>(0)}; ASSERT_EQ(cs.size(), 0); ASSERT_EQ(cs.data(), nullptr); } @@ -68,26 +85,26 @@ TEST(Span, FromPtrLen) { // static extent { - Span s(arr, 16); + host_span s(arr, 16); ASSERT_EQ(s.size(), 16); ASSERT_EQ(s.data(), arr); - for (Span::size_type i = 0; i < 16; ++i) { + for (host_span::size_type i = 0; i < 16; ++i) { ASSERT_EQ(s[i], arr[i]); } - Span cs(arr, 16); + host_span cs(arr, 16); ASSERT_EQ(cs.size(), 16); ASSERT_EQ(cs.data(), arr); - for (Span::size_type i = 0; i < 16; ++i) { + for (host_span::size_type i = 0; i < 16; ++i) { ASSERT_EQ(cs[i], arr[i]); } } // dynamic extent { - Span s(arr, 16); + host_span s(arr, 16); ASSERT_EQ(s.size(), 16); ASSERT_EQ(s.data(), arr); @@ -95,11 +112,11 @@ TEST(Span, FromPtrLen) { ASSERT_EQ(s[i], arr[i]); } - Span cs(arr, 16); + host_span cs(arr, 16); ASSERT_EQ(cs.size(), 16); ASSERT_EQ(cs.data(), arr); - for (Span::size_type i = 0; i < 16; ++i) { + for (host_span::size_type i = 0; i < 16; ++i) { ASSERT_EQ(cs[i], arr[i]); } } @@ -107,47 +124,47 @@ TEST(Span, FromPtrLen) { TEST(Span, FromFirstLast) { float arr[16]; - InitializeRange(arr, arr+16); + InitializeRange(arr, arr + 16); // dynamic extent { - Span s (arr, arr + 16); - ASSERT_EQ (s.size(), 16); - ASSERT_EQ (s.data(), arr); - ASSERT_EQ (s.data() + s.size(), arr + 16); + host_span s(arr, arr + 16); + ASSERT_EQ(s.size(), 16); + ASSERT_EQ(s.data(), arr); + ASSERT_EQ(s.data() + s.size(), arr + 16); for (size_t i = 0; i < 16; ++i) { - ASSERT_EQ (s[i], arr[i]); + ASSERT_EQ(s[i], arr[i]); } - Span cs (arr, arr + 16); - ASSERT_EQ (cs.size(), 16); - ASSERT_EQ (cs.data(), arr); - ASSERT_EQ (cs.data() + cs.size(), arr + 16); + host_span cs(arr, arr + 16); + ASSERT_EQ(cs.size(), 16); + ASSERT_EQ(cs.data(), arr); + ASSERT_EQ(cs.data() + cs.size(), arr + 16); for (size_t i = 0; i < 16; ++i) { - ASSERT_EQ (cs[i], arr[i]); + ASSERT_EQ(cs[i], arr[i]); } } // static extent { - Span s (arr, arr + 16); - ASSERT_EQ (s.size(), 16); - ASSERT_EQ (s.data(), arr); - ASSERT_EQ (s.data() + s.size(), arr + 16); + host_span s(arr, arr + 16); + ASSERT_EQ(s.size(), 16); + ASSERT_EQ(s.data(), arr); + ASSERT_EQ(s.data() + s.size(), arr + 16); for (size_t i = 0; i < 16; ++i) { - ASSERT_EQ (s[i], arr[i]); + ASSERT_EQ(s[i], arr[i]); } - Span cs (arr, arr + 16); - ASSERT_EQ (cs.size(), 16); - ASSERT_EQ (cs.data(), arr); - ASSERT_EQ (cs.data() + cs.size(), arr + 16); + host_span cs(arr, arr + 16); + ASSERT_EQ(cs.size(), 16); + ASSERT_EQ(cs.data(), arr); + ASSERT_EQ(cs.data() + cs.size(), arr + 16); for (size_t i = 0; i < 16; ++i) { - ASSERT_EQ (cs[i], arr[i]); + ASSERT_EQ(cs[i], arr[i]); } } } @@ -160,11 +177,10 @@ struct DerivedClass : public BaseClass { }; TEST(Span, FromOther) { - // convert constructor { - Span derived; - Span base { derived }; + host_span derived; + host_span base{derived}; ASSERT_EQ(base.size(), derived.size()); ASSERT_EQ(base.data(), derived.data()); } @@ -174,8 +190,8 @@ TEST(Span, FromOther) { // default copy constructor { - Span s0 (arr); - Span s1 (s0); + host_span s0(arr); + host_span s1(s0); ASSERT_EQ(s0.size(), s1.size()); ASSERT_EQ(s0.data(), s1.data()); } @@ -186,7 +202,7 @@ TEST(Span, FromArray) { InitializeRange(arr, arr + 16); { - Span s (arr); + host_span s(arr); ASSERT_EQ(&arr[0], s.data()); ASSERT_EQ(s.size(), 16); for (size_t i = 0; i < 16; ++i) { @@ -195,7 +211,7 @@ TEST(Span, FromArray) { } { - Span s (arr); + host_span s(arr); ASSERT_EQ(&arr[0], s.data()); ASSERT_EQ(s.size(), 16); for (size_t i = 0; i < 16; ++i) { @@ -206,13 +222,13 @@ TEST(Span, FromArray) { TEST(Span, Assignment) { int status = 1; - TestAssignment{&status}(); + TestAssignment{&status}(); ASSERT_EQ(status, 1); } TEST(Span, BeginEnd) { int status = 1; - TestBeginEnd{&status}(); + TestBeginEnd{&status}(); ASSERT_EQ(status, 1); } @@ -220,7 +236,7 @@ TEST(Span, ElementAccess) { float arr[16]; InitializeRange(arr, arr + 16); - Span s (arr); + host_span s(arr); size_t j = 0; for (auto i : s) { ASSERT_EQ(i, arr[j]); @@ -228,49 +244,35 @@ TEST(Span, ElementAccess) { } } - TEST(Span, Obversers) { int status = 1; - TestObservers{&status}(); + TestObservers{&status}(); ASSERT_EQ(status, 1); } TEST(Span, FrontBack) { { float arr[4]{0, 1, 2, 3}; - Span s(arr); + host_span s(arr); ASSERT_EQ(s.front(), 0); ASSERT_EQ(s.back(), 3); } { std::vector arr{0, 1, 2, 3}; - Span s(arr.data(), arr.size()); + host_span s(arr.data(), arr.size()); ASSERT_EQ(s.front(), 0); ASSERT_EQ(s.back(), 3); } } -TEST(SpanDeathTest, FrontBack) { - { - Span s; - EXPECT_DEATH(s.front(), ""); - EXPECT_DEATH(s.back(), ""); - } - { - Span s; - EXPECT_DEATH(s.front(), ""); - EXPECT_DEATH(s.back(), ""); - } -} - TEST(Span, FirstLast) { // static extent { float arr[16]; InitializeRange(arr, arr + 16); - Span s (arr); - Span first = s.first<4>(); + host_span s(arr); + host_span first = s.first<4>(); ASSERT_EQ(first.size(), 4); ASSERT_EQ(first.data(), arr); @@ -284,14 +286,14 @@ TEST(Span, FirstLast) { float arr[16]; InitializeRange(arr, arr + 16); - Span s (arr); - Span last = s.last<4>(); + host_span s(arr); + host_span last = s.last<4>(); ASSERT_EQ(last.size(), 4); ASSERT_EQ(last.data(), arr + 12); for (size_t i = 0; i < last.size(); ++i) { - ASSERT_EQ(last[i], arr[i+12]); + ASSERT_EQ(last[i], arr[i + 12]); } } @@ -299,8 +301,8 @@ TEST(Span, FirstLast) { { float *arr = new float[16]; InitializeRange(arr, arr + 16); - Span s (arr, 16); - Span first = s.first(4); + host_span s(arr, 16); + host_span first = s.first(4); ASSERT_EQ(first.size(), 4); ASSERT_EQ(first.data(), s.data()); @@ -309,14 +311,14 @@ TEST(Span, FirstLast) { ASSERT_EQ(first[i], s[i]); } - delete [] arr; + delete[] arr; } { float *arr = new float[16]; InitializeRange(arr, arr + 16); - Span s (arr, 16); - Span last = s.last(4); + host_span s(arr, 16); + host_span last = s.last(4); ASSERT_EQ(last.size(), 4); ASSERT_EQ(last.data(), s.data() + 12); @@ -325,13 +327,13 @@ TEST(Span, FirstLast) { ASSERT_EQ(s[12 + i], last[i]); } - delete [] arr; + delete[] arr; } } TEST(Span, Subspan) { - int arr[16] {0}; - Span s1 (arr); + int arr[16]{0}; + host_span s1(arr); auto s2 = s1.subspan<4>(); ASSERT_EQ(s1.size() - 4, s2.size()); @@ -346,25 +348,25 @@ TEST(Span, Subspan) { TEST(Span, Compare) { int status = 1; - TestCompare{&status}(); + TestCompare{&status}(); ASSERT_EQ(status, 1); } TEST(Span, AsBytes) { int status = 1; - TestAsBytes{&status}(); + TestAsBytes{&status}(); ASSERT_EQ(status, 1); } TEST(Span, AsWritableBytes) { int status = 1; - TestAsWritableBytes{&status}(); + TestAsWritableBytes{&status}(); ASSERT_EQ(status, 1); } TEST(Span, Empty) { { - Span s {nullptr, static_cast::size_type>(0)}; + host_span s{nullptr, static_cast::size_type>(0)}; auto res = s.subspan(0); ASSERT_EQ(res.data(), nullptr); ASSERT_EQ(res.size(), 0); @@ -375,7 +377,7 @@ TEST(Span, Empty) { } { - Span s {nullptr, static_cast::size_type>(0)}; + host_span s{nullptr, static_cast::size_type>(0)}; auto res = s.subspan(0); ASSERT_EQ(res.data(), nullptr); ASSERT_EQ(res.size(), 0); diff --git a/cpp/test/common/span.cu b/cpp/test/common/span.cu index 108a7ee583..3356fe40f4 100644 --- a/cpp/test/common/span.cu +++ b/cpp/test/common/span.cu @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include #include @@ -5,7 +20,7 @@ #include // iota #include #include -#include "test_span.h" +#include "test_span.hpp" namespace raft::common { struct TestStatus { @@ -31,7 +46,7 @@ struct TestStatus { int* Data() { return status_; } }; -__global__ void TestFromOtherKernel(Span span) { +__global__ void TestFromOtherKernel(device_span span) { // don't get optimized out size_t idx = threadIdx.x + blockIdx.x * blockDim.x; @@ -40,7 +55,7 @@ __global__ void TestFromOtherKernel(Span span) { } } // Test converting different T -__global__ void TestFromOtherKernelConst(Span span) { +__global__ void TestFromOtherKernelConst(device_span span) { // don't get optimized out size_t idx = threadIdx.x + blockIdx.x * blockDim.x; @@ -60,20 +75,20 @@ TEST(GPUSpan, FromOther) { thrust::copy(h_vec.begin(), h_vec.end(), d_vec.begin()); // dynamic extent { - Span span(d_vec.data().get(), d_vec.size()); + device_span span(d_vec.data().get(), d_vec.size()); TestFromOtherKernel<<<1, 16>>>(span); } { - Span span(d_vec.data().get(), d_vec.size()); + device_span span(d_vec.data().get(), d_vec.size()); TestFromOtherKernelConst<<<1, 16>>>(span); } // static extent { - Span span(d_vec.data().get(), d_vec.data().get() + 16); + device_span span(d_vec.data().get(), d_vec.data().get() + 16); TestFromOtherKernel<<<1, 16>>>(span); } { - Span span(d_vec.data().get(), d_vec.data().get() + 16); + device_span span(d_vec.data().get(), d_vec.data().get() + 16); TestFromOtherKernelConst<<<1, 16>>>(span); } } @@ -82,7 +97,7 @@ TEST(GPUSpan, Assignment) { CUDA_CHECK(cudaSetDevice(0)); TestStatus status; thrust::for_each_n(thrust::make_counting_iterator(0ul), 16, - TestAssignment{status.Data()}); + TestAssignment{status.Data()}); ASSERT_EQ(status.Get(), 1); } @@ -123,7 +138,7 @@ TEST(GPUSpan, WithTrust) { // Can't initialize span with device_vector, since d_vec.data() is not raw // pointer { - Span s(d_vec.data().get(), d_vec.size()); + device_span s(d_vec.data().get(), d_vec.size()); ASSERT_EQ(d_vec.size(), s.size()); ASSERT_EQ(d_vec.data().get(), s.data()); @@ -133,7 +148,7 @@ TEST(GPUSpan, WithTrust) { TestStatus status; thrust::device_vector d_vec1(d_vec.size()); thrust::copy(thrust::device, d_vec.begin(), d_vec.end(), d_vec1.begin()); - Span s(d_vec1.data().get(), d_vec.size()); + device_span s(d_vec1.data().get(), d_vec.size()); thrust::for_each_n(thrust::make_counting_iterator(0ul), 16, TestEqual{thrust::raw_pointer_cast(d_vec1.data()), @@ -146,7 +161,7 @@ TEST(GPUSpan, BeginEnd) { CUDA_CHECK(cudaSetDevice(0)); TestStatus status; thrust::for_each_n(thrust::make_counting_iterator(0ul), 16, - TestBeginEnd{status.Data()}); + TestBeginEnd{status.Data()}); ASSERT_EQ(status.Get(), 1); } @@ -154,11 +169,11 @@ TEST(GPUSpan, RBeginREnd) { CUDA_CHECK(cudaSetDevice(0)); TestStatus status; thrust::for_each_n(thrust::make_counting_iterator(0ul), 16, - TestRBeginREnd{status.Data()}); + TestRBeginREnd{status.Data()}); ASSERT_EQ(status.Get(), 1); } -__global__ void TestModifyKernel(Span span) { +__global__ void TestModifyKernel(device_span span) { size_t idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= span.size()) { @@ -174,7 +189,7 @@ TEST(GPUSpan, Modify) { thrust::device_vector d_vec(h_vec.size()); thrust::copy(h_vec.begin(), h_vec.end(), d_vec.begin()); - Span span(d_vec.data().get(), d_vec.size()); + device_span span(d_vec.data().get(), d_vec.size()); TestModifyKernel<<<1, 16>>>(span); @@ -187,16 +202,16 @@ TEST(GPUSpan, Observers) { CUDA_CHECK(cudaSetDevice(0)); TestStatus status; thrust::for_each_n(thrust::make_counting_iterator(0ul), 16, - TestObservers{status.Data()}); + TestObservers{status.Data()}); ASSERT_EQ(status.Get(), 1); } struct TestElementAccess { private: - Span span_; + device_span span_; public: - HD explicit TestElementAccess(Span _span) : span_(_span) {} + HD explicit TestElementAccess(device_span _span) : span_(_span) {} HD float operator()(size_t _idx) { float tmp = span_[_idx]; diff --git a/cpp/test/common/test_span.h b/cpp/test/common/test_span.hpp similarity index 67% rename from cpp/test/common/test_span.h rename to cpp/test/common/test_span.hpp index 6ec377085d..0a03e378b7 100644 --- a/cpp/test/common/test_span.h +++ b/cpp/test/common/test_span.hpp @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once #include namespace raft::common { @@ -30,6 +46,7 @@ struct TestTestStatus { } }; +template struct TestAssignment { int* status_; @@ -37,11 +54,11 @@ struct TestAssignment { __host__ __device__ void operator()() { this->operator()(0); } __host__ __device__ void operator()(int _idx) { - Span s1; + Span s1; float arr[] = {3, 4, 5}; - Span s2 = arr; + Span s2 = arr; SPAN_ASSERT_TRUE(s2.size() == 3, status_); SPAN_ASSERT_TRUE(s2.data() == &arr[0], status_); @@ -50,6 +67,7 @@ struct TestAssignment { } }; +template struct TestBeginEnd { int* status_; @@ -60,9 +78,9 @@ struct TestBeginEnd { float arr[16]; InitializeRange(arr, arr + 16); - Span s(arr); - Span::iterator beg{s.begin()}; - Span::iterator end{s.end()}; + Span s(arr); + typename Span::iterator beg{s.begin()}; + typename Span::iterator end{s.end()}; SPAN_ASSERT_TRUE(end == beg + 16, status_); SPAN_ASSERT_TRUE(*beg == arr[0], status_); @@ -70,6 +88,7 @@ struct TestBeginEnd { } }; +template struct TestRBeginREnd { int* status_; @@ -80,10 +99,10 @@ struct TestRBeginREnd { float arr[16]; InitializeRange(arr, arr + 16); - Span s(arr); + Span s(arr); s.rbegin(); - Span::reverse_iterator rbeg{s.rbegin()}; - Span::reverse_iterator rend{s.rend()}; + typename Span::reverse_iterator rbeg{s.rbegin()}; + typename Span::reverse_iterator rend{s.rend()}; SPAN_ASSERT_TRUE(rbeg + 16 == rend, status_); SPAN_ASSERT_TRUE(*(rbeg) == arr[15], status_); @@ -91,6 +110,7 @@ struct TestRBeginREnd { } }; +template struct TestObservers { int* status_; @@ -101,14 +121,15 @@ struct TestObservers { // empty { float* arr = nullptr; - Span s(arr, static_cast::size_type>(0)); + Span s( + arr, static_cast::size_type>(0)); SPAN_ASSERT_TRUE(s.empty(), status_); } // size, size_types { float* arr = new float[16]; - Span s(arr, 16); + Span s(arr, 16); SPAN_ASSERT_TRUE(s.size() == 16, status_); SPAN_ASSERT_TRUE(s.size_bytes() == 16 * sizeof(float), status_); delete[] arr; @@ -116,6 +137,7 @@ struct TestObservers { } }; +template struct TestCompare { int* status_; @@ -127,8 +149,8 @@ struct TestCompare { InitializeRange(lhs_arr, lhs_arr + 16); InitializeRange(rhs_arr, rhs_arr + 16); - Span lhs(lhs_arr); - Span rhs(rhs_arr); + Span lhs(lhs_arr); + Span rhs(rhs_arr); SPAN_ASSERT_TRUE(lhs == rhs, status_); SPAN_ASSERT_FALSE(lhs != rhs, status_); @@ -144,6 +166,7 @@ struct TestCompare { } }; +template struct TestAsBytes { int* status_; @@ -155,8 +178,8 @@ struct TestAsBytes { InitializeRange(arr, arr + 16); { - const Span s{arr}; - const Span bs = as_bytes(s); + const Span s{arr}; + const Span bs = as_bytes(s); SPAN_ASSERT_TRUE(bs.size() == s.size_bytes(), status_); SPAN_ASSERT_TRUE(static_cast(bs.data()) == static_cast(s.data()), @@ -164,8 +187,8 @@ struct TestAsBytes { } { - Span s; - const Span bs = as_bytes(s); + Span s; + const Span bs = as_bytes(s); SPAN_ASSERT_TRUE(bs.size() == s.size(), status_); SPAN_ASSERT_TRUE(bs.size() == 0, status_); SPAN_ASSERT_TRUE(bs.size_bytes() == 0, status_); @@ -177,6 +200,7 @@ struct TestAsBytes { } }; +template struct TestAsWritableBytes { int* status_; @@ -188,20 +212,21 @@ struct TestAsWritableBytes { InitializeRange(arr, arr + 16); { - Span s; - Span bs = as_writable_bytes(s); - SPAN_ASSERT_TRUE(bs.size() == s.size(), status_); - SPAN_ASSERT_TRUE(bs.size_bytes() == s.size_bytes(), status_); - SPAN_ASSERT_TRUE(bs.size() == 0, status_); - SPAN_ASSERT_TRUE(bs.size_bytes() == 0, status_); - SPAN_ASSERT_TRUE(bs.data() == nullptr, status_); + Span s; + Span byte_s = as_writable_bytes(s); + SPAN_ASSERT_TRUE(byte_s.size() == s.size(), status_); + SPAN_ASSERT_TRUE(byte_s.size_bytes() == s.size_bytes(), status_); + SPAN_ASSERT_TRUE(byte_s.size() == 0, status_); + SPAN_ASSERT_TRUE(byte_s.size_bytes() == 0, status_); + SPAN_ASSERT_TRUE(byte_s.data() == nullptr, status_); SPAN_ASSERT_TRUE( - static_cast(bs.data()) == static_cast(s.data()), status_); + static_cast(byte_s.data()) == static_cast(s.data()), + status_); } { - Span s{arr}; - Span bs{as_writable_bytes(s)}; + Span s{arr}; + Span bs{as_writable_bytes(s)}; SPAN_ASSERT_TRUE(s.size_bytes() == bs.size_bytes(), status_); SPAN_ASSERT_TRUE( static_cast(bs.data()) == static_cast(s.data()), status_);