Skip to content

Commit

Permalink
Fix stopping strings for llama-3 and phi (oobabooga#6043)
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga authored and PoetOnTheRun committed Oct 22, 2024
1 parent 110f7d4 commit e63c816
Showing 1 changed file with 38 additions and 37 deletions.
75 changes: 38 additions & 37 deletions modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,34 +45,35 @@ def str_presenter(dumper, data):
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)


def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
def extract_message_prefix_suffix(renderer, strip_trailing_spaces=True):
'''
Given a Jinja template, reverse-engineers the prefix and the suffix for
an assistant message (if impersonate=False) or an user message
(if impersonate=True)
Given a Jinja template, extracts the prefix and suffix for
an assistant message and a user message. It assumes that they
share the same suffix.
'''

if impersonate:
messages = [
{"role": "user", "content": "<<|user-message-1|>>"},
{"role": "user", "content": "<<|user-message-2|>>"},
]
else:
messages = [
{"role": "assistant", "content": "<<|user-message-1|>>"},
{"role": "assistant", "content": "<<|user-message-2|>>"},
]
messages = [
{"role": "user", "content": "<<|user-message-1|>>"},
{"role": "assistant", "content": "<<|assistant-message-1|>>"},
{"role": "user", "content": "<<|user-message-2|>>"},
{"role": "assistant", "content": "<<|assistant-message-2|>>"},
]

prompt = renderer(messages=messages)
unwanted_suffix = renderer(messages=[])

suffix = prompt.split('<<|assistant-message-2|>>')[1]
if unwanted_suffix != '':
suffix = suffix[:-len(unwanted_suffix)]

suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
suffix = prompt.split("<<|user-message-2|>>")[1]
prefix = suffix_plus_prefix[len(suffix):]
prefix_user = prompt.split('<<|assistant-message-1|>>')[1].split('<<|user-message-2|>>')[0][len(suffix):]
prefix_assistant = prompt.split('<<|user-message-1|>>')[1].split('<<|assistant-message-1|>>')[0][len(suffix):]

if strip_trailing_spaces:
prefix = prefix.rstrip(' ')
prefix_user = prefix_user.rstrip(' ')
prefix_assistant = prefix_assistant.rstrip(' ')

return prefix, suffix
return prefix_user, prefix_assistant, suffix


def generate_chat_prompt(user_input, state, **kwargs):
Expand Down Expand Up @@ -125,7 +126,12 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "user", "content": user_input})

def remove_extra_bos(prompt):
for bos_token in ['<s>', '<|startoftext|>', '<BOS_TOKEN>', '<|endoftext|>']:
if hasattr(shared.tokenizer, 'bos_token_id'):
bos_tokens = [shared.tokenizer.decode(shared.tokenizer.bos_token_id)]
else:
bos_tokens = ['<s>', '<|startoftext|>', '<BOS_TOKEN>']

for bos_token in bos_tokens:
while prompt.startswith(bos_token):
prompt = prompt[len(bos_token):]

Expand All @@ -137,6 +143,9 @@ def make_prompt(messages):
else:
prompt = renderer(messages=messages)

prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue)
prefix = prefix_user if impersonate else prefix_assistant

if state['mode'] == 'chat-instruct':
outer_messages = []
if state['custom_system_message'].strip() != '':
Expand All @@ -148,29 +157,25 @@ def make_prompt(messages):
command = command.replace('<|prompt|>', prompt)
command = replace_character_names(command, state['name1'], state['name2'])


if _continue:
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
prefix += messages[-1]["content"]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
elif not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)

outer_messages.append({"role": "user", "content": command})
outer_messages.append({"role": "assistant", "content": prefix})

prompt = instruction_template.render(messages=outer_messages)
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]

else:

if _continue:
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)

Expand Down Expand Up @@ -249,15 +254,11 @@ def get_stopping_strings(state):
renderers.append(renderer)

for renderer in renderers:
prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)

stopping_strings += [
suffix_user + prefix_bot,
suffix_user + prefix_user,
suffix_bot + prefix_bot,
suffix_bot + prefix_user,
]
prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer)

for item in [suffix + prefix_assistant, suffix + prefix_user, suffix]:
stopping_strings.append(item)
stopping_strings.append(item.rstrip())

if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
stopping_strings += state.pop('stopping_strings')
Expand Down

0 comments on commit e63c816

Please sign in to comment.