Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several improvements to zip_iterator/zip_function #1710

Merged
merged 6 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions thrust/testing/zip_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ struct SumThreeTuple
THRUST_DECLTYPE_RETURNS(thrust::get<0>(x) + thrust::get<1>(x) + thrust::get<2>(x))
}; // end SumThreeTuple

template <typename T>
struct TestZipFunctionCtor
{
void operator()()
{
ASSERT_EQUAL(thrust::zip_function<SumThree>()(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
ASSERT_EQUAL(thrust::zip_function<SumThree>(SumThree{})(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
# ifdef __cpp_deduction_guides
ASSERT_EQUAL(thrust::zip_function(SumThree{})(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
# endif // __cpp_deduction_guides
}
};
SimpleUnitTest<TestZipFunctionCtor, type_list<int>> TestZipFunctionCtorInstance;

template <typename T>
struct TestZipFunctionTransform
{
Expand Down
1 change: 1 addition & 0 deletions thrust/testing/zip_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct TestZipIteratorManipulation

// test construction
ZipIterator iter0 = make_zip_iterator(t);
ASSERT_EQUAL(true, iter0 == ZipIterator{t});

ASSERT_EQUAL_QUIET(v0.begin(), get<0>(iter0.get_iterator_tuple()));
ASSERT_EQUAL_QUIET(v1.begin(), get<1>(iter0.get_iterator_tuple()));
Expand Down
40 changes: 16 additions & 24 deletions thrust/thrust/iterator/zip_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,20 @@ THRUST_NAMESPACE_BEGIN
* #include <thrust/tuple.h>
* #include <thrust/device_vector.h>
* ...
* thrust::device_vector<int> int_v(3);
* int_v[0] = 0; int_v[1] = 1; int_v[2] = 2;
* thrust::device_vector<int> int_v{0, 1, 2};
* thrust::device_vector<float> float_v{0.0f, 1.0f, 2.0f};
* thrust::device_vector<char> char_v{'a', 'b', 'c'};
*
* thrust::device_vector<float> float_v(3);
* float_v[0] = 0.0f; float_v[1] = 1.0f; float_v[2] = 2.0f;
* // aliases for iterators
* using IntIterator = thrust::device_vector<int>::iterator;
* using FloatIterator = thrust::device_vector<float>::iterator;
* using CharIterator = thrust::device_vector<char>::iterator;
*
* thrust::device_vector<char> char_v(3);
* char_v[0] = 'a'; char_v[1] = 'b'; char_v[2] = 'c';
*
* // typedef these iterators for shorthand
* typedef thrust::device_vector<int>::iterator IntIterator;
* typedef thrust::device_vector<float>::iterator FloatIterator;
* typedef thrust::device_vector<char>::iterator CharIterator;
*
* // typedef a tuple of these iterators
* typedef thrust::tuple<IntIterator, FloatIterator, CharIterator> IteratorTuple;
* // alias for a tuple of these iterators
* using IteratorTuple = thrust::tuple<IntIterator, FloatIterator, CharIterator>;
*
* // typedef the zip_iterator of this tuple
* typedef thrust::zip_iterator<IteratorTuple> ZipIterator;
* using ZipIterator = thrust::zip_iterator<IteratorTuple>;
*
* // finally, create the zip_iterator
* ZipIterator iter(thrust::make_tuple(int_v.begin(), float_v.begin(), char_v.begin()));
Expand Down Expand Up @@ -116,15 +111,8 @@ THRUST_NAMESPACE_BEGIN
*
* int main()
* {
* thrust::device_vector<int> int_in(3), int_out(3);
* int_in[0] = 0;
* int_in[1] = 1;
* int_in[2] = 2;
*
* thrust::device_vector<float> float_in(3), float_out(3);
* float_in[0] = 0.0f;
* float_in[1] = 10.0f;
* float_in[2] = 20.0f;
* thrust::device_vector<int> int_in{0, 1, 2}, int_out(3);
* thrust::device_vector<float> float_in{0.0f, 10.0f, 20.0f}, float_out(3);
*
* thrust::copy(thrust::make_zip_iterator(thrust::make_tuple(int_in.begin(), float_in.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(int_in.end(), float_in.end())),
Expand All @@ -146,6 +134,10 @@ template <typename IteratorTuple>
class zip_iterator : public detail::zip_iterator_base<IteratorTuple>::type
{
public:
/*! The underlying iterator tuple type. Alias to zip_iterator's first template argument.
*/
using iterator_tuple = IteratorTuple;

/*! Default constructor does nothing.
*/
#if defined(_CCCL_COMPILER_MSVC_2017)
Expand Down
55 changes: 25 additions & 30 deletions thrust/thrust/zip_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,54 +95,40 @@ _CCCL_HOST_DEVICE auto apply_impl(Function&& func, Tuple&& args, index_sequence<
* #include <thrust/zip_function.h>
*
* struct SumTuple {
* float operator()(Tuple tup) {
* return std::get<0>(tup) + std::get<1>(tup) + std::get<2>(tup);
* float operator()(auto tup) const {
* return thrust::get<0>(tup) + thrust::get<1>(tup) + thrust::get<2>(tup);
* }
* };
* struct SumArgs {
* float operator()(float a, float b, float c) {
* float operator()(float a, float b, float c) const {
* return a + b + c;
* }
* };
*
* int main() {
* thrust::device_vector<float> A(3);
* thrust::device_vector<float> B(3);
* thrust::device_vector<float> C(3);
* thrust::device_vector<float> A{0.f, 1.f, 2.f};
* thrust::device_vector<float> B{1.f, 2.f, 3.f};
* thrust::device_vector<float> C{2.f, 3.f, 4.f};
* thrust::device_vector<float> D(3);
* A[0] = 0.f; A[1] = 1.f; A[2] = 2.f;
* B[0] = 1.f; B[1] = 2.f; B[2] = 3.f;
* C[0] = 2.f; C[1] = 3.f; C[2] = 4.f;
*
* // The following four invocations of transform are equivalent
* auto begin = thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin()));
* auto end = thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end()));
*
* // The following four invocations of transform are equivalent:
* // Transform with 3-tuple
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
* D.begin(),
* SumTuple{});
* thrust::transform(begin, end, D.begin(), SumTuple{});
*
* // Transform with 3 parameters
* thrust::zip_function<SumArgs> adapted{};
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
* D.begin(),
* adapted);
* thrust::transform(begin, end, D.begin(), adapted);
*
* // Transform with 3 parameters with convenience function
* thrust::zip_function<SumArgs> adapted{};
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
* D.begin(),
* thrust::make_zip_function(SumArgs{}));
* thrust::transform(begin, end, D.begin(), thrust::make_zip_function(SumArgs{}));
*
* // Transform with 3 parameters with convenience function and lambda
* thrust::zip_function<SumArgs> adapted{};
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
* D.begin(),
* thrust::make_zip_function([] (float a, float b, float c) {
* return a + b + c;
* }));
* thrust::transform(begin, end, D.begin(), thrust::make_zip_function([] (float a, float b, float c) {
* return a + b + c;
* }));
* return 0;
* }
* \endcode
Expand All @@ -154,6 +140,9 @@ template <typename Function>
class zip_function
{
public:
//! Default constructs the contained function object.
zip_function() = default;

_CCCL_HOST_DEVICE zip_function(Function func)
: func(std::move(func))
{}
Expand Down Expand Up @@ -181,6 +170,12 @@ class zip_function

# endif // _CCCL_STD_VER

//! Returns a reference to the underlying function.
_CCCL_HOST_DEVICE Function& underlying_function() const
{
return func;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this. This makes it easier to "undo" the zip_function when the zip iterator has been destructured, and will be really helpful for the thrust::transform performance work.


private:
mutable Function func;
};
Expand Down
Loading