diff --git a/gptme/cli.py b/gptme/cli.py index 907cece5..72787215 100644 --- a/gptme/cli.py +++ b/gptme/cli.py @@ -225,10 +225,16 @@ def chat( # then exit elif not interactive: # noreorder - from .tools import is_supported_codeblock # fmt: skip - - codeblock = log.get_last_code_block("assistant", history=1, content=False) - if not (codeblock and is_supported_codeblock(codeblock)): + from .tools import is_supported_codeblock_tool # fmt: skip + + # continue if we can run tools on the last message + runnable = False + if codeblock := log.get_last_code_block("assistant", history=1): + print("Checking for codeblock") + lang, _ = codeblock + if is_supported_codeblock_tool(lang): + runnable = True + if not runnable: logger.info("Non-interactive and exhausted prompts, exiting") break diff --git a/gptme/commands.py b/gptme/commands.py index cf274911..4f5a40ef 100644 --- a/gptme/commands.py +++ b/gptme/commands.py @@ -200,16 +200,17 @@ def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover def save(log: LogManager, filename: str): # save the most recent code block to a file - code = log.get_last_code_block(content=True) - if not code: + codeblock = log.get_last_code_block() + if not codeblock: print("No code block found") return + _, content = codeblock if Path(filename).exists(): confirm = ask_execute("File already exists, overwrite?", default=False) if not confirm: return with open(filename, "w") as f: - f.write(code) + f.write(content) print(f"Saved code block to {filename}") diff --git a/gptme/llm.py b/gptme/llm.py index a414b750..8ea53940 100644 --- a/gptme/llm.py +++ b/gptme/llm.py @@ -11,7 +11,7 @@ from .constants import PROMPT_ASSISTANT from .message import Message from .models import MODELS, get_summary_model -from .util import len_tokens, msgs2dicts +from .util import extract_codeblocks, len_tokens, msgs2dicts # Optimized for code # Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683 @@ -215,38 +215,20 @@ def print_clear(): sys.stdout.flush() # pause inference on finished code-block, letting user run the command before continuing - # TODO: support nested code blocks - # TODO: test for nested code blocks - code_start = "\n```" - code_end = "\n```\n" - codeblock_started = code_start in output[:-3] - codeblock_finished = code_end in output[-7:] - if codeblock_started and codeblock_finished: - n_start = output.count(code_start) - n_end = output.count(code_end) - if n_start == n_end: - print("\nFound codeblock, breaking") - # noreorder - from .tools import is_supported_codeblock # fmt: skip - - # if closing a code block supported by tools, abort generation to let them run - if is_supported_codeblock(output): - print("\n") - break - else: - logger.warning( - "Code block not supported by tools, continuing generation" - ) - - # pause inference in finished patch - patch_started = "```patch" in output[:-3] - patch_finished = "\n>>>>>>> UPDATED" in output[-30:] - if patch_started and patch_finished: - if "```" not in output[-10:]: - print("\n```", end="") - output += "\n```" - print("\n") - break + if codeblocks := extract_codeblocks(output): + lang, _ = codeblocks[0] + print("\nFound codeblock, breaking") + # noreorder + from .tools import is_supported_codeblock_tool # fmt: skip + + # if closing a code block supported by tools, abort generation to let them run + if is_supported_codeblock_tool(lang): + print("\n") + break + else: + logger.warning( + "Code block not supported by tools, continuing generation" + ) except KeyboardInterrupt: return Message("assistant", output + "... ^C Interrupted") finally: diff --git a/gptme/logmanager.py b/gptme/logmanager.py index 2b7ff9d9..a8a145ef 100644 --- a/gptme/logmanager.py +++ b/gptme/logmanager.py @@ -218,13 +218,11 @@ def get_last_code_block( self, role: RoleLiteral | None = None, history: int | None = None, - content=False, - ) -> str | None: + ) -> tuple[str, str] | None: """Returns the last code block in the log, if any. If `role` set, only check that role. If `history` set, only check n messages back. - If `content` set, return the content of the code block, else return the whole message. """ msgs = self.log if role: @@ -233,7 +231,7 @@ def get_last_code_block( msgs = msgs[-history:] for msg in msgs[::-1]: - codeblocks = msg.get_codeblocks(content=content) + codeblocks = msg.get_codeblocks() if codeblocks: return codeblocks[-1] return None diff --git a/gptme/message.py b/gptme/message.py index 0aa1aac5..baec8902 100644 --- a/gptme/message.py +++ b/gptme/message.py @@ -186,13 +186,12 @@ def from_toml(cls, toml: str) -> Self: timestamp=datetime.fromisoformat(msg["timestamp"]), ) - def get_codeblocks(self, content=False) -> list[str]: + def get_codeblocks(self) -> list[tuple[str, str]]: """ - Get all codeblocks. - If `content` set, return the content of the code block, else return the whole message. + Get all codeblocks from the message content, as a list of tuples (lang, content). """ - codeblocks = [] content_str = self.content + # prepend newline to make sure we get the first codeblock if not content_str.startswith("\n"): content_str = "\n" + content_str @@ -201,19 +200,10 @@ def get_codeblocks(self, content=False) -> list[str]: backtick_count = content_str.count("\n```") if backtick_count < 2: return [] - for i in range(1, backtick_count, 2): - codeblock_str = content_str.split("\n```")[i] - # get codeblock language or filename from first line - lang_or_fn = codeblock_str.split("\n")[0] - codeblock_str = "\n".join(codeblock_str.split("\n")[1:]) - - if content: - codeblocks.append(codeblock_str) - else: - full_codeblock = f"```{lang_or_fn}\n{codeblock_str}\n```" - codeblocks.append(full_codeblock) - return codeblocks + from .util import extract_codeblocks # noreorder + + return extract_codeblocks(content_str) def format_msgs( diff --git a/gptme/tools/__init__.py b/gptme/tools/__init__.py index 7ea4d108..6c4d23db 100644 --- a/gptme/tools/__init__.py +++ b/gptme/tools/__init__.py @@ -4,6 +4,7 @@ from xml.etree import ElementTree from ..message import Message +from ..util import extract_codeblocks from .base import ToolSpec from .browser import tool as browser_tool from .gh import tool as gh_tool @@ -93,12 +94,12 @@ def execute_msg(msg: Message, ask: bool) -> Generator[Message, None, None]: assert msg.role == "assistant", "Only assistant messages can be executed" # get all markdown code blocks - for codeblock in get_codeblocks(msg.content): + for lang, content in extract_codeblocks(msg.content): try: - if is_supported_codeblock(codeblock): - yield from codeblock_to_tooluse(codeblock).execute(ask) + if is_supported_codeblock_tool(lang): + yield from codeblock_to_tooluse(lang, content).execute(ask) else: - logger.info(f"Codeblock not supported: {codeblock}") + logger.info(f"Codeblock not supported: {lang}") except Exception as e: logger.exception(e) yield Message( @@ -112,29 +113,28 @@ def execute_msg(msg: Message, ask: bool) -> Generator[Message, None, None]: yield from tooluse.execute(ask) -def codeblock_to_tooluse(codeblock: str) -> ToolUse: +def codeblock_to_tooluse(lang: str, content: str) -> ToolUse: """Parses a codeblock into a ToolUse. Codeblock must be a supported type.""" - lang_or_fn = codeblock.splitlines()[0].strip() - codeblock_content = codeblock[len(lang_or_fn) :] - if tool := get_tool_for_codeblock(lang_or_fn): + if tool := get_tool_for_codeblock(lang): # NOTE: special case - args = lang_or_fn.split(" ")[1:] if tool.name != "save" else [lang_or_fn] - return ToolUse(tool.name, args, codeblock_content) + args = lang.split(" ")[1:] if tool.name != "save" else [lang] + return ToolUse(tool.name, args, content) else: - assert not is_supported_codeblock(codeblock) + assert not is_supported_codeblock_tool(lang) raise ValueError( - f"Unknown codeblock type '{lang_or_fn}', neither supported language or filename." + f"Unknown codeblock type '{lang}', neither supported language or filename." ) -def execute_codeblock(codeblock: str, ask: bool) -> Generator[Message, None, None]: +def execute_codeblock( + lang: str, codeblock: str, ask: bool +) -> Generator[Message, None, None]: """Executes a codeblock and returns the output.""" - lang_or_fn = codeblock.splitlines()[0].strip() - if tool := get_tool_for_codeblock(lang_or_fn): + if tool := get_tool_for_codeblock(lang): if tool.execute: - args = lang_or_fn.split(" ")[1:] + args = lang.split(" ")[1:] yield from tool.execute(codeblock, ask, args) - assert not is_supported_codeblock(codeblock) + assert not is_supported_codeblock_tool(codeblock) logger.debug("Unknown codeblock, neither supported language or filename.") @@ -162,37 +162,6 @@ def is_supported(self) -> bool: return is_supported_codeblock_tool(self.lang_or_fn) -def is_supported_codeblock(codeblock: str) -> bool: - """Returns whether a codeblock is supported by tools.""" - # if the codeblock are the clean contents of a code block, - # with a tool on the first line, without any leading or trailing whitespace or ``` - content = codeblock - if content.startswith("```"): - content = codeblock[3:] - if codeblock.endswith("```"): - content = content[:-3] - lang_or_fn = content.splitlines()[0].strip() - if is_supported_codeblock_tool(lang_or_fn): - return True - - # if not, it might be a message containing a code block - # TODO: this doesn't really make sense? - # codeblocks = list(get_codeblocks(codeblock)) - # if codeblocks: - # all_supported = True - # for cb in codeblocks: - # lang_or_fn = cb.strip().splitlines()[0].strip() - # supported = is_supported_codeblock_tool(lang_or_fn) - # print(f"supported: {supported}\n{cb}") - # all_supported = all_supported and supported - # if not all_supported: - # return False - - if lang_or_fn not in ["json", "csv", "stdout", "stderr", "output"]: - logger.warning(f"Unsupported codeblock type: {lang_or_fn}") - return False - - def get_tool_for_codeblock(lang_or_fn: str) -> ToolSpec | None: block_type = lang_or_fn.split(" ")[0] for tool in loaded_tools: @@ -212,12 +181,6 @@ def is_supported_codeblock_tool(lang_or_fn: str) -> bool: return False -def get_codeblocks(content: str) -> Generator[str, None, None]: - """Returns all codeblocks in a message.""" - for codeblock in ("\n" + content).split("\n```")[1::2]: - yield codeblock + "\n" - - def get_tooluse_xml(content: str) -> Generator[ToolUse, None, None]: """Returns all ToolUse in a message. diff --git a/gptme/tools/reduce.py b/gptme/tools/reduce.py index 0ac6b410..3377f794 100644 --- a/gptme/tools/reduce.py +++ b/gptme/tools/reduce.py @@ -73,9 +73,11 @@ def truncate_msg(msg: Message, lines_pre=10, lines_post=10) -> Message | None: content_staged = msg.content # Truncate long codeblocks - for codeblock in msg.get_codeblocks(): - assert codeblock in content_staged - lang_or_fn, content = codeblock.split("```", 1)[1].split("\n", 1) + for lang, content in msg.get_codeblocks(): + # check that the reformatted codeblock is in the content + full_block = f"```{lang}\n{content}\n```" + assert full_block in content_staged, f"{full_block} not in {content_staged}" + # truncate the middle part of the codeblock, keeping the first and last n lines lines = content.split("\n") if len(lines) > lines_pre + lines_post + 1: @@ -85,13 +87,12 @@ def truncate_msg(msg: Message, lines_pre=10, lines_post=10) -> Message | None: continue # replace the codeblock with the truncated version - assert codeblock in content_staged content_staged_prev = content_staged content_staged = content_staged.replace( - codeblock, f"```{lang_or_fn}\n{content}\n```" + full_block, f"```{lang}\n{content}\n```" ) assert content_staged != content_staged_prev - assert codeblock not in content_staged + assert full_block not in content_staged if content_staged != msg.content: msg_new = copy(msg) diff --git a/gptme/tools/save.py b/gptme/tools/save.py index 7fe21007..948d3c1b 100644 --- a/gptme/tools/save.py +++ b/gptme/tools/save.py @@ -41,6 +41,9 @@ def execute_save( assert fn, "No filename provided" # strip leading newlines code = code.lstrip("\n") + # ensure it ends with a newline + if not code.endswith("\n"): + code += "\n" if ask: confirm = ask_execute(f"Save to {fn}?") @@ -98,6 +101,9 @@ def execute_append( assert fn, "No filename provided" # strip leading newlines code = code.lstrip("\n") + # ensure it ends with a newline + if not code.endswith("\n"): + code += "\n" if ask: confirm = ask_execute(f"Append to {fn}?") diff --git a/gptme/util.py b/gptme/util.py index e9da0415..3348e85d 100644 --- a/gptme/util.py +++ b/gptme/util.py @@ -177,3 +177,38 @@ def transform_examples_to_chat_directives(s: str, strict=False) -> str: if strict: assert s != orig, "Couldn't find place to put start of directive" return s + + +def extract_codeblocks(markdown: str) -> list[tuple[str, str]]: + # speed check (early exit): check if message contains a code block + backtick_count = markdown.count("```") + if backtick_count < 2: + return [] + + codeblocks = [] + lines = markdown.split("\n") + stack: list[str] = [] + current_block = [] + current_lang = "" + + for line in lines: + stripped_line = line.strip() + if stripped_line.startswith("```"): + if not stack: # Start of a new block + stack.append(stripped_line[3:]) + current_lang = stripped_line[3:] + elif stripped_line[3:] and stack[-1] != stripped_line[3:]: # Nested start + current_block.append(line) + stack.append(stripped_line[3:]) + else: # End of a block + if len(stack) == 1: # Outermost block + codeblocks.append((current_lang, "\n".join(current_block))) + current_block = [] + current_lang = "" + else: # Nested end + current_block.append(line) + stack.pop() + elif stack: + current_block.append(line) + + return codeblocks diff --git a/tests/test_logmanager.py b/tests/test_logmanager.py index 71e324b1..da0229e4 100644 --- a/tests/test_logmanager.py +++ b/tests/test_logmanager.py @@ -18,7 +18,7 @@ def test_get_last_code_block(): """, ) ) - assert log.get_last_code_block(content=True) == "print('world')" + assert ("python", "print('world')") == log.get_last_code_block() def test_branch(): diff --git a/tests/test_util.py b/tests/test_util.py index e6b7688a..257c67b0 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,8 +1,8 @@ from datetime import datetime -from gptme.tools import get_codeblocks, is_supported_codeblock from gptme.util import ( epoch_to_age, + extract_codeblocks, generate_name, is_generated_name, transform_examples_to_chat_directives, @@ -21,63 +21,6 @@ def test_epoch_to_age(): assert epoch_to_age(epoch_yesterday) == "yesterday" -def test_is_supported_codeblock(): - block_plain = """``` -some plaintext -``` -""" - # clean unsupported block - assert not is_supported_codeblock(block_plain) - - block_python = """```python -print("hello world") -``` -""" - # clean supported block - assert is_supported_codeblock(block_python) - - # has preamble - # NOTE: this should not be supported by this function, clean it first if you really want to - # s = f"""bla bla\n{block_python}""" - # assert is_supported_codeblock(s) - - # last block is plain/unsupported - # NOTE: this should not be supported by this function, clean it first if you really want to - # s = f"""{block_python}\n{block_plain}""" - # assert not is_supported_codeblock(s) - - -def test_get_codeblocks(): - s = """```python -print("hello world") -``` -""" - - codeblocks = list(get_codeblocks(s)) - assert len(codeblocks) == 1 - assert ( - codeblocks[0] - == """python -print("hello world") -""" - ) - - # test a codeblock which contains triple backticks - s = """```python -print("hello ``` world") -``` -""" - - codeblocks = list(get_codeblocks(s)) - assert len(codeblocks) == 1 - assert ( - codeblocks[0] - == """python -print("hello ``` world") -""" - ) - - def test_transform_examples_to_chat_directives(): src = """ # Example @@ -114,3 +57,80 @@ def test_transform_examples_to_chat_directives_tricky(): Assistant: lol""" assert transform_examples_to_chat_directives(src, strict=True) == expected + + +def test_extract_codeblocks_basic(): + markdown = """ +Some text +```python +def hello(): + print("Hello, World!") +``` +More text +""" + assert extract_codeblocks(markdown) == [ + ("python", 'def hello():\n print("Hello, World!")') + ] + + +def test_extract_codeblocks_multiple(): + markdown = """ +```java +public class Main { + public static void main(String[] args) { + System.out.println("Hello, Java!"); + } +} +``` +Some text +```python +def greet(name): + return f"Hello, {name}!" +``` +""" + assert extract_codeblocks(markdown) == [ + ( + "java", + 'public class Main {\n public static void main(String[] args) {\n System.out.println("Hello, Java!");\n }\n}', + ), + ("python", 'def greet(name):\n return f"Hello, {name}!"'), + ] + + +def test_extract_codeblocks_nested(): + markdown = """ +```python +def print_readme(): + print('''Usage: +```javascript +callme() +``` +''') +``` +""" + assert extract_codeblocks(markdown) == [ + ( + "python", + "def print_readme():\n print('''Usage:\n```javascript\ncallme()\n```\n''')", + ) + ] + + +def test_extract_codeblocks_empty(): + assert extract_codeblocks("") == [] + + +def test_extract_codeblocks_text_only(): + assert extract_codeblocks("Just some regular text\nwithout any code blocks.") == [] + + +def test_extract_codeblocks_no_language(): + markdown = """ +``` +def hello(): + print("Hello, World!") +``` +""" + assert extract_codeblocks(markdown) == [ + ("", 'def hello():\n print("Hello, World!")') + ]