diff --git a/include/oneapi/dpl/pstl/ranges_defs.h b/include/oneapi/dpl/pstl/ranges_defs.h index 39f67ed858..aab2aad2db 100644 --- a/include/oneapi/dpl/pstl/ranges_defs.h +++ b/include/oneapi/dpl/pstl/ranges_defs.h @@ -23,6 +23,9 @@ #include #endif +//oneapi::dpl::ranges::zip_view support for C++20 +#include "zip_view_impl.h" + #include "utils_ranges.h" #if _ONEDPL_BACKEND_SYCL # include "hetero/dpcpp/utils_ranges_sycl.h" diff --git a/include/oneapi/dpl/pstl/tuple_impl.h b/include/oneapi/dpl/pstl/tuple_impl.h index 239734d486..01b5237bdf 100644 --- a/include/oneapi/dpl/pstl/tuple_impl.h +++ b/include/oneapi/dpl/pstl/tuple_impl.h @@ -500,6 +500,15 @@ struct tuple next = other.next; return *this; } + + template + tuple& + operator=(const tuple& other) const + { + holder.value = other.holder.value; + next = other.next; + return *this; + } // if T1 is deduced with reference, compiler generates deleted operator= and, // since "template operator=" is not considered as operator= overload diff --git a/include/oneapi/dpl/pstl/zip_view_impl.h b/include/oneapi/dpl/pstl/zip_view_impl.h new file mode 100644 index 0000000000..2a7d88f8e6 --- /dev/null +++ b/include/oneapi/dpl/pstl/zip_view_impl.h @@ -0,0 +1,408 @@ +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Copyright (C) Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This file incorporates work covered by the following copyright and permission +// notice: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef _ONEDPL_ZIP_VIEW_IMPL_H +#define _ONEDPL_ZIP_VIEW_IMPL_H + +#if _ONEDPL_CPP20_RANGES_PRESENT + +#include +#include +#include + +#include "tuple_impl.h" +#include "iterator_impl.h" + +namespace oneapi +{ +namespace dpl +{ + +namespace ranges +{ + +template +concept all_forward = ( std::ranges::forward_range> && ... ); + +template +concept all_bidirectional = ( std::ranges::bidirectional_range> && ... ); + +template +concept all_random_access = ( std::ranges::random_access_range> && ... ); + +template +concept zip_is_common = + (sizeof...(Rs) == 1 && ( std::ranges::common_range && ... )) + || + (!(std::ranges::bidirectional_range && ...) && (std::ranges::common_range && ...)) + || + ((std::ranges::random_access_range && ...) && (std::ranges::sized_range && ...)); + +template +struct declare_iterator_category {}; + +template + requires all_forward +struct declare_iterator_category { + using iterator_category = std::input_iterator_tag; +}; + +template + requires ((std::ranges::view && ... ) && (sizeof...(Views) > 0)) +class zip_view : public std::ranges::view_interface> { + template + using tuple_type = oneapi::dpl::__internal::tuple; + + template + static decltype(auto) + apply_to_tuple_impl(_ReturnAdapter __tr, _F __f, _Tuple& __t, std::index_sequence<_Ip...>) + { + return __tr(__f(std::get<_Ip>(__t))...); + } + + template + static decltype(auto) + apply_to_tuple(_ReturnAdapter __tr, _F __f, _Tuple& __t) + { + return apply_to_tuple_impl(__tr, __f, __t, std::make_index_sequence{}); + } + + template + static void + apply_to_tuple(_F __f, _Tuple& __t) + { + apply_to_tuple([](auto...){}, __f, __t); + } + +public: + zip_view() = default; + constexpr zip_view(Views... views) : views_(std::move(views)...) {} + + template + class iterator : declare_iterator_category { + public: + using iterator_concept = std::conditional_t, + std::random_access_iterator_tag, + std::conditional_t, + std::bidirectional_iterator_tag, + std::conditional_t, + std::forward_iterator_tag, + std::input_iterator_tag>>>; + + using value_type = std::conditional_t...>, + tuple_type...>>; + + using difference_type = std::conditional_t...>, + std::common_type_t...>>; + + iterator() = default; + + constexpr iterator(iterator i) + requires Const && (std::convertible_to, std::ranges::iterator_t> && ...) + : current_(std::move(i.current_)) {} + + private: + template + constexpr iterator(const Iterators&... iterators) + : current_(iterators...) {} + public: + + template + operator oneapi::dpl::zip_iterator() const + { + auto __tr = [](auto&&... __args) -> decltype(auto) { return oneapi::dpl::make_zip_iterator(__args...);}; + return apply_to_tuple(__tr, [](auto it) -> decltype(auto) { return it;}, current_); + } + + constexpr decltype(auto) operator*() const { + auto __tr = [](auto&&... __args) -> decltype(auto) { + using return_tuple_type = std::conditional_t< + !Const, tuple_type...>, + tuple_type...>>; + return return_tuple_type(std::forward(__args)...); + }; + return apply_to_tuple(__tr, [](auto it) -> decltype(auto) { return *it;}, current_); + } + + constexpr decltype(auto) operator[]( difference_type n ) const + requires all_random_access + { + return *(*this + n); + } + + constexpr iterator& operator++() { + zip_view::apply_to_tuple([](auto& it) { return ++it; }, current_); + return *this; + } + + constexpr void operator++(int) { + ++*this; + } + + constexpr iterator operator++(int) requires all_forward { + auto tmp = *this; + ++*this; + return tmp; + } + + constexpr iterator& operator--() requires all_bidirectional { + zip_view::apply_to_tuple([](auto& it) { return --it; }, current_); + return *this; + } + + constexpr iterator operator--(int) requires all_bidirectional { + auto tmp = *this; + --*this; + return tmp; + } + + constexpr iterator& operator+=(difference_type n) + requires all_random_access + { + zip_view::apply_to_tuple([n](auto& it) { return it += n; }, current_); + return *this; + } + + constexpr iterator& operator-=(difference_type n) + requires all_random_access + { + zip_view::apply_to_tuple([n](auto& it) { return it -= n; }, current_); + return *this; + } + + friend constexpr bool operator==(const iterator& x, const iterator& y) + requires ( std::equality_comparable, + std::ranges::iterator_t>> && ... ) + { + if constexpr (all_bidirectional) { + return x.current_ == y.current_; + } else { + return x.compare_equal(y, std::make_index_sequence()); + } + } + + friend constexpr auto operator<=>(const iterator& x, const iterator& y) + requires all_random_access + { + if (x.current_ < y.current_) + return -1; + else if (x.current_ == y.current_) + return 0; + return 1; //x.current > y.current_ + } + + friend constexpr auto operator-(const iterator& x, const iterator& y) + requires all_random_access + { + return y.distance_to_it(x.current_, std::make_index_sequence()); + } + + friend constexpr iterator operator+(iterator it, difference_type n) + { + return it += n; + } + + friend constexpr iterator operator+(difference_type n, iterator it) + { + return it += n; + } + + friend constexpr iterator operator-(iterator it, difference_type n) + { + return it -= n; + } + + private: + template + constexpr bool compare_equal(iterator y, std::index_sequence) { + return ((std::get(current_) == std::get(y.current_)) && ...); + } + + template + constexpr bool compare_with_sentinels(const SentinelsTuple& sentinels, std::index_sequence) const { + return ( (std::get(current_) == std::get(sentinels)) || ... ); + } + + template + constexpr std::common_type_t, + std::ranges::range_difference_t>...> + distance_to_sentinels(const SentinelsTuple& sentinels, std::index_sequence<0, In...>) { + auto min = std::get<0>(current_) - std::get<0>(sentinels); + + ( (min = std::min(min, (std::get(current_) - std::get(sentinels)))) , ... ); + return min; + } + template + constexpr std::common_type_t, + std::ranges::range_difference_t>...> + distance_to_it(const iterator it, std::index_sequence<0, In...>) const { + auto min = std::get<0>(it.current_) - std::get<0>(current_); + + ( (min = std::min(min, (std::get(it.current_) - std::get(current_)))) , ... ); + return min; + } + + friend class zip_view; + + using current_type = std::conditional_t...>, + tuple_type...>>; + + current_type current_; + }; // class iterator + + template + class sentinel { + public: + sentinel() = default; + constexpr sentinel(sentinel i) + requires Const && + ( std::convertible_to, std::ranges::sentinel_t> && ... ) + : end_(std::move(i.end_)) {} + + private: + template + constexpr sentinel(const Sentinels&... sentinels) + : end_(sentinels...) {} + public: + template + requires (std::sentinel_for, + std::ranges::sentinel_t>, + std::conditional_t, + std::ranges::iterator_t>> && ...) + friend constexpr bool operator==(const iterator& x, const sentinel& y) + { + return x.compare_with_sentinels(y.end_, std::make_index_sequence()); + } + + template + requires (std::sized_sentinel_for, + std::ranges::sentinel_t>, + std::conditional_t, + std::ranges::iterator_t>> && ...) + friend constexpr std::common_type_t, + std::ranges::range_difference_t>...> + operator-(const iterator& x, const sentinel& y) { + return x.distance_to_sentinels(y.end_, std::make_index_sequence()); + } + + template + requires (std::sized_sentinel_for, + std::ranges::sentinel_t>, + std::conditional_t, + std::ranges::iterator_t>> && ...) + friend constexpr std::common_type_t, + std::ranges::range_difference_t>...> + operator-(const sentinel& y, const iterator& x) { + return -(x - y); + } + + private: + friend class zip_view; + + using end_type = std::conditional_t...>, + tuple_type...>>; + + end_type end_; + }; // class sentinel + + constexpr auto begin() requires (std::ranges::range && ...) + { + auto __tr = [](auto... __args) { return iterator(__args...);}; + return apply_to_tuple(__tr, std::ranges::begin, views_); + } + + constexpr auto begin() const requires ( std::ranges::range && ... ) + { + auto __tr = [](auto... __args) { return iterator(__args...);}; + return apply_to_tuple(__tr, std::ranges::begin, views_); + } + + constexpr auto end() requires (std::ranges::range && ...) + { + if constexpr (!zip_is_common) + { + auto __tr = [](auto... __args) { return sentinel(__args...);}; + return apply_to_tuple(__tr, std::ranges::end, views_); + } + else if constexpr ((std::ranges::random_access_range && ...)) + { + auto it = begin(); + it += size(); + return it; + } + else + { + auto __tr = [](auto... __args) { return iterator(__args...);}; + return apply_to_tuple(__tr, std::ranges::end, views_); + } + } + + constexpr auto end() const requires (std::ranges::range && ...) + { + if constexpr (!zip_is_common) + { + auto __tr = [](auto... __args) { return sentinel(__args...);}; + return apply_to_tuple(__tr, std::ranges::end, views_); + } + else if constexpr ((std::ranges::random_access_range && ...)) + { + auto it = begin(); + it += size(); + return it; + } + else + { + auto __tr = [](auto... __args) { return iterator(__args...);}; + return apply_to_tuple(__tr, std::ranges::end, views_); + } + } + + constexpr auto size() requires (std::ranges::sized_range && ...) + { + auto __tr = [](auto... __args) { + using CT = std::make_unsigned_t>; + return std::ranges::min({CT(__args)...}); + }; + + return apply_to_tuple(__tr, std::ranges::size, views_); + } + + constexpr auto size() const requires (std::ranges::sized_range && ...) + { + return const_cast(this)->size(); + } +private: + tuple_type views_; +}; // class zip_view + +template +zip_view(Rs&&...) -> zip_view...>; + +struct zip_fn { + template + constexpr auto operator()( Rs&&... rs ) const { + return zip_view...>(std::forward(rs)...); + } +}; + +inline constexpr zip_fn zip{}; + +} // namespace ranges +} // namespace dpl +} // namespace oneapi + +#endif //_ONEDPL_CPP20_RANGES_PRESENT + +#endif //_ONEDPL_ZIP_VIEW_IMPL_H diff --git a/test/parallel_api/ranges/std_ranges_test.h b/test/parallel_api/ranges/std_ranges_test.h index b022dea294..0efc433648 100644 --- a/test/parallel_api/ranges/std_ranges_test.h +++ b/test/parallel_api/ranges/std_ranges_test.h @@ -144,6 +144,14 @@ template static constexpr bool is_range().begin())>> = true; +void call_with_host_policies(auto algo, auto... args) +{ + algo(oneapi::dpl::execution::seq, args...); + algo(oneapi::dpl::execution::unseq, args...); + algo(oneapi::dpl::execution::par, args...); + algo(oneapi::dpl::execution::par_unseq, args...); +} + template struct test { diff --git a/test/parallel_api/ranges/std_ranges_zip_view.pass.cpp b/test/parallel_api/ranges/std_ranges_zip_view.pass.cpp new file mode 100644 index 0000000000..3e4bc3366a --- /dev/null +++ b/test/parallel_api/ranges/std_ranges_zip_view.pass.cpp @@ -0,0 +1,97 @@ +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Copyright (C) Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This file incorporates work covered by the following copyright and permission +// notice: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#include "std_ranges_test.h" +#include + +#if _ENABLE_STD_RANGES_TESTING +#include + +void test_zip_view_base_op() +{ + namespace dpl_ranges = oneapi::dpl::ranges; + + constexpr int max_n = 100; + std::vector vec1(max_n); + std::vector vec2(max_n/2); + + auto zip_view = dpl_ranges::zip(vec1, vec2); + + static_assert(std::random_access_iterator); + static_assert(std::sentinel_for); + + EXPECT_TRUE(zip_view.end() - zip_view.begin() == max_n/2, + "Difference operation between an iterator and a sentinel (zip_view) returns a wrong result."); + + EXPECT_TRUE(zip_view[2] == *(zip_view.begin() + 2), + "Subscription or dereferencing operation for zip_view returns a wrong result."); + + EXPECT_TRUE(std::ranges::size(zip_view) == max_n/2, "zip_view::size method returns a wrong result."); + EXPECT_TRUE((bool)zip_view, "zip_view::operator bool() method returns a wrong result."); + + EXPECT_TRUE(zip_view[0] == zip_view.front(), "zip_view::front method returns a wrong result."); + EXPECT_TRUE(zip_view[zip_view.size() - 1] == zip_view.back(), "zip_view::back method returns a wrong result."); + EXPECT_TRUE(!zip_view.empty(), "zip_view::empty() method returns a wrong result."); + + using zip_view_t = dpl_ranges::zip_view>; + auto zip_view_0 = zip_view_t(); + EXPECT_TRUE(!zip_view_0.empty(), "zip_view::empty() method returns a wrong result."); +} +#endif //_ENABLE_STD_RANGES_TESTING + +std::int32_t +main() +{ +#if _ENABLE_STD_RANGES_TESTING + + test_zip_view_base_op(); + + namespace dpl_ranges = oneapi::dpl::ranges; + + constexpr int max_n = 10; + int data[max_n] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + + auto zip_view = dpl_ranges::zip(data, std::views::iota(0, max_n)) | std::views::take(5); + std::ranges::for_each(zip_view, test_std_ranges::f_mutuable, [](const auto& val) { return std::get<1>(val); }); + + test_std_ranges::call_with_host_policies(dpl_ranges::for_each, zip_view, test_std_ranges::f_mutuable, [](const auto& val) { return std::get<1>(val); }); + +#if TEST_DPCPP_BACKEND_PRESENT + dpl_ranges::for_each(test_std_ranges::dpcpp_policy(), zip_view, test_std_ranges::f_mutuable, [](const auto& val) { return std::get<1>(val); }); +#endif + + auto zip_view_sort = dpl_ranges::zip(data, data); + + oneapi::dpl::zip_iterator zip_it = zip_view_sort.begin(); //check conversion to oneapi::dpl::zip_iterator + + std::sort(zip_view_sort.begin(), zip_view_sort.begin() + max_n, [](const auto& val1, const auto& val2) { return std::get<0>(val1) > std::get<0>(val2); }); + for(int i = 0; i < max_n; ++i) + EXPECT_TRUE(std::get<0>(zip_view_sort[i]) == max_n - 1 - i, "Wrong effect for std::sort with zip_view."); + + std::ranges::sort(zip_view_sort, std::less{}, [](auto&& val) { return std::get<0>(val); }); + for(int i = 0; i < max_n; ++i) + EXPECT_TRUE(std::get<0>(zip_view_sort[i]) == i, "Wrong effect for std::ranges::sort with zip_view."); + + static_assert(std::ranges::random_access_range); + static_assert(std::random_access_iterator); + + test_std_ranges::call_with_host_policies(dpl_ranges::sort, zip_view_sort, std::greater{}, [](const auto& val) { return std::get<0>(val); }); + for(int i = 0; i < max_n; ++i) + assert(std::get<0>(zip_view_sort[i]) == max_n - 1 - i); + +#endif //_ENABLE_STD_RANGES_TESTING + + return TestUtils::done(_ENABLE_STD_RANGES_TESTING); +}