Skip to content

Commit

Permalink
Merge pull request NVIDIA#1805 from harrism/fix-transform_output_iter…
Browse files Browse the repository at this point in the history
…ator-default-ctor

Add transform_output_iterator and transform_input_output_iterator default constructors
  • Loading branch information
gevtushenko authored Dec 2, 2022
2 parents b2cd968 + abd0bed commit fb758b8
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 3 deletions.
4 changes: 4 additions & 0 deletions testing/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ thrust_declare_test_restrictions(future CPP.CUDA OMP.CUDA TBB.CUDA)
# for CUDA.
thrust_declare_test_restrictions(unittest_static_assert CPP.CPP CPP.CUDA)

# In the TBB backend, reduce_by_key does not currently work with transform_output_iterator
# https://github.com/NVIDIA/thrust/issues/1811
thrust_declare_test_restrictions(transform_output_iterator_reduce_by_key CPP.CPP CPP.OMP CPP.CUDA)

## thrust_add_test
#
# Add a test executable and register it with ctest.
Expand Down
8 changes: 5 additions & 3 deletions testing/transform_output_iterator.cu
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#include <unittest/unittest.h>
#include <thrust/iterator/transform_output_iterator.h>

#include <thrust/copy.h>
#include <thrust/reduce.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/sequence.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>

template <class Vector>
void TestTransformOutputIterator(void)
Expand Down
51 changes: 51 additions & 0 deletions testing/transform_output_iterator_reduce_by_key.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include <unittest/unittest.h>

#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>


template <typename T>
struct TestTransformOutputIteratorReduceByKey
{
void operator()(const size_t n)
{
thrust::host_vector<T> h_keys = unittest::random_samples<T>(n);
thrust::sort(h_keys.begin(), h_keys.end());
thrust::device_vector<T> d_keys = h_keys;

thrust::host_vector<T> h_values = unittest::random_samples<T>(n);
thrust::device_vector<T> d_values = h_values;

thrust::host_vector<T> h_result(n);
thrust::device_vector<T> d_result(n);

// run on host
thrust::reduce_by_key(thrust::host,
h_keys.begin(),
h_keys.end(),
thrust::make_transform_iterator(h_values.begin(), thrust::negate<T>()),
thrust::discard_iterator<T>{},
h_result.begin());
// run on device
thrust::reduce_by_key(thrust::device,
d_keys.begin(),
d_keys.end(),
d_values.begin(),
thrust::discard_iterator<T>{},
thrust::make_transform_output_iterator(d_result.begin(),
thrust::negate<T>()));

ASSERT_EQUAL(h_result, d_result);
}
};
VariableUnitTest<TestTransformOutputIteratorReduceByKey, SignedIntegralTypes>
TestTransformOutputIteratorReduceByKeyInstance;

2 changes: 2 additions & 0 deletions thrust/iterator/transform_input_output_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ template <typename InputFunction, typename OutputFunction, typename Iterator>
/*! \endcond
*/

transform_input_output_iterator() = default;

/*! This constructor takes as argument a \c Iterator an \c InputFunction and an
* \c OutputFunction and copies them to a new \p transform_input_output_iterator
*
Expand Down
2 changes: 2 additions & 0 deletions thrust/iterator/transform_output_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ template <typename UnaryFunction, typename OutputIterator>
/*! \endcond
*/

transform_output_iterator() = default;

/*! This constructor takes as argument an \c OutputIterator and an \c
* UnaryFunction and copies them to a new \p transform_output_iterator
*
Expand Down

0 comments on commit fb758b8

Please sign in to comment.