From 6b0c4aeaac3fd76e8848f5389dbe7d58b3f4b049 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 4 Jan 2022 08:52:47 -0800 Subject: [PATCH 1/5] Removed str.subword_tokenize --- cpp/include/nvtext/subword_tokenize.hpp | 23 +--- cpp/src/text/subword/subword_tokenize.cu | 21 ---- cpp/tests/text/subword_tests.cpp | 23 ++-- .../cudf/_lib/nvtext/subword_tokenize.pyx | 35 ------ python/cudf/cudf/_lib/strings/__init__.py | 1 - python/cudf/cudf/core/column/string.py | 113 ------------------ 6 files changed, 17 insertions(+), 199 deletions(-) diff --git a/cpp/include/nvtext/subword_tokenize.hpp b/cpp/include/nvtext/subword_tokenize.hpp index 8cc000ff095..2b09ec66203 100644 --- a/cpp/include/nvtext/subword_tokenize.hpp +++ b/cpp/include/nvtext/subword_tokenize.hpp @@ -130,9 +130,7 @@ struct tokenizer_result { * larger than the max value for cudf::size_type * * @param strings The input strings to tokenize. - * @param filename_hashed_vocabulary A path to the preprocessed vocab.txt file. - * Note that this is the file AFTER python/perfect_hash.py has been used - * for preprocessing. + * @param vocabulary_table The vocabulary table pre-loaded into this object. * @param max_sequence_length Limit of the number of token-ids per row in final tensor * for each string. * @param stride Each row in the output token-ids will replicate `max_sequence_length - stride` @@ -150,25 +148,6 @@ struct tokenizer_result { * @param mr Memory resource to allocate any returned objects. * @return token-ids, attention-mask, and metadata */ -tokenizer_result subword_tokenize( - cudf::strings_column_view const& strings, - std::string const& filename_hashed_vocabulary, - uint32_t max_sequence_length, - uint32_t stride, - bool do_lower_case, - bool do_truncate, - uint32_t max_rows_tensor, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); - -/** - * @copydoc subword_tokenize() - * - * This function differs from the one above by only the hashed vocabulary parameter. - * The file can be pre-loaded using the @ref load_vocabulary_file API and then - * passed in place of the file name in a call to this API. - * - * @param vocabulary_table The vocabulary table pre-loaded into this object. - */ tokenizer_result subword_tokenize( cudf::strings_column_view const& strings, hashed_vocabulary const& vocabulary_table, diff --git a/cpp/src/text/subword/subword_tokenize.cu b/cpp/src/text/subword/subword_tokenize.cu index 6de1044b492..7ebc10446c3 100644 --- a/cpp/src/text/subword/subword_tokenize.cu +++ b/cpp/src/text/subword/subword_tokenize.cu @@ -249,27 +249,6 @@ tokenizer_result subword_tokenize(cudf::strings_column_view const& strings, } // namespace detail -tokenizer_result subword_tokenize(cudf::strings_column_view const& strings, - std::string const& filename_hashed_vocabulary, - uint32_t max_sequence_length, - uint32_t stride, - bool do_lower_case, - bool do_truncate, - uint32_t max_rows_tensor, - rmm::mr::device_memory_resource* mr) -{ - auto vocab_table = load_vocabulary_file(filename_hashed_vocabulary, mr); - CUDF_FUNC_RANGE(); - return detail::subword_tokenize(strings, - *vocab_table, - max_sequence_length, - stride, - do_lower_case, - do_truncate, - max_rows_tensor, - rmm::cuda_stream_default, - mr); -} tokenizer_result subword_tokenize(cudf::strings_column_view const& strings, hashed_vocabulary const& vocabulary_table, diff --git a/cpp/tests/text/subword_tests.cpp b/cpp/tests/text/subword_tests.cpp index 65cc466fee7..298a68940ea 100644 --- a/cpp/tests/text/subword_tests.cpp +++ b/cpp/tests/text/subword_tests.cpp @@ -67,12 +67,13 @@ TEST(TextSubwordTest, Tokenize) cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end()); std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); + auto vocab = nvtext::load_vocabulary_file(hash_file); uint32_t max_sequence_length = 16; uint32_t stride = 16; auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings}, - hash_file, + *vocab, max_sequence_length, stride, true, // do_lower_case @@ -119,12 +120,14 @@ TEST(TextSubwordTest, TokenizeMultiRow) cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end()); std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); + auto vocab = nvtext::load_vocabulary_file(hash_file); + uint32_t max_sequence_length = 8; uint32_t stride = 6; auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings}, - hash_file, + *vocab, max_sequence_length, stride, true, // do_lower_case @@ -148,12 +151,14 @@ TEST(TextSubwordTest, TokenizeMaxEqualsTokens) cudf::test::strings_column_wrapper strings({"This is a test."}); std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); + auto vocab = nvtext::load_vocabulary_file(hash_file); + uint32_t max_sequence_length = 5; // five tokens in strings; uint32_t stride = 5; // this should not effect the result auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings}, - hash_file, + *vocab, max_sequence_length, stride, true, // do_lower_case @@ -175,8 +180,10 @@ TEST(TextSubwordTest, ParameterErrors) cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end()); std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); + auto vocab = nvtext::load_vocabulary_file(hash_file); + EXPECT_THROW(nvtext::subword_tokenize(cudf::strings_column_view{strings}, - hash_file, + *vocab, 12, // max_sequence_length 13, // stride <= max_sequence_length true, // do_lower_case @@ -185,7 +192,7 @@ TEST(TextSubwordTest, ParameterErrors) cudf::logic_error); EXPECT_THROW(nvtext::subword_tokenize(cudf::strings_column_view{strings}, - hash_file, + *vocab, 5, 5, true, // do_lower_case @@ -199,8 +206,9 @@ TEST(TextSubwordTest, EmptyStrings) cudf::test::strings_column_wrapper strings; std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); + auto vocab = nvtext::load_vocabulary_file(hash_file); auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings}, - hash_file, + *vocab, 16, 16, true, // do_lower_case @@ -217,8 +225,9 @@ TEST(TextSubwordTest, AllNullStrings) cudf::test::strings_column_wrapper strings({"", "", ""}, {0, 0, 0}); std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); + auto vocab = nvtext::load_vocabulary_file(hash_file); auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings}, - hash_file, + *vocab, 16, 16, true, // do_lower_case diff --git a/python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx b/python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx index 49f24436b88..426744ee46c 100644 --- a/python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx +++ b/python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx @@ -58,38 +58,3 @@ def subword_tokenize_inmem_hash( masks = Column.from_unique_ptr(move(c_result.tensor_attention_mask)) metadata = Column.from_unique_ptr(move(c_result.tensor_metadata)) return tokens, masks, metadata - - -def subword_tokenize_vocab_file( - Column strings, - object hash_file, - uint32_t max_sequence_length=64, - uint32_t stride=48, - bool do_lower=True, - bool do_truncate=False, - uint32_t max_rows_tensor=500 -): - """ - Subword tokenizes text series by using the hashed vocabulary - stored on disk - """ - cdef column_view c_strings = strings.view() - cdef cpp_tokenizer_result c_result - cdef string c_hash_file = str(hash_file).encode() - with nogil: - c_result = tr_move( - cpp_subword_tokenize( - c_strings, - c_hash_file, - max_sequence_length, - stride, - do_lower, - do_truncate, - max_rows_tensor - ) - ) - # return the 3 tensor components - tokens = Column.from_unique_ptr(move(c_result.tensor_token_ids)) - masks = Column.from_unique_ptr(move(c_result.tensor_attention_mask)) - metadata = Column.from_unique_ptr(move(c_result.tensor_metadata)) - return tokens, masks, metadata diff --git a/python/cudf/cudf/_lib/strings/__init__.py b/python/cudf/cudf/_lib/strings/__init__.py index fbc1538cc74..7911d0eff2a 100644 --- a/python/cudf/cudf/_lib/strings/__init__.py +++ b/python/cudf/cudf/_lib/strings/__init__.py @@ -12,7 +12,6 @@ is_letter_multi, porter_stemmer_measure, ) -from cudf._lib.nvtext.subword_tokenize import subword_tokenize_vocab_file from cudf._lib.nvtext.tokenize import ( _count_tokens_column, _count_tokens_scalar, diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 1c9a013810a..a83110d273c 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -4711,119 +4711,6 @@ def filter_tokens( ), ) - def subword_tokenize( - self, - hash_file: str, - max_length: int = 64, - stride: int = 48, - do_lower: bool = True, - do_truncate: bool = False, - max_rows_tensor: int = 500, - ) -> Tuple[cupy.ndarray, cupy.ndarray, cupy.ndarray]: - """ - Run CUDA BERT subword tokenizer on cuDF strings column. - Encodes words to token ids using vocabulary from a pretrained - tokenizer. - - This function requires about 21x the number of character bytes - in the input strings column as working memory. - - ``Series.str.subword_tokenize`` is deprecated and will be removed. - Use ``cudf.core.subword_tokenizer.SubwordTokenizer`` instead. - - Parameters - ---------- - hash_file : str - Path to hash file containing vocabulary of words with token-ids. - This can be created from the raw vocabulary - using the ``cudf.utils.hash_vocab_utils.hash_vocab`` function - max_length : int, Default is 64 - Limits the length of the sequence returned. - If tokenized string is shorter than max_length, - output will be padded with 0s. - If the tokenized string is longer than max_length and - do_truncate == False, there will be multiple returned - sequences containing the overflowing token-ids. - stride : int, Default is 48 - If do_truncate == False and the tokenized string is larger - than max_length, the sequences containing the overflowing - token-ids can contain duplicated token-ids from the main - sequence. If max_length is equal to stride there are no - duplicated-id tokens. If stride is 80% of max_length, - 20% of the first sequence will be repeated on the second - sequence and so on until the entire sentence is encoded. - do_lower : bool, Default is True - If set to true, original text will be lowercased before encoding. - do_truncate : bool, Default is False - If set to true, strings will be truncated and padded to - max_length. Each input string will result in exactly one output - sequence. If set to false, there may be multiple output - sequences when the max_length is smaller than generated tokens. - max_rows_tensor : int, Default is 500 - Maximum number of rows for the output token-ids expected - to be generated by the tokenizer. - Used for allocating temporary working memory on the GPU device. - If the output generates a larger number of rows, behavior - is undefined. - This will vary based on stride, truncation, and max_length. - For example, for non-overlapping sequences output rows - will be the same as input rows. - - Returns - ------- - token-ids : cupy.ndarray - The token-ids for each string padded with 0s to max_length. - attention-mask : cupy.ndarray - The mask for token-ids result where corresponding positions - identify valid token-id values. - metadata : cupy.ndarray - Each row contains the index id of the original string and the - first and last index of the token-ids that are non-padded and - non-overlapping. - - Examples - -------- - >>> import cudf - >>> from cudf.utils.hash_vocab_utils import hash_vocab - >>> hash_vocab('bert-base-uncased-vocab.txt', 'voc_hash.txt') - >>> ser = cudf.Series(['this is the', 'best book']) - >>> stride, max_length = 8, 8 - >>> max_rows_tensor = len(ser) - >>> tokens, masks, metadata = ser.str.subword_tokenize('voc_hash.txt', - ... max_length=max_length, stride=stride, - ... max_rows_tensor=max_rows_tensor) - >>> tokens.reshape(-1, max_length) - array([[2023, 2003, 1996, 0, 0, 0, 0, 0], - [2190, 2338, 0, 0, 0, 0, 0, 0]], dtype=uint32) - >>> masks.reshape(-1, max_length) - array([[1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0]], dtype=uint32) - >>> metadata.reshape(-1, 3) - array([[0, 0, 2], - [1, 0, 1]], dtype=uint32) - """ - warnings.warn( - "`Series.str.subword_tokenize` is deprecated and will be removed " - "in future versions of cudf. Use " - "`cudf.core.subword_tokenizer.SubwordTokenizer` instead.", - FutureWarning, - ) - - tokens, masks, metadata = libstrings.subword_tokenize_vocab_file( - self._column, - hash_file, - max_length, - stride, - do_lower, - do_truncate, - max_rows_tensor, - ) - return ( - cupy.asarray(tokens), - cupy.asarray(masks), - cupy.asarray(metadata), - ) - def porter_stemmer_measure(self) -> SeriesOrIndex: """ Compute the Porter Stemmer measure for each string. From f95d5f707a0a3959602a99391ef94dd3a8b69860 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 4 Jan 2022 09:24:57 -0800 Subject: [PATCH 2/5] removed old test --- .../cudf/cudf/tests/test_subword_tokenizer.py | 158 +++++++++++++++--- python/cudf/cudf/tests/test_text.py | 131 --------------- 2 files changed, 134 insertions(+), 155 deletions(-) diff --git a/python/cudf/cudf/tests/test_subword_tokenizer.py b/python/cudf/cudf/tests/test_subword_tokenizer.py index 717b3de8479..214c21c023b 100644 --- a/python/cudf/cudf/tests/test_subword_tokenizer.py +++ b/python/cudf/cudf/tests/test_subword_tokenizer.py @@ -1,12 +1,14 @@ # Copyright (c) 2020-2021, NVIDIA CORPORATION. import os +import cupy import numpy as np import pytest from transformers import BertTokenizer import cudf from cudf.core.subword_tokenizer import SubwordTokenizer +from cudf.testing._utils import assert_eq @pytest.fixture(scope="module") @@ -26,30 +28,6 @@ def assert_equal_tokenization_outputs(hf_output, cudf_output): ) -def test_subword_tokenize_on_disk_vocab_str_api(datadir): - """ - Tests the subword-tokenizer API where - the vocabulary is not pre-loaded - and is accessed via the string accessor - """ - with open( - os.path.join(datadir, "test_sentences.txt"), encoding="utf-8" - ) as file: - input_sentence_ls = [line.strip() for line in file] - - vocab_dir = os.path.join(datadir, "bert_base_cased_sampled") - vocab_hash_path = os.path.join(vocab_dir, "vocab-hash.txt") - - ser = cudf.Series(input_sentence_ls) - tokens, masks, metadata = ser.str.subword_tokenize( - vocab_hash_path, - max_length=32, - stride=32, - do_lower=True, - max_rows_tensor=len(ser), - ) - - @pytest.mark.parametrize("seq_len", [32, 64]) @pytest.mark.parametrize("stride", [0, 15, 30]) @pytest.mark.parametrize("add_special_tokens", [True, False]) @@ -115,3 +93,135 @@ def test_subword_tokenize_with_truncation(datadir): truncation=False, add_special_tokens=True, ) + + +def test_text_subword_tokenize(tmpdir): + sr = cudf.Series( + [ + "This is a test", + "A test this is", + "Is test a this", + "Test test", + "this This", + ] + ) + hash_file = tmpdir.mkdir("nvtext").join("tmp_hashed_vocab.txt") + content = "1\n0\n23\n" + coefficients = [65559] * 23 + for c in coefficients: + content = content + str(c) + " 0\n" + # based on values from the bert_hash_table.txt file for the + # test words used here: 'this' 'is' 'a' test' + table = [0] * 23 + table[0] = 3015668 + table[1] = 6205475701751155871 + table[5] = 6358029 + table[16] = 451412625363 + table[20] = 6206321707968235495 + content = content + "23\n" + for v in table: + content = content + str(v) + "\n" + content = content + "100\n101\n102\n\n" + hash_file.write(content) + + cudf_tokenizer = SubwordTokenizer("voc_hash.txt") + + tokens, masks, metadata = cudf_tokenizer(sr, 8, 8) + expected_tokens = cupy.asarray( + [ + 2023, + 2003, + 1037, + 3231, + 0, + 0, + 0, + 0, + 1037, + 3231, + 2023, + 2003, + 0, + 0, + 0, + 0, + 2003, + 3231, + 1037, + 2023, + 0, + 0, + 0, + 0, + 3231, + 3231, + 0, + 0, + 0, + 0, + 0, + 0, + 2023, + 2023, + 0, + 0, + 0, + 0, + 0, + 0, + ], + dtype=np.uint32, + ) + assert_eq(expected_tokens, tokens) + + expected_masks = cupy.asarray( + [ + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + ], + dtype=np.uint32, + ) + assert_eq(expected_masks, masks) + + expected_metadata = cupy.asarray( + [0, 0, 3, 1, 0, 3, 2, 0, 3, 3, 0, 1, 4, 0, 1], dtype=np.uint32 + ) + assert_eq(expected_metadata, metadata) diff --git a/python/cudf/cudf/tests/test_text.py b/python/cudf/cudf/tests/test_text.py index fcae0a21b6a..a447a60c709 100644 --- a/python/cudf/cudf/tests/test_text.py +++ b/python/cudf/cudf/tests/test_text.py @@ -1,6 +1,5 @@ # Copyright (c) 2019, NVIDIA CORPORATION. -import cupy import numpy as np import pytest @@ -655,136 +654,6 @@ def test_text_filter_tokens_error_cases(): sr.str.filter_tokens(3, delimiter=["a", "b"]) -def test_text_subword_tokenize(tmpdir): - sr = cudf.Series( - [ - "This is a test", - "A test this is", - "Is test a this", - "Test test", - "this This", - ] - ) - hash_file = tmpdir.mkdir("nvtext").join("tmp_hashed_vocab.txt") - content = "1\n0\n23\n" - coefficients = [65559] * 23 - for c in coefficients: - content = content + str(c) + " 0\n" - # based on values from the bert_hash_table.txt file for the - # test words used here: 'this' 'is' 'a' test' - table = [0] * 23 - table[0] = 3015668 - table[1] = 6205475701751155871 - table[5] = 6358029 - table[16] = 451412625363 - table[20] = 6206321707968235495 - content = content + "23\n" - for v in table: - content = content + str(v) + "\n" - content = content + "100\n101\n102\n\n" - hash_file.write(content) - - tokens, masks, metadata = sr.str.subword_tokenize(str(hash_file), 8, 8) - expected_tokens = cupy.asarray( - [ - 2023, - 2003, - 1037, - 3231, - 0, - 0, - 0, - 0, - 1037, - 3231, - 2023, - 2003, - 0, - 0, - 0, - 0, - 2003, - 3231, - 1037, - 2023, - 0, - 0, - 0, - 0, - 3231, - 3231, - 0, - 0, - 0, - 0, - 0, - 0, - 2023, - 2023, - 0, - 0, - 0, - 0, - 0, - 0, - ], - dtype=np.uint32, - ) - assert_eq(expected_tokens, tokens) - - expected_masks = cupy.asarray( - [ - 1, - 1, - 1, - 1, - 0, - 0, - 0, - 0, - 1, - 1, - 1, - 1, - 0, - 0, - 0, - 0, - 1, - 1, - 1, - 1, - 0, - 0, - 0, - 0, - 1, - 1, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 1, - 0, - 0, - 0, - 0, - 0, - 0, - ], - dtype=np.uint32, - ) - assert_eq(expected_masks, masks) - - expected_metadata = cupy.asarray( - [0, 0, 3, 1, 0, 3, 2, 0, 3, 3, 0, 1, 4, 0, 1], dtype=np.uint32 - ) - assert_eq(expected_metadata, metadata) - - def test_edit_distance(): sr = cudf.Series(["kitten", "saturday", "address", "book"]) tg = cudf.Series(["sitting", "sunday", "addressee", "back"]) From cd13ce5799a03e5bb9a8e64044ae786df94b3602 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 4 Jan 2022 09:27:30 -0800 Subject: [PATCH 3/5] style fixes to cpp code --- cpp/src/text/subword/subword_tokenize.cu | 1 - cpp/tests/text/subword_tests.cpp | 10 ++++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/cpp/src/text/subword/subword_tokenize.cu b/cpp/src/text/subword/subword_tokenize.cu index 7ebc10446c3..193cd80d9a6 100644 --- a/cpp/src/text/subword/subword_tokenize.cu +++ b/cpp/src/text/subword/subword_tokenize.cu @@ -249,7 +249,6 @@ tokenizer_result subword_tokenize(cudf::strings_column_view const& strings, } // namespace detail - tokenizer_result subword_tokenize(cudf::strings_column_view const& strings, hashed_vocabulary const& vocabulary_table, uint32_t max_sequence_length, diff --git a/cpp/tests/text/subword_tests.cpp b/cpp/tests/text/subword_tests.cpp index 298a68940ea..521a082faa2 100644 --- a/cpp/tests/text/subword_tests.cpp +++ b/cpp/tests/text/subword_tests.cpp @@ -67,7 +67,7 @@ TEST(TextSubwordTest, Tokenize) cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end()); std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); - auto vocab = nvtext::load_vocabulary_file(hash_file); + auto vocab = nvtext::load_vocabulary_file(hash_file); uint32_t max_sequence_length = 16; uint32_t stride = 16; @@ -120,8 +120,7 @@ TEST(TextSubwordTest, TokenizeMultiRow) cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end()); std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); - auto vocab = nvtext::load_vocabulary_file(hash_file); - + auto vocab = nvtext::load_vocabulary_file(hash_file); uint32_t max_sequence_length = 8; uint32_t stride = 6; @@ -151,8 +150,7 @@ TEST(TextSubwordTest, TokenizeMaxEqualsTokens) cudf::test::strings_column_wrapper strings({"This is a test."}); std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); - auto vocab = nvtext::load_vocabulary_file(hash_file); - + auto vocab = nvtext::load_vocabulary_file(hash_file); uint32_t max_sequence_length = 5; // five tokens in strings; uint32_t stride = 5; // this should not effect the result @@ -180,7 +178,7 @@ TEST(TextSubwordTest, ParameterErrors) cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end()); std::string hash_file = temp_env->get_temp_filepath("hashed_vocab.txt"); create_hashed_vocab(hash_file); - auto vocab = nvtext::load_vocabulary_file(hash_file); + auto vocab = nvtext::load_vocabulary_file(hash_file); EXPECT_THROW(nvtext::subword_tokenize(cudf::strings_column_view{strings}, *vocab, From 56133b3803027782c0842699f95a48a07656d4ae Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 4 Jan 2022 11:00:04 -0800 Subject: [PATCH 4/5] remove tensorflow bug --- python/cudf/cudf/core/subword_tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/subword_tokenizer.py b/python/cudf/cudf/core/subword_tokenizer.py index 3502fc9acae..782b74ef4a6 100644 --- a/python/cudf/cudf/core/subword_tokenizer.py +++ b/python/cudf/cudf/core/subword_tokenizer.py @@ -21,7 +21,7 @@ def _cast_to_appropriate_type(ar, cast_type): from torch.utils.dlpack import from_dlpack elif cast_type == "tf": - from tf.experimental.dlpack import from_dlpack + from tensorflow.experimental.dlpack import from_dlpack return from_dlpack(ar.astype("int32").toDlpack()) From 62606d3c0f5d996697510972baa0f974a34403ec Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 5 Jan 2022 02:19:52 -0800 Subject: [PATCH 5/5] Fixed test for subword_tokenize --- python/cudf/cudf/tests/test_subword_tokenizer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/tests/test_subword_tokenizer.py b/python/cudf/cudf/tests/test_subword_tokenizer.py index 214c21c023b..ec6e0b30cb1 100644 --- a/python/cudf/cudf/tests/test_subword_tokenizer.py +++ b/python/cudf/cudf/tests/test_subword_tokenizer.py @@ -124,9 +124,16 @@ def test_text_subword_tokenize(tmpdir): content = content + "100\n101\n102\n\n" hash_file.write(content) - cudf_tokenizer = SubwordTokenizer("voc_hash.txt") + cudf_tokenizer = SubwordTokenizer(hash_file) - tokens, masks, metadata = cudf_tokenizer(sr, 8, 8) + token_d = cudf_tokenizer( + sr, 8, 8, add_special_tokens=False, truncation=True + ) + tokens, masks, metadata = ( + token_d["input_ids"], + token_d["attention_mask"], + token_d["metadata"], + ) expected_tokens = cupy.asarray( [ 2023, @@ -172,6 +179,7 @@ def test_text_subword_tokenize(tmpdir): ], dtype=np.uint32, ) + expected_tokens = expected_tokens.reshape(-1, 8) assert_eq(expected_tokens, tokens) expected_masks = cupy.asarray( @@ -219,9 +227,11 @@ def test_text_subword_tokenize(tmpdir): ], dtype=np.uint32, ) + expected_masks = expected_masks.reshape(-1, 8) assert_eq(expected_masks, masks) expected_metadata = cupy.asarray( [0, 0, 3, 1, 0, 3, 2, 0, 3, 3, 0, 1, 4, 0, 1], dtype=np.uint32 ) + expected_metadata = expected_metadata.reshape(-1, 3) assert_eq(expected_metadata, metadata)