Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Apr 23, 2024
1 parent ca25a81 commit 7bfe577
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
14 changes: 8 additions & 6 deletions tests/models/gemma/test_tokenization_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
get_tests_dir,
nested_simplify,
require_jinja,
require_read_token,
require_sentencepiece,
require_tokenizers,
require_torch,
require_read_token,
slow,
)

Expand Down Expand Up @@ -320,11 +320,13 @@ def test_integration_test_xnli(self):
encoded1 = pyth_tokenizer.encode(string)
encoded2 = rust_tokenizer.encode(string)

self.assertEqual(encoded1, encoded2, msg=
"Hint: the following tokenization diff were obtained for slow vs fast:\n "
f"elements in slow: {set(pyth_tokenizer.tokenize(string))-set(rust_tokenizer.tokenize(string))} \nvs\n "
f"elements in fast: {set(rust_tokenizer.tokenize(string))-set(pyth_tokenizer.tokenize(string))} \n\n{string}"
)
self.assertEqual(
encoded1,
encoded2,
msg="Hint: the following tokenization diff were obtained for slow vs fast:\n "
f"elements in slow: {set(pyth_tokenizer.tokenize(string))-set(rust_tokenizer.tokenize(string))} \nvs\n "
f"elements in fast: {set(rust_tokenizer.tokenize(string))-set(pyth_tokenizer.tokenize(string))} \n\n{string}",
)

decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = rust_tokenizer.decode(encoded1, skip_special_tokens=True)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/llama/test_tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,9 @@ def test_special_token_special_word(self):
self.assertEqual(decoded_tokens, "hello")

def test_no_prefix_space(self):
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True, add_prefix_space=False)
tokenizer = LlamaTokenizerFast.from_pretrained(
"huggyllama/llama-7b", legacy=False, from_slow=True, add_prefix_space=False
)
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)

example_inputs = tokenizer.tokenize("<REPR_END>inform<s>. Hey. .")
Expand Down

0 comments on commit 7bfe577

Please sign in to comment.