-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds
thrust::tabulate_output_iterator
(#2282)
* adds tabulate output iterator * uses cccl exec space macros * addresses review comments * fixes documentation and example * moves to using alias template instead of member type
- Loading branch information
Showing
3 changed files
with
336 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
#include <thrust/copy.h> | ||
#include <thrust/device_vector.h> | ||
#include <thrust/functional.h> | ||
#include <thrust/gather.h> | ||
#include <thrust/host_vector.h> | ||
#include <thrust/iterator/counting_iterator.h> | ||
#include <thrust/iterator/tabulate_output_iterator.h> | ||
#include <thrust/iterator/transform_iterator.h> | ||
#include <thrust/iterator/zip_iterator.h> | ||
#include <thrust/reduce.h> | ||
#include <thrust/sequence.h> | ||
|
||
#include <cuda/std/type_traits> | ||
|
||
#include <unittest/unittest.h> | ||
|
||
template <typename OutItT> | ||
struct host_write_op | ||
{ | ||
OutItT out; | ||
|
||
template <typename IndexT, typename T> | ||
_CCCL_HOST void operator()(IndexT index, T val) | ||
{ | ||
out[index] = val; | ||
} | ||
}; | ||
|
||
template <typename OutItT> | ||
struct host_write_first_op | ||
{ | ||
OutItT out; | ||
|
||
template <typename IndexT, typename T> | ||
_CCCL_HOST void operator()(IndexT index, T val) | ||
{ | ||
// val is a thrust::tuple(value, input_index). Only write out the value part. | ||
out[index] = thrust::get<0>(val); | ||
} | ||
}; | ||
|
||
template <typename OutItT> | ||
struct device_write_first_op | ||
{ | ||
OutItT out; | ||
|
||
template <typename IndexT, typename T> | ||
_CCCL_DEVICE void operator()(IndexT index, T val) | ||
{ | ||
// val is a thrust::tuple(value, input_index). Only write out the value part. | ||
out[index] = thrust::get<0>(val); | ||
} | ||
}; | ||
|
||
struct select_op | ||
{ | ||
std::size_t select_every_nth; | ||
|
||
template <typename T, typename IndexT> | ||
_CCCL_HOST_DEVICE bool operator()(thrust::tuple<T, IndexT> key_index_pair) | ||
{ | ||
// Select every n-th item | ||
return (thrust::get<1>(key_index_pair) % select_every_nth == 0); | ||
} | ||
}; | ||
|
||
struct index_to_gather_index_op | ||
{ | ||
std::size_t gather_stride; | ||
|
||
template <typename IndexT> | ||
_CCCL_HOST_DEVICE IndexT operator()(IndexT index) | ||
{ | ||
// Gather the i-th output item from input[i*3] | ||
return index * static_cast<IndexT>(gather_stride); | ||
} | ||
}; | ||
|
||
template <class Vector> | ||
void TestTabulateOutputIterator() | ||
{ | ||
using T = typename Vector::value_type; | ||
using it_t = typename Vector::iterator; | ||
using space = typename thrust::iterator_system<typename Vector::iterator>::type; | ||
|
||
static constexpr std::size_t num_items = 240; | ||
Vector input(num_items); | ||
Vector output(num_items, T{42}); | ||
|
||
// Use operator type that supports the targeted system | ||
using op_t = typename ::cuda::std::conditional<(::cuda::std::is_same<space, thrust::host_system_tag>::value), | ||
host_write_first_op<it_t>, | ||
device_write_first_op<it_t>>::type; | ||
|
||
// Construct tabulate_output_iterator | ||
op_t op{output.begin()}; | ||
auto tabulate_out_it = thrust::make_tabulate_output_iterator(op); | ||
|
||
// Prepare input | ||
thrust::sequence(input.begin(), input.end(), 1); | ||
auto iota_it = thrust::make_counting_iterator(0); | ||
auto zipped_in = thrust::make_zip_iterator(input.begin(), iota_it); | ||
|
||
// Run copy_if using tabulate_output_iterator as the output iterator | ||
static constexpr std::size_t select_every_nth = 3; | ||
auto selected_it_end = | ||
thrust::copy_if(zipped_in, zipped_in + num_items, tabulate_out_it, select_op{select_every_nth}); | ||
const auto num_selected = static_cast<std::size_t>(thrust::distance(tabulate_out_it, selected_it_end)); | ||
|
||
// Prepare expected data | ||
Vector expected_output(num_items, T{42}); | ||
const std::size_t expected_num_selected = (num_items + select_every_nth - 1) / select_every_nth; | ||
auto gather_index_it = | ||
thrust::make_transform_iterator(thrust::make_counting_iterator(0), index_to_gather_index_op{select_every_nth}); | ||
thrust::gather(gather_index_it, gather_index_it + expected_num_selected, input.cbegin(), expected_output.begin()); | ||
|
||
ASSERT_EQUAL(expected_num_selected, num_selected); | ||
ASSERT_EQUAL(output, expected_output); | ||
} | ||
DECLARE_VECTOR_UNITTEST(TestTabulateOutputIterator); | ||
|
||
void TestTabulateOutputIterator() | ||
{ | ||
using vector_t = thrust::host_vector<int>; | ||
using vec_it_t = typename vector_t::iterator; | ||
using op_t = host_write_op<vec_it_t>; | ||
|
||
vector_t out(4, 42); | ||
thrust::tabulate_output_iterator<op_t> tabulate_out_it{op_t{out.begin()}}; | ||
|
||
tabulate_out_it[1] = 2; | ||
ASSERT_EQUAL(out[0], 42); | ||
ASSERT_EQUAL(out[1], 2); | ||
ASSERT_EQUAL(out[2], 42); | ||
ASSERT_EQUAL(out[3], 42); | ||
|
||
tabulate_out_it[3] = 0; | ||
ASSERT_EQUAL(out[0], 42); | ||
ASSERT_EQUAL(out[1], 2); | ||
ASSERT_EQUAL(out[2], 42); | ||
ASSERT_EQUAL(out[3], 0); | ||
|
||
tabulate_out_it[1] = 4; | ||
ASSERT_EQUAL(out[0], 42); | ||
ASSERT_EQUAL(out[1], 4); | ||
ASSERT_EQUAL(out[2], 42); | ||
ASSERT_EQUAL(out[3], 0); | ||
} | ||
|
||
DECLARE_UNITTEST(TestTabulateOutputIterator); |
69 changes: 69 additions & 0 deletions
69
thrust/thrust/iterator/detail/tabulate_output_iterator.inl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#pragma once | ||
|
||
#include <thrust/detail/config.h> | ||
|
||
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||
# pragma GCC system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||
# pragma clang system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||
# pragma system_header | ||
#endif // no system header | ||
|
||
#include <thrust/iterator/counting_iterator.h> | ||
#include <thrust/iterator/iterator_adaptor.h> | ||
#include <thrust/iterator/tabulate_output_iterator.h> | ||
|
||
THRUST_NAMESPACE_BEGIN | ||
|
||
template <typename BinaryFunction, typename System, typename DifferenceT> | ||
class tabulate_output_iterator; | ||
|
||
namespace detail | ||
{ | ||
|
||
// Proxy reference that invokes a BinaryFunction with the index of the dereferenced iterator and the assigned value | ||
template <typename BinaryFunction, typename DifferenceT> | ||
class tabulate_output_iterator_proxy | ||
{ | ||
public: | ||
_CCCL_HOST_DEVICE tabulate_output_iterator_proxy(BinaryFunction fun, DifferenceT index) | ||
: fun(fun) | ||
, index(index) | ||
{} | ||
|
||
_CCCL_EXEC_CHECK_DISABLE | ||
template <typename T> | ||
_CCCL_HOST_DEVICE tabulate_output_iterator_proxy operator=(const T& x) | ||
{ | ||
fun(index, x); | ||
return *this; | ||
} | ||
|
||
private: | ||
BinaryFunction fun; | ||
DifferenceT index; | ||
}; | ||
|
||
// Alias template for the iterator_adaptor instantiation to be used for tabulate_output_iterator | ||
template <typename BinaryFunction, typename System, typename DifferenceT> | ||
using tabulate_output_iterator_base = | ||
thrust::iterator_adaptor<tabulate_output_iterator<BinaryFunction, System, DifferenceT>, | ||
counting_iterator<DifferenceT>, | ||
thrust::use_default, | ||
System, | ||
thrust::use_default, | ||
tabulate_output_iterator_proxy<BinaryFunction, DifferenceT>>; | ||
|
||
// Register tabulate_output_iterator_proxy with 'is_proxy_reference' from | ||
// type_traits to enable its use with algorithms. | ||
template <class BinaryFunction, class OutputIterator> | ||
struct is_proxy_reference<tabulate_output_iterator_proxy<BinaryFunction, OutputIterator>> | ||
: public thrust::detail::true_type | ||
{}; | ||
|
||
} // namespace detail | ||
THRUST_NAMESPACE_END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#pragma once | ||
|
||
#include <thrust/detail/config.h> | ||
|
||
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||
# pragma GCC system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||
# pragma clang system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||
# pragma system_header | ||
#endif // no system header | ||
|
||
#include <thrust/iterator/detail/tabulate_output_iterator.inl> | ||
|
||
THRUST_NAMESPACE_BEGIN | ||
|
||
/*! \addtogroup iterators | ||
* \{ | ||
*/ | ||
|
||
/*! \addtogroup fancyiterator Fancy Iterators | ||
* \ingroup iterators | ||
* \{ | ||
*/ | ||
|
||
/*! \p tabulate_output_iterator is a special kind of output iterator which, whenever a value is assigned to a | ||
* dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced | ||
* iterator and the assigned value. | ||
* | ||
* The following code snippet demonstrated how to create a \p tabulate_output_iterator which prints the index and the | ||
* assigned value. | ||
* | ||
* \code | ||
* #include <thrust/iterator/tabulate_output_iterator.h> | ||
* | ||
* // note: functor inherits form binary function | ||
* struct print_op | ||
* { | ||
* __host__ __device__ | ||
* void operator()(int index, float value) const | ||
* { | ||
* printf("%d: %f\n", index, value); | ||
* } | ||
* }; | ||
* | ||
* int main() | ||
* { | ||
* auto tabulate_it = thrust::make_tabulate_output_iterator(print_op{}); | ||
* | ||
* tabulate_it[0] = 1.0f; // prints: 0: 1.0 | ||
* tabulate_it[1] = 3.0f; // prints: 1: 3.0 | ||
* tabulate_it[9] = 5.0f; // prints: 9: 5.0 | ||
* } | ||
* \endcode | ||
* | ||
* \see make_tabulate_output_iterator | ||
*/ | ||
|
||
template <typename BinaryFunction, typename System = use_default, typename DifferenceT = ptrdiff_t> | ||
class tabulate_output_iterator : public detail::tabulate_output_iterator_base<BinaryFunction, System, DifferenceT> | ||
{ | ||
/*! \cond | ||
*/ | ||
|
||
public: | ||
using super_t = detail::tabulate_output_iterator_base<BinaryFunction, System, DifferenceT>; | ||
|
||
friend class thrust::iterator_core_access; | ||
/*! \endcond | ||
*/ | ||
|
||
tabulate_output_iterator() = default; | ||
|
||
/*! This constructor takes as argument a \c BinaryFunction and copies it to a new \p tabulate_output_iterator | ||
* | ||
* \param fun A \c BinaryFunction called whenever a value is assigned to this \p tabulate_output_iterator. | ||
*/ | ||
_CCCL_HOST_DEVICE tabulate_output_iterator(BinaryFunction fun) | ||
: fun(fun) | ||
{} | ||
|
||
/*! \cond | ||
*/ | ||
|
||
private: | ||
_CCCL_HOST_DEVICE typename super_t::reference dereference() const | ||
{ | ||
return detail::tabulate_output_iterator_proxy<BinaryFunction, DifferenceT>(fun, *this->base()); | ||
} | ||
|
||
BinaryFunction fun; | ||
|
||
/*! \endcond | ||
*/ | ||
}; // end tabulate_output_iterator | ||
|
||
/*! \p make_tabulate_output_iterator creates a \p tabulate_output_iterator from a \c BinaryFunction. | ||
* | ||
* \param fun The \c BinaryFunction invoked whenever assigning to a dereferenced \p tabulate_output_iterator | ||
* \see tabulate_output_iterator | ||
*/ | ||
template <typename BinaryFunction> | ||
tabulate_output_iterator<BinaryFunction> _CCCL_HOST_DEVICE make_tabulate_output_iterator(BinaryFunction fun) | ||
{ | ||
return tabulate_output_iterator<BinaryFunction>(fun); | ||
} // end make_tabulate_output_iterator | ||
|
||
/*! \} // end fancyiterators | ||
*/ | ||
|
||
/*! \} // end iterators | ||
*/ | ||
|
||
THRUST_NAMESPACE_END |