Skip to content

Commit

Permalink
Async model support, closes #13
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 19, 2024
1 parent aed678f commit 98e0ae7
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 32 deletions.
91 changes: 83 additions & 8 deletions llm_mistral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import click
from httpx_sse import connect_sse
from httpx_sse import connect_sse, aconnect_sse
import httpx
import json
import llm
Expand Down Expand Up @@ -29,7 +29,11 @@ def register_models(register):
our_model_id = "mistral/" + model_id
alias = DEFAULT_ALIASES.get(our_model_id)
aliases = [alias] if alias else []
register(Mistral(our_model_id, model_id, vision), aliases=aliases)
register(
Mistral(our_model_id, model_id, vision),
AsyncMistral(our_model_id, model_id, vision),
aliases=aliases,
)


@llm.hookimpl
Expand Down Expand Up @@ -103,7 +107,7 @@ def refresh():
click.echo("No changes", err=True)


class Mistral(llm.Model):
class _Shared:
can_stream = True
needs_key = "mistral"
key_env_var = "LLM_MISTRAL_KEY"
Expand Down Expand Up @@ -210,17 +214,16 @@ def build_messages(self, prompt, conversation):
messages.append(
{"role": "user", "content": prev_response.prompt.prompt}
)
messages.append({"role": "assistant", "content": prev_response.text()})
messages.append(
{"role": "assistant", "content": prev_response.text_or_raise()}
)
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})

messages.append(latest_message)
return messages

def execute(self, prompt, stream, response, conversation):
key = self.get_key()
messages = self.build_messages(prompt, conversation)
response._prompt_json = {"messages": messages}
def build_body(self, prompt, messages):
body = {
"model": self.mistral_model_id,
"messages": messages,
Expand All @@ -235,6 +238,15 @@ def execute(self, prompt, stream, response, conversation):
body["safe_mode"] = prompt.options.safe_mode
if prompt.options.random_seed:
body["random_seed"] = prompt.options.random_seed
return body


class Mistral(_Shared, llm.Model):
def execute(self, prompt, stream, response, conversation):
key = self.get_key()
messages = self.build_messages(prompt, conversation)
response._prompt_json = {"messages": messages}
body = self.build_body(prompt, messages)
if stream:
body["stream"] = True
with httpx.Client() as client:
Expand Down Expand Up @@ -292,6 +304,69 @@ def execute(self, prompt, stream, response, conversation):
response.response_json = api_response.json()


class AsyncMistral(_Shared, llm.AsyncModel):
async def execute(self, prompt, stream, response, conversation):
key = self.get_key()
messages = self.build_messages(prompt, conversation)
response._prompt_json = {"messages": messages}
body = self.build_body(prompt, messages)
if stream:
body["stream"] = True
async with httpx.AsyncClient() as client:
async with aconnect_sse(
client,
"POST",
"https://api.mistral.ai/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {key}",
},
json=body,
timeout=None,
) as event_source:
# In case of unauthorized:
if event_source.response.status_code != 200:
# Try to make this a readable error, it may have a base64 chunk
try:
decoded = json.loads(event_source.response.read())
type = decoded["type"]
words = decoded["message"].split()
except (json.JSONDecodeError, KeyError):
click.echo(
event_source.response.read().decode()[:200], err=True
)
event_source.response.raise_for_status()
# Truncate any words longer than 30 characters
words = [word[:30] for word in words]
message = " ".join(words)
raise click.ClickException(
f"{event_source.response.status_code}: {type} - {message}"
)
event_source.response.raise_for_status()
async for sse in event_source.aiter_sse():
if sse.data != "[DONE]":
try:
yield sse.json()["choices"][0]["delta"]["content"]
except KeyError:
pass
else:
async with httpx.AsyncClient() as client:
api_response = await client.post(
"https://api.mistral.ai/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {key}",
},
json=body,
timeout=None,
)
api_response.raise_for_status()
yield api_response.json()["choices"][0]["message"]["content"]
response.response_json = api_response.json()


class MistralEmbed(llm.EmbeddingModel):
model_id = "mistral-embed"
batch_size = 10
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"llm>=0.17",
"llm>=0.18",
"httpx",
"httpx-sse",
]
Expand All @@ -24,4 +24,4 @@ CI = "https://github.com/simonw/llm-mistral/actions"
mistral = "llm_mistral"

[project.optional-dependencies]
test = ["pytest", "pytest-httpx"]
test = ["pytest", "pytest-httpx", "pytest-asyncio"]
73 changes: 51 additions & 22 deletions tests/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,32 @@ def mocked_stream(httpx_mock):
return httpx_mock


@pytest.fixture
def mocked_no_stream(httpx_mock):
httpx_mock.add_response(
url="https://api.mistral.ai/v1/chat/completions",
method="POST",
json={
"id": "cmpl-362653b3050c4939bfa423af5f97709b",
"object": "chat.completion",
"created": 1702614202,
"model": "mistral-tiny",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'm just a computer program, I don't have feelings.",
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 16, "total_tokens": 79, "completion_tokens": 63},
},
)
return httpx_mock


def test_stream(mocked_stream):
model = llm.get_model("mistral-tiny")
response = model.prompt("How are you?")
Expand All @@ -104,6 +130,30 @@ def test_stream(mocked_stream):
}


@pytest.mark.asyncio
async def test_stream_async(mocked_stream):
model = llm.get_async_model("mistral-tiny")
response = await model.prompt("How are you?")
chunks = [item async for item in response]
assert chunks == ["I am an AI"]
request = mocked_stream.get_request()
assert json.loads(request.content) == {
"model": "mistral-tiny",
"messages": [{"role": "user", "content": "How are you?"}],
"temperature": 0.7,
"top_p": 1,
"stream": True,
}


@pytest.mark.asyncio
async def test_async_no_stream(mocked_no_stream):
model = llm.get_async_model("mistral-tiny")
response = await model.prompt("How are you?", stream=False)
text = await response.text()
assert text == "I'm just a computer program, I don't have feelings."


def test_stream_with_options(mocked_stream):
model = llm.get_model("mistral-tiny")
model.prompt(
Expand All @@ -127,28 +177,7 @@ def test_stream_with_options(mocked_stream):
}


def test_no_stream(httpx_mock):
httpx_mock.add_response(
url="https://api.mistral.ai/v1/chat/completions",
method="POST",
json={
"id": "cmpl-362653b3050c4939bfa423af5f97709b",
"object": "chat.completion",
"created": 1702614202,
"model": "mistral-tiny",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'm just a computer program, I don't have feelings.",
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 16, "total_tokens": 79, "completion_tokens": 63},
},
)
def test_no_stream(mocked_no_stream):
model = llm.get_model("mistral-tiny")
response = model.prompt("How are you?", stream=False)
assert response.text() == "I'm just a computer program, I don't have feelings."

0 comments on commit 98e0ae7

Please sign in to comment.