diff --git a/cpp/src/text/subword/load_hash_file.cu b/cpp/src/text/subword/load_hash_file.cu index bb0af41e602..1e81e603ca8 100644 --- a/cpp/src/text/subword/load_hash_file.cu +++ b/cpp/src/text/subword/load_hash_file.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -114,6 +114,52 @@ const aux_codepoint_data_type* get_aux_codepoint_data(rmm::cuda_stream_view stre }); } +namespace { +/** + * @brief Convert string to uint32. + * + * This just wraps the std::stoi but provides a nice error message + * in case the hash file format is incorrect. + */ +uint32_t str_to_uint32(std::string const& str, uint64_t line_no) +{ + try { + return std::stoi(str); // there is no std::stoui + } catch (std::exception exc) { + std::string message("Line "); + message += std::to_string(line_no) + ": "; + message += "cannot convert integer from '"; + message += str; + message += "': "; + message += exc.what(); + std::cerr << message << std::endl; + throw exc; + } +} + +/** + * @brief Convert string to uint64. + * + * This just wraps the std::stoul but provides a nice error message + * in case the hash file format is incorrect. + */ +uint64_t str_to_uint64(std::string const& str, uint64_t line_no) +{ + try { + return std::stoul(str); + } catch (std::exception exc) { + std::string message("Line "); + message += std::to_string(line_no) + ": "; + message += "cannot convert integer from '"; + message += str; + message += "': "; + message += exc.what(); + std::cerr << message << std::endl; + throw exc; + } +} +} // namespace + /** * @brief Loads a text file representing the hashed vocabulary into hashed_vocabulary struct. * @@ -145,15 +191,16 @@ hashed_vocabulary load_vocabulary_file(std::string const& filename_hashed_vocabu std::ifstream hash_file(filename_hashed_vocabulary); CUDF_EXPECTS(hash_file.good(), "Could not open " + filename_hashed_vocabulary); + uint64_t line_no = 1; std::string line; std::getline(hash_file, line); - result.outer_hash_a = std::stoi(line); + result.outer_hash_a = str_to_uint32(line, line_no++); std::getline(hash_file, line); - result.outer_hash_b = std::stoi(line); + result.outer_hash_b = str_to_uint32(line, line_no++); std::getline(hash_file, line); - result.num_bins = std::stoi(line); + result.num_bins = str_to_uint32(line, line_no++); std::vector bin_coefficients(result.num_bins); std::vector bin_offsets(result.num_bins); @@ -161,32 +208,34 @@ hashed_vocabulary load_vocabulary_file(std::string const& filename_hashed_vocabu for (int i = 0; i < result.num_bins; ++i) { std::getline(hash_file, line); size_t loc_of_space = line.find(" "); + CUDF_EXPECTS(loc_of_space != line.npos, "invalid hash file format"); std::string first_num = line.substr(0, loc_of_space); std::string second_num = line.substr(loc_of_space + 1, line.length()); - bin_coefficients[i] = std::stoull(first_num); - bin_offsets[i] = std::stoull(second_num); + bin_coefficients[i] = str_to_uint64(first_num, line_no); + bin_offsets[i] = str_to_uint32(second_num, line_no); + ++line_no; } std::getline(hash_file, line); - uint64_t hash_table_length = std::stoull(line); + uint64_t hash_table_length = str_to_uint64(line, line_no++); std::vector table(hash_table_length); - std::generate(table.begin(), table.end(), [&hash_file]() { + std::generate(table.begin(), table.end(), [&hash_file, &line_no]() { std::string line; std::getline(hash_file, line); - return std::stoull(line); + return str_to_uint64(line, line_no++); }); std::getline(hash_file, line); - result.unknown_token_id = std::stoi(line); + result.unknown_token_id = str_to_uint32(line, line_no++); std::getline(hash_file, line); - result.first_token_id = std::stoi(line); + result.first_token_id = str_to_uint32(line, line_no++); std::getline(hash_file, line); - result.separator_token_id = std::stoi(line); + result.separator_token_id = str_to_uint32(line, line_no++); // Transfer hash table to columns result.table = cudf::make_numeric_column(cudf::data_type{cudf::type_id::UINT64}, diff --git a/cpp/src/text/subword/wordpiece_tokenizer.cu b/cpp/src/text/subword/wordpiece_tokenizer.cu index b13c22670ee..4d048b3cf99 100644 --- a/cpp/src/text/subword/wordpiece_tokenizer.cu +++ b/cpp/src/text/subword/wordpiece_tokenizer.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -101,7 +102,7 @@ __global__ void init_data_and_mark_word_start_and_ends(uint32_t const* code_poin /** * @brief Resolves the string boundaries for the start and end words. * - * This kernel should be called after `mark_word_start_and_ends` with at + * This kernel should be called after `init_data_and_mark_word_start_and_ends` with at * least `num_strings` total threads. * * The start and end indices are updated to honor the string boundaries @@ -141,6 +142,125 @@ __global__ void mark_string_start_and_ends(uint32_t const* code_points, } } +/** + * @brief Currently supported special tokens. + * + * Code logic expects these to be 3 upper-case characters along + * with a single trailing space. + */ +__constant__ char special_tokens[35]{"BOS EOS UNK SEP PAD CLS MASK "}; +constexpr cudf::size_type MIN_ST_WIDTH = 4; // Min token size in special_tokens +constexpr cudf::size_type MAX_ST_WIDTH = 5; // Max token size in special_tokens + +struct mark_special_tokens { + /** + * @brief Check given code-point array to the list of known + * special tokens. + */ + __device__ bool is_special_token(uint32_t const* token, cudf::size_type size) const + { + if (size < MIN_ST_WIDTH || size > MAX_ST_WIDTH) return false; + char str_token[MAX_ST_WIDTH]; + // convert code-points to chars + thrust::transform(thrust::seq, token, token + size, str_token, [](uint32_t cp) { + // also upper-case them to match again special_tokens array + return static_cast(cp >= 'a' ? cp - 'a' + 'A' : cp); + }); + // search the special tokens array for the str_token + cudf::string_view tokens(special_tokens, sizeof(special_tokens)); + return tokens.find(str_token, size) >= 0; + } + + /** + * @brief Check code-points for special tokens and adjust indices. + * + * Tokens will appear in the `code_points` array as: + * `_[_ttt_]_` where `_` are single space characters and + * ttt is the variable-length token name + * + * The logic below uses the following variables to represent position + * values in the `code_points` array after locating a special token: + * ``` + * _ [ _ t t t _ ] _ + * ^ ^ ^ ^ + * si sp ep ei + * ``` + * where `si` is `start_index` + * `sp` is `start_pos` + * `ep` is `end_pos` + * `ei` is `end_index` + * + * When a special token is found, the `code_points` are adjusted + * to remove the spaces and capitalize the name. + * ``` + * _ [ _ t t t _ ] _ is updated to + * _ [ T T T ] _ ] _ + * ``` + * This is required for the downstream word-piece tokenizer to + * match it to the vocabulary hash table. + * + * The `start_word_indices` and `end_word_indices` are updated to + * identify the token and to ignore the extra trailing `]` character. + */ + __device__ void operator()(size_t idx) const + { + uint32_t const start_index = start_word_indices[idx]; + if ((start_index == std::numeric_limits::max()) || + ((start_index + MIN_ST_WIDTH + 2) > num_code_points)) + return; + if (code_points[start_index] != '[') return; + + // check for matching end bracket + uint32_t const start_pos = start_index + 2; // after the space delimiter + // search for next start-word and then check it is a ']' + uint32_t const end_index = [&] { + auto const begin = start_word_indices + start_pos; + auto const width = + std::min(static_cast(MAX_ST_WIDTH + 1), (num_code_points - start_pos)); + auto const end = begin + width; + // checking the next start-word is more reliable than arbitrarily searching for ']' + // in case the text is split across string rows + auto const iter = thrust::find_if(thrust::seq, begin + 1, end, [](auto swi) { + return swi != std::numeric_limits::max(); + }); + return iter == end ? start_index : static_cast(iter - start_word_indices); + }(); + if (code_points[end_index] != ']') return; + + // check for special token + auto const size = static_cast(end_index - start_pos); + if (!is_special_token(code_points + start_pos, size)) return; + + // special token found + // adjust code-points + auto const end_pos = end_index - 2; + // change _[_ttt_]_ to _[TTT]_ + for (auto left_idx = start_pos - 1; left_idx <= end_pos; ++left_idx) { + auto const cp = code_points[left_idx + 1]; + code_points[left_idx] = cp >= 'a' ? cp - 'a' + 'A' : cp; + } + code_points[end_pos] = ']'; + + // erase the intermediate indices + thrust::fill(thrust::seq, + start_word_indices + start_index + 1, // keep the first one + start_word_indices + end_index + 1, + std::numeric_limits::max()); + thrust::fill(thrust::seq, + end_word_indices + start_index, + end_word_indices + end_index + 1, + std::numeric_limits::max()); + + // reset the new end-word index + end_word_indices[end_pos] = end_pos + 1; + } + + uint32_t* const code_points; + uint32_t* const start_word_indices; + uint32_t* const end_word_indices; + size_t const num_code_points; +}; + /** * @brief Converts words into token ids. * @@ -345,6 +465,14 @@ void wordpiece_tokenizer::tokenize(uvector_pair& cps_and_offsets, rmm::cuda_stre num_strings); CHECK_CUDA(stream.value()); + // check for special tokens and adjust indices + thrust::for_each_n( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + num_code_points, + mark_special_tokens{ + device_code_points, device_start_word_indices, device_end_word_indices, num_code_points}); + // Now start_word_indices has the word starts scattered throughout the array. We need to select // all values not equal to the max uint32_t and place them at the start of the array. We leverage // the fact that the start_word_indices and the end_word indices are contiguous to only launch one diff --git a/cpp/tests/text/subword_tests.cpp b/cpp/tests/text/subword_tests.cpp index c87f6e9af8a..3cab612fccd 100644 --- a/cpp/tests/text/subword_tests.cpp +++ b/cpp/tests/text/subword_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ struct TextSubwordTest : public cudf::test::BaseFixture { // Create a fake hashed vocab text file for the tests in this source file. // The vocab only includes the following words: // 'this', 'is', 'a', 'test', 'tést' -// The period '.' character is also supported. +// The period '.' character also has a token id. void create_hashed_vocab(std::string const& hash_file) { std::vector> coefficients(23, {65559, 0}); @@ -262,3 +262,80 @@ TEST(TextSubwordTest, LoadVocabFileErrors) std::string hash_file = temp_env->get_temp_filepath("nothing.txt"); EXPECT_THROW(nvtext::load_vocabulary_file(hash_file), cudf::logic_error); } + +// This includes the words above and 7 special tokens: +// [BOS] [EOS] [UNK] [SEP] [PAD] [CLS] [MASK] +// The data here was generated by the utility: +// cudf.utils.hash_vocab_utils.hash_vocab() +void create_special_tokens_hashed_vocab(std::string const& hash_file) +{ + std::ofstream outfile(hash_file, std::ofstream::out); + outfile << "26899\n27424\n3\n"; + outfile << "1416131940466419714 0\n"; + outfile << "313740585393291779 2\n"; + outfile << "17006415773850330120 5\n"; + outfile << "13\n"; + outfile << "5903884228619468800\n"; + outfile << "6205475701751152650\n"; + outfile << "16285378285009240068\n"; + outfile << "5162333542489915397\n"; + outfile << "6064762127302393859\n"; + outfile << "6173800107753209857\n"; + outfile << "5322083323972878342\n"; + outfile << "6242701866907861003\n"; + outfile << "451412623368\n"; + outfile << "3014668\n"; + outfile << "5214737420442796034\n"; + outfile << "6206321707968233479\n"; + outfile << "6357001\n"; + outfile << "1\n2\n3\n\n"; +} + +TEST(TextSubwordTest, TokenizeWithSpecialTokens) +{ + std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); + create_special_tokens_hashed_vocab(hash_file); + + // clang-format off + std::vector h_strings{ + "[BOS]This is a tést.[eos]", + "[CLS]A test[SEP]this is.", + "[PAD] [A][MASK]", + "test this [CL", + "S] is a ."}; + // clang-format on + cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end()); + auto vocab = nvtext::load_vocabulary_file(hash_file); + auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings}, + vocab, + 8, + 6, + true, // do_lower_case + true, // do_truncate + MAX_ROWS_TENSOR); + + EXPECT_EQ(static_cast(h_strings.size()), result.nrows_tensor); + // clang-format off + cudf::test::fixed_width_column_wrapper expected_tokens( + { 5, 7, 8, 9, 10, 12, 6, 0, + 2, 9, 10, 3, 7, 8, 12, 0, + 0, 1, 9, 1, 4, 0, 0, 0, + 10, 7, 1, 1, 0, 0, 0, 0, + 1, 1, 8, 9, 12, 0, 0, 0}); + cudf::test::fixed_width_column_wrapper expected_attn( + {1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 1, 1, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 0, 0, 0}); + cudf::test::fixed_width_column_wrapper expected_metadata( + {0, 0, 6, + 1, 0, 6, + 2, 0, 4, + 3, 0, 3, + 4, 0, 4}); + // clang-format on + CUDF_TEST_EXPECT_COLUMNS_EQUAL(result.tensor_token_ids->view(), expected_tokens); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(result.tensor_attention_mask->view(), expected_attn); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(result.tensor_metadata->view(), expected_metadata); +}