Skip to content

Commit

Permalink
fix: support nested codeblocks, rewrote/refactored codeblock parsing/…
Browse files Browse the repository at this point in the history
…management
  • Loading branch information
ErikBjare committed Aug 13, 2024
1 parent 07f1cbb commit 3e291a4
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 179 deletions.
14 changes: 10 additions & 4 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,16 @@ def chat(
# then exit
elif not interactive:
# noreorder
from .tools import is_supported_codeblock # fmt: skip

codeblock = log.get_last_code_block("assistant", history=1, content=False)
if not (codeblock and is_supported_codeblock(codeblock)):
from .tools import is_supported_codeblock_tool # fmt: skip

# continue if we can run tools on the last message
runnable = False
if codeblock := log.get_last_code_block("assistant", history=1):
print("Checking for codeblock")
lang, _ = codeblock
if is_supported_codeblock_tool(lang):
runnable = True
if not runnable:
logger.info("Non-interactive and exhausted prompts, exiting")
break

Expand Down
7 changes: 4 additions & 3 deletions gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,17 @@ def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover

def save(log: LogManager, filename: str):
# save the most recent code block to a file
code = log.get_last_code_block(content=True)
if not code:
codeblock = log.get_last_code_block()
if not codeblock:
print("No code block found")
return
_, content = codeblock
if Path(filename).exists():
confirm = ask_execute("File already exists, overwrite?", default=False)
if not confirm:
return
with open(filename, "w") as f:
f.write(code)
f.write(content)
print(f"Saved code block to {filename}")


Expand Down
48 changes: 15 additions & 33 deletions gptme/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .constants import PROMPT_ASSISTANT
from .message import Message
from .models import MODELS, get_summary_model
from .util import len_tokens, msgs2dicts
from .util import extract_codeblocks, len_tokens, msgs2dicts

# Optimized for code
# Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683
Expand Down Expand Up @@ -215,38 +215,20 @@ def print_clear():
sys.stdout.flush()

# pause inference on finished code-block, letting user run the command before continuing
# TODO: support nested code blocks
# TODO: test for nested code blocks
code_start = "\n```"
code_end = "\n```\n"
codeblock_started = code_start in output[:-3]
codeblock_finished = code_end in output[-7:]
if codeblock_started and codeblock_finished:
n_start = output.count(code_start)
n_end = output.count(code_end)
if n_start == n_end:
print("\nFound codeblock, breaking")
# noreorder
from .tools import is_supported_codeblock # fmt: skip

# if closing a code block supported by tools, abort generation to let them run
if is_supported_codeblock(output):
print("\n")
break
else:
logger.warning(
"Code block not supported by tools, continuing generation"
)

# pause inference in finished patch
patch_started = "```patch" in output[:-3]
patch_finished = "\n>>>>>>> UPDATED" in output[-30:]
if patch_started and patch_finished:
if "```" not in output[-10:]:
print("\n```", end="")
output += "\n```"
print("\n")
break
if codeblocks := extract_codeblocks(output):
lang, _ = codeblocks[0]
print("\nFound codeblock, breaking")
# noreorder
from .tools import is_supported_codeblock_tool # fmt: skip

# if closing a code block supported by tools, abort generation to let them run
if is_supported_codeblock_tool(lang):
print("\n")
break
else:
logger.warning(
"Code block not supported by tools, continuing generation"
)
except KeyboardInterrupt:
return Message("assistant", output + "... ^C Interrupted")
finally:
Expand Down
6 changes: 2 additions & 4 deletions gptme/logmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,11 @@ def get_last_code_block(
self,
role: RoleLiteral | None = None,
history: int | None = None,
content=False,
) -> str | None:
) -> tuple[str, str] | None:
"""Returns the last code block in the log, if any.
If `role` set, only check that role.
If `history` set, only check n messages back.
If `content` set, return the content of the code block, else return the whole message.
"""
msgs = self.log
if role:
Expand All @@ -233,7 +231,7 @@ def get_last_code_block(
msgs = msgs[-history:]

for msg in msgs[::-1]:
codeblocks = msg.get_codeblocks(content=content)
codeblocks = msg.get_codeblocks()
if codeblocks:
return codeblocks[-1]
return None
Expand Down
22 changes: 6 additions & 16 deletions gptme/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,12 @@ def from_toml(cls, toml: str) -> Self:
timestamp=datetime.fromisoformat(msg["timestamp"]),
)

def get_codeblocks(self, content=False) -> list[str]:
def get_codeblocks(self) -> list[tuple[str, str]]:
"""
Get all codeblocks.
If `content` set, return the content of the code block, else return the whole message.
Get all codeblocks from the message content, as a list of tuples (lang, content).
"""
codeblocks = []
content_str = self.content

# prepend newline to make sure we get the first codeblock
if not content_str.startswith("\n"):
content_str = "\n" + content_str
Expand All @@ -201,19 +200,10 @@ def get_codeblocks(self, content=False) -> list[str]:
backtick_count = content_str.count("\n```")
if backtick_count < 2:
return []
for i in range(1, backtick_count, 2):
codeblock_str = content_str.split("\n```")[i]
# get codeblock language or filename from first line
lang_or_fn = codeblock_str.split("\n")[0]
codeblock_str = "\n".join(codeblock_str.split("\n")[1:])

if content:
codeblocks.append(codeblock_str)
else:
full_codeblock = f"```{lang_or_fn}\n{codeblock_str}\n```"
codeblocks.append(full_codeblock)

return codeblocks
from .util import extract_codeblocks # noreorder

return extract_codeblocks(content_str)


def format_msgs(
Expand Down
71 changes: 17 additions & 54 deletions gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from xml.etree import ElementTree

from ..message import Message
from ..util import extract_codeblocks
from .base import ToolSpec
from .browser import tool as browser_tool
from .gh import tool as gh_tool
Expand Down Expand Up @@ -93,12 +94,12 @@ def execute_msg(msg: Message, ask: bool) -> Generator[Message, None, None]:
assert msg.role == "assistant", "Only assistant messages can be executed"

# get all markdown code blocks
for codeblock in get_codeblocks(msg.content):
for lang, content in extract_codeblocks(msg.content):
try:
if is_supported_codeblock(codeblock):
yield from codeblock_to_tooluse(codeblock).execute(ask)
if is_supported_codeblock_tool(lang):
yield from codeblock_to_tooluse(lang, content).execute(ask)
else:
logger.info(f"Codeblock not supported: {codeblock}")
logger.info(f"Codeblock not supported: {lang}")
except Exception as e:
logger.exception(e)
yield Message(
Expand All @@ -112,29 +113,28 @@ def execute_msg(msg: Message, ask: bool) -> Generator[Message, None, None]:
yield from tooluse.execute(ask)


def codeblock_to_tooluse(codeblock: str) -> ToolUse:
def codeblock_to_tooluse(lang: str, content: str) -> ToolUse:
"""Parses a codeblock into a ToolUse. Codeblock must be a supported type."""
lang_or_fn = codeblock.splitlines()[0].strip()
codeblock_content = codeblock[len(lang_or_fn) :]
if tool := get_tool_for_codeblock(lang_or_fn):
if tool := get_tool_for_codeblock(lang):
# NOTE: special case
args = lang_or_fn.split(" ")[1:] if tool.name != "save" else [lang_or_fn]
return ToolUse(tool.name, args, codeblock_content)
args = lang.split(" ")[1:] if tool.name != "save" else [lang]
return ToolUse(tool.name, args, content)
else:
assert not is_supported_codeblock(codeblock)
assert not is_supported_codeblock_tool(lang)
raise ValueError(
f"Unknown codeblock type '{lang_or_fn}', neither supported language or filename."
f"Unknown codeblock type '{lang}', neither supported language or filename."
)


def execute_codeblock(codeblock: str, ask: bool) -> Generator[Message, None, None]:
def execute_codeblock(
lang: str, codeblock: str, ask: bool
) -> Generator[Message, None, None]:
"""Executes a codeblock and returns the output."""
lang_or_fn = codeblock.splitlines()[0].strip()
if tool := get_tool_for_codeblock(lang_or_fn):
if tool := get_tool_for_codeblock(lang):
if tool.execute:
args = lang_or_fn.split(" ")[1:]
args = lang.split(" ")[1:]
yield from tool.execute(codeblock, ask, args)
assert not is_supported_codeblock(codeblock)
assert not is_supported_codeblock_tool(codeblock)
logger.debug("Unknown codeblock, neither supported language or filename.")


Expand Down Expand Up @@ -162,37 +162,6 @@ def is_supported(self) -> bool:
return is_supported_codeblock_tool(self.lang_or_fn)


def is_supported_codeblock(codeblock: str) -> bool:
"""Returns whether a codeblock is supported by tools."""
# if the codeblock are the clean contents of a code block,
# with a tool on the first line, without any leading or trailing whitespace or ```
content = codeblock
if content.startswith("```"):
content = codeblock[3:]
if codeblock.endswith("```"):
content = content[:-3]
lang_or_fn = content.splitlines()[0].strip()
if is_supported_codeblock_tool(lang_or_fn):
return True

# if not, it might be a message containing a code block
# TODO: this doesn't really make sense?
# codeblocks = list(get_codeblocks(codeblock))
# if codeblocks:
# all_supported = True
# for cb in codeblocks:
# lang_or_fn = cb.strip().splitlines()[0].strip()
# supported = is_supported_codeblock_tool(lang_or_fn)
# print(f"supported: {supported}\n{cb}")
# all_supported = all_supported and supported
# if not all_supported:
# return False

if lang_or_fn not in ["json", "csv", "stdout", "stderr", "output"]:
logger.warning(f"Unsupported codeblock type: {lang_or_fn}")
return False


def get_tool_for_codeblock(lang_or_fn: str) -> ToolSpec | None:
block_type = lang_or_fn.split(" ")[0]
for tool in loaded_tools:
Expand All @@ -212,12 +181,6 @@ def is_supported_codeblock_tool(lang_or_fn: str) -> bool:
return False


def get_codeblocks(content: str) -> Generator[str, None, None]:
"""Returns all codeblocks in a message."""
for codeblock in ("\n" + content).split("\n```")[1::2]:
yield codeblock + "\n"


def get_tooluse_xml(content: str) -> Generator[ToolUse, None, None]:
"""Returns all ToolUse in a message.
Expand Down
13 changes: 7 additions & 6 deletions gptme/tools/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def truncate_msg(msg: Message, lines_pre=10, lines_post=10) -> Message | None:
content_staged = msg.content

# Truncate long codeblocks
for codeblock in msg.get_codeblocks():
assert codeblock in content_staged
lang_or_fn, content = codeblock.split("```", 1)[1].split("\n", 1)
for lang, content in msg.get_codeblocks():
# check that the reformatted codeblock is in the content
full_block = f"```{lang}\n{content}\n```"
assert full_block in content_staged, f"{full_block} not in {content_staged}"

# truncate the middle part of the codeblock, keeping the first and last n lines
lines = content.split("\n")
if len(lines) > lines_pre + lines_post + 1:
Expand All @@ -85,13 +87,12 @@ def truncate_msg(msg: Message, lines_pre=10, lines_post=10) -> Message | None:
continue

# replace the codeblock with the truncated version
assert codeblock in content_staged
content_staged_prev = content_staged
content_staged = content_staged.replace(
codeblock, f"```{lang_or_fn}\n{content}\n```"
full_block, f"```{lang}\n{content}\n```"
)
assert content_staged != content_staged_prev
assert codeblock not in content_staged
assert full_block not in content_staged

if content_staged != msg.content:
msg_new = copy(msg)
Expand Down
6 changes: 6 additions & 0 deletions gptme/tools/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def execute_save(
assert fn, "No filename provided"
# strip leading newlines
code = code.lstrip("\n")
# ensure it ends with a newline
if not code.endswith("\n"):
code += "\n"

if ask:
confirm = ask_execute(f"Save to {fn}?")
Expand Down Expand Up @@ -98,6 +101,9 @@ def execute_append(
assert fn, "No filename provided"
# strip leading newlines
code = code.lstrip("\n")
# ensure it ends with a newline
if not code.endswith("\n"):
code += "\n"

if ask:
confirm = ask_execute(f"Append to {fn}?")
Expand Down
35 changes: 35 additions & 0 deletions gptme/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,38 @@ def transform_examples_to_chat_directives(s: str, strict=False) -> str:
if strict:
assert s != orig, "Couldn't find place to put start of directive"
return s


def extract_codeblocks(markdown: str) -> list[tuple[str, str]]:
# speed check (early exit): check if message contains a code block
backtick_count = markdown.count("```")
if backtick_count < 2:
return []

codeblocks = []
lines = markdown.split("\n")
stack: list[str] = []
current_block = []
current_lang = ""

for line in lines:
stripped_line = line.strip()
if stripped_line.startswith("```"):
if not stack: # Start of a new block
stack.append(stripped_line[3:])
current_lang = stripped_line[3:]
elif stripped_line[3:] and stack[-1] != stripped_line[3:]: # Nested start
current_block.append(line)
stack.append(stripped_line[3:])
else: # End of a block
if len(stack) == 1: # Outermost block
codeblocks.append((current_lang, "\n".join(current_block)))
current_block = []
current_lang = ""
else: # Nested end
current_block.append(line)
stack.pop()
elif stack:
current_block.append(line)

return codeblocks
2 changes: 1 addition & 1 deletion tests/test_logmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_get_last_code_block():
""",
)
)
assert log.get_last_code_block(content=True) == "print('world')"
assert ("python", "print('world')") == log.get_last_code_block()


def test_branch():
Expand Down
Loading

0 comments on commit 3e291a4

Please sign in to comment.