Skip to content

Commit

Permalink
feat: added config, refactoring, hide initial system messages, better…
Browse files Browse the repository at this point in the history
… context awareness
  • Loading branch information
ErikBjare committed Sep 6, 2023
1 parent 3f8f238 commit 9b54cec
Show file tree
Hide file tree
Showing 13 changed files with 510 additions and 387 deletions.
118 changes: 87 additions & 31 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,25 @@
from rich import print
from rich.console import Console

from .config import get_config
from .constants import role_color
from .logmanager import LogManager, print_log
from .message import Message
from .prompts import initial_prompt
from .tools import (
_execute_codeblock,
_execute_linecmd,
_execute_python,
_execute_save,
_execute_shell,
execute_codeblock,
execute_linecmd,
execute_python,
execute_shell,
)
from .tools.shell import get_shell
from .util import epoch_to_age, generate_unique_name, msgs2dicts

logger = logging.getLogger(__name__)


LLMChoice = Literal["openai", "llama"]
OpenAIModel = Literal["gpt-3.5-turbo", "gpt4"]
ModelChoice = Literal["gpt-3.5-turbo", "gpt4"]


def get_logfile(logdir: Path) -> Path:
Expand All @@ -70,24 +71,34 @@ def execute_msg(msg: Message) -> Generator[Message, None, None]:
assert msg.role == "assistant", "Only assistant messages can be executed"

for line in msg.content.splitlines():
yield from _execute_linecmd(line)
yield from execute_linecmd(line)

# get all markdown code blocks
# we support blocks beginning with ```python and ```bash
codeblocks = [codeblock for codeblock in msg.content.split("```")[1::2]]
for codeblock in codeblocks:
yield from _execute_codeblock(codeblock)

yield from _execute_save(msg.content)
yield from execute_codeblock(codeblock)


Actions = Literal[
"continue", "summarize", "load", "shell", "python", "replay", "undo", "help", "exit"
"continue",
"summarize",
"log",
"summarize",
"context",
"load",
"shell",
"python",
"replay",
"undo",
"help",
"exit",
]

action_descriptions: dict[Actions, str] = {
"continue": "Continue",
"undo": "Undo the last action",
"log": "Show the conversation log",
"summarize": "Summarize the conversation so far",
"load": "Load a file",
"shell": "Execute a shell command",
Expand All @@ -106,13 +117,18 @@ def handle_cmd(
name, *args = cmd.split(" ")
match name:
case "bash" | "sh" | "shell":
yield from _execute_shell(" ".join(args), ask=not no_confirm)
yield from execute_shell(" ".join(args), ask=not no_confirm)
case "python" | "py":
yield from _execute_python(" ".join(args), ask=not no_confirm)
yield from execute_python(" ".join(args), ask=not no_confirm)
case "continue":
raise NotImplementedError
case "log":
logmanager.print(show_hidden="--hidden" in args)
case "summarize":
raise NotImplementedError
case "context":
# print context msg
print(_gen_context_msg())
case "undo":
# if int, undo n messages
n = int(args[0]) if args and args[0].isdigit() else 1
Expand Down Expand Up @@ -170,10 +186,10 @@ def handle_cmd(
type=click.Choice(["openai", "llama"]),
)
@click.option(
"--openai-model",
default="gpt-3.5-turbo",
help="OpenAI model to use",
type=click.Choice(["gpt-3.5-turbo", "gpt4"]),
"--model",
default="gpt-4",
help="Model to use (gpt-3.5 not recommended)",
type=click.Choice(["gpt-4", "gpt-3.5-turbo", "wizardcoder-..."]),
)
@click.option(
"--stream/--no-stream",
Expand All @@ -185,25 +201,36 @@ def handle_cmd(
@click.option(
"-y", "--no-confirm", is_flag=True, help="Skips all confirmation prompts."
)
@click.option(
"--show-hidden",
is_flag=True,
help="Show hidden system messages.",
)
def main(
prompt: str | None,
prompt_system: str,
name: str,
llm: LLMChoice,
openai_model: OpenAIModel,
model: ModelChoice,
stream: bool,
verbose: bool,
no_confirm: bool,
show_hidden: bool,
):
config = get_config()
logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO)
load_dotenv()
_load_readline_history()

if llm == "openai":
if "OPENAI_API_KEY" not in os.environ:
print("Error: OPENAI_API_KEY not set, see the README.")
if "OPENAI_API_KEY" in os.environ:
api_key = os.environ["OPENAI_API_KEY"]
elif api_key := config["env"]["OPENAI_API_KEY"]:
pass
else:
print("Error: OPENAI_API_KEY not set in env or config, see README.")
sys.exit(1)
openai.api_key = os.environ["OPENAI_API_KEY"]
openai.api_key = api_key
else:
openai.api_base = "http://localhost:8000/v1"

Expand All @@ -214,10 +241,16 @@ def main(

LOGDIR = Path("~/.local/share/gptme/logs").expanduser()
LOGDIR.mkdir(parents=True, exist_ok=True)
if name:
if name == "random":
name = generate_unique_name()
logpath = LOGDIR / (f"{datetime.now().strftime('%Y-%m-%d')}-{name}")
datestr = datetime.now().strftime("%Y-%m-%d")
logpath = LOGDIR
if name == "random":
# check if name exists, if so, generate another one
while not logpath or logpath.exists():
if name == "random":
name = generate_unique_name()
logpath = LOGDIR / (f"{datestr}-{name}")
elif name:
logpath = LOGDIR / (f"{datestr}-{name}")
else:
# let user select between starting a new conversation and loading a previous one
# using the library
Expand All @@ -244,13 +277,15 @@ def main(
name = input("Name for conversation (or empty for random words): ")
if not name:
name = generate_unique_name()
logpath = LOGDIR / (datetime.now().strftime("%Y-%m-%d") + f"-{name}")
logpath = LOGDIR / f"{datestr}-{name}"
else:
logpath = LOGDIR / prev_conv_files[index - 1].parent

print(f"Using logdir {logpath}")
logfile = get_logfile(logpath)
logmanager = LogManager.load(logfile, initial_msgs=promptmsgs)
logmanager = LogManager.load(
logfile, initial_msgs=promptmsgs, show_hidden=show_hidden
)
logmanager.print()
print("--- ^^^ past messages ^^^ ---")

Expand Down Expand Up @@ -302,7 +337,14 @@ def main(
# if large context, try to reduce/summarize
# print response
try:
msg_response = reply(logmanager.prepare_messages(), openai_model, stream)
# performs reduction/context trimming
msgs = logmanager.prepare_messages()

# append temporary message with current context
msgs += [_gen_context_msg()]

# generate response
msg_response = reply(msgs, model, stream)

# log response and run tools
if msg_response:
Expand All @@ -313,6 +355,20 @@ def main(
print("Interrupted")


def _gen_context_msg() -> Message:
shell = get_shell()
msgstr = ""

_, pwd, _ = shell.run_command("echo pwd: $(pwd)")
msgstr += f"$ pwd\n{pwd.strip()}\n"

ret, git, _ = shell.run_command("git status -s")
if ret == 0:
msgstr += f"$ git status\n{git}\n"

return Message("system", msgstr.strip(), hide=True)


CONFIG_PATH = Path("~/.config/gptme").expanduser()
CONFIG_PATH.mkdir(parents=True, exist_ok=True)
HISTORY_FILE = CONFIG_PATH / "history"
Expand Down Expand Up @@ -356,7 +412,7 @@ def prompt_input(prompt: str, value=None) -> str:
return value


def reply(messages: list[Message], model: OpenAIModel, stream: bool = False) -> Message:
def reply(messages: list[Message], model: str, stream: bool = False) -> Message:
if stream:
return reply_stream(messages, model)
else:
Expand All @@ -370,7 +426,7 @@ def reply(messages: list[Message], model: OpenAIModel, stream: bool = False) ->
top_p = 0.1


def _chat_complete(messages: list[Message], model: OpenAIModel) -> str:
def _chat_complete(messages: list[Message], model: str) -> str:
# This will generate code and such, so we need appropriate temperature and top_p params
# top_p controls diversity, temperature controls randomness
response = openai.ChatCompletion.create( # type: ignore
Expand All @@ -382,7 +438,7 @@ def _chat_complete(messages: list[Message], model: OpenAIModel) -> str:
return response.choices[0].message.content


def reply_stream(messages: list[Message], model: OpenAIModel) -> Message:
def reply_stream(messages: list[Message], model: str) -> Message:
print(f"{PROMPT_ASSISTANT}: Thinking...", end="\r")
response = openai.ChatCompletion.create( # type: ignore
model=model,
Expand Down
51 changes: 51 additions & 0 deletions gptme/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
from typing import TypedDict

import toml


class Config(TypedDict):
prompt: dict
env: dict


default_config: Config = {
"prompt": {"about_user": "I am a curious human programmer."},
"env": {"OPENAI_API_KEY": None},
}

_config: Config | None = None


def get_config() -> Config:
global _config
if _config is None:
_config = _load_config()
return _config


def _load_config() -> Config:
# Define the path to the config file
config_path = os.path.expanduser("~/.config/gptme/config.toml")

# Check if the config file exists
if not os.path.exists(config_path):
# If not, create it and write some default settings
os.makedirs(os.path.dirname(config_path), exist_ok=True)
with open(config_path, "w") as config_file:
toml.dump(default_config, config_file)
print(f"Created config file at {config_path}")

# Now you can read the settings from the config file like this:
with open(config_path, "r") as config_file:
config: dict = toml.load(config_file)

# TODO: validate
config = Config(**config) # type: ignore

return config # type: ignore


if __name__ == "__main__":
config = get_config()
print(config)
29 changes: 23 additions & 6 deletions gptme/logmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@

class LogManager:
def __init__(
self, log: list[Message] | None = None, logfile: PathLike | None = None
self,
log: list[Message] | None = None,
logfile: PathLike | None = None,
show_hidden=False,
):
self.log = log or []
assert logfile is not None, "logfile must be specified"
self.logfile = logfile
self.show_hidden = show_hidden
# TODO: Check if logfile has contents, then maybe load, or should it overwrite?

def append(self, msg: Message, quiet=False) -> None:
Expand All @@ -36,8 +40,8 @@ def write(self) -> None:
"""Writes the log to the logfile."""
write_log(self.log, self.logfile)

def print(self):
print_log(self.log, oneline=False)
def print(self, show_hidden: bool | None = None):
print_log(self.log, oneline=False, show_hidden=show_hidden or self.show_hidden)

def undo(self, n: int = 1) -> None:
"""Removes the last message from the log."""
Expand All @@ -64,6 +68,7 @@ def prepare_messages(self) -> list[Message]:
"""Prepares the log into messages before sending it to the LLM."""
msgs = self.log
msgs_reduced = list(reduce_log(msgs))

if len(msgs) != len(msgs_reduced):
print(
f"Reduced log from {len_tokens(msgs)//1} to {len_tokens(msgs_reduced)//1} tokens"
Expand All @@ -76,13 +81,15 @@ def prepare_messages(self) -> list[Message]:
return msgs_limited

@classmethod
def load(cls, logfile=None, initial_msgs=initial_prompt()) -> "LogManager":
def load(
cls, logfile=None, initial_msgs=initial_prompt(), **kwargs
) -> "LogManager":
"""Loads a conversation log."""
with open(logfile, "r") as file:
msgs = [Message(**json.loads(line)) for line in file.readlines()]
if not msgs:
msgs = initial_msgs
return cls(msgs, logfile=logfile)
return cls(msgs, logfile=logfile, **kwargs)


def write_log(msg_or_log: Message | list[Message], logfile: PathLike) -> None:
Expand All @@ -106,9 +113,15 @@ def write_log(msg_or_log: Message | list[Message], logfile: PathLike) -> None:
)


def print_log(log: Message | list[Message], oneline: bool = True) -> None:
def print_log(
log: Message | list[Message], oneline: bool = True, show_hidden=False
) -> None:
"""Prints the log to the console."""
skipped_hidden = 0
for msg in log if isinstance(log, list) else [log]:
if msg.hide and not show_hidden:
skipped_hidden += 1
continue
color = role_color[msg.role]
userprefix = f"[bold {color}]{msg.user}[/bold {color}]"
# get terminal width
Expand Down Expand Up @@ -141,3 +154,7 @@ def print_log(log: Message | list[Message], oneline: bool = True) -> None:
print(f" ```{output[code_end+3:]}")
else:
print(f"\n{userprefix}: {output.rstrip()}")
if skipped_hidden:
print(
f"[grey30]Skipped {skipped_hidden} hidden system messages, show with --show-hidden[/]"
)
Loading

0 comments on commit 9b54cec

Please sign in to comment.