diff --git a/thrust/testing/tabulate_output_iterator.cu b/thrust/testing/tabulate_output_iterator.cu new file mode 100644 index 00000000000..789ed6cf04e --- /dev/null +++ b/thrust/testing/tabulate_output_iterator.cu @@ -0,0 +1,150 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +template +struct host_write_op +{ + OutItT out; + + template + _CCCL_HOST void operator()(IndexT index, T val) + { + out[index] = val; + } +}; + +template +struct host_write_first_op +{ + OutItT out; + + template + _CCCL_HOST void operator()(IndexT index, T val) + { + // val is a thrust::tuple(value, input_index). Only write out the value part. + out[index] = thrust::get<0>(val); + } +}; + +template +struct device_write_first_op +{ + OutItT out; + + template + _CCCL_DEVICE void operator()(IndexT index, T val) + { + // val is a thrust::tuple(value, input_index). Only write out the value part. + out[index] = thrust::get<0>(val); + } +}; + +struct select_op +{ + std::size_t select_every_nth; + + template + _CCCL_HOST_DEVICE bool operator()(thrust::tuple key_index_pair) + { + // Select every n-th item + return (thrust::get<1>(key_index_pair) % select_every_nth == 0); + } +}; + +struct index_to_gather_index_op +{ + std::size_t gather_stride; + + template + _CCCL_HOST_DEVICE IndexT operator()(IndexT index) + { + // Gather the i-th output item from input[i*3] + return index * static_cast(gather_stride); + } +}; + +template +void TestTabulateOutputIterator() +{ + using T = typename Vector::value_type; + using it_t = typename Vector::iterator; + using space = typename thrust::iterator_system::type; + + static constexpr std::size_t num_items = 240; + Vector input(num_items); + Vector output(num_items, T{42}); + + // Use operator type that supports the targeted system + using op_t = typename ::cuda::std::conditional<(::cuda::std::is_same::value), + host_write_first_op, + device_write_first_op>::type; + + // Construct tabulate_output_iterator + op_t op{output.begin()}; + auto tabulate_out_it = thrust::make_tabulate_output_iterator(op); + + // Prepare input + thrust::sequence(input.begin(), input.end(), 1); + auto iota_it = thrust::make_counting_iterator(0); + auto zipped_in = thrust::make_zip_iterator(input.begin(), iota_it); + + // Run copy_if using tabulate_output_iterator as the output iterator + static constexpr std::size_t select_every_nth = 3; + auto selected_it_end = + thrust::copy_if(zipped_in, zipped_in + num_items, tabulate_out_it, select_op{select_every_nth}); + const auto num_selected = static_cast(thrust::distance(tabulate_out_it, selected_it_end)); + + // Prepare expected data + Vector expected_output(num_items, T{42}); + const std::size_t expected_num_selected = (num_items + select_every_nth - 1) / select_every_nth; + auto gather_index_it = + thrust::make_transform_iterator(thrust::make_counting_iterator(0), index_to_gather_index_op{select_every_nth}); + thrust::gather(gather_index_it, gather_index_it + expected_num_selected, input.cbegin(), expected_output.begin()); + + ASSERT_EQUAL(expected_num_selected, num_selected); + ASSERT_EQUAL(output, expected_output); +} +DECLARE_VECTOR_UNITTEST(TestTabulateOutputIterator); + +void TestTabulateOutputIterator() +{ + using vector_t = thrust::host_vector; + using vec_it_t = typename vector_t::iterator; + using op_t = host_write_op; + + vector_t out(4, 42); + thrust::tabulate_output_iterator tabulate_out_it{op_t{out.begin()}}; + + tabulate_out_it[1] = 2; + ASSERT_EQUAL(out[0], 42); + ASSERT_EQUAL(out[1], 2); + ASSERT_EQUAL(out[2], 42); + ASSERT_EQUAL(out[3], 42); + + tabulate_out_it[3] = 0; + ASSERT_EQUAL(out[0], 42); + ASSERT_EQUAL(out[1], 2); + ASSERT_EQUAL(out[2], 42); + ASSERT_EQUAL(out[3], 0); + + tabulate_out_it[1] = 4; + ASSERT_EQUAL(out[0], 42); + ASSERT_EQUAL(out[1], 4); + ASSERT_EQUAL(out[2], 42); + ASSERT_EQUAL(out[3], 0); +} + +DECLARE_UNITTEST(TestTabulateOutputIterator); diff --git a/thrust/thrust/iterator/detail/tabulate_output_iterator.inl b/thrust/thrust/iterator/detail/tabulate_output_iterator.inl new file mode 100644 index 00000000000..b5ed7258015 --- /dev/null +++ b/thrust/thrust/iterator/detail/tabulate_output_iterator.inl @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include +#include +#include + +THRUST_NAMESPACE_BEGIN + +template +class tabulate_output_iterator; + +namespace detail +{ + +// Proxy reference that invokes a BinaryFunction with the index of the dereferenced iterator and the assigned value +template +class tabulate_output_iterator_proxy +{ +public: + _CCCL_HOST_DEVICE tabulate_output_iterator_proxy(BinaryFunction fun, DifferenceT index) + : fun(fun) + , index(index) + {} + + _CCCL_EXEC_CHECK_DISABLE + template + _CCCL_HOST_DEVICE tabulate_output_iterator_proxy operator=(const T& x) + { + fun(index, x); + return *this; + } + +private: + BinaryFunction fun; + DifferenceT index; +}; + +// Alias template for the iterator_adaptor instantiation to be used for tabulate_output_iterator +template +using tabulate_output_iterator_base = + thrust::iterator_adaptor, + counting_iterator, + thrust::use_default, + System, + thrust::use_default, + tabulate_output_iterator_proxy>; + +// Register tabulate_output_iterator_proxy with 'is_proxy_reference' from +// type_traits to enable its use with algorithms. +template +struct is_proxy_reference> + : public thrust::detail::true_type +{}; + +} // namespace detail +THRUST_NAMESPACE_END diff --git a/thrust/thrust/iterator/tabulate_output_iterator.h b/thrust/thrust/iterator/tabulate_output_iterator.h new file mode 100644 index 00000000000..af9a244063e --- /dev/null +++ b/thrust/thrust/iterator/tabulate_output_iterator.h @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include + +THRUST_NAMESPACE_BEGIN + +/*! \addtogroup iterators + * \{ + */ + +/*! \addtogroup fancyiterator Fancy Iterators + * \ingroup iterators + * \{ + */ + +/*! \p tabulate_output_iterator is a special kind of output iterator which, whenever a value is assigned to a + * dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced + * iterator and the assigned value. + * + * The following code snippet demonstrated how to create a \p tabulate_output_iterator which prints the index and the + * assigned value. + * + * \code + * #include + * + * // note: functor inherits form binary function + * struct print_op + * { + * __host__ __device__ + * void operator()(int index, float value) const + * { + * printf("%d: %f\n", index, value); + * } + * }; + * + * int main() + * { + * auto tabulate_it = thrust::make_tabulate_output_iterator(print_op{}); + * + * tabulate_it[0] = 1.0f; // prints: 0: 1.0 + * tabulate_it[1] = 3.0f; // prints: 1: 3.0 + * tabulate_it[9] = 5.0f; // prints: 9: 5.0 + * } + * \endcode + * + * \see make_tabulate_output_iterator + */ + +template +class tabulate_output_iterator : public detail::tabulate_output_iterator_base +{ + /*! \cond + */ + +public: + using super_t = detail::tabulate_output_iterator_base; + + friend class thrust::iterator_core_access; + /*! \endcond + */ + + tabulate_output_iterator() = default; + + /*! This constructor takes as argument a \c BinaryFunction and copies it to a new \p tabulate_output_iterator + * + * \param fun A \c BinaryFunction called whenever a value is assigned to this \p tabulate_output_iterator. + */ + _CCCL_HOST_DEVICE tabulate_output_iterator(BinaryFunction fun) + : fun(fun) + {} + + /*! \cond + */ + +private: + _CCCL_HOST_DEVICE typename super_t::reference dereference() const + { + return detail::tabulate_output_iterator_proxy(fun, *this->base()); + } + + BinaryFunction fun; + + /*! \endcond + */ +}; // end tabulate_output_iterator + +/*! \p make_tabulate_output_iterator creates a \p tabulate_output_iterator from a \c BinaryFunction. + * + * \param fun The \c BinaryFunction invoked whenever assigning to a dereferenced \p tabulate_output_iterator + * \see tabulate_output_iterator + */ +template +tabulate_output_iterator _CCCL_HOST_DEVICE make_tabulate_output_iterator(BinaryFunction fun) +{ + return tabulate_output_iterator(fun); +} // end make_tabulate_output_iterator + +/*! \} // end fancyiterators + */ + +/*! \} // end iterators + */ + +THRUST_NAMESPACE_END