From aa010d2c2e8c9d3967604aa49d3c82e494e470cb Mon Sep 17 00:00:00 2001 From: jean-malo Date: Wed, 29 May 2024 15:06:41 +0200 Subject: [PATCH] release 0.3.0: add support for completion --- examples/async_completion.py | 33 +++++++++ examples/chatbot_with_streaming.py | 9 +-- examples/code_completion.py | 33 +++++++++ examples/completion_with_streaming.py | 29 ++++++++ pyproject.toml | 2 +- src/mistralai/async_client.py | 73 +++++++++++++++++++- src/mistralai/client.py | 76 ++++++++++++++++++++- src/mistralai/client_base.py | 71 +++++++++++++++++--- tests/test_chat.py | 8 +-- tests/test_chat_async.py | 8 +-- tests/test_completion.py | 97 +++++++++++++++++++++++++++ tests/utils.py | 30 +++++++-- 12 files changed, 442 insertions(+), 27 deletions(-) create mode 100644 examples/async_completion.py create mode 100644 examples/code_completion.py create mode 100644 examples/completion_with_streaming.py create mode 100644 tests/test_completion.py diff --git a/examples/async_completion.py b/examples/async_completion.py new file mode 100644 index 0000000..6aa22b4 --- /dev/null +++ b/examples/async_completion.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +import asyncio +import os + +from mistralai.async_client import MistralAsyncClient + + +async def main(): + api_key = os.environ["MISTRAL_API_KEY"] + + client = MistralAsyncClient(api_key=api_key) + + prompt = "def fibonacci(n: int):" + suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))" + + response = await client.completion( + model="codestral-latest", + prompt=prompt, + suffix=suffix, + ) + + print( + f""" +{prompt} +{response.choices[0].message.content} +{suffix} +""" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/chatbot_with_streaming.py b/examples/chatbot_with_streaming.py index a815e2f..4304551 100755 --- a/examples/chatbot_with_streaming.py +++ b/examples/chatbot_with_streaming.py @@ -12,11 +12,12 @@ from mistralai.models.chat_completion import ChatMessage MODEL_LIST = [ - "mistral-tiny", - "mistral-small", - "mistral-medium", + "mistral-small-latest", + "mistral-medium-latest", + "mistral-large-latest", + "codestral-latest", ] -DEFAULT_MODEL = "mistral-small" +DEFAULT_MODEL = "mistral-small-latest" DEFAULT_TEMPERATURE = 0.7 LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" # A dictionary of all commands and their arguments, used for tab completion. diff --git a/examples/code_completion.py b/examples/code_completion.py new file mode 100644 index 0000000..f76f0f1 --- /dev/null +++ b/examples/code_completion.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +import asyncio +import os + +from mistralai.client import MistralClient + + +async def main(): + api_key = os.environ["MISTRAL_API_KEY"] + + client = MistralClient(api_key=api_key) + + prompt = "def fibonacci(n: int):" + suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))" + + response = client.completion( + model="codestral-latest", + prompt=prompt, + suffix=suffix, + ) + + print( + f""" +{prompt} +{response.choices[0].message.content} +{suffix} +""" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/completion_with_streaming.py b/examples/completion_with_streaming.py new file mode 100644 index 0000000..f0760bf --- /dev/null +++ b/examples/completion_with_streaming.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import asyncio +import os + +from mistralai.client import MistralClient + + +async def main(): + api_key = os.environ["MISTRAL_API_KEY"] + + client = MistralClient(api_key=api_key) + + prompt = "def fibonacci(n: int):" + suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))" + + print(prompt) + for chunk in client.completion_stream( + model="codestral-latest", + prompt=prompt, + suffix=suffix, + ): + if chunk.choices[0].delta.content is not None: + print(chunk.choices[0].delta.content, end="") + print(suffix) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 9a4d726..bf3077e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mistralai" -version = "0.2.0" +version = "0.3.0" description = "" authors = ["Bam4d "] readme = "README.md" diff --git a/src/mistralai/async_client.py b/src/mistralai/async_client.py index 2019de5..bc80a8b 100644 --- a/src/mistralai/async_client.py +++ b/src/mistralai/async_client.py @@ -92,7 +92,7 @@ async def _check_response(self, response: Response) -> Dict[str, Any]: async def _request( self, method: str, - json: Dict[str, Any], + json: Optional[Dict[str, Any]], path: str, stream: bool = False, attempt: int = 1, @@ -291,3 +291,74 @@ async def list_models(self) -> ModelList: return ModelList(**response) raise MistralException("No response received") + + async def completion( + self, + model: str, + prompt: str, + suffix: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + random_seed: Optional[int] = None, + stop: Optional[List[str]] = None, + ) -> ChatCompletionResponse: + """An asynchronous completion endpoint that returns a single response. + + Args: + model (str): model the name of the model to get completions with, e.g. codestral-latest + prompt (str): the prompt to complete + suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion + temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5. + max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None. + top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9. + Defaults to None. + random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. + stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n'] + Returns: + Dict[str, Any]: a response object containing the generated text. + """ + request = self._make_completion_request( + prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop + ) + single_response = self._request("post", request, "v1/fim/completions") + + async for response in single_response: + return ChatCompletionResponse(**response) + + raise MistralException("No response received") + + async def completion_stream( + self, + model: str, + prompt: str, + suffix: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + random_seed: Optional[int] = None, + stop: Optional[List[str]] = None, + ) -> AsyncGenerator[ChatCompletionStreamResponse, None]: + """An asynchronous completion endpoint that returns a streaming response. + + Args: + model (str): model the name of the model to get completions with, e.g. codestral-latest + prompt (str): the prompt to complete + suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion + temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5. + max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None. + top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9. + Defaults to None. + random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. + stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n'] + + Returns: + Dict[str, Any]: a response object containing the generated text. + """ + request = self._make_completion_request( + prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True + ) + async_response = self._request("post", request, "v1/fim/completions", stream=True) + + async for json_response in async_response: + yield ChatCompletionStreamResponse(**json_response) diff --git a/src/mistralai/client.py b/src/mistralai/client.py index a5daa51..b00ddcf 100644 --- a/src/mistralai/client.py +++ b/src/mistralai/client.py @@ -85,7 +85,7 @@ def _check_response(self, response: Response) -> Dict[str, Any]: def _request( self, method: str, - json: Dict[str, Any], + json: Optional[Dict[str, Any]], path: str, stream: bool = False, attempt: int = 1, @@ -285,3 +285,77 @@ def list_models(self) -> ModelList: return ModelList(**response) raise MistralException("No response received") + + def completion( + self, + model: str, + prompt: str, + suffix: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + random_seed: Optional[int] = None, + stop: Optional[List[str]] = None, + ) -> ChatCompletionResponse: + """A completion endpoint that returns a single response. + + Args: + model (str): model the name of the model to get completion with, e.g. codestral-latest + prompt (str): the prompt to complete + suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion + temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5. + max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None. + top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9. + Defaults to None. + random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. + stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n'] + + Returns: + Dict[str, Any]: a response object containing the generated text. + """ + request = self._make_completion_request( + prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop + ) + + single_response = self._request("post", request, "v1/fim/completions", stream=False) + + for response in single_response: + return ChatCompletionResponse(**response) + + raise MistralException("No response received") + + def completion_stream( + self, + model: str, + prompt: str, + suffix: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + random_seed: Optional[int] = None, + stop: Optional[List[str]] = None, + ) -> Iterable[ChatCompletionStreamResponse]: + """An asynchronous completion endpoint that streams responses. + + Args: + model (str): model the name of the model to get completions with, e.g. codestral-latest + prompt (str): the prompt to complete + suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion + temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5. + max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None. + top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9. + Defaults to None. + random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. + stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n'] + + Returns: + Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text. + """ + request = self._make_completion_request( + prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True + ) + + response = self._request("post", request, "v1/fim/completions", stream=True) + + for json_streamed_response in response: + yield ChatCompletionStreamResponse(**json_streamed_response) diff --git a/src/mistralai/client_base.py b/src/mistralai/client_base.py index d58ff14..c38e093 100644 --- a/src/mistralai/client_base.py +++ b/src/mistralai/client_base.py @@ -73,6 +73,63 @@ def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: return parsed_messages + def _make_completion_request( + self, + prompt: str, + model: Optional[str] = None, + suffix: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + random_seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream: Optional[bool] = False, + ) -> Dict[str, Any]: + request_data: Dict[str, Any] = { + "prompt": prompt, + "suffix": suffix, + "model": model, + "stream": stream, + } + + if stop is not None: + request_data["stop"] = stop + + if model is not None: + request_data["model"] = model + else: + if self._default_model is None: + raise MistralException(message="model must be provided") + request_data["model"] = self._default_model + + request_data.update( + self._build_sampling_params( + temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed + ) + ) + + self._logger.debug(f"Completion request: {request_data}") + + return request_data + + def _build_sampling_params( + self, + max_tokens: Optional[int], + random_seed: Optional[int], + temperature: Optional[float], + top_p: Optional[float], + ) -> Dict[str, Any]: + params = {} + if temperature is not None: + params["temperature"] = temperature + if max_tokens is not None: + params["max_tokens"] = max_tokens + if top_p is not None: + params["top_p"] = top_p + if random_seed is not None: + params["random_seed"] = random_seed + return params + def _make_chat_request( self, messages: List[Any], @@ -99,16 +156,14 @@ def _make_chat_request( raise MistralException(message="model must be provided") request_data["model"] = self._default_model + request_data.update( + self._build_sampling_params( + temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed + ) + ) + if tools is not None: request_data["tools"] = self._parse_tools(tools) - if temperature is not None: - request_data["temperature"] = temperature - if max_tokens is not None: - request_data["max_tokens"] = max_tokens - if top_p is not None: - request_data["top_p"] = top_p - if random_seed is not None: - request_data["random_seed"] = random_seed if stream is not None: request_data["stream"] = stream diff --git a/tests/test_chat.py b/tests/test_chat.py index eebc736..6b1658e 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -20,7 +20,7 @@ def test_chat(self, client): ) result = client.chat( - model="mistral-small", + model="mistral-small-latest", messages=[ChatMessage(role="user", content="What is the best French cheese?")], ) @@ -34,7 +34,7 @@ def test_chat(self, client): "Content-Type": "application/json", }, json={ - "model": "mistral-small", + "model": "mistral-small-latest", "messages": [{"role": "user", "content": "What is the best French cheese?"}], "safe_prompt": False, "stream": False, @@ -53,7 +53,7 @@ def test_chat_streaming(self, client): ) result = client.chat_stream( - model="mistral-small", + model="mistral-small-latest", messages=[ChatMessage(role="user", content="What is the best French cheese?")], ) @@ -69,7 +69,7 @@ def test_chat_streaming(self, client): "Content-Type": "application/json", }, json={ - "model": "mistral-small", + "model": "mistral-small-latest", "messages": [{"role": "user", "content": "What is the best French cheese?"}], "safe_prompt": False, "stream": True, diff --git a/tests/test_chat_async.py b/tests/test_chat_async.py index e68760f..15479ed 100644 --- a/tests/test_chat_async.py +++ b/tests/test_chat_async.py @@ -24,7 +24,7 @@ async def test_chat(self, async_client): ) result = await async_client.chat( - model="mistral-small", + model="mistral-small-latest", messages=[ChatMessage(role="user", content="What is the best French cheese?")], ) @@ -38,7 +38,7 @@ async def test_chat(self, async_client): "Content-Type": "application/json", }, json={ - "model": "mistral-small", + "model": "mistral-small-latest", "messages": [{"role": "user", "content": "What is the best French cheese?"}], "safe_prompt": False, "stream": False, @@ -59,7 +59,7 @@ async def test_chat_streaming(self, async_client): ) result = async_client.chat_stream( - model="mistral-small", + model="mistral-small-latest", messages=[ChatMessage(role="user", content="What is the best French cheese?")], ) @@ -75,7 +75,7 @@ async def test_chat_streaming(self, async_client): "Content-Type": "application/json", }, json={ - "model": "mistral-small", + "model": "mistral-small-latest", "messages": [{"role": "user", "content": "What is the best French cheese?"}], "safe_prompt": False, "stream": True, diff --git a/tests/test_completion.py b/tests/test_completion.py new file mode 100644 index 0000000..1b6f1c1 --- /dev/null +++ b/tests/test_completion.py @@ -0,0 +1,97 @@ +from mistralai.models.chat_completion import ( + ChatCompletionResponse, + ChatCompletionStreamResponse, +) + +from .utils import ( + mock_completion_response_payload, + mock_response, + mock_stream_response, +) + + +class TestCompletion: + def test_completion(self, client): + client._client.request.return_value = mock_response( + 200, + mock_completion_response_payload(), + ) + + result = client.completion( + model="mistral-small-latest", + prompt="def add(a, b):", + suffix="return a + b", + temperature=0.5, + max_tokens=50, + top_p=0.9, + random_seed=42, + ) + + client._client.request.assert_called_once_with( + "post", + "https://api.mistral.ai/v1/fim/completions", + headers={ + "User-Agent": f"mistral-client-python/{client._version}", + "Accept": "application/json", + "Authorization": "Bearer test_api_key", + "Content-Type": "application/json", + }, + json={ + "model": "mistral-small-latest", + "prompt": "def add(a, b):", + "suffix": "return a + b", + "stream": False, + "temperature": 0.5, + "max_tokens": 50, + "top_p": 0.9, + "random_seed": 42, + }, + ) + + assert isinstance(result, ChatCompletionResponse), "Should return an ChatCompletionResponse" + assert len(result.choices) == 1 + assert result.choices[0].index == 0 + assert result.object == "chat.completion" + + def test_completion_streaming(self, client): + client._client.stream.return_value = mock_stream_response( + 200, + mock_completion_response_payload(), + ) + + result = client.completion_stream( + model="mistral-small-latest", prompt="def add(a, b):", suffix="return a + b", stop=["#"] + ) + + results = list(result) + + client._client.stream.assert_called_once_with( + "post", + "https://api.mistral.ai/v1/fim/completions", + headers={ + "User-Agent": f"mistral-client-python/{client._version}", + "Accept": "text/event-stream", + "Authorization": "Bearer test_api_key", + "Content-Type": "application/json", + }, + json={ + "model": "mistral-small-latest", + "prompt": "def add(a, b):", + "suffix": "return a + b", + "stream": True, + "stop": ["#"], + }, + ) + + for i, result in enumerate(results): + if i == 0: + assert isinstance(result, ChatCompletionStreamResponse), "Should return an ChatCompletionStreamResponse" + assert len(result.choices) == 1 + assert result.choices[0].index == 0 + assert result.choices[0].delta.role == "assistant" + else: + assert isinstance(result, ChatCompletionStreamResponse), "Should return an ChatCompletionStreamResponse" + assert len(result.choices) == 1 + assert result.choices[0].index == i - 1 + assert result.choices[0].delta.content == f"stream response {i - 1}" + assert result.object == "chat.completion.chunk" diff --git a/tests/utils.py b/tests/utils.py index 50637c2..826753d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -67,7 +67,7 @@ def mock_list_models_response_payload() -> str: ], }, { - "id": "mistral-small", + "id": "mistral-small-latest", "object": "model", "created": 1703186988, "owned_by": "mistralai", @@ -178,7 +178,7 @@ def mock_chat_response_payload(): "index": 0, } ], - "model": "mistral-small", + "model": "mistral-small-latest", "usage": {"prompt_tokens": 90, "total_tokens": 90, "completion_tokens": 0}, } ).decode() @@ -190,7 +190,7 @@ def mock_chat_response_streaming_payload(): + orjson.dumps( { "id": "cmpl-8cd9019d21ba490aa6b9740f5d0a883e", - "model": "mistral-small", + "model": "mistral-small-latest", "choices": [ { "index": 0, @@ -208,7 +208,7 @@ def mock_chat_response_streaming_payload(): "id": "cmpl-8cd9019d21ba490aa6b9740f5d0a883e", "object": "chat.completion.chunk", "created": 1703168544, - "model": "mistral-small", + "model": "mistral-small-latest", "choices": [ { "index": i, @@ -223,3 +223,25 @@ def mock_chat_response_streaming_payload(): ], "data: [DONE]\n\n", ] + + +def mock_completion_response_payload() -> str: + return orjson.dumps( + { + "id": "chat-98c8c60e3fbf4fc49658eddaf447357c", + "object": "chat.completion", + "created": 1703165682, + "choices": [ + { + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": " a + b", + }, + "index": 0, + } + ], + "model": "mistral-small-latest", + "usage": {"prompt_tokens": 90, "total_tokens": 90, "completion_tokens": 0}, + } + ).decode()