Skip to content

Commit

Permalink
Merge branch 'branch-0.19' of https://github.com/rapidsai/cudf into o…
Browse files Browse the repository at this point in the history
…ptimize
  • Loading branch information
skirui-source committed Feb 25, 2021
2 parents 3cdec51 + c80f9db commit 905cd31
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 51 deletions.
16 changes: 8 additions & 8 deletions cpp/include/cudf/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -557,14 +557,14 @@ struct column_gatherer_impl<struct_view, MapItRoot> {
std::back_inserter(output_struct_members),
[&gather_map_begin, &gather_map_end, nullify_out_of_bounds, stream, mr](
cudf::column_view const& col) {
return cudf::type_dispatcher(col.type(),
column_gatherer{},
col,
gather_map_begin,
gather_map_end,
nullify_out_of_bounds,
stream,
mr);
return cudf::type_dispatcher<dispatch_storage_type>(col.type(),
column_gatherer{},
col,
gather_map_begin,
gather_map_end,
nullify_out_of_bounds,
stream,
mr);
});

gather_bitmask(
Expand Down
24 changes: 12 additions & 12 deletions cpp/src/lists/copying/gather.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -96,17 +96,17 @@ std::unique_ptr<column> gather_list_leaf(column_view const& column,
size_type gather_map_size = gd.gather_map_size;

// call the normal gather
auto leaf_column =
cudf::type_dispatcher(column.type(),
cudf::detail::column_gatherer{},
column,
gather_map_begin,
gather_map_begin + gather_map_size,
// note : we don't need to bother checking for out-of-bounds here since
// our inputs at this stage aren't coming from the user.
false,
stream,
mr);
auto leaf_column = cudf::type_dispatcher<dispatch_storage_type>(
column.type(),
cudf::detail::column_gatherer{},
column,
gather_map_begin,
gather_map_begin + gather_map_size,
// note : we don't need to bother checking for out-of-bounds here since
// our inputs at this stage aren't coming from the user.
false,
stream,
mr);

// the column_gatherer doesn't create the null mask because it expects
// that will be done in the gather_bitmask() step. however, gather_bitmask()
Expand Down
61 changes: 32 additions & 29 deletions cpp/src/strings/findall.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/findall.hpp>
#include <cudf/strings/string_view.cuh>
Expand Down Expand Up @@ -118,42 +118,43 @@ std::unique_ptr<table> findall_re(
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource(),
rmm::cuda_stream_view stream = rmm::cuda_stream_default)
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_strings = *strings_column;
auto const strings_count = strings.size();
auto const d_strings = column_device_view::create(strings.parent(), stream);

auto d_flags = detail::get_character_flags_table();
auto const d_flags = detail::get_character_flags_table();
// compile regex into device object
auto prog = reprog_device::create(pattern, d_flags, strings_count, stream);
auto d_prog = *prog;
int regex_insts = prog->insts_counts();
auto const d_prog = reprog_device::create(pattern, d_flags, strings_count, stream);
auto const regex_insts = d_prog->insts_counts();

rmm::device_vector<size_type> find_counts(strings_count);
auto d_find_counts = find_counts.data().get();
rmm::device_uvector<size_type> find_counts(strings_count, stream);
auto d_find_counts = find_counts.data();

if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS))
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_find_counts,
findall_count_fn<RX_STACK_SMALL>{d_strings, d_prog});
findall_count_fn<RX_STACK_SMALL>{*d_strings, *d_prog});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_find_counts,
findall_count_fn<RX_STACK_MEDIUM>{d_strings, d_prog});
findall_count_fn<RX_STACK_MEDIUM>{*d_strings, *d_prog});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_find_counts,
findall_count_fn<RX_STACK_LARGE>{d_strings, d_prog});
findall_count_fn<RX_STACK_LARGE>{*d_strings, *d_prog});

std::vector<std::unique_ptr<column>> results;

size_type columns =
*thrust::max_element(rmm::exec_policy(stream), find_counts.begin(), find_counts.end());
size_type const columns = thrust::reduce(rmm::exec_policy(stream),
find_counts.begin(),
find_counts.end(),
0,
thrust::maximum<size_type>{});
// boundary case: if no columns, return all nulls column (issue #119)
if (columns == 0)
results.emplace_back(std::make_unique<column>(
Expand All @@ -164,30 +165,32 @@ std::unique_ptr<table> findall_re(
strings_count));

for (int32_t column_index = 0; column_index < columns; ++column_index) {
rmm::device_vector<string_index_pair> indices(strings_count);
string_index_pair* d_indices = indices.data().get();
rmm::device_uvector<string_index_pair> indices(strings_count, stream);
auto d_indices = indices.data();

if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS))
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_SMALL>{d_strings, d_prog, column_index, d_find_counts});
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_SMALL>{*d_strings, *d_prog, column_index, d_find_counts});
else if (regex_insts <= RX_MEDIUM_INSTS)
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_MEDIUM>{d_strings, d_prog, column_index, d_find_counts});
findall_fn<RX_STACK_MEDIUM>{*d_strings, *d_prog, column_index, d_find_counts});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_LARGE>{d_strings, d_prog, column_index, d_find_counts});
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_indices,
findall_fn<RX_STACK_LARGE>{*d_strings, *d_prog, column_index, d_find_counts});
//
results.emplace_back(make_strings_column(indices, stream, mr));
results.emplace_back(make_strings_column(indices.begin(), indices.end(), stream, mr));
}
return std::make_unique<table>(std::move(results));
}
Expand Down
6 changes: 4 additions & 2 deletions cpp/tests/collect_list/collect_list_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ template <typename T>
struct TypedCollectListTest : public CollectListTest {
};

using TypesForTest = cudf::test::
Concat<cudf::test::IntegralTypes, cudf::test::FloatingPointTypes, cudf::test::DurationTypes>;
using TypesForTest = cudf::test::Concat<cudf::test::IntegralTypes,
cudf::test::FloatingPointTypes,
cudf::test::DurationTypes,
cudf::test::FixedPointTypes>;

TYPED_TEST_CASE(TypedCollectListTest, TypesForTest);

Expand Down
1 change: 1 addition & 0 deletions cpp/tests/copying/gather_list_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ template <typename T>
class GatherTestListTyped : public cudf::test::BaseFixture {
};
using FixedWidthTypesNotBool = cudf::test::Concat<cudf::test::IntegralTypesNotBool,
cudf::test::FixedPointTypes,
cudf::test::FloatingPointTypes,
cudf::test::DurationTypes,
cudf::test::TimestampTypes>;
Expand Down

0 comments on commit 905cd31

Please sign in to comment.