Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Azure.AI.Projects] Add support for Inference #46972

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions eng/Packages.Data.props
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
<PackageReference Update="Azure.Storage.Queues" Version="12.19.1" />
<PackageReference Update="Azure.Storage.Files.Shares" Version="12.19.1" />
<PackageReference Update="Azure.AI.OpenAI" Version="2.0.0" />
<PackageReference Update="Azure.AI.Inference" Version="1.0.0-beta.2" />
<PackageReference Update="Azure.ResourceManager" Version="1.13.0" />
<PackageReference Update="Azure.ResourceManager.AppConfiguration" Version="1.3.2" />
<PackageReference Update="Azure.ResourceManager.ApplicationInsights" Version="1.0.0" />
Expand Down
6 changes: 3 additions & 3 deletions sdk/ai/Azure.AI.Projects/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Azure.AI.Client client library for .NET
# Azure AI Projects client library for .NET

The Azure AI Assistants client library for .NET is an adaptation of OpenAI's REST APIs that provides an idiomatic interface
TODO: [Update README] The Azure AI Assistants client library for .NET is an adaptation of OpenAI's REST APIs that provides an idiomatic interface
and rich integration with the rest of the Azure SDK ecosystem. It will connect to Azure AI resources endpoint.

Use this library to:
Expand All @@ -20,7 +20,7 @@ To use Assistants capabilities, you'll need to use an Azure AI resource, you mus
Install the client library for .NET with [NuGet](https://www.nuget.org/ ):

```dotnetcli
dotnet add package Azure.AI.Project --prerelease
dotnet add package Azure.AI.Projects --prerelease
```

### Authenticate the client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,9 @@ public AIProjectClient(System.Uri endpoint, string subscriptionId, string resour
public AIProjectClient(System.Uri endpoint, string subscriptionId, string resourceGroupName, string projectName, Azure.Core.TokenCredential credential, Azure.AI.Projects.AIProjectClientOptions options) { }
public virtual Azure.Core.Pipeline.HttpPipeline Pipeline { get { throw null; } }
public virtual Azure.AI.Projects.AgentsClient GetAgentsClient(string apiVersion = "2024-07-01-preview") { throw null; }
public virtual Azure.AI.Inference.ChatCompletionsClient GetChatCompletionsClient() { throw null; }
public virtual Azure.AI.Projects.ConnectionsClient GetConnectionsClient(string apiVersion = "2024-07-01-preview") { throw null; }
public virtual Azure.AI.Inference.EmbeddingsClient GetEmbeddingsClient() { throw null; }
public virtual Azure.AI.Projects.EvaluationsClient GetEvaluationsClient(string apiVersion = "2024-07-01-preview") { throw null; }
}
public partial class AIProjectClientOptions : Azure.Core.ClientOptions
Expand Down
1 change: 1 addition & 0 deletions sdk/ai/Azure.AI.Projects/src/Azure.AI.Projects.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<ItemGroup>
<PackageReference Include="Azure.Core" />
<PackageReference Include="System.Text.Json" />
<PackageReference Include="Azure.AI.Inference" />
</ItemGroup>

<!-- Shared source from Azure.Core -->
Expand Down
45 changes: 45 additions & 0 deletions sdk/ai/Azure.AI.Projects/src/Custom/AIProjectClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using Azure.AI.Inference;
using Azure.Core;

namespace Azure.AI.Projects
Expand Down Expand Up @@ -36,5 +37,49 @@ public AIProjectClient(string connectionString, TokenCredential credential, AIPr
options)
{
}

private ChatCompletionsClient _chatCompletionsClient;
private EmbeddingsClient _embeddingsClient;

/// <summary> Initializes a new instance of Inference's ChatCompletionsClient. </summary>
public virtual ChatCompletionsClient GetChatCompletionsClient()
{
return _chatCompletionsClient ??= InitializeInferenceClient((endpoint, credential) =>
new ChatCompletionsClient(endpoint, credential, new AzureAIInferenceClientOptions()));
KrzysztofCwalina marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary> Initializes a new instance of Inference's EmbeddingsClient. </summary>
public virtual EmbeddingsClient GetEmbeddingsClient()
{
return _embeddingsClient ??= InitializeInferenceClient((endpoint, credential) =>
new EmbeddingsClient(endpoint, credential, new AzureAIInferenceClientOptions()));
}

/// <summary> Initializes a new instance of Inference client. </summary>
private T InitializeInferenceClient<T>(Func<Uri, AzureKeyCredential, T> clientFactory)
{
var connectionsClient = GetConnectionsClient();
ConnectionsListSecretsResponse connectionSecret = connectionsClient.GetDefaultConnection(ConnectionType.Serverless, true);

if (connectionSecret.Properties is ConnectionPropertiesApiKeyAuth apiKeyAuthProperties)
{
if (string.IsNullOrWhiteSpace(apiKeyAuthProperties.Target))
{
throw new ArgumentException("The API key authentication target URI is missing or invalid.");
}

if (!Uri.TryCreate(apiKeyAuthProperties.Target, UriKind.Absolute, out var endpoint))
{
throw new UriFormatException("Invalid URI format in API key authentication target.");
}

var credential = new AzureKeyCredential(apiKeyAuthProperties.Credentials.Key);
return clientFactory(endpoint, credential);
}
else
{
throw new ArgumentException("Cannot connect with Inference! Ensure valid ConnectionPropertiesApiKeyAuth.");
}
}
}
}
145 changes: 144 additions & 1 deletion sdk/ai/Azure.AI.Projects/src/Custom/Connection/ConnectionsClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#nullable disable

using System;
using System.Threading.Tasks;
using System.Threading;
using Azure.Core;
using Azure.Core.Pipeline;

Expand Down Expand Up @@ -73,11 +75,152 @@ public ConnectionsClient(Uri endpoint, string subscriptionId, string resourceGro
ClientDiagnostics = new ClientDiagnostics(options, true);
_tokenCredential = credential;
_pipeline = HttpPipelineBuilder.Build(options, Array.Empty<HttpPipelinePolicy>(), new HttpPipelinePolicy[] { new BearerTokenAuthenticationPolicy(_tokenCredential, AuthorizationScopes) }, new ResponseClassifier());
_endpoint = endpoint;
_endpoint = new Uri("https://management.azure.com");
_subscriptionId = subscriptionId;
_resourceGroupName = resourceGroupName;
_projectName = projectName;
_apiVersion = options.Version;
}

/// <summary> Initializes a new instance of ConnectionsClient. </summary>
/// <param name="clientDiagnostics"> The handler for diagnostic messaging in the client. </param>
/// <param name="pipeline"> The HTTP pipeline for sending and receiving REST requests and responses. </param>
/// <param name="tokenCredential"> The token credential to copy. </param>
/// <param name="endpoint"> The Azure AI Studio project endpoint, in the form `https://&lt;azure-region&gt;.api.azureml.ms` or `https://&lt;private-link-guid&gt;.&lt;azure-region&gt;.api.azureml.ms`, where &lt;azure-region&gt; is the Azure region where the project is deployed (e.g. westus) and &lt;private-link-guid&gt; is the GUID of the Enterprise private link. </param>
/// <param name="subscriptionId"> The Azure subscription ID. </param>
/// <param name="resourceGroupName"> The name of the Azure Resource Group. </param>
/// <param name="projectName"> The Azure AI Studio project name. </param>
/// <param name="apiVersion"> The API version to use for this operation. </param>
internal ConnectionsClient(ClientDiagnostics clientDiagnostics, HttpPipeline pipeline, TokenCredential tokenCredential, Uri endpoint, string subscriptionId, string resourceGroupName, string projectName, string apiVersion)
{
ClientDiagnostics = clientDiagnostics;
_pipeline = pipeline;
_tokenCredential = tokenCredential;
_endpoint = new Uri("https://management.azure.com");
_subscriptionId = subscriptionId;
_resourceGroupName = resourceGroupName;
_projectName = projectName;
_apiVersion = apiVersion;
}

/// <summary> List the details of all the connections (not including their credentials). </summary>
/// <param name="category"> Category of the workspace connection. </param>
/// <param name="withCredential"></param>
/// <param name="includeAll"> Indicates whether to list datastores. Service default: do not list datastores. </param>
/// <param name="target"> Target of the workspace connection. </param>
/// <param name="cancellationToken"> The cancellation token to use. </param>
internal virtual async Task<Response<ConnectionsListSecretsResponse>> GetDefaultConnectionAsync(ConnectionType category, bool? withCredential = null, bool? includeAll = null, string target = null, CancellationToken cancellationToken = default)
{
ConnectionsListResponse connections = await GetConnectionsAsync(category, includeAll, target, cancellationToken).ConfigureAwait(false);

if (connections?.Value == null || connections.Value.Count == 0)
{
throw new InvalidOperationException("No connections found for the specified parameters.");
}

var secret = connections.Value[0];
return withCredential.GetValueOrDefault()
? await GetSecretsAsync(secret.Name, "ignored").ConfigureAwait(false)
: await GetConnectionAsync(secret.Name).ConfigureAwait(false);
}

/// <summary> Get the details of a single connection. </summary>
/// <param name="category"> Category of the workspace connection. </param>
/// <param name="withCredential"></param>
/// <param name="includeAll"> Indicates whether to list datastores. Service default: do not list datastores. </param>
/// <param name="target"> Target of the workspace connection. </param>
/// <param name="cancellationToken"> The cancellation token to use. </param>
internal virtual Response<ConnectionsListSecretsResponse> GetDefaultConnection(ConnectionType category, bool? withCredential = null, bool? includeAll = null, string target = null, CancellationToken cancellationToken = default)
{
ConnectionsListResponse connections = GetConnections(category, includeAll, target, cancellationToken);

if (connections?.Value == null || connections.Value.Count == 0)
{
throw new InvalidOperationException("No connections found for the specified parameters.");
}

var secret = connections.Value[0];
return withCredential.GetValueOrDefault()
? GetSecrets(secret.Name, "ignored")
: GetConnection(secret.Name);
}

// CUSTOM: Fixed the request URI by removing "/agents/v1.0"
internal HttpMessage CreateGetConnectionsRequest(string category, bool? includeAll, string target, RequestContext context)
{
var message = _pipeline.CreateMessage(context, ResponseClassifier200);
var request = message.Request;
request.Method = RequestMethod.Get;
var uri = new RawRequestUriBuilder();
uri.Reset(_endpoint);
uri.AppendRaw("/subscriptions/", false);
uri.AppendRaw(_subscriptionId, true);
uri.AppendRaw("/resourceGroups/", false);
uri.AppendRaw(_resourceGroupName, true);
uri.AppendRaw("/providers/Microsoft.MachineLearningServices/workspaces/", false);
uri.AppendRaw(_projectName, true);
uri.AppendPath("/connections", false);
uri.AppendQuery("api-version", _apiVersion, true);
if (category != null)
{
uri.AppendQuery("category", category, true);
}
if (includeAll != null)
{
uri.AppendQuery("includeAll", includeAll.Value, true);
}
if (target != null)
{
uri.AppendQuery("target", target, true);
}
request.Uri = uri;
request.Headers.Add("Accept", "application/json");
return message;
}

internal HttpMessage CreateGetConnectionRequest(string connectionName, RequestContext context)
{
var message = _pipeline.CreateMessage(context, ResponseClassifier200);
var request = message.Request;
request.Method = RequestMethod.Get;
var uri = new RawRequestUriBuilder();
uri.Reset(_endpoint);
uri.AppendRaw("/subscriptions/", false);
uri.AppendRaw(_subscriptionId, true);
uri.AppendRaw("/resourceGroups/", false);
uri.AppendRaw(_resourceGroupName, true);
uri.AppendRaw("/providers/Microsoft.MachineLearningServices/workspaces/", false);
uri.AppendRaw(_projectName, true);
uri.AppendPath("/connections/", false);
uri.AppendPath(connectionName, true);
uri.AppendQuery("api-version", _apiVersion, true);
request.Uri = uri;
request.Headers.Add("Accept", "application/json");
return message;
}

internal HttpMessage CreateGetSecretsRequest(string connectionName, RequestContent content, RequestContext context)
{
var message = _pipeline.CreateMessage(context, ResponseClassifier200);
var request = message.Request;
request.Method = RequestMethod.Post;
var uri = new RawRequestUriBuilder();
uri.Reset(_endpoint);
uri.AppendRaw("/subscriptions/", false);
uri.AppendRaw(_subscriptionId, true);
uri.AppendRaw("/resourceGroups/", false);
uri.AppendRaw(_resourceGroupName, true);
uri.AppendRaw("/providers/Microsoft.MachineLearningServices/workspaces/", false);
uri.AppendRaw(_projectName, true);
uri.AppendPath("/connections/", false);
uri.AppendPath(connectionName, true);
uri.AppendPath("/listsecrets", false);
uri.AppendQuery("api-version", _apiVersion, true);
request.Uri = uri;
request.Headers.Add("Accept", "application/json");
request.Headers.Add("Content-Type", "application/json");
request.Content = content;
return message;
}
}
}
Loading
Loading