Skip to content

Commit

Permalink
Workaround thrust-copy-if limit in json get_tree_representation (#12190)
Browse files Browse the repository at this point in the history
Workaround in json's get_tree_representation due to limitation in `thrust::copy_if` which fails if the input-iterator spans more than int-max.

Found existing thrust issue: https://github.com/NVIDIA/thrust/issues/1302
This calls the thrust::copy_if in chunks if the iterator can span greater than int-max.

Found while working on #12079

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Elias Stehle (https://github.com/elstehle)
  - Mike Wilson (https://github.com/hyperbolic2346)

URL: #12190
  • Loading branch information
davidwendt authored Dec 6, 2022
1 parent 1ca4dad commit 394f414
Showing 1 changed file with 64 additions and 20 deletions.
84 changes: 64 additions & 20 deletions cpp/src/io/json/json_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,50 @@

namespace cudf::io::json {
namespace detail {
namespace {

/**
* @brief Utility for calling thrust::copy_if
*
* Workaround for thrust::copy_if bug (https://github.com/NVIDIA/thrust/issues/1302)
* where it cannot iterate over int-max values `distance(first,last) > int-max`
* This calls thrust::copy_if in 2B chunks instead.
*/
template <typename InputIterator,
typename StencilIterator,
typename OutputIterator,
typename Predicate>
OutputIterator thrust_copy_if(rmm::exec_policy policy,
InputIterator first,
InputIterator last,
StencilIterator stencil,
OutputIterator result,
Predicate pred)
{
auto const copy_size = std::min(static_cast<std::size_t>(std::distance(first, last)),
static_cast<std::size_t>(std::numeric_limits<int>::max()));

auto itr = first;
while (itr != last) {
auto const copy_end =
static_cast<std::size_t>(std::distance(itr, last)) <= copy_size ? last : itr + copy_size;
result = thrust::copy_if(policy, itr, copy_end, stencil, result, pred);
stencil += std::distance(itr, copy_end);
itr = copy_end;
}
return result;
}

template <typename InputIterator, typename OutputIterator, typename Predicate>
OutputIterator thrust_copy_if(rmm::exec_policy policy,
InputIterator first,
InputIterator last,
OutputIterator result,
Predicate pred)
{
return thrust_copy_if(policy, first, last, first, result, pred);
}
} // namespace

// The node that a token represents
struct token_to_node {
Expand Down Expand Up @@ -279,12 +323,12 @@ tree_meta_t get_tree_representation(device_span<PdaTokenT const> tokens,
thrust::exclusive_scan(
rmm::exec_policy(stream), push_pop_it, push_pop_it + num_tokens, token_levels.begin());

auto const node_levels_end = thrust::copy_if(rmm::exec_policy(stream),
token_levels.begin(),
token_levels.end(),
tokens.begin(),
node_levels.begin(),
is_node);
auto const node_levels_end = thrust_copy_if(rmm::exec_policy(stream),
token_levels.begin(),
token_levels.end(),
tokens.begin(),
node_levels.begin(),
is_node);
CUDF_EXPECTS(thrust::distance(node_levels.begin(), node_levels_end) == num_nodes,
"node level count mismatch");
}
Expand All @@ -295,12 +339,12 @@ tree_meta_t get_tree_representation(device_span<PdaTokenT const> tokens,
// This block of code is generalized logical stack algorithm. TODO: make this a separate function.
{
rmm::device_uvector<NodeIndexT> node_token_ids(num_nodes, stream);
thrust::copy_if(rmm::exec_policy(stream),
thrust::make_counting_iterator<NodeIndexT>(0),
thrust::make_counting_iterator<NodeIndexT>(0) + num_tokens,
tokens.begin(),
node_token_ids.begin(),
is_node);
thrust_copy_if(rmm::exec_policy(stream),
thrust::make_counting_iterator<NodeIndexT>(0),
thrust::make_counting_iterator<NodeIndexT>(0) + num_tokens,
tokens.begin(),
node_token_ids.begin(),
is_node);

// previous push node_id
// if previous node is a push, then i-1
Expand Down Expand Up @@ -349,7 +393,7 @@ tree_meta_t get_tree_representation(device_span<PdaTokenT const> tokens,
rmm::device_uvector<NodeT> node_categories(num_nodes, stream, mr);
auto const node_categories_it =
thrust::make_transform_output_iterator(node_categories.begin(), token_to_node{});
auto const node_categories_end = thrust::copy_if(
auto const node_categories_end = thrust_copy_if(
rmm::exec_policy(stream), tokens.begin(), tokens.end(), node_categories_it, is_node);
CUDF_EXPECTS(node_categories_end - node_categories_it == num_nodes,
"node category count mismatch");
Expand All @@ -366,13 +410,13 @@ tree_meta_t get_tree_representation(device_span<PdaTokenT const> tokens,
node_range_tuple_it, node_ranges{tokens, token_indices, include_quote_char});

auto const node_range_out_end =
thrust::copy_if(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(0) + num_tokens,
node_range_out_it,
[is_node, tokens_gpu = tokens.begin()] __device__(size_type i) -> bool {
return is_node(tokens_gpu[i]);
});
thrust_copy_if(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(0) + num_tokens,
node_range_out_it,
[is_node, tokens_gpu = tokens.begin()] __device__(size_type i) -> bool {
return is_node(tokens_gpu[i]);
});
CUDF_EXPECTS(node_range_out_end - node_range_out_it == num_nodes, "node range count mismatch");

return {std::move(node_categories),
Expand Down

0 comments on commit 394f414

Please sign in to comment.