Skip to content

Commit

Permalink
llm openai models command, closes #70
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jul 10, 2023
1 parent ae47f0a commit ce70b4d
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 0 deletions.
48 changes: 48 additions & 0 deletions llm/openai_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from . import Model, Prompt, OptionsError, Response, hookimpl
from .errors import NeedsKeyException
from .utils import dicts_to_table_string
import click
import datetime
from typing import Optional
import openai
import requests
import json


@hookimpl
Expand All @@ -12,6 +17,49 @@ def register_models(register):
register(Chat("gpt-4-32k"), aliases=("4-32k",))


@hookimpl
def register_commands(cli):
@cli.group(name="openai")
def openai_():
"Commands for working directly with the OpenAI API"

@openai_.command()
@click.option("json_", "--json", is_flag=True, help="Output as JSON")
@click.option("--key", help="OpenAI API key")
def models(json_, key):
"List models available to you from the OpenAI API"
from .cli import get_key

api_key = get_key(key, "openai", "OPENAI_API_KEY")
response = requests.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
)
if response.status_code != 200:
raise click.ClickException(
f"Error {response.status_code} from OpenAI API: {response.text}"
)
models = response.json()["data"]
if json_:
click.echo(json.dumps(models, indent=4))
else:
to_print = []
for model in models:
# Print id, owned_by, root, created as ISO 8601
created_str = datetime.datetime.fromtimestamp(
model["created"]
).isoformat()
to_print.append(
{
"id": model["id"],
"owned_by": model["owned_by"],
"created": created_str,
}
)
done = dicts_to_table_string("id owned_by created".split(), to_print)
print("\n".join(done))


class ChatResponse(Response):
def __init__(self, prompt, stream, key):
super().__init__(prompt)
Expand Down
25 changes: 25 additions & 0 deletions llm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List, Dict


def dicts_to_table_string(
headings: List[str], dicts: List[Dict[str, str]]
) -> List[str]:
max_lengths = [len(h) for h in headings]

# Compute maximum length for each column
for d in dicts:
for i, h in enumerate(headings):
if h in d and len(str(d[h])) > max_lengths[i]:
max_lengths[i] = len(str(d[h]))

# Generate formatted table strings
res = []
res.append(" ".join(h.ljust(max_lengths[i]) for i, h in enumerate(headings)))

for d in dicts:
row = []
for i, h in enumerate(headings):
row.append(str(d.get(h, "")).ljust(max_lengths[i]))
res.append(" ".join(row))

return res
39 changes: 39 additions & 0 deletions tests/test_cli_openai_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from click.testing import CliRunner
from llm.cli import cli
import pytest


@pytest.fixture
def mocked_models(requests_mock):
requests_mock.get(
"https://api.openai.com/v1/models",
json={
"data": [
{
"id": "ada:2020-05-03",
"object": "model",
"created": 1588537600,
"owned_by": "openai",
},
{
"id": "babbage:2020-05-03",
"object": "model",
"created": 1588537600,
"owned_by": "openai",
},
]
},
headers={"Content-Type": "application/json"},
)
return requests_mock


def test_openai_models(mocked_models):
runner = CliRunner()
result = runner.invoke(cli, ["openai", "models", "--key", "x"])
assert result.exit_code == 0
assert result.output == (
"id owned_by created \n"
"ada:2020-05-03 openai 2020-05-03T13:26:40\n"
"babbage:2020-05-03 openai 2020-05-03T13:26:40\n"
)

0 comments on commit ce70b4d

Please sign in to comment.