diff --git a/gptme/tools/python.py b/gptme/tools/python.py index 5e2d4c5d..8217b5b0 100644 --- a/gptme/tools/python.py +++ b/gptme/tools/python.py @@ -26,7 +26,7 @@ ) if TYPE_CHECKING: - from IPython.terminal.embed import InteractiveShellEmbed # fmt: skip + from IPython.core.interactiveshell import InteractiveShell # fmt: skip logger = getLogger(__name__) @@ -35,7 +35,7 @@ # https://github.com/ErikBjare/gptme/issues/29 # IPython instance -_ipython: "InteractiveShellEmbed | None" = None +_ipython: "InteractiveShell | None" = None registered_functions: dict[str, Callable] = {} @@ -54,15 +54,47 @@ def register_function(func: T) -> T: def _get_ipython(): global _ipython - from IPython.terminal.embed import InteractiveShellEmbed # fmt: skip + from IPython.core.interactiveshell import InteractiveShell # fmt: skip if _ipython is None: - _ipython = InteractiveShellEmbed() + _ipython = InteractiveShell() _ipython.push(registered_functions) return _ipython +class TeeIO(io.StringIO): + def __init__(self, original_stream): + super().__init__() + self.original_stream = original_stream + self.in_result_block = False + + def write(self, s): + # hack to get rid of ipython result-prompt ("Out[0]: ...") and everything after it + if s.startswith("Out["): + self.in_result_block = True + if self.in_result_block: + if s.startswith("\n"): + self.in_result_block = False + else: + s = "" + self.original_stream.write(s) + self.original_stream.flush() # Ensure immediate display + return super().write(s) + + +@contextmanager +def capture_and_display(): + stdout_capture = TeeIO(sys.stdout) + stderr_capture = TeeIO(sys.stderr) + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = stdout_capture, stderr_capture + try: + yield stdout_capture, stderr_capture + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + def execute_python( code: str | None, args: list[str] | None, @@ -70,6 +102,7 @@ def execute_python( confirm: ConfirmFunc = lambda _: True, ) -> Generator[Message, None, None]: """Executes a python codeblock and returns the output.""" + from IPython.core.interactiveshell import ExecutionResult # fmt: skip if code is not None and args is not None: code = code.strip() @@ -88,31 +121,11 @@ def execute_python( _ipython = _get_ipython() # Capture and display output in real-time - - class TeeIO(io.StringIO): - def __init__(self, original_stream): - super().__init__() - self.original_stream = original_stream - - def write(self, s): - self.original_stream.write(s) - self.original_stream.flush() # Ensure immediate display - return super().write(s) - - @contextmanager - def capture_and_display(): - stdout_capture = TeeIO(sys.stdout) - stderr_capture = TeeIO(sys.stderr) - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout, sys.stderr = stdout_capture, stderr_capture - try: - yield stdout_capture, stderr_capture - finally: - sys.stdout, sys.stderr = old_stdout, old_stderr - with capture_and_display() as (stdout_capture, stderr_capture): # Execute the code (output will be displayed in real-time) - result = _ipython.run_cell(code, silent=False, store_history=False) + result: ExecutionResult = _ipython.run_cell( + code, silent=False, store_history=False + ) captured_stdout = stdout_capture.getvalue() captured_stderr = stderr_capture.getvalue()