Skip to content

Commit

Permalink
Fix ConversationalPipeline tests (huggingface#26217)
Browse files Browse the repository at this point in the history
Add BlenderbotSmall templates and correct handling for conversation.past_user_inputs
  • Loading branch information
Rocketknight1 authored and parambharat committed Sep 26, 2023
1 parent 52e88ee commit 25b4f58
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,18 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
index += 1

return vocab_file, merge_file

@property
# Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
def default_chat_template(self):
"""
A very simple chat template that just adds whitespace between messages.
"""
return (
"{% for message in messages %}"
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
"{{ message['content'] }}"
"{% if not loop.last %}{{ ' ' }}{% endif %}"
"{% endfor %}"
"{{ eos_token }}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,18 @@ def create_token_type_ids_from_sequences(
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

@property
# Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
def default_chat_template(self):
"""
A very simple chat template that just adds whitespace between messages.
"""
return (
"{% for message in messages %}"
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
"{{ message['content'] }}"
"{% if not loop.last %}{{ ' ' }}{% endif %}"
"{% endfor %}"
"{{ eos_token }}"
)
10 changes: 5 additions & 5 deletions tests/pipelines/test_pipelines_conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def test_integration_torch_conversation(self):
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_integration_torch_conversation_truncated_history(self):
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_1.past_user_inputs), 1)
# When
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
# Then
Expand Down Expand Up @@ -379,8 +379,8 @@ def test_integration_torch_conversation_encoder_decoder(self):
conversation_1 = Conversation("My name is Sarah and I live in London")
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
Expand Down

0 comments on commit 25b4f58

Please sign in to comment.