Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvement in cudf::strings::all_characters_of_type #13259

Merged
merged 21 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
7be60e1
Performance improvement in cudf::strings::all_characters_of_type
davidwendt May 1, 2023
238a121
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 1, 2023
71b8037
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 2, 2023
433516d
use max(int8) instead of max(char)
davidwendt May 2, 2023
83fece9
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 2, 2023
f445b9f
fix/add comments
davidwendt May 3, 2023
dba259d
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 3, 2023
b41914b
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 4, 2023
c6e6f3b
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 4, 2023
c547a8b
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 4, 2023
8b63260
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 4, 2023
0b9390b
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 8, 2023
39a8901
add some const decls
davidwendt May 8, 2023
93cfda9
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 8, 2023
5073fa5
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 8, 2023
6578d48
fix merge conflict
davidwendt May 9, 2023
c82a561
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 10, 2023
a681713
Merge branch 'branch-23.06' into char-types-perf
davidwendt May 10, 2023
07cc19c
remove unneeded header
davidwendt May 10, 2023
91051cb
fix comment
davidwendt May 11, 2023
35eae82
Merge branch 'branch-23.06' into char-types-perf
karthikeyann May 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ ConfigureBench(
string/url_decode.cu
)

ConfigureNVBench(STRINGS_NVBENCH string/like.cpp string/reverse.cpp string/lengths.cpp)
ConfigureNVBench(
STRINGS_NVBENCH string/char_types.cpp string/like.cpp string/reverse.cpp string/lengths.cpp
)

# ##################################################################################################
# * json benchmark -------------------------------------------------------------------
Expand Down
67 changes: 67 additions & 0 deletions cpp/benchmarks/string/char_types.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <benchmarks/common/generate_input.hpp>
#include <benchmarks/fixture/rmm_pool_raii.hpp>
PointKernel marked this conversation as resolved.
Show resolved Hide resolved

#include <cudf/strings/char_types/char_types.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <nvbench/nvbench.cuh>

static void bench_char_types(nvbench::state& state)
{
auto const num_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const api_type = state.get_string("api");

if (static_cast<std::size_t>(num_rows) * static_cast<std::size_t>(row_width) >=
static_cast<std::size_t>(std::numeric_limits<cudf::size_type>::max())) {
state.skip("Skip benchmarks greater than size_type limit");
}

data_profile const table_profile = data_profile_builder().distribution(
cudf::type_id::STRING, distribution_id::NORMAL, 0, row_width);
auto const table =
create_random_table({cudf::type_id::STRING}, row_count{num_rows}, table_profile);
cudf::strings_column_view input(table->view().column(0));
auto input_types = cudf::strings::string_character_types::SPACE;

state.set_cuda_stream(nvbench::make_cuda_stream_view(cudf::get_default_stream().value()));
// gather some throughput statistics as well
auto chars_size = input.chars_size();
state.add_global_memory_reads<nvbench::int8_t>(chars_size); // all bytes are read;
if (api_type == "all") {
state.add_global_memory_writes<nvbench::int8_t>(num_rows); // output is a bool8 per row
} else {
state.add_global_memory_writes<nvbench::int8_t>(chars_size);
}

state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
if (api_type == "all") {
auto result = cudf::strings::all_characters_of_type(input, input_types);
} else {
auto result = cudf::strings::filter_characters_of_type(input, input_types);
}
});
}

NVBENCH_BENCH(bench_char_types)
.set_name("char_types")
.add_int64_axis("row_width", {32, 64, 128, 256, 512, 1024, 2048, 4096})
.add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216})
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
.add_string_axis("api", {"all", "filter"});
93 changes: 59 additions & 34 deletions cpp/src/strings/char_types/char_types.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,59 +31,84 @@
#include <rmm/cuda_stream_view.hpp>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/logical.h>
#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {
//
std::unique_ptr<column> all_characters_of_type(strings_column_view const& strings,
namespace {

/**
* @brief Returns true for each string where all characters match the given types.
*
* Only the characters that match to `verify_types` are checked.
* Returns false if no characters are checked or one character does not match `types`.
* Returns true if at least one character is checked and all checked characters match `types`.
*/
struct char_types_fn {
column_device_view const d_column;
character_flags_table_type const* d_flags;
string_character_types const types;
string_character_types const verify_types;

__device__ bool operator()(size_type idx)
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
{
if (d_column.is_null(idx)) return false;
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
auto const d_str = d_column.element<string_view>(idx);
auto const end = d_str.data() + d_str.size_bytes();

bool type_matched = !d_str.empty(); // require at least one character;
size_type check_count = 0; // count checked characters
for (auto itr = d_str.data(); type_matched && (itr < end); ++itr) {
uint8_t const chr = static_cast<uint8_t>(*itr);
if (is_utf8_continuation_char(chr)) { continue; }
auto u8 = static_cast<char_utf8>(chr); // holds UTF8 value
// using max(int8) here since max(char)=255 on ARM systems
if (u8 > std::numeric_limits<int8_t>::max()) { to_char_utf8(itr, u8); }

// lookup flags in table by code-point
auto const code_point = utf8_to_codepoint(u8);
auto const flag = code_point <= 0x00'FFFF ? d_flags[code_point] : 0;

if ((verify_types & flag) || // should flag be verified;
(flag == 0 && verify_types == ALL_TYPES)) // special edge case
{
type_matched = (types & flag) > 0;
++check_count;
}
}

return type_matched && (check_count > 0);
}
};
} // namespace

std::unique_ptr<column> all_characters_of_type(strings_column_view const& input,
string_character_types types,
string_character_types verify_types,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_column = *strings_column;
auto d_strings = column_device_view::create(input.parent(), stream);

// create output column
auto results = make_numeric_column(data_type{type_id::BOOL8},
strings_count,
cudf::detail::copy_bitmask(strings.parent(), stream, mr),
strings.null_count(),
auto results = make_numeric_column(data_type{type_id::BOOL8},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);
auto results_view = results->mutable_view();
auto d_results = results_view.data<bool>();
// get the static character types table
auto d_flags = detail::get_character_flags_table();

// set the output values by checking the character types for each string
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
[d_column, d_flags, types, verify_types, d_results] __device__(size_type idx) {
if (d_column.is_null(idx)) return false;
auto d_str = d_column.element<string_view>(idx);
bool check = !d_str.empty(); // require at least one character
size_type check_count = 0;
for (auto itr = d_str.begin(); check && (itr != d_str.end()); ++itr) {
auto code_point = detail::utf8_to_codepoint(*itr);
// lookup flags in table by code-point
auto flag = code_point <= 0x00'FFFF ? d_flags[code_point] : 0;
if ((verify_types & flag) || // should flag be verified
(flag == 0 && verify_types == ALL_TYPES)) // special edge case
{
check = (types & flag) > 0;
++check_count;
}
}
return check && (check_count > 0);
});
//
results->set_null_count(strings.null_count());
thrust::make_counting_iterator<size_type>(input.size()),
results->mutable_view().data<bool>(),
char_types_fn{*d_strings, d_flags, types, verify_types});

results->set_null_count(input.null_count());
return results;
}

Expand Down