Skip to content

Commit

Permalink
Azure AI Inference SDK - Beta 2 updates (#36163)
Browse files Browse the repository at this point in the history
The main reason for this release, shortly after the first release:
- Add strongly-typed `model` as an optional input argument to the `complete` method of `ChatCompletionsClient`. This is required for a high-visiblity project. For this project, developers must set `model`.

Breaking change (noted in CHANGELOG.md):
- The field `input_tokens` was removed from class `EmbeddingsUsage`, as this was never defined in the
REST API and the service never returned this value.

Other changes in this release:
- Addressing some test dept (work in progress)
  - Add tests for setting `model_extras` for sync and async clients. Make sure the additional parameters appear at the root of the JSON request payload, and make sure the `unknown_parameters` HTTP request header was set to `pass_through`.
  - Add tests to validate serialization of a dummy chat completion request that includes all type of input objects. This is a regression test (no service response needed), as the test looks at the JSON request payload and compared to a hard-coded expected string, that was previously verified by hand. This test includes the new `model` argument, as well as all other arguments defined by the REST API. It will catch any regressions in hand-written code.
- Update ref docs to remove mentioning of the old `extras` input argument to chat completions in hand-written code. The name was changed to `model_extras` before the first release, but looks like we still had some left-over ref-doc comments that describe the no-longer-existing argument.
- Remove unused function from the sample `sample_chat_completions_with_image_data.py`. Forgot to do that in the first release.
- Minor changes to root README.md
- Indicate that `complete` method with `stream=True` returns `Iterable[StreamingChatCompletionsUpdate]` for
the synchronous `ChatComletionsClient`, and `Iterable[StreamingChatCompletionsUpdate]` for the asynchronous
`ChatCompletionsClient`. Per feedback from Anna T.
- Update environment variable names used by sample code and test to start with "AZURE_AI" as common elsewhere, per feedback from Rob C.
  • Loading branch information
dargilco authored Jun 22, 2024
1 parent f792549 commit 444ed8b
Show file tree
Hide file tree
Showing 44 changed files with 594 additions and 479 deletions.
9 changes: 5 additions & 4 deletions sdk/ai/azure-ai-inference/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Release History

## 1.0.0b2 (Unreleased)
## 1.0.0b2 (2024-06-24)

### Features Added

### Breaking Changes
Add `model` as an optional input argument to the `complete` method of `ChatCompletionsClient`.

### Bugs Fixed
### Breaking Changes

### Other Changes
The field `input_tokens` was removed from class `EmbeddingsUsage`, as this was never defined in the
REST API and the service never returned this value.

## 1.0.0b1 (2024-06-11)

Expand Down
2 changes: 1 addition & 1 deletion sdk/ai/azure-ai-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ To load an asynchronous client, import the `load_client` function from `azure.ai

Entra ID authentication is also supported by the `load_client` function. Replace the key authentication above with `credential=DefaultAzureCredential()` for example.

### Getting AI model information
### Get AI model information

All clients provide a `get_model_info` method to retrive AI model information. This makes a REST call to the `/info` route on the provided endpoint, as documented in [the REST API reference](https://learn.microsoft.com/azure/ai-studio/reference/reference-model-inference-info).

Expand Down
7 changes: 4 additions & 3 deletions sdk/ai/azure-ai-inference/azure/ai/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from copy import deepcopy
from typing import Any, TYPE_CHECKING, Union
from typing_extensions import Self

from azure.core import PipelineClient
from azure.core.credentials import AzureKeyCredential
Expand Down Expand Up @@ -101,7 +102,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
def close(self) -> None:
self._client.close()

def __enter__(self) -> "ChatCompletionsClient":
def __enter__(self) -> Self:
self._client.__enter__()
return self

Expand Down Expand Up @@ -179,7 +180,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
def close(self) -> None:
self._client.close()

def __enter__(self) -> "EmbeddingsClient":
def __enter__(self) -> Self:
self._client.__enter__()
return self

Expand Down Expand Up @@ -257,7 +258,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
def close(self) -> None:
self._client.close()

def __enter__(self) -> "ImageEmbeddingsClient":
def __enter__(self) -> Self:
self._client.__enter__()
return self

Expand Down
3 changes: 2 additions & 1 deletion sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,5 +883,6 @@ def rest_discriminator(
*,
name: typing.Optional[str] = None,
type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
visibility: typing.Optional[typing.List[str]] = None,
) -> typing.Any:
return _RestField(name=name, type=type, is_discriminator=True)
return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility)
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def _complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any
) -> _models.ChatCompletions: ...
@overload
Expand Down Expand Up @@ -240,6 +241,7 @@ def _complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any
) -> _models.ChatCompletions:
# pylint: disable=line-too-long
Expand Down Expand Up @@ -317,9 +319,12 @@ def _complete(
~azure.ai.inference.models.ChatCompletionsNamedToolSelection
:keyword seed: If specified, the system will make a best effort to sample deterministically
such that repeated requests with the
same seed and parameters should return the same result. Determinism is not guaranteed.".
Default value is None.
same seed and parameters should return the same result. Determinism is not guaranteed. Default
value is None.
:paramtype seed: int
:keyword model: ID of the specific AI model to use, if more than one model is available on the
endpoint. Default value is None.
:paramtype model: str
:return: ChatCompletions. The ChatCompletions is compatible with MutableMapping
:rtype: ~azure.ai.inference.models.ChatCompletions
:raises ~azure.core.exceptions.HttpResponseError:
Expand All @@ -338,6 +343,8 @@ def _complete(
frequency increases and decrease the likelihood of the model repeating the same
statements verbatim. Supported range is [-2, 2].
"max_tokens": 0, # Optional. The maximum number of tokens to generate.
"model": "str", # Optional. ID of the specific AI model to use, if more than
one model is available on the endpoint.
"presence_penalty": 0.0, # Optional. A value that influences the probability
of generated tokens appearing based on their existing presence in generated text.
Positive values will make tokens less likely to appear when they already exist
Expand All @@ -348,7 +355,7 @@ def _complete(
"json_object".
"seed": 0, # Optional. If specified, the system will make a best effort to
sample deterministically such that repeated requests with the same seed and
parameters should return the same result. Determinism is not guaranteed.".
parameters should return the same result. Determinism is not guaranteed.
"stop": [
"str" # Optional. A collection of textual sequences that will end
completions generation.
Expand Down Expand Up @@ -435,6 +442,7 @@ def _complete(
"frequency_penalty": frequency_penalty,
"max_tokens": max_tokens,
"messages": messages,
"model": model,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
Expand Down
61 changes: 31 additions & 30 deletions sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys

from io import IOBase
from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING
from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, Iterable

from azure.core.pipeline import PipelineResponse
from azure.core.credentials import AzureKeyCredential
Expand Down Expand Up @@ -75,7 +75,7 @@

def load_client(
endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any
) -> Union[ChatCompletionsClientGenerated, EmbeddingsClientGenerated, ImageEmbeddingsClientGenerated]:
) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]:
"""
Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route
on the given endpoint, to determine the model type and therefore which client to instantiate.
Expand All @@ -90,7 +90,7 @@ def load_client(
"2024-05-01-preview". Note that overriding this default value may result in unsupported
behavior.
:paramtype api_version: str
:return: The appropriate client associated with the given endpoint
:return: The appropriate synchronous client associated with the given endpoint
:rtype: ~azure.ai.inference.ChatCompletionsClient or ~azure.ai.inference.EmbeddingsClient
or ~azure.ai.inference.ImageEmbeddingsClient
:raises ~azure.core.exceptions.HttpResponseError
Expand All @@ -110,7 +110,9 @@ def load_client(
# TODO: Remove "completions" and "embedding" once Mistral Large and Cohere fixes their model type
if model_info.model_type in (_models.ModelType.CHAT, "completion"):
chat_completion_client = ChatCompletionsClient(endpoint, credential, **kwargs)
chat_completion_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init
chat_completion_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init
model_info
)
return chat_completion_client

if model_info.model_type in (_models.ModelType.EMBEDDINGS, "embedding"):
Expand All @@ -120,7 +122,9 @@ def load_client(

if model_info.model_type == _models.ModelType.IMAGE_EMBEDDINGS:
image_embedding_client = ImageEmbeddingsClient(endpoint, credential, **kwargs)
image_embedding_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init
image_embedding_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init
model_info
)
return image_embedding_client

raise ValueError(f"No client available to support AI model type `{model_info.model_type}`")
Expand Down Expand Up @@ -165,6 +169,7 @@ def complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any,
) -> _models.ChatCompletions: ...

Expand All @@ -188,8 +193,9 @@ def complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any,
) -> _models.StreamingChatCompletions: ...
) -> Iterable[_models.StreamingChatCompletionsUpdate]: ...

@overload
def complete(
Expand All @@ -211,8 +217,9 @@ def complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any,
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
# pylint: disable=line-too-long
"""Gets chat completions for the provided chat messages.
Completions support a wide variety of tasks and generate text that continues from or
Expand Down Expand Up @@ -294,10 +301,13 @@ def complete(
~azure.ai.inference.models.ChatCompletionsNamedToolSelection
:keyword seed: If specified, the system will make a best effort to sample deterministically
such that repeated requests with the
same seed and parameters should return the same result. Determinism is not guaranteed.".
same seed and parameters should return the same result. Determinism is not guaranteed.
Default value is None.
:paramtype seed: int
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
:keyword model: ID of the specific AI model to use, if more than one model is available on the
endpoint. Default value is None.
:paramtype model: str
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
:raises ~azure.core.exceptions.HttpResponseError
"""
Expand All @@ -309,7 +319,7 @@ def complete(
*,
content_type: str = "application/json",
**kwargs: Any,
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
# pylint: disable=line-too-long
"""Gets chat completions for the provided chat messages.
Completions support a wide variety of tasks and generate text that continues from or
Expand All @@ -321,7 +331,7 @@ def complete(
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
Default value is "application/json".
:paramtype content_type: str
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
:raises ~azure.core.exceptions.HttpResponseError
"""
Expand All @@ -333,7 +343,7 @@ def complete(
*,
content_type: str = "application/json",
**kwargs: Any,
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
# pylint: disable=line-too-long
# pylint: disable=too-many-locals
"""Gets chat completions for the provided chat messages.
Expand All @@ -345,7 +355,7 @@ def complete(
:keyword content_type: Body Parameter content-type. Content type parameter for binary body.
Default value is "application/json".
:paramtype content_type: str
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
:raises ~azure.core.exceptions.HttpResponseError
"""
Expand All @@ -370,8 +380,9 @@ def complete(
Union[str, _models.ChatCompletionsToolSelectionPreset, _models.ChatCompletionsNamedToolSelection]
] = None,
seed: Optional[int] = None,
model: Optional[str] = None,
**kwargs: Any,
) -> Union[_models.StreamingChatCompletions, _models.ChatCompletions]:
) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]:
# pylint: disable=line-too-long
# pylint: disable=too-many-locals
"""Gets chat completions for the provided chat messages.
Expand Down Expand Up @@ -451,10 +462,13 @@ def complete(
~azure.ai.inference.models.ChatCompletionsNamedToolSelection
:keyword seed: If specified, the system will make a best effort to sample deterministically
such that repeated requests with the
same seed and parameters should return the same result. Determinism is not guaranteed.".
same seed and parameters should return the same result. Determinism is not guaranteed.
Default value is None.
:paramtype seed: int
:return: ChatCompletions for non-streaming, or StreamingChatCompletions for streaming.
:keyword model: ID of the specific AI model to use, if more than one model is available on the
endpoint. Default value is None.
:paramtype model: str
:return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming.
:rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions
:raises ~azure.core.exceptions.HttpResponseError
"""
Expand All @@ -479,6 +493,7 @@ def complete(
"frequency_penalty": frequency_penalty,
"max_tokens": max_tokens,
"messages": messages,
"model": model,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
Expand Down Expand Up @@ -603,13 +618,6 @@ def embed(
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
Default value is "application/json".
:paramtype content_type: str
:keyword extras: Extra parameters (in the form of string key-value pairs) that are not in the
standard request payload.
They will be passed to the service as-is in the root of the JSON request payload.
How the service handles these extra parameters depends on the value of the
``extra-parameters``
HTTP request header. Default value is None.
:paramtype extras: dict[str, str]
:keyword dimensions: Optional. The number of dimensions the resulting output embeddings should
have.
Passing null causes the model to use its default value.
Expand Down Expand Up @@ -855,13 +863,6 @@ def embed(
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
Default value is "application/json".
:paramtype content_type: str
:keyword extras: Extra parameters (in the form of string key-value pairs) that are not in the
standard request payload.
They will be passed to the service as-is in the root of the JSON request payload.
How the service handles these extra parameters depends on the value of the
``extra-parameters``
HTTP request header. Default value is None.
:paramtype extras: dict[str, str]
:keyword dimensions: Optional. The number of dimensions the resulting output embeddings should
have.
Passing null causes the model to use its default value.
Expand Down
7 changes: 4 additions & 3 deletions sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from copy import deepcopy
from typing import Any, Awaitable, TYPE_CHECKING, Union
from typing_extensions import Self

from azure.core import AsyncPipelineClient
from azure.core.credentials import AzureKeyCredential
Expand Down Expand Up @@ -105,7 +106,7 @@ def send_request(
async def close(self) -> None:
await self._client.close()

async def __aenter__(self) -> "ChatCompletionsClient":
async def __aenter__(self) -> Self:
await self._client.__aenter__()
return self

Expand Down Expand Up @@ -187,7 +188,7 @@ def send_request(
async def close(self) -> None:
await self._client.close()

async def __aenter__(self) -> "EmbeddingsClient":
async def __aenter__(self) -> Self:
await self._client.__aenter__()
return self

Expand Down Expand Up @@ -269,7 +270,7 @@ def send_request(
async def close(self) -> None:
await self._client.close()

async def __aenter__(self) -> "ImageEmbeddingsClient":
async def __aenter__(self) -> Self:
await self._client.__aenter__()
return self

Expand Down
Loading

0 comments on commit 444ed8b

Please sign in to comment.