Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace thrust/std::get with structured bindings #9915

Merged
merged 3 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions cpp/include/cudf/detail/merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,10 @@ struct tagged_element_relational_comparator {
__device__ weak_ordering compare(index_type lhs_tagged_index,
index_type rhs_tagged_index) const noexcept
{
side const l_side = thrust::get<0>(lhs_tagged_index);
side const r_side = thrust::get<0>(rhs_tagged_index);

cudf::size_type const l_indx = thrust::get<1>(lhs_tagged_index);
cudf::size_type const r_indx = thrust::get<1>(rhs_tagged_index);
auto const [l_side, l_indx] = lhs_tagged_index;
auto const [r_side, r_indx] = rhs_tagged_index;

column_device_view const* ptr_left_dview{l_side == side::LEFT ? &lhs : &rhs};

column_device_view const* ptr_right_dview{r_side == side::LEFT ? &lhs : &rhs};

auto erl_comparator = element_relational_comparator(
Expand Down
7 changes: 2 additions & 5 deletions cpp/include/cudf/strings/detail/merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ std::unique_ptr<column> merge(strings_column_view const& lhs,

// build offsets column
auto offsets_transformer = [d_lhs, d_rhs] __device__(auto index_pair) {
auto side = thrust::get<0>(index_pair);
auto index = thrust::get<1>(index_pair);
auto const [side, index] = index_pair;
if (side == side::LEFT ? d_lhs.is_null(index) : d_rhs.is_null(index)) return 0;
auto d_str =
side == side::LEFT ? d_lhs.element<string_view>(index) : d_rhs.element<string_view>(index);
Expand All @@ -90,9 +89,7 @@ std::unique_ptr<column> merge(strings_column_view const& lhs,
thrust::make_counting_iterator<size_type>(0),
strings_count,
[d_lhs, d_rhs, begin, d_offsets, d_chars] __device__(size_type idx) {
index_type index_pair = begin[idx];
auto side = thrust::get<0>(index_pair);
auto index = thrust::get<1>(index_pair);
auto const [side, index] = begin[idx];
if (side == side::LEFT ? d_lhs.is_null(index) : d_rhs.is_null(index)) return;
auto d_str = side == side::LEFT ? d_lhs.element<string_view>(index)
: d_rhs.element<string_view>(index);
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/dictionary/detail/merge.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@ std::unique_ptr<column> merge(dictionary_column_view const& lcol,
row_order.end(),
output_iter,
[lcol_iter, rcol_iter] __device__(auto const& index_pair) {
auto index = thrust::get<1>(index_pair);
return (thrust::get<0>(index_pair) == cudf::detail::side::LEFT
? lcol_iter[index]
: rcol_iter[index]);
auto const [side, index] = index_pair;
return side == cudf::detail::side::LEFT ? lcol_iter[index] : rcol_iter[index];
});

// build dictionary; the validity mask is updated by the caller
Expand Down
8 changes: 3 additions & 5 deletions cpp/src/merge/merge.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ __global__ void materialize_merged_bitmask_kernel(
auto active_threads = __ballot_sync(0xffffffff, destination_row < num_destination_rows);

while (destination_row < num_destination_rows) {
index_type const& merged_idx = merged_indices[destination_row];
side const src_side = thrust::get<0>(merged_idx);
size_type const src_row = thrust::get<1>(merged_idx);
index_type const& merged_idx = merged_indices[destination_row];
auto const [src_side, src_row] = merged_idx;
codereport marked this conversation as resolved.
Show resolved Hide resolved
bool const from_left{src_side == side::LEFT};
bool source_bit_is_valid{true};
if (left_have_valids && from_left) {
Expand Down Expand Up @@ -284,8 +283,7 @@ struct column_merger {
row_order_.end(),
merged_view.begin<Element>(),
[d_lcol, d_rcol] __device__(index_type const& index_pair) {
auto side = thrust::get<0>(index_pair);
auto index = thrust::get<1>(index_pair);
auto const [side, index] = index_pair;
return side == side::LEFT ? d_lcol[index] : d_rcol[index];
});

Expand Down
5 changes: 2 additions & 3 deletions cpp/tests/copying/sample_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ struct SampleBasicTest : public SampleTest,

TEST_P(SampleBasicTest, CombinationOfParameters)
{
cudf::size_type const table_size = 1024;
cudf::size_type const n_samples = std::get<0>(GetParam());
cudf::sample_with_replacement multi_smpl = std::get<1>(GetParam());
cudf::size_type const table_size = 1024;
auto const [n_samples, multi_smpl] = GetParam();

auto data = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i; });
cudf::test::fixed_width_column_wrapper<int16_t> col1(data, data + table_size);
Expand Down
9 changes: 4 additions & 5 deletions cpp/tests/interop/from_arrow_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,10 @@ struct FromArrowTestSlice

TEST_P(FromArrowTestSlice, SliceTest)
{
auto tables = get_tables(10000);
auto cudf_table_view = tables.first->view();
auto arrow_table = tables.second;
auto start = std::get<0>(GetParam());
auto end = std::get<1>(GetParam());
auto tables = get_tables(10000);
auto cudf_table_view = tables.first->view();
auto arrow_table = tables.second;
auto const [start, end] = GetParam();

auto sliced_cudf_table = cudf::slice(cudf_table_view, {start, end})[0];
auto expected_cudf_table = cudf::table{sliced_cudf_table};
Expand Down
9 changes: 4 additions & 5 deletions cpp/tests/interop/to_arrow_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,10 @@ struct ToArrowTestSlice

TEST_P(ToArrowTestSlice, SliceTest)
{
auto tables = get_tables(10000);
auto cudf_table_view = tables.first->view();
auto arrow_table = tables.second;
auto start = std::get<0>(GetParam());
auto end = std::get<1>(GetParam());
auto tables = get_tables(10000);
auto cudf_table_view = tables.first->view();
auto arrow_table = tables.second;
auto const [start, end] = GetParam();

auto sliced_cudf_table = cudf::slice(cudf_table_view, {start, end})[0];
auto expected_arrow_table = arrow_table->Slice(start, end - start);
Expand Down
8 changes: 3 additions & 5 deletions cpp/tests/io/orc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,8 +1137,7 @@ struct OrcWriterTestDecimal : public OrcWriterTest,

TEST_P(OrcWriterTestDecimal, Decimal64)
{
auto const num_rows = std::get<0>(GetParam());
auto const scale = std::get<1>(GetParam());
auto const [num_rows, scale] = GetParam();

// Using int16_t because scale causes values to overflow if they already require 32 bits
auto const vals = random_values<int32_t>(num_rows);
Expand Down Expand Up @@ -1241,9 +1240,8 @@ struct OrcWriterTestStripes

TEST_P(OrcWriterTestStripes, StripeSize)
{
constexpr auto num_rows = 1000000;
auto size_bytes = std::get<0>(GetParam());
auto size_rows = std::get<1>(GetParam());
constexpr auto num_rows = 1000000;
auto const [size_bytes, size_rows] = GetParam();

const auto seq_col = random_values<int>(num_rows);
const auto validity =
Expand Down
3 changes: 1 addition & 2 deletions cpp/tests/transform/mask_to_bools_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ struct MaskToBoolsTest
TEST_P(MaskToBoolsTest, LargeDataSizeTest)
{
auto data = std::vector<bool>(10000);
cudf::size_type const begin_bit = std::get<0>(GetParam());
cudf::size_type const end_bit = std::get<1>(GetParam());
auto const [begin_bit, end_bit] = GetParam();
std::transform(data.cbegin(), data.cend(), data.begin(), [](auto val) {
return rand() % 2 == 0 ? true : false;
});
Expand Down