Skip to content

Commit

Permalink
Performance improvement in cudf::strings::all_characters_of_type (#13259
Browse files Browse the repository at this point in the history
)

Improves performance for `cudf::strings::all_characters_of_type()` API which covers many cudf `is_X` functions. The solution improves performance for all string lengths as measured by the new benchmark included in this PR.
Additionally, the code was cleaned up to help with maintenance and clarity.

Reference #13048

Authors:
  - David Wendt (https://github.com/davidwendt)
  - Karthikeyan (https://github.com/karthikeyann)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Karthikeyan (https://github.com/karthikeyann)

URL: #13259
  • Loading branch information
davidwendt authored May 15, 2023
1 parent 79c0116 commit 2825d5a
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 35 deletions.
3 changes: 2 additions & 1 deletion cpp/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ ConfigureBench(
)

ConfigureNVBench(
STRINGS_NVBENCH string/like.cpp string/reverse.cpp string/lengths.cpp string/case.cpp
STRINGS_NVBENCH string/case.cpp string/char_types.cpp string/lengths.cpp string/like.cpp
string/reverse.cpp
)

# ##################################################################################################
Expand Down
66 changes: 66 additions & 0 deletions cpp/benchmarks/string/char_types.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <benchmarks/common/generate_input.hpp>

#include <cudf/strings/char_types/char_types.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <nvbench/nvbench.cuh>

static void bench_char_types(nvbench::state& state)
{
auto const num_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const api_type = state.get_string("api");

if (static_cast<std::size_t>(num_rows) * static_cast<std::size_t>(row_width) >=
static_cast<std::size_t>(std::numeric_limits<cudf::size_type>::max())) {
state.skip("Skip benchmarks greater than size_type limit");
}

data_profile const table_profile = data_profile_builder().distribution(
cudf::type_id::STRING, distribution_id::NORMAL, 0, row_width);
auto const table =
create_random_table({cudf::type_id::STRING}, row_count{num_rows}, table_profile);
cudf::strings_column_view input(table->view().column(0));
auto input_types = cudf::strings::string_character_types::SPACE;

state.set_cuda_stream(nvbench::make_cuda_stream_view(cudf::get_default_stream().value()));
// gather some throughput statistics as well
auto chars_size = input.chars_size();
state.add_global_memory_reads<nvbench::int8_t>(chars_size); // all bytes are read;
if (api_type == "all") {
state.add_global_memory_writes<nvbench::int8_t>(num_rows); // output is a bool8 per row
} else {
state.add_global_memory_writes<nvbench::int8_t>(chars_size);
}

state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
if (api_type == "all") {
auto result = cudf::strings::all_characters_of_type(input, input_types);
} else {
auto result = cudf::strings::filter_characters_of_type(input, input_types);
}
});
}

NVBENCH_BENCH(bench_char_types)
.set_name("char_types")
.add_int64_axis("row_width", {32, 64, 128, 256, 512, 1024, 2048, 4096})
.add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216})
.add_string_axis("api", {"all", "filter"});
93 changes: 59 additions & 34 deletions cpp/src/strings/char_types/char_types.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,59 +31,84 @@
#include <rmm/cuda_stream_view.hpp>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/logical.h>
#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {
//
std::unique_ptr<column> all_characters_of_type(strings_column_view const& strings,
namespace {

/**
* @brief Returns true for each string where all characters match the given types.
*
* Only the characters that match to `verify_types` are checked.
* Returns false if no characters are checked or one character does not match `types`.
* Returns true if at least one character is checked and all checked characters match `types`.
*/
struct char_types_fn {
column_device_view const d_column;
character_flags_table_type const* d_flags;
string_character_types const types;
string_character_types const verify_types;

__device__ bool operator()(size_type idx) const
{
if (d_column.is_null(idx)) { return false; }
auto const d_str = d_column.element<string_view>(idx);
auto const end = d_str.data() + d_str.size_bytes();

bool type_matched = !d_str.empty(); // require at least one character;
size_type check_count = 0; // count checked characters
for (auto itr = d_str.data(); type_matched && (itr < end); ++itr) {
uint8_t const chr = static_cast<uint8_t>(*itr);
if (is_utf8_continuation_char(chr)) { continue; }
auto u8 = static_cast<char_utf8>(chr); // holds UTF8 value
// using max(int8) here since max(char)=255 on ARM systems
if (u8 > std::numeric_limits<int8_t>::max()) { to_char_utf8(itr, u8); }

// lookup flags in table by codepoint
auto const code_point = utf8_to_codepoint(u8);
auto const flag = code_point <= 0x00'FFFF ? d_flags[code_point] : 0;

if ((verify_types & flag) || // should flag be verified;
(flag == 0 && verify_types == ALL_TYPES)) // special edge case
{
type_matched = (types & flag) > 0;
++check_count;
}
}

return type_matched && (check_count > 0);
}
};
} // namespace

std::unique_ptr<column> all_characters_of_type(strings_column_view const& input,
string_character_types types,
string_character_types verify_types,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_column = *strings_column;
auto d_strings = column_device_view::create(input.parent(), stream);

// create output column
auto results = make_numeric_column(data_type{type_id::BOOL8},
strings_count,
cudf::detail::copy_bitmask(strings.parent(), stream, mr),
strings.null_count(),
auto results = make_numeric_column(data_type{type_id::BOOL8},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);
auto results_view = results->mutable_view();
auto d_results = results_view.data<bool>();
// get the static character types table
auto d_flags = detail::get_character_flags_table();

// set the output values by checking the character types for each string
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
[d_column, d_flags, types, verify_types, d_results] __device__(size_type idx) {
if (d_column.is_null(idx)) return false;
auto d_str = d_column.element<string_view>(idx);
bool check = !d_str.empty(); // require at least one character
size_type check_count = 0;
for (auto itr = d_str.begin(); check && (itr != d_str.end()); ++itr) {
auto code_point = detail::utf8_to_codepoint(*itr);
// lookup flags in table by code-point
auto flag = code_point <= 0x00'FFFF ? d_flags[code_point] : 0;
if ((verify_types & flag) || // should flag be verified
(flag == 0 && verify_types == ALL_TYPES)) // special edge case
{
check = (types & flag) > 0;
++check_count;
}
}
return check && (check_count > 0);
});
//
results->set_null_count(strings.null_count());
thrust::make_counting_iterator<size_type>(input.size()),
results->mutable_view().data<bool>(),
char_types_fn{*d_strings, d_flags, types, verify_types});

results->set_null_count(input.null_count());
return results;
}

Expand Down

0 comments on commit 2825d5a

Please sign in to comment.