Skip to content

Commit

Permalink
feat: validation of API keys (#86)
Browse files Browse the repository at this point in the history
* feat: validation of API keys

* fix: black formating
  • Loading branch information
zhuraromdev authored Jun 20, 2024
1 parent ad3ca69 commit e7b21a2
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions surfkit/env_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

import typer
import openai
from mllm import Router

from surfkit.types import AgentType
Expand Down Expand Up @@ -34,10 +35,19 @@ def find_local_llm_keys(typ: AgentType) -> Optional[dict]:
return env_vars


def is_api_key_valid(api_key: str) -> bool:
client = openai.OpenAI(api_key=api_key)
try:
client.models.list()
except openai.AuthenticationError:
return False
else:
return True


def find_llm_keys(typ: AgentType, llm_providers_local: bool) -> Optional[dict]:
env_vars = None
if typ.llm_providers and typ.llm_providers.preference:

found = {}

if llm_providers_local:
Expand All @@ -55,10 +65,11 @@ def find_llm_keys(typ: AgentType, llm_providers_local: bool) -> Optional[dict]:
if api_key_env:
typer.echo(f" - {api_key_env}")
typer.echo("")

for provider_name in typ.llm_providers.preference:
api_key_env = Router.provider_api_keys.get(provider_name)
if not api_key_env:
raise ValueError(f"no api key env for provider {provider_name}")
raise ValueError(f"No API key env for provider {provider_name}")

if found.get(api_key_env):
continue
Expand All @@ -75,18 +86,28 @@ def find_llm_keys(typ: AgentType, llm_providers_local: bool) -> Optional[dict]:

if not found:
for provider_name in typ.llm_providers.preference:
add = typer.confirm(
f"Would you like to enter an API key for '{provider_name}'"
)
if add:
api_key_env = Router.provider_api_keys.get(provider_name)
if not api_key_env:
continue
response = typer.prompt(api_key_env)
found[api_key_env] = response
while True:
add = typer.confirm(
f"Would you like to enter an API key for '{provider_name}'"
)
if add:
api_key_env = Router.provider_api_keys.get(provider_name)
if not api_key_env:
continue
response = typer.prompt(api_key_env)
if is_api_key_valid(response):
found[api_key_env] = response
break
else:
typer.echo(
f"The API Key is not valid for '{provider_name}'. Please try again."
)
else:
break

if not found:
raise ValueError(
"No API keys given for any of the llm providers in the agent type"
"No valid API keys given for any of the llm providers in the agent type"
)

env_vars = found
Expand Down

0 comments on commit e7b21a2

Please sign in to comment.