Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Azure.AI.Projects] Fix Inference endpoint construction and expose Connection types #47219

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdk/ai/Azure.AI.Projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ To further diagnose and troubleshoot issues, you can enable logging following th
## Next steps

Beyond the introductory scenarios discussed, the AI Projects client library offers support for additional scenarios to help take advantage of the full feature set of the AI services. In order to help explore some of these scenarios, the AI Projects client library offers a set of samples to serve as an illustration for common scenarios. Please see the `Azure.AI.Projects/tests/Samples` for details.
Beyond the introductory scenarios discussed, the AI Projects client library offers support for additional scenarios to help take advantage of the full feature set of the AI services. In order to help explore some of these scenarios, the AI Projects client library offers a set of samples to serve as an illustration for common scenarios. Please see the [Samples][samples] for details.

## Contributing

Expand All @@ -384,6 +384,7 @@ This project has adopted the [Microsoft Open Source Code of Conduct][code_of_con

<!-- LINKS -->
[RequestFailedException]: https://learn.microsoft.com/dotnet/api/azure.requestfailedexception?view=azure-dotnet
[samples]: https://github.com/Azure/azure-sdk-for-net/tree/main/sdk/ai/Azure.AI.Projects/tests/Samples
[azure_identity]: https://learn.microsoft.com/dotnet/api/overview/azure/identity-readme?view=azure-dotnet
[azure_identity_dac]: https://learn.microsoft.com/dotnet/api/azure.identity.defaultazurecredential?view=azure-dotnet
[aiprojects_contrib]: https://github.com/Azure/azure-sdk-for-net/blob/main/CONTRIBUTING.md
Expand Down
80 changes: 60 additions & 20 deletions sdk/ai/Azure.AI.Projects/api/Azure.AI.Projects.netstandard2.0.cs

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions sdk/ai/Azure.AI.Projects/src/Custom/AIProjectClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Linq;
using Azure.AI.Inference;
using Azure.Core;

Expand Down Expand Up @@ -65,9 +66,9 @@ private T InitializeInferenceClient<T>(Func<Uri, AzureKeyCredential, T> clientFa
bool useServerlessConnection = Environment.GetEnvironmentVariable("USE_SERVERLESS_CONNECTION") == "true";
ConnectionType connectionType = useServerlessConnection ? ConnectionType.Serverless : ConnectionType.AzureAIServices;

GetConnectionResponse connectionSecret = connectionsClient.GetDefaultConnection(connectionType, true);
ConnectionResponse connection = connectionsClient.GetDefaultConnection(connectionType, true);

if (connectionSecret.Properties is InternalConnectionPropertiesApiKeyAuth apiKeyAuthProperties)
if (connection.Properties is ConnectionPropertiesApiKeyAuth apiKeyAuthProperties)
{
if (string.IsNullOrWhiteSpace(apiKeyAuthProperties.Target))
{
Expand All @@ -80,6 +81,14 @@ private T InitializeInferenceClient<T>(Func<Uri, AzureKeyCredential, T> clientFa
}

var credential = new AzureKeyCredential(apiKeyAuthProperties.Credentials.Key);
if (!useServerlessConnection)
{
// Be sure to use the Azure resource name here, not the connection name. Connection name is something that
// admins can pick when they manually create a new connection (or use bicep). Get the Azure resource name
// from the end of the connection id.
var azureResourceName = connection.Id.Split('/').Last();
endpoint = new Uri($"https://{azureResourceName}.services.ai.azure.com/models");
}
return clientFactory(endpoint, credential);
}
else
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Azure.Core;

namespace Azure.AI.Projects
{
[CodeGenModel("InternalConnectionProperties")]
public abstract partial class ConnectionProperties
{
/// <summary> Authentication type of the connection target. </summary>
public AuthenticationType AuthType { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Azure.Core;

namespace Azure.AI.Projects
{
[CodeGenModel("InternalConnectionPropertiesApiKeyAuth")]
public partial class ConnectionPropertiesApiKeyAuth : ConnectionProperties
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using Azure.Core;

namespace Azure.AI.Projects
{
/// <summary> Response from the listSecrets operation. </summary>
[CodeGenModel("GetConnectionResponse")]
public partial class ConnectionResponse
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -292,27 +292,27 @@ public virtual Response GetConnections(string category, bool? includeAll, string
/// <param name="cancellationToken"> The cancellation token to use. </param>
/// <exception cref="ArgumentNullException"> <paramref name="connectionName"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="connectionName"/> is an empty string, and was expected to be non-empty. </exception>
public virtual async Task<Response<GetConnectionResponse>> GetConnectionAsync(string connectionName, CancellationToken cancellationToken = default)
public virtual async Task<Response<ConnectionResponse>> GetConnectionAsync(string connectionName, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(connectionName, nameof(connectionName));

RequestContext context = FromCancellationToken(cancellationToken);
Response response = await GetConnectionAsync(connectionName, context).ConfigureAwait(false);
return Response.FromValue(GetConnectionResponse.FromResponse(response), response);
return Response.FromValue(ConnectionResponse.FromResponse(response), response);
}

/// <summary> Get the details of a single connection, without credentials. </summary>
/// <param name="connectionName"> Connection Name. </param>
/// <param name="cancellationToken"> The cancellation token to use. </param>
/// <exception cref="ArgumentNullException"> <paramref name="connectionName"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="connectionName"/> is an empty string, and was expected to be non-empty. </exception>
public virtual Response<GetConnectionResponse> GetConnection(string connectionName, CancellationToken cancellationToken = default)
public virtual Response<ConnectionResponse> GetConnection(string connectionName, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(connectionName, nameof(connectionName));

RequestContext context = FromCancellationToken(cancellationToken);
Response response = GetConnection(connectionName, context);
return Response.FromValue(GetConnectionResponse.FromResponse(response), response);
return Response.FromValue(ConnectionResponse.FromResponse(response), response);
}

/// <summary>
Expand Down Expand Up @@ -399,15 +399,15 @@ public virtual Response GetConnection(string connectionName, RequestContext cont
/// <param name="cancellationToken"> The cancellation token to use. </param>
/// <exception cref="ArgumentNullException"> <paramref name="connectionName"/> or <paramref name="ignored"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="connectionName"/> is an empty string, and was expected to be non-empty. </exception>
public virtual async Task<Response<GetConnectionResponse>> GetConnectionWithSecretsAsync(string connectionName, string ignored, CancellationToken cancellationToken = default)
public virtual async Task<Response<ConnectionResponse>> GetConnectionWithSecretsAsync(string connectionName, string ignored, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(connectionName, nameof(connectionName));
Argument.AssertNotNull(ignored, nameof(ignored));

GetConnectionWithSecretsRequest getConnectionWithSecretsRequest = new GetConnectionWithSecretsRequest(ignored, null);
RequestContext context = FromCancellationToken(cancellationToken);
Response response = await GetConnectionWithSecretsAsync(connectionName, getConnectionWithSecretsRequest.ToRequestContent(), context).ConfigureAwait(false);
return Response.FromValue(GetConnectionResponse.FromResponse(response), response);
return Response.FromValue(ConnectionResponse.FromResponse(response), response);
}

/// <summary> Get the details of a single connection, including credentials (if available). </summary>
Expand All @@ -416,15 +416,15 @@ public virtual async Task<Response<GetConnectionResponse>> GetConnectionWithSecr
/// <param name="cancellationToken"> The cancellation token to use. </param>
/// <exception cref="ArgumentNullException"> <paramref name="connectionName"/> or <paramref name="ignored"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="connectionName"/> is an empty string, and was expected to be non-empty. </exception>
public virtual Response<GetConnectionResponse> GetConnectionWithSecrets(string connectionName, string ignored, CancellationToken cancellationToken = default)
public virtual Response<ConnectionResponse> GetConnectionWithSecrets(string connectionName, string ignored, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(connectionName, nameof(connectionName));
Argument.AssertNotNull(ignored, nameof(ignored));

GetConnectionWithSecretsRequest getConnectionWithSecretsRequest = new GetConnectionWithSecretsRequest(ignored, null);
RequestContext context = FromCancellationToken(cancellationToken);
Response response = GetConnectionWithSecrets(connectionName, getConnectionWithSecretsRequest.ToRequestContent(), context);
return Response.FromValue(GetConnectionResponse.FromResponse(response), response);
return Response.FromValue(ConnectionResponse.FromResponse(response), response);
}

/// <summary>
Expand Down Expand Up @@ -515,7 +515,7 @@ public virtual Response GetConnectionWithSecrets(string connectionName, RequestC
/// <param name="includeAll"> Indicates whether to list datastores. Service default: do not list datastores. </param>
/// <param name="target"> Target of the workspace connection. </param>
/// <param name="cancellationToken"> The cancellation token to use. </param>
public virtual async Task<Response<GetConnectionResponse>> GetDefaultConnectionAsync(ConnectionType category, bool? withCredential = null, bool? includeAll = null, string target = null, CancellationToken cancellationToken = default)
public virtual async Task<Response<ConnectionResponse>> GetDefaultConnectionAsync(ConnectionType category, bool? withCredential = null, bool? includeAll = null, string target = null, CancellationToken cancellationToken = default)
{
ListConnectionsResponse connections = await GetConnectionsAsync(category, includeAll, target, cancellationToken).ConfigureAwait(false);

Expand All @@ -536,7 +536,7 @@ public virtual async Task<Response<GetConnectionResponse>> GetDefaultConnectionA
/// <param name="includeAll"> Indicates whether to list datastores. Service default: do not list datastores. </param>
/// <param name="target"> Target of the workspace connection. </param>
/// <param name="cancellationToken"> The cancellation token to use. </param>
public virtual Response<GetConnectionResponse> GetDefaultConnection(ConnectionType category, bool? withCredential = null, bool? includeAll = null, string target = null, CancellationToken cancellationToken = default)
public virtual Response<ConnectionResponse> GetDefaultConnection(ConnectionType category, bool? withCredential = null, bool? includeAll = null, string target = null, CancellationToken cancellationToken = default)
{
ListConnectionsResponse connections = GetConnections(category, includeAll, target, cancellationToken);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

namespace Azure.AI.Projects
{
/// <summary> The credentials needed for API key authentication. </summary>
public partial class CredentialsApiKeyAuth
{
}
}

This file was deleted.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading