Skip to content

Commit

Permalink
Escapable blocks of text (#83)
Browse files Browse the repository at this point in the history
allow escapable blocks of text when wrapped in triple backticks or
quotes. Fixes #52
  • Loading branch information
sghael authored Jul 24, 2024
1 parent 93b5ed9 commit 67491ba
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 16 deletions.
51 changes: 35 additions & 16 deletions gptcli/cli.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
import re
from typing import Any, Dict, Optional, Tuple

from openai import BadRequestError, OpenAIError
from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
from openai import OpenAIError, BadRequestError
from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent
from prompt_toolkit.key_binding.bindings import named_commands
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from typing import Any, Dict, Optional, Tuple

from rich.text import Text
from gptcli.session import (
ALL_COMMANDS,
COMMAND_CLEAR,
COMMAND_QUIT,
COMMAND_RERUN,
ChatListener,
InvalidArgumentError,
ResponseStreamer,
UserInputProvider,
)

from gptcli.session import (ALL_COMMANDS, COMMAND_CLEAR, COMMAND_QUIT,
COMMAND_RERUN, ChatListener, InvalidArgumentError,
ResponseStreamer, UserInputProvider)

TERMINAL_WELCOME = """
Hi! I'm here to help. Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C and Enter to clear
Expand Down Expand Up @@ -121,12 +114,38 @@ def response_streamer(self) -> ResponseStreamer:


def parse_args(input: str) -> Tuple[str, Dict[str, Any]]:
# Extract parts enclosed in specific delimiters (triple backticks, triple quotes, single backticks)
extracted_parts = []
delimiters = ['```', '"""', '`']

def replacer(match):
for i, delimiter in enumerate(delimiters):
part = match.group(i + 1)
if part is not None:
extracted_parts.append((part, delimiter))
break
return f"__EXTRACTED_PART_{len(extracted_parts) - 1}__"

# Construct the regex pattern dynamically from the delimiters list
pattern_fragments = [re.escape(d) + '(.*?)' + re.escape(d) for d in delimiters]
pattern = re.compile('|'.join(pattern_fragments), re.DOTALL)

input = pattern.sub(replacer, input)

# Parse the remaining string for arguments
args = {}
regex = r"--(\w+)(?:\s+|=)([^\s]+)"
regex = r'--(\w+)(?:=(\S+)|\s+(\S+))?'
matches = re.findall(regex, input)

if matches:
args = dict(matches)
input = input.split("--")[0].strip()
for key, value1, value2 in matches:
value = value1 if value1 else value2 if value2 else ''
args[key] = value.strip("\"'")
input = re.sub(regex, "", input).strip()

# Add back the extracted parts, with enclosing backticks or quotes
for i, (part, delimiter) in enumerate(extracted_parts):
input = input.replace(f"__EXTRACTED_PART_{i}__", f"{delimiter}{part.strip()}{delimiter}")

return input, args

Expand Down
57 changes: 57 additions & 0 deletions tests/test_term_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,60 @@ def test_parse_args():
"this is a prompt",
{"bar": "1.0", "baz": "2.0"},
)
assert parse_args("this is a prompt --bar 1.0") == (
"this is a prompt",
{"bar": "1.0"},
)


def test_parse_with_escape_blocks():
test_cases = [
(
# escaped text at end of prompt
"this is a prompt --bar=1.0 {start}--baz=2.0{end}",
"this is a prompt {start}--baz=2.0{end}",
{"bar": "1.0"},
),
(
# escaped text in middle of prompt with equal assignment
"this is a prompt {start}--bar=1.0{end} --baz=2.0",
"this is a prompt {start}--bar=1.0{end}",
{"baz": "2.0"},
),
(
# escaped text in middle of prompt with space assignment
"this is a prompt {start}--bar 1.0{end} --baz 2.0",
"this is a prompt {start}--bar 1.0{end}",
{"baz": "2.0"},
),
(
# escaped text in multiple escape sequences
'this is a prompt --bar=1.0 {start}my first context block{end} and then ```my second context block``` --baz=2.0',
'this is a prompt {start}my first context block{end} and then ```my second context block```',
{'bar': '1.0', 'baz': '2.0'},
),
(
# entire prompt is escaped
"{start}this is a prompt --bar=1.0 --baz=2.0{end}",
"{start}this is a prompt --bar=1.0 --baz=2.0{end}",
{},
),
(
# multi-line escaped text
"this is a prompt \n--bar=1.0 --baz=2.0\n{start}--foo=3.0 \n another line \nmy final line{end}",
"this is a prompt \n \n{start}--foo=3.0 \n another line \nmy final line{end}",
{'bar': '1.0', 'baz': '2.0'},
)

]

delimiters = ["```", '"""', "`"]

for start, end in [(d, d) for d in delimiters]:
for prompt, expected_prompt, expected_args in test_cases:
formatted_prompt = prompt.format(start=start, end=end)
formatted_expected_prompt = expected_prompt.format(start=start, end=end)
assert parse_args(formatted_prompt) == (
formatted_expected_prompt,
expected_args,
)

0 comments on commit 67491ba

Please sign in to comment.