diff --git a/cpp/src/strings/replace/multi_re.cu b/cpp/src/strings/replace/multi_re.cu index 84c2466b9ed..3eb551ead18 100644 --- a/cpp/src/strings/replace/multi_re.cu +++ b/cpp/src/strings/replace/multi_re.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,8 +21,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -97,13 +97,12 @@ struct replace_multi_regex_fn { } // all the ranges have been updated from each regex match; // look for any that match at this character position (ch_pos) - auto itr = thrust::find_if( - thrust::seq, d_ranges, d_ranges + number_of_patterns, [ch_pos] __device__(auto range) { + auto itr = + thrust::find_if(thrust::seq, d_ranges, d_ranges + number_of_patterns, [ch_pos](auto range) { return range.first == ch_pos; }); - if (itr != - d_ranges + - number_of_patterns) { // match found, compute and replace the string in the output + if (itr != d_ranges + number_of_patterns) { + // match found, compute and replace the string in the output size_type ptn_idx = static_cast(itr - d_ranges); size_type begin = d_ranges[ptn_idx].first; size_type end = d_ranges[ptn_idx].second; @@ -149,22 +148,27 @@ std::unique_ptr replace_re( auto repls_column = column_device_view::create(repls.parent(), stream); auto d_repls = *repls_column; auto d_flags = get_character_flags_table(); + // compile regexes into device objects size_type regex_insts = 0; std::vector>> h_progs; - rmm::device_vector progs; + thrust::host_vector progs; for (auto itr = patterns.begin(); itr != patterns.end(); ++itr) { - auto prog = reprog_device::create(*itr, d_flags, strings_count, stream); - auto insts = prog->insts_counts(); - if (insts > regex_insts) regex_insts = insts; + auto prog = reprog_device::create(*itr, d_flags, strings_count, stream); + regex_insts = std::max(regex_insts, prog->insts_counts()); progs.push_back(*prog); h_progs.emplace_back(std::move(prog)); } - auto d_progs = progs.data().get(); - // copy null mask - auto null_mask = copy_bitmask(strings.parent()); - auto null_count = strings.null_count(); + // copy all the reprog_device instances to a device memory array + rmm::device_buffer progs_buffer{sizeof(reprog_device) * progs.size()}; + CUDA_TRY(cudaMemcpyAsync(progs_buffer.data(), + progs.data(), + progs.size() * sizeof(reprog_device), + cudaMemcpyHostToDevice, + stream.value())); + reprog_device* d_progs = reinterpret_cast(progs_buffer.data()); + // create working buffer for ranges pairs rmm::device_vector found_ranges(patterns.size() * strings_count); auto d_found_ranges = found_ranges.data().get(); @@ -178,7 +182,7 @@ std::unique_ptr replace_re( replace_multi_regex_fn{ d_strings, d_progs, static_cast(progs.size()), d_found_ranges, d_repls}, strings_count, - null_count, + strings.null_count(), stream, mr); else if (regex_insts <= RX_MEDIUM_INSTS) @@ -186,7 +190,7 @@ std::unique_ptr replace_re( replace_multi_regex_fn{ d_strings, d_progs, static_cast(progs.size()), d_found_ranges, d_repls}, strings_count, - null_count, + strings.null_count(), stream, mr); else @@ -194,15 +198,15 @@ std::unique_ptr replace_re( replace_multi_regex_fn{ d_strings, d_progs, static_cast(progs.size()), d_found_ranges, d_repls}, strings_count, - null_count, + strings.null_count(), stream, mr); return make_strings_column(strings_count, std::move(children.first), std::move(children.second), - null_count, - std::move(null_mask), + strings.null_count(), + cudf::detail::copy_bitmask(strings.parent(), stream, mr), stream, mr); } diff --git a/cpp/tests/merge/merge_string_test.cpp b/cpp/tests/merge/merge_string_test.cpp index 625a947d8e8..0cd5d68ea39 100644 --- a/cpp/tests/merge/merge_string_test.cpp +++ b/cpp/tests/merge/merge_string_test.cpp @@ -188,12 +188,10 @@ TYPED_TEST(MergeStringTest, Merge2StringKeyColumns) "hi", "hj"}); - auto seq_out2 = cudf::detail::make_counting_transform_iterator(0, [outputRows](auto row) { - if (cudf::type_to_id() == cudf::type_id::BOOL8) { - return (row % 2 == 0) ? 1 : 0; - } else - return (row); - }); + auto seq_out2 = cudf::detail::make_counting_transform_iterator( + 0, [bool8 = (cudf::type_to_id() == cudf::type_id::BOOL8)](auto row) { + return bool8 ? static_cast(row % 2 == 0) : row; + }); fixed_width_column_wrapper expectedDataWrap2( seq_out2, seq_out2 + outputRows); @@ -376,12 +374,11 @@ TYPED_TEST(MergeStringTest, Merge2StringKeyNullColumns) "hj"}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0}); - auto seq_out2 = cudf::detail::make_counting_transform_iterator(0, [outputRows](auto row) { - if (cudf::type_to_id() == cudf::type_id::BOOL8) { - return (row % 2 == 0) ? 1 : 0; - } else - return (row); - }); + auto seq_out2 = cudf::detail::make_counting_transform_iterator( + 0, [bool8 = (cudf::type_to_id() == cudf::type_id::BOOL8)](auto row) { + return bool8 ? static_cast(row % 2 == 0) : row; + }); + fixed_width_column_wrapper expectedDataWrap2( seq_out2, seq_out2 + outputRows); diff --git a/cpp/tests/partitioning/round_robin_test.cpp b/cpp/tests/partitioning/round_robin_test.cpp index 59bdc80dc07..160365834fe 100644 --- a/cpp/tests/partitioning/round_robin_test.cpp +++ b/cpp/tests/partitioning/round_robin_test.cpp @@ -55,12 +55,10 @@ TYPED_TEST(RoundRobinTest, RoundRobinPartitions13_3) cudf::size_type inputRows = static_cast(rrColWrap1).size(); - auto sequence_l = cudf::detail::make_counting_transform_iterator(0, [](auto row) { - if (cudf::type_to_id() == cudf::type_id::BOOL8) { - return (row % 2 == 0) ? 1 : 0; - } else - return row; - }); + auto sequence_l = cudf::detail::make_counting_transform_iterator( + 0, [bool8 = (cudf::type_to_id() == cudf::type_id::BOOL8)](auto row) { + return bool8 ? static_cast(row % 2 == 0) : row; + }); cudf::test::fixed_width_column_wrapper rrColWrap2(sequence_l, sequence_l + inputRows); @@ -191,12 +189,10 @@ TYPED_TEST(RoundRobinTest, RoundRobinPartitions11_3) cudf::size_type inputRows = static_cast(rrColWrap1).size(); - auto sequence_l = cudf::detail::make_counting_transform_iterator(0, [](auto row) { - if (cudf::type_to_id() == cudf::type_id::BOOL8) { - return (row % 2 == 0) ? 1 : 0; - } else - return row; - }); + auto sequence_l = cudf::detail::make_counting_transform_iterator( + 0, [bool8 = (cudf::type_to_id() == cudf::type_id::BOOL8)](auto row) { + return bool8 ? static_cast(row % 2 == 0) : row; + }); cudf::test::fixed_width_column_wrapper rrColWrap2(sequence_l, sequence_l + inputRows); @@ -324,12 +320,10 @@ TYPED_TEST(RoundRobinTest, RoundRobinDegeneratePartitions11_15) cudf::size_type inputRows = static_cast(rrColWrap1).size(); - auto sequence_l = cudf::detail::make_counting_transform_iterator(0, [](auto row) { - if (cudf::type_to_id() == cudf::type_id::BOOL8) { - return (row % 2 == 0) ? 1 : 0; - } else - return row; - }); + auto sequence_l = cudf::detail::make_counting_transform_iterator( + 0, [bool8 = (cudf::type_to_id() == cudf::type_id::BOOL8)](auto row) { + return bool8 ? static_cast(row % 2 == 0) : row; + }); cudf::test::fixed_width_column_wrapper rrColWrap2(sequence_l, sequence_l + inputRows); @@ -460,12 +454,10 @@ TYPED_TEST(RoundRobinTest, RoundRobinDegeneratePartitions11_11) cudf::size_type inputRows = static_cast(rrColWrap1).size(); - auto sequence_l = cudf::detail::make_counting_transform_iterator(0, [](auto row) { - if (cudf::type_to_id() == cudf::type_id::BOOL8) { - return (row % 2 == 0) ? 1 : 0; - } else - return row; - }); + auto sequence_l = cudf::detail::make_counting_transform_iterator( + 0, [bool8 = (cudf::type_to_id() == cudf::type_id::BOOL8)](auto row) { + return bool8 ? static_cast(row % 2 == 0) : row; + }); cudf::test::fixed_width_column_wrapper rrColWrap2(sequence_l, sequence_l + inputRows); @@ -528,12 +520,10 @@ TYPED_TEST(RoundRobinTest, RoundRobinNPartitionsDivideNRows) cudf::size_type inputRows = static_cast(rrColWrap1).size(); - auto sequence_l = cudf::detail::make_counting_transform_iterator(0, [](auto row) { - if (cudf::type_to_id() == cudf::type_id::BOOL8) { - return (row % 2 == 0) ? 1 : 0; - } else - return row; - }); + auto sequence_l = cudf::detail::make_counting_transform_iterator( + 0, [bool8 = (cudf::type_to_id() == cudf::type_id::BOOL8)](auto row) { + return bool8 ? static_cast(row % 2 == 0) : row; + }); cudf::test::fixed_width_column_wrapper rrColWrap2(sequence_l, sequence_l + inputRows); @@ -644,12 +634,10 @@ TYPED_TEST(RoundRobinTest, RoundRobinSinglePartition) cudf::size_type inputRows = static_cast(rrColWrap1).size(); - auto sequence_l = cudf::detail::make_counting_transform_iterator(0, [](auto row) { - if (cudf::type_to_id() == cudf::type_id::BOOL8) { - return (row % 2 == 0) ? 1 : 0; - } else - return row; - }); + auto sequence_l = cudf::detail::make_counting_transform_iterator( + 0, [bool8 = (cudf::type_to_id() == cudf::type_id::BOOL8)](auto row) { + return bool8 ? static_cast(row % 2 == 0) : row; + }); cudf::test::fixed_width_column_wrapper rrColWrap2(sequence_l, sequence_l + inputRows);