From 9b54cec03fe754aa872dd2cbf2868e4f52ecfbbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Wed, 6 Sep 2023 14:01:31 +0200 Subject: [PATCH] feat: added config, refactoring, hide initial system messages, better context awareness --- gptme/cli.py | 118 +++++++++++----- gptme/config.py | 51 +++++++ gptme/logmanager.py | 29 +++- gptme/message.py | 5 +- gptme/prompts.py | 111 ++++++--------- gptme/tools/__init__.py | 291 +++------------------------------------ gptme/tools/python.py | 103 ++++++++++++++ gptme/tools/shell.py | 95 ++++++++++++- gptme/tools/summarize.py | 39 ++++++ gptme/util.py | 27 ++++ poetry.lock | 24 +++- pyproject.toml | 2 + tests/test_shell.py | 2 +- 13 files changed, 510 insertions(+), 387 deletions(-) create mode 100644 gptme/config.py create mode 100644 gptme/tools/python.py create mode 100644 gptme/tools/summarize.py diff --git a/gptme/cli.py b/gptme/cli.py index d7f05714..c00a4d8f 100644 --- a/gptme/cli.py +++ b/gptme/cli.py @@ -36,24 +36,25 @@ from rich import print from rich.console import Console +from .config import get_config from .constants import role_color from .logmanager import LogManager, print_log from .message import Message from .prompts import initial_prompt from .tools import ( - _execute_codeblock, - _execute_linecmd, - _execute_python, - _execute_save, - _execute_shell, + execute_codeblock, + execute_linecmd, + execute_python, + execute_shell, ) +from .tools.shell import get_shell from .util import epoch_to_age, generate_unique_name, msgs2dicts logger = logging.getLogger(__name__) LLMChoice = Literal["openai", "llama"] -OpenAIModel = Literal["gpt-3.5-turbo", "gpt4"] +ModelChoice = Literal["gpt-3.5-turbo", "gpt4"] def get_logfile(logdir: Path) -> Path: @@ -70,24 +71,34 @@ def execute_msg(msg: Message) -> Generator[Message, None, None]: assert msg.role == "assistant", "Only assistant messages can be executed" for line in msg.content.splitlines(): - yield from _execute_linecmd(line) + yield from execute_linecmd(line) # get all markdown code blocks # we support blocks beginning with ```python and ```bash codeblocks = [codeblock for codeblock in msg.content.split("```")[1::2]] for codeblock in codeblocks: - yield from _execute_codeblock(codeblock) - - yield from _execute_save(msg.content) + yield from execute_codeblock(codeblock) Actions = Literal[ - "continue", "summarize", "load", "shell", "python", "replay", "undo", "help", "exit" + "continue", + "summarize", + "log", + "summarize", + "context", + "load", + "shell", + "python", + "replay", + "undo", + "help", + "exit", ] action_descriptions: dict[Actions, str] = { "continue": "Continue", "undo": "Undo the last action", + "log": "Show the conversation log", "summarize": "Summarize the conversation so far", "load": "Load a file", "shell": "Execute a shell command", @@ -106,13 +117,18 @@ def handle_cmd( name, *args = cmd.split(" ") match name: case "bash" | "sh" | "shell": - yield from _execute_shell(" ".join(args), ask=not no_confirm) + yield from execute_shell(" ".join(args), ask=not no_confirm) case "python" | "py": - yield from _execute_python(" ".join(args), ask=not no_confirm) + yield from execute_python(" ".join(args), ask=not no_confirm) case "continue": raise NotImplementedError + case "log": + logmanager.print(show_hidden="--hidden" in args) case "summarize": raise NotImplementedError + case "context": + # print context msg + print(_gen_context_msg()) case "undo": # if int, undo n messages n = int(args[0]) if args and args[0].isdigit() else 1 @@ -170,10 +186,10 @@ def handle_cmd( type=click.Choice(["openai", "llama"]), ) @click.option( - "--openai-model", - default="gpt-3.5-turbo", - help="OpenAI model to use", - type=click.Choice(["gpt-3.5-turbo", "gpt4"]), + "--model", + default="gpt-4", + help="Model to use (gpt-3.5 not recommended)", + type=click.Choice(["gpt-4", "gpt-3.5-turbo", "wizardcoder-..."]), ) @click.option( "--stream/--no-stream", @@ -185,25 +201,36 @@ def handle_cmd( @click.option( "-y", "--no-confirm", is_flag=True, help="Skips all confirmation prompts." ) +@click.option( + "--show-hidden", + is_flag=True, + help="Show hidden system messages.", +) def main( prompt: str | None, prompt_system: str, name: str, llm: LLMChoice, - openai_model: OpenAIModel, + model: ModelChoice, stream: bool, verbose: bool, no_confirm: bool, + show_hidden: bool, ): + config = get_config() logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO) load_dotenv() _load_readline_history() if llm == "openai": - if "OPENAI_API_KEY" not in os.environ: - print("Error: OPENAI_API_KEY not set, see the README.") + if "OPENAI_API_KEY" in os.environ: + api_key = os.environ["OPENAI_API_KEY"] + elif api_key := config["env"]["OPENAI_API_KEY"]: + pass + else: + print("Error: OPENAI_API_KEY not set in env or config, see README.") sys.exit(1) - openai.api_key = os.environ["OPENAI_API_KEY"] + openai.api_key = api_key else: openai.api_base = "http://localhost:8000/v1" @@ -214,10 +241,16 @@ def main( LOGDIR = Path("~/.local/share/gptme/logs").expanduser() LOGDIR.mkdir(parents=True, exist_ok=True) - if name: - if name == "random": - name = generate_unique_name() - logpath = LOGDIR / (f"{datetime.now().strftime('%Y-%m-%d')}-{name}") + datestr = datetime.now().strftime("%Y-%m-%d") + logpath = LOGDIR + if name == "random": + # check if name exists, if so, generate another one + while not logpath or logpath.exists(): + if name == "random": + name = generate_unique_name() + logpath = LOGDIR / (f"{datestr}-{name}") + elif name: + logpath = LOGDIR / (f"{datestr}-{name}") else: # let user select between starting a new conversation and loading a previous one # using the library @@ -244,13 +277,15 @@ def main( name = input("Name for conversation (or empty for random words): ") if not name: name = generate_unique_name() - logpath = LOGDIR / (datetime.now().strftime("%Y-%m-%d") + f"-{name}") + logpath = LOGDIR / f"{datestr}-{name}" else: logpath = LOGDIR / prev_conv_files[index - 1].parent print(f"Using logdir {logpath}") logfile = get_logfile(logpath) - logmanager = LogManager.load(logfile, initial_msgs=promptmsgs) + logmanager = LogManager.load( + logfile, initial_msgs=promptmsgs, show_hidden=show_hidden + ) logmanager.print() print("--- ^^^ past messages ^^^ ---") @@ -302,7 +337,14 @@ def main( # if large context, try to reduce/summarize # print response try: - msg_response = reply(logmanager.prepare_messages(), openai_model, stream) + # performs reduction/context trimming + msgs = logmanager.prepare_messages() + + # append temporary message with current context + msgs += [_gen_context_msg()] + + # generate response + msg_response = reply(msgs, model, stream) # log response and run tools if msg_response: @@ -313,6 +355,20 @@ def main( print("Interrupted") +def _gen_context_msg() -> Message: + shell = get_shell() + msgstr = "" + + _, pwd, _ = shell.run_command("echo pwd: $(pwd)") + msgstr += f"$ pwd\n{pwd.strip()}\n" + + ret, git, _ = shell.run_command("git status -s") + if ret == 0: + msgstr += f"$ git status\n{git}\n" + + return Message("system", msgstr.strip(), hide=True) + + CONFIG_PATH = Path("~/.config/gptme").expanduser() CONFIG_PATH.mkdir(parents=True, exist_ok=True) HISTORY_FILE = CONFIG_PATH / "history" @@ -356,7 +412,7 @@ def prompt_input(prompt: str, value=None) -> str: return value -def reply(messages: list[Message], model: OpenAIModel, stream: bool = False) -> Message: +def reply(messages: list[Message], model: str, stream: bool = False) -> Message: if stream: return reply_stream(messages, model) else: @@ -370,7 +426,7 @@ def reply(messages: list[Message], model: OpenAIModel, stream: bool = False) -> top_p = 0.1 -def _chat_complete(messages: list[Message], model: OpenAIModel) -> str: +def _chat_complete(messages: list[Message], model: str) -> str: # This will generate code and such, so we need appropriate temperature and top_p params # top_p controls diversity, temperature controls randomness response = openai.ChatCompletion.create( # type: ignore @@ -382,7 +438,7 @@ def _chat_complete(messages: list[Message], model: OpenAIModel) -> str: return response.choices[0].message.content -def reply_stream(messages: list[Message], model: OpenAIModel) -> Message: +def reply_stream(messages: list[Message], model: str) -> Message: print(f"{PROMPT_ASSISTANT}: Thinking...", end="\r") response = openai.ChatCompletion.create( # type: ignore model=model, diff --git a/gptme/config.py b/gptme/config.py new file mode 100644 index 00000000..982a16c4 --- /dev/null +++ b/gptme/config.py @@ -0,0 +1,51 @@ +import os +from typing import TypedDict + +import toml + + +class Config(TypedDict): + prompt: dict + env: dict + + +default_config: Config = { + "prompt": {"about_user": "I am a curious human programmer."}, + "env": {"OPENAI_API_KEY": None}, +} + +_config: Config | None = None + + +def get_config() -> Config: + global _config + if _config is None: + _config = _load_config() + return _config + + +def _load_config() -> Config: + # Define the path to the config file + config_path = os.path.expanduser("~/.config/gptme/config.toml") + + # Check if the config file exists + if not os.path.exists(config_path): + # If not, create it and write some default settings + os.makedirs(os.path.dirname(config_path), exist_ok=True) + with open(config_path, "w") as config_file: + toml.dump(default_config, config_file) + print(f"Created config file at {config_path}") + + # Now you can read the settings from the config file like this: + with open(config_path, "r") as config_file: + config: dict = toml.load(config_file) + + # TODO: validate + config = Config(**config) # type: ignore + + return config # type: ignore + + +if __name__ == "__main__": + config = get_config() + print(config) diff --git a/gptme/logmanager.py b/gptme/logmanager.py index bccb7b3c..0a888ff4 100644 --- a/gptme/logmanager.py +++ b/gptme/logmanager.py @@ -18,11 +18,15 @@ class LogManager: def __init__( - self, log: list[Message] | None = None, logfile: PathLike | None = None + self, + log: list[Message] | None = None, + logfile: PathLike | None = None, + show_hidden=False, ): self.log = log or [] assert logfile is not None, "logfile must be specified" self.logfile = logfile + self.show_hidden = show_hidden # TODO: Check if logfile has contents, then maybe load, or should it overwrite? def append(self, msg: Message, quiet=False) -> None: @@ -36,8 +40,8 @@ def write(self) -> None: """Writes the log to the logfile.""" write_log(self.log, self.logfile) - def print(self): - print_log(self.log, oneline=False) + def print(self, show_hidden: bool | None = None): + print_log(self.log, oneline=False, show_hidden=show_hidden or self.show_hidden) def undo(self, n: int = 1) -> None: """Removes the last message from the log.""" @@ -64,6 +68,7 @@ def prepare_messages(self) -> list[Message]: """Prepares the log into messages before sending it to the LLM.""" msgs = self.log msgs_reduced = list(reduce_log(msgs)) + if len(msgs) != len(msgs_reduced): print( f"Reduced log from {len_tokens(msgs)//1} to {len_tokens(msgs_reduced)//1} tokens" @@ -76,13 +81,15 @@ def prepare_messages(self) -> list[Message]: return msgs_limited @classmethod - def load(cls, logfile=None, initial_msgs=initial_prompt()) -> "LogManager": + def load( + cls, logfile=None, initial_msgs=initial_prompt(), **kwargs + ) -> "LogManager": """Loads a conversation log.""" with open(logfile, "r") as file: msgs = [Message(**json.loads(line)) for line in file.readlines()] if not msgs: msgs = initial_msgs - return cls(msgs, logfile=logfile) + return cls(msgs, logfile=logfile, **kwargs) def write_log(msg_or_log: Message | list[Message], logfile: PathLike) -> None: @@ -106,9 +113,15 @@ def write_log(msg_or_log: Message | list[Message], logfile: PathLike) -> None: ) -def print_log(log: Message | list[Message], oneline: bool = True) -> None: +def print_log( + log: Message | list[Message], oneline: bool = True, show_hidden=False +) -> None: """Prints the log to the console.""" + skipped_hidden = 0 for msg in log if isinstance(log, list) else [log]: + if msg.hide and not show_hidden: + skipped_hidden += 1 + continue color = role_color[msg.role] userprefix = f"[bold {color}]{msg.user}[/bold {color}]" # get terminal width @@ -141,3 +154,7 @@ def print_log(log: Message | list[Message], oneline: bool = True) -> None: print(f" ```{output[code_end+3:]}") else: print(f"\n{userprefix}: {output.rstrip()}") + if skipped_hidden: + print( + f"[grey30]Skipped {skipped_hidden} hidden system messages, show with --show-hidden[/]" + ) diff --git a/gptme/message.py b/gptme/message.py index b122f5d8..d1d6aee5 100644 --- a/gptme/message.py +++ b/gptme/message.py @@ -11,10 +11,12 @@ def __init__( content: str, user: str | None = None, pinned: bool = False, + hide: bool = False, ): assert role in ["system", "user", "assistant"] self.role = role self.content = content.strip() + self.timestamp = datetime.now() if user: self.user = user else: @@ -23,7 +25,8 @@ def __init__( # Wether this message should be pinned to the top of the chat, and never context-trimmed. self.pinned = pinned - self.timestamp = datetime.now() + # Wether this message should be hidden from the chat output (but still be sent to the assistant) + self.hide = hide def to_dict(self): """Return a dict representation of the message, serializable to JSON.""" diff --git a/gptme/prompts.py b/gptme/prompts.py index 25578105..f80a935b 100644 --- a/gptme/prompts.py +++ b/gptme/prompts.py @@ -1,18 +1,13 @@ import os -import subprocess from .cli import __doc__ as cli_doc +from .config import get_config from .message import Message USER = os.environ["USER"] -ABOUT_ERB = """ -Erik Bjäreholt is a software engineer who is passionate about building tools that make people's lives easier. -He is known for building ActivityWatch, a open-source time tracking app. -""" - ABOUT_ACTIVITYWATCH = """ -ActivityWatch is a free and open-source time tracking app. +ActivityWatch is a free and open-source automated time-tracker that helps you track how you spend your time on your devices. It runs locally on the user's computer and has a REST API available at http://localhost:5600/api/. @@ -23,8 +18,11 @@ def initial_prompt(short: bool = False) -> list[Message]: """Initial prompt to start the conversation. If no history given.""" + config = get_config() + include_about = False - include_user = False + include_user = True + include_project = False include_tools = not short assert cli_doc @@ -33,21 +31,32 @@ def initial_prompt(short: bool = False) -> list[Message]: msgs.append(Message("system", cli_doc)) if include_user: msgs.append(Message("system", "$ whoami\n" + USER)) - pwd = subprocess.run(["pwd"], capture_output=True, text=True).stdout - msgs.append(Message("system", f"$ pwd\n{pwd}")) - if USER == "erb": - msgs.append( - Message( - "system", "Here is some information about the user: " + ABOUT_ERB - ) + + # NOTE: this is better to have as a temporary message that's updated with every request, so that the information is up-to-date + # pwd = subprocess.run(["pwd"], capture_output=True, text=True).stdout + # msgs.append(Message("system", f"$ pwd\n{pwd}")) + + msgs.append( + Message( + "system", + "Here is some information about the user: " + + config["prompt"]["about_user"], ) - msgs.append( - Message( - "system", - "Here is some information about ActivityWatch: " - + ABOUT_ACTIVITYWATCH, - ) + ) + if include_project: + # TODO: detect from git root folder name + project = "activitywatch" + # TODO: enshrine in config + config["prompt"]["project"] = { + "activitywatch": ABOUT_ACTIVITYWATCH, + } + msgs.append( + Message( + "system", + f"Some information about the current project {project}: " + + config["prompt"]["project"][project], ) + ) if include_tools: include_saveload = False @@ -59,16 +68,22 @@ def initial_prompt(short: bool = False) -> list[Message]: The assistant shows the user how to use tools to interact with the system and access the internet. The assistant should be concise and not verbose, it should assume the user is very knowledgeable. All commands should be copy-pasteable and runnable, do not use placeholders like `$REPO` or ``. +Do not suggest the user open a browser or editor, instead show them how to do it in the shell. +When the output of a command is of interest, end the code block so that the user can execute it before continuing. Here are some examples: # Terminal Use by writing a code block like this: +> User: learn about the project ```bash -pwd ls ``` +> stdout: README.md +```bash +cat README.md +``` # Python interpreter Use by writing a code block like this: @@ -82,8 +97,8 @@ def initial_prompt(short: bool = False) -> list[Message]: else """ # Save files Saving is done using `echo` with a redirect operator. -Example to save `hello.py`: +> User: write a Hello world script to hello.py ```bash echo '#!/usr/bin/env python print("Hello world!")' > hello.py @@ -91,66 +106,23 @@ def initial_prompt(short: bool = False) -> list[Message]: # Read files Loading is done using `cat`. -Example to load `hello.py`: +> User: read hello.py ```bash cat hello.py ``` # Putting it together -Run the script `hello.py` and save it to hello.sh: - -# hello.sh +> User: run hello.py ```bash -#!/usr/bin/env bash -chmod +x hello.sh hello.py python hello.py ``` """, + hide=True, ) ) - # Short/concise example use - # DEPRECATED - include_exampleuse = False - if include_exampleuse: - msgs.append( - Message( - "system", - """ -Example use: - -User: Look in the current directory and learn about the project. -Assistant: $ ls -System: README.md Makefile src pyproject.toml -Assistant: $ cat README.md -System: ... -""".strip(), - ) - ) - - # Karpathy wisdom and CoT hint - include_wisdom = False - if include_wisdom: - msgs.append( - Message( - "system", - """ - Always remember you are an AI language model, and to generate good answers you might need to reason step-by-step. - (In the words of Andrej Karpathy: LLMs need tokens to think) - """.strip(), - ) - ) - - # The most basic prompt, always given. - # msgs.append( - # Message( - # "assistant", - # "Hello, I am your personal AI assistant. How may I help you today?", - # ) - # ) - # gh examples include_gh = True if include_gh: @@ -178,6 +150,7 @@ def initial_prompt(short: bool = False) -> list[Message]: gh run view $RUN --repo $REPO --log ``` """, + hide=True, ) ) diff --git a/gptme/tools/__init__.py b/gptme/tools/__init__.py index e7b28037..da57efe9 100644 --- a/gptme/tools/__init__.py +++ b/gptme/tools/__init__.py @@ -1,298 +1,41 @@ -import ast -import atexit -import code -import io import logging -import os -import re -import textwrap -from contextlib import redirect_stderr, redirect_stdout from typing import Generator -import openai -from rich import print -from rich.console import Console -from rich.syntax import Syntax - -from ..cache import memory from ..message import Message -from ..util import len_tokens -from .shell import ShellSession +from .python import execute_python +from .shell import execute_shell +from .summarize import summarize logger = logging.getLogger(__name__) -EMOJI_WARN = "⚠️" - - -def _print_preview(code=None, lang=None): - # print a preview section header - print() - print("[bold white]Preview[/bold white]") - if code: - print(Syntax(code.strip(), lang)) - print() - - -def _execute_save(text: str, ask=True) -> Generator[Message, None, None]: - """Saves a codeblock to a file.""" - # last scanned codeblock - prev_codeblock = "" - # currently scanning codeblock - codeblock = "" - - for line in text.splitlines(): - if line.strip().startswith("// save:"): - filename = line.split(":")[1].strip() - content = "\n".join(prev_codeblock.split("\n")[1:-2]) - _print_preview() - print(f"# filename: {filename}") - print(textwrap.indent(content, "> ")) - confirm = input("Save to " + filename + "? (Y/n) ") - if confirm.lower() in ["y", "Y", "", "yes"]: - with open(filename, "w") as file: - file.write(content) - yield Message("system", "Saved to " + filename) - if line.startswith("```") or codeblock: - codeblock += line + "\n" - # if block if complete - if codeblock.startswith("```") and codeblock.strip().endswith("```"): - prev_codeblock = codeblock - codeblock = "" +__all__ = [ + "execute_linecmd", + "execute_codeblock", + "execute_python", + "execute_shell", + "summarize", +] -def _execute_load(filename: str) -> Generator[Message, None, None]: - if not os.path.exists(filename): - yield Message( - "system", "Tried to load file '" + filename + "', but it does not exist." - ) - confirm = input("Load from " + filename + "? (Y/n) ") - if confirm.lower() in ["y", "Y", "", "yes"]: - with open(filename, "r") as file: - data = file.read() - yield Message("system", f"# filename: {filename}\n\n{data}") - - -def _execute_linecmd(line: str) -> Generator[Message, None, None]: +# DEPRECATED +def execute_linecmd(line: str) -> Generator[Message, None, None]: """Executes a line command and returns the response.""" if line.startswith("terminal: "): cmd = line[len("terminal: ") :] - yield from _execute_shell(cmd) + yield from execute_shell(cmd) elif line.startswith("python: "): cmd = line[len("python: ") :] - yield from _execute_python(cmd) - elif line.strip().startswith("// load: "): - filename = line[len("load: ") :] - yield from _execute_load(filename) + yield from execute_python(cmd) -def _execute_codeblock(codeblock: str) -> Generator[Message, None, None]: +def execute_codeblock(codeblock: str) -> Generator[Message, None, None]: """Executes a codeblock and returns the output.""" codeblock_lang = codeblock.splitlines()[0].strip() codeblock = codeblock[len(codeblock_lang) :] if codeblock_lang in ["python"]: - yield from _execute_python(codeblock) + yield from execute_python(codeblock) elif codeblock_lang in ["terminal", "bash", "sh"]: - yield from _execute_shell(codeblock) + yield from execute_shell(codeblock) else: logger.warning(f"Unknown codeblock type {codeblock_lang}") - - -def _shorten_stdout(stdout: str) -> str: - """Shortens stdout to 1000 tokens.""" - lines = stdout.split("\n") - - # strip iso8601 timestamps - lines = [ - re.sub(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}[.]\d{3,9}Z?", "", line) - for line in lines - ] - # strip dates like "2017-08-02 08:48:43 +0000 UTC" - lines = [ - re.sub( - r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}( [+]\d{4})?( UTC)?", "", line - ).strip() - for line in lines - ] - - # strip common prefixes, useful for things like `gh runs view` - if len(lines) > 5: - prefix = os.path.commonprefix(lines) - if prefix: - lines = [line[len(prefix) :] for line in lines] - - pre_lines = 30 - post_lines = 30 - if len(lines) > pre_lines + post_lines: - lines = ( - lines[:pre_lines] - + [f"... ({len(lines) - pre_lines - post_lines} truncated) ..."] - + lines[-post_lines:] - ) - - return "\n".join(lines) - - -# init shell -shell = ShellSession() - -# close on exit -atexit.register(shell.close) - - -def ask_execute() -> bool: - console = Console() - answer = console.input( - f"[bold yellow on red] {EMOJI_WARN} Execute code? (Y/n) [/] ", - ) - return answer.lower() in ["y", "Y", "", "yes"] - - -def _execute_shell(cmd: str, ask=True) -> Generator[Message, None, None]: - """Executes a shell command and returns the output.""" - cmd = cmd.strip() - if cmd.startswith("$ "): - cmd = cmd[len("$ ") :] - if ask: - _print_preview(f"$ {cmd}", "bash") - confirm = ask_execute() - print() - - if not ask or confirm: - returncode, stdout, stderr = shell.run_command(cmd) - stdout = _shorten_stdout(stdout.strip()) - stderr = _shorten_stdout(stderr.strip()) - - msg = f"Ran command:\n```bash\n{cmd}\n```\n\n" - if stdout: - msg += f"stdout:\n```\n{stdout}\n```\n\n" - if stderr: - msg += f"stderr:\n```\n{stderr}\n```\n\n" - if not stdout and not stderr: - msg += "No output\n" - msg += f"Return code: {returncode}" - - yield Message("system", msg) - - -locals_ = {} # type: ignore - - -def _execute_python(code: str, ask=True) -> Generator[Message, None, None]: - """Executes a python codeblock and returns the output.""" - code = code.strip() - if ask: - _print_preview(code, "python") - confirm = ask_execute() - print() - - if not ask or confirm: - # remove blank lines - code = "\n".join([line for line in code.split("\n") if line.strip()]) - - exc = None - with redirect_stdout(io.StringIO()) as out, redirect_stderr( - io.StringIO() - ) as err: - try: - exec(code, locals_, locals_) # type: ignore - except Exception as e: - exc = e - stdout = out.getvalue().strip() - stderr = err.getvalue().strip() - # print(f"Completed execution: stdout={stdout}, stderr={stderr}, exc={exc}") - - output = "" - if stdout: - output += f"stdout:\n{stdout}\n\n" - if stderr: - output += f"stderr:\n{stderr}\n\n" - if exc: - tb = exc.__traceback__ - while tb.tb_next: # type: ignore - tb = tb.tb_next # type: ignore - output += f"Exception during execution on line {tb.tb_lineno}:\n {exc.__class__.__name__}: {exc}" # type: ignore - yield Message("system", "Executed code block.\n\n" + output) - else: - yield Message("system", "Aborted, user chose not to run command.") - - -def test_execute_python(): - assert _execute_python("1 + 1", ask=False) == ">>> 1 + 1\n2\n" - assert _execute_python("a = 2\na", ask=False) == ">>> a = 2\n>>> a\n2\n" - assert _execute_python("print(1)", ask=False) == ">>> print(1)\n" - - -@memory.cache -def _llm_summarize(content: str) -> str: - """Summarizes a long text using a LLM algorithm.""" - response = openai.Completion.create( - model="text-davinci-003", - prompt="Please summarize the following:\n" + content + "\n\nSummary:", - temperature=0, - max_tokens=256, - ) - summary = response.choices[0].text - logger.info( - f"Summarized long output ({len_tokens(content)} -> {len_tokens(summary)} tokens): " - + summary - ) - return summary - - -def summarize(msg: Message) -> Message: - """Uses a cheap LLM to summarize long outputs.""" - if len_tokens(msg.content) > 200: - # first 100 tokens - beginning = " ".join(msg.content.split()[:150]) - # last 100 tokens - end = " ".join(msg.content.split()[-100:]) - summary = _llm_summarize(beginning + "\n...\n" + end) - else: - summary = _llm_summarize(msg.content) - return Message("system", f"Here is a summary of the response:\n{summary}") - - -# OLD -def old(): - # parse code into statements - try: - statements = ast.parse(code).body - except SyntaxError as e: - yield Message("system", f"SyntaxError: {e}") - return - - output = "" - # execute statements - for stmt in statements: - stmt_str = ast.unparse(stmt) - output += ">>> " + stmt_str + "\n" - try: - # if stmt is assignment or function def, have to use exec - if ( - isinstance(stmt, ast.Assign) - or isinstance(stmt, ast.AnnAssign) - or isinstance(stmt, ast.Assert) - or isinstance(stmt, ast.ClassDef) - or isinstance(stmt, ast.FunctionDef) - or isinstance(stmt, ast.Import) - or isinstance(stmt, ast.ImportFrom) - or isinstance(stmt, ast.If) - or isinstance(stmt, ast.For) - or isinstance(stmt, ast.While) - or isinstance(stmt, ast.With) - or isinstance(stmt, ast.Try) - or isinstance(stmt, ast.AsyncFor) - or isinstance(stmt, ast.AsyncFunctionDef) - or isinstance(stmt, ast.AsyncWith) - ): - with io.StringIO() as buf, redirect_stdout(buf): - exec(stmt_str, globals(), locals_) - result = buf.getvalue().strip() - else: - result = eval(stmt_str, globals(), locals_) - if result: - output += str(result) + "\n" - except Exception as e: - output += f"{e.__class__.__name__}: {e}\n" - break diff --git a/gptme/tools/python.py b/gptme/tools/python.py new file mode 100644 index 00000000..d657a0bc --- /dev/null +++ b/gptme/tools/python.py @@ -0,0 +1,103 @@ +import ast +import io +from contextlib import redirect_stderr, redirect_stdout +from typing import Generator + +from ..message import Message +from ..util import ask_execute, print_preview + +locals_ = {} # type: ignore + + +def execute_python(code: str, ask=True) -> Generator[Message, None, None]: + """Executes a python codeblock and returns the output.""" + code = code.strip() + if ask: + print_preview(code, "python") + confirm = ask_execute() + print() + + if not ask or confirm: + # remove blank lines + code = "\n".join([line for line in code.split("\n") if line.strip()]) + + exc = None + with redirect_stdout(io.StringIO()) as out, redirect_stderr( + io.StringIO() + ) as err: + try: + exec(code, locals_, locals_) # type: ignore + except Exception as e: + exc = e + stdout = out.getvalue().strip() + stderr = err.getvalue().strip() + # print(f"Completed execution: stdout={stdout}, stderr={stderr}, exc={exc}") + + output = "" + if stdout: + output += f"stdout:\n{stdout}\n\n" + if stderr: + output += f"stderr:\n{stderr}\n\n" + if exc: + tb = exc.__traceback__ + while tb.tb_next: # type: ignore + tb = tb.tb_next # type: ignore + output += f"Exception during execution on line {tb.tb_lineno}:\n {exc.__class__.__name__}: {exc}" # type: ignore + yield Message("system", "Executed code block.\n\n" + output) + else: + yield Message("system", "Aborted, user chose not to run command.") + + +def test_execute_python(): + assert execute_python("1 + 1", ask=False) == ">>> 1 + 1\n2\n" + assert execute_python("a = 2\na", ask=False) == ">>> a = 2\n>>> a\n2\n" + assert execute_python("print(1)", ask=False) == ">>> print(1)\n" + + # test that vars are preserved between executions + assert execute_python("a = 2", ask=False) == ">>> a = 2\n" + assert execute_python("a", ask=False) == ">>> a\n2\n" + + +# OLD +def old(code: str): + # parse code into statements + try: + statements = ast.parse(code).body + except SyntaxError as e: + yield Message("system", f"SyntaxError: {e}") + return + + output = "" + # execute statements + for stmt in statements: + stmt_str = ast.unparse(stmt) + output += ">>> " + stmt_str + "\n" + try: + # if stmt is assignment or function def, have to use exec + if ( + isinstance(stmt, ast.Assign) + or isinstance(stmt, ast.AnnAssign) + or isinstance(stmt, ast.Assert) + or isinstance(stmt, ast.ClassDef) + or isinstance(stmt, ast.FunctionDef) + or isinstance(stmt, ast.Import) + or isinstance(stmt, ast.ImportFrom) + or isinstance(stmt, ast.If) + or isinstance(stmt, ast.For) + or isinstance(stmt, ast.While) + or isinstance(stmt, ast.With) + or isinstance(stmt, ast.Try) + or isinstance(stmt, ast.AsyncFor) + or isinstance(stmt, ast.AsyncFunctionDef) + or isinstance(stmt, ast.AsyncWith) + ): + with io.StringIO() as buf, redirect_stdout(buf): + exec(stmt_str, globals(), locals_) + result = buf.getvalue().strip() + else: + result = eval(stmt_str, globals(), locals_) + if result: + output += str(result) + "\n" + except Exception as e: + output += f"{e.__class__.__name__}: {e}\n" + break diff --git a/gptme/tools/shell.py b/gptme/tools/shell.py index 6411535a..54d576e6 100644 --- a/gptme/tools/shell.py +++ b/gptme/tools/shell.py @@ -1,10 +1,17 @@ +import atexit import os +import re import select +import shlex import subprocess +from typing import Generator + +from ..message import Message +from ..util import ask_execute, print_preview class ShellSession: - def __init__(self): + def __init__(self) -> None: self.process = subprocess.Popen( ["bash"], stdin=subprocess.PIPE, @@ -13,12 +20,14 @@ def __init__(self): bufsize=0, # Unbuffered universal_newlines=True, ) - self.stdout_fd = self.process.stdout.fileno() - self.stderr_fd = self.process.stderr.fileno() + self.stdout_fd = self.process.stdout.fileno() # type: ignore + self.stderr_fd = self.process.stderr.fileno() # type: ignore self.delimiter = "END_OF_COMMAND_OUTPUT" - def run_command(self, command): + def run_command(self, command: str | list[str]) -> tuple[int | None, str, str]: assert self.process.stdin + if isinstance(command, list): + command = " ".join(shlex.quote(arg) for arg in command) full_command = f"{command}; echo ReturnCode:$?; echo {self.delimiter}" self.process.stdin.write(full_command + "\n") @@ -62,3 +71,81 @@ def close(self): self.process.terminate() self.process.wait(timeout=0.2) self.process.kill() + + +_shell = None + + +def get_shell() -> ShellSession: + global _shell + if _shell is None: + # init shell + _shell = ShellSession() + + # close on exit + atexit.register(_shell.close) + return _shell + + +def execute_shell(cmd: str, ask=True) -> Generator[Message, None, None]: + """Executes a shell command and returns the output.""" + shell = get_shell() + + cmd = cmd.strip() + if cmd.startswith("$ "): + cmd = cmd[len("$ ") :] + if ask: + print_preview(f"$ {cmd}", "bash") + confirm = ask_execute() + print() + + if not ask or confirm: + returncode, stdout, stderr = shell.run_command(cmd) + stdout = _shorten_stdout(stdout.strip()) + stderr = _shorten_stdout(stderr.strip()) + + msg = f"Ran command:\n```bash\n{cmd}\n```\n\n" + if stdout: + msg += f"stdout:\n```\n{stdout}\n```\n\n" + if stderr: + msg += f"stderr:\n```\n{stderr}\n```\n\n" + if not stdout and not stderr: + msg += "No output\n" + msg += f"Return code: {returncode}" + + yield Message("system", msg) + + +def _shorten_stdout(stdout: str) -> str: + """Shortens stdout to 1000 tokens.""" + lines = stdout.split("\n") + + # strip iso8601 timestamps + lines = [ + re.sub(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}[.]\d{3,9}Z?", "", line) + for line in lines + ] + # strip dates like "2017-08-02 08:48:43 +0000 UTC" + lines = [ + re.sub( + r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}( [+]\d{4})?( UTC)?", "", line + ).strip() + for line in lines + ] + + # strip common prefixes, useful for things like `gh runs view` + if len(lines) > 5: + prefix = os.path.commonprefix(lines) + if prefix: + lines = [line[len(prefix) :] for line in lines] + + pre_lines = 30 + post_lines = 30 + if len(lines) > pre_lines + post_lines: + lines = ( + lines[:pre_lines] + + [f"... ({len(lines) - pre_lines - post_lines} truncated) ..."] + + lines[-post_lines:] + ) + + return "\n".join(lines) diff --git a/gptme/tools/summarize.py b/gptme/tools/summarize.py new file mode 100644 index 00000000..f0bce43c --- /dev/null +++ b/gptme/tools/summarize.py @@ -0,0 +1,39 @@ +import logging + +import openai + +from ..cache import memory +from ..message import Message +from ..util import len_tokens + +logger = logging.getLogger(__name__) + + +@memory.cache +def _llm_summarize(content: str) -> str: + """Summarizes a long text using a LLM algorithm.""" + response = openai.Completion.create( + model="text-davinci-003", + prompt="Please summarize the following:\n" + content + "\n\nSummary:", + temperature=0, + max_tokens=256, + ) + summary = response.choices[0].text + logger.info( + f"Summarized long output ({len_tokens(content)} -> {len_tokens(summary)} tokens): " + + summary + ) + return summary + + +def summarize(msg: Message) -> Message: + """Uses a cheap LLM to summarize long outputs.""" + if len_tokens(msg.content) > 200: + # first 100 tokens + beginning = " ".join(msg.content.split()[:150]) + # last 100 tokens + end = " ".join(msg.content.split()[-100:]) + summary = _llm_summarize(beginning + "\n...\n" + end) + else: + summary = _llm_summarize(msg.content) + return Message("system", f"Here is a summary of the response:\n{summary}") diff --git a/gptme/util.py b/gptme/util.py index 75c5acef..52530e96 100644 --- a/gptme/util.py +++ b/gptme/util.py @@ -1,8 +1,14 @@ import random from datetime import datetime, timedelta +from rich import print +from rich.console import Console +from rich.syntax import Syntax + from .message import Message +EMOJI_WARN = "⚠️" + def len_tokens(content: str | list[Message]) -> float: """Approximate the number of tokens in a string by assuming words have len 4 (lol).""" @@ -82,3 +88,24 @@ def epoch_to_age(epoch): return "yesterday" else: return f"{age.days} days ago ({datetime.fromtimestamp(epoch).strftime('%Y-%m-%d')})" + + +def print_preview(code=None, lang=None): + # print a preview section header + print() + print("[bold white]Preview[/bold white]") + if code: + print(Syntax(code.strip(), lang)) + print() + + +def ask_execute(default=True) -> bool: + # TODO: add a way to outsource ask_execute decision to another agent/LLM + console = Console() + choicestr = f"({'Y' if default else 'y'}/{'n' if default else 'N'})" + # answer = None + # while not answer or answer.lower() not in ["y", "yes", "n", "no", ""]: + answer = console.input( + f"[bold yellow on red] {EMOJI_WARN} Execute code? {choicestr} [/] ", + ) + return answer.lower() in (["y", "yes"] + [""] if default else []) diff --git a/poetry.lock b/poetry.lock index 4d2f4a93..eca4533c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1219,6 +1219,17 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] +[[package]] +name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, + {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, +] + [[package]] name = "tomli" version = "2.0.1" @@ -1250,6 +1261,17 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "types-toml" +version = "0.10.8.7" +description = "Typing stubs for toml" +optional = false +python-versions = "*" +files = [ + {file = "types-toml-0.10.8.7.tar.gz", hash = "sha256:58b0781c681e671ff0b5c0319309910689f4ab40e8a2431e205d70c94bb6efb1"}, + {file = "types_toml-0.10.8.7-py3-none-any.whl", hash = "sha256:61951da6ad410794c97bec035d59376ce1cbf4453dc9b6f90477e81e4442d631"}, +] + [[package]] name = "typing-extensions" version = "4.7.1" @@ -1411,4 +1433,4 @@ server = ["llama-cpp-python"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "a9bcaee93ada169e79f168e62a4017a834cc789faafa3bdee2b4888e8f158462" +content-hash = "e23ed065d01ac6420ee333923fcc7844c69b747e15104aad76dd9489beea0136" diff --git a/pyproject.toml b/pyproject.toml index dbd1d12c..1010c076 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ python-dotenv = "^1.0.0" rich = "^13.5.2" pick = "^2.2.0" joblib = "^1.3.2" +toml = "^0.10.2" [tool.poetry.group.dev.dependencies] pytest = "^7.2" @@ -27,6 +28,7 @@ pytest-cov = "*" mypy = "*" ruff = "*" black = "*" +types-toml = "^0.10.8.7" [tool.poetry.extras] server = ["llama-cpp-python"] diff --git a/tests/test_shell.py b/tests/test_shell.py index 000af16c..3c312703 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -1,4 +1,4 @@ -from gptme.tools import ShellSession +from gptme.tools.shell import ShellSession def test_echo():