Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use gather-based strings factory in cudf::strings::strip #11954

Merged
merged 1 commit into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions cpp/benchmarks/string/filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#include "string_bench_args.hpp"

#include <benchmarks/common/generate_input.hpp>
#include <benchmarks/fixture/benchmark_fixture.hpp>
#include <benchmarks/synchronization/synchronization.hpp>
Expand All @@ -27,7 +29,6 @@
#include <cudf/strings/translate.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <limits>
#include <vector>

enum FilterAPI { filter, filter_chars, strip };
Expand Down Expand Up @@ -62,21 +63,14 @@ static void BM_filter_chars(benchmark::State& state, FilterAPI api)

static void generate_bench_args(benchmark::internal::Benchmark* b)
{
int const min_rows = 1 << 12;
int const max_rows = 1 << 24;
int const row_mult = 8;
int const min_rowlen = 1 << 5;
int const max_rowlen = 1 << 13;
int const len_mult = 4;
for (int row_count = min_rows; row_count <= max_rows; row_count *= row_mult) {
for (int rowlen = min_rowlen; rowlen <= max_rowlen; rowlen *= len_mult) {
// avoid generating combinations that exceed the cudf column limit
size_t total_chars = static_cast<size_t>(row_count) * rowlen;
if (total_chars < static_cast<size_t>(std::numeric_limits<cudf::size_type>::max())) {
b->Args({row_count, rowlen});
}
}
}
int const min_rows = 1 << 12;
int const max_rows = 1 << 24;
int const row_multiplier = 8;
int const min_length = 1 << 5;
int const max_length = 1 << 13;
int const length_multiplier = 2;
generate_string_bench_args(
b, min_rows, max_rows, row_multiplier, min_length, max_length, length_multiplier);
}

#define STRINGS_BENCHMARK_DEFINE(name) \
Expand Down
42 changes: 14 additions & 28 deletions cpp/src/strings/strip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
*/

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/detail/strip.cuh>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/strings/strip.hpp>
Expand All @@ -35,35 +33,24 @@ namespace detail {
namespace {

/**
* @brief Strip characters from the beginning and/or end of a string.
* @brief Strip characters from the beginning and/or end of a string
*
* This functor strips the beginning and/or end of each string
* of any characters found in d_to_strip or whitespace if
* d_to_strip is empty.
*
*/
struct strip_fn {
struct strip_transform_fn {
column_device_view const d_strings;
side_type const side; // right, left, or both
string_view const d_to_strip;
int32_t* d_offsets{};
char* d_chars{};

__device__ void operator()(size_type idx)
__device__ string_index_pair operator()(size_type idx)
{
if (d_strings.is_null(idx)) {
if (!d_chars) d_offsets[idx] = 0;
return;
}

auto const d_str = d_strings.element<string_view>(idx);

if (d_strings.is_null(idx)) { return string_index_pair{nullptr, 0}; }
auto const d_str = d_strings.element<string_view>(idx);
auto const d_stripped = strip(d_str, d_to_strip, side);
if (d_chars) {
copy_string(d_chars + d_offsets[idx], d_stripped);
} else {
d_offsets[idx] = d_stripped.size_bytes();
}
return string_index_pair{d_stripped.data(), d_stripped.size_bytes()};
}
};

Expand All @@ -83,15 +70,14 @@ std::unique_ptr<column> strip(

auto const d_column = column_device_view::create(input.parent(), stream);

// this utility calls the strip_fn to build the offsets and chars columns
auto children = cudf::strings::detail::make_strings_children(
strip_fn{*d_column, side, d_to_strip}, input.size(), stream, mr);
auto result = rmm::device_uvector<string_index_pair>(input.size(), stream);
thrust::transform(rmm::exec_policy(stream),
thrust::counting_iterator<size_type>(0),
thrust::counting_iterator<size_type>(input.size()),
result.begin(),
strip_transform_fn{*d_column, side, d_to_strip});

return make_strings_column(input.size(),
std::move(children.first),
std::move(children.second),
input.null_count(),
cudf::detail::copy_bitmask(input.parent(), stream, mr));
return make_strings_column(result.begin(), result.end(), stream, mr);
upsj marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace detail
Expand Down