From 675649e453e9cbba19104fa7649beb8346c06cbb Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 19 Aug 2024 06:55:48 +0000 Subject: [PATCH 1/4] Fix unit test --- tests/test_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_data.py b/tests/test_data.py index 28483c36..917f5e90 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -122,8 +122,8 @@ def setUp(self): ) def test_maybe_insert_system_message(self): - # does not accept system prompt - mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") + # does not accept system prompt. Use community checkpoint since it has no HF token requirement + mistral_tokenizer = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-Instruct-v0.3") # accepts system prompt. use codellama since it has no HF token requirement llama_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf") messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}] From 338fbb1d75ac2416e2c512fff87e5f01ef8c8f45 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 19 Aug 2024 10:24:53 +0000 Subject: [PATCH 2/4] Fix chat template tests --- src/alignment/data.py | 2 +- tests/test_data.py | 20 ++++++++++---------- tests/test_model_utils.py | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/alignment/data.py b/src/alignment/data.py index 84544c68..56a4af62 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -32,7 +32,7 @@ def maybe_insert_system_message(messages, tokenizer): # chat template can be one of two attributes, we check in order chat_template = tokenizer.chat_template if chat_template is None: - chat_template = tokenizer.default_chat_template + chat_template = tokenizer.get_chat_template() # confirm the jinja template refers to a system message before inserting if "system" in chat_template or "<|im_start|>" in chat_template: diff --git a/tests/test_data.py b/tests/test_data.py index 917f5e90..bcf600f5 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -122,21 +122,21 @@ def setUp(self): ) def test_maybe_insert_system_message(self): - # does not accept system prompt. Use community checkpoint since it has no HF token requirement - mistral_tokenizer = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-Instruct-v0.3") - # accepts system prompt. use codellama since it has no HF token requirement - llama_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf") + # Chat template that does not accept system prompt. Use community checkpoint since it has no HF token requirement + tokenizer_sys_excl = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-Instruct-v0.3") + # Chat template that accepts system prompt + tokenizer_sys_incl = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct") messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}] messages_sys_incl = [{"role": "system", "content": ""}, {"role": "user", "content": "Tell me a joke."}] - mistral_messages = deepcopy(messages_sys_excl) - llama_messages = deepcopy(messages_sys_excl) - maybe_insert_system_message(mistral_messages, mistral_tokenizer) - maybe_insert_system_message(llama_messages, llama_tokenizer) + messages_proc_excl = deepcopy(messages_sys_excl) + message_proc_incl = deepcopy(messages_sys_excl) + maybe_insert_system_message(messages_proc_excl, tokenizer_sys_excl) + maybe_insert_system_message(message_proc_incl, tokenizer_sys_incl) # output from mistral should not have a system message, output from llama should - self.assertEqual(mistral_messages, messages_sys_excl) - self.assertEqual(llama_messages, messages_sys_incl) + self.assertEqual(messages_proc_excl, messages_sys_excl) + self.assertEqual(message_proc_incl, messages_sys_incl) def test_sft(self): dataset = self.dataset.map( diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index 07bb6813..b19d5464 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -75,7 +75,7 @@ def test_default_chat_template_no_overwrite(self): processed_tokenizer = get_tokenizer(model_args, DataArguments()) assert getattr(processed_tokenizer, "chat_template") is None - self.assertEqual(base_tokenizer.default_chat_template, processed_tokenizer.default_chat_template) + self.assertEqual(base_tokenizer.get_chat_template(), processed_tokenizer.get_chat_template()) def test_chatml_chat_template(self): chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" From 2fe365264da21f7167a3ecc7277b409a3d44a9c6 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 19 Aug 2024 10:26:50 +0000 Subject: [PATCH 3/4] Remove deprecated test --- tests/test_model_utils.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index b19d5464..d15097ce 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -64,19 +64,6 @@ def test_default_chat_template(self): tokenizer = get_tokenizer(self.model_args, DataArguments()) self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE) - def test_default_chat_template_no_overwrite(self): - """ - If no chat template is passed explicitly in the config, then for models with a - `default_chat_template` but no `chat_template` we do not set a `chat_template`, - and that we do not change `default_chat_template` - """ - model_args = ModelArguments(model_name_or_path="m-a-p/OpenCodeInterpreter-SC2-7B") - base_tokenizer = AutoTokenizer.from_pretrained("m-a-p/OpenCodeInterpreter-SC2-7B") - processed_tokenizer = get_tokenizer(model_args, DataArguments()) - - assert getattr(processed_tokenizer, "chat_template") is None - self.assertEqual(base_tokenizer.get_chat_template(), processed_tokenizer.get_chat_template()) - def test_chatml_chat_template(self): chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template)) From ca4fe2744d5aef9fbe198cae314ea75e76aa99e9 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 19 Aug 2024 10:29:58 +0000 Subject: [PATCH 4/4] up --- tests/test_model_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index d15097ce..e0fc6fe2 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -15,7 +15,6 @@ import unittest import torch -from transformers import AutoTokenizer from alignment import ( DataArguments,