Skip to content

Commit

Permalink
Use gather-based strings factory in cudf::strings::strip
Browse files Browse the repository at this point in the history
  • Loading branch information
davidwendt committed Oct 20, 2022
1 parent 6ca2ceb commit e8c2bdc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 44 deletions.
26 changes: 10 additions & 16 deletions cpp/benchmarks/string/filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#include "string_bench_args.hpp"

#include <benchmarks/common/generate_input.hpp>
#include <benchmarks/fixture/benchmark_fixture.hpp>
#include <benchmarks/synchronization/synchronization.hpp>
Expand All @@ -27,7 +29,6 @@
#include <cudf/strings/translate.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <limits>
#include <vector>

enum FilterAPI { filter, filter_chars, strip };
Expand Down Expand Up @@ -62,21 +63,14 @@ static void BM_filter_chars(benchmark::State& state, FilterAPI api)

static void generate_bench_args(benchmark::internal::Benchmark* b)
{
int const min_rows = 1 << 12;
int const max_rows = 1 << 24;
int const row_mult = 8;
int const min_rowlen = 1 << 5;
int const max_rowlen = 1 << 13;
int const len_mult = 4;
for (int row_count = min_rows; row_count <= max_rows; row_count *= row_mult) {
for (int rowlen = min_rowlen; rowlen <= max_rowlen; rowlen *= len_mult) {
// avoid generating combinations that exceed the cudf column limit
size_t total_chars = static_cast<size_t>(row_count) * rowlen;
if (total_chars < static_cast<size_t>(std::numeric_limits<cudf::size_type>::max())) {
b->Args({row_count, rowlen});
}
}
}
int const min_rows = 1 << 12;
int const max_rows = 1 << 24;
int const row_multiplier = 8;
int const min_length = 1 << 5;
int const max_length = 1 << 13;
int const length_multiplier = 2;
generate_string_bench_args(
b, min_rows, max_rows, row_multiplier, min_length, max_length, length_multiplier);
}

#define STRINGS_BENCHMARK_DEFINE(name) \
Expand Down
42 changes: 14 additions & 28 deletions cpp/src/strings/strip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
*/

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/detail/strip.cuh>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/strings/strip.hpp>
Expand All @@ -35,35 +33,24 @@ namespace detail {
namespace {

/**
* @brief Strip characters from the beginning and/or end of a string.
* @brief Strip characters from the beginning and/or end of a string
*
* This functor strips the beginning and/or end of each string
* of any characters found in d_to_strip or whitespace if
* d_to_strip is empty.
*
*/
struct strip_fn {
struct strip_transform_fn {
column_device_view const d_strings;
side_type const side; // right, left, or both
string_view const d_to_strip;
int32_t* d_offsets{};
char* d_chars{};

__device__ void operator()(size_type idx)
__device__ string_index_pair operator()(size_type idx)
{
if (d_strings.is_null(idx)) {
if (!d_chars) d_offsets[idx] = 0;
return;
}

auto const d_str = d_strings.element<string_view>(idx);

if (d_strings.is_null(idx)) { return string_index_pair{nullptr, 0}; }
auto const d_str = d_strings.element<string_view>(idx);
auto const d_stripped = strip(d_str, d_to_strip, side);
if (d_chars) {
copy_string(d_chars + d_offsets[idx], d_stripped);
} else {
d_offsets[idx] = d_stripped.size_bytes();
}
return string_index_pair{d_stripped.data(), d_stripped.size_bytes()};
}
};

Expand All @@ -83,15 +70,14 @@ std::unique_ptr<column> strip(

auto const d_column = column_device_view::create(input.parent(), stream);

// this utility calls the strip_fn to build the offsets and chars columns
auto children = cudf::strings::detail::make_strings_children(
strip_fn{*d_column, side, d_to_strip}, input.size(), stream, mr);
auto result = rmm::device_uvector<string_index_pair>(input.size(), stream);
thrust::transform(rmm::exec_policy(stream),
thrust::counting_iterator<size_type>(0),
thrust::counting_iterator<size_type>(input.size()),
result.begin(),
strip_transform_fn{*d_column, side, d_to_strip});

return make_strings_column(input.size(),
std::move(children.first),
std::move(children.second),
input.null_count(),
cudf::detail::copy_bitmask(input.parent(), stream, mr));
return make_strings_column(result.begin(), result.end(), stream, mr);
}

} // namespace detail
Expand Down

0 comments on commit e8c2bdc

Please sign in to comment.