Skip to content

Commit

Permalink
Remove in-chat model overrides (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
kharvd authored Nov 17, 2024
1 parent efb8126 commit eb5991a
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 248 deletions.
8 changes: 0 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,6 @@ optional arguments:
Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C to clear the conversation, `:r` or Ctrl-R to re-generate the last response.
To enter multi-line mode, enter a backslash `\` followed by a new line. Exit the multi-line mode by pressing ESC and then Enter.

You can override the model parameters using `--model`, `--temperature` and `--top_p` arguments at the end of your prompt. For example:

```
> What is the meaning of life? --model gpt-4 --temperature 2.0
The meaning of life is subjective and can be different for diverse human beings and unique-phil ethics.org/cultuties-/ it that reson/bdstals89im3_jrf334;mvs-bread99ef=g22me
```

The `dev` assistant is instructed to be an expert in software development and provide short responses.

```bash
Expand Down Expand Up @@ -197,7 +190,6 @@ assistants:
- { role: system, content: !include "pirate.txt" }
```


### Customize OpenAI API URL

If you are using other models compatible with the OpenAI Python SDK, you can configure them by modifying the `openai_base_url` setting in the config file or using the `OPENAI_BASE_URL` environment variable .
Expand Down
23 changes: 7 additions & 16 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from gptcli.completion import (
CompletionEvent,
CompletionProvider,
ModelOverrides,
Message,
)
from gptcli.providers.google import GoogleCompletionProvider
Expand Down Expand Up @@ -107,28 +106,20 @@ def from_config(cls, name: str, config: AssistantConfig):
def init_messages(self) -> List[Message]:
return self.config.get("messages", [])[:]

def supported_overrides(self) -> List[str]:
return ["model", "temperature", "top_p"]

def _param(self, param: str, override_params: ModelOverrides) -> Any:
# If the param is in the override_params, use that value
# Otherwise, use the value from the config
def _param(self, param: str) -> Any:
# Use the value from the config if exists
# Otherwise, use the default value
return override_params.get(
param, self.config.get(param, CONFIG_DEFAULTS[param])
)
return self.config.get(param, CONFIG_DEFAULTS[param])

def complete_chat(
self, messages, override_params: ModelOverrides = {}, stream: bool = True
) -> Iterator[CompletionEvent]:
model = self._param("model", override_params)
def complete_chat(self, messages, stream: bool = True) -> Iterator[CompletionEvent]:
model = self._param("model")
completion_provider = get_completion_provider(model)
return completion_provider.complete(
messages,
{
"model": model,
"temperature": float(self._param("temperature", override_params)),
"top_p": float(self._param("top_p", override_params)),
"temperature": float(self._param("temperature")),
"top_p": float(self._param("top_p")),
},
stream,
)
Expand Down
62 changes: 13 additions & 49 deletions gptcli/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any, Dict, Optional, Tuple
from typing import Optional

from openai import BadRequestError, OpenAIError
from prompt_toolkit import PromptSession
Expand All @@ -11,9 +10,16 @@
from rich.markdown import Markdown
from rich.text import Text

from gptcli.session import (ALL_COMMANDS, COMMAND_CLEAR, COMMAND_QUIT,
COMMAND_RERUN, ChatListener, InvalidArgumentError,
ResponseStreamer, UserInputProvider)
from gptcli.session import (
ALL_COMMANDS,
COMMAND_CLEAR,
COMMAND_QUIT,
COMMAND_RERUN,
ChatListener,
InvalidArgumentError,
ResponseStreamer,
UserInputProvider,
)

TERMINAL_WELCOME = """
Hi! I'm here to help. Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C and Enter to clear
Expand Down Expand Up @@ -113,43 +119,6 @@ def response_streamer(self) -> ResponseStreamer:
return CLIResponseStreamer(self.console, self.markdown)


def parse_args(input: str) -> Tuple[str, Dict[str, Any]]:
# Extract parts enclosed in specific delimiters (triple backticks, triple quotes, single backticks)
extracted_parts = []
delimiters = ['```', '"""', '`']

def replacer(match):
for i, delimiter in enumerate(delimiters):
part = match.group(i + 1)
if part is not None:
extracted_parts.append((part, delimiter))
break
return f"__EXTRACTED_PART_{len(extracted_parts) - 1}__"

# Construct the regex pattern dynamically from the delimiters list
pattern_fragments = [re.escape(d) + '(.*?)' + re.escape(d) for d in delimiters]
pattern = re.compile('|'.join(pattern_fragments), re.DOTALL)

input = pattern.sub(replacer, input)

# Parse the remaining string for arguments
args = {}
regex = r'--(\w+)(?:=(\S+)|\s+(\S+))?'
matches = re.findall(regex, input)

if matches:
for key, value1, value2 in matches:
value = value1 if value1 else value2 if value2 else ''
args[key] = value.strip("\"'")
input = re.sub(regex, "", input).strip()

# Add back the extracted parts, with enclosing backticks or quotes
for i, (part, delimiter) in enumerate(extracted_parts):
input = input.replace(f"__EXTRACTED_PART_{i}__", f"{delimiter}{part.strip()}{delimiter}")

return input, args


class CLIFileHistory(FileHistory):
def append_string(self, string: str) -> None:
if string in ALL_COMMANDS:
Expand All @@ -163,12 +132,11 @@ def __init__(self, history_filename) -> None:
history=CLIFileHistory(history_filename)
)

def get_user_input(self) -> Tuple[str, Dict[str, Any]]:
def get_user_input(self) -> str:
while (next_user_input := self._request_input()) == "":
pass

user_input, args = self._parse_input(next_user_input)
return user_input, args
return next_user_input

def prompt(self, multiline=False):
bindings = KeyBindings()
Expand Down Expand Up @@ -219,7 +187,3 @@ def _request_input(self):
return line

return self.prompt(multiline=True)

def _parse_input(self, input: str) -> Tuple[str, Dict[str, Any]]:
input, args = parse_args(input)
return input, args
6 changes: 0 additions & 6 deletions gptcli/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ class Message(TypedDict):
content: str


class ModelOverrides(TypedDict, total=False):
model: str
temperature: float
top_p: float


class Pricing(TypedDict):
prompt: float
response: float
Expand Down
5 changes: 2 additions & 3 deletions gptcli/composite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gptcli.completion import Message, ModelOverrides, UsageEvent
from gptcli.completion import Message, UsageEvent
from gptcli.session import ChatListener, ResponseStreamer


Expand Down Expand Up @@ -56,8 +56,7 @@ def on_chat_response(
self,
messages: List[Message],
response: Message,
overrides: ModelOverrides,
usage: Optional[UsageEvent],
):
for listener in self.listeners:
listener.on_chat_response(messages, response, overrides, usage)
listener.on_chat_response(messages, response, usage)
5 changes: 2 additions & 3 deletions gptcli/cost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gptcli.assistant import Assistant
from gptcli.completion import Message, ModelOverrides, UsageEvent
from gptcli.completion import Message, UsageEvent
from gptcli.session import ChatListener

from rich.console import Console
Expand All @@ -22,13 +22,12 @@ def on_chat_response(
self,
messages: List[Message],
response: Message,
args: ModelOverrides,
usage: Optional[UsageEvent] = None,
):
if usage is None:
return

model = self.assistant._param("model", args)
model = self.assistant._param("model")
num_tokens = usage.total_tokens
cost = usage.cost

Expand Down
45 changes: 13 additions & 32 deletions gptcli/session.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from abc import abstractmethod
from typing_extensions import TypeGuard
from gptcli.assistant import Assistant
from gptcli.completion import (
Message,
ModelOverrides,
CompletionError,
BadRequestError,
UsageEvent,
)
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional


class ResponseStreamer:
Expand Down Expand Up @@ -45,15 +43,14 @@ def on_chat_response(
self,
messages: List[Message],
response: Message,
overrides: ModelOverrides,
usage: Optional[UsageEvent] = None,
):
pass


class UserInputProvider:
@abstractmethod
def get_user_input(self) -> Tuple[str, Dict[str, Any]]:
def get_user_input(self) -> str:
pass


Expand Down Expand Up @@ -85,7 +82,7 @@ def __init__(
):
self.assistant = assistant
self.messages: List[Message] = assistant.init_messages()
self.user_prompts: List[Tuple[Message, ModelOverrides]] = []
self.user_prompts: List[Message] = []
self.listener = listener
self.stream = stream

Expand All @@ -103,18 +100,17 @@ def _rerun(self):
self.messages = self.messages[:-1]

self.listener.on_chat_rerun(True)
_, args = self.user_prompts[-1]
self._respond(args)
self._respond()

def _respond(self, overrides: ModelOverrides) -> bool:
def _respond(self) -> bool:
"""
Respond to the user's input and return whether the assistant's response was saved.
"""
next_response: str = ""
usage: Optional[UsageEvent] = None
try:
completion_iter = self.assistant.complete_chat(
self.messages, override_params=overrides, stream=self.stream
self.messages, stream=self.stream
)

with self.listener.response_streamer() as stream:
Expand All @@ -137,28 +133,16 @@ def _respond(self, overrides: ModelOverrides) -> bool:

next_message: Message = {"role": "assistant", "content": next_response}
self.listener.on_chat_message(next_message)
self.listener.on_chat_response(self.messages, next_message, overrides, usage)
self.listener.on_chat_response(self.messages, next_message, usage)

self.messages = self.messages + [next_message]
return True

def _validate_args(self, args: Dict[str, Any]) -> TypeGuard[ModelOverrides]:
for key in args:
supported_overrides = self.assistant.supported_overrides()
if key not in supported_overrides:
self.listener.on_error(
InvalidArgumentError(
f"Invalid argument: {key}. Allowed arguments: {supported_overrides}"
)
)
return False
return True

def _add_user_message(self, user_input: str, args: ModelOverrides):
def _add_user_message(self, user_input: str):
user_message: Message = {"role": "user", "content": user_input}
self.messages = self.messages + [user_message]
self.listener.on_chat_message(user_message)
self.user_prompts.append((user_message, args))
self.user_prompts.append(user_message)

def _rollback_user_message(self):
self.messages = self.messages[:-1]
Expand All @@ -168,13 +152,10 @@ def _print_help(self):
with self.listener.response_streamer() as stream:
stream.on_next_token(COMMANDS_HELP)

def process_input(self, user_input: str, args: Dict[str, Any]):
def process_input(self, user_input: str):
"""
Process the user's input and return whether the session should continue.
"""
if not self._validate_args(args):
return True

if user_input in COMMAND_QUIT:
return False
elif user_input in COMMAND_CLEAR:
Expand All @@ -187,14 +168,14 @@ def process_input(self, user_input: str, args: Dict[str, Any]):
self._print_help()
return True

self._add_user_message(user_input, args)
response_saved = self._respond(args)
self._add_user_message(user_input)
response_saved = self._respond()
if not response_saved:
self._rollback_user_message()

return True

def loop(self, input_provider: UserInputProvider):
self.listener.on_chat_start()
while self.process_input(*input_provider.get_user_input()):
while self.process_input(input_provider.get_user_input()):
pass
Loading

0 comments on commit eb5991a

Please sign in to comment.