Skip to content

Commit

Permalink
Merge pull request #748 from PrefectHQ/refactor-fallback
Browse files Browse the repository at this point in the history
Update openai.py
  • Loading branch information
jlowin authored Jan 15, 2024
2 parents 72ef069 + aec3245 commit c365faf
Showing 1 changed file with 25 additions and 33 deletions.
58 changes: 25 additions & 33 deletions src/marvin/client/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,24 @@
FALLBACK_CHAT_COMPLETIONS_MODEL = "gpt-3.5-turbo"


def _request_is_using_marvin_default_model(request: ChatRequest) -> bool:
return (
request.model.startswith("gpt-4")
async def should_fallback(e: NotFoundError, request: ChatRequest) -> bool:
if (
"you do not have access" in str(e)
and request.model.startswith("gpt-4")
and request.model == marvin.settings.openai.chat.completions.model
)
):
get_logger().warning(
"Marvin's default chat model is"
f" {marvin.settings.openai.chat.completions.model!r}, which your"
" API key likely does not give you access to. This API call will"
f" fall back to the {FALLBACK_CHAT_COMPLETIONS_MODEL!r} model. To"
" avoid this warning, please set"
" `MARVIN_OPENAI_CHAT_COMPLETIONS_MODEL=<accessible model>` in"
f" `~/.marvin/.env` - for example, `gpt-3.5-turbo`.\n\n {e}"
)
return True
else:
return False


def _get_default_client(client_type: str) -> Union[Client, AsyncClient]:
Expand All @@ -53,7 +66,8 @@ def _get_default_client(client_type: str) -> Union[Client, AsyncClient]:
)
except AttributeError:
raise ValueError(
"To use Azure OpenAI, please set all of the following environment variables in `~/.marvin/.env`:"
"To use Azure OpenAI, please set all of the following environment"
" variables in `~/.marvin/.env`:"
"\n\n"
"```"
"\nMARVIN_USE_AZURE_OPENAI=true"
Expand Down Expand Up @@ -107,20 +121,9 @@ def generate_chat(
try:
response: "ChatCompletion" = create(**request.model_dump(exclude_none=True))
except NotFoundError as e:
if "you do not have access" in str(
e
) and _request_is_using_marvin_default_model(request):
get_logger().warning(
f"Marvin's default chat model is {marvin.settings.openai.chat.completions.model!r}, which"
" your API key likely does not give you access to. This API call will fall back to the"
f" {FALLBACK_CHAT_COMPLETIONS_MODEL!r} model. To avoid this warning,"
" please set `MARVIN_OPENAI_CHAT_COMPLETIONS_MODEL=<accessible model>` in `~/.marvin/.env` -"
f" for example, `gpt-3.5-turbo`.\n\n {e}"
)
return create(
**request.model_dump(
exclude_none=True,
)
if should_fallback(e, request):
response = create(
**request.model_dump(exclude_none=True)
| dict(model=FALLBACK_CHAT_COMPLETIONS_MODEL)
)
else:
Expand Down Expand Up @@ -192,20 +195,9 @@ async def generate_chat(
**request.model_dump(exclude_none=True)
)
except NotFoundError as e:
if "you do not have access" in str(
e
) and _request_is_using_marvin_default_model(request):
get_logger().warning(
f"Marvin's default chat model is {marvin.settings.openai.chat.completions.model!r}, which"
" your API key likely does not give you access to. This API call will fall back to the"
f" {FALLBACK_CHAT_COMPLETIONS_MODEL!r} model. To avoid this warning,"
" please set `MARVIN_OPENAI_CHAT_COMPLETIONS_MODEL=<accessible model>` in `~/.marvin/.env` -"
f" for example, `gpt-3.5-turbo`.\n\n {e}"
)
return await create(
**request.model_dump(
exclude_none=True,
)
if should_fallback(e, request):
response = await create(
**request.model_dump(exclude_none=True)
| dict(model=FALLBACK_CHAT_COMPLETIONS_MODEL)
)
else:
Expand Down

0 comments on commit c365faf

Please sign in to comment.