Skip to content

Commit

Permalink
Better handling missing SYS in llama conversation tokenizer (#24997)
Browse files Browse the repository at this point in the history
* Better handling missing SYS in llama conversation tokenizer

The existing code failed to add SYS if the conversation has history
without SYS, but did modify the passed conversation as it did.

Rearrange the code so modification to the conversation object are taken
into account for token id generation.

* Fix formatting with black

* Avoid one-liners

* Also fix fast tokenizer

* Drop List decl
  • Loading branch information
ichernev authored Jul 24, 2023
1 parent 6704923 commit efb2ba6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
21 changes: 11 additions & 10 deletions src/transformers/models/llama/tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,17 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in
`List[int]`:
Input ids for the conversation.
"""
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
)
elif conversation.new_user_input:
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
else:
raise ValueError("Last message must be from user")

dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
[not is_user for is_user, msg in dialogue[1::2]]
Expand All @@ -365,14 +376,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in
)

dialog_tokens: List[int] = []
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
)
elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]:
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1])

dialog_tokens += sum(
[
[self.bos_token_id]
Expand All @@ -384,8 +387,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[in
],
[],
)
if not (dialogue[-1][0]):
raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}")
dialog_tokens += [self.bos_token_id] + self.encode(
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
)
Expand Down
21 changes: 11 additions & 10 deletions src/transformers/models/llama/tokenization_llama_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,17 @@ def _build_conversation_input_ids(self, conversation: "Conversation"):
`List[int]`:
Input ids for the conversation.
"""
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
)
elif conversation.new_user_input:
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
else:
raise ValueError("Last message must be from user")

dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
[not is_user for is_user, msg in dialogue[1::2]]
Expand All @@ -221,14 +232,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation"):
)

dialog_tokens = []
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
)
elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]:
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1])

dialog_tokens += sum(
[
[self.bos_token_id]
Expand All @@ -240,8 +243,6 @@ def _build_conversation_input_ids(self, conversation: "Conversation"):
],
[],
)
if not (dialogue[-1][0]):
raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}")
dialog_tokens += [self.bos_token_id] + self.encode(
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
)
Expand Down

0 comments on commit efb2ba6

Please sign in to comment.