From e8c2bdc3887ac282fb66c7b0f730f9fab2325e18 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Thu, 20 Oct 2022 14:54:44 -0400 Subject: [PATCH] Use gather-based strings factory in cudf::strings::strip --- cpp/benchmarks/string/filter.cpp | 26 ++++++++------------ cpp/src/strings/strip.cu | 42 +++++++++++--------------------- 2 files changed, 24 insertions(+), 44 deletions(-) diff --git a/cpp/benchmarks/string/filter.cpp b/cpp/benchmarks/string/filter.cpp index 4001fef5da6..f7f5ceee0a2 100644 --- a/cpp/benchmarks/string/filter.cpp +++ b/cpp/benchmarks/string/filter.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "string_bench_args.hpp" + #include #include #include @@ -27,7 +29,6 @@ #include #include -#include #include enum FilterAPI { filter, filter_chars, strip }; @@ -62,21 +63,14 @@ static void BM_filter_chars(benchmark::State& state, FilterAPI api) static void generate_bench_args(benchmark::internal::Benchmark* b) { - int const min_rows = 1 << 12; - int const max_rows = 1 << 24; - int const row_mult = 8; - int const min_rowlen = 1 << 5; - int const max_rowlen = 1 << 13; - int const len_mult = 4; - for (int row_count = min_rows; row_count <= max_rows; row_count *= row_mult) { - for (int rowlen = min_rowlen; rowlen <= max_rowlen; rowlen *= len_mult) { - // avoid generating combinations that exceed the cudf column limit - size_t total_chars = static_cast(row_count) * rowlen; - if (total_chars < static_cast(std::numeric_limits::max())) { - b->Args({row_count, rowlen}); - } - } - } + int const min_rows = 1 << 12; + int const max_rows = 1 << 24; + int const row_multiplier = 8; + int const min_length = 1 << 5; + int const max_length = 1 << 13; + int const length_multiplier = 2; + generate_string_bench_args( + b, min_rows, max_rows, row_multiplier, min_length, max_length, length_multiplier); } #define STRINGS_BENCHMARK_DEFINE(name) \ diff --git a/cpp/src/strings/strip.cu b/cpp/src/strings/strip.cu index 5d51a5a7bed..42f0183f245 100644 --- a/cpp/src/strings/strip.cu +++ b/cpp/src/strings/strip.cu @@ -15,11 +15,9 @@ */ #include -#include -#include #include +#include #include -#include #include #include #include @@ -35,35 +33,24 @@ namespace detail { namespace { /** - * @brief Strip characters from the beginning and/or end of a string. + * @brief Strip characters from the beginning and/or end of a string * * This functor strips the beginning and/or end of each string * of any characters found in d_to_strip or whitespace if * d_to_strip is empty. * */ -struct strip_fn { +struct strip_transform_fn { column_device_view const d_strings; side_type const side; // right, left, or both string_view const d_to_strip; - int32_t* d_offsets{}; - char* d_chars{}; - __device__ void operator()(size_type idx) + __device__ string_index_pair operator()(size_type idx) { - if (d_strings.is_null(idx)) { - if (!d_chars) d_offsets[idx] = 0; - return; - } - - auto const d_str = d_strings.element(idx); - + if (d_strings.is_null(idx)) { return string_index_pair{nullptr, 0}; } + auto const d_str = d_strings.element(idx); auto const d_stripped = strip(d_str, d_to_strip, side); - if (d_chars) { - copy_string(d_chars + d_offsets[idx], d_stripped); - } else { - d_offsets[idx] = d_stripped.size_bytes(); - } + return string_index_pair{d_stripped.data(), d_stripped.size_bytes()}; } }; @@ -83,15 +70,14 @@ std::unique_ptr strip( auto const d_column = column_device_view::create(input.parent(), stream); - // this utility calls the strip_fn to build the offsets and chars columns - auto children = cudf::strings::detail::make_strings_children( - strip_fn{*d_column, side, d_to_strip}, input.size(), stream, mr); + auto result = rmm::device_uvector(input.size(), stream); + thrust::transform(rmm::exec_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(input.size()), + result.begin(), + strip_transform_fn{*d_column, side, d_to_strip}); - return make_strings_column(input.size(), - std::move(children.first), - std::move(children.second), - input.null_count(), - cudf::detail::copy_bitmask(input.parent(), stream, mr)); + return make_strings_column(result.begin(), result.end(), stream, mr); } } // namespace detail