Skip to content

Commit

Permalink
Fix token-count logic in nvtext::tokenize_with_vocabulary (#14393)
Browse files Browse the repository at this point in the history
Fixes a bug introduced in #14336 when trying to simplify the token-counting logic as per this discussion #14336 (comment)
The simplification caused an error which was found when running the nvtext benchmarks.
The appropriate gtest has been updated to cover this case now.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Karthikeyan (https://github.com/karthikeyann)

URL: #14393
  • Loading branch information
davidwendt authored Nov 14, 2023
1 parent b0c1b7b commit b446a6f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cpp/benchmarks/text/vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ static void bench_vocab_tokenize(nvbench::state& state)

auto const vocab_col = [] {
data_profile const profile = data_profile_builder().no_validity().distribution(
cudf::type_id::STRING, distribution_id::NORMAL, 0, 5);
cudf::type_id::STRING, distribution_id::NORMAL, 0, 15);
auto const col = create_random_column(cudf::type_id::STRING, row_count{100}, profile);
return cudf::strings::filter_characters_of_type(
cudf::strings_column_view(col->view()),
Expand Down
8 changes: 6 additions & 2 deletions cpp/src/text/vocabulary_tokenize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,12 @@ __global__ void token_counts_fn(cudf::column_device_view const d_strings,
__syncwarp();

for (auto itr = d_output + lane_idx + 1; itr < d_output_end; itr += cudf::detail::warp_size) {
// add one if at the edge of a token or at the string's end
count += ((*itr && !(*(itr - 1))) || (itr + 1 == d_output_end));
// add one if at the edge of a token or if at the string's end
if (*itr) {
count += !(*(itr - 1));
} else {
count += (itr + 1 == d_output_end);
}
}
__syncwarp();

Expand Down
12 changes: 6 additions & 6 deletions cpp/tests/text/tokenize_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,14 @@ TEST_F(TextTokenizeTest, Vocabulary)

TEST_F(TextTokenizeTest, VocabularyLongStrings)
{
cudf::test::strings_column_wrapper vocabulary( // leaving out 'cat' on purpose
cudf::test::strings_column_wrapper vocabulary(
{"ate", "chased", "cheese", "dog", "fox", "jumped", "mouse", "mousé", "over", "the"});
auto vocab = nvtext::load_vocabulary(cudf::strings_column_view(vocabulary));

std::vector<std::string> h_strings(
4,
"the fox jumped chased the dog cheese mouse at the over there dog mouse cat plus the horse "
"jumped over the mouse house with the dog");
"jumped over the mousé house with the dog ");
cudf::test::strings_column_wrapper input(h_strings.begin(), h_strings.end());
auto input_view = cudf::strings_column_view(input);
auto delimiter = cudf::string_scalar(" ");
Expand All @@ -262,10 +262,10 @@ TEST_F(TextTokenizeTest, VocabularyLongStrings)

using LCW = cudf::test::lists_column_wrapper<cudf::size_type>;
// clang-format off
LCW expected({LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3}});
LCW expected({LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3}});
// clang-format on
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);

Expand Down

0 comments on commit b446a6f

Please sign in to comment.