diff --git a/gptcli/cli.py b/gptcli/cli.py index d3f2f93..65e1040 100644 --- a/gptcli/cli.py +++ b/gptcli/cli.py @@ -1,26 +1,19 @@ import re +from typing import Any, Dict, Optional, Tuple + +from openai import BadRequestError, OpenAIError from prompt_toolkit import PromptSession from prompt_toolkit.history import FileHistory -from openai import OpenAIError, BadRequestError from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent from prompt_toolkit.key_binding.bindings import named_commands from rich.console import Console from rich.live import Live from rich.markdown import Markdown -from typing import Any, Dict, Optional, Tuple - from rich.text import Text -from gptcli.session import ( - ALL_COMMANDS, - COMMAND_CLEAR, - COMMAND_QUIT, - COMMAND_RERUN, - ChatListener, - InvalidArgumentError, - ResponseStreamer, - UserInputProvider, -) +from gptcli.session import (ALL_COMMANDS, COMMAND_CLEAR, COMMAND_QUIT, + COMMAND_RERUN, ChatListener, InvalidArgumentError, + ResponseStreamer, UserInputProvider) TERMINAL_WELCOME = """ Hi! I'm here to help. Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C and Enter to clear @@ -121,12 +114,38 @@ def response_streamer(self) -> ResponseStreamer: def parse_args(input: str) -> Tuple[str, Dict[str, Any]]: + # Extract parts enclosed in specific delimiters (triple backticks, triple quotes, single backticks) + extracted_parts = [] + delimiters = ['```', '"""', '`'] + + def replacer(match): + for i, delimiter in enumerate(delimiters): + part = match.group(i + 1) + if part is not None: + extracted_parts.append((part, delimiter)) + break + return f"__EXTRACTED_PART_{len(extracted_parts) - 1}__" + + # Construct the regex pattern dynamically from the delimiters list + pattern_fragments = [re.escape(d) + '(.*?)' + re.escape(d) for d in delimiters] + pattern = re.compile('|'.join(pattern_fragments), re.DOTALL) + + input = pattern.sub(replacer, input) + + # Parse the remaining string for arguments args = {} - regex = r"--(\w+)(?:\s+|=)([^\s]+)" + regex = r'--(\w+)(?:=(\S+)|\s+(\S+))?' matches = re.findall(regex, input) + if matches: - args = dict(matches) - input = input.split("--")[0].strip() + for key, value1, value2 in matches: + value = value1 if value1 else value2 if value2 else '' + args[key] = value.strip("\"'") + input = re.sub(regex, "", input).strip() + + # Add back the extracted parts, with enclosing backticks or quotes + for i, (part, delimiter) in enumerate(extracted_parts): + input = input.replace(f"__EXTRACTED_PART_{i}__", f"{delimiter}{part.strip()}{delimiter}") return input, args diff --git a/tests/test_term_utils.py b/tests/test_term_utils.py index 511a4c8..fee724e 100644 --- a/tests/test_term_utils.py +++ b/tests/test_term_utils.py @@ -16,3 +16,60 @@ def test_parse_args(): "this is a prompt", {"bar": "1.0", "baz": "2.0"}, ) + assert parse_args("this is a prompt --bar 1.0") == ( + "this is a prompt", + {"bar": "1.0"}, + ) + + +def test_parse_with_escape_blocks(): + test_cases = [ + ( + # escaped text at end of prompt + "this is a prompt --bar=1.0 {start}--baz=2.0{end}", + "this is a prompt {start}--baz=2.0{end}", + {"bar": "1.0"}, + ), + ( + # escaped text in middle of prompt with equal assignment + "this is a prompt {start}--bar=1.0{end} --baz=2.0", + "this is a prompt {start}--bar=1.0{end}", + {"baz": "2.0"}, + ), + ( + # escaped text in middle of prompt with space assignment + "this is a prompt {start}--bar 1.0{end} --baz 2.0", + "this is a prompt {start}--bar 1.0{end}", + {"baz": "2.0"}, + ), + ( + # escaped text in multiple escape sequences + 'this is a prompt --bar=1.0 {start}my first context block{end} and then ```my second context block``` --baz=2.0', + 'this is a prompt {start}my first context block{end} and then ```my second context block```', + {'bar': '1.0', 'baz': '2.0'}, + ), + ( + # entire prompt is escaped + "{start}this is a prompt --bar=1.0 --baz=2.0{end}", + "{start}this is a prompt --bar=1.0 --baz=2.0{end}", + {}, + ), + ( + # multi-line escaped text + "this is a prompt \n--bar=1.0 --baz=2.0\n{start}--foo=3.0 \n another line \nmy final line{end}", + "this is a prompt \n \n{start}--foo=3.0 \n another line \nmy final line{end}", + {'bar': '1.0', 'baz': '2.0'}, + ) + + ] + + delimiters = ["```", '"""', "`"] + + for start, end in [(d, d) for d in delimiters]: + for prompt, expected_prompt, expected_args in test_cases: + formatted_prompt = prompt.format(start=start, end=end) + formatted_expected_prompt = expected_prompt.format(start=start, end=end) + assert parse_args(formatted_prompt) == ( + formatted_expected_prompt, + expected_args, + ) \ No newline at end of file