Skip to content

Commit

Permalink
More prompting and better events system
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyaGusev committed Jan 27, 2024
1 parent faa49ae commit 442684e
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 177 deletions.
232 changes: 90 additions & 142 deletions gradio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,71 +6,38 @@
import fire

from tale_studio.recurrentgpt import RecurrentGPT, State
from tale_studio.embedders import EMBEDDER_LIST, DEFAULT_EMBEDDER_NAME
from tale_studio.embedders import EMBEDDER_LIST
from tale_studio.utils import OPENAI_MODELS
from tale_studio.human_simulator import Human
from tale_studio.files import LOCAL_MODELS_LIST, SAVES_DIR_PATH
from tale_studio.prompt_templates import PROMPT_TEMPLATE_LIST, DEFAULT_PROMPT_TEMPLATE_NAME, PROMPT_TEMPLATES
from tale_studio.prompt_templates import PROMPT_TEMPLATE_LIST, PROMPT_TEMPLATES, DEFAULT_PROMPT_TEMPLATE_NAME
from tale_studio.model_settings import ModelSettings


API_KEY = os.getenv("OPENAI_API_KEY", None)
MODEL_LIST = list(OPENAI_MODELS) + list(LOCAL_MODELS_LIST)
DEFAULT_MODEL_NAME = "gpt-3.5-turbo-16k"
DEFAULT_NOVEL_TYPE = "Science Fiction"
DEFAULT_DESCRIPTION = "Рассказ на русском языке в сеттинге коммунизма в высокотехнологичном будущем"
API_KEY = os.getenv("OPENAI_API_KEY", None)


def validate_inputs(
model_name,
prompt_template,
api_key
):
if prompt_template == "openai" and model_name not in OPENAI_MODELS:
def validate_inputs(model_state):
if model_state.prompt_template == "openai" and model_state.model_name not in OPENAI_MODELS:
raise gr.Error("Please set the correct prompt template!")
if model_name in OPENAI_MODELS and not API_KEY and not api_key:
if model_state.model_name in OPENAI_MODELS and not API_KEY and not model_state.api_key:
raise gr.Error("Please set the API key!")


def generate_plan(
novel_type,
description,
model_name,
prompt_template,
embedder_name,
api_key
):
validate_inputs(model_name, prompt_template, api_key)
writer = RecurrentGPT(
embedder_name=embedder_name,
model_name=model_name,
prompt_template=prompt_template,
api_key=api_key
)
def generate_plan(novel_type, description, model_state):
validate_inputs(model_state)
writer = RecurrentGPT(model_state)
state = writer.generate_plan(novel_type=novel_type, description=description)
return (state, state.name, state.synopsis, state.plan)
return (state, state.name, state.language, state.synopsis, state.plan)


def generate_first_paragraphs(
state,
name,
synopsis,
plan,
model_name,
prompt_template,
embedder_name,
api_key
):
validate_inputs(model_name, prompt_template, api_key)
writer = RecurrentGPT(
embedder_name=embedder_name,
model_name=model_name,
prompt_template=prompt_template,
api_key=api_key
)

state.name = name
state.synopsis = synopsis
state.plan = plan
def generate_first_paragraphs(state, model_state):
assert state is not None
validate_inputs(model_state)
writer = RecurrentGPT(model_state)
state = writer.generate_first_paragraphs(state)
return (
state,
Expand All @@ -82,37 +49,18 @@ def generate_first_paragraphs(
)


def step(
state,
plan,
short_memory,
paragraphs,
selected_instruction,
model_name,
prompt_template,
embedder_name,
api_key,
selection_mode,
):
validate_inputs(model_name, prompt_template, api_key)
writer = RecurrentGPT(
embedder_name=embedder_name,
model_name=model_name,
prompt_template=prompt_template
)

def step(state, model_state, selection_mode):
assert state is not None
state.instruction = selected_instruction
state.short_memory = short_memory
state.paragraphs = [p.strip() for p in paragraphs.split("\n\n") if p.strip()]
validate_inputs(model_state)
writer = RecurrentGPT(model_state)

if selection_mode == "gpt":
human = Human(model_name=model_name, prompt_template=prompt_template)
human = Human(model_name=model_state.model_name, prompt_template=model_state.prompt_template)
state = human.step(state)
elif selection_mode == "random":
state.instruction = random.choice(state.next_instructions)
else:
assert selected_instruction
assert instruction

state = writer.step(state)

Expand All @@ -131,22 +79,12 @@ def save(
file_name,
root_dir,
state,
name,
synopsis,
plan,
short_memory,
paragraphs,
):
if not file_name:
raise gr.Error("File name should not be empty")
if not name:
raise gr.Error("Please set a name of the story")

state.name = name
state.synopsis = synopsis
state.plan = plan
state.short_memory = short_memory
state.paragraphs = [p.strip() for p in paragraphs.split("\n\n") if p.strip()]
with open(os.path.join(root_dir, file_name), "w") as w:
json.dump(state.to_dict(), w, ensure_ascii=False, indent=4)

Expand Down Expand Up @@ -191,22 +129,25 @@ def on_selection_mode_select(evt: gr.SelectData):
return gr.Row.update(visible=is_manual)


def on_model_name_select(evt: gr.SelectData):
def on_model_name_select(model_state, evt: gr.SelectData):
value = evt.value
is_local = "gguf" in value
return gr.update(visible=not is_local)
model_state.model_name = value
return gr.update(visible=not is_local), model_state


def on_prompt_template_name_select(prompt_template_text, evt: gr.SelectData):
def on_prompt_template_name_select(model_state, prompt_template, evt: gr.SelectData):
value = evt.value
is_custom = "custom" in value
prompt_template_text = PROMPT_TEMPLATES[value]
prompt_template = PROMPT_TEMPLATES[value]
is_openai = "openai" in value
return gr.update(value=prompt_template_text, interactive=is_custom, visible=not is_openai)
model_state.prompt_template = prompt_template
return model_state, gr.update(value=prompt_template, interactive=is_custom, visible=not is_openai)


with gr.Blocks(title="TaleStudio", css="footer {visibility: hidden}") as demo:
state = gr.State(None)
state = gr.State(State())
model_state = gr.State(ModelSettings())
gr.Markdown("# Tale Studio")
with gr.Tab("Main"):
with gr.Row():
Expand Down Expand Up @@ -246,22 +187,27 @@ def on_prompt_template_name_select(prompt_template_text, evt: gr.SelectData):
)

with gr.Row():
with gr.Column(scale=1):
with gr.Column(scale=5):
name = gr.Textbox(
label="Name (editable)",
max_lines=1,
lines=1
)
language = gr.Textbox(
label="Language (editable)",
max_lines=1,
lines=1
)
synopsis = gr.Textbox(
label="Synopsis (editable)",
max_lines=8,
lines=8
)
with gr.Column(scale=3):
with gr.Column(scale=9):
plan = gr.Textbox(
label="Outline (editable)",
lines=13,
max_lines=13
lines=17,
max_lines=17
)

with gr.Group():
Expand Down Expand Up @@ -301,7 +247,7 @@ def on_prompt_template_name_select(prompt_template_text, evt: gr.SelectData):
["Instruction 1", "Instruction 2", "Instruction 3"],
label="Instruction Selection",
)
selected_instruction = gr.Textbox(
instruction = gr.Textbox(
label="Selected Instruction (editable)",
max_lines=5,
lines=5,
Expand Down Expand Up @@ -339,7 +285,7 @@ def on_prompt_template_name_select(prompt_template_text, evt: gr.SelectData):
with gr.Group():
model_name = gr.Dropdown(
MODEL_LIST,
value=DEFAULT_MODEL_NAME,
value=model_state.value.model_name,
multiselect=False,
label="Model name",
)
Expand All @@ -362,77 +308,77 @@ def on_prompt_template_name_select(prompt_template_text, evt: gr.SelectData):
)
embedder_name = gr.Dropdown(
EMBEDDER_LIST,
value=DEFAULT_EMBEDDER_NAME,
value=model_state.value.embedder_name,
multiselect=False,
label="Embedder name",
)

# Sync inputs
def set_field(state, field, value):
setattr(state, field, value)
return state

state_fields = {
"name": name,
"language": language,
"description": description,
"novel_type": novel_type,
"synopsis": synopsis,
"plan": plan,
"instruction": instruction,
"short_memory": short_memory
}
for key, field in state_fields.items():
field.change((lambda s, f, k=key: set_field(s, k, f)), [state, field], state)

def set_paragraphs(state, paragraphs):
state.paragraphs = [p.strip() for p in paragraphs.split("\n\n") if p.strip()]
return state

paragraphs.change(set_paragraphs, [state, paragraphs], state)

model_state_fields = {
"model_name": model_name,
"prompt_template": prompt_template,
"embedder_name": embedder_name,
"api_key": api_key
}
for key, field in model_state_fields.items():
field.change((lambda s, f, k=key: set_field(s, k, f)), [model_state, field], model_state)

# Main events
btn_init.click(
generate_plan,
inputs=[
novel_type,
description,
model_name,
prompt_template,
embedder_name,
api_key
],
outputs=[state, name, synopsis, plan]
inputs=[novel_type, description, model_state],
outputs=[state, name, language, synopsis, plan]
).success(
generate_first_paragraphs,
inputs=[
state,
name,
synopsis,
plan,
model_name,
prompt_template,
embedder_name,
api_key
],
inputs=[state, model_state],
outputs=[state, short_memory, paragraphs, instruction1, instruction2, instruction3]
)

btn_step.click(
step,
inputs=[
state,
plan,
short_memory,
paragraphs,
selected_instruction,
model_name,
prompt_template,
embedder_name,
api_key,
selection_mode,
],
inputs=[state, model_state, selection_mode],
outputs=[
state,
short_memory,
paragraphs,
instruction1,
instruction2,
instruction3,
selected_instruction
instruction
]
)

# Save/Load
btn_save.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[file_saver, save_load_buttons]
)
btn_confirm_save.click(
save,
inputs=[
save_filename,
save_root,
state,
name,
synopsis,
plan,
short_memory,
paragraphs,
]
inputs=[save_filename, save_root, state]
).success(
lambda: (gr.update(visible=False), gr.update(visible=True)),
outputs=[file_saver, save_load_buttons]
Expand Down Expand Up @@ -486,10 +432,12 @@ def on_prompt_template_name_select(prompt_template_text, evt: gr.SelectData):
instruction3,
]
)

# Other events
selected_plan.select(
on_selected_plan_select,
inputs=[instruction1, instruction2, instruction3],
outputs=[selected_instruction]
outputs=[instruction]
)
selection_mode.select(
on_selection_mode_select,
Expand All @@ -498,13 +446,13 @@ def on_prompt_template_name_select(prompt_template_text, evt: gr.SelectData):
)
model_name.select(
on_model_name_select,
inputs=[],
outputs=[api_key]
inputs=[model_state],
outputs=[api_key, model_state]
)
prompt_template_name.select(
on_prompt_template_name_select,
inputs=[prompt_template],
outputs=[prompt_template]
inputs=[model_state, prompt_template],
outputs=[model_state, prompt_template]
)
demo.queue()

Expand Down
1 change: 0 additions & 1 deletion tale_studio/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@ def get_embedder(cls, embedder_name: str):
"embaas/sentence-transformers-multilingual-e5-base",
"sentence-transformers/multi-qa-mpnet-base-cos-v1"
]
DEFAULT_EMBEDDER_NAME = "embaas/sentence-transformers-multilingual-e5-base"
Loading

0 comments on commit 442684e

Please sign in to comment.