Skip to content

Commit

Permalink
refactor: Move utility functions into dedicated modules, make tool co…
Browse files Browse the repository at this point in the history
…ntent editable

- Move clipboard.py and useredit.py into util/
- Split out name generation into util/generate_name.py
- Create util/ask_execute.py for user interaction functions
- Make tool content editable before execution in patch/save tools
- Update imports across codebase
  • Loading branch information
ErikBjare committed Dec 5, 2024
1 parent 2c872f3 commit d01b943
Show file tree
Hide file tree
Showing 15 changed files with 304 additions and 172 deletions.
2 changes: 1 addition & 1 deletion gptme/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
from .tools.base import ConfirmFunc
from .tools.browser import read_url
from .util import (
ask_execute,
console,
path_with_tilde,
print_bell,
rich_to_str,
)
from .util.ask_execute import ask_execute
from .util.cost import log_costs
from .util.readline import add_history

Expand Down
10 changes: 8 additions & 2 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@
from .logmanager import ConversationMeta, get_user_conversations
from .message import Message
from .prompts import get_prompt
from .tools import all_tools, init_tools, ToolFormat, set_tool_format
from .util import epoch_to_age, generate_name
from .tools import (
ToolFormat,
all_tools,
init_tools,
set_tool_format,
)
from .util import epoch_to_age
from .util.generate_name import generate_name
from .util.readline import add_history

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .llm.models import get_model
from .tools import ToolUse, execute_msg, loaded_tools
from .tools.base import ConfirmFunc, get_tool_format
from .useredit import edit_text_with_editor
from .util.useredit import edit_text_with_editor

logger = logging.getLogger(__name__)

Expand Down
19 changes: 17 additions & 2 deletions gptme/tools/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@
from pathlib import Path

from ..message import Message
from ..util import print_preview
from .base import ConfirmFunc, Parameter, ToolSpec, ToolUse
from ..util.ask_execute import get_editable_text, set_editable_text, print_preview
from .base import (
ConfirmFunc,
Parameter,
ToolSpec,
ToolUse,
)

instructions = """
To patch/modify files, we use an adapted version of git conflict markers.
Expand Down Expand Up @@ -188,6 +193,7 @@ def execute_patch(
Applies the patch.
"""

fn = None
if code is not None and args is not None:
fn = " ".join(args)
if not fn:
Expand All @@ -197,6 +203,9 @@ def execute_patch(
code = kwargs.get("patch", "")
fn = kwargs.get("path", "")

assert code is not None, "No patch provided"
assert fn is not None, "No path provided"

if code is None:
yield Message("system", "No patch provided")
return
Expand All @@ -217,10 +226,16 @@ def execute_patch(
# TODO: include patch headers to delimit multiple patches
print_preview(patches_str, lang="diff")

# Make patch content editable before confirmation
set_editable_text(code, "patch")

if not confirm(f"Apply patch to {fn}?"):
print("Patch not applied")
return

# Get potentially edited content
code = get_editable_text()

try:
with open(path) as f:
original_content = f.read()
Expand Down
2 changes: 1 addition & 1 deletion gptme/tools/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import TYPE_CHECKING, TypeVar

from ..message import Message
from ..util import print_preview
from ..util.ask_execute import print_preview
from .base import (
ConfirmFunc,
Parameter,
Expand Down
25 changes: 20 additions & 5 deletions gptme/tools/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from pathlib import Path

from ..message import Message
from ..util import print_preview
from ..util.ask_execute import (
clear_editable_text,
get_editable_text,
set_editable_text,
print_preview,
)
from .base import (
ConfirmFunc,
Parameter,
Expand Down Expand Up @@ -99,10 +104,20 @@ def execute_save(
yield Message("system", "File already exists with identical content.")
return

if not confirm(f"Save to {fn}?"):
# early return
yield Message("system", "Save cancelled.")
return
# Make content editable before confirmation
ext = Path(fn).suffix.lstrip(".")
set_editable_text(content, ext)

try:
if not confirm(f"Save to {fn}?"):
# early return
yield Message("system", "Save cancelled.")
return

# Get potentially edited content
content = get_editable_text()
finally:
clear_editable_text()

# if the file exists, ask to overwrite
if path.exists():
Expand Down
3 changes: 2 additions & 1 deletion gptme/tools/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .base import Parameter

from ..message import Message
from ..util import get_installed_programs, get_tokenizer, print_preview
from ..util import get_installed_programs, get_tokenizer
from ..util.ask_execute import print_preview
from .base import ConfirmFunc, ToolSpec, ToolUse

logger = logging.getLogger(__name__)
Expand Down
9 changes: 7 additions & 2 deletions gptme/tools/tmux.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
from time import sleep

from ..message import Message
from ..util import print_preview
from .base import ConfirmFunc, Parameter, ToolSpec, ToolUse
from ..util.ask_execute import print_preview
from .base import (
ConfirmFunc,
Parameter,
ToolSpec,
ToolUse,
)

logger = logging.getLogger(__name__)

Expand Down
154 changes: 0 additions & 154 deletions gptme/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
import functools
import io
import logging
import random
import re
import shutil
import subprocess
import sys
import termios
import textwrap
from datetime import datetime, timedelta
from functools import lru_cache
Expand All @@ -19,14 +17,10 @@

from rich import print
from rich.console import Console
from rich.syntax import Syntax

from ..clipboard import copy, set_copytext

EMOJI_WARN = "⚠️"

logger = logging.getLogger(__name__)

console = Console(log_path=False)

_warned_models = set()
Expand All @@ -50,85 +44,6 @@ def get_tokenizer(model: str):
return tiktoken.get_encoding("cl100k_base")


actions = [
"running",
"jumping",
"walking",
"skipping",
"hopping",
"flying",
"swimming",
"crawling",
"sneaking",
"sprinting",
"sneaking",
"dancing",
"singing",
"laughing",
]
adjectives = [
"funny",
"happy",
"sad",
"angry",
"silly",
"crazy",
"sneaky",
"sleepy",
"hungry",
# colors
"red",
"blue",
"green",
"pink",
"purple",
"yellow",
"orange",
]
nouns = [
"cat",
"dog",
"rat",
"mouse",
"fish",
"elephant",
"dinosaur",
# birds
"bird",
"pelican",
# fictional
"dragon",
"unicorn",
"mermaid",
"monster",
"alien",
"robot",
# sea creatures
"whale",
"shark",
"walrus",
"octopus",
"squid",
"jellyfish",
"starfish",
"penguin",
"seal",
]


def generate_name():
action = random.choice(actions)
adjective = random.choice(adjectives)
noun = random.choice(nouns)
return f"{action}-{adjective}-{noun}"


def is_generated_name(name: str) -> bool:
"""if name is a name generated by generate_name"""
all_words = actions + adjectives + nouns
return name.count("-") == 2 and all(word in all_words for word in name.split("-"))


def epoch_to_age(epoch, incl_date=False):
# takes epoch and returns "x minutes ago", "3 hours ago", "yesterday", etc.
age = datetime.now() - datetime.fromtimestamp(epoch)
Expand All @@ -148,75 +63,6 @@ def epoch_to_age(epoch, incl_date=False):
)


copiable = False


def set_copiable():
global copiable
copiable = True


def clear_copiable():
global copiable
copiable = False


def print_preview(code: str, lang: str, copy: bool = False): # pragma: no cover
print()
print("[bold white]Preview[/bold white]")

if copy:
set_copiable()
set_copytext(code)

# NOTE: we can set background_color="default" to remove background
print(Syntax(code.strip("\n"), lang))
print()


override_auto = False


def ask_execute(question="Execute code?", default=True) -> bool: # pragma: no cover
global override_auto
if override_auto:
return True

print_bell() # Ring the bell just before asking for input
termios.tcflush(sys.stdin, termios.TCIFLUSH) # flush stdin

choicestr = f"[{'Y' if default else 'y'}/{'n' if default else 'N'}{'/c' if copiable else ''}/?]"
answer = console.input(
f"[bold bright_yellow on red] {question} {choicestr} [/] ",
)

if not override_auto and copiable and "c" == answer.lower().strip():
if copy():
print("Copied to clipboard.")
return False
clear_copiable()

# secret option to stop asking for the rest of the session
if answer.lower() in ["auto"]:
return (override_auto := True)

# secret option to ask for help
if answer.lower() in ["help", "h", "?"]:
lines = [
"Options:",
" y - execute the code",
" n - do not execute the code",
(" c - copy the code to the clipboard\n" if copiable else ""),
" auto - stop asking for the rest of the session",
f"Default is '{'y' if default else 'n'}' if answer is empty.",
]
helptext = "\n".join(line for line in lines if line)
print(helptext)
return ask_execute(question, default)

return answer.lower() in (["y", "yes"] + [""] if default else [])


def clean_example(s: str, strict=False) -> str:
orig = s
s = re.sub(
Expand Down
Loading

0 comments on commit d01b943

Please sign in to comment.