Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure OpenAI updates #46103

Merged
merged 14 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions Custom/Internal/AzureAsyncCollectionResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#nullable enable

using System.ClientModel;
ralph-msft marked this conversation as resolved.
Show resolved Hide resolved
using System.ClientModel.Primitives;

namespace Azure.AI.OpenAI.Utility;

/// <summary>
/// Represents a collection of values returned from an asynchronous Azure cloud service operation.
/// </summary>
/// <typeparam name="TItem">The type of items in the collection.</typeparam>
/// <typeparam name="TContinuation">Type of the continuation token.</typeparam>
internal class AzureAsyncCollectionResult<TItem, TContinuation> : AsyncCollectionResult<TItem> where TContinuation : ContinuationToken
{
private readonly ClientPipeline _pipeline;
private readonly RequestOptions _options;
private readonly Func<TContinuation?, PipelineMessage> _createRequest;
private readonly Func<ClientResult, TContinuation?> _getContinuationToken;
private readonly Func<ClientResult, IEnumerable<TItem>> _getValues;
private readonly CancellationToken _cancellation;

/// <summary>
/// Creates a new instance.
/// </summary>
/// <param name="pipeline">The client pipeline to use to send requests.</param>
/// <param name="options">The request options to use.</param>
/// <param name="createRequest">The function used to create the request to get a page of results. The continuation token
/// may be set to null to get the first page. After that it will be set to a value used to get the next page of results.</param>
/// <param name="getContinuationToken">The function used to create a continuation token from a page of results.</param>
/// <param name="getValues">The function used to extract results from a page.</param>
/// <param name="cancellation">The cancellation token to use.</param>
/// <exception cref="ArgumentNullException">If any of the required arguments are null.</exception>
public AzureAsyncCollectionResult(
ClientPipeline pipeline,
RequestOptions options,
Func<TContinuation?, PipelineMessage> createRequest,
Func<ClientResult, TContinuation?> getContinuationToken,
Func<ClientResult, IEnumerable<TItem>> getValues,
CancellationToken cancellation)
{
_pipeline = pipeline ?? throw new ArgumentNullException(nameof(pipeline));
_options = options ?? new();
_getContinuationToken = getContinuationToken ?? throw new ArgumentNullException(nameof(_getContinuationToken));
_createRequest = createRequest ?? throw new ArgumentNullException(nameof(_createRequest));
_getValues = getValues ?? throw new ArgumentNullException(nameof(_getContinuationToken));
_cancellation = cancellation;
}

/// <inheritdoc />
public override ContinuationToken? GetContinuationToken(ClientResult page) => _getContinuationToken(page);

/// <inheritdoc />
public override async IAsyncEnumerable<ClientResult> GetRawPagesAsync()
{
TContinuation? continuation = null;
do
{
ClientResult page = await SendRequestAsync(continuation).ConfigureAwait(false);
continuation = _getContinuationToken(page);

yield return page;
} while (continuation != null);
}

/// <inheritdoc />
protected override IAsyncEnumerable<TItem> GetValuesFromPageAsync(ClientResult page)
=> _getValues(page).ToAsyncEnumerable(_cancellation);

/// <summary>
/// Sends a request to get the first page of results (<paramref name="continuationToken"/> is null),
/// or the next page of results (<paramref name="continuationToken"/> has a non-null value).
/// </summary>
/// <param name="continuationToken">The continuation token to use. Will be null when retrieving the first page of results.</param>
/// <returns>The result containing the page of results.</returns>
protected virtual async Task<ClientResult> SendRequestAsync(TContinuation? continuationToken)
{
using PipelineMessage message = _createRequest(continuationToken);
PipelineResponse response = await _pipeline.ProcessMessageAsync(message, _options).ConfigureAwait(false);
return ClientResult.FromResponse(response);
}
}
80 changes: 80 additions & 0 deletions Custom/Internal/AzureCollectionResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#nullable enable

using System.ClientModel;
using System.ClientModel.Primitives;

namespace Azure.AI.OpenAI.Utility;

/// <summary>
/// Represents a collection of values returned from an Azure cloud service operation.
/// </summary>
/// <typeparam name="TItem">The type of items in the collection.</typeparam>
/// <typeparam name="TContinuation">Type of the continuation token.</typeparam>
internal class AzureCollectionResult<TItem, TContinuation> : CollectionResult<TItem> where TContinuation : ContinuationToken
{
private readonly ClientPipeline _pipeline;
private readonly RequestOptions _options;
private readonly Func<TContinuation?, PipelineMessage> _createRequest;
private readonly Func<ClientResult, TContinuation?> _getContinuationToken;
private readonly Func<ClientResult, IEnumerable<TItem>> _getValues;

/// <summary>
/// Creates a new instance.
/// </summary>
/// <param name="pipeline">The client pipeline to use to send requests.</param>
/// <param name="options">The request options to use.</param>
/// <param name="createRequest">The function used to create the request to get a page of results. The continuation token
/// may be set to null to get the first page. After that it will be set to a value used to get the next page of results.</param>
/// <param name="getContinuationToken">The function used to create a continuation token from a page of results.</param>
/// <param name="getValues">The function used to extract results from a page.</param>
/// <exception cref="ArgumentNullException">If any of the required arguments are null.</exception>
public AzureCollectionResult(
ClientPipeline pipeline,
RequestOptions options,
Func<TContinuation?, PipelineMessage> createRequest,
Func<ClientResult, TContinuation?> getContinuationToken,
Func<ClientResult, IEnumerable<TItem>> getValues)
{
_pipeline = pipeline ?? throw new ArgumentNullException(nameof(pipeline));
_options = options ?? new();
_createRequest = createRequest ?? throw new ArgumentNullException(nameof(_createRequest));
_getContinuationToken = getContinuationToken ?? throw new ArgumentNullException(nameof(_getContinuationToken));
_getValues = getValues ?? throw new ArgumentNullException(nameof(_getContinuationToken));
}

/// <inheritdoc />
public override ContinuationToken? GetContinuationToken(ClientResult page) => _getContinuationToken(page);

/// <inheritdoc />
public override IEnumerable<ClientResult> GetRawPages()
{
TContinuation? continuation = null;

do
{
ClientResult page = SendRequest(continuation);
continuation = _getContinuationToken(page);

yield return page;
}
while (continuation != null);
}

/// <inheritdoc />
protected override IEnumerable<TItem> GetValuesFromPage(ClientResult page) => _getValues(page);

/// <summary>
/// Sends a request to get the first page of results (<paramref name="continuationToken"/> is null),
/// or the next page of results (<paramref name="continuationToken"/> has a non-null value).
/// </summary>
/// <param name="continuationToken">The continuation token to use. Will be null when retrieving the first page of results.</param>
/// <returns>The result containing the page of results.</returns>
protected virtual ClientResult SendRequest(TContinuation? continuationToken)
{
using PipelineMessage message = _createRequest(continuationToken);
return ClientResult.FromResponse(_pipeline.ProcessMessage(message, _options));
}
}
5 changes: 3 additions & 2 deletions eng/Packages.Data.props
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@

<!-- BCL packages -->
<PackageReference Update="System.Buffers" Version="4.5.1" />
<PackageReference Update="System.ClientModel" Version="1.0.0" />
<PackageReference Update="System.ClientModel" Version="1.1.0" />
ralph-msft marked this conversation as resolved.
Show resolved Hide resolved
<PackageReference Update="System.IO.Hashing" Version="6.0.0" />
<PackageReference Update="System.Memory" Version="4.5.5" />
<PackageReference Update="System.Memory.Data" Version="1.0.2" />
Expand Down Expand Up @@ -179,7 +179,7 @@
</ItemGroup>

<ItemGroup Condition="$(MSBuildProjectName.StartsWith('Azure.AI.OpenAI'))">
<PackageReference Update="OpenAI" Version="2.0.0-beta.7" />
<PackageReference Update="OpenAI" Version="2.0.0-beta.12" />
</ItemGroup>

<!--
Expand Down Expand Up @@ -350,6 +350,7 @@
<PackageReference Update="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
<PackageReference Update="Portable.BouncyCastle" Version="1.9.0" />
<PackageReference Update="PublicApiGenerator" Version="10.0.1" />
<PackageReference Update="System.ClientModel" Version="1.1.0" />
<PackageReference Update="System.Diagnostics.TraceSource" Version="4.3.0" />
<PackageReference Update="System.IO.Compression" Version="4.3.0" />
<PackageReference Update="System.IO.Pipelines" Version="4.5.1" />
Expand Down
2 changes: 1 addition & 1 deletion sdk/openai/Azure.AI.OpenAI/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<ExternalAzureCoreLibrary>../../external/Azure.Core.Slim/netstandard2.0/Azure.Core.Slim.dll</ExternalAzureCoreLibrary>
-->
</PropertyGroup>

<!--
Add any shared properties you want for the projects under this package directory that need to be set before the auto imported Directory.Build.props
-->
Expand Down
4 changes: 3 additions & 1 deletion sdk/openai/Azure.AI.OpenAI/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ foreach (KeyValuePair<int, string> indexToIdPair in toolCallIdsByIndex)
functionArgumentBuildersByIndex[indexToIdPair.Key].ToString()));
}

conversationMessages.Add(new AssistantChatMessage(toolCalls, contentBuilder.ToString()));
var assistantChatMessage = new AssistantChatMessage(toolCalls);
assistantChatMessage.Content.Add(ChatMessageContentPart.CreateTextPart(contentBuilder.ToString()));
conversationMessages.Add(assistantChatMessage);

// Placeholder: each tool call must be resolved, like in the non-streaming case
string GetToolCallOutput(ChatToolCall toolCall) => null;
Expand Down
Loading
Loading