Skip to content

Commit

Permalink
Adds thrust::tabulate_output_iterator (#2282)
Browse files Browse the repository at this point in the history
* adds tabulate output iterator

* uses cccl exec space macros

* addresses review comments

* fixes documentation and example

* moves to using alias template instead of member type
  • Loading branch information
elstehle authored Aug 25, 2024
1 parent d62e979 commit 0d0d2d3
Show file tree
Hide file tree
Showing 3 changed files with 336 additions and 0 deletions.
150 changes: 150 additions & 0 deletions thrust/testing/tabulate_output_iterator.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/tabulate_output_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>

#include <cuda/std/type_traits>

#include <unittest/unittest.h>

template <typename OutItT>
struct host_write_op
{
OutItT out;

template <typename IndexT, typename T>
_CCCL_HOST void operator()(IndexT index, T val)
{
out[index] = val;
}
};

template <typename OutItT>
struct host_write_first_op
{
OutItT out;

template <typename IndexT, typename T>
_CCCL_HOST void operator()(IndexT index, T val)
{
// val is a thrust::tuple(value, input_index). Only write out the value part.
out[index] = thrust::get<0>(val);
}
};

template <typename OutItT>
struct device_write_first_op
{
OutItT out;

template <typename IndexT, typename T>
_CCCL_DEVICE void operator()(IndexT index, T val)
{
// val is a thrust::tuple(value, input_index). Only write out the value part.
out[index] = thrust::get<0>(val);
}
};

struct select_op
{
std::size_t select_every_nth;

template <typename T, typename IndexT>
_CCCL_HOST_DEVICE bool operator()(thrust::tuple<T, IndexT> key_index_pair)
{
// Select every n-th item
return (thrust::get<1>(key_index_pair) % select_every_nth == 0);
}
};

struct index_to_gather_index_op
{
std::size_t gather_stride;

template <typename IndexT>
_CCCL_HOST_DEVICE IndexT operator()(IndexT index)
{
// Gather the i-th output item from input[i*3]
return index * static_cast<IndexT>(gather_stride);
}
};

template <class Vector>
void TestTabulateOutputIterator()
{
using T = typename Vector::value_type;
using it_t = typename Vector::iterator;
using space = typename thrust::iterator_system<typename Vector::iterator>::type;

static constexpr std::size_t num_items = 240;
Vector input(num_items);
Vector output(num_items, T{42});

// Use operator type that supports the targeted system
using op_t = typename ::cuda::std::conditional<(::cuda::std::is_same<space, thrust::host_system_tag>::value),
host_write_first_op<it_t>,
device_write_first_op<it_t>>::type;

// Construct tabulate_output_iterator
op_t op{output.begin()};
auto tabulate_out_it = thrust::make_tabulate_output_iterator(op);

// Prepare input
thrust::sequence(input.begin(), input.end(), 1);
auto iota_it = thrust::make_counting_iterator(0);
auto zipped_in = thrust::make_zip_iterator(input.begin(), iota_it);

// Run copy_if using tabulate_output_iterator as the output iterator
static constexpr std::size_t select_every_nth = 3;
auto selected_it_end =
thrust::copy_if(zipped_in, zipped_in + num_items, tabulate_out_it, select_op{select_every_nth});
const auto num_selected = static_cast<std::size_t>(thrust::distance(tabulate_out_it, selected_it_end));

// Prepare expected data
Vector expected_output(num_items, T{42});
const std::size_t expected_num_selected = (num_items + select_every_nth - 1) / select_every_nth;
auto gather_index_it =
thrust::make_transform_iterator(thrust::make_counting_iterator(0), index_to_gather_index_op{select_every_nth});
thrust::gather(gather_index_it, gather_index_it + expected_num_selected, input.cbegin(), expected_output.begin());

ASSERT_EQUAL(expected_num_selected, num_selected);
ASSERT_EQUAL(output, expected_output);
}
DECLARE_VECTOR_UNITTEST(TestTabulateOutputIterator);

void TestTabulateOutputIterator()
{
using vector_t = thrust::host_vector<int>;
using vec_it_t = typename vector_t::iterator;
using op_t = host_write_op<vec_it_t>;

vector_t out(4, 42);
thrust::tabulate_output_iterator<op_t> tabulate_out_it{op_t{out.begin()}};

tabulate_out_it[1] = 2;
ASSERT_EQUAL(out[0], 42);
ASSERT_EQUAL(out[1], 2);
ASSERT_EQUAL(out[2], 42);
ASSERT_EQUAL(out[3], 42);

tabulate_out_it[3] = 0;
ASSERT_EQUAL(out[0], 42);
ASSERT_EQUAL(out[1], 2);
ASSERT_EQUAL(out[2], 42);
ASSERT_EQUAL(out[3], 0);

tabulate_out_it[1] = 4;
ASSERT_EQUAL(out[0], 42);
ASSERT_EQUAL(out[1], 4);
ASSERT_EQUAL(out[2], 42);
ASSERT_EQUAL(out[3], 0);
}

DECLARE_UNITTEST(TestTabulateOutputIterator);
69 changes: 69 additions & 0 deletions thrust/thrust/iterator/detail/tabulate_output_iterator.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/tabulate_output_iterator.h>

THRUST_NAMESPACE_BEGIN

template <typename BinaryFunction, typename System, typename DifferenceT>
class tabulate_output_iterator;

namespace detail
{

// Proxy reference that invokes a BinaryFunction with the index of the dereferenced iterator and the assigned value
template <typename BinaryFunction, typename DifferenceT>
class tabulate_output_iterator_proxy
{
public:
_CCCL_HOST_DEVICE tabulate_output_iterator_proxy(BinaryFunction fun, DifferenceT index)
: fun(fun)
, index(index)
{}

_CCCL_EXEC_CHECK_DISABLE
template <typename T>
_CCCL_HOST_DEVICE tabulate_output_iterator_proxy operator=(const T& x)
{
fun(index, x);
return *this;
}

private:
BinaryFunction fun;
DifferenceT index;
};

// Alias template for the iterator_adaptor instantiation to be used for tabulate_output_iterator
template <typename BinaryFunction, typename System, typename DifferenceT>
using tabulate_output_iterator_base =
thrust::iterator_adaptor<tabulate_output_iterator<BinaryFunction, System, DifferenceT>,
counting_iterator<DifferenceT>,
thrust::use_default,
System,
thrust::use_default,
tabulate_output_iterator_proxy<BinaryFunction, DifferenceT>>;

// Register tabulate_output_iterator_proxy with 'is_proxy_reference' from
// type_traits to enable its use with algorithms.
template <class BinaryFunction, class OutputIterator>
struct is_proxy_reference<tabulate_output_iterator_proxy<BinaryFunction, OutputIterator>>
: public thrust::detail::true_type
{};

} // namespace detail
THRUST_NAMESPACE_END
117 changes: 117 additions & 0 deletions thrust/thrust/iterator/tabulate_output_iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <thrust/iterator/detail/tabulate_output_iterator.inl>

THRUST_NAMESPACE_BEGIN

/*! \addtogroup iterators
* \{
*/

/*! \addtogroup fancyiterator Fancy Iterators
* \ingroup iterators
* \{
*/

/*! \p tabulate_output_iterator is a special kind of output iterator which, whenever a value is assigned to a
* dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced
* iterator and the assigned value.
*
* The following code snippet demonstrated how to create a \p tabulate_output_iterator which prints the index and the
* assigned value.
*
* \code
* #include <thrust/iterator/tabulate_output_iterator.h>
*
* // note: functor inherits form binary function
* struct print_op
* {
* __host__ __device__
* void operator()(int index, float value) const
* {
* printf("%d: %f\n", index, value);
* }
* };
*
* int main()
* {
* auto tabulate_it = thrust::make_tabulate_output_iterator(print_op{});
*
* tabulate_it[0] = 1.0f; // prints: 0: 1.0
* tabulate_it[1] = 3.0f; // prints: 1: 3.0
* tabulate_it[9] = 5.0f; // prints: 9: 5.0
* }
* \endcode
*
* \see make_tabulate_output_iterator
*/

template <typename BinaryFunction, typename System = use_default, typename DifferenceT = ptrdiff_t>
class tabulate_output_iterator : public detail::tabulate_output_iterator_base<BinaryFunction, System, DifferenceT>
{
/*! \cond
*/

public:
using super_t = detail::tabulate_output_iterator_base<BinaryFunction, System, DifferenceT>;

friend class thrust::iterator_core_access;
/*! \endcond
*/

tabulate_output_iterator() = default;

/*! This constructor takes as argument a \c BinaryFunction and copies it to a new \p tabulate_output_iterator
*
* \param fun A \c BinaryFunction called whenever a value is assigned to this \p tabulate_output_iterator.
*/
_CCCL_HOST_DEVICE tabulate_output_iterator(BinaryFunction fun)
: fun(fun)
{}

/*! \cond
*/

private:
_CCCL_HOST_DEVICE typename super_t::reference dereference() const
{
return detail::tabulate_output_iterator_proxy<BinaryFunction, DifferenceT>(fun, *this->base());
}

BinaryFunction fun;

/*! \endcond
*/
}; // end tabulate_output_iterator

/*! \p make_tabulate_output_iterator creates a \p tabulate_output_iterator from a \c BinaryFunction.
*
* \param fun The \c BinaryFunction invoked whenever assigning to a dereferenced \p tabulate_output_iterator
* \see tabulate_output_iterator
*/
template <typename BinaryFunction>
tabulate_output_iterator<BinaryFunction> _CCCL_HOST_DEVICE make_tabulate_output_iterator(BinaryFunction fun)
{
return tabulate_output_iterator<BinaryFunction>(fun);
} // end make_tabulate_output_iterator

/*! \} // end fancyiterators
*/

/*! \} // end iterators
*/

THRUST_NAMESPACE_END

0 comments on commit 0d0d2d3

Please sign in to comment.