diff --git a/gradio_server.py b/gradio_server.py index 35659e8..17dbd79 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -27,18 +27,24 @@ def validate_inputs(model_state): raise gr.Error("Please set the API key!") -def generate_plan(novel_type, description, model_state): +def generate_name(state, model_state): + writer = RecurrentGPT(model_state) + state.name = writer.generate_name(state) + return (state, state.name) + + +def generate_meta(novel_type, description, model_state): validate_inputs(model_state) writer = RecurrentGPT(model_state) - state = writer.generate_plan(novel_type=novel_type, description=description) - return (state, state.name, state.language, state.synopsis, state.plan) + state = writer.generate_meta(novel_type=novel_type, description=description) + return (state, state.name, state.language, state.synopsis, state.outline) -def generate_first_paragraphs(state, model_state): +def generate_first_step(state, model_state): assert state is not None validate_inputs(model_state) writer = RecurrentGPT(model_state) - state = writer.generate_first_paragraphs(state) + state = writer.generate_first_step(state) return ( state, state.short_memory, @@ -96,7 +102,7 @@ def load(file_name): state, state.name, state.synopsis, - state.plan, + state.outline, state.short_memory, "\n\n".join(state.paragraphs), state.next_instructions[0], @@ -193,6 +199,10 @@ def on_prompt_template_name_select(model_state, prompt_template, evt: gr.SelectD max_lines=1, lines=1 ) + btn_gen_name = gr.Button( + "Regenerate", + variant="primary" + ) language = gr.Textbox( label="Language (editable)", max_lines=1, @@ -204,7 +214,7 @@ def on_prompt_template_name_select(model_state, prompt_template, evt: gr.SelectD lines=8 ) with gr.Column(scale=9): - plan = gr.Textbox( + outline = gr.Textbox( label="Outline (editable)", lines=17, max_lines=17 @@ -336,7 +346,7 @@ def on_prompt_template_name_select(model_state, prompt_template, evt: gr.SelectD value=model_state.value.generation_params.repetition_penalty, step=0.05, interactive=True, - label="Rep" + label="Repetition penalty" ) with gr.Column(scale=1, min_width=200): top_p = gr.Slider( @@ -367,7 +377,7 @@ def set_field(state, field, value): "description": description, "novel_type": novel_type, "synopsis": synopsis, - "plan": plan, + "outline": outline, "instruction": instruction, "short_memory": short_memory } @@ -404,11 +414,11 @@ def set_param(state, field, value): # Main events btn_init.click( - generate_plan, + generate_meta, inputs=[novel_type, description, model_state], - outputs=[state, name, language, synopsis, plan] + outputs=[state, name, language, synopsis, outline] ).success( - generate_first_paragraphs, + generate_first_step, inputs=[state, model_state], outputs=[state, short_memory, paragraphs, instruction1, instruction2, instruction3] ) @@ -426,6 +436,11 @@ def set_param(state, field, value): instruction ] ) + btn_gen_name.click( + generate_name, + inputs=[state, model_state], + outputs=[state, name] + ) # Save/Load btn_save.click( @@ -461,7 +476,7 @@ def set_param(state, field, value): state, name, synopsis, - plan, + outline, short_memory, paragraphs, instruction1, @@ -480,7 +495,7 @@ def set_param(state, field, value): state, name, synopsis, - plan, + outline, short_memory, paragraphs, instruction1, diff --git a/tale_studio/prompts/begin.jinja b/tale_studio/prompts/first_paragraphs.jinja similarity index 98% rename from tale_studio/prompts/begin.jinja rename to tale_studio/prompts/first_paragraphs.jinja index 41e25a3..8a1d6d7 100644 --- a/tale_studio/prompts/begin.jinja +++ b/tale_studio/prompts/first_paragraphs.jinja @@ -10,7 +10,7 @@ Synopsis: {{synopsis}} Outline: -{{plan}} +{{outline}} Remember, you are an author narrating events. Events should be narrated in the third person limited perspective and contain dialogues between the characters present. Do not conclude your output, leave it open for continuation. Write in a novelistic style and take your time to set the scene. Write 3 paragraphs below: diff --git a/tale_studio/prompts/process.jinja b/tale_studio/prompts/first_summary.jinja similarity index 89% rename from tale_studio/prompts/process.jinja rename to tale_studio/prompts/first_summary.jinja index 665dfdc..018ce11 100644 --- a/tale_studio/prompts/process.jinja +++ b/tale_studio/prompts/first_summary.jinja @@ -2,7 +2,7 @@ Write a summary that captures the critical information of the paragraphs. Then write 3 different instructions for what to write next, each containing around 5 sentences. -Each instruction should present a possible, exciting story continuation that fits the global plan presented below. +Each instruction should present a possible, exciting story continuation that fits the global outline presented below. The output should be in JSON with the following fields: { @@ -17,7 +17,7 @@ Synopsis: {{synopsis}} Outline: -{{plan}} +{{outline}} Paragraphs: {{paragraphs}} diff --git a/tale_studio/prompts/instruct.jinja b/tale_studio/prompts/instruct.jinja index aee9dfe..df5b7b1 100644 --- a/tale_studio/prompts/instruct.jinja +++ b/tale_studio/prompts/instruct.jinja @@ -1,5 +1,5 @@ You should output 3 different instructions, each is a possible interesting continuation of the story. -Each output instruction should contain around 5 sentences and fit the global plan. +Each output instruction should contain around 5 sentences and fit the global outline. Think about what plot can be attractive for common readers when writing output instructions. Write and organize your output by strictly following the format below using JSON. @@ -14,7 +14,7 @@ Use strictly this language: {{language}} Here is the outline: #### -{{plan}} +{{outline}} #### Previous short summary: diff --git a/tale_studio/prompts/plan.jinja b/tale_studio/prompts/meta.jinja similarity index 65% rename from tale_studio/prompts/plan.jinja rename to tale_studio/prompts/meta.jinja index ec168b2..e509189 100644 --- a/tale_studio/prompts/plan.jinja +++ b/tale_studio/prompts/meta.jinja @@ -3,9 +3,9 @@ Please write a detailed outline for a {% if not novel_type %}Science Fiction{% e Begin with a catchy name of the novel and determine its main language. Then, generate a summary of the whole novel containing its main idea. -Then, plan all 6 or 7 chapters of the novel. -The story should have an ending. +Then, write an outline with all 6 or 7 chapters of the novel. Chapter summaries should contain at least 2 sentences. +The story should have an ending. Write in a novelistic style. Write all the fields in the selected langauge. @@ -14,9 +14,17 @@ The output should be in JSON with the following fields: "language": , "name": , "synopsis": , - "chapter_summaries": [ - "Chapter 1. ", - ..., - "Chapter 6. " + "outline": [ + { + "index": 1, + "chapter_name": , + "chapter_summary": , + }, + ... + { + "index": 6, + "chapter_name": , + "chapter_summary": , + } ] } diff --git a/tale_studio/prompts/name.jinja b/tale_studio/prompts/name.jinja new file mode 100644 index 0000000..b35b677 --- /dev/null +++ b/tale_studio/prompts/name.jinja @@ -0,0 +1,12 @@ +Generate a name for the novel. + +Idea: {{description}} +Novel type: {{novel_type}} +Synopsis: {{synopsis}} +Outline: {{outline}} +Language: {{language}} + +Format output as a JSON with one key, "name": +{ + "name": "..." +} diff --git a/tale_studio/prompts/output.jinja b/tale_studio/prompts/output.jinja index 2a2f742..a8ff885 100644 --- a/tale_studio/prompts/output.jinja +++ b/tale_studio/prompts/output.jinja @@ -9,7 +9,7 @@ Use strictly this language: {{language}}. Do not switch to another language. Do not write any chapter numbers. Outline of possible past and future events: -{{plan}} +{{outline}} Summary of previous events: {{short_memory}} diff --git a/tale_studio/recurrentgpt.py b/tale_studio/recurrentgpt.py index d409a4e..7575e1a 100644 --- a/tale_studio/recurrentgpt.py +++ b/tale_studio/recurrentgpt.py @@ -12,7 +12,7 @@ class State: name: str = "" synopsis: str = "" - plan: str = "" + outline: str = "" novel_type: str = "" language: str = "English" description: str = "" @@ -67,118 +67,112 @@ def step(self, state: State): state.memory_index ) - output_paragraph = self.output( - plan=state.plan, + output_paragraph = self._complete_text( + "output", + outline=state.outline, language=state.language, short_memory=state.short_memory, input_paragraph=state.paragraphs[-1], input_instruction=state.instruction, input_long_term_memory=formatted_long_memory, ) + output_paragraph = " ".join([p.strip() for p in output_paragraph.split("\n") if p.strip()]) state.paragraphs.append(output_paragraph) state.update_index(self.embedder, self.passage_prefix) - state.short_memory = self.summarize( + state.short_memory = self._complete_json( + "summarize", language=state.language, short_memory=state.short_memory, input_paragraph=state.paragraphs[-2], - ) + )["updated_memory"] - state.next_instructions = self.instruct( + output = self._complete_json( + "instruct", language=state.language, short_memory=state.short_memory, output_paragraph=state.paragraphs[-1], - plan=state.plan, + outline=state.outline, input_long_term_memory=formatted_long_memory, ) - - return state - - def output(self, **kwargs): - prompt = encode_prompt("output.jinja", **kwargs) - print("OUTPUT PROMPT") - print(prompt) - print() - output_paragraph = self._complete_text(prompt) - output_paragraph = " ".join([p.strip() for p in output_paragraph.split("\n") if p.strip()]) - print("OUTPUT") - print(output_paragraph) - print("===========") - return output_paragraph - - def summarize(self, **kwargs): - prompt = encode_prompt("summarize.jinja", **kwargs) - print("SUMMARIZE PROMPT") - print(prompt) - print() - output = self._complete_json(prompt) - print("SUMMARIZE OUTPUT") - print(json.dumps(output, ensure_ascii=False, indent=4)) - print("===========") - return output["updated_memory"] - - def instruct(self, **kwargs): - prompt = encode_prompt("instruct.jinja", **kwargs) - print("INSTRUCT PROMPT") - print(prompt) - print() - output = self._complete_json(prompt) - print("INSTRUCT OUTPUT") - print(json.dumps(output, ensure_ascii=False, indent=4)) - print("===========") - return [ + state.next_instructions = [ output["instruction_1"].strip(), output["instruction_2"].strip(), output["instruction_3"].strip(), ] - def generate_plan( + return state + + def generate_name( + self, + state: State + ): + return self._complete_json( + "name", + novel_type=state.novel_type, + description=state.description, + synopsis=state.synopsis, + outline=state.outline, + language=state.language + )["name"] + + def generate_meta( self, description: str, novel_type: str, ): - plan_prompt = encode_prompt( - "plan.jinja", - description=description, - novel_type=novel_type - ) - print("PLAN PROMPT") - print(plan_prompt) - print() - plan_info = self._complete_json(plan_prompt) - print("PLAN OUTPUT") - print(json.dumps(plan_info, ensure_ascii=False, indent=4)) - print("===========") + while True: + try: + info = self._complete_json( + "meta", + description=description, + novel_type=novel_type + ) + outline = info["outline"] + + assert isinstance(outline, list) + assert outline + keys = ("index", "chapter_name", "chapter_summary") + assert all(key in outline[0] for key in keys) + + template = "Chapter {index}: {chapter_name}. {chapter_summary}" + chapters = [template.format(**ch) for ch in outline] + outline = "\n".join(chapters) + + break + except AssertionError: + continue - chapter_summaries = plan_info["chapter_summaries"] - if isinstance(chapter_summaries, dict): - chapter_summaries = [" ".join((k, v)) for k, v in chapter_summaries.items()] return State( - name=plan_info["name"], - synopsis=plan_info["synopsis"], - plan="\n".join(chapter_summaries), + name=info["name"], + synopsis=info["synopsis"], + language=info["language"], + outline=outline, novel_type=novel_type, - description=description, - language=plan_info["language"] + description=description ) - def generate_first_paragraphs( + def generate_first_step( self, state: State ): - plan_start = state.plan.split("\n")[0] - paragraphs = self.begin( + outline_start = state.outline.split("\n")[0] + paragraphs = self._complete_text( + "first_paragraphs", language=state.language, novel_type=state.novel_type, - plan=plan_start, + outline=outline_start, name=state.name, synopsis=state.synopsis, ) + paragraphs = paragraphs.split("\n") + paragraphs = [p.strip() for p in paragraphs if p.strip()] state.paragraphs = paragraphs - info = self.process( + info = self._complete_json( + "first_summary", novel_type=state.novel_type, - plan=state.plan, + outline=state.outline, language=state.language, name=state.name, synopsis=state.synopsis, @@ -192,37 +186,30 @@ def generate_first_paragraphs( ] return state - def begin(self, **kwargs): - begin_prompt = encode_prompt("begin.jinja", **kwargs) - print("BEGIN PROMPT") - print(begin_prompt) - print() - paragraphs = self._complete_text(begin_prompt).split("\n") - paragraphs = [p.strip() for p in paragraphs if p.strip()] - print("BEGIN OUTPUT") - print("\n\n".join(paragraphs)) - print("===========") - return paragraphs - - def process(self, **kwargs): - process_prompt = encode_prompt("process.jinja", **kwargs) - print("PROCESS PROMPT") - print(process_prompt) + def _complete_json(self, prompt_name, **kwargs): + prompt = encode_prompt(prompt_name, **kwargs) + print(f"{prompt_name.upper()} PROMPT") + print(prompt) print() - info = self._complete_json(process_prompt) - print("PROCESS OUTPUT") - print(json.dumps(info, ensure_ascii=False, indent=4)) - print("===========") - return info - - def _complete_json(self, prompt): - return novel_json_completion( + result = novel_json_completion( prompt, model_settings=self.model_settings ) + print(f"{prompt_name.upper()} OUTPUT") + print(json.dumps(result, ensure_ascii=False, indent=4)) + print("===========") + return result - def _complete_text(self, prompt): - return novel_completion( + def _complete_text(self, prompt_name, **kwargs): + prompt = encode_prompt(prompt_name, **kwargs) + print(f"{prompt_name.upper()} PROMPT") + print(prompt) + print() + result = novel_completion( prompt, model_settings=self.model_settings ) + print(f"{prompt_name.upper()} OUTPUT") + print(result) + print("===========") + return result diff --git a/tale_studio/utils.py b/tale_studio/utils.py index 7edddcf..6b7c934 100644 --- a/tale_studio/utils.py +++ b/tale_studio/utils.py @@ -100,7 +100,7 @@ def cos_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def encode_prompt(template_name, **kwargs): - template_path = PROMPTS_DIR_PATH / template_name + template_path = PROMPTS_DIR_PATH / f"{template_name}.jinja" with open(template_path) as f: template = Template(f.read()) return template.render(**kwargs).strip() + "\n"