Skip to content

Commit

Permalink
fix: improve rich usage, change calls to use gptme.util.console.{prin…
Browse files Browse the repository at this point in the history
…t,input,log}
  • Loading branch information
ErikBjare committed Sep 26, 2024
1 parent 54d91d7 commit 8cf53cb
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 45 deletions.
42 changes: 16 additions & 26 deletions gptme/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import errno
import importlib.metadata
import io
import logging
import os
import re
Expand All @@ -15,8 +14,6 @@

import click
from pick import pick
from rich import print # noqa: F401
from rich.console import Console

from .commands import (
CMDFIX,
Expand All @@ -35,10 +32,15 @@
from .prompts import get_prompt
from .tools import ToolUse, execute_msg, has_tool
from .tools.browser import read_url
from .util import epoch_to_age, generate_name, print_bell
from .util import (
console,
epoch_to_age,
generate_name,
print_bell,
rich_to_str,
)

logger = logging.getLogger(__name__)
print_builtin = __builtins__["print"] # type: ignore


script_path = Path(os.path.realpath(__file__))
Expand Down Expand Up @@ -135,10 +137,10 @@ def main(
"""Main entrypoint for the CLI."""
if version:
# print version
print_builtin(f"gptme {importlib.metadata.version('gptme-python')}")
print(f"gptme {importlib.metadata.version('gptme-python')}")

# print dirs
print_builtin(f"Logs dir: {get_logs_dir()}")
print(f"Logs dir: {get_logs_dir()}")

exit(0)

Expand Down Expand Up @@ -227,18 +229,18 @@ def chat(
logfile = get_logfile(
name, interactive=(not prompt_msgs and interactive) and sys.stdin.isatty()
)
print(f"Using logdir {logfile.parent}")
console.log(f"Using logdir {logfile.parent}")
log = LogManager.load(logfile, initial_msgs=initial_msgs, show_hidden=show_hidden)

# change to workspace directory
# use if exists, create if @log, or use given path
if (logfile.parent / "workspace").exists():
assert workspace in ["@log", "."], "Workspace already exists"
workspace_path = logfile.parent / "workspace"
print(f"Using workspace at {workspace_path}")
console.log(f"Using workspace at {workspace_path}")
elif workspace == "@log":
workspace_path = logfile.parent / "workspace"
print(f"Creating workspace at {workspace_path}")
console.log(f"Creating workspace at {workspace_path}")
os.makedirs(workspace_path, exist_ok=True)
else:
workspace_path = Path(workspace)
Expand All @@ -258,7 +260,7 @@ def chat(

# print log
log.print()
print("--- ^^^ past messages ^^^ ---")
console.print("--- ^^^ past messages ^^^ ---")

# main loop
while True:
Expand Down Expand Up @@ -320,12 +322,6 @@ def step(
stream: bool = True,
) -> Generator[Message, None, None]:
"""Runs a single pass of the chat."""

# if last message was from assistant, try to run tools again
# FIXME: can't do this here because it will run twice
# if log[-1].role == "assistant":
# yield from execute_msg(log[-1], ask=not no_confirm)

# If last message was a response, ask for input.
# If last message was from the user (such as from crash/edited log),
# then skip asking for input and generate response
Expand Down Expand Up @@ -397,7 +393,7 @@ def get_name(name: str) -> Path:
if not logpath.exists():
break
else:
print(f"Name {name} already exists, try again.")
console.print(f"Name {name} already exists, try again.")
else:
# if name starts with date, use as is
try:
Expand Down Expand Up @@ -483,9 +479,9 @@ def prompt_user(value=None) -> str: # pragma: no cover
def prompt_input(prompt: str, value=None) -> str: # pragma: no cover
prompt = prompt.strip() + ": "
if value:
print(prompt + value)
console.print(prompt + value)
else:
prompt = _rich_to_str(prompt)
prompt = rich_to_str(prompt, color_system="256")

# https://stackoverflow.com/a/53260487/965332
original_stdout = sys.stdout
Expand All @@ -495,12 +491,6 @@ def prompt_input(prompt: str, value=None) -> str: # pragma: no cover
return value


def _rich_to_str(s: str) -> str:
console = Console(file=io.StringIO(), color_system="256")
console.print(s)
return console.file.getvalue() # type: ignore


def _read_stdin() -> str:
chunk_size = 1024 # 1 KB
all_data = ""
Expand Down
6 changes: 4 additions & 2 deletions gptme/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from tomlkit import TOMLDocument
from tomlkit.container import Container

from .util import console

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -93,7 +95,7 @@ def _load_config() -> tomlkit.TOMLDocument:
os.makedirs(os.path.dirname(config_path), exist_ok=True)
with open(config_path, "w") as config_file:
tomlkit.dump(default_config.dict(), config_file)
print(f"Created config file at {config_path}")
console.log(f"Created config file at {config_path}")

# Now you can read the settings from the config file like this:
with open(config_path) as config_file:
Expand Down Expand Up @@ -131,7 +133,7 @@ def get_workspace_prompt(workspace: str) -> str:
]
if project_config_paths:
project_config_path = project_config_paths[0]
logger.info(f"Using project configuration at {project_config_path}")
console.log(f"Using project configuration at {project_config_path}")
# load project config
with open(project_config_path) as f:
project_config = tomlkit.load(f)
Expand Down
5 changes: 3 additions & 2 deletions gptme/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .models import PROVIDERS, get_recommended_model, set_default_model
from .tabcomplete import register_tabcomplete
from .tools import init_tools
from .util import console

logger = logging.getLogger(__name__)
_init_done = False
Expand Down Expand Up @@ -58,8 +59,8 @@ def init(model: str | None, interactive: bool):

if not model:
model = get_recommended_model(provider)
logger.info(
"No model specified, using recommended model for provider: %s", model
console.log(
f"No model specified, using recommended model for provider: {model}"
)
set_default_model(model)

Expand Down
19 changes: 6 additions & 13 deletions gptme/message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import base64
import builtins
import dataclasses
import io
import logging
import shutil
import sys
Expand All @@ -12,15 +10,13 @@
from typing import Any, Literal

import tomlkit
from rich import print
from rich.console import Console
from rich.syntax import Syntax
from tomlkit._utils import escape_string
from typing_extensions import Self

from .codeblock import Codeblock
from .constants import ROLE_COLOR
from .util import get_tokenizer
from .util import console, get_tokenizer, rich_to_str

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -247,11 +243,7 @@ def format_msgs(
continue
elif highlight:
lang = block.split("\n")[0]
console = Console(
file=io.StringIO(), width=shutil.get_terminal_size().columns
)
console.print(Syntax(block.rstrip(), lang))
block = console.file.getvalue() # type: ignore
block = rich_to_str(Syntax(block.rstrip(), lang))
output += f"```{block.rstrip()}\n```"
outputs.append(f"{userprefix}: {output.rstrip()}")
return outputs
Expand All @@ -276,12 +268,13 @@ def print_msg(
skipped_hidden += 1
continue
try:
print(s)
console.print(s)
except Exception:
# rich can throw errors, if so then print the raw message
builtins.print(s)
logger.exception("Error printing message")
print(s)
if skipped_hidden:
print(
console.print(
f"[grey30]Skipped {skipped_hidden} hidden system messages, show with --show-hidden[/]"
)

Expand Down
13 changes: 11 additions & 2 deletions gptme/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import io
import logging
import random
import re
import sys
import textwrap
from datetime import datetime, timedelta
from functools import lru_cache
from typing import Any

import tiktoken
from rich import print
Expand All @@ -15,6 +17,8 @@

logger = logging.getLogger(__name__)

console = Console()


def get_tokenizer(model: str):
if "gpt-4" in model or "gpt-3.5" in model:
Expand Down Expand Up @@ -130,8 +134,7 @@ def print_preview(code: str, lang: str): # pragma: no cover


def ask_execute(question="Execute code?", default=True) -> bool: # pragma: no cover
# TODO: add a way to outsource ask_execute decision to another agent/LLM
console = Console()
# TODO: add a way to outsource ask_execute decision to another agent/LLM, possibly by overriding rich console somehow
choicestr = f"({'Y' if default else 'y'}/{'n' if default else 'N'})"
# answer = None
# while not answer or answer.lower() not in ["y", "yes", "n", "no", ""]:
Expand Down Expand Up @@ -212,3 +215,9 @@ def decorator(func): # pragma: no cover
return func

return decorator


def rich_to_str(s: Any, **kwargs) -> str:
c = Console(file=io.StringIO(), **kwargs)
c.print(s)
return c.file.getvalue() # type: ignore

0 comments on commit 8cf53cb

Please sign in to comment.