From 9132353c35ce21230df487d3b8b6f997b6b03850 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 20 Sep 2023 17:19:30 +0200 Subject: [PATCH] include changes from llama (#26260) * include changes from llama * add a test --- .../models/code_llama/tokenization_code_llama.py | 2 ++ .../code_llama/test_tokenization_code_llama.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/transformers/models/code_llama/tokenization_code_llama.py b/src/transformers/models/code_llama/tokenization_code_llama.py index da1012095cfb23..1dbe6731852eed 100644 --- a/src/transformers/models/code_llama/tokenization_code_llama.py +++ b/src/transformers/models/code_llama/tokenization_code_llama.py @@ -293,6 +293,8 @@ def _tokenize(self, text, **kwargs): `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. """ tokens = self.sp_model.encode(text, out_type=str) + if not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] diff --git a/tests/models/code_llama/test_tokenization_code_llama.py b/tests/models/code_llama/test_tokenization_code_llama.py index fa39a0571d5d36..beab1a5b1b89ee 100644 --- a/tests/models/code_llama/test_tokenization_code_llama.py +++ b/tests/models/code_llama/test_tokenization_code_llama.py @@ -559,6 +559,18 @@ def test_special_token_special_word(self): decoded_tokens = tokenizer.decode(input_ids) self.assertEqual(decoded_tokens, " Hello how") + def test_spm_edge_cases(self): + # the word inform should be split as ['in', 'form'] + tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", legacy=False) + tokens = tokenizer.tokenize("[INST] How are you doing?[/INST]") + self.assertEqual( + tokens, ["▁[", "INST", "]", "▁How", "▁are", "▁you", "▁doing", "?", "", "[", "/", "INST", "]"] + ) + inputs_ids = tokenizer.encode("[INST] How are you doing?[/INST]") + self.assertEqual( + inputs_ids, [1, 518, 25580, 29962, 1128, 526, 366, 2599, 29973, 1, 29961, 29914, 25580, 29962] + ) + def test_infilling_tokenization(self): PROMPTS = [ '''def remove_non_ascii(s: str) -> str: