From efb2ba666da57a92a8ea47c62ed9e90fbf8aa344 Mon Sep 17 00:00:00 2001 From: Iskren Ivov Chernev Date: Mon, 24 Jul 2023 16:21:10 +0300 Subject: [PATCH] Better handling missing SYS in llama conversation tokenizer (#24997) * Better handling missing SYS in llama conversation tokenizer The existing code failed to add SYS if the conversation has history without SYS, but did modify the passed conversation as it did. Rearrange the code so modification to the conversation object are taken into account for token id generation. * Fix formatting with black * Avoid one-liners * Also fix fast tokenizer * Drop List decl --- .../models/llama/tokenization_llama.py | 21 ++++++++++--------- .../models/llama/tokenization_llama_fast.py | 21 ++++++++++--------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py index ddc643a360f28e..00dbc117a4c5a5 100644 --- a/src/transformers/models/llama/tokenization_llama.py +++ b/src/transformers/models/llama/tokenization_llama.py @@ -356,6 +356,17 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in `List[int]`: Input ids for the conversation. """ + if len(conversation.past_user_inputs) > 0: + if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: + conversation.past_user_inputs[0] = ( + B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] + ) + elif conversation.new_user_input: + if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: + conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input + else: + raise ValueError("Last message must be from user") + dialogue = list(conversation.iter_texts()) if not all([is_user for is_user, msg in dialogue[::2]]) or not all( [not is_user for is_user, msg in dialogue[1::2]] @@ -365,14 +376,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in ) dialog_tokens: List[int] = [] - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]: - dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1]) - dialog_tokens += sum( [ [self.bos_token_id] @@ -384,8 +387,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in ], [], ) - if not (dialogue[-1][0]): - raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}") dialog_tokens += [self.bos_token_id] + self.encode( f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False ) diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py index c04e2da114cc12..82dfbe8925f04a 100644 --- a/src/transformers/models/llama/tokenization_llama_fast.py +++ b/src/transformers/models/llama/tokenization_llama_fast.py @@ -212,6 +212,17 @@ def _build_conversation_input_ids(self, conversation: "Conversation"): `List[int]`: Input ids for the conversation. """ + if len(conversation.past_user_inputs) > 0: + if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: + conversation.past_user_inputs[0] = ( + B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] + ) + elif conversation.new_user_input: + if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: + conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input + else: + raise ValueError("Last message must be from user") + dialogue = list(conversation.iter_texts()) if not all([is_user for is_user, msg in dialogue[::2]]) or not all( [not is_user for is_user, msg in dialogue[1::2]] @@ -221,14 +232,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation"): ) dialog_tokens = [] - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]: - dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1]) - dialog_tokens += sum( [ [self.bos_token_id] @@ -240,8 +243,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation"): ], [], ) - if not (dialogue[-1][0]): - raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}") dialog_tokens += [self.bos_token_id] + self.encode( f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False )