diff --git a/gptme/tools/python.py b/gptme/tools/python.py index d44078f7..0a829f12 100644 --- a/gptme/tools/python.py +++ b/gptme/tools/python.py @@ -7,8 +7,11 @@ import dataclasses import functools import importlib.util +import io import re +import sys from collections.abc import Callable, Generator +from contextlib import contextmanager from logging import getLogger from typing import TYPE_CHECKING, TypeVar @@ -23,7 +26,7 @@ ) if TYPE_CHECKING: - from IPython.terminal.embed import InteractiveShellEmbed # fmt: skip + from IPython.core.interactiveshell import InteractiveShell # fmt: skip logger = getLogger(__name__) @@ -32,7 +35,7 @@ # https://github.com/ErikBjare/gptme/issues/29 # IPython instance -_ipython: "InteractiveShellEmbed | None" = None +_ipython: "InteractiveShell | None" = None registered_functions: dict[str, Callable] = {} @@ -51,15 +54,47 @@ def register_function(func: T) -> T: def _get_ipython(): global _ipython - from IPython.terminal.embed import InteractiveShellEmbed # fmt: skip + from IPython.core.interactiveshell import InteractiveShell # fmt: skip if _ipython is None: - _ipython = InteractiveShellEmbed() + _ipython = InteractiveShell() _ipython.push(registered_functions) return _ipython +class TeeIO(io.StringIO): + def __init__(self, original_stream): + super().__init__() + self.original_stream = original_stream + self.in_result_block = False + + def write(self, s): + # hack to get rid of ipython result-prompt ("Out[0]: ...") and everything after it + if s.startswith("Out["): + self.in_result_block = True + if self.in_result_block: + if s.startswith("\n"): + self.in_result_block = False + else: + s = "" + self.original_stream.write(s) + self.original_stream.flush() # Ensure immediate display + return super().write(s) + + +@contextmanager +def capture_and_display(): + stdout_capture = TeeIO(sys.stdout) + stderr_capture = TeeIO(sys.stderr) + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = stdout_capture, stderr_capture + try: + yield stdout_capture, stderr_capture + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + def execute_python( code: str | None, args: list[str] | None, @@ -67,6 +102,7 @@ def execute_python( confirm: ConfirmFunc = lambda _: True, ) -> Generator[Message, None, None]: """Executes a python codeblock and returns the output.""" + from IPython.core.interactiveshell import ExecutionResult # fmt: skip if code is not None and args is not None: code = code.strip() @@ -84,12 +120,15 @@ def execute_python( # Create an IPython instance if it doesn't exist yet _ipython = _get_ipython() - # Capture the standard output and error streams - from IPython.utils.capture import capture_output # fmt: skip + # Capture and display output in real-time + with capture_and_display() as (stdout_capture, stderr_capture): + # Execute the code (output will be displayed in real-time) + result: ExecutionResult = _ipython.run_cell( + code, silent=False, store_history=False + ) - with capture_output() as captured: - # Execute the code - result = _ipython.run_cell(code, silent=False, store_history=False) + captured_stdout = stdout_capture.getvalue() + captured_stderr = stderr_capture.getvalue() output = "" # TODO: should we include captured stdout with messages like these? @@ -102,16 +141,16 @@ def execute_python( output += f"Result:\n```\n{result.result}\n```\n\n" # only show stdout if there is no result - elif captured.stdout: - output += f"```stdout\n{captured.stdout.rstrip()}\n```\n\n" - if captured.stderr: - output += f"```stderr\n{captured.stderr.rstrip()}\n```\n\n" + elif captured_stdout: + output += f"```stdout\n{captured_stdout.rstrip()}\n```\n\n" + if captured_stderr: + output += f"```stderr\n{captured_stderr.rstrip()}\n```\n\n" if result.error_in_exec: tb = result.error_in_exec.__traceback__ - while tb.tb_next: # type: ignore - tb = tb.tb_next # type: ignore - # type: ignore - output += f"Exception during execution on line {tb.tb_lineno}:\n {result.error_in_exec.__class__.__name__}: {result.error_in_exec}" + while tb and tb.tb_next: + tb = tb.tb_next + if tb: + output += f"Exception during execution on line {tb.tb_lineno}:\n {result.error_in_exec.__class__.__name__}: {result.error_in_exec}" # strip ANSI escape sequences # TODO: better to signal to the terminal that we don't want colors?