Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: stream and capture ipython output #357

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions gptme/tools/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import dataclasses
import functools
import importlib.util
import io
import re
import sys
from collections.abc import Callable, Generator
from contextlib import contextmanager
from logging import getLogger
from typing import TYPE_CHECKING, TypeVar

Expand All @@ -23,7 +26,7 @@
)

if TYPE_CHECKING:
from IPython.terminal.embed import InteractiveShellEmbed # fmt: skip
from IPython.core.interactiveshell import InteractiveShell # fmt: skip

logger = getLogger(__name__)

Expand All @@ -32,7 +35,7 @@
# https://github.com/ErikBjare/gptme/issues/29

# IPython instance
_ipython: "InteractiveShellEmbed | None" = None
_ipython: "InteractiveShell | None" = None


registered_functions: dict[str, Callable] = {}
Expand All @@ -51,22 +54,55 @@ def register_function(func: T) -> T:

def _get_ipython():
global _ipython
from IPython.terminal.embed import InteractiveShellEmbed # fmt: skip
from IPython.core.interactiveshell import InteractiveShell # fmt: skip

if _ipython is None:
_ipython = InteractiveShellEmbed()
_ipython = InteractiveShell()
_ipython.push(registered_functions)

return _ipython


class TeeIO(io.StringIO):
def __init__(self, original_stream):
super().__init__()
self.original_stream = original_stream
self.in_result_block = False

def write(self, s):
# hack to get rid of ipython result-prompt ("Out[0]: ...") and everything after it
if s.startswith("Out["):
self.in_result_block = True
if self.in_result_block:
if s.startswith("\n"):
self.in_result_block = False
else:
s = ""
self.original_stream.write(s)
self.original_stream.flush() # Ensure immediate display
return super().write(s)


@contextmanager
def capture_and_display():
stdout_capture = TeeIO(sys.stdout)
stderr_capture = TeeIO(sys.stderr)
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = stdout_capture, stderr_capture
try:
yield stdout_capture, stderr_capture
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr


def execute_python(
code: str | None,
args: list[str] | None,
kwargs: dict[str, str] | None,
confirm: ConfirmFunc = lambda _: True,
) -> Generator[Message, None, None]:
"""Executes a python codeblock and returns the output."""
from IPython.core.interactiveshell import ExecutionResult # fmt: skip

if code is not None and args is not None:
code = code.strip()
Expand All @@ -84,12 +120,15 @@ def execute_python(
# Create an IPython instance if it doesn't exist yet
_ipython = _get_ipython()

# Capture the standard output and error streams
from IPython.utils.capture import capture_output # fmt: skip
# Capture and display output in real-time
with capture_and_display() as (stdout_capture, stderr_capture):
# Execute the code (output will be displayed in real-time)
result: ExecutionResult = _ipython.run_cell(
code, silent=False, store_history=False
)

with capture_output() as captured:
# Execute the code
result = _ipython.run_cell(code, silent=False, store_history=False)
captured_stdout = stdout_capture.getvalue()
captured_stderr = stderr_capture.getvalue()

output = ""
# TODO: should we include captured stdout with messages like these?
Expand All @@ -102,16 +141,16 @@ def execute_python(
output += f"Result:\n```\n{result.result}\n```\n\n"

# only show stdout if there is no result
elif captured.stdout:
output += f"```stdout\n{captured.stdout.rstrip()}\n```\n\n"
if captured.stderr:
output += f"```stderr\n{captured.stderr.rstrip()}\n```\n\n"
elif captured_stdout:
output += f"```stdout\n{captured_stdout.rstrip()}\n```\n\n"
if captured_stderr:
output += f"```stderr\n{captured_stderr.rstrip()}\n```\n\n"
if result.error_in_exec:
tb = result.error_in_exec.__traceback__
while tb.tb_next: # type: ignore
tb = tb.tb_next # type: ignore
# type: ignore
output += f"Exception during execution on line {tb.tb_lineno}:\n {result.error_in_exec.__class__.__name__}: {result.error_in_exec}"
while tb and tb.tb_next:
tb = tb.tb_next
if tb:
output += f"Exception during execution on line {tb.tb_lineno}:\n {result.error_in_exec.__class__.__name__}: {result.error_in_exec}"

# strip ANSI escape sequences
# TODO: better to signal to the terminal that we don't want colors?
Expand Down
Loading