diff --git a/llm_mistral.py b/llm_mistral.py index f196568..0428ff3 100644 --- a/llm_mistral.py +++ b/llm_mistral.py @@ -1,5 +1,5 @@ import click -from httpx_sse import connect_sse +from httpx_sse import connect_sse, aconnect_sse import httpx import json import llm @@ -29,7 +29,11 @@ def register_models(register): our_model_id = "mistral/" + model_id alias = DEFAULT_ALIASES.get(our_model_id) aliases = [alias] if alias else [] - register(Mistral(our_model_id, model_id, vision), aliases=aliases) + register( + Mistral(our_model_id, model_id, vision), + AsyncMistral(our_model_id, model_id, vision), + aliases=aliases, + ) @llm.hookimpl @@ -103,7 +107,7 @@ def refresh(): click.echo("No changes", err=True) -class Mistral(llm.Model): +class _Shared: can_stream = True needs_key = "mistral" key_env_var = "LLM_MISTRAL_KEY" @@ -210,17 +214,16 @@ def build_messages(self, prompt, conversation): messages.append( {"role": "user", "content": prev_response.prompt.prompt} ) - messages.append({"role": "assistant", "content": prev_response.text()}) + messages.append( + {"role": "assistant", "content": prev_response.text_or_raise()} + ) if prompt.system and prompt.system != current_system: messages.append({"role": "system", "content": prompt.system}) messages.append(latest_message) return messages - def execute(self, prompt, stream, response, conversation): - key = self.get_key() - messages = self.build_messages(prompt, conversation) - response._prompt_json = {"messages": messages} + def build_body(self, prompt, messages): body = { "model": self.mistral_model_id, "messages": messages, @@ -235,6 +238,15 @@ def execute(self, prompt, stream, response, conversation): body["safe_mode"] = prompt.options.safe_mode if prompt.options.random_seed: body["random_seed"] = prompt.options.random_seed + return body + + +class Mistral(_Shared, llm.Model): + def execute(self, prompt, stream, response, conversation): + key = self.get_key() + messages = self.build_messages(prompt, conversation) + response._prompt_json = {"messages": messages} + body = self.build_body(prompt, messages) if stream: body["stream"] = True with httpx.Client() as client: @@ -292,6 +304,69 @@ def execute(self, prompt, stream, response, conversation): response.response_json = api_response.json() +class AsyncMistral(_Shared, llm.AsyncModel): + async def execute(self, prompt, stream, response, conversation): + key = self.get_key() + messages = self.build_messages(prompt, conversation) + response._prompt_json = {"messages": messages} + body = self.build_body(prompt, messages) + if stream: + body["stream"] = True + async with httpx.AsyncClient() as client: + async with aconnect_sse( + client, + "POST", + "https://api.mistral.ai/v1/chat/completions", + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {key}", + }, + json=body, + timeout=None, + ) as event_source: + # In case of unauthorized: + if event_source.response.status_code != 200: + # Try to make this a readable error, it may have a base64 chunk + try: + decoded = json.loads(event_source.response.read()) + type = decoded["type"] + words = decoded["message"].split() + except (json.JSONDecodeError, KeyError): + click.echo( + event_source.response.read().decode()[:200], err=True + ) + event_source.response.raise_for_status() + # Truncate any words longer than 30 characters + words = [word[:30] for word in words] + message = " ".join(words) + raise click.ClickException( + f"{event_source.response.status_code}: {type} - {message}" + ) + event_source.response.raise_for_status() + async for sse in event_source.aiter_sse(): + if sse.data != "[DONE]": + try: + yield sse.json()["choices"][0]["delta"]["content"] + except KeyError: + pass + else: + async with httpx.AsyncClient() as client: + api_response = await client.post( + "https://api.mistral.ai/v1/chat/completions", + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {key}", + }, + json=body, + timeout=None, + ) + api_response.raise_for_status() + yield api_response.json()["choices"][0]["message"]["content"] + response.response_json = api_response.json() + + class MistralEmbed(llm.EmbeddingModel): model_id = "mistral-embed" batch_size = 10 diff --git a/pyproject.toml b/pyproject.toml index f661f2b..9dfd668 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ classifiers = [ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "llm>=0.17", + "llm>=0.18", "httpx", "httpx-sse", ] @@ -24,4 +24,4 @@ CI = "https://github.com/simonw/llm-mistral/actions" mistral = "llm_mistral" [project.optional-dependencies] -test = ["pytest", "pytest-httpx"] +test = ["pytest", "pytest-httpx", "pytest-asyncio"] diff --git a/tests/test_mistral.py b/tests/test_mistral.py index 4f9aae2..5da9b3f 100644 --- a/tests/test_mistral.py +++ b/tests/test_mistral.py @@ -89,6 +89,32 @@ def mocked_stream(httpx_mock): return httpx_mock +@pytest.fixture +def mocked_no_stream(httpx_mock): + httpx_mock.add_response( + url="https://api.mistral.ai/v1/chat/completions", + method="POST", + json={ + "id": "cmpl-362653b3050c4939bfa423af5f97709b", + "object": "chat.completion", + "created": 1702614202, + "model": "mistral-tiny", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I'm just a computer program, I don't have feelings.", + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 16, "total_tokens": 79, "completion_tokens": 63}, + }, + ) + return httpx_mock + + def test_stream(mocked_stream): model = llm.get_model("mistral-tiny") response = model.prompt("How are you?") @@ -104,6 +130,30 @@ def test_stream(mocked_stream): } +@pytest.mark.asyncio +async def test_stream_async(mocked_stream): + model = llm.get_async_model("mistral-tiny") + response = await model.prompt("How are you?") + chunks = [item async for item in response] + assert chunks == ["I am an AI"] + request = mocked_stream.get_request() + assert json.loads(request.content) == { + "model": "mistral-tiny", + "messages": [{"role": "user", "content": "How are you?"}], + "temperature": 0.7, + "top_p": 1, + "stream": True, + } + + +@pytest.mark.asyncio +async def test_async_no_stream(mocked_no_stream): + model = llm.get_async_model("mistral-tiny") + response = await model.prompt("How are you?", stream=False) + text = await response.text() + assert text == "I'm just a computer program, I don't have feelings." + + def test_stream_with_options(mocked_stream): model = llm.get_model("mistral-tiny") model.prompt( @@ -127,28 +177,7 @@ def test_stream_with_options(mocked_stream): } -def test_no_stream(httpx_mock): - httpx_mock.add_response( - url="https://api.mistral.ai/v1/chat/completions", - method="POST", - json={ - "id": "cmpl-362653b3050c4939bfa423af5f97709b", - "object": "chat.completion", - "created": 1702614202, - "model": "mistral-tiny", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "I'm just a computer program, I don't have feelings.", - }, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 16, "total_tokens": 79, "completion_tokens": 63}, - }, - ) +def test_no_stream(mocked_no_stream): model = llm.get_model("mistral-tiny") response = model.prompt("How are you?", stream=False) assert response.text() == "I'm just a computer program, I don't have feelings."