Skip to content

Commit

Permalink
More work-in-progress OpenAI 1.0 porting, refs #325
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jan 26, 2024
1 parent 53c845e commit 23cbb44
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 100 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ jobs:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
pydantic: ["==1.10.2", ">=2.0.0"]
openai: ["<1.0", ">=1.0"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -26,7 +25,6 @@ jobs:
run: |
pip install -e '.[test]'
pip install 'pydantic${{ matrix.pydantic }}'
pip install 'openai${{ matrix.openai }}'
- name: Run tests
run: |
pytest
Expand Down
6 changes: 3 additions & 3 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ To run the tests:

The default OpenAI plugin has a debugging mechanism for showing the exact responses that came back from the OpenAI API.

Set the `LLM_OPENAI_SHOW_RESPONSES` environment variable like this:
Set the `OPENAI_LOG` environment variable like this:
```bash
LLM_OPENAI_SHOW_RESPONSES=1 llm -m chatgpt 'three word slogan for an an otter-run bakery'
OPENAI_LOG=debug llm -m chatgpt 'three word slogan for an an otter-run bakery'
```
This will output the response (including streaming responses) to standard error, as shown in [issues 286](https://github.com/simonw/llm/issues/286).
This will output details of the API request to the console.

## Documentation

Expand Down
124 changes: 40 additions & 84 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,23 @@
from llm.utils import dicts_to_table_string
import click
import datetime
import httpx
import openai
import os

try:
# Pydantic 2
from pydantic import field_validator, Field # type: ignore

except ImportError:
# Pydantic 1
from pydantic.fields import Field
from pydantic.class_validators import validator as field_validator # type: ignore [no-redef]
import requests

from typing import List, Iterable, Iterator, Optional, Union
import json
import yaml


def _log_response(response, *args, **kwargs):
click.echo(response.text, err=True)
return response


_log_session = requests.Session()
_log_session.hooks["response"].append(_log_response)


IS_OPENAI_PRE_1 = openai.version.VERSION.startswith("0.")


if IS_OPENAI_PRE_1 and os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
openai.requestssession = _log_session # type: ignore


class OpenAILegacyWrapper:
def __init__(self, client):
self.client = client

@property
def ChatCompletion(self):
return self.client.chat.completions

@property
def Completion(self):
return self.client.completions

@property
def Embedding(self):
return self.client.embeddings


def get_openai_client():
if IS_OPENAI_PRE_1:
return openai

if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
client = openai.OpenAI(requestssession=_log_session)
else:
client = openai.OpenAI()

return OpenAILegacyWrapper(client)


client = get_openai_client()


@hookimpl
def register_models(register):
register(Chat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt"))
Expand Down Expand Up @@ -168,7 +123,7 @@ def models(json_, key):
from llm.cli import get_key

api_key = get_key(key, "openai", "OPENAI_API_KEY")
response = requests.get(
response = httpx.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
)
Expand Down Expand Up @@ -343,8 +298,9 @@ def execute(self, prompt, stream, response, conversation=None):
messages.append({"role": "user", "content": prompt.prompt})
response._prompt_json = {"messages": messages}
kwargs = self.build_kwargs(prompt)
client = self.get_client()
if stream:
completion = client.ChatCompletion.create(
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=True,
Expand All @@ -353,25 +309,22 @@ def execute(self, prompt, stream, response, conversation=None):
chunks = []
for chunk in completion:
chunks.append(chunk)
content = chunk["choices"][0].get("delta", {}).get("content")
content = chunk.choices[0].delta.content
if content is not None:
yield content
response.response_json = combine_chunks(chunks)
else:
completion = client.ChatCompletion.create(
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=False,
**kwargs,
)
response.response_json = completion.to_dict_recursive()
response.response_json = completion.dict()
yield completion.choices[0].message.content

def build_kwargs(self, prompt):
kwargs = dict(not_nulls(prompt.options))
json_object = kwargs.pop("json_object", None)
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
kwargs["max_tokens"] = self.default_max_tokens
def get_client(self):
kwargs = {}
if self.api_base:
kwargs["api_base"] = self.api_base
if self.api_type:
Expand All @@ -380,8 +333,6 @@ def build_kwargs(self, prompt):
kwargs["api_version"] = self.api_version
if self.api_engine:
kwargs["engine"] = self.api_engine
if json_object:
kwargs["response_format"] = {"type": "json_object"}
if self.needs_key:
if self.key:
kwargs["api_key"] = self.key
Expand All @@ -391,6 +342,15 @@ def build_kwargs(self, prompt):
kwargs["api_key"] = "DUMMY_KEY"
if self.headers:
kwargs["headers"] = self.headers
return openai.OpenAI(**kwargs)

def build_kwargs(self, prompt):
kwargs = dict(not_nulls(prompt.options))
json_object = kwargs.pop("json_object", None)
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
kwargs["max_tokens"] = self.default_max_tokens
if json_object:
kwargs["response_format"] = {"type": "json_object"}
return kwargs


Expand Down Expand Up @@ -422,8 +382,9 @@ def execute(self, prompt, stream, response, conversation=None):
messages.append(prompt.prompt)
response._prompt_json = {"messages": messages}
kwargs = self.build_kwargs(prompt)
client = self.get_client()
if stream:
completion = client.Completion.create(
completion = client.completions.create(
model=self.model_name or self.model_id,
prompt="\n".join(messages),
stream=True,
Expand All @@ -432,26 +393,26 @@ def execute(self, prompt, stream, response, conversation=None):
chunks = []
for chunk in completion:
chunks.append(chunk)
content = chunk["choices"][0].get("text") or ""
content = chunk.choices[0].text
if content is not None:
yield content
response.response_json = combine_chunks(chunks)
else:
completion = client.Completion.create(
completion = client.completions.create(
model=self.model_name or self.model_id,
prompt="\n".join(messages),
stream=False,
**kwargs,
)
response.response_json = completion.to_dict_recursive()
response.response_json = completion.dict()
yield completion.choices[0]["text"]


def not_nulls(data) -> dict:
return {key: value for key, value in data if value is not None}


def combine_chunks(chunks: List[dict]) -> dict:
def combine_chunks(chunks: List) -> dict:
content = ""
role = None
finish_reason = None
Expand All @@ -461,28 +422,23 @@ def combine_chunks(chunks: List[dict]) -> dict:
logprobs = []

for item in chunks:
for choice in item["choices"]:
if (
"logprobs" in choice
and "text" in choice
and isinstance(choice["logprobs"], dict)
and "top_logprobs" in choice["logprobs"]
):
for choice in item.choices:
if choice.logprobs:
logprobs.append(
{
"text": choice["text"],
"top_logprobs": choice["logprobs"]["top_logprobs"],
"text": choice.text,
"top_logprobs": choice.logprobs.top_logprobs,
}
)
if "text" in choice and "delta" not in choice:
content += choice["text"]

if not hasattr(choice, "delta"):
content += choice.text
continue
if "role" in choice["delta"]:
role = choice["delta"]["role"]
if "content" in choice["delta"]:
content += choice["delta"]["content"]
if choice.get("finish_reason") is not None:
finish_reason = choice["finish_reason"]
role = choice.delta.role
if choice.delta.content is not None:
content += choice.delta.content
if choice.finish_reason is not None:
finish_reason = choice.finish_reason

# Imitations of the OpenAI API may be missing some of these fields
combined = {
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_long_description():
""",
install_requires=[
"click",
"openai",
"openai>=1.0",
"click-default-group>=1.2.3",
"sqlite-utils>=3.35.0",
"sqlite-migrate>=0.1a2",
Expand All @@ -52,14 +52,13 @@ def get_long_description():
"test": [
"pytest",
"numpy",
"requests-mock",
"pytest-httpx",
"cogapp",
"mypy",
"black",
"ruff",
"types-click",
"types-PyYAML",
"types-requests",
"types-setuptools",
]
},
Expand Down
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,18 @@ def register_models(self, register):


@pytest.fixture
def mocked_openai_chat(requests_mock):
return requests_mock.post(
"https://api.openai.com/v1/chat/completions",
def mocked_openai_chat(httpx_mock):
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/chat/completions",
json={
"model": "gpt-3.5-turbo",
"usage": {},
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
},
headers={"Content-Type": "application/json"},
)
return httpx_mock


@pytest.fixture
Expand Down
5 changes: 2 additions & 3 deletions tests/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def test_uses_correct_key(mocked_openai_chat, monkeypatch, tmpdir):
monkeypatch.setenv("OPENAI_API_KEY", "from-env")

def assert_key(key):
assert mocked_openai_chat.last_request.headers[
"Authorization"
] == "Bearer {}".format(key)
request = mocked_openai_chat.get_requests()[-1]
assert request.headers["Authorization"] == "Bearer {}".format(key)

runner = CliRunner()

Expand Down
3 changes: 2 additions & 1 deletion tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def test_llm_default_prompt(
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == "Bob, Alice, Eve\n"
assert mocked_openai_chat.last_request.headers["Authorization"] == "Bearer X"
last_request = mocked_openai_chat.get_requests()[-1]
assert last_request.headers["Authorization"] == "Bearer X"

# Was it logged?
rows = list(log_db["responses"].rows)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_templates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from click.testing import CliRunner
import json
from llm import Template
from llm.cli import cli
import os
Expand Down Expand Up @@ -173,11 +174,13 @@ def test_template_basic(
)
if expected_error is None:
assert result.exit_code == 0
assert mocked_openai_chat.last_request.json() == {
last_request = mocked_openai_chat.get_requests()[-1]
assert json.loads(last_request.content) == {
"model": expected_model,
"messages": [{"role": "user", "content": expected_input}],
"stream": False,
}
else:
assert result.exit_code == 1
assert result.output.strip() == expected_error
mocked_openai_chat.reset(assert_all_responses_were_requested=False)

0 comments on commit 23cbb44

Please sign in to comment.