diff --git a/cpp/benchmarks/CMakeLists.txt b/cpp/benchmarks/CMakeLists.txt index 004a9b22826..4c8ac4165fe 100644 --- a/cpp/benchmarks/CMakeLists.txt +++ b/cpp/benchmarks/CMakeLists.txt @@ -308,6 +308,7 @@ set(STRINGS_BENCH_SRC "${CMAKE_CURRENT_SOURCE_DIR}/string/case_benchmark.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/string/contains_benchmark.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/string/convert_durations_benchmark.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/string/find_benchmark.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/string/replace_benchmark.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/string/url_decode_benchmark.cpp") diff --git a/cpp/benchmarks/string/find_benchmark.cpp b/cpp/benchmarks/string/find_benchmark.cpp new file mode 100644 index 00000000000..200527d606e --- /dev/null +++ b/cpp/benchmarks/string/find_benchmark.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +enum FindAPI { find, find_multi, contains, starts_with, ends_with }; + +class StringFindScalar : public cudf::benchmark { +}; + +static void BM_find_scalar(benchmark::State& state, FindAPI find_api) +{ + cudf::size_type const n_rows{static_cast(state.range(0))}; + cudf::size_type const max_str_length{static_cast(state.range(1))}; + data_profile table_profile; + table_profile.set_distribution_params( + cudf::type_id::STRING, distribution_id::NORMAL, 0, max_str_length); + auto const table = + create_random_table({cudf::type_id::STRING}, 1, row_count{n_rows}, table_profile); + cudf::strings_column_view input(table->view().column(0)); + cudf::string_scalar target("+"); + cudf::test::strings_column_wrapper targets({"+", "-"}); + + for (auto _ : state) { + cuda_event_timer raii(state, true, 0); + switch (find_api) { + case find: cudf::strings::find(input, target); break; + case find_multi: + cudf::strings::find_multiple(input, cudf::strings_column_view(targets)); + break; + case contains: cudf::strings::contains(input, target); break; + case starts_with: cudf::strings::starts_with(input, target); break; + case ends_with: cudf::strings::ends_with(input, target); break; + } + } + + state.SetBytesProcessed(state.iterations() * input.chars_size()); +} + +static void generate_bench_args(benchmark::internal::Benchmark* b) +{ + int const min_rows = 1 << 12; + int const max_rows = 1 << 24; + int const row_mult = 8; + int const min_rowlen = 1 << 5; + int const max_rowlen = 1 << 13; + int const len_mult = 4; + for (int row_count = min_rows; row_count <= max_rows; row_count *= row_mult) { + for (int rowlen = min_rowlen; rowlen <= max_rowlen; rowlen *= len_mult) { + // avoid generating combinations that exceed the cudf column limit + size_t total_chars = static_cast(row_count) * rowlen; + if (total_chars < std::numeric_limits::max()) { + b->Args({row_count, rowlen}); + } + } + } +} + +#define STRINGS_BENCHMARK_DEFINE(name) \ + BENCHMARK_DEFINE_F(StringFindScalar, name) \ + (::benchmark::State & st) { BM_find_scalar(st, name); } \ + BENCHMARK_REGISTER_F(StringFindScalar, name) \ + ->Apply(generate_bench_args) \ + ->UseManualTime() \ + ->Unit(benchmark::kMillisecond); + +STRINGS_BENCHMARK_DEFINE(find) +STRINGS_BENCHMARK_DEFINE(find_multi) +STRINGS_BENCHMARK_DEFINE(contains) +STRINGS_BENCHMARK_DEFINE(starts_with) +STRINGS_BENCHMARK_DEFINE(ends_with) diff --git a/cpp/src/strings/find.cu b/cpp/src/strings/find.cu index c300fb0cc8c..57d5d6afc75 100644 --- a/cpp/src/strings/find.cu +++ b/cpp/src/strings/find.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -315,7 +315,8 @@ std::unique_ptr starts_with( rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { auto pfn = [] __device__(string_view d_string, string_view d_target) { - return d_string.find(d_target) == 0; + return (d_target.size_bytes() <= d_string.size_bytes()) && + (d_target.compare(d_string.data(), d_target.size_bytes()) == 0); }; return contains_fn(strings, target, pfn, stream, mr); } @@ -327,7 +328,8 @@ std::unique_ptr starts_with( rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { auto pfn = [] __device__(string_view d_string, string_view d_target) { - return d_string.find(d_target) == 0; + return (d_target.size_bytes() <= d_string.size_bytes()) && + (d_target.compare(d_string.data(), d_target.size_bytes()) == 0); }; return contains_fn(strings, targets, pfn, stream, mr); } @@ -339,10 +341,10 @@ std::unique_ptr ends_with( rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { auto pfn = [] __device__(string_view d_string, string_view d_target) { - auto str_length = d_string.length(); - auto tgt_length = d_target.length(); - if (str_length < tgt_length) return false; - return d_string.find(d_target, str_length - tgt_length) >= 0; + auto const str_size = d_string.size_bytes(); + auto const tgt_size = d_target.size_bytes(); + return (tgt_size <= str_size) && + (d_target.compare(d_string.data() + str_size - tgt_size, tgt_size) == 0); }; return contains_fn(strings, target, pfn, stream, mr); @@ -355,10 +357,10 @@ std::unique_ptr ends_with( rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { auto pfn = [] __device__(string_view d_string, string_view d_target) { - auto str_length = d_string.length(); - auto tgt_length = d_target.length(); - if (str_length < tgt_length) return false; - return d_string.find(d_target, str_length - tgt_length) >= 0; + auto const str_size = d_string.size_bytes(); + auto const tgt_size = d_target.size_bytes(); + return (tgt_size <= str_size) && + (d_target.compare(d_string.data() + str_size - tgt_size, tgt_size) == 0); }; return contains_fn(strings, targets, pfn, stream, mr);