Skip to content

Commit

Permalink
Add default ctor to transform[_input]_output_iterator and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
harrism committed Oct 5, 2022
1 parent d3e6fa1 commit 6cdb69d
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 136 deletions.
168 changes: 105 additions & 63 deletions testing/transform_output_iterator.cu
Original file line number Diff line number Diff line change
@@ -1,91 +1,133 @@
#include <unittest/unittest.h>
#include <thrust/iterator/transform_output_iterator.h>

#include <thrust/copy.h>
#include <thrust/reduce.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/sequence.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>

#include <unittest/random.h>
#include <unittest/unittest.h>

template <class Vector>
void TestTransformOutputIterator(void)
{
typedef typename Vector::value_type T;
typedef typename Vector::value_type T;

typedef thrust::square<T> UnaryFunction;
typedef typename Vector::iterator Iterator;

typedef thrust::square<T> UnaryFunction;
typedef typename Vector::iterator Iterator;
Vector input(4);
Vector output(4);

Vector input(4);
Vector output(4);

// initialize input
thrust::sequence(input.begin(), input.end(), T{1});

// construct transform_iterator
thrust::transform_output_iterator<UnaryFunction, Iterator> output_iter(output.begin(), UnaryFunction());
// initialize input
thrust::sequence(input.begin(), input.end(), T{1});

thrust::copy(input.begin(), input.end(), output_iter);
// construct transform_iterator
thrust::transform_output_iterator<UnaryFunction, Iterator> output_iter(output.begin(),
UnaryFunction());

Vector gold_output(4);
gold_output[0] = 1;
gold_output[1] = 4;
gold_output[2] = 9;
gold_output[3] = 16;
thrust::copy(input.begin(), input.end(), output_iter);

ASSERT_EQUAL(output, gold_output);
Vector gold_output(4);
gold_output[0] = 1;
gold_output[1] = 4;
gold_output[2] = 9;
gold_output[3] = 16;

ASSERT_EQUAL(output, gold_output);
}
DECLARE_VECTOR_UNITTEST(TestTransformOutputIterator);

template <class Vector>
void TestMakeTransformOutputIterator(void)
{
typedef typename Vector::value_type T;

typedef thrust::square<T> UnaryFunction;

Vector input(4);
Vector output(4);

// initialize input
thrust::sequence(input.begin(), input.end(), 1);

thrust::copy(input.begin(), input.end(),
thrust::make_transform_output_iterator(output.begin(), UnaryFunction()));

Vector gold_output(4);
gold_output[0] = 1;
gold_output[1] = 4;
gold_output[2] = 9;
gold_output[3] = 16;
ASSERT_EQUAL(output, gold_output);
typedef typename Vector::value_type T;

typedef thrust::square<T> UnaryFunction;

Vector input(4);
Vector output(4);

// initialize input
thrust::sequence(input.begin(), input.end(), 1);

thrust::copy(input.begin(),
input.end(),
thrust::make_transform_output_iterator(output.begin(), UnaryFunction()));

Vector gold_output(4);
gold_output[0] = 1;
gold_output[1] = 4;
gold_output[2] = 9;
gold_output[3] = 16;
ASSERT_EQUAL(output, gold_output);
}
DECLARE_VECTOR_UNITTEST(TestMakeTransformOutputIterator);

template <typename T>
struct TestTransformOutputIteratorScan
{
void operator()(const size_t n)
{
thrust::host_vector<T> h_data = unittest::random_samples<T>(n);
thrust::device_vector<T> d_data = h_data;

thrust::host_vector<T> h_result(n);
thrust::device_vector<T> d_result(n);

// run on host
thrust::inclusive_scan(thrust::make_transform_iterator(h_data.begin(), thrust::negate<T>()),
thrust::make_transform_iterator(h_data.end(), thrust::negate<T>()),
h_result.begin());
// run on device
thrust::inclusive_scan(d_data.begin(), d_data.end(),
thrust::make_transform_output_iterator(
d_result.begin(), thrust::negate<T>()));


ASSERT_EQUAL(h_result, d_result);
}
void operator()(const size_t n)
{
thrust::host_vector<T> h_data = unittest::random_samples<T>(n);
thrust::device_vector<T> d_data = h_data;

thrust::host_vector<T> h_result(n);
thrust::device_vector<T> d_result(n);

// run on host
thrust::inclusive_scan(thrust::make_transform_iterator(h_data.begin(), thrust::negate<T>()),
thrust::make_transform_iterator(h_data.end(), thrust::negate<T>()),
h_result.begin());
// run on device
thrust::inclusive_scan(d_data.begin(),
d_data.end(),
thrust::make_transform_output_iterator(d_result.begin(),
thrust::negate<T>()));

ASSERT_EQUAL(h_result, d_result);
}
};
VariableUnitTest<TestTransformOutputIteratorScan, SignedIntegralTypes> TestTransformOutputIteratorScanInstance;
VariableUnitTest<TestTransformOutputIteratorScan, SignedIntegralTypes>
TestTransformOutputIteratorScanInstance;

template <typename T>
struct TestTransformOutputIteratorReduceByKey
{
void operator()(const size_t n)
{
thrust::host_vector<T> h_keys = unittest::random_samples<T>(n);
thrust::sort(h_keys.begin(), h_keys.end());
thrust::device_vector<T> d_keys = h_keys;

thrust::host_vector<T> h_values = unittest::random_samples<T>(n);
thrust::device_vector<T> d_values = h_values;

thrust::host_vector<T> h_result(n);
thrust::device_vector<T> d_result(n);

// run on host
thrust::reduce_by_key(thrust::host,
h_keys.begin(),
h_keys.end(),
thrust::make_transform_iterator(h_values.begin(), thrust::negate<T>()),
thrust::discard_iterator<T>{},
h_result.begin());
// run on device
thrust::reduce_by_key(thrust::device,
d_keys.begin(),
d_keys.end(),
d_values.begin(),
thrust::discard_iterator<T>{},
thrust::make_transform_output_iterator(d_result.begin(),
thrust::negate<T>()));

ASSERT_EQUAL(h_result, d_result);
}
};
VariableUnitTest<TestTransformOutputIteratorReduceByKey, SignedIntegralTypes>
TestTransformOutputIteratorReduceByKeyInstance;
75 changes: 40 additions & 35 deletions thrust/iterator/transform_input_output_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ THRUST_NAMESPACE_BEGIN
* // Iterator that returns negated values and writes squared values
* auto iter = thrust::make_transform_input_output_iterator(v.begin(),
* thrust::negate<float>{}, thrust::square<float>{});
*
*
* // Iterator negates values when reading
* std::cout << iter[0] << " "; // -1.0f;
* std::cout << iter[1] << " "; // -2.0f;
Expand All @@ -85,23 +85,25 @@ THRUST_NAMESPACE_BEGIN
*/

template <typename InputFunction, typename OutputFunction, typename Iterator>
class transform_input_output_iterator
class transform_input_output_iterator
: public detail::transform_input_output_iterator_base<InputFunction, OutputFunction, Iterator>::type
{

/*! \cond
*/

public:

typedef typename
detail::transform_input_output_iterator_base<InputFunction, OutputFunction, Iterator>::type
super_t;
public:
typedef typename detail::
transform_input_output_iterator_base<InputFunction, OutputFunction, Iterator>::type super_t;

friend class thrust::iterator_core_access;
friend class thrust::iterator_core_access;
/*! \endcond
*/

/*! Null constructor does nothing.
*/
__host__ __device__ transform_input_output_iterator() {}

/*! This constructor takes as argument a \c Iterator an \c InputFunction and an
* \c OutputFunction and copies them to a new \p transform_input_output_iterator
*
Expand All @@ -110,29 +112,30 @@ template <typename InputFunction, typename OutputFunction, typename Iterator>
* \param input_function An \c InputFunction to be executed on values read from the iterator
* \param output_function An \c OutputFunction to be executed on values written to the iterator
*/
__host__ __device__
transform_input_output_iterator(Iterator const& io, InputFunction input_function, OutputFunction output_function)
: super_t(io), input_function(input_function), output_function(output_function)
{
}

/*! \cond
*/
private:

__host__ __device__
typename super_t::reference dereference() const
{
return detail::transform_input_output_iterator_proxy<
InputFunction, OutputFunction, Iterator
>(this->base_reference(), input_function, output_function);
}

InputFunction input_function;
OutputFunction output_function;

/*! \endcond
*/
__host__ __device__ transform_input_output_iterator(Iterator const &io,
InputFunction input_function,
OutputFunction output_function)
: super_t(io)
, input_function(input_function)
, output_function(output_function)
{}

/*! \cond
*/
private:
__host__ __device__ typename super_t::reference dereference() const
{
return detail::transform_input_output_iterator_proxy<InputFunction, OutputFunction, Iterator>(
this->base_reference(),
input_function,
output_function);
}

InputFunction input_function;
OutputFunction output_function;

/*! \endcond
*/
}; // end transform_input_output_iterator

/*! \p make_transform_input_output_iterator creates a \p transform_input_output_iterator from
Expand All @@ -146,10 +149,13 @@ template <typename InputFunction, typename OutputFunction, typename Iterator>
*/
template <typename InputFunction, typename OutputFunction, typename Iterator>
transform_input_output_iterator<InputFunction, OutputFunction, Iterator>
__host__ __device__
make_transform_input_output_iterator(Iterator io, InputFunction input_function, OutputFunction output_function)
__host__ __device__ make_transform_input_output_iterator(Iterator io,
InputFunction input_function,
OutputFunction output_function)
{
return transform_input_output_iterator<InputFunction, OutputFunction, Iterator>(io, input_function, output_function);
return transform_input_output_iterator<InputFunction, OutputFunction, Iterator>(io,
input_function,
output_function);
} // end make_transform_input_output_iterator

/*! \} // end fancyiterators
Expand All @@ -159,4 +165,3 @@ make_transform_input_output_iterator(Iterator io, InputFunction input_function,
*/

THRUST_NAMESPACE_END

Loading

0 comments on commit 6cdb69d

Please sign in to comment.