Skip to content

Commit

Permalink
Implement thrust::equal in terms of thrust::all_if
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jun 24, 2024
1 parent 5289892 commit 558a09d
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions thrust/thrust/system/cuda/detail/equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@
#endif // no system header

#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC
# include <thrust/system/cuda/config.h>

# include <thrust/system/cuda/detail/mismatch.h>
# include <thrust/logical.h>

THRUST_NAMESPACE_BEGIN
namespace cuda_cub
Expand All @@ -49,7 +48,10 @@ template <class Derived, class InputIt1, class InputIt2, class BinaryPred>
bool _CCCL_HOST_DEVICE
equal(execution_policy<Derived>& policy, InputIt1 first1, InputIt1 last1, InputIt2 first2, BinaryPred binary_pred)
{
return cuda_cub::mismatch(policy, first1, last1, first2, binary_pred).first == last1;
const auto n = distance(first1, last1);
using transform_t = transform_pair_of_input_iterators_t<bool, InputIt1, InputIt2, BinaryPred>;
transform_t transformed_first = transform_t(first1, first2, binary_pred);
return thrust::all_of(policy, transformed_first, transformed_first + n, identity{});
}

template <class Derived, class InputIt1, class InputIt2>
Expand Down

0 comments on commit 558a09d

Please sign in to comment.