Skip to content

Commit

Permalink
Add a CLI arg for the chat greedy mode (NVIDIA#450)
Browse files Browse the repository at this point in the history
Add a --greedy-mode flag argument to the chat command.

The flag can be used to enabled greedy mode for that invocation,
and it overrides the value in the configuration.

When the flag is not passed, the value from the configuration is
honoured.

Related: NVIDIA#371

Signed-off-by: Andrea Frittoli <[email protected]>
  • Loading branch information
afrittoli authored Mar 8, 2024
1 parent b2a4a5b commit c045340
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
8 changes: 6 additions & 2 deletions cli/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,9 @@ def start_prompt(self, content=None, box=True):
self._update_conversation(response_content.plain, "assistant")


def chat_cli(logger, api_base, config, question, model, context, session, qq):
def chat_cli(
logger, api_base, config, question, model, context, session, qq, greedy_mode
):
"""Starts a CLI-based chat with the server"""
client = OpenAI(base_url=api_base, api_key="no_api_key")

Expand Down Expand Up @@ -414,7 +416,9 @@ def chat_cli(logger, api_base, config, question, model, context, session, qq):
prompt=not qq,
vertical_overflow=("visible" if config.visible_overflow else "ellipsis"),
loaded=loaded,
greedy_mode=config.greedy_mode,
greedy_mode=greedy_mode
if greedy_mode
else config.greedy_mode, # The CLI flag can only be used to enable
)

if not qq:
Expand Down
9 changes: 8 additions & 1 deletion cli/lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,14 @@ def generate(
is_flag=True,
help="Exit after answering question",
)
@click.option(
"-gm",
"--greedy-mode",
is_flag=True,
help="Use model greedy decoding. Useful for debugging and reproducing errors.",
)
@click.pass_context
def chat(ctx, question, model, context, session, quick_question):
def chat(ctx, question, model, context, session, quick_question, greedy_mode):
"""Run a chat using the modified model"""
api_base = ctx.obj.config.serve.api_base()
try:
Expand All @@ -352,6 +358,7 @@ def chat(ctx, question, model, context, session, quick_question):
context=context,
session=session,
qq=quick_question,
greedy_mode=greedy_mode,
)
except ChatException as exc:
click.secho(f"Executing chat failed with: {exc}", fg="red")
Expand Down

0 comments on commit c045340

Please sign in to comment.