Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Add transform_output_iterator and transform_input_output_iterator default constructors #1805

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 105 additions & 63 deletions testing/transform_output_iterator.cu
Original file line number Diff line number Diff line change
@@ -1,91 +1,133 @@
#include <unittest/unittest.h>
#include <thrust/iterator/transform_output_iterator.h>

#include <thrust/copy.h>
#include <thrust/reduce.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/sequence.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>

#include <unittest/random.h>
#include <unittest/unittest.h>

template <class Vector>
void TestTransformOutputIterator(void)
{
typedef typename Vector::value_type T;
typedef typename Vector::value_type T;

typedef thrust::square<T> UnaryFunction;
typedef typename Vector::iterator Iterator;

typedef thrust::square<T> UnaryFunction;
typedef typename Vector::iterator Iterator;
Vector input(4);
Vector output(4);

Vector input(4);
Vector output(4);

// initialize input
thrust::sequence(input.begin(), input.end(), T{1});

// construct transform_iterator
thrust::transform_output_iterator<UnaryFunction, Iterator> output_iter(output.begin(), UnaryFunction());
// initialize input
thrust::sequence(input.begin(), input.end(), T{1});

thrust::copy(input.begin(), input.end(), output_iter);
// construct transform_iterator
thrust::transform_output_iterator<UnaryFunction, Iterator> output_iter(output.begin(),
UnaryFunction());

Vector gold_output(4);
gold_output[0] = 1;
gold_output[1] = 4;
gold_output[2] = 9;
gold_output[3] = 16;
thrust::copy(input.begin(), input.end(), output_iter);

ASSERT_EQUAL(output, gold_output);
Vector gold_output(4);
gold_output[0] = 1;
gold_output[1] = 4;
gold_output[2] = 9;
gold_output[3] = 16;

ASSERT_EQUAL(output, gold_output);
}
DECLARE_VECTOR_UNITTEST(TestTransformOutputIterator);

template <class Vector>
void TestMakeTransformOutputIterator(void)
{
typedef typename Vector::value_type T;

typedef thrust::square<T> UnaryFunction;

Vector input(4);
Vector output(4);

// initialize input
thrust::sequence(input.begin(), input.end(), 1);

thrust::copy(input.begin(), input.end(),
thrust::make_transform_output_iterator(output.begin(), UnaryFunction()));

Vector gold_output(4);
gold_output[0] = 1;
gold_output[1] = 4;
gold_output[2] = 9;
gold_output[3] = 16;
ASSERT_EQUAL(output, gold_output);
typedef typename Vector::value_type T;

typedef thrust::square<T> UnaryFunction;

Vector input(4);
Vector output(4);

// initialize input
thrust::sequence(input.begin(), input.end(), 1);

thrust::copy(input.begin(),
input.end(),
thrust::make_transform_output_iterator(output.begin(), UnaryFunction()));

Vector gold_output(4);
gold_output[0] = 1;
gold_output[1] = 4;
gold_output[2] = 9;
gold_output[3] = 16;
ASSERT_EQUAL(output, gold_output);
}
DECLARE_VECTOR_UNITTEST(TestMakeTransformOutputIterator);

template <typename T>
struct TestTransformOutputIteratorScan
{
void operator()(const size_t n)
{
thrust::host_vector<T> h_data = unittest::random_samples<T>(n);
thrust::device_vector<T> d_data = h_data;

thrust::host_vector<T> h_result(n);
thrust::device_vector<T> d_result(n);

// run on host
thrust::inclusive_scan(thrust::make_transform_iterator(h_data.begin(), thrust::negate<T>()),
thrust::make_transform_iterator(h_data.end(), thrust::negate<T>()),
h_result.begin());
// run on device
thrust::inclusive_scan(d_data.begin(), d_data.end(),
thrust::make_transform_output_iterator(
d_result.begin(), thrust::negate<T>()));


ASSERT_EQUAL(h_result, d_result);
}
void operator()(const size_t n)
{
thrust::host_vector<T> h_data = unittest::random_samples<T>(n);
thrust::device_vector<T> d_data = h_data;

thrust::host_vector<T> h_result(n);
thrust::device_vector<T> d_result(n);

// run on host
thrust::inclusive_scan(thrust::make_transform_iterator(h_data.begin(), thrust::negate<T>()),
thrust::make_transform_iterator(h_data.end(), thrust::negate<T>()),
h_result.begin());
// run on device
thrust::inclusive_scan(d_data.begin(),
d_data.end(),
thrust::make_transform_output_iterator(d_result.begin(),
thrust::negate<T>()));

ASSERT_EQUAL(h_result, d_result);
}
};
VariableUnitTest<TestTransformOutputIteratorScan, SignedIntegralTypes> TestTransformOutputIteratorScanInstance;
VariableUnitTest<TestTransformOutputIteratorScan, SignedIntegralTypes>
TestTransformOutputIteratorScanInstance;

template <typename T>
struct TestTransformOutputIteratorReduceByKey
{
void operator()(const size_t n)
{
thrust::host_vector<T> h_keys = unittest::random_samples<T>(n);
thrust::sort(h_keys.begin(), h_keys.end());
thrust::device_vector<T> d_keys = h_keys;

thrust::host_vector<T> h_values = unittest::random_samples<T>(n);
thrust::device_vector<T> d_values = h_values;

thrust::host_vector<T> h_result(n);
thrust::device_vector<T> d_result(n);

// run on host
thrust::reduce_by_key(thrust::host,
h_keys.begin(),
h_keys.end(),
thrust::make_transform_iterator(h_values.begin(), thrust::negate<T>()),
thrust::discard_iterator<T>{},
h_result.begin());
// run on device
thrust::reduce_by_key(thrust::device,
d_keys.begin(),
d_keys.end(),
d_values.begin(),
thrust::discard_iterator<T>{},
thrust::make_transform_output_iterator(d_result.begin(),
thrust::negate<T>()));

ASSERT_EQUAL(h_result, d_result);
}
};
VariableUnitTest<TestTransformOutputIteratorReduceByKey, SignedIntegralTypes>
TestTransformOutputIteratorReduceByKeyInstance;
75 changes: 40 additions & 35 deletions thrust/iterator/transform_input_output_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ THRUST_NAMESPACE_BEGIN
* // Iterator that returns negated values and writes squared values
* auto iter = thrust::make_transform_input_output_iterator(v.begin(),
* thrust::negate<float>{}, thrust::square<float>{});
*
*
* // Iterator negates values when reading
* std::cout << iter[0] << " "; // -1.0f;
* std::cout << iter[1] << " "; // -2.0f;
Expand All @@ -85,23 +85,25 @@ THRUST_NAMESPACE_BEGIN
*/

template <typename InputFunction, typename OutputFunction, typename Iterator>
class transform_input_output_iterator
class transform_input_output_iterator
: public detail::transform_input_output_iterator_base<InputFunction, OutputFunction, Iterator>::type
{

/*! \cond
*/

public:

typedef typename
detail::transform_input_output_iterator_base<InputFunction, OutputFunction, Iterator>::type
super_t;
public:
typedef typename detail::
transform_input_output_iterator_base<InputFunction, OutputFunction, Iterator>::type super_t;

friend class thrust::iterator_core_access;
friend class thrust::iterator_core_access;
/*! \endcond
*/

/*! Null constructor does nothing.
*/
__host__ __device__ transform_input_output_iterator() {}
harrism marked this conversation as resolved.
Show resolved Hide resolved

/*! This constructor takes as argument a \c Iterator an \c InputFunction and an
* \c OutputFunction and copies them to a new \p transform_input_output_iterator
*
Expand All @@ -110,29 +112,30 @@ template <typename InputFunction, typename OutputFunction, typename Iterator>
* \param input_function An \c InputFunction to be executed on values read from the iterator
* \param output_function An \c OutputFunction to be executed on values written to the iterator
*/
__host__ __device__
transform_input_output_iterator(Iterator const& io, InputFunction input_function, OutputFunction output_function)
: super_t(io), input_function(input_function), output_function(output_function)
{
}

/*! \cond
*/
private:

__host__ __device__
typename super_t::reference dereference() const
{
return detail::transform_input_output_iterator_proxy<
InputFunction, OutputFunction, Iterator
>(this->base_reference(), input_function, output_function);
}

InputFunction input_function;
OutputFunction output_function;

/*! \endcond
*/
__host__ __device__ transform_input_output_iterator(Iterator const &io,
InputFunction input_function,
OutputFunction output_function)
: super_t(io)
, input_function(input_function)
, output_function(output_function)
{}

/*! \cond
*/
private:
__host__ __device__ typename super_t::reference dereference() const
{
return detail::transform_input_output_iterator_proxy<InputFunction, OutputFunction, Iterator>(
this->base_reference(),
input_function,
output_function);
}

InputFunction input_function;
OutputFunction output_function;

/*! \endcond
*/
}; // end transform_input_output_iterator

/*! \p make_transform_input_output_iterator creates a \p transform_input_output_iterator from
Expand All @@ -146,10 +149,13 @@ template <typename InputFunction, typename OutputFunction, typename Iterator>
*/
template <typename InputFunction, typename OutputFunction, typename Iterator>
transform_input_output_iterator<InputFunction, OutputFunction, Iterator>
__host__ __device__
make_transform_input_output_iterator(Iterator io, InputFunction input_function, OutputFunction output_function)
__host__ __device__ make_transform_input_output_iterator(Iterator io,
InputFunction input_function,
OutputFunction output_function)
{
return transform_input_output_iterator<InputFunction, OutputFunction, Iterator>(io, input_function, output_function);
return transform_input_output_iterator<InputFunction, OutputFunction, Iterator>(io,
input_function,
output_function);
} // end make_transform_input_output_iterator

/*! \} // end fancyiterators
Expand All @@ -159,4 +165,3 @@ make_transform_input_output_iterator(Iterator io, InputFunction input_function,
*/

THRUST_NAMESPACE_END

miscco marked this conversation as resolved.
Show resolved Hide resolved
Loading