Skip to content

Commit

Permalink
Add a greedy mode to the chat (NVIDIA#393)
Browse files Browse the repository at this point in the history
Add a greedy mode to the chat, which sets temperature=0 in the
endpoint call. This makes the model do greedy decoding, which
is deterministic and is useful for debugging and reproduce errors.

Fixes: NVIDIA#371

Signed-off-by: Andrea Frittoli <[email protected]>
  • Loading branch information
afrittoli authored Mar 8, 2024
1 parent f34ed32 commit b2a4a5b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
17 changes: 15 additions & 2 deletions cli/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ChatException(Exception):


# TODO Autosave chat history
class ConsoleChatBot:
class ConsoleChatBot: # pylint: disable=too-many-instance-attributes
def __init__(
self,
model,
Expand All @@ -69,13 +69,16 @@ def __init__(
vertical_overflow="ellipsis",
loaded={},
log_file=None,
greedy_mode=False,
):
self.client = client
self.model = model
self.vi_mode = vi_mode
self.vertical_overflow = vertical_overflow
self.loaded = loaded
self.log_file = log_file
self.greedy_mode = greedy_mode

self.console = Console()
self.input = (
PromptSession(history=FileHistory(PROMPT_HISTORY_FILEPATH))
Expand Down Expand Up @@ -306,10 +309,19 @@ def start_prompt(self, content=None, box=True):
self.multiline_mode = 0
self.multiline = not self.multiline

# Optional parameters
create_params = {}
if self.greedy_mode:
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature
create_params["temperature"] = 0

# Get and parse response
try:
response = self.client.chat.completions.create(
model=self.model, messages=self.info["messages"], stream=True
model=self.model,
messages=self.info["messages"],
stream=True,
**create_params,
)
assert (
next(response).choices[0].delta.role == "assistant"
Expand Down Expand Up @@ -402,6 +414,7 @@ def chat_cli(logger, api_base, config, question, model, context, session, qq):
prompt=not qq,
vertical_overflow=("visible" if config.visible_overflow else "ellipsis"),
loaded=loaded,
greedy_mode=config.greedy_mode,
)

if not qq:
Expand Down
3 changes: 3 additions & 0 deletions cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DEFAULT_PROMPT_FILE = "prompt.txt"
DEFAULT_SEED_FILE = "seed_tasks.json"
DEFAULT_GENERATED_FILES_OUTPUT_DIR = "generated"
DEFAULT_GREEDY_MODE = False


class ConfigException(Exception):
Expand All @@ -43,6 +44,7 @@ class _chat:
context: str
session: str
logs_dir: str
greedy_mode: bool


@dataclass
Expand Down Expand Up @@ -116,6 +118,7 @@ def get_default_config():
context="default",
session=None,
logs_dir=DEFAULT_CHAT_LOGS,
greedy_mode=DEFAULT_GREEDY_MODE,
)
generate = _generate(
model=DEFAULT_MODEL,
Expand Down

0 comments on commit b2a4a5b

Please sign in to comment.