Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1423 from jrhemstad/fix-transform-iterator-noncop…
Browse files Browse the repository at this point in the history
…yable

Fix transform_iterator with non-copyable types
  • Loading branch information
alliepiper authored May 20, 2021
2 parents 13e608f + 8355fa6 commit 8760d0c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
27 changes: 27 additions & 0 deletions testing/transform_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <thrust/sequence.h>
#include <thrust/iterator/counting_iterator.h>

#include <memory>

template <class Vector>
void TestTransformIterator(void)
{
Expand Down Expand Up @@ -84,3 +86,28 @@ struct TestTransformIteratorReduce
};
VariableUnitTest<TestTransformIteratorReduce, IntegralTypes> TestTransformIteratorReduceInstance;


struct ExtractValue{
int operator()(std::unique_ptr<int> const& n){
return *n;
}
};

void TestTransformIteratorNonCopyable(){

thrust::host_vector<std::unique_ptr<int>> hv(4);
hv[0].reset(new int{1});
hv[1].reset(new int{2});
hv[2].reset(new int{3});
hv[3].reset(new int{4});

auto transformed = thrust::make_transform_iterator(hv.begin(), ExtractValue{});
ASSERT_EQUAL(transformed[0], 1);
ASSERT_EQUAL(transformed[1], 2);
ASSERT_EQUAL(transformed[2], 3);
ASSERT_EQUAL(transformed[3], 4);

}

DECLARE_UNITTEST(TestTransformIteratorNonCopyable);

2 changes: 1 addition & 1 deletion thrust/iterator/transform_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ template <class AdaptableUnaryFunction, class Iterator, class Reference = use_de
// Create a temporary to allow iterators with wrapped references to
// convert to their value type before calling m_f. Note that this
// disallows non-constant operations through m_f.
typename thrust::iterator_value<Iterator>::type x = *this->base();
typename thrust::iterator_value<Iterator>::type const& x = *this->base();
return m_f(x);
}

Expand Down

0 comments on commit 8760d0c

Please sign in to comment.