Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure AI Inference SDK - Beta 2 updates #36163

Merged
merged 10 commits into from
Jun 22, 2024
9 changes: 5 additions & 4 deletions sdk/ai/azure-ai-inference/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Release History

## 1.0.0b2 (Unreleased)
## 1.0.0b2 (2024-06-24)

### Features Added

### Breaking Changes
Add `model` as an optional input argument to the `complete` method of `ChatCompletionsClient`.

### Bugs Fixed
### Breaking Changes

### Other Changes
The field `input_tokens` was removed from class `EmbeddingsUsage`, as this was never defined in the
REST API and the service never returned this value.

## 1.0.0b1 (2024-06-11)

Expand Down
2 changes: 1 addition & 1 deletion sdk/ai/azure-ai-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ To load an asynchronous client, import the `load_client` function from `azure.ai

Entra ID authentication is also supported by the `load_client` function. Replace the key authentication above with `credential=DefaultAzureCredential()` for example.

### Getting AI model information
### Get AI model information

All clients provide a `get_model_info` method to retrive AI model information. This makes a REST call to the `/info` route on the provided endpoint, as documented in [the REST API reference](https://learn.microsoft.com/azure/ai-studio/reference/reference-model-inference-info).

Expand Down
7 changes: 4 additions & 3 deletions sdk/ai/azure-ai-inference/azure/ai/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from copy import deepcopy
from typing import Any, TYPE_CHECKING, Union
from typing_extensions import Self

from azure.core import PipelineClient
from azure.core.credentials import AzureKeyCredential
Expand Down Expand Up @@ -101,7 +102,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
def close(self) -> None:
self._client.close()

def __enter__(self) -> "ChatCompletionsClient":
def __enter__(self) -> Self:
self._client.__enter__()
return self

Expand Down Expand Up @@ -179,7 +180,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
def close(self) -> None:
self._client.close()

def __enter__(self) -> "EmbeddingsClient":
def __enter__(self) -> Self:
self._client.__enter__()
return self

Expand Down Expand Up @@ -257,7 +258,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
def close(self) -> None:
self._client.close()

def __enter__(self) -> "ImageEmbeddingsClient":
def __enter__(self) -> Self:
self._client.__enter__()
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,5 +883,6 @@ def rest_discriminator(
*,
name: typing.Optional[str] = None,
type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
visibility: typing.Optional[typing.List[str]] = None,
) -> typing.Any:
return _RestField(name=name, type=type, is_discriminator=True)
return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility)
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def _complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any
) -> _models.ChatCompletions: ...
@overload
Expand Down Expand Up @@ -240,6 +241,7 @@ def _complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any
) -> _models.ChatCompletions:
# pylint: disable=line-too-long
Expand Down Expand Up @@ -317,9 +319,12 @@ def _complete(
~azure.ai.inference.models.ChatCompletionsNamedToolSelection
:keyword seed: If specified, the system will make a best effort to sample deterministically
such that repeated requests with the
same seed and parameters should return the same result. Determinism is not guaranteed.".
Default value is None.
same seed and parameters should return the same result. Determinism is not guaranteed. Default
value is None.
:paramtype seed: int
:keyword model: ID of the specific AI model to use, if more than one model is available on the
endpoint. Default value is None.
:paramtype model: str
:return: ChatCompletions. The ChatCompletions is compatible with MutableMapping
:rtype: ~azure.ai.inference.models.ChatCompletions
:raises ~azure.core.exceptions.HttpResponseError:
Expand All @@ -338,6 +343,8 @@ def _complete(
frequency increases and decrease the likelihood of the model repeating the same
statements verbatim. Supported range is [-2, 2].
"max_tokens": 0, # Optional. The maximum number of tokens to generate.
"model": "str", # Optional. ID of the specific AI model to use, if more than
one model is available on the endpoint.
"presence_penalty": 0.0, # Optional. A value that influences the probability
of generated tokens appearing based on their existing presence in generated text.
Positive values will make tokens less likely to appear when they already exist
Expand All @@ -348,7 +355,7 @@ def _complete(
"json_object".
"seed": 0, # Optional. If specified, the system will make a best effort to
sample deterministically such that repeated requests with the same seed and
parameters should return the same result. Determinism is not guaranteed.".
parameters should return the same result. Determinism is not guaranteed.
"stop": [
"str" # Optional. A collection of textual sequences that will end
completions generation.
Expand Down Expand Up @@ -435,6 +442,7 @@ def _complete(
"frequency_penalty": frequency_penalty,
"max_tokens": max_tokens,
"messages": messages,
"model": model,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
Expand Down
61 changes: 31 additions & 30 deletions sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys

from io import IOBase
from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING
from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, Iterable

from azure.core.pipeline import PipelineResponse
from azure.core.credentials import AzureKeyCredential
Expand Down Expand Up @@ -75,7 +75,7 @@

def load_client(
endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any
) -> Union[ChatCompletionsClientGenerated, EmbeddingsClientGenerated, ImageEmbeddingsClientGenerated]:
) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]:
"""
Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route
on the given endpoint, to determine the model type and therefore which client to instantiate.
Expand All @@ -90,7 +90,7 @@ def load_client(
"2024-05-01-preview". Note that overriding this default value may result in unsupported
behavior.
:paramtype api_version: str
:return: The appropriate client associated with the given endpoint
:return: The appropriate synchronous client associated with the given endpoint
:rtype: ~azure.ai.inference.ChatCompletionsClient or ~azure.ai.inference.EmbeddingsClient
or ~azure.ai.inference.ImageEmbeddingsClient
:raises ~azure.core.exceptions.HttpResponseError
Expand All @@ -110,7 +110,9 @@ def load_client(
# TODO: Remove "completions" and "embedding" once Mistral Large and Cohere fixes their model type
if model_info.model_type in (_models.ModelType.CHAT, "completion"):
chat_completion_client = ChatCompletionsClient(endpoint, credential, **kwargs)
chat_completion_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init
chat_completion_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init
model_info
)
return chat_completion_client

if model_info.model_type in (_models.ModelType.EMBEDDINGS, "embedding"):
Expand All @@ -120,7 +122,9 @@ def load_client(

if model_info.model_type == _models.ModelType.IMAGE_EMBEDDINGS:
image_embedding_client = ImageEmbeddingsClient(endpoint, credential, **kwargs)
image_embedding_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init
image_embedding_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init
model_info
)
return image_embedding_client

raise ValueError(f"No client available to support AI model type `{model_info.model_type}`")
Expand Down Expand Up @@ -165,6 +169,7 @@ def complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any,
) -> _models.ChatCompletions: ...

Expand All @@ -188,8 +193,9 @@ def complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any,
) -> _models.StreamingChatCompletions: ...
) -> Iterable[_models.StreamingChatCompletionsUpdate]: ...

@overload
def complete(
Expand All @@ -211,8 +217,9 @@ def complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any,
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
# pylint: disable=line-too-long
"""Gets chat completions for the provided chat messages.
Completions support a wide variety of tasks and generate text that continues from or
Expand Down Expand Up @@ -294,10 +301,13 @@ def complete(
~azure.ai.inference.models.ChatCompletionsNamedToolSelection
:keyword seed: If specified, the system will make a best effort to sample deterministically
such that repeated requests with the
same seed and parameters should return the same result. Determinism is not guaranteed.".
same seed and parameters should return the same result. Determinism is not guaranteed.
Default value is None.
:paramtype seed: int
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
:keyword model: ID of the specific AI model to use, if more than one model is available on the
endpoint. Default value is None.
:paramtype model: str
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
:raises ~azure.core.exceptions.HttpResponseError
"""
Expand All @@ -309,7 +319,7 @@ def complete(
*,
content_type: str = "application/json",
**kwargs: Any,
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
# pylint: disable=line-too-long
"""Gets chat completions for the provided chat messages.
Completions support a wide variety of tasks and generate text that continues from or
Expand All @@ -321,7 +331,7 @@ def complete(
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
Default value is "application/json".
:paramtype content_type: str
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
:raises ~azure.core.exceptions.HttpResponseError
"""
Expand All @@ -333,7 +343,7 @@ def complete(
*,
content_type: str = "application/json",
**kwargs: Any,
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
# pylint: disable=line-too-long
# pylint: disable=too-many-locals
"""Gets chat completions for the provided chat messages.
Expand All @@ -345,7 +355,7 @@ def complete(
:keyword content_type: Body Parameter content-type. Content type parameter for binary body.
Default value is "application/json".
:paramtype content_type: str
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
:raises ~azure.core.exceptions.HttpResponseError
"""
Expand All @@ -370,8 +380,9 @@ def complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any,
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
# pylint: disable=line-too-long
# pylint: disable=too-many-locals
"""Gets chat completions for the provided chat messages.
Expand Down Expand Up @@ -451,10 +462,13 @@ def complete(
~azure.ai.inference.models.ChatCompletionsNamedToolSelection
:keyword seed: If specified, the system will make a best effort to sample deterministically
such that repeated requests with the
same seed and parameters should return the same result. Determinism is not guaranteed.".
same seed and parameters should return the same result. Determinism is not guaranteed.
Default value is None.
:paramtype seed: int
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
:keyword model: ID of the specific AI model to use, if more than one model is available on the
endpoint. Default value is None.
:paramtype model: str
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
:raises ~azure.core.exceptions.HttpResponseError
"""
Expand All @@ -479,6 +493,7 @@ def complete(
"frequency_penalty": frequency_penalty,
"max_tokens": max_tokens,
"messages": messages,
"model": model,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
Expand Down Expand Up @@ -603,13 +618,6 @@ def embed(
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
Default value is "application/json".
:paramtype content_type: str
:keyword extras: Extra parameters (in the form of string key-value pairs) that are not in the
standard request payload.
They will be passed to the service as-is in the root of the JSON request payload.
How the service handles these extra parameters depends on the value of the
``extra-parameters``
HTTP request header. Default value is None.
:paramtype extras: dict[str, str]
:keyword dimensions: Optional. The number of dimensions the resulting output embeddings should
have.
Passing null causes the model to use its default value.
Expand Down Expand Up @@ -855,13 +863,6 @@ def embed(
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
Default value is "application/json".
:paramtype content_type: str
:keyword extras: Extra parameters (in the form of string key-value pairs) that are not in the
standard request payload.
They will be passed to the service as-is in the root of the JSON request payload.
How the service handles these extra parameters depends on the value of the
``extra-parameters``
HTTP request header. Default value is None.
:paramtype extras: dict[str, str]
:keyword dimensions: Optional. The number of dimensions the resulting output embeddings should
have.
Passing null causes the model to use its default value.
Expand Down
7 changes: 4 additions & 3 deletions sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from copy import deepcopy
from typing import Any, Awaitable, TYPE_CHECKING, Union
from typing_extensions import Self

from azure.core import AsyncPipelineClient
from azure.core.credentials import AzureKeyCredential
Expand Down Expand Up @@ -105,7 +106,7 @@ def send_request(
async def close(self) -> None:
await self._client.close()

async def __aenter__(self) -> "ChatCompletionsClient":
async def __aenter__(self) -> Self:
await self._client.__aenter__()
return self

Expand Down Expand Up @@ -187,7 +188,7 @@ def send_request(
async def close(self) -> None:
await self._client.close()

async def __aenter__(self) -> "EmbeddingsClient":
async def __aenter__(self) -> Self:
await self._client.__aenter__()
return self

Expand Down Expand Up @@ -269,7 +270,7 @@ def send_request(
async def close(self) -> None:
await self._client.close()

async def __aenter__(self) -> "ImageEmbeddingsClient":
async def __aenter__(self) -> Self:
await self._client.__aenter__()
return self

Expand Down
Loading