Skip to content

Commit

Permalink
Release 0.2.0 (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
fuegoio authored May 23, 2024
1 parent 374fd01 commit 32ec8b6
Show file tree
Hide file tree
Showing 19 changed files with 189 additions and 201 deletions.
90 changes: 90 additions & 0 deletions .github/workflows/build_publish.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
name: Lint / Test / Publish

on:
push:
branches: ["main"]

# We only deploy on tags and main branch
tags:
# Only run on tags that match the following regex
# This will match tags like 1.0.0, 1.0.1, etc.
- "[0-9]+.[0-9]+.[0-9]+"

# Lint and test on pull requests
pull_request:

jobs:
lint_and_test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
# Checkout the repository
- name: Checkout
uses: actions/checkout@v4

# Set python version to 3.11
- name: set python version
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

# Install Build stuff
- name: Install Dependencies
run: |
pip install poetry \
&& poetry config virtualenvs.create false \
&& poetry install
# Ruff
- name: Ruff check
run: |
poetry run ruff check .
- name: Ruff check
run: |
poetry run ruff format . --check
# Mypy
- name: Mypy Check
run: |
poetry run mypy .
# Tests
- name: Run Tests
run: |
poetry run pytest .
publish:
if: startsWith(github.ref, 'refs/tags')
runs-on: ubuntu-latest
needs: lint_and_test
steps:
# Checkout the repository
- name: Checkout
uses: actions/checkout@v4

# Set python version to 3.11
- name: set python version
uses: actions/setup-python@v4
with:
python-version: 3.11

# Install Build stuff
- name: Install Dependencies
run: |
pip install poetry \
&& poetry config virtualenvs.create false \
&& poetry install
# build package using poetry
- name: Build Package
run: |
poetry build
# Publish to PyPi
- name: Pypi publish
run: |
poetry config pypi-token.pypi ${{ secrets.PYPI_TOKEN }}
poetry publish
36 changes: 9 additions & 27 deletions examples/chatbot_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def completer(text, state):


class ChatBot:
def __init__(
self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE
):
def __init__(self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE):
if not api_key:
raise ValueError("An API key must be provided to use the Mistral API.")
self.client = MistralClient(api_key=api_key)
Expand All @@ -89,15 +87,11 @@ def opening_instructions(self):

def new_chat(self):
print("")
print(
f"Starting new chat with model: {self.model}, temperature: {self.temperature}"
)
print(f"Starting new chat with model: {self.model}, temperature: {self.temperature}")
print("")
self.messages = []
if self.system_message:
self.messages.append(
ChatMessage(role="system", content=self.system_message)
)
self.messages.append(ChatMessage(role="system", content=self.system_message))

def switch_model(self, input):
model = self.get_arguments(input)
Expand Down Expand Up @@ -146,13 +140,9 @@ def run_inference(self, content):
self.messages.append(ChatMessage(role="user", content=content))

assistant_response = ""
logger.debug(
f"Running inference with model: {self.model}, temperature: {self.temperature}"
)
logger.debug(f"Running inference with model: {self.model}, temperature: {self.temperature}")
logger.debug(f"Sending messages: {self.messages}")
for chunk in self.client.chat_stream(
model=self.model, temperature=self.temperature, messages=self.messages
):
for chunk in self.client.chat_stream(model=self.model, temperature=self.temperature, messages=self.messages):
response = chunk.choices[0].delta.content
if response is not None:
print(response, end="", flush=True)
Expand All @@ -161,9 +151,7 @@ def run_inference(self, content):
print("", flush=True)

if assistant_response:
self.messages.append(
ChatMessage(role="assistant", content=assistant_response)
)
self.messages.append(ChatMessage(role="assistant", content=assistant_response))
logger.debug(f"Current messages: {self.messages}")

def get_command(self, input):
Expand Down Expand Up @@ -215,9 +203,7 @@ def exit(self):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A simple chatbot using the Mistral API"
)
parser = argparse.ArgumentParser(description="A simple chatbot using the Mistral API")
parser.add_argument(
"--api-key",
default=os.environ.get("MISTRAL_API_KEY"),
Expand All @@ -230,19 +216,15 @@ def exit(self):
default=DEFAULT_MODEL,
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s",
)
parser.add_argument(
"-s", "--system-message", help="Optional system message to prepend."
)
parser.add_argument("-s", "--system-message", help="Optional system message to prepend.")
parser.add_argument(
"-t",
"--temperature",
type=float,
default=DEFAULT_TEMPERATURE,
help="Optional temperature for chat inference. Defaults to %(default)s",
)
parser.add_argument(
"-d", "--debug", action="store_true", help="Enable debug logging"
)
parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging")

args = parser.parse_args()

Expand Down
13 changes: 7 additions & 6 deletions examples/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,26 @@
"payment_status": ["Paid", "Unpaid", "Paid", "Paid", "Pending"],
}

def retrieve_payment_status(data: Dict[str,List], transaction_id: str) -> str:

def retrieve_payment_status(data: Dict[str, List], transaction_id: str) -> str:
for i, r in enumerate(data["transaction_id"]):
if r == transaction_id:
return json.dumps({"status": data["payment_status"][i]})
else:
return json.dumps({"status": "Error - transaction id not found"})


def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
for i, r in enumerate(data["transaction_id"]):
if r == transaction_id:
return json.dumps({"date": data["payment_date"][i]})
else:
return json.dumps({"status": "Error - transaction id not found"})


names_to_functions = {
"retrieve_payment_status": functools.partial(retrieve_payment_status, data=data),
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data)
"retrieve_payment_status": functools.partial(retrieve_payment_status, data=data),
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data),
}

tools = [
Expand Down Expand Up @@ -75,9 +78,7 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
messages.append(ChatMessage(role="assistant", content=response.choices[0].message.content))
messages.append(ChatMessage(role="user", content="My transaction ID is T1001."))

response = client.chat(
model=model, messages=messages, tools=tools
)
response = client.chat(model=model, messages=messages, tools=tools)

tool_call = response.choices[0].message.tool_calls[0]
function_name = tool_call.function.name
Expand Down
1 change: 0 additions & 1 deletion examples/json_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def main():
model=model,
response_format={"type": "json_object"},
messages=[ChatMessage(role="user", content="What is the best French cheese? Answer shortly in JSON.")],

)
print(chat_response.choices[0].message.content)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistralai"
version = "0.0.1"
version = "0.2.0"
description = ""
authors = ["Bam4d <[email protected]>"]
readme = "README.md"
Expand Down
1 change: 0 additions & 1 deletion src/mistralai/async_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import os
import posixpath
from json import JSONDecodeError
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
Expand Down
1 change: 0 additions & 1 deletion src/mistralai/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import posixpath
import time
from json import JSONDecodeError
Expand Down
9 changes: 4 additions & 5 deletions src/mistralai/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
)
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice

CLIENT_VERSION = "0.2.0"


class ClientBase(ABC):
def __init__(
Expand All @@ -25,9 +27,7 @@ def __init__(
if api_key is None:
api_key = os.environ.get("MISTRAL_API_KEY")
if api_key is None:
raise MistralException(
message="API key not provided. Please set MISTRAL_API_KEY environment variable."
)
raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.")
self._api_key = api_key
self._endpoint = endpoint
self._logger = logging.getLogger(__name__)
Expand All @@ -36,8 +36,7 @@ def __init__(
if "inference.azure.com" in self._endpoint:
self._default_model = "mistral"

# This should be automatically updated by the deploy script
self._version = "0.0.1"
self._version = CLIENT_VERSION

def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
parsed_tools: List[Dict[str, Any]] = []
Expand Down
2 changes: 0 additions & 2 deletions src/mistralai/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


RETRY_STATUS_CODES = {429, 500, 502, 503, 504}

ENDPOINT = "https://api.mistral.ai"
6 changes: 3 additions & 3 deletions src/mistralai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def __init__(
self.headers = headers or {}

@classmethod
def from_response(
cls, response: Response, message: Optional[str] = None
) -> MistralAPIException:
def from_response(cls, response: Response, message: Optional[str] = None) -> MistralAPIException:
return cls(
message=message or response.text,
http_status=response.status_code,
Expand All @@ -47,8 +45,10 @@ def from_response(
def __repr__(self) -> str:
return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"


class MistralAPIStatusException(MistralAPIException):
"""Returned when we receive a non-200 response from the API that we should retry"""


class MistralConnectionException(MistralException):
"""Returned when the SDK can not reach the API server for any reason"""
1 change: 1 addition & 0 deletions src/mistralai/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ModelPermission(BaseModel):
group: Optional[str] = None
is_blocking: bool = False


class ModelCard(BaseModel):
id: str
object: str
Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from unittest import mock

import pytest
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient


@pytest.fixture()
def client():
client = MistralClient(api_key="test_api_key")
client._client = mock.MagicMock()
return client


@pytest.fixture()
def async_client():
client = MistralAsyncClient(api_key="test_api_key")
client._client = mock.AsyncMock()
return client
Loading

0 comments on commit 32ec8b6

Please sign in to comment.