From 32ec8b63b2e4f470842aad22c2904e8d2fd8f22b Mon Sep 17 00:00:00 2001 From: Alexis Tacnet Date: Thu, 23 May 2024 18:34:09 +0200 Subject: [PATCH] Release 0.2.0 (#94) --- .github/workflows/build_publish.yaml | 90 ++++++++++++++++++++++++++++ examples/chatbot_with_streaming.py | 36 +++-------- examples/function_calling.py | 13 ++-- examples/json_format.py | 1 - pyproject.toml | 2 +- src/mistralai/async_client.py | 1 - src/mistralai/client.py | 1 - src/mistralai/client_base.py | 9 ++- src/mistralai/constants.py | 2 - src/mistralai/exceptions.py | 6 +- src/mistralai/models/models.py | 1 + tests/conftest.py | 19 ++++++ tests/test_chat.py | 43 +++---------- tests/test_chat_async.py | 62 +++++++------------ tests/test_embedder.py | 23 ++----- tests/test_embedder_async.py | 42 +++++-------- tests/test_list_models.py | 13 +--- tests/test_list_models_async.py | 22 ++----- tests/utils.py | 4 +- 19 files changed, 189 insertions(+), 201 deletions(-) create mode 100644 .github/workflows/build_publish.yaml create mode 100644 tests/conftest.py diff --git a/.github/workflows/build_publish.yaml b/.github/workflows/build_publish.yaml new file mode 100644 index 0000000..a696f10 --- /dev/null +++ b/.github/workflows/build_publish.yaml @@ -0,0 +1,90 @@ +name: Lint / Test / Publish + +on: + push: + branches: ["main"] + + # We only deploy on tags and main branch + tags: + # Only run on tags that match the following regex + # This will match tags like 1.0.0, 1.0.1, etc. + - "[0-9]+.[0-9]+.[0-9]+" + + # Lint and test on pull requests + pull_request: + +jobs: + lint_and_test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + # Checkout the repository + - name: Checkout + uses: actions/checkout@v4 + + # Set python version to 3.11 + - name: set python version + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + # Install Build stuff + - name: Install Dependencies + run: | + pip install poetry \ + && poetry config virtualenvs.create false \ + && poetry install + + # Ruff + - name: Ruff check + run: | + poetry run ruff check . + + - name: Ruff check + run: | + poetry run ruff format . --check + + # Mypy + - name: Mypy Check + run: | + poetry run mypy . + + # Tests + - name: Run Tests + run: | + poetry run pytest . + + publish: + if: startsWith(github.ref, 'refs/tags') + runs-on: ubuntu-latest + needs: lint_and_test + steps: + # Checkout the repository + - name: Checkout + uses: actions/checkout@v4 + + # Set python version to 3.11 + - name: set python version + uses: actions/setup-python@v4 + with: + python-version: 3.11 + + # Install Build stuff + - name: Install Dependencies + run: | + pip install poetry \ + && poetry config virtualenvs.create false \ + && poetry install + + # build package using poetry + - name: Build Package + run: | + poetry build + + # Publish to PyPi + - name: Pypi publish + run: | + poetry config pypi-token.pypi ${{ secrets.PYPI_TOKEN }} + poetry publish diff --git a/examples/chatbot_with_streaming.py b/examples/chatbot_with_streaming.py index 6654cb4..a815e2f 100755 --- a/examples/chatbot_with_streaming.py +++ b/examples/chatbot_with_streaming.py @@ -63,9 +63,7 @@ def completer(text, state): class ChatBot: - def __init__( - self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE - ): + def __init__(self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE): if not api_key: raise ValueError("An API key must be provided to use the Mistral API.") self.client = MistralClient(api_key=api_key) @@ -89,15 +87,11 @@ def opening_instructions(self): def new_chat(self): print("") - print( - f"Starting new chat with model: {self.model}, temperature: {self.temperature}" - ) + print(f"Starting new chat with model: {self.model}, temperature: {self.temperature}") print("") self.messages = [] if self.system_message: - self.messages.append( - ChatMessage(role="system", content=self.system_message) - ) + self.messages.append(ChatMessage(role="system", content=self.system_message)) def switch_model(self, input): model = self.get_arguments(input) @@ -146,13 +140,9 @@ def run_inference(self, content): self.messages.append(ChatMessage(role="user", content=content)) assistant_response = "" - logger.debug( - f"Running inference with model: {self.model}, temperature: {self.temperature}" - ) + logger.debug(f"Running inference with model: {self.model}, temperature: {self.temperature}") logger.debug(f"Sending messages: {self.messages}") - for chunk in self.client.chat_stream( - model=self.model, temperature=self.temperature, messages=self.messages - ): + for chunk in self.client.chat_stream(model=self.model, temperature=self.temperature, messages=self.messages): response = chunk.choices[0].delta.content if response is not None: print(response, end="", flush=True) @@ -161,9 +151,7 @@ def run_inference(self, content): print("", flush=True) if assistant_response: - self.messages.append( - ChatMessage(role="assistant", content=assistant_response) - ) + self.messages.append(ChatMessage(role="assistant", content=assistant_response)) logger.debug(f"Current messages: {self.messages}") def get_command(self, input): @@ -215,9 +203,7 @@ def exit(self): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="A simple chatbot using the Mistral API" - ) + parser = argparse.ArgumentParser(description="A simple chatbot using the Mistral API") parser.add_argument( "--api-key", default=os.environ.get("MISTRAL_API_KEY"), @@ -230,9 +216,7 @@ def exit(self): default=DEFAULT_MODEL, help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s", ) - parser.add_argument( - "-s", "--system-message", help="Optional system message to prepend." - ) + parser.add_argument("-s", "--system-message", help="Optional system message to prepend.") parser.add_argument( "-t", "--temperature", @@ -240,9 +224,7 @@ def exit(self): default=DEFAULT_TEMPERATURE, help="Optional temperature for chat inference. Defaults to %(default)s", ) - parser.add_argument( - "-d", "--debug", action="store_true", help="Enable debug logging" - ) + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging") args = parser.parse_args() diff --git a/examples/function_calling.py b/examples/function_calling.py index 9d6b89f..e6e6f28 100644 --- a/examples/function_calling.py +++ b/examples/function_calling.py @@ -15,13 +15,15 @@ "payment_status": ["Paid", "Unpaid", "Paid", "Paid", "Pending"], } -def retrieve_payment_status(data: Dict[str,List], transaction_id: str) -> str: + +def retrieve_payment_status(data: Dict[str, List], transaction_id: str) -> str: for i, r in enumerate(data["transaction_id"]): if r == transaction_id: return json.dumps({"status": data["payment_status"][i]}) else: return json.dumps({"status": "Error - transaction id not found"}) + def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str: for i, r in enumerate(data["transaction_id"]): if r == transaction_id: @@ -29,9 +31,10 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str: else: return json.dumps({"status": "Error - transaction id not found"}) + names_to_functions = { - "retrieve_payment_status": functools.partial(retrieve_payment_status, data=data), - "retrieve_payment_date": functools.partial(retrieve_payment_date, data=data) + "retrieve_payment_status": functools.partial(retrieve_payment_status, data=data), + "retrieve_payment_date": functools.partial(retrieve_payment_date, data=data), } tools = [ @@ -75,9 +78,7 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str: messages.append(ChatMessage(role="assistant", content=response.choices[0].message.content)) messages.append(ChatMessage(role="user", content="My transaction ID is T1001.")) -response = client.chat( - model=model, messages=messages, tools=tools -) +response = client.chat(model=model, messages=messages, tools=tools) tool_call = response.choices[0].message.tool_calls[0] function_name = tool_call.function.name diff --git a/examples/json_format.py b/examples/json_format.py index 5c03d35..749965b 100755 --- a/examples/json_format.py +++ b/examples/json_format.py @@ -16,7 +16,6 @@ def main(): model=model, response_format={"type": "json_object"}, messages=[ChatMessage(role="user", content="What is the best French cheese? Answer shortly in JSON.")], - ) print(chat_response.choices[0].message.content) diff --git a/pyproject.toml b/pyproject.toml index 8423099..9a4d726 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mistralai" -version = "0.0.1" +version = "0.2.0" description = "" authors = ["Bam4d "] readme = "README.md" diff --git a/src/mistralai/async_client.py b/src/mistralai/async_client.py index d04edba..2019de5 100644 --- a/src/mistralai/async_client.py +++ b/src/mistralai/async_client.py @@ -1,5 +1,4 @@ import asyncio -import os import posixpath from json import JSONDecodeError from typing import Any, AsyncGenerator, Dict, List, Optional, Union diff --git a/src/mistralai/client.py b/src/mistralai/client.py index 40b46e8..a5daa51 100644 --- a/src/mistralai/client.py +++ b/src/mistralai/client.py @@ -1,4 +1,3 @@ -import os import posixpath import time from json import JSONDecodeError diff --git a/src/mistralai/client_base.py b/src/mistralai/client_base.py index 2497fb8..d58ff14 100644 --- a/src/mistralai/client_base.py +++ b/src/mistralai/client_base.py @@ -10,6 +10,8 @@ ) from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice +CLIENT_VERSION = "0.2.0" + class ClientBase(ABC): def __init__( @@ -25,9 +27,7 @@ def __init__( if api_key is None: api_key = os.environ.get("MISTRAL_API_KEY") if api_key is None: - raise MistralException( - message="API key not provided. Please set MISTRAL_API_KEY environment variable." - ) + raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.") self._api_key = api_key self._endpoint = endpoint self._logger = logging.getLogger(__name__) @@ -36,8 +36,7 @@ def __init__( if "inference.azure.com" in self._endpoint: self._default_model = "mistral" - # This should be automatically updated by the deploy script - self._version = "0.0.1" + self._version = CLIENT_VERSION def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: parsed_tools: List[Dict[str, Any]] = [] diff --git a/src/mistralai/constants.py b/src/mistralai/constants.py index b274a4c..c70331b 100644 --- a/src/mistralai/constants.py +++ b/src/mistralai/constants.py @@ -1,5 +1,3 @@ - - RETRY_STATUS_CODES = {429, 500, 502, 503, 504} ENDPOINT = "https://api.mistral.ai" diff --git a/src/mistralai/exceptions.py b/src/mistralai/exceptions.py index 9c9da81..5728a1c 100644 --- a/src/mistralai/exceptions.py +++ b/src/mistralai/exceptions.py @@ -35,9 +35,7 @@ def __init__( self.headers = headers or {} @classmethod - def from_response( - cls, response: Response, message: Optional[str] = None - ) -> MistralAPIException: + def from_response(cls, response: Response, message: Optional[str] = None) -> MistralAPIException: return cls( message=message or response.text, http_status=response.status_code, @@ -47,8 +45,10 @@ def from_response( def __repr__(self) -> str: return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})" + class MistralAPIStatusException(MistralAPIException): """Returned when we receive a non-200 response from the API that we should retry""" + class MistralConnectionException(MistralException): """Returned when the SDK can not reach the API server for any reason""" diff --git a/src/mistralai/models/models.py b/src/mistralai/models/models.py index 8b3b6d7..0acd402 100644 --- a/src/mistralai/models/models.py +++ b/src/mistralai/models/models.py @@ -17,6 +17,7 @@ class ModelPermission(BaseModel): group: Optional[str] = None is_blocking: bool = False + class ModelCard(BaseModel): id: str object: str diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c43f7aa --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,19 @@ +from unittest import mock + +import pytest +from mistralai.async_client import MistralAsyncClient +from mistralai.client import MistralClient + + +@pytest.fixture() +def client(): + client = MistralClient(api_key="test_api_key") + client._client = mock.MagicMock() + return client + + +@pytest.fixture() +def async_client(): + client = MistralAsyncClient(api_key="test_api_key") + client._client = mock.AsyncMock() + return client diff --git a/tests/test_chat.py b/tests/test_chat.py index e64e68a..eebc736 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,7 +1,3 @@ -import unittest.mock as mock - -import pytest -from mistralai.client import MistralClient from mistralai.models.chat_completion import ( ChatCompletionResponse, ChatCompletionStreamResponse, @@ -16,13 +12,6 @@ ) -@pytest.fixture() -def client(): - client = MistralClient() - client._client = mock.MagicMock() - return client - - class TestChat: def test_chat(self, client): client._client.request.return_value = mock_response( @@ -32,9 +21,7 @@ def test_chat(self, client): result = client.chat( model="mistral-small", - messages=[ - ChatMessage(role="user", content="What is the best French cheese?") - ], + messages=[ChatMessage(role="user", content="What is the best French cheese?")], ) client._client.request.assert_called_once_with( @@ -43,22 +30,18 @@ def test_chat(self, client): headers={ "User-Agent": f"mistral-client-python/{client._version}", "Accept": "application/json", - "Authorization": "Bearer None", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={ "model": "mistral-small", - "messages": [ - {"role": "user", "content": "What is the best French cheese?"} - ], + "messages": [{"role": "user", "content": "What is the best French cheese?"}], "safe_prompt": False, "stream": False, }, ) - assert isinstance( - result, ChatCompletionResponse - ), "Should return an ChatCompletionResponse" + assert isinstance(result, ChatCompletionResponse), "Should return an ChatCompletionResponse" assert len(result.choices) == 1 assert result.choices[0].index == 0 assert result.object == "chat.completion" @@ -71,9 +54,7 @@ def test_chat_streaming(self, client): result = client.chat_stream( model="mistral-small", - messages=[ - ChatMessage(role="user", content="What is the best French cheese?") - ], + messages=[ChatMessage(role="user", content="What is the best French cheese?")], ) results = list(result) @@ -84,14 +65,12 @@ def test_chat_streaming(self, client): headers={ "User-Agent": f"mistral-client-python/{client._version}", "Accept": "text/event-stream", - "Authorization": "Bearer None", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={ "model": "mistral-small", - "messages": [ - {"role": "user", "content": "What is the best French cheese?"} - ], + "messages": [{"role": "user", "content": "What is the best French cheese?"}], "safe_prompt": False, "stream": True, }, @@ -99,16 +78,12 @@ def test_chat_streaming(self, client): for i, result in enumerate(results): if i == 0: - assert isinstance( - result, ChatCompletionStreamResponse - ), "Should return an ChatCompletionStreamResponse" + assert isinstance(result, ChatCompletionStreamResponse), "Should return an ChatCompletionStreamResponse" assert len(result.choices) == 1 assert result.choices[0].index == 0 assert result.choices[0].delta.role == "assistant" else: - assert isinstance( - result, ChatCompletionStreamResponse - ), "Should return an ChatCompletionStreamResponse" + assert isinstance(result, ChatCompletionStreamResponse), "Should return an ChatCompletionStreamResponse" assert len(result.choices) == 1 assert result.choices[0].index == i - 1 assert result.choices[0].delta.content == f"stream response {i-1}" diff --git a/tests/test_chat_async.py b/tests/test_chat_async.py index 7e51a97..e68760f 100644 --- a/tests/test_chat_async.py +++ b/tests/test_chat_async.py @@ -1,7 +1,6 @@ import unittest.mock as mock import pytest -from mistralai.async_client import MistralAsyncClient from mistralai.models.chat_completion import ( ChatCompletionResponse, ChatCompletionStreamResponse, @@ -16,85 +15,68 @@ ) -@pytest.fixture() -def client(): - client = MistralAsyncClient() - client._client = mock.AsyncMock() - client._client.stream = mock.Mock() - return client - - class TestAsyncChat: @pytest.mark.asyncio - async def test_chat(self, client): - client._client.request.return_value = mock_response( + async def test_chat(self, async_client): + async_client._client.request.return_value = mock_response( 200, mock_chat_response_payload(), ) - result = await client.chat( + result = await async_client.chat( model="mistral-small", - messages=[ - ChatMessage(role="user", content="What is the best French cheese?") - ], + messages=[ChatMessage(role="user", content="What is the best French cheese?")], ) - client._client.request.assert_awaited_once_with( + async_client._client.request.assert_awaited_once_with( "post", "https://api.mistral.ai/v1/chat/completions", headers={ - "User-Agent": f"mistral-client-python/{client._version}", + "User-Agent": f"mistral-client-python/{async_client._version}", "Accept": "application/json", - "Authorization": "Bearer None", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={ "model": "mistral-small", - "messages": [ - {"role": "user", "content": "What is the best French cheese?"} - ], + "messages": [{"role": "user", "content": "What is the best French cheese?"}], "safe_prompt": False, "stream": False, }, ) - assert isinstance( - result, ChatCompletionResponse - ), "Should return an ChatCompletionResponse" + assert isinstance(result, ChatCompletionResponse), "Should return an ChatCompletionResponse" assert len(result.choices) == 1 assert result.choices[0].index == 0 assert result.object == "chat.completion" @pytest.mark.asyncio - async def test_chat_streaming(self, client): - client._client.stream.return_value = mock_async_stream_response( + async def test_chat_streaming(self, async_client): + async_client._client.stream = mock.Mock() + async_client._client.stream.return_value = mock_async_stream_response( 200, mock_chat_response_streaming_payload(), ) - result = client.chat_stream( + result = async_client.chat_stream( model="mistral-small", - messages=[ - ChatMessage(role="user", content="What is the best French cheese?") - ], + messages=[ChatMessage(role="user", content="What is the best French cheese?")], ) results = [r async for r in result] - client._client.stream.assert_called_once_with( + async_client._client.stream.assert_called_once_with( "post", "https://api.mistral.ai/v1/chat/completions", headers={ "Accept": "text/event-stream", - "User-Agent": f"mistral-client-python/{client._version}", - "Authorization": "Bearer None", + "User-Agent": f"mistral-client-python/{async_client._version}", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={ "model": "mistral-small", - "messages": [ - {"role": "user", "content": "What is the best French cheese?"} - ], + "messages": [{"role": "user", "content": "What is the best French cheese?"}], "safe_prompt": False, "stream": True, }, @@ -102,16 +84,12 @@ async def test_chat_streaming(self, client): for i, result in enumerate(results): if i == 0: - assert isinstance( - result, ChatCompletionStreamResponse - ), "Should return an ChatCompletionStreamResponse" + assert isinstance(result, ChatCompletionStreamResponse), "Should return an ChatCompletionStreamResponse" assert len(result.choices) == 1 assert result.choices[0].index == 0 assert result.choices[0].delta.role == "assistant" else: - assert isinstance( - result, ChatCompletionStreamResponse - ), "Should return an ChatCompletionStreamResponse" + assert isinstance(result, ChatCompletionStreamResponse), "Should return an ChatCompletionStreamResponse" assert len(result.choices) == 1 assert result.choices[0].index == i - 1 assert result.choices[0].delta.content == f"stream response {i-1}" diff --git a/tests/test_embedder.py b/tests/test_embedder.py index 59e30fa..56cd4c5 100644 --- a/tests/test_embedder.py +++ b/tests/test_embedder.py @@ -1,19 +1,8 @@ -import unittest.mock as mock - -import pytest -from mistralai.client import MistralClient from mistralai.models.embeddings import EmbeddingResponse from .utils import mock_embedding_response_payload, mock_response -@pytest.fixture() -def client(): - client = MistralClient() - client._client = mock.MagicMock() - return client - - class TestEmbeddings: def test_embeddings(self, client): client._client.request.return_value = mock_response( @@ -32,15 +21,13 @@ def test_embeddings(self, client): headers={ "User-Agent": f"mistral-client-python/{client._version}", "Accept": "application/json", - "Authorization": "Bearer None", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={"model": "mistral-embed", "input": "What is the best French cheese?"}, ) - assert isinstance( - result, EmbeddingResponse - ), "Should return an EmbeddingResponse" + assert isinstance(result, EmbeddingResponse), "Should return an EmbeddingResponse" assert len(result.data) == 1 assert result.data[0].index == 0 assert result.object == "list" @@ -62,7 +49,7 @@ def test_embeddings_batch(self, client): headers={ "User-Agent": f"mistral-client-python/{client._version}", "Accept": "application/json", - "Authorization": "Bearer None", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={ @@ -71,9 +58,7 @@ def test_embeddings_batch(self, client): }, ) - assert isinstance( - result, EmbeddingResponse - ), "Should return an EmbeddingResponse" + assert isinstance(result, EmbeddingResponse), "Should return an EmbeddingResponse" assert len(result.data) == 10 assert result.data[0].index == 0 assert result.object == "list" diff --git a/tests/test_embedder_async.py b/tests/test_embedder_async.py index 8d08d6a..d95fdd4 100644 --- a/tests/test_embedder_async.py +++ b/tests/test_embedder_async.py @@ -1,70 +1,58 @@ -import unittest.mock as mock - import pytest -from mistralai.async_client import MistralAsyncClient from mistralai.models.embeddings import EmbeddingResponse from .utils import mock_embedding_response_payload, mock_response -@pytest.fixture() -def client(): - client = MistralAsyncClient() - client._client = mock.AsyncMock() - return client - - class TestAsyncEmbeddings: @pytest.mark.asyncio - async def test_embeddings(self, client): - client._client.request.return_value = mock_response( + async def test_embeddings(self, async_client): + async_client._client.request.return_value = mock_response( 200, mock_embedding_response_payload(), ) - result = await client.embeddings( + result = await async_client.embeddings( model="mistral-embed", input="What is the best French cheese?", ) - client._client.request.assert_awaited_once_with( + async_client._client.request.assert_awaited_once_with( "post", "https://api.mistral.ai/v1/embeddings", headers={ - "User-Agent": f"mistral-client-python/{client._version}", + "User-Agent": f"mistral-client-python/{async_client._version}", "Accept": "application/json", - "Authorization": "Bearer None", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={"model": "mistral-embed", "input": "What is the best French cheese?"}, ) - assert isinstance( - result, EmbeddingResponse - ), "Should return an EmbeddingResponse" + assert isinstance(result, EmbeddingResponse), "Should return an EmbeddingResponse" assert len(result.data) == 1 assert result.data[0].index == 0 assert result.object == "list" @pytest.mark.asyncio - async def test_embeddings_batch(self, client): - client._client.request.return_value = mock_response( + async def test_embeddings_batch(self, async_client): + async_client._client.request.return_value = mock_response( 200, mock_embedding_response_payload(batch_size=10), ) - result = await client.embeddings( + result = await async_client.embeddings( model="mistral-embed", input=["What is the best French cheese?"] * 10, ) - client._client.request.assert_awaited_once_with( + async_client._client.request.assert_awaited_once_with( "post", "https://api.mistral.ai/v1/embeddings", headers={ - "User-Agent": f"mistral-client-python/{client._version}", + "User-Agent": f"mistral-client-python/{async_client._version}", "Accept": "application/json", - "Authorization": "Bearer None", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={ @@ -73,9 +61,7 @@ async def test_embeddings_batch(self, client): }, ) - assert isinstance( - result, EmbeddingResponse - ), "Should return an EmbeddingResponse" + assert isinstance(result, EmbeddingResponse), "Should return an EmbeddingResponse" assert len(result.data) == 10 assert result.data[0].index == 0 assert result.object == "list" diff --git a/tests/test_list_models.py b/tests/test_list_models.py index 1a048fa..6b73978 100644 --- a/tests/test_list_models.py +++ b/tests/test_list_models.py @@ -1,19 +1,8 @@ -import unittest.mock as mock - -import pytest -from mistralai.client import MistralClient from mistralai.models.models import ModelList from .utils import mock_list_models_response_payload, mock_response -@pytest.fixture() -def client(): - client = MistralClient() - client._client = mock.MagicMock() - return client - - class TestListModels: def test_list_models(self, client): client._client.request.return_value = mock_response( @@ -29,7 +18,7 @@ def test_list_models(self, client): headers={ "User-Agent": f"mistral-client-python/{client._version}", "Accept": "application/json", - "Authorization": "Bearer None", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={}, diff --git a/tests/test_list_models_async.py b/tests/test_list_models_async.py index 6d572d1..a876484 100644 --- a/tests/test_list_models_async.py +++ b/tests/test_list_models_async.py @@ -1,36 +1,26 @@ -import unittest.mock as mock - import pytest -from mistralai.async_client import MistralAsyncClient from mistralai.models.models import ModelList from .utils import mock_list_models_response_payload, mock_response -@pytest.fixture() -def client(): - client = MistralAsyncClient() - client._client = mock.AsyncMock() - return client - - class TestAsyncListModels: @pytest.mark.asyncio - async def test_list_models(self, client): - client._client.request.return_value = mock_response( + async def test_list_models(self, async_client): + async_client._client.request.return_value = mock_response( 200, mock_list_models_response_payload(), ) - result = await client.list_models() + result = await async_client.list_models() - client._client.request.assert_awaited_once_with( + async_client._client.request.assert_awaited_once_with( "get", "https://api.mistral.ai/v1/models", headers={ - "User-Agent": f"mistral-client-python/{client._version}", + "User-Agent": f"mistral-client-python/{async_client._version}", "Accept": "application/json", - "Authorization": "Bearer None", + "Authorization": "Bearer test_api_key", "Content-Type": "application/json", }, json={}, diff --git a/tests/utils.py b/tests/utils.py index e5edef6..50637c2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -27,9 +27,7 @@ async def async_iter(content: List[str]): yield response -def mock_response( - status_code: int, content: str, is_json: bool = True -) -> mock.MagicMock: +def mock_response(status_code: int, content: str, is_json: bool = True) -> mock.MagicMock: response = mock.Mock(Response) response.status_code = status_code if is_json: