From b446a6f187241e765c925da1053ece2679313a06 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Tue, 14 Nov 2023 12:49:19 -0500 Subject: [PATCH] Fix token-count logic in nvtext::tokenize_with_vocabulary (#14393) Fixes a bug introduced in #14336 when trying to simplify the token-counting logic as per this discussion https://github.com/rapidsai/cudf/pull/14336#discussion_r1378173552 The simplification caused an error which was found when running the nvtext benchmarks. The appropriate gtest has been updated to cover this case now. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Bradley Dice (https://github.com/bdice) - Karthikeyan (https://github.com/karthikeyann) URL: https://github.com/rapidsai/cudf/pull/14393 --- cpp/benchmarks/text/vocab.cpp | 2 +- cpp/src/text/vocabulary_tokenize.cu | 8 ++++++-- cpp/tests/text/tokenize_tests.cpp | 12 ++++++------ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/cpp/benchmarks/text/vocab.cpp b/cpp/benchmarks/text/vocab.cpp index 6922b7214ff..80942e2697d 100644 --- a/cpp/benchmarks/text/vocab.cpp +++ b/cpp/benchmarks/text/vocab.cpp @@ -53,7 +53,7 @@ static void bench_vocab_tokenize(nvbench::state& state) auto const vocab_col = [] { data_profile const profile = data_profile_builder().no_validity().distribution( - cudf::type_id::STRING, distribution_id::NORMAL, 0, 5); + cudf::type_id::STRING, distribution_id::NORMAL, 0, 15); auto const col = create_random_column(cudf::type_id::STRING, row_count{100}, profile); return cudf::strings::filter_characters_of_type( cudf::strings_column_view(col->view()), diff --git a/cpp/src/text/vocabulary_tokenize.cu b/cpp/src/text/vocabulary_tokenize.cu index 41f8c0a8731..511f1995374 100644 --- a/cpp/src/text/vocabulary_tokenize.cu +++ b/cpp/src/text/vocabulary_tokenize.cu @@ -276,8 +276,12 @@ __global__ void token_counts_fn(cudf::column_device_view const d_strings, __syncwarp(); for (auto itr = d_output + lane_idx + 1; itr < d_output_end; itr += cudf::detail::warp_size) { - // add one if at the edge of a token or at the string's end - count += ((*itr && !(*(itr - 1))) || (itr + 1 == d_output_end)); + // add one if at the edge of a token or if at the string's end + if (*itr) { + count += !(*(itr - 1)); + } else { + count += (itr + 1 == d_output_end); + } } __syncwarp(); diff --git a/cpp/tests/text/tokenize_tests.cpp b/cpp/tests/text/tokenize_tests.cpp index 8118183a458..ea36e13de6f 100644 --- a/cpp/tests/text/tokenize_tests.cpp +++ b/cpp/tests/text/tokenize_tests.cpp @@ -246,14 +246,14 @@ TEST_F(TextTokenizeTest, Vocabulary) TEST_F(TextTokenizeTest, VocabularyLongStrings) { - cudf::test::strings_column_wrapper vocabulary( // leaving out 'cat' on purpose + cudf::test::strings_column_wrapper vocabulary( {"ate", "chased", "cheese", "dog", "fox", "jumped", "mouse", "mousé", "over", "the"}); auto vocab = nvtext::load_vocabulary(cudf::strings_column_view(vocabulary)); std::vector h_strings( 4, "the fox jumped chased the dog cheese mouse at the over there dog mouse cat plus the horse " - "jumped over the mouse house with the dog"); + "jumped over the mousé house with the dog "); cudf::test::strings_column_wrapper input(h_strings.begin(), h_strings.end()); auto input_view = cudf::strings_column_view(input); auto delimiter = cudf::string_scalar(" "); @@ -262,10 +262,10 @@ TEST_F(TextTokenizeTest, VocabularyLongStrings) using LCW = cudf::test::lists_column_wrapper; // clang-format off - LCW expected({LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3}, - LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3}, - LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3}, - LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3}}); + LCW expected({LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3}, + LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3}, + LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3}, + LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3}}); // clang-format on CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);