diff --git a/cpp/src/replace/clamp.cu b/cpp/src/replace/clamp.cu index 950cb484ddf..0c934533d54 100644 --- a/cpp/src/replace/clamp.cu +++ b/cpp/src/replace/clamp.cu @@ -47,37 +47,43 @@ namespace cudf { namespace detail { namespace { -template -std::pair, std::unique_ptr> form_offsets_and_char_column( - cudf::column_device_view input, - size_type, - Transformer offsets_transformer, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - std::unique_ptr offsets_column{}; - auto strings_count = input.size(); - size_type bytes = 0; - if (input.nullable()) { - auto input_begin = - cudf::detail::make_null_replacement_iterator(input, string_view{}); - auto offsets_transformer_itr = - thrust::make_transform_iterator(input_begin, offsets_transformer); - std::tie(offsets_column, bytes) = cudf::detail::make_offsets_child_column( - offsets_transformer_itr, offsets_transformer_itr + strings_count, stream, mr); - } else { - auto offsets_transformer_itr = - thrust::make_transform_iterator(input.begin(), offsets_transformer); - std::tie(offsets_column, bytes) = cudf::detail::make_offsets_child_column( - offsets_transformer_itr, offsets_transformer_itr + strings_count, stream, mr); +template +struct clamp_strings_fn { + column_device_view const d_strings; + OptionalScalarIterator lo_itr; + ReplaceScalarIterator lo_replace_itr; + OptionalScalarIterator hi_itr; + ReplaceScalarIterator hi_replace_itr; + size_type* d_offsets{}; + char* d_chars{}; + + __device__ void operator()(size_type idx) const + { + if (d_strings.is_null(idx)) { + if (!d_chars) { d_offsets[idx] = 0; } + return; + } + auto const element = d_strings.element(idx); + auto const d_lo = (*lo_itr).value_or(element); + auto const d_hi = (*hi_itr).value_or(element); + auto const d_lo_replace = *(*lo_replace_itr); + auto const d_hi_replace = *(*hi_replace_itr); + auto d_output = d_chars ? d_chars + d_offsets[idx] : nullptr; + + auto d_str = [d_lo, d_lo_replace, d_hi, d_hi_replace, element] { + if (element < d_lo) { return d_lo_replace; } + if (d_hi < element) { return d_hi_replace; } + return element; + }(); + + if (d_output) { + cudf::strings::detail::copy_string(d_output, d_str); + } else { + d_offsets[idx] = d_str.size_bytes(); + } } - - // build chars column - auto chars_column = cudf::strings::detail::create_chars_child_column(bytes, stream, mr); - - return std::pair(std::move(offsets_column), std::move(chars_column)); -} +}; template std::unique_ptr clamp_string_column(strings_column_view const& input, @@ -90,58 +96,11 @@ std::unique_ptr clamp_string_column(strings_column_view const& inp { auto input_device_column = column_device_view::create(input.parent(), stream); auto d_input = *input_device_column; - size_type null_count = input.null_count(); - - // build offset column - auto offsets_transformer = [lo_itr, hi_itr, lo_replace_itr, hi_replace_itr] __device__( - string_view element, bool is_valid = true) { - const auto d_lo = (*lo_itr).value_or(element); - const auto d_hi = (*hi_itr).value_or(element); - const auto d_lo_replace = *(*lo_replace_itr); - const auto d_hi_replace = *(*hi_replace_itr); - size_type bytes = 0; - - if (is_valid) { - if (element < d_lo) { - bytes = d_lo_replace.size_bytes(); - } else if (d_hi < element) { - bytes = d_hi_replace.size_bytes(); - } else { - bytes = element.size_bytes(); - } - } - return bytes; - }; + auto fn = clamp_strings_fn{ + d_input, lo_itr, lo_replace_itr, hi_itr, hi_replace_itr}; auto [offsets_column, chars_column] = - form_offsets_and_char_column(d_input, null_count, offsets_transformer, stream, mr); - - auto d_offsets = offsets_column->view().template data(); - auto d_chars = chars_column->mutable_view().template data(); - // fill in chars - auto copy_transformer = - [d_input, lo_itr, hi_itr, lo_replace_itr, hi_replace_itr, d_offsets, d_chars] __device__( - size_type idx) { - if (d_input.is_null(idx)) { return; } - auto input_element = d_input.element(idx); - const auto d_lo = (*lo_itr).value_or(input_element); - const auto d_hi = (*hi_itr).value_or(input_element); - const auto d_lo_replace = *(*lo_replace_itr); - const auto d_hi_replace = *(*hi_replace_itr); - - if (input_element < d_lo) { - memcpy(d_chars + d_offsets[idx], d_lo_replace.data(), d_lo_replace.size_bytes()); - } else if (d_hi < input_element) { - memcpy(d_chars + d_offsets[idx], d_hi_replace.data(), d_hi_replace.size_bytes()); - } else { - memcpy(d_chars + d_offsets[idx], input_element.data(), input_element.size_bytes()); - } - }; - - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - input.size(), - copy_transformer); + cudf::strings::detail::make_strings_children(fn, input.size(), stream, mr); return make_strings_column(input.size(), std::move(offsets_column), diff --git a/cpp/tests/replace/clamp_test.cpp b/cpp/tests/replace/clamp_test.cpp index 5b276668b8c..a13829c5abc 100644 --- a/cpp/tests/replace/clamp_test.cpp +++ b/cpp/tests/replace/clamp_test.cpp @@ -381,7 +381,7 @@ struct ClampStringTest : public cudf::test::BaseFixture {}; TEST_F(ClampStringTest, WithNullableColumn) { - std::vector strings{"A", "b", "c", "D", "e", "F", "G", "H", "i", "j", "B"}; + std::vector strings{"A", "b", "c", "", "e", "F", "G", "H", "", "j", "B"}; std::vector valids{1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1}; cudf::test::strings_column_wrapper input(strings.begin(), strings.end(), valids.begin()); @@ -391,7 +391,7 @@ TEST_F(ClampStringTest, WithNullableColumn) lo->set_valid_async(true); hi->set_valid_async(true); - std::vector expected_strings{"B", "b", "c", "D", "e", "F", "G", "H", "i", "e", "B"}; + std::vector expected_strings{"B", "b", "c", "", "e", "F", "G", "H", "", "e", "B"}; cudf::test::strings_column_wrapper expected( expected_strings.begin(), expected_strings.end(), valids.begin()); @@ -423,7 +423,7 @@ TEST_F(ClampStringTest, WithNonNullableColumn) TEST_F(ClampStringTest, WithNullableColumnNullLow) { - std::vector strings{"A", "b", "c", "D", "e", "F", "G", "H", "i", "j", "B"}; + std::vector strings{"A", "b", "c", "", "e", "F", "G", "H", "", "j", "B"}; std::vector valids{1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1}; cudf::test::strings_column_wrapper input(strings.begin(), strings.end(), valids.begin()); @@ -433,7 +433,7 @@ TEST_F(ClampStringTest, WithNullableColumnNullLow) lo->set_valid_async(false); hi->set_valid_async(true); - std::vector expected_strings{"A", "b", "c", "D", "e", "F", "G", "H", "i", "e", "B"}; + std::vector expected_strings{"A", "b", "c", "", "e", "F", "G", "H", "", "e", "B"}; cudf::test::strings_column_wrapper expected( expected_strings.begin(), expected_strings.end(), valids.begin()); @@ -445,7 +445,7 @@ TEST_F(ClampStringTest, WithNullableColumnNullLow) TEST_F(ClampStringTest, WithNullableColumnNullHigh) { - std::vector strings{"A", "b", "c", "D", "e", "F", "G", "H", "i", "j", "B"}; + std::vector strings{"A", "b", "c", "", "e", "F", "G", "H", "", "j", "B"}; std::vector valids{1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1}; cudf::test::strings_column_wrapper input(strings.begin(), strings.end(), valids.begin()); @@ -455,7 +455,7 @@ TEST_F(ClampStringTest, WithNullableColumnNullHigh) lo->set_valid_async(true); hi->set_valid_async(false); - std::vector expected_strings{"B", "b", "c", "D", "e", "F", "G", "H", "i", "j", "B"}; + std::vector expected_strings{"B", "b", "c", "", "e", "F", "G", "H", "", "j", "B"}; cudf::test::strings_column_wrapper expected( expected_strings.begin(), expected_strings.end(), valids.begin()); @@ -467,7 +467,7 @@ TEST_F(ClampStringTest, WithNullableColumnNullHigh) TEST_F(ClampStringTest, WithNullableColumnBothLoAndHiNull) { - std::vector strings{"A", "b", "c", "D", "e", "F", "G", "H", "i", "j", "B"}; + std::vector strings{"A", "b", "c", "", "e", "F", "G", "H", "", "j", "B"}; std::vector valids{1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1}; cudf::test::strings_column_wrapper input(strings.begin(), strings.end(), valids.begin()); @@ -484,7 +484,7 @@ TEST_F(ClampStringTest, WithNullableColumnBothLoAndHiNull) TEST_F(ClampStringTest, WithReplaceString) { - std::vector strings{"A", "b", "c", "D", "e", "F", "G", "H", "i", "j", "B"}; + std::vector strings{"A", "b", "c", "", "e", "F", "G", "H", "", "j", "B"}; std::vector valids{1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1}; cudf::test::strings_column_wrapper input(strings.begin(), strings.end(), valids.begin()); @@ -498,7 +498,7 @@ TEST_F(ClampStringTest, WithReplaceString) hi->set_valid_async(true); hi_replace->set_valid_async(true); - std::vector expected_strings{"Z", "b", "c", "D", "e", "F", "G", "H", "z", "z", "B"}; + std::vector expected_strings{"Z", "b", "c", "", "e", "F", "G", "H", "", "z", "B"}; cudf::test::strings_column_wrapper expected( expected_strings.begin(), expected_strings.end(), valids.begin());