From b30a85620eb13a436107a9f58094f4012c5246cd Mon Sep 17 00:00:00 2001 From: Allison Vacanti Date: Wed, 27 May 2020 22:53:13 -0400 Subject: [PATCH] Update scan accum / binary_op edgecase handling. TBB's scan was implemented differently than the other backends, leading to some failing unit tests. This patch fixes these inconsistencies by making the following changes: - Follow P0571's guidance regarding accumulator variable type. - https://wg21.link/P0571 - The accumulator's type is now: - The type of the user-supplied initial value (if provided), or - The input iterator's value type if no initial value. - Follow C++ standard guidance for default binary operator type. - https://eel.is/c++draft/exclusive.scan#1 - Thrust binary/unary functors now specialize a default void template parameter. Types are deduced and forwarded transparently. - Updated the scan's default binary operator to the new `thrust::plus<>` specialization. - The `intermediate_type_from_function_and_iterators` helper is no longer needed and has been removed. Closes #1170. --- testing/scan.cu | 73 ++-- ...mediate_type_from_function_and_iterators.h | 61 ---- thrust/functional.h | 311 +++++++++++++++--- thrust/system/cuda/detail/transform_scan.h | 49 +-- .../system/detail/generic/reduce_by_key.inl | 23 +- thrust/system/detail/generic/scan.inl | 31 +- thrust/system/detail/generic/scan_by_key.inl | 6 +- .../system/detail/generic/transform_scan.inl | 50 +-- .../system/detail/sequential/reduce_by_key.h | 8 +- thrust/system/detail/sequential/scan.h | 50 +-- thrust/system/tbb/detail/scan.inl | 36 +- 11 files changed, 335 insertions(+), 363 deletions(-) delete mode 100644 thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h diff --git a/testing/scan.cu b/testing/scan.cu index 347b1c126..925c7bc8f 100644 --- a/testing/scan.cu +++ b/testing/scan.cu @@ -250,48 +250,49 @@ void TestScanMixedTypes(void) IntVector int_output(4); FloatVector float_output(4); - - // float -> int should use using plus operator by default + + // float -> int should use plus operator and float accumulator by default thrust::inclusive_scan(float_input.begin(), float_input.end(), int_output.begin()); - ASSERT_EQUAL(int_output[0], 1); - ASSERT_EQUAL(int_output[1], 3); - ASSERT_EQUAL(int_output[2], 6); - ASSERT_EQUAL(int_output[3], 10); - - // float -> float with plus operator (int accumulator) + ASSERT_EQUAL(int_output[0], 1); // in: 1.5 accum: 1.5f out: 1 + ASSERT_EQUAL(int_output[1], 4); // in: 2.5 accum: 4.0f out: 4 + ASSERT_EQUAL(int_output[2], 7); // in: 3.5 accum: 7.5f out: 7 + ASSERT_EQUAL(int_output[3], 12); // in: 4.5 accum: 12.f out: 12 + + // float -> float with plus operator (float accumulator) thrust::inclusive_scan(float_input.begin(), float_input.end(), float_output.begin(), thrust::plus()); - ASSERT_EQUAL(float_output[0], 1.5); - ASSERT_EQUAL(float_output[1], 3.0); - ASSERT_EQUAL(float_output[2], 6.0); - ASSERT_EQUAL(float_output[3], 10.0); - - // float -> int should use using plus operator by default + ASSERT_EQUAL(float_output[0], 1.5f); // in: 1.5 accum: 1.5f out: 1.5f + ASSERT_EQUAL(float_output[1], 3.0f); // in: 2.5 accum: 3.0f out: 3.0f + ASSERT_EQUAL(float_output[2], 6.0f); // in: 3.5 accum: 6.0f out: 6.0f + ASSERT_EQUAL(float_output[3], 10.0f); // in: 4.5 accum: 10.f out: 10.f + + // float -> int should use plus operator and float accumulator by default thrust::exclusive_scan(float_input.begin(), float_input.end(), int_output.begin()); - ASSERT_EQUAL(int_output[0], 0); - ASSERT_EQUAL(int_output[1], 1); - ASSERT_EQUAL(int_output[2], 3); - ASSERT_EQUAL(int_output[3], 6); - - // float -> int should use using plus operator by default + ASSERT_EQUAL(int_output[0], 0); // out: 0.0f in: 1.5 accum: 1.5f + ASSERT_EQUAL(int_output[1], 1); // out: 1.5f in: 2.5 accum: 4.0f + ASSERT_EQUAL(int_output[2], 4); // out: 4.0f in: 3.5 accum: 7.5f + ASSERT_EQUAL(int_output[3], 7); // out: 7.5f in: 4.5 accum: 12.f + + // float -> int should use plus<> operator and float accumulator by default thrust::exclusive_scan(float_input.begin(), float_input.end(), int_output.begin(), (float) 5.5); - ASSERT_EQUAL(int_output[0], 5); - ASSERT_EQUAL(int_output[1], 7); - ASSERT_EQUAL(int_output[2], 9); - ASSERT_EQUAL(int_output[3], 13); - - // int -> float should use using plus operator by default + ASSERT_EQUAL(int_output[0], 5); // out: 5.5f in: 1.5 accum: 7.0f + ASSERT_EQUAL(int_output[1], 7); // out: 7.0f in: 2.5 accum: 9.5f + ASSERT_EQUAL(int_output[2], 9); // out: 9.5f in: 3.5 accum: 13.0f + ASSERT_EQUAL(int_output[3], 13); // out: 13.f in: 4.5 accum: 17.4f + + // int -> float should use using plus<> operator and int accumulator by default thrust::inclusive_scan(int_input.begin(), int_input.end(), float_output.begin()); - ASSERT_EQUAL(float_output[0], 1.0); - ASSERT_EQUAL(float_output[1], 3.0); - ASSERT_EQUAL(float_output[2], 6.0); - ASSERT_EQUAL(float_output[3], 10.0); - - // int -> float should use using plus operator by default + ASSERT_EQUAL(float_output[0], 1.f); // in: 1 accum: 1 out: 1 + ASSERT_EQUAL(float_output[1], 3.f); // in: 2 accum: 3 out: 3 + ASSERT_EQUAL(float_output[2], 6.f); // in: 3 accum: 6 out: 6 + ASSERT_EQUAL(float_output[3], 10.f); // in: 4 accum: 10 out: 10 + + // int -> float + float init_value should use using plus<> operator and + // float accumulator by default thrust::exclusive_scan(int_input.begin(), int_input.end(), float_output.begin(), (float) 5.5); - ASSERT_EQUAL(float_output[0], 5.5); - ASSERT_EQUAL(float_output[1], 6.5); - ASSERT_EQUAL(float_output[2], 8.5); - ASSERT_EQUAL(float_output[3], 11.5); + ASSERT_EQUAL(float_output[0], 5.5f); // out: 5.5f in: 1 accum: 6.5f + ASSERT_EQUAL(float_output[1], 6.5f); // out: 6.0f in: 2 accum: 8.5f + ASSERT_EQUAL(float_output[2], 8.5f); // out: 8.0f in: 3 accum: 11.5f + ASSERT_EQUAL(float_output[3], 11.5f); // out: 11.f in: 4 accum: 15.5f } void TestScanMixedTypesHost(void) { diff --git a/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h b/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h deleted file mode 100644 index f221c915f..000000000 --- a/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2008-2013 NVIDIA Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace thrust -{ - -namespace detail -{ - -// this trait reports what type should be used as a temporary in certain algorithms -// which aggregate intermediate results from a function before writing to an output iterator - -// the pseudocode for deducing the type of the temporary used below: -// -// if Function is an AdaptableFunction -// result = Function::result_type -// else if OutputIterator2 is a "pure" output iterator -// result = InputIterator2::value_type -// else -// result = OutputIterator2::value_type -// -// XXX upon c++0x, TemporaryType needs to be: -// result_of_adaptable_function::type -template - struct intermediate_type_from_function_and_iterators - : eval_if< - has_result_type::value, - result_type, - eval_if< - is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - > - > -{ -}; // end intermediate_type_from_function_and_iterators - -} // end detail - -} // end thrust - diff --git a/thrust/functional.h b/thrust/functional.h index a550afddb..2a62539d2 100644 --- a/thrust/functional.h +++ b/thrust/functional.h @@ -139,6 +139,41 @@ struct binary_function * \{ */ +#define THRUST_UNARY_FUNCTOR_VOID_SPECIALIZATION(func, impl) \ + template <> \ + struct func \ + { \ + using is_transparent = void; \ + __thrust_exec_check_disable__ \ + template \ + __host__ __device__ \ + constexpr auto operator()(T&& x) const \ + noexcept(noexcept(impl)) -> decltype(impl) \ + { \ + return impl; \ + } \ + } + +#define THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION(func, impl) \ + template <> \ + struct func \ + { \ + using is_transparent = void; \ + __thrust_exec_check_disable__ \ + template \ + __host__ __device__ \ + constexpr auto operator()(T1&& t1, T2&& t2) const \ + noexcept(noexcept(impl)) -> decltype(impl) \ + { \ + return impl; \ + } \ + } + +#define THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(func, op) \ + THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION( \ + func, THRUST_FWD(t1) op THRUST_FWD(t2)) + + /*! \p plus is a function object. Specifically, it is an Adaptable Binary Function. * If \c f is an object of class plus, and \c x and \c y are objects * of class \c T, then f(x,y) returns x+y. @@ -172,7 +207,7 @@ struct binary_function * \see http://www.sgi.com/tech/stl/plus.html * \see binary_function */ -template +template struct plus { /*! \typedef first_argument_type @@ -193,9 +228,15 @@ struct plus /*! Function call operator. The return value is lhs + rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs + rhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs + rhs; + } }; // end plus +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(plus, +); + /*! \p minus is a function object. Specifically, it is an Adaptable Binary Function. * If \c f is an object of class minus, and \c x and \c y are objects * of class \c T, then f(x,y) returns x-y. @@ -229,7 +270,7 @@ struct plus * \see http://www.sgi.com/tech/stl/minus.html * \see binary_function */ -template +template struct minus { /*! \typedef first_argument_type @@ -250,9 +291,15 @@ struct minus /*! Function call operator. The return value is lhs - rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs - rhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs - rhs; + } }; // end minus +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(minus, -); + /*! \p multiplies is a function object. Specifically, it is an Adaptable Binary Function. * If \c f is an object of class multiplies, and \c x and \c y are objects * of class \c T, then f(x,y) returns x*y. @@ -286,7 +333,7 @@ struct minus * \see http://www.sgi.com/tech/stl/multiplies.html * \see binary_function */ -template +template struct multiplies { /*! \typedef first_argument_type @@ -307,9 +354,15 @@ struct multiplies /*! Function call operator. The return value is lhs * rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs * rhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs * rhs; + } }; // end multiplies +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(multiplies, *); + /*! \p divides is a function object. Specifically, it is an Adaptable Binary Function. * If \c f is an object of class divides, and \c x and \c y are objects * of class \c T, then f(x,y) returns x/y. @@ -343,7 +396,7 @@ struct multiplies * \see http://www.sgi.com/tech/stl/divides.html * \see binary_function */ -template +template struct divides { /*! \typedef first_argument_type @@ -364,9 +417,15 @@ struct divides /*! Function call operator. The return value is lhs / rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs / rhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs / rhs; + } }; // end divides +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(divides, /); + /*! \p modulus is a function object. Specifically, it is an Adaptable Binary Function. * If \c f is an object of class modulus, and \c x and \c y are objects * of class \c T, then f(x,y) returns x \% y. @@ -400,7 +459,7 @@ struct divides * \see http://www.sgi.com/tech/stl/modulus.html * \see binary_function */ -template +template struct modulus { /*! \typedef first_argument_type @@ -421,9 +480,15 @@ struct modulus /*! Function call operator. The return value is lhs % rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs % rhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs % rhs; + } }; // end modulus +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(modulus, %); + /*! \p negate is a function object. Specifically, it is an Adaptable Unary Function. * If \c f is an object of class negate, and \c x is an object * of class \c T, then f(x) returns -x. @@ -454,7 +519,7 @@ struct modulus * \see http://www.sgi.com/tech/stl/negate.html * \see unary_function */ -template +template struct negate { /*! \typedef argument_type @@ -470,9 +535,15 @@ struct negate /*! Function call operator. The return value is -x. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &x) const {return -x;} + __host__ __device__ + constexpr T operator()(const T &x) const + { + return -x; + } }; // end negate +THRUST_UNARY_FUNCTOR_VOID_SPECIALIZATION(negate, -THRUST_FWD(x)); + /*! \p square is a function object. Specifically, it is an Adaptable Unary Function. * If \c f is an object of class square, and \c x is an object * of class \c T, then f(x) returns x*x. @@ -502,7 +573,7 @@ struct negate * * \see unary_function */ -template +template struct square { /*! \typedef argument_type @@ -518,9 +589,15 @@ struct square /*! Function call operator. The return value is x*x. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &x) const {return x*x;} + __host__ __device__ + constexpr T operator()(const T &x) const + { + return x*x; + } }; // end square +THRUST_UNARY_FUNCTOR_VOID_SPECIALIZATION(square, x*x); + /*! \} */ @@ -540,7 +617,7 @@ struct square * \see http://www.sgi.com/tech/stl/equal_to.html * \see binary_function */ -template +template struct equal_to { /*! \typedef first_argument_type @@ -561,9 +638,15 @@ struct equal_to /*! Function call operator. The return value is lhs == rhs. */ __thrust_exec_check_disable__ - __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {return lhs == rhs;} + __host__ __device__ + constexpr bool operator()(const T &lhs, const T &rhs) const + { + return lhs == rhs; + } }; // end equal_to +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(equal_to, ==); + /*! \p not_equal_to is a function object. Specifically, it is an Adaptable Binary * Predicate, which means it is a function object that tests the truth or falsehood * of some condition. If \c f is an object of class not_equal_to and \c x @@ -575,7 +658,7 @@ struct equal_to * \see http://www.sgi.com/tech/stl/not_equal_to.html * \see binary_function */ -template +template struct not_equal_to { /*! \typedef first_argument_type @@ -596,9 +679,15 @@ struct not_equal_to /*! Function call operator. The return value is lhs != rhs. */ __thrust_exec_check_disable__ - __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {return lhs != rhs;} + __host__ __device__ + constexpr bool operator()(const T &lhs, const T &rhs) const + { + return lhs != rhs; + } }; // end not_equal_to +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(not_equal_to, !=); + /*! \p greater is a function object. Specifically, it is an Adaptable Binary * Predicate, which means it is a function object that tests the truth or falsehood * of some condition. If \c f is an object of class greater and \c x @@ -610,7 +699,7 @@ struct not_equal_to * \see http://www.sgi.com/tech/stl/greater.html * \see binary_function */ -template +template struct greater { /*! \typedef first_argument_type @@ -631,9 +720,15 @@ struct greater /*! Function call operator. The return value is lhs > rhs. */ __thrust_exec_check_disable__ - __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {return lhs > rhs;} + __host__ __device__ + constexpr bool operator()(const T &lhs, const T &rhs) const + { + return lhs > rhs; + } }; // end greater +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(greater, >); + /*! \p less is a function object. Specifically, it is an Adaptable Binary * Predicate, which means it is a function object that tests the truth or falsehood * of some condition. If \c f is an object of class less and \c x @@ -645,7 +740,7 @@ struct greater * \see http://www.sgi.com/tech/stl/less.html * \see binary_function */ -template +template struct less { /*! \typedef first_argument_type @@ -666,9 +761,15 @@ struct less /*! Function call operator. The return value is lhs < rhs. */ __thrust_exec_check_disable__ - __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {return lhs < rhs;} + __host__ __device__ + constexpr bool operator()(const T &lhs, const T &rhs) const + { + return lhs < rhs; + } }; // end less +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(less, <); + /*! \p greater_equal is a function object. Specifically, it is an Adaptable Binary * Predicate, which means it is a function object that tests the truth or falsehood * of some condition. If \c f is an object of class greater_equal and \c x @@ -680,7 +781,7 @@ struct less * \see http://www.sgi.com/tech/stl/greater_equal.html * \see binary_function */ -template +template struct greater_equal { /*! \typedef first_argument_type @@ -701,9 +802,15 @@ struct greater_equal /*! Function call operator. The return value is lhs >= rhs. */ __thrust_exec_check_disable__ - __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {return lhs >= rhs;} + __host__ __device__ + constexpr bool operator()(const T &lhs, const T &rhs) const + { + return lhs >= rhs; + } }; // end greater_equal +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(greater_equal, >=); + /*! \p less_equal is a function object. Specifically, it is an Adaptable Binary * Predicate, which means it is a function object that tests the truth or falsehood * of some condition. If \c f is an object of class less_equal and \c x @@ -715,7 +822,7 @@ struct greater_equal * \see http://www.sgi.com/tech/stl/less_equal.html * \see binary_function */ -template +template struct less_equal { /*! \typedef first_argument_type @@ -736,9 +843,15 @@ struct less_equal /*! Function call operator. The return value is lhs <= rhs. */ __thrust_exec_check_disable__ - __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {return lhs <= rhs;} + __host__ __device__ + constexpr bool operator()(const T &lhs, const T &rhs) const + { + return lhs <= rhs; + } }; // end less_equal +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(less_equal, <=); + /*! \} */ @@ -759,7 +872,7 @@ struct less_equal * \see http://www.sgi.com/tech/stl/logical_and.html * \see binary_function */ -template +template struct logical_and { /*! \typedef first_argument_type @@ -780,9 +893,15 @@ struct logical_and /*! Function call operator. The return value is lhs && rhs. */ __thrust_exec_check_disable__ - __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {return lhs && rhs;} + __host__ __device__ + constexpr bool operator()(const T &lhs, const T &rhs) const + { + return lhs && rhs; + } }; // end logical_and +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(logical_and, &&); + /*! \p logical_or is a function object. Specifically, it is an Adaptable Binary Predicate, * which means it is a function object that tests the truth or falsehood of some condition. * If \c f is an object of class logical_or and \c x and \c y are objects of @@ -794,7 +913,7 @@ struct logical_and * \see http://www.sgi.com/tech/stl/logical_or.html * \see binary_function */ -template +template struct logical_or { /*! \typedef first_argument_type @@ -815,9 +934,15 @@ struct logical_or /*! Function call operator. The return value is lhs || rhs. */ __thrust_exec_check_disable__ - __host__ __device__ bool operator()(const T &lhs, const T &rhs) const {return lhs || rhs;} + __host__ __device__ + constexpr bool operator()(const T &lhs, const T &rhs) const + { + return lhs || rhs; + } }; // end logical_or +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(logical_or, ||); + /*! \p logical_not is a function object. Specifically, it is an Adaptable Predicate, * which means it is a function object that tests the truth or falsehood of some condition. * If \c f is an object of class logical_not and \c x is an object of @@ -843,7 +968,7 @@ struct logical_or * \see http://www.sgi.com/tech/stl/logical_not.html * \see unary_function */ -template +template struct logical_not { /*! \typedef first_argument_type @@ -864,9 +989,15 @@ struct logical_not /*! Function call operator. The return value is !x. */ __thrust_exec_check_disable__ - __host__ __device__ bool operator()(const T &x) const {return !x;} + __host__ __device__ + constexpr bool operator()(const T &x) const + { + return !x; + } }; // end logical_not +THRUST_UNARY_FUNCTOR_VOID_SPECIALIZATION(logical_not, !THRUST_FWD(x)); + /*! \} */ @@ -907,7 +1038,7 @@ struct logical_not * * \see binary_function */ -template +template struct bit_and { /*! \typedef first_argument_type @@ -928,9 +1059,15 @@ struct bit_and /*! Function call operator. The return value is lhs & rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs & rhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs & rhs; + } }; // end bit_and +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(bit_and, &); + /*! \p bit_or is a function object. Specifically, it is an Adaptable Binary Function. * If \c f is an object of class bit_and, and \c x and \c y are objects * of class \c T, then f(x,y) returns x|y. @@ -963,7 +1100,7 @@ struct bit_and * * \see binary_function */ -template +template struct bit_or { /*! \typedef first_argument_type @@ -984,9 +1121,15 @@ struct bit_or /*! Function call operator. The return value is lhs | rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs | rhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs | rhs; + } }; // end bit_or +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(bit_or, |); + /*! \p bit_xor is a function object. Specifically, it is an Adaptable Binary Function. * If \c f is an object of class bit_and, and \c x and \c y are objects * of class \c T, then f(x,y) returns x^y. @@ -1019,7 +1162,7 @@ struct bit_or * * \see binary_function */ -template +template struct bit_xor { /*! \typedef first_argument_type @@ -1040,9 +1183,15 @@ struct bit_xor /*! Function call operator. The return value is lhs ^ rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs ^ rhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs ^ rhs; + } }; // end bit_xor +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP(bit_xor, ^); + /*! \} */ @@ -1071,7 +1220,7 @@ struct bit_xor * \see http://www.sgi.com/tech/stl/identity.html * \see unary_function */ -template +template struct identity { /*! \typedef argument_type @@ -1087,9 +1236,15 @@ struct identity /*! Function call operator. The return value is x. */ __thrust_exec_check_disable__ - __host__ __device__ const T &operator()(const T &x) const {return x;} + __host__ __device__ + constexpr const T &operator()(const T &x) const + { + return x; + } }; // end identity +THRUST_UNARY_FUNCTOR_VOID_SPECIALIZATION(identity, THRUST_FWD(x)); + /*! \p maximum is a function object that takes two arguments and returns the greater * of the two. Specifically, it is an Adaptable Binary Function. If \c f is an * object of class maximum and \c x and \c y are objects of class \c T @@ -1114,7 +1269,7 @@ struct identity * \see min * \see binary_function */ -template +template struct maximum { /*! \typedef first_argument_type @@ -1135,9 +1290,17 @@ struct maximum /*! Function call operator. The return value is rhs < lhs ? lhs : rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs < rhs ? rhs : lhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs < rhs ? rhs : lhs; + } }; // end maximum +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION(maximum, + t1 < t2 ? THRUST_FWD(t2) + : THRUST_FWD(t1)); + /*! \p minimum is a function object that takes two arguments and returns the lesser * of the two. Specifically, it is an Adaptable Binary Function. If \c f is an * object of class minimum and \c x and \c y are objects of class \c T @@ -1162,7 +1325,7 @@ struct maximum * \see max * \see binary_function */ -template +template struct minimum { /*! \typedef first_argument_type @@ -1183,10 +1346,18 @@ struct minimum /*! Function call operator. The return value is lhs < rhs ? lhs : rhs. */ __thrust_exec_check_disable__ - __host__ __device__ T operator()(const T &lhs, const T &rhs) const {return lhs < rhs ? lhs : rhs;} + __host__ __device__ + constexpr T operator()(const T &lhs, const T &rhs) const + { + return lhs < rhs ? lhs : rhs; + } }; // end minimum -/*! \p project1st is a function object that takes two arguments and returns +THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION(minimum, + t1 < t2 ? THRUST_FWD(t1) + : THRUST_FWD(t2)); + +/*! \p project1st is a function object that takes two arguments and returns * its first argument; the second argument is unused. It is essentially a * generalization of identity to the case of a Binary Function. * @@ -1204,7 +1375,7 @@ struct minimum * \see project2nd * \see binary_function */ -template +template struct project1st { /*! \typedef first_argument_type @@ -1224,10 +1395,28 @@ struct project1st /*! Function call operator. The return value is lhs. */ - __host__ __device__ const T1 &operator()(const T1 &lhs, const T2 & /*rhs*/) const {return lhs;} + __host__ __device__ + constexpr const T1 &operator()(const T1 &lhs, const T2 & /*rhs*/) const + { + return lhs; + } }; // end project1st -/*! \p project2nd is a function object that takes two arguments and returns +template <> +struct project1st +{ + using is_transparent = void; + __thrust_exec_check_disable__ + template + __host__ __device__ + constexpr auto operator()(T1&& t1, T2&&) const + noexcept(noexcept(THRUST_FWD(t1))) -> decltype(THRUST_FWD(t1)) + { + return THRUST_FWD(t1); + } +}; + +/*! \p project2nd is a function object that takes two arguments and returns * its second argument; the first argument is unused. It is essentially a * generalization of identity to the case of a Binary Function. * @@ -1245,7 +1434,7 @@ struct project1st * \see project1st * \see binary_function */ -template +template struct project2nd { /*! \typedef first_argument_type @@ -1265,13 +1454,30 @@ struct project2nd /*! Function call operator. The return value is rhs. */ - __host__ __device__ const T2 &operator()(const T1 &/*lhs*/, const T2 &rhs) const {return rhs;} + __host__ __device__ + constexpr const T2 &operator()(const T1 &/*lhs*/, const T2 &rhs) const + { + return rhs; + } }; // end project2nd +template <> +struct project2nd +{ + using is_transparent = void; + __thrust_exec_check_disable__ + template + __host__ __device__ + constexpr auto operator()(T1&&, T2&& t2) const + noexcept(noexcept(THRUST_FWD(t2))) -> decltype(THRUST_FWD(t2)) + { + return THRUST_FWD(t2); + } +}; + /*! \} */ - // odds and ends /*! \addtogroup function_object_adaptors @@ -1502,6 +1708,9 @@ THRUST_INLINE_CONSTANT thrust::detail::functional::placeholder<9>::type _10; /*! \} // placeholder_objects */ +#undef THRUST_UNARY_FUNCTOR_VOID_SPECIALIZATION +#undef THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION +#undef THRUST_BINARY_FUNCTOR_VOID_SPECIALIZATION_OP } // end thrust diff --git a/thrust/system/cuda/detail/transform_scan.h b/thrust/system/cuda/detail/transform_scan.h index 500152190..4e26f5c0f 100644 --- a/thrust/system/cuda/detail/transform_scan.h +++ b/thrust/system/cuda/detail/transform_scan.h @@ -50,26 +50,8 @@ transform_inclusive_scan(execution_policy &policy, TransformOp transform_op, ScanOp scan_op) { - // the pseudocode for deducing the type of the temporary used below: - // - // if UnaryFunction is AdaptableUnaryFunction - // TemporaryType = AdaptableUnaryFunction::result_type - // else if OutputIterator is a "pure" output iterator - // TemporaryType = InputIterator::value_type - // else - // TemporaryType = OutputIterator::value_type - // - // XXX upon c++0x, TemporaryType needs to be: - // result_of_adaptable_function::type - typedef typename thrust::detail::eval_if< - thrust::detail::has_result_type::value, - thrust::detail::result_type, - thrust::detail::eval_if< - thrust::detail::is_output_iterator::value, - iterator_value, - iterator_value - > - >::type result_type; + // Use the input iterator's value type per https://wg21.link/P0571 + using result_type = typename thrust::iterator_value::type; typedef typename iterator_traits::difference_type size_type; size_type num_items = static_cast(thrust::distance(first, last)); @@ -89,7 +71,7 @@ template OutputIt __host__ __device__ transform_exclusive_scan(execution_policy &policy, @@ -97,30 +79,11 @@ transform_exclusive_scan(execution_policy &policy, InputIt last, OutputIt result, TransformOp transform_op, - T init, + InitialValueType init, ScanOp scan_op) { - // the pseudocode for deducing the type of the temporary used below: - // - // if UnaryFunction is AdaptableUnaryFunction - // TemporaryType = AdaptableUnaryFunction::result_type - // else if OutputIterator is a "pure" output iterator - // TemporaryType = InputIterator::value_type - // else - // TemporaryType = OutputIterator::value_type - // - // XXX upon c++0x, TemporaryType needs to be: - // result_of_adaptable_function::type - - typedef typename thrust::detail::eval_if< - thrust::detail::has_result_type::value, - thrust::detail::result_type, - thrust::detail::eval_if< - thrust::detail::is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - > - >::type result_type; + // Use the initial value type per https://wg21.link/P0571 + using result_type = InitialValueType; typedef typename iterator_traits::difference_type size_type; size_type num_items = static_cast(thrust::distance(first, last)); diff --git a/thrust/system/detail/generic/reduce_by_key.inl b/thrust/system/detail/generic/reduce_by_key.inl index 41c2106b0..86640ea9f 100644 --- a/thrust/system/detail/generic/reduce_by_key.inl +++ b/thrust/system/detail/generic/reduce_by_key.inl @@ -91,27 +91,8 @@ __host__ __device__ typedef unsigned int FlagType; // TODO use difference_type - // the pseudocode for deducing the type of the temporary used below: - // - // if BinaryFunction is AdaptableBinaryFunction - // TemporaryType = AdaptableBinaryFunction::result_type - // else if OutputIterator2 is a "pure" output iterator - // TemporaryType = InputIterator2::value_type - // else - // TemporaryType = OutputIterator2::value_type - // - // XXX upon c++0x, TemporaryType needs to be: - // result_of_adaptable_function::type - - typedef typename thrust::detail::eval_if< - thrust::detail::has_result_type::value, - thrust::detail::result_type, - thrust::detail::eval_if< - thrust::detail::is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - > - >::type ValueType; + // Use the input iterator's value type per https://wg21.link/P0571 + using ValueType = typename thrust::iterator_value::type; if (keys_first == keys_last) return thrust::make_pair(keys_output, values_output); diff --git a/thrust/system/detail/generic/scan.inl b/thrust/system/detail/generic/scan.inl index 675d8f986..300b697b2 100644 --- a/thrust/system/detail/generic/scan.inl +++ b/thrust/system/detail/generic/scan.inl @@ -45,21 +45,8 @@ __host__ __device__ InputIterator last, OutputIterator result) { - // the pseudocode for deducing the type of the temporary used below: - // - // if OutputIterator is a "pure" output iterator - // TemporaryType = InputIterator::value_type - // else - // TemporaryType = OutputIterator::value_type - - typedef typename thrust::detail::eval_if< - thrust::detail::is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - >::type ValueType; - // assume plus as the associative operator - return thrust::inclusive_scan(exec, first, last, result, thrust::plus()); + return thrust::inclusive_scan(exec, first, last, result, thrust::plus<>()); } // end inclusive_scan() @@ -72,18 +59,8 @@ __host__ __device__ InputIterator last, OutputIterator result) { - // the pseudocode for deducing the type of the temporary used below: - // - // if OutputIterator is a "pure" output iterator - // TemporaryType = InputIterator::value_type - // else - // TemporaryType = OutputIterator::value_type - - typedef typename thrust::detail::eval_if< - thrust::detail::is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - >::type ValueType; + // Use the input iterator's value type per https://wg21.link/P0571 + using ValueType = typename thrust::iterator_value::type; // assume 0 as the initialization value return thrust::exclusive_scan(exec, first, last, result, ValueType(0)); @@ -102,7 +79,7 @@ __host__ __device__ T init) { // assume plus as the associative operator - return thrust::exclusive_scan(exec, first, last, result, init, thrust::plus()); + return thrust::exclusive_scan(exec, first, last, result, init, thrust::plus<>()); } // end exclusive_scan() diff --git a/thrust/system/detail/generic/scan_by_key.inl b/thrust/system/detail/generic/scan_by_key.inl index 129cef17b..d3d1667a9 100644 --- a/thrust/system/detail/generic/scan_by_key.inl +++ b/thrust/system/detail/generic/scan_by_key.inl @@ -89,8 +89,7 @@ __host__ __device__ OutputIterator result, BinaryPredicate binary_pred) { - typedef typename thrust::iterator_traits::value_type OutputType; - return thrust::inclusive_scan_by_key(exec, first1, last1, first2, result, binary_pred, thrust::plus()); + return thrust::inclusive_scan_by_key(exec, first1, last1, first2, result, binary_pred, thrust::plus<>()); } @@ -185,8 +184,7 @@ __host__ __device__ T init, BinaryPredicate binary_pred) { - typedef typename thrust::iterator_traits::value_type OutputType; - return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, init, binary_pred, thrust::plus()); + return thrust::exclusive_scan_by_key(exec, first1, last1, first2, result, init, binary_pred, thrust::plus<>()); } diff --git a/thrust/system/detail/generic/transform_scan.inl b/thrust/system/detail/generic/transform_scan.inl index e411613c6..1cc48d9a1 100644 --- a/thrust/system/detail/generic/transform_scan.inl +++ b/thrust/system/detail/generic/transform_scan.inl @@ -48,27 +48,8 @@ __host__ __device__ UnaryFunction unary_op, BinaryFunction binary_op) { - // the pseudocode for deducing the type of the temporary used below: - // - // if UnaryFunction is AdaptableUnaryFunction - // TemporaryType = AdaptableUnaryFunction::result_type - // else if OutputIterator is a "pure" output iterator - // TemporaryType = InputIterator::value_type - // else - // TemporaryType = OutputIterator::value_type - // - // XXX upon c++0x, TemporaryType needs to be: - // result_of_adaptable_function::type - - typedef typename thrust::detail::eval_if< - thrust::detail::has_result_type::value, - thrust::detail::result_type, - thrust::detail::eval_if< - thrust::detail::is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - > - >::type ValueType; + // Use the input iterator's value type per https://wg21.link/P0571 + using ValueType = typename thrust::iterator_value::type; thrust::transform_iterator _first(first, unary_op); thrust::transform_iterator _last(last, unary_op); @@ -81,7 +62,7 @@ template __host__ __device__ OutputIterator transform_exclusive_scan(thrust::execution_policy &exec, @@ -89,30 +70,11 @@ __host__ __device__ InputIterator last, OutputIterator result, UnaryFunction unary_op, - T init, + InitialValueType init, AssociativeOperator binary_op) { - // the pseudocode for deducing the type of the temporary used below: - // - // if UnaryFunction is AdaptableUnaryFunction - // TemporaryType = AdaptableUnaryFunction::result_type - // else if OutputIterator is a "pure" output iterator - // TemporaryType = InputIterator::value_type - // else - // TemporaryType = OutputIterator::value_type - // - // XXX upon c++0x, TemporaryType needs to be: - // result_of_adaptable_function::type - - typedef typename thrust::detail::eval_if< - thrust::detail::has_result_type::value, - thrust::detail::result_type, - thrust::detail::eval_if< - thrust::detail::is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - > - >::type ValueType; + // Use the initial value type per https://wg21.link/P0571 + using ValueType = InitialValueType; thrust::transform_iterator _first(first, unary_op); thrust::transform_iterator _last(last, unary_op); diff --git a/thrust/system/detail/sequential/reduce_by_key.h b/thrust/system/detail/sequential/reduce_by_key.h index f19e62a29..6e0741365 100644 --- a/thrust/system/detail/sequential/reduce_by_key.h +++ b/thrust/system/detail/sequential/reduce_by_key.h @@ -19,7 +19,6 @@ #include #include #include -#include #include namespace thrust @@ -54,11 +53,8 @@ __host__ __device__ typedef typename thrust::iterator_traits::value_type InputKeyType; typedef typename thrust::iterator_traits::value_type InputValueType; - typedef typename thrust::detail::intermediate_type_from_function_and_iterators< - InputIterator2, - OutputIterator2, - BinaryFunction - >::type TemporaryType; + // Use the input iterator's value type per https://wg21.link/P0571 + using TemporaryType = typename thrust::iterator_value::type; if(keys_first != keys_last) { diff --git a/thrust/system/detail/sequential/scan.h b/thrust/system/detail/sequential/scan.h index 3ac06a9eb..3bffc93d7 100644 --- a/thrust/system/detail/sequential/scan.h +++ b/thrust/system/detail/sequential/scan.h @@ -51,29 +51,10 @@ __host__ __device__ OutputIterator result, BinaryFunction binary_op) { - // the pseudocode for deducing the type of the temporary used below: - // - // if BinaryFunction is AdaptableBinaryFunction - // TemporaryType = AdaptableBinaryFunction::result_type - // else if OutputIterator is a "pure" output iterator - // TemporaryType = InputIterator::value_type - // else - // TemporaryType = OutputIterator::value_type - // - // XXX upon c++0x, TemporaryType needs to be: - // result_of_adaptable_function::type - using namespace thrust::detail; - typedef typename eval_if< - has_result_type::value, - result_type, - eval_if< - is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - > - >::type ValueType; + // Use the input iterator's value type per https://wg21.link/P0571 + using ValueType = typename thrust::iterator_value::type; // wrap binary_op thrust::detail::wrapped_function< @@ -99,39 +80,20 @@ __thrust_exec_check_disable__ template __host__ __device__ OutputIterator exclusive_scan(sequential::execution_policy &, InputIterator first, InputIterator last, OutputIterator result, - T init, + InitialValueType init, BinaryFunction binary_op) { - // the pseudocode for deducing the type of the temporary used below: - // - // if BinaryFunction is AdaptableBinaryFunction - // TemporaryType = AdaptableBinaryFunction::result_type - // else if OutputIterator is a "pure" output iterator - // TemporaryType = InputIterator::value_type - // else - // TemporaryType = OutputIterator::value_type - // - // XXX upon c++0x, TemporaryType needs to be: - // result_of_adaptable_function::type - using namespace thrust::detail; - typedef typename eval_if< - has_result_type::value, - result_type, - eval_if< - is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - > - >::type ValueType; + // Use the initial value type per https://wg21.link/P0571 + using ValueType = InitialValueType; if(first != last) { diff --git a/thrust/system/tbb/detail/scan.inl b/thrust/system/tbb/detail/scan.inl index 477c04ee3..88fb999c6 100644 --- a/thrust/system/tbb/detail/scan.inl +++ b/thrust/system/tbb/detail/scan.inl @@ -208,18 +208,10 @@ template::value, - result_type, - eval_if< - is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - > - >::type ValueType; - - typedef typename thrust::iterator_difference::type Size; - + // Use the input iterator's value type per https://wg21.link/P0571 + using ValueType = typename thrust::iterator_value::type; + + using Size = typename thrust::iterator_difference::type; Size n = thrust::distance(first, last); if (n != 0) @@ -237,13 +229,13 @@ template OutputIterator exclusive_scan(tag, InputIterator first, InputIterator last, OutputIterator result, - T init, + InitialValueType init, BinaryFunction binary_op) { // the pseudocode for deducing the type of the temporary used below: @@ -260,18 +252,10 @@ template::value, - result_type, - eval_if< - is_output_iterator::value, - thrust::iterator_value, - thrust::iterator_value - > - >::type ValueType; - - typedef typename thrust::iterator_difference::type Size; - + // Use the initial value type per https://wg21.link/P0571 + using ValueType = InitialValueType; + + using Size = typename thrust::iterator_difference::type; Size n = thrust::distance(first, last); if (n != 0)