From 62606d3c0f5d996697510972baa0f974a34403ec Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 5 Jan 2022 02:19:52 -0800 Subject: [PATCH] Fixed test for subword_tokenize --- python/cudf/cudf/tests/test_subword_tokenizer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/tests/test_subword_tokenizer.py b/python/cudf/cudf/tests/test_subword_tokenizer.py index 214c21c023b..ec6e0b30cb1 100644 --- a/python/cudf/cudf/tests/test_subword_tokenizer.py +++ b/python/cudf/cudf/tests/test_subword_tokenizer.py @@ -124,9 +124,16 @@ def test_text_subword_tokenize(tmpdir): content = content + "100\n101\n102\n\n" hash_file.write(content) - cudf_tokenizer = SubwordTokenizer("voc_hash.txt") + cudf_tokenizer = SubwordTokenizer(hash_file) - tokens, masks, metadata = cudf_tokenizer(sr, 8, 8) + token_d = cudf_tokenizer( + sr, 8, 8, add_special_tokens=False, truncation=True + ) + tokens, masks, metadata = ( + token_d["input_ids"], + token_d["attention_mask"], + token_d["metadata"], + ) expected_tokens = cupy.asarray( [ 2023, @@ -172,6 +179,7 @@ def test_text_subword_tokenize(tmpdir): ], dtype=np.uint32, ) + expected_tokens = expected_tokens.reshape(-1, 8) assert_eq(expected_tokens, tokens) expected_masks = cupy.asarray( @@ -219,9 +227,11 @@ def test_text_subword_tokenize(tmpdir): ], dtype=np.uint32, ) + expected_masks = expected_masks.reshape(-1, 8) assert_eq(expected_masks, masks) expected_metadata = cupy.asarray( [0, 0, 3, 1, 0, 3, 2, 0, 3, 3, 0, 1, 4, 0, 1], dtype=np.uint32 ) + expected_metadata = expected_metadata.reshape(-1, 3) assert_eq(expected_metadata, metadata)