Skip to content

Commit

Permalink
add roles for llama2
Browse files Browse the repository at this point in the history
  • Loading branch information
PhaneeshB committed Sep 11, 2023
1 parent c854208 commit fbdd77a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
28 changes: 20 additions & 8 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,14 +1782,26 @@ def autocomplete(self, prompt):
def create_prompt(model_name, history):
global start_message
system_message = start_message[model_name]
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[
f"{B_INST} {item[0].strip()} {E_INST} {item[1].strip()} "
for item in history[1:]
]
)
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"

else:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
return msg


Expand Down
21 changes: 12 additions & 9 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,28 @@ def user(message, history):
def create_prompt(model_name, history):
system_message = start_message[model_name]

if model_name in [
"vicuna",
"llama2_7b",
"llama2_13b",
"llama2_70b",
]:
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)

msg = system_message + conversation
msg = msg.strip()
msg = system_message + conversation
msg = msg.strip()
return msg


Expand Down

0 comments on commit fbdd77a

Please sign in to comment.