Skip to content

Commit

Permalink
fix: more codeblock refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Sep 9, 2024
1 parent d2af905 commit ebe446a
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 107 deletions.
7 changes: 3 additions & 4 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,12 @@ def chat(
# then exit
elif not interactive:
# noreorder
from .tools import is_supported_codeblock_tool # fmt: skip
from .tools import is_supported_langtag # fmt: skip

# continue if we can run tools on the last message
runnable = False
if codeblock := log.get_last_code_block("assistant", history=1):
lang, _ = codeblock
if is_supported_codeblock_tool(lang):
if codeblock := log.get_last_codeblock("assistant", history=1):
if is_supported_langtag(codeblock.lang):
runnable = True
if not runnable:
logger.info("Non-interactive and exhausted prompts, exiting")
Expand Down
8 changes: 7 additions & 1 deletion gptme/codeblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def __post_init__(self):
if self.path is None and self.is_filename:
self.path = self.lang

def to_markdown(self) -> str:
return f"```{self.lang}\n{self.content}\n```"

def to_xml(self) -> str:
return f'<codeblock lang="{self.lang}" path="{self.path}">\n{self.content}\n</codeblock>'

@classmethod
def from_markdown(cls, content: str) -> "Codeblock":
if content.strip().startswith("```"):
Expand All @@ -32,7 +38,7 @@ def from_xml(cls, content: str) -> "Codeblock":
</codeblock>
"""
root = ElementTree.fromstring(content)
return cls(root.attrib["lang"], root.text or "", root.attrib.get("filename"))
return cls(root.attrib["lang"], root.text or "", root.attrib.get("path"))

@property
def is_filename(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,19 @@ def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover
print("Applied edited messages, write /log to see the result")


# TODO: remove?
def save(log: LogManager, filename: str):
# save the most recent code block to a file
codeblock = log.get_last_code_block()
codeblock = log.get_last_codeblock()
if not codeblock:
print("No code block found")
return
_, content = codeblock
if Path(filename).exists():
confirm = ask_execute("File already exists, overwrite?", default=False)
if not confirm:
return
with open(filename, "w") as f:
f.write(content)
f.write(codeblock.content)
print(f"Saved code block to {filename}")


Expand Down
8 changes: 4 additions & 4 deletions gptme/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def print_clear():
sys.stdout.flush()

# pause inference on finished code-block, letting user run the command before continuing
if codeblocks := Codeblock.extract_codeblocks(output):
lang, _ = codeblocks[0]
if codeblocks := Codeblock.iter_from_markdown(output):
codeblock = codeblocks[0]
# noreorder
from .tools import is_supported_codeblock_tool # fmt: skip
from .tools import is_supported_langtag # fmt: skip

# if closing a code block supported by tools, abort generation to let them run
if is_supported_codeblock_tool(lang):
if is_supported_langtag(codeblock.lang):
print("\nFound codeblock, breaking")
break
except KeyboardInterrupt:
Expand Down
5 changes: 3 additions & 2 deletions gptme/logmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from rich import print

from .codeblock import Codeblock
from .constants import CMDFIX
from .dirs import get_logs_dir
from .message import Message, len_tokens, print_msg
Expand Down Expand Up @@ -215,11 +216,11 @@ def load(
msgs = initial_msgs
return cls(msgs, logdir=logdir, branch=branch, **kwargs)

def get_last_code_block(
def get_last_codeblock(
self,
role: RoleLiteral | None = None,
history: int | None = None,
) -> tuple[str, str] | None:
) -> Codeblock | None:
"""Returns the last code block in the log, if any.
If `role` set, only check that role.
Expand Down
9 changes: 5 additions & 4 deletions gptme/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Generator
from copy import copy

from ..codeblock import Codeblock
from .message import Message, len_tokens
from .models import get_model

Expand Down Expand Up @@ -72,13 +73,13 @@ def truncate_msg(msg: Message, lines_pre=10, lines_post=10) -> Message | None:
content_staged = msg.content

# Truncate long codeblocks
for lang, content in msg.get_codeblocks():
for codeblock in msg.get_codeblocks():
# check that the reformatted codeblock is in the content
full_block = f"```{lang}\n{content}\n```"
full_block = codeblock.to_markdown()
assert full_block in content_staged, f"{full_block} not in {content_staged}"

# truncate the middle part of the codeblock, keeping the first and last n lines
lines = content.split("\n")
lines = codeblock.content.split("\n")
if len(lines) > lines_pre + lines_post + 1:
content = "\n".join([*lines[:lines_pre], "[...]", *lines[-lines_post:]])
else:
Expand All @@ -88,7 +89,7 @@ def truncate_msg(msg: Message, lines_pre=10, lines_post=10) -> Message | None:
# replace the codeblock with the truncated version
content_staged_prev = content_staged
content_staged = content_staged.replace(
full_block, f"```{lang}\n{content}\n```"
full_block, Codeblock(codeblock.lang, content).to_markdown()
)
assert content_staged != content_staged_prev
assert full_block not in content_staged
Expand Down
21 changes: 12 additions & 9 deletions gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,7 @@ def execute_msg(msg: Message, ask: bool) -> Generator[Message, None, None]:
# get all markdown code blocks
for codeblock in Codeblock.iter_from_markdown(msg.content):
try:
if get_tool_for_langtag(codeblock.lang):
yield from ToolUse.from_codeblock(codeblock).execute(ask)
else:
logger.info(f"Codeblock not supported: {codeblock.lang}")
yield from execute_codeblock(codeblock, ask)
except Exception as e:
logger.exception(e)
yield Message(
Expand All @@ -160,14 +157,16 @@ def execute_msg(msg: Message, ask: bool) -> Generator[Message, None, None]:


def execute_codeblock(
lang: str, codeblock: str, ask: bool
codeblock: Codeblock, ask: bool
) -> Generator[Message, None, None]:
"""Executes a codeblock and returns the output."""
if tool := get_tool_for_langtag(lang):
ToolUse.from_codeblock(codeblock)
if tool := get_tool_for_langtag(codeblock.lang):
if tool.execute:
args = lang.split(" ")[1:]
yield from tool.execute(codeblock, ask, args)
logger.debug("Unknown codeblock, neither supported language or filename.")
args = codeblock.lang.split(" ")[1:]
yield from tool.execute(codeblock.content, ask, args)
else:
logger.info(f"Codeblock not supported: {codeblock.lang}")


def get_tool_for_langtag(lang: str) -> ToolSpec | None:
Expand All @@ -182,6 +181,10 @@ def get_tool_for_langtag(lang: str) -> ToolSpec | None:
return None


def is_supported_langtag(lang: str) -> bool:
return bool(get_tool_for_langtag(lang))


def get_tool(tool_name: str) -> ToolSpec:
"""Returns a tool by name."""
for tool in loaded_tools:
Expand Down
81 changes: 81 additions & 0 deletions tests/test_codeblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from gptme.codeblock import Codeblock


def test_extract_codeblocks_basic():
markdown = """
Some text
```python
def hello():
print("Hello, World!")
```
More text
"""
assert Codeblock.iter_from_markdown(markdown) == [
("python", 'def hello():\n print("Hello, World!")')
]


def test_extract_codeblocks_multiple():
markdown = """
```java
public class Main {
public static void main(String[] args) {
System.out.println("Hello, Java!");
}
}
```
Some text
```python
def greet(name):
return f"Hello, {name}!"
```
"""
assert Codeblock.iter_from_markdown(markdown) == [
(
"java",
'public class Main {\n public static void main(String[] args) {\n System.out.println("Hello, Java!");\n }\n}',
),
("python", 'def greet(name):\n return f"Hello, {name}!"'),
]


def test_extract_codeblocks_nested():
markdown = """
```python
def print_readme():
print('''Usage:
```javascript
callme()
```
''')
```
"""
assert Codeblock.iter_from_markdown(markdown) == [
(
"python",
"def print_readme():\n print('''Usage:\n```javascript\ncallme()\n```\n''')",
)
]


def test_extract_codeblocks_empty():
assert Codeblock.iter_from_markdown("") == []


def test_extract_codeblocks_text_only():
assert (
Codeblock.iter_from_markdown("Just some regular text\nwithout any code blocks.")
== []
)


def test_extract_codeblocks_no_language():
markdown = """
```
def hello():
print("Hello, World!")
```
"""
assert Codeblock.iter_from_markdown(markdown) == [
("", 'def hello():\n print("Hello, World!")')
]
4 changes: 2 additions & 2 deletions tests/test_logmanager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from gptme.logmanager import LogManager, Message


def test_get_last_code_block():
def test_get_last_codeblock():
# tests that the last code block is indeed returned, with the correct formatting
log = LogManager()
log.append(
Expand All @@ -18,7 +18,7 @@ def test_get_last_code_block():
""",
)
)
assert ("python", "print('world')") == log.get_last_code_block()
assert ("python", "print('world')") == log.get_last_codeblock()


def test_branch():
Expand Down
78 changes: 0 additions & 78 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from gptme.util import (
epoch_to_age,
extract_codeblocks,
generate_name,
is_generated_name,
transform_examples_to_chat_directives,
Expand Down Expand Up @@ -57,80 +56,3 @@ def test_transform_examples_to_chat_directives_tricky():
Assistant: lol"""

assert transform_examples_to_chat_directives(src, strict=True) == expected


def test_extract_codeblocks_basic():
markdown = """
Some text
```python
def hello():
print("Hello, World!")
```
More text
"""
assert extract_codeblocks(markdown) == [
("python", 'def hello():\n print("Hello, World!")')
]


def test_extract_codeblocks_multiple():
markdown = """
```java
public class Main {
public static void main(String[] args) {
System.out.println("Hello, Java!");
}
}
```
Some text
```python
def greet(name):
return f"Hello, {name}!"
```
"""
assert extract_codeblocks(markdown) == [
(
"java",
'public class Main {\n public static void main(String[] args) {\n System.out.println("Hello, Java!");\n }\n}',
),
("python", 'def greet(name):\n return f"Hello, {name}!"'),
]


def test_extract_codeblocks_nested():
markdown = """
```python
def print_readme():
print('''Usage:
```javascript
callme()
```
''')
```
"""
assert extract_codeblocks(markdown) == [
(
"python",
"def print_readme():\n print('''Usage:\n```javascript\ncallme()\n```\n''')",
)
]


def test_extract_codeblocks_empty():
assert extract_codeblocks("") == []


def test_extract_codeblocks_text_only():
assert extract_codeblocks("Just some regular text\nwithout any code blocks.") == []


def test_extract_codeblocks_no_language():
markdown = """
```
def hello():
print("Hello, World!")
```
"""
assert extract_codeblocks(markdown) == [
("", 'def hello():\n print("Hello, World!")')
]

0 comments on commit ebe446a

Please sign in to comment.