Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds thrust::tabulate_output_iterator #2282

Merged
merged 5 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions thrust/testing/tabulate_output_iterator.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/tabulate_output_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>

#include <cuda/std/type_traits>

#include <unittest/unittest.h>

template <typename OutItT>
struct host_write_op
{
OutItT out;

template <typename IndexT, typename T>
_CCCL_HOST void operator()(IndexT index, T val)
{
out[index] = val;
}
};
bernhardmgruber marked this conversation as resolved.
Show resolved Hide resolved

template <typename OutItT>
struct host_write_first_op
{
OutItT out;

template <typename IndexT, typename T>
_CCCL_HOST void operator()(IndexT index, T val)
{
// val is a thrust::tuple(value, input_index). Only write out the value part.
out[index] = thrust::get<0>(val);
}
};

template <typename OutItT>
struct device_write_first_op
{
OutItT out;

template <typename IndexT, typename T>
_CCCL_DEVICE void operator()(IndexT index, T val)
{
// val is a thrust::tuple(value, input_index). Only write out the value part.
out[index] = thrust::get<0>(val);
}
};
elstehle marked this conversation as resolved.
Show resolved Hide resolved

struct select_op
{
std::size_t select_every_nth;

template <typename T, typename IndexT>
_CCCL_HOST_DEVICE bool operator()(thrust::tuple<T, IndexT> key_index_pair)
{
// Select every n-th item
return (thrust::get<1>(key_index_pair) % select_every_nth == 0);
}
};

struct index_to_gather_index_op
{
std::size_t gather_stride;

template <typename IndexT>
_CCCL_HOST_DEVICE IndexT operator()(IndexT index)
{
// Gather the i-th output item from input[i*3]
return index * static_cast<IndexT>(gather_stride);
}
};

template <class Vector>
void TestTabulateOutputIterator()
{
using T = typename Vector::value_type;
using it_t = typename Vector::iterator;
using space = typename thrust::iterator_system<typename Vector::iterator>::type;

static constexpr std::size_t num_items = 240;
Vector input(num_items);
Vector output(num_items, T{42});

// Use operator type that supports the targeted system
using op_t = typename ::cuda::std::conditional<(::cuda::std::is_same<space, thrust::host_system_tag>::value),
host_write_first_op<it_t>,
device_write_first_op<it_t>>::type;

// Construct tabulate_output_iterator
op_t op{output.begin()};
auto tabulate_out_it = thrust::make_tabulate_output_iterator(op);

// Prepare input
thrust::sequence(input.begin(), input.end(), 1);
auto iota_it = thrust::make_counting_iterator(0);
auto zipped_in = thrust::make_zip_iterator(input.begin(), iota_it);

// Run copy_if using tabulate_output_iterator as the output iterator
static constexpr std::size_t select_every_nth = 3;
auto selected_it_end =
thrust::copy_if(zipped_in, zipped_in + num_items, tabulate_out_it, select_op{select_every_nth});
const auto num_selected = static_cast<std::size_t>(thrust::distance(tabulate_out_it, selected_it_end));

// Prepare expected data
Vector expected_output(num_items, T{42});
const std::size_t expected_num_selected = (num_items + select_every_nth - 1) / select_every_nth;
auto gather_index_it =
thrust::make_transform_iterator(thrust::make_counting_iterator(0), index_to_gather_index_op{select_every_nth});
thrust::gather(gather_index_it, gather_index_it + expected_num_selected, input.cbegin(), expected_output.begin());

ASSERT_EQUAL(expected_num_selected, num_selected);
ASSERT_EQUAL(output, expected_output);
}
DECLARE_VECTOR_UNITTEST(TestTabulateOutputIterator);

void TestTabulateOutputIterator()
{
using vector_t = thrust::host_vector<int>;
using vec_it_t = typename vector_t::iterator;
using op_t = host_write_op<vec_it_t>;

vector_t out(4, 42);
thrust::tabulate_output_iterator<op_t> tabulate_out_it{op_t{out.begin()}};

tabulate_out_it[1] = 2;
ASSERT_EQUAL(out[0], 42);
ASSERT_EQUAL(out[1], 2);
ASSERT_EQUAL(out[2], 42);
ASSERT_EQUAL(out[3], 42);

tabulate_out_it[3] = 0;
ASSERT_EQUAL(out[0], 42);
ASSERT_EQUAL(out[1], 2);
ASSERT_EQUAL(out[2], 42);
ASSERT_EQUAL(out[3], 0);

tabulate_out_it[1] = 4;
ASSERT_EQUAL(out[0], 42);
ASSERT_EQUAL(out[1], 4);
ASSERT_EQUAL(out[2], 42);
ASSERT_EQUAL(out[3], 0);
}

DECLARE_UNITTEST(TestTabulateOutputIterator);
69 changes: 69 additions & 0 deletions thrust/thrust/iterator/detail/tabulate_output_iterator.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/tabulate_output_iterator.h>

THRUST_NAMESPACE_BEGIN

template <typename BinaryFunction, typename System, typename DifferenceT>
class tabulate_output_iterator;

namespace detail
{

// Proxy reference that invokes a BinaryFunction with the index of the dereferenced iterator and the assigned value
template <typename BinaryFunction, typename DifferenceT>
class tabulate_output_iterator_proxy
{
public:
_CCCL_HOST_DEVICE tabulate_output_iterator_proxy(BinaryFunction fun, DifferenceT index)
: fun(fun)
, index(index)
{}

_CCCL_EXEC_CHECK_DISABLE
template <typename T>
_CCCL_HOST_DEVICE tabulate_output_iterator_proxy operator=(const T& x)
{
fun(index, x);
return *this;
}

private:
BinaryFunction fun;
DifferenceT index;
};

// Alias template for the iterator_adaptor instantiation to be used for tabulate_output_iterator
template <typename BinaryFunction, typename System, typename DifferenceT>
using tabulate_output_iterator_base =
thrust::iterator_adaptor<tabulate_output_iterator<BinaryFunction, System, DifferenceT>,
counting_iterator<DifferenceT>,
thrust::use_default,
System,
thrust::use_default,
tabulate_output_iterator_proxy<BinaryFunction, DifferenceT>>;

// Register tabulate_output_iterator_proxy with 'is_proxy_reference' from
// type_traits to enable its use with algorithms.
template <class BinaryFunction, class OutputIterator>
struct is_proxy_reference<tabulate_output_iterator_proxy<BinaryFunction, OutputIterator>>
bernhardmgruber marked this conversation as resolved.
Show resolved Hide resolved
: public thrust::detail::true_type
{};

} // namespace detail
THRUST_NAMESPACE_END
117 changes: 117 additions & 0 deletions thrust/thrust/iterator/tabulate_output_iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <thrust/iterator/detail/tabulate_output_iterator.inl>
elstehle marked this conversation as resolved.
Show resolved Hide resolved

THRUST_NAMESPACE_BEGIN

/*! \addtogroup iterators
* \{
*/

/*! \addtogroup fancyiterator Fancy Iterators
* \ingroup iterators
* \{
*/

/*! \p tabulate_output_iterator is a special kind of output iterator which, whenever a value is assigned to a
* dereferenced iterator, calls the given callable with the index that corresponds to the offset of the dereferenced
* iterator and the assigned value.
*
* The following code snippet demonstrated how to create a \p tabulate_output_iterator which prints the index and the
* assigned value.
*
* \code
* #include <thrust/iterator/tabulate_output_iterator.h>
*
* // note: functor inherits form binary function
* struct print_op
* {
* __host__ __device__
* void operator()(int index, float value) const
* {
* printf("%d: %f\n", index, value);
* }
* };
*
* int main()
* {
* auto tabulate_it = thrust::make_tabulate_output_iterator(print_op{});
*
* tabulate_it[0] = 1.0f; // prints: 0: 1.0
* tabulate_it[1] = 3.0f; // prints: 1: 3.0
* tabulate_it[9] = 5.0f; // prints: 9: 5.0
* }
* \endcode
*
* \see make_tabulate_output_iterator
*/

template <typename BinaryFunction, typename System = use_default, typename DifferenceT = ptrdiff_t>
class tabulate_output_iterator : public detail::tabulate_output_iterator_base<BinaryFunction, System, DifferenceT>
{
/*! \cond
*/

public:
using super_t = detail::tabulate_output_iterator_base<BinaryFunction, System, DifferenceT>;

friend class thrust::iterator_core_access;
/*! \endcond
*/

tabulate_output_iterator() = default;

/*! This constructor takes as argument a \c BinaryFunction and copies it to a new \p tabulate_output_iterator
*
* \param fun A \c BinaryFunction called whenever a value is assigned to this \p tabulate_output_iterator.
*/
_CCCL_HOST_DEVICE tabulate_output_iterator(BinaryFunction fun)
: fun(fun)
{}

/*! \cond
*/

private:
_CCCL_HOST_DEVICE typename super_t::reference dereference() const
{
return detail::tabulate_output_iterator_proxy<BinaryFunction, DifferenceT>(fun, *this->base());
}

BinaryFunction fun;

/*! \endcond
*/
}; // end tabulate_output_iterator

/*! \p make_tabulate_output_iterator creates a \p tabulate_output_iterator from a \c BinaryFunction.
*
* \param fun The \c BinaryFunction invoked whenever assigning to a dereferenced \p tabulate_output_iterator
* \see tabulate_output_iterator
*/
template <typename BinaryFunction>
tabulate_output_iterator<BinaryFunction> _CCCL_HOST_DEVICE make_tabulate_output_iterator(BinaryFunction fun)
{
return tabulate_output_iterator<BinaryFunction>(fun);
} // end make_tabulate_output_iterator

/*! \} // end fancyiterators
*/

/*! \} // end iterators
*/

THRUST_NAMESPACE_END
Loading