Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix token-count logic in nvtext::tokenize_with_vocabulary #14393

Merged
merged 2 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/benchmarks/text/vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ static void bench_vocab_tokenize(nvbench::state& state)

auto const vocab_col = [] {
data_profile const profile = data_profile_builder().no_validity().distribution(
cudf::type_id::STRING, distribution_id::NORMAL, 0, 5);
cudf::type_id::STRING, distribution_id::NORMAL, 0, 15);
auto const col = create_random_column(cudf::type_id::STRING, row_count{100}, profile);
return cudf::strings::filter_characters_of_type(
cudf::strings_column_view(col->view()),
Expand Down
8 changes: 6 additions & 2 deletions cpp/src/text/vocabulary_tokenize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,12 @@ __global__ void token_counts_fn(cudf::column_device_view const d_strings,
__syncwarp();

for (auto itr = d_output + lane_idx + 1; itr < d_output_end; itr += cudf::detail::warp_size) {
// add one if at the edge of a token or at the string's end
count += ((*itr && !(*(itr - 1))) || (itr + 1 == d_output_end));
// add one if at the edge of a token or if at the string's end
if (*itr) {
count += !(*(itr - 1));
} else {
count += (itr + 1 == d_output_end);
}
}
__syncwarp();

Expand Down
12 changes: 6 additions & 6 deletions cpp/tests/text/tokenize_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,14 @@ TEST_F(TextTokenizeTest, Vocabulary)

TEST_F(TextTokenizeTest, VocabularyLongStrings)
{
cudf::test::strings_column_wrapper vocabulary( // leaving out 'cat' on purpose
cudf::test::strings_column_wrapper vocabulary(
{"ate", "chased", "cheese", "dog", "fox", "jumped", "mouse", "mousé", "over", "the"});
auto vocab = nvtext::load_vocabulary(cudf::strings_column_view(vocabulary));

std::vector<std::string> h_strings(
4,
"the fox jumped chased the dog cheese mouse at the over there dog mouse cat plus the horse "
"jumped over the mouse house with the dog");
"jumped over the mousé house with the dog ");
cudf::test::strings_column_wrapper input(h_strings.begin(), h_strings.end());
auto input_view = cudf::strings_column_view(input);
auto delimiter = cudf::string_scalar(" ");
Expand All @@ -262,10 +262,10 @@ TEST_F(TextTokenizeTest, VocabularyLongStrings)

using LCW = cudf::test::lists_column_wrapper<cudf::size_type>;
// clang-format off
LCW expected({LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 6, -1, -1, 9, 3}});
LCW expected({LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3},
LCW{ 9, 4, 5, 1, 9, 3, 2, 6, -1, 9, 8, -1, 3, 6, -1, -1, 9, -1, 5, 8, 9, 7, -1, -1, 9, 3}});
// clang-format on
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);

Expand Down