Skip to content

Commit

Permalink
Retain newlines in chat template when continue_final_message=True (#…
Browse files Browse the repository at this point in the history
…34253)

* Retain newlines in chat template when

* Add try/except

* Add regression test

* Simplify test

* Apply suggestions from code review

Co-authored-by: Matt <[email protected]>

---------

Co-authored-by: Matt <[email protected]>
  • Loading branch information
lewtun and Rocketknight1 authored Nov 15, 2024
1 parent a3d69a8 commit 8ba3e15
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,8 +1690,12 @@ def apply_chat_template(
final_message = chat[-1]["content"]
if isinstance(final_message, (list, tuple)):
final_message = final_message[-1]["text"]
final_message = final_message.strip()
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip()
try:
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)]
except: # noqa: E722
# Some chat templates like Llama-3.1 trim messages before rendering, so we must do the same here.
final_message = final_message.strip()
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)]
rendered.append(rendered_chat)

if not is_batched:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,38 @@ def test_continue_final_message(self):
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
)

@require_jinja
def test_continue_final_message_with_trim(self):
"""Regression test for chat templates with trimming: https://github.com/huggingface/transformers/pull/34214"""

dummy_template = """
{%- for message in messages %}
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] | trim + "<|im_end|>" + "\n"}}
{%- endfor %}"""
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message "}, # Note the trailing whitespace
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
)
self.assertEqual(
output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
)
prefill_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
)
# Assert that the final message is unterminated
self.assertEqual(
prefill_output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
)

@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
Expand Down

0 comments on commit 8ba3e15

Please sign in to comment.